finish ruleset_v2 implementation

This commit is contained in:
Kevin Pham
2023-12-05 14:02:54 -06:00
parent 9d77c63697
commit b2f6cf9f1d
4 changed files with 120 additions and 692 deletions

View File

@@ -1,347 +0,0 @@
package ruleset
import (
"compress/gzip"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"gopkg.in/yaml.v3"
)
type Regex struct {
Match string `yaml:"match"`
Replace string `yaml:"replace"`
}
type KV struct {
Key string `yaml:"key"`
Value string `yaml:"value"`
}
type RuleSet []Rule
type Rule struct {
Domain string `yaml:"domain,omitempty"`
Domains []string `yaml:"domains,omitempty"`
Paths []string `yaml:"paths,omitempty"`
Headers struct {
UserAgent string `yaml:"user-agent,omitempty"`
XForwardedFor string `yaml:"x-forwarded-for,omitempty"`
Referer string `yaml:"referer,omitempty"`
Cookie string `yaml:"cookie,omitempty"`
CSP string `yaml:"content-security-policy,omitempty"`
} `yaml:"headers,omitempty"`
GoogleCache bool `yaml:"googleCache,omitempty"`
RegexRules []Regex `yaml:"regexRules,omitempty"`
URLMods struct {
Domain []Regex `yaml:"domain,omitempty"`
Path []Regex `yaml:"path,omitempty"`
Query []KV `yaml:"query,omitempty"`
} `yaml:"urlMods,omitempty"`
Injections []struct {
Position string `yaml:"position,omitempty"`
Append string `yaml:"append,omitempty"`
Prepend string `yaml:"prepend,omitempty"`
Replace string `yaml:"replace,omitempty"`
} `yaml:"injections,omitempty"`
}
var remoteRegex = regexp.MustCompile(`^https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()!@:%_\+.~#?&\/\/=]*)`)
// NewRulesetFromEnv creates a new RuleSet based on the RULESET environment variable.
// It logs a warning and returns an empty RuleSet if the RULESET environment variable is not set.
// If the RULESET is set but the rules cannot be loaded, it panics.
func NewRulesetFromEnv() RuleSet {
rulesPath, ok := os.LookupEnv("RULESET")
if !ok {
log.Printf("WARN: No ruleset specified. Set the `RULESET` environment variable to load one for a better success rate.")
return RuleSet{}
}
ruleSet, err := NewRuleset(rulesPath)
if err != nil {
log.Println(err)
}
return ruleSet
}
// NewRuleset loads a RuleSet from a given string of rule paths, separated by semicolons.
// It supports loading rules from both local file paths and remote URLs.
// Returns a RuleSet and an error if any issues occur during loading.
func NewRuleset(rulePaths string) (RuleSet, error) {
var ruleSet RuleSet
var errs []error
rp := strings.Split(rulePaths, ";")
for _, rule := range rp {
var err error
rulePath := strings.Trim(rule, " ")
isRemote := remoteRegex.MatchString(rulePath)
if isRemote {
err = ruleSet.loadRulesFromRemoteFile(rulePath)
} else {
err = ruleSet.loadRulesFromLocalDir(rulePath)
}
if err != nil {
e := fmt.Errorf("WARN: failed to load ruleset from '%s'", rulePath)
errs = append(errs, errors.Join(e, err))
continue
}
}
if len(errs) != 0 {
e := fmt.Errorf("WARN: failed to load %d rulesets", len(rp))
errs = append(errs, e)
// panic if the user specified a local ruleset, but it wasn't found on disk
// don't fail silently
for _, err := range errs {
if errors.Is(os.ErrNotExist, err) {
e := fmt.Errorf("PANIC: ruleset '%s' not found", err)
panic(errors.Join(e, err))
}
}
// else, bubble up any errors, such as syntax or remote host issues
return ruleSet, errors.Join(errs...)
}
ruleSet.PrintStats()
return ruleSet, nil
}
// ================== RULESET loading logic ===================================
// loadRulesFromLocalDir loads rules from a local directory specified by the path.
// It walks through the directory, loading rules from YAML files.
// Returns an error if the directory cannot be accessed
// If there is an issue loading any file, it will be skipped
func (rs *RuleSet) loadRulesFromLocalDir(path string) error {
_, err := os.Stat(path)
if err != nil {
return err
}
yamlRegex := regexp.MustCompile(`.*\.ya?ml`)
err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
if isYaml := yamlRegex.MatchString(path); !isYaml {
return nil
}
err = rs.loadRulesFromLocalFile(path)
if err != nil {
log.Printf("WARN: failed to load directory ruleset '%s': %s, skipping", path, err)
return nil
}
log.Printf("INFO: loaded ruleset %s\n", path)
return nil
})
if err != nil {
return err
}
return nil
}
// loadRulesFromLocalFile loads rules from a local YAML file specified by the path.
// Returns an error if the file cannot be read or if there's a syntax error in the YAML.
func (rs *RuleSet) loadRulesFromLocalFile(path string) error {
yamlFile, err := os.ReadFile(path)
if err != nil {
e := fmt.Errorf("failed to read rules from local file: '%s'", path)
return errors.Join(e, err)
}
var r RuleSet
err = yaml.Unmarshal(yamlFile, &r)
if err != nil {
e := fmt.Errorf("failed to load rules from local file, possible syntax error in '%s'", path)
ee := errors.Join(e, err)
if _, ok := os.LookupEnv("DEBUG"); ok {
debugPrintRule(string(yamlFile), ee)
}
return ee
}
*rs = append(*rs, r...)
return nil
}
// loadRulesFromRemoteFile loads rules from a remote URL.
// It supports plain and gzip compressed content.
// Returns an error if there's an issue accessing the URL or if there's a syntax error in the YAML.
func (rs *RuleSet) loadRulesFromRemoteFile(rulesURL string) error {
var r RuleSet
resp, err := http.Get(rulesURL)
if err != nil {
e := fmt.Errorf("failed to load rules from remote url '%s'", rulesURL)
return errors.Join(e, err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
e := fmt.Errorf("failed to load rules from remote url (%s) on '%s'", resp.Status, rulesURL)
return errors.Join(e, err)
}
var reader io.Reader
isGzip := strings.HasSuffix(rulesURL, ".gz") || strings.HasSuffix(rulesURL, ".gzip") || resp.Header.Get("content-encoding") == "gzip"
if isGzip {
reader, err = gzip.NewReader(resp.Body)
if err != nil {
return fmt.Errorf("failed to create gzip reader for URL '%s' with status code '%s': %w", rulesURL, resp.Status, err)
}
} else {
reader = resp.Body
}
err = yaml.NewDecoder(reader).Decode(&r)
if err != nil {
e := fmt.Errorf("failed to load rules from remote url '%s' with status code '%s' and possible syntax error", rulesURL, resp.Status)
ee := errors.Join(e, err)
return ee
}
*rs = append(*rs, r...)
return nil
}
// ================= utility methods ==========================
// Yaml returns the ruleset as a Yaml string
func (rs *RuleSet) Yaml() (string, error) {
y, err := yaml.Marshal(rs)
if err != nil {
return "", err
}
return string(y), nil
}
// GzipYaml returns an io.Reader that streams the Gzip-compressed YAML representation of the RuleSet.
func (rs *RuleSet) GzipYaml() (io.Reader, error) {
pr, pw := io.Pipe()
go func() {
defer pw.Close()
gw := gzip.NewWriter(pw)
defer gw.Close()
if err := yaml.NewEncoder(gw).Encode(rs); err != nil {
gw.Close() // Ensure to close the gzip writer
pw.CloseWithError(err)
return
}
}()
return pr, nil
}
// Domains extracts and returns a slice of all domains present in the RuleSet.
func (rs *RuleSet) Domains() []string {
var domains []string
for _, rule := range *rs {
domains = append(domains, rule.Domain)
domains = append(domains, rule.Domains...)
}
return domains
}
// DomainCount returns the count of unique domains present in the RuleSet.
func (rs *RuleSet) DomainCount() int {
return len(rs.Domains())
}
// Count returns the total number of rules in the RuleSet.
func (rs *RuleSet) Count() int {
return len(*rs)
}
// PrintStats logs the number of rules and domains loaded in the RuleSet.
func (rs *RuleSet) PrintStats() {
log.Printf("INFO: Loaded %d rules for %d domains\n", rs.Count(), rs.DomainCount())
}
// debugPrintRule is a utility function for printing a rule and associated error for debugging purposes.
func debugPrintRule(rule string, err error) {
fmt.Println("------------------------------ BEGIN DEBUG RULESET -----------------------------")
fmt.Printf("%s\n", err.Error())
fmt.Println("--------------------------------------------------------------------------------")
fmt.Println(rule)
fmt.Println("------------------------------ END DEBUG RULESET -------------------------------")
}
// ======================= RuleSetMap implementation =================================================
// RuleSetMap: A map with domain names as keys and pointers to the corresponding Rules as values.
// This type is used to efficiently access rules based on domain names.
type RuleSetMap map[string]*Rule
// ToMap converts a RuleSet into a RuleSetMap. It transforms each Rule in the RuleSet
// into a map entry where the key is the Rule's domain (lowercase)
// and the value is a pointer to the Rule. This method is used to
// efficiently access rules based on domain names.
// The RuleSetMap may be accessed with or without a "www." prefix in the domain.
func (rs *RuleSet) ToMap() RuleSetMap {
rsm := make(RuleSetMap)
addMapEntry := func(d string, rule *Rule) {
d = strings.ToLower(d)
rsm[d] = rule
if strings.HasPrefix(d, "www.") {
d = strings.TrimPrefix(d, "www.")
rsm[d] = rule
} else {
d = fmt.Sprintf("www.%s", d)
rsm[d] = rule
}
}
for i, rule := range *rs {
rulePtr := &(*rs)[i]
addMapEntry(rule.Domain, rulePtr)
for _, domain := range rule.Domains {
addMapEntry(domain, rulePtr)
}
}
return rsm
}

View File

@@ -1,225 +0,0 @@
package ruleset
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
)
var (
validYAML = `
- domain: example.com
regexRules:
- match: "^http:"
replace: "https:"`
invalidYAML = `
- domain: [thisIsATestYamlThatIsMeantToFail.example]
regexRules:
- match: "^http:"
replace: "https:"
- match: "[incomplete"`
)
func TestLoadRulesFromRemoteFile(t *testing.T) {
app := fiber.New()
defer app.Shutdown()
app.Get("/valid-config.yml", func(c *fiber.Ctx) error {
c.SendString(validYAML)
return nil
})
app.Get("/invalid-config.yml", func(c *fiber.Ctx) error {
c.SendString(invalidYAML)
return nil
})
app.Get("/valid-config.gz", func(c *fiber.Ctx) error {
c.Set("Content-Type", "application/octet-stream")
rs, err := loadRuleFromString(validYAML)
if err != nil {
t.Errorf("failed to load valid yaml from string: %s", err.Error())
}
s, err := rs.GzipYaml()
if err != nil {
t.Errorf("failed to load gzip serialize yaml: %s", err.Error())
}
err = c.SendStream(s)
if err != nil {
t.Errorf("failed to stream gzip serialized yaml: %s", err.Error())
}
return nil
})
// Start the server in a goroutine
go func() {
if err := app.Listen("127.0.0.1:9999"); err != nil {
t.Errorf("Server failed to start: %s", err.Error())
}
}()
// Wait for the server to start
time.Sleep(time.Second * 1)
rs, err := NewRuleset("http://127.0.0.1:9999/valid-config.yml")
if err != nil {
t.Errorf("failed to load plaintext ruleset from http server: %s", err.Error())
}
assert.Equal(t, rs[0].Domain, "example.com")
rs, err = NewRuleset("http://127.0.0.1:9999/valid-config.gz")
if err != nil {
t.Errorf("failed to load gzipped ruleset from http server: %s", err.Error())
}
assert.Equal(t, rs[0].Domain, "example.com")
os.Setenv("RULESET", "http://127.0.0.1:9999/valid-config.gz")
rs = NewRulesetFromEnv()
if !assert.Equal(t, rs[0].Domain, "example.com") {
t.Error("expected no errors loading ruleset from gzip url using environment variable, but got one")
}
}
func loadRuleFromString(yaml string) (RuleSet, error) {
// Create a temporary file and load it
tmpFile, _ := os.CreateTemp("", "ruleset*.yaml")
defer os.Remove(tmpFile.Name())
tmpFile.WriteString(yaml)
rs := RuleSet{}
err := rs.loadRulesFromLocalFile(tmpFile.Name())
return rs, err
}
// TestLoadRulesFromLocalFile tests the loading of rules from a local YAML file.
func TestLoadRulesFromLocalFile(t *testing.T) {
rs, err := loadRuleFromString(validYAML)
if err != nil {
t.Errorf("Failed to load rules from valid YAML: %s", err)
}
assert.Equal(t, rs[0].Domain, "example.com")
assert.Equal(t, rs[0].RegexRules[0].Match, "^http:")
assert.Equal(t, rs[0].RegexRules[0].Replace, "https:")
_, err = loadRuleFromString(invalidYAML)
if err == nil {
t.Errorf("Expected an error when loading invalid YAML, but got none")
}
}
// TestLoadRulesFromLocalDir tests the loading of rules from a local nested directory full of yaml rulesets
func TestLoadRulesFromLocalDir(t *testing.T) {
// Create a temporary directory
baseDir, err := os.MkdirTemp("", "ruleset_test")
if err != nil {
t.Fatalf("Failed to create temporary directory: %s", err)
}
defer os.RemoveAll(baseDir)
// Create a nested subdirectory
nestedDir := filepath.Join(baseDir, "nested")
err = os.Mkdir(nestedDir, 0o755)
if err != nil {
t.Fatalf("Failed to create nested directory: %s", err)
}
// Create a nested subdirectory
nestedTwiceDir := filepath.Join(nestedDir, "nestedTwice")
err = os.Mkdir(nestedTwiceDir, 0o755)
if err != nil {
t.Fatalf("Failed to create twice-nested directory: %s", err)
}
testCases := []string{"test.yaml", "test2.yaml", "test-3.yaml", "test 4.yaml", "1987.test.yaml.yml", "foobar.example.com.yaml", "foobar.com.yml"}
for _, fileName := range testCases {
filePath := filepath.Join(nestedDir, "2x-"+fileName)
os.WriteFile(filePath, []byte(validYAML), 0o644)
filePath = filepath.Join(nestedDir, fileName)
os.WriteFile(filePath, []byte(validYAML), 0o644)
filePath = filepath.Join(baseDir, "base-"+fileName)
os.WriteFile(filePath, []byte(validYAML), 0o644)
}
rs := RuleSet{}
err = rs.loadRulesFromLocalDir(baseDir)
assert.NoError(t, err)
assert.Equal(t, rs.Count(), len(testCases)*3)
for _, rule := range rs {
assert.Equal(t, rule.Domain, "example.com")
assert.Equal(t, rule.RegexRules[0].Match, "^http:")
assert.Equal(t, rule.RegexRules[0].Replace, "https:")
}
}
func TestToMap(t *testing.T) {
// Prepare a ruleset with multiple rules, including "www." prefixed domains
rules := RuleSet{
{
Domain: "Example.com",
RegexRules: []Regex{{Match: "match1", Replace: "replace1"}},
},
{
Domain: "www.AnotherExample.com",
RegexRules: []Regex{{Match: "match2", Replace: "replace2"}},
},
{
Domain: "www.foo.bAr.baz.bOol.quX.com",
RegexRules: []Regex{{Match: "match3", Replace: "replace3"}},
},
}
// Convert to RuleSetMap
rsm := rules.ToMap()
// Test for correct number of entries
if len(rsm) != 6 {
t.Errorf("Expected 6 entries in RuleSetMap, got %d", len(rsm))
}
// Test for correct mapping
testDomains := []struct {
domain string
expectedMatch string
}{
{"example.com", "match1"},
{"www.example.com", "match1"},
{"anotherexample.com", "match2"},
{"www.anotherexample.com", "match2"},
{"foo.bar.baz.bool.qux.com", "match3"},
{"no.ruleset.domain.com", ""},
}
for _, test := range testDomains {
if test.domain == "no.ruleset.domain.com" {
assert.Empty(t, test.expectedMatch)
continue
}
rule, exists := rsm[test.domain]
if !exists {
t.Errorf("Expected domain %s to exist in RuleSetMap", test.domain)
} else if rule.RegexRules[0].Match != test.expectedMatch {
t.Errorf("Expected match for %s to be %s, got %s", test.domain, test.expectedMatch, rule.RegexRules[0].Match)
}
}
}

View File

@@ -13,6 +13,7 @@ import (
"strings" "strings"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
//"encoding/json"
) )
type IRuleset interface { type IRuleset interface {
@@ -21,12 +22,68 @@ type IRuleset interface {
} }
type Ruleset struct { type Ruleset struct {
rulesetPath string Rules []Rule `json:"rules" yaml:"rules"`
rules map[string]Rule _rulemap map[string]*Rule // internal map for fast lookups; points at a rule in the Rules slice
} }
func (rs Ruleset) GetRule(url *url.URL) (rule Rule, exists bool) { func (rs *Ruleset) UnmarshalYAML(unmarshal func(interface{}) error) error {
rule, exists = rs.rules[url.Hostname()] type AuxRuleset struct {
Rules []Rule `yaml:"rules"`
}
yr := &AuxRuleset{}
if err := unmarshal(&yr); err != nil {
return err
}
rs._rulemap = make(map[string]*Rule)
rs.Rules = yr.Rules
// create a map of pointers to rules loaded above based on domain string keys
// this way we don't have two copies of the rule in ruleset
for i, rule := range rs.Rules {
rulePtr := &rs.Rules[i]
for _, domain := range rule.Domains {
rs._rulemap[domain] = rulePtr
if !strings.HasPrefix(domain, "www.") {
rs._rulemap["www."+domain] = rulePtr
}
}
}
return nil
}
// MarshalYAML implements the yaml.Marshaler interface.
// It customizes the marshaling of a Ruleset object into YAML
func (rs *Ruleset) MarshalYAML() (interface{}, error) {
type AuxRule struct {
Domains []string `yaml:"domains"`
RequestModifications []_rqm `yaml:"requestmodifications"`
ResponseModifications []_rsm `yaml:"responsemodifications"`
}
type Aux struct {
Rules []AuxRule `yaml:"rules"`
}
aux := Aux{}
for _, rule := range rs.Rules {
auxRule := AuxRule{
Domains: rule.Domains,
RequestModifications: rule._rqms,
ResponseModifications: rule._rsms,
}
aux.Rules = append(aux.Rules, auxRule)
}
out, err := yaml.Marshal(&aux)
return out, err
}
func (rs Ruleset) GetRule(url *url.URL) (rule *Rule, exists bool) {
rule, exists = rs._rulemap[url.Hostname()]
return rule, exists return rule, exists
} }
@@ -38,8 +95,8 @@ func (rs Ruleset) HasRule(url *url.URL) bool {
// NewRuleset loads a new RuleSet from a path // NewRuleset loads a new RuleSet from a path
func NewRuleset(path string) (Ruleset, error) { func NewRuleset(path string) (Ruleset, error) {
rs := Ruleset{ rs := Ruleset{
rulesetPath: path, _rulemap: map[string]*Rule{},
rules: map[string]Rule{}, Rules: []Rule{},
} }
switch { switch {
@@ -88,18 +145,20 @@ func (rs *Ruleset) loadRulesFromLocalDir(path string) error {
return nil return nil
} }
isYAML := filepath.Ext(path) == "yaml" || filepath.Ext(path) == "yml" isYAML := filepath.Ext(path) == ".yaml" || filepath.Ext(path) == ".yml"
if !isYAML { if !isYAML {
return nil return nil
} }
err = rs.loadRulesFromLocalFile(path) tmpRs := Ruleset{_rulemap: make(map[string]*Rule)}
err = tmpRs.loadRulesFromLocalFile(path)
if err != nil { if err != nil {
log.Printf("WARN: failed to load directory ruleset '%s': %s, skipping", path, err) log.Printf("WARN: failed to load directory ruleset '%s': %s, skipping", path, err)
return nil return nil
} }
rs.Rules = append(rs.Rules, tmpRs.Rules...)
log.Printf("INFO: loaded ruleset %s\n", path) //log.Printf("INFO: loaded ruleset %s\n", path)
return nil return nil
}) })
@@ -120,21 +179,12 @@ func (rs *Ruleset) loadRulesFromLocalFile(path string) error {
return errors.Join(e, err) return errors.Join(e, err)
} }
rule := Rule{} err = yaml.Unmarshal(yamlFile, rs)
err = yaml.Unmarshal(yamlFile, &rule)
if err != nil { if err != nil {
e := fmt.Errorf("failed to load rules from local file, possible syntax error in '%s' - %s", path, err) e := fmt.Errorf("failed to load rules from local file, possible syntax error in '%s' - %s", path, err)
ee := errors.Join(e, err) debugPrintRule(string(yamlFile), e)
debugPrintRule(string(yamlFile), ee) return e
return ee
}
for _, domain := range rule.Domains {
rs.rules[domain] = rule
if !strings.HasSuffix(domain, "www.") {
rs.rules["www."+domain] = rule
}
} }
return nil return nil
@@ -144,7 +194,6 @@ func (rs *Ruleset) loadRulesFromLocalFile(path string) error {
// It supports plain and gzip compressed content. // It supports plain and gzip compressed content.
// Returns an error if there's an issue accessing the URL or if there's a syntax error in the YAML. // Returns an error if there's an issue accessing the URL or if there's a syntax error in the YAML.
func (rs *Ruleset) loadRulesFromRemoteFile(rulesURL string) error { func (rs *Ruleset) loadRulesFromRemoteFile(rulesURL string) error {
rule := Rule{}
resp, err := http.Get(rulesURL) resp, err := http.Get(rulesURL)
if err != nil { if err != nil {
@@ -160,7 +209,7 @@ func (rs *Ruleset) loadRulesFromRemoteFile(rulesURL string) error {
var reader io.Reader var reader io.Reader
// in case remote server did not set content-encoding gzip header // in case remote server did not set content-encoding gzip header
isGzip := strings.HasSuffix(rulesURL, ".gz") || strings.HasSuffix(rulesURL, ".gzip") isGzip := strings.HasSuffix(rulesURL, ".gz") || strings.HasSuffix(rulesURL, ".gzip") || resp.Header.Get("content-encoding") == "gzip"
if isGzip { if isGzip {
reader, err = gzip.NewReader(resp.Body) reader, err = gzip.NewReader(resp.Body)
@@ -171,24 +220,12 @@ func (rs *Ruleset) loadRulesFromRemoteFile(rulesURL string) error {
reader = resp.Body reader = resp.Body
} }
err = yaml.NewDecoder(reader).Decode(&rule) err = yaml.NewDecoder(reader).Decode(&rs)
if err != nil { if err != nil {
return fmt.Errorf("failed to load rules from remote url '%s' with status code '%s' and possible syntax error - %s", rulesURL, resp.Status, err) return fmt.Errorf("failed to load rules from remote url '%s' with status code '%s' and possible syntax error - %s", rulesURL, resp.Status, err)
} }
if rs.rules == nil {
fmt.Println("nilmap")
rs.rules = make(map[string]Rule)
}
for _, domain := range rule.Domains {
rs.rules[domain] = rule
if !strings.HasSuffix(domain, "www.") {
rs.rules["www."+domain] = rule
}
}
return nil return nil
} }
@@ -204,31 +241,11 @@ func (rs *Ruleset) Yaml() (string, error) {
return string(y), nil return string(y), nil
} }
// GzipYaml returns an io.Reader that streams the Gzip-compressed YAML representation of the RuleSet.
func (rs *Ruleset) GzipYaml() (io.Reader, error) {
pr, pw := io.Pipe()
go func() {
defer pw.Close()
gw := gzip.NewWriter(pw)
defer gw.Close()
if err := yaml.NewEncoder(gw).Encode(rs); err != nil {
gw.Close() // Ensure to close the gzip writer
pw.CloseWithError(err)
return
}
}()
return pr, nil
}
// Domains extracts and returns a slice of all domains present in the RuleSet. // Domains extracts and returns a slice of all domains present in the RuleSet.
func (rs *Ruleset) Domains() []string { func (rs *Ruleset) Domains() []string {
var domains []string var domains []string
for domain := range rs.rules { for _, rule := range rs.Rules {
domains = append(domains, domain) domains = append(domains, rule.Domains...)
} }
return domains return domains
} }
@@ -240,7 +257,7 @@ func (rs *Ruleset) DomainCount() int {
// Count returns the total number of rules in the RuleSet. // Count returns the total number of rules in the RuleSet.
func (rs *Ruleset) Count() int { func (rs *Ruleset) Count() int {
return len(rs.rules) return len(rs.Rules)
} }
// PrintStats logs the number of rules and domains loaded in the RuleSet. // PrintStats logs the number of rules and domains loaded in the RuleSet.

View File

@@ -1,6 +1,7 @@
package ruleset_v2 package ruleset_v2
import ( import (
"fmt"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
@@ -13,38 +14,40 @@ import (
var ( var (
validYAML = ` validYAML = `
domains: rules:
- example.com - domains:
- www.example.com - example.com
responsemodifications: - www.example.com
- name: APIContent responsemodifications:
params: [] - name: APIContent
- name: SetContentSecurityPolicy params: []
params: - name: SetContentSecurityPolicy
- foobar params:
- name: SetIncomingCookie - foobar
params: - name: SetIncomingCookie
- authorization-bearer params:
- hunter2 - authorization-bearer
requestmodifications: - hunter2
- name: ForwardRequestHeaders requestmodifications:
params: [] - name: ForwardRequestHeaders
params: []
` `
invalidYAML = ` invalidYAML = `
domains: rules:
- example.com domains:
- www.example.com - example.com
responsemodifications: - www.example.com
- name: APIContent responsemodifications:
- name: SetContentSecurityPolicy - name: APIContent
- name: INVALIDSetIncomingCookie - name: SetContentSecurityPolicy
params: - name: INVALIDSetIncomingCookie
- authorization-bearer params:
- hunter2 - authorization-bearer
requestmodifications: - hunter2
- name: ForwardRequestHeaders requestmodifications:
params: [] - name: ForwardRequestHeaders
params: []
` `
) )
@@ -62,26 +65,6 @@ func TestLoadRulesFromRemoteFile(t *testing.T) {
return nil return nil
}) })
app.Get("/valid-config.gz", func(c *fiber.Ctx) error {
c.Set("Content-Type", "application/octet-stream")
rs, err := loadRuleFromString(validYAML)
if err != nil {
t.Errorf("failed to load valid yaml from string: %s", err.Error())
}
s, err := rs.GzipYaml()
if err != nil {
t.Errorf("failed to load gzip serialize yaml: %s", err.Error())
}
err = c.SendStream(s)
if err != nil {
t.Errorf("failed to stream gzip serialized yaml: %s", err.Error())
}
return nil
})
// Start the server in a goroutine // Start the server in a goroutine
go func() { go func() {
if err := app.Listen("127.0.0.1:9999"); err != nil { if err := app.Listen("127.0.0.1:9999"); err != nil {
@@ -102,25 +85,21 @@ func TestLoadRulesFromRemoteFile(t *testing.T) {
assert.True(t, exists, "expected example.com rule to be present") assert.True(t, exists, "expected example.com rule to be present")
assert.Equal(t, r.Domains[0], "example.com") assert.Equal(t, r.Domains[0], "example.com")
u, _ = url.Parse("http://www.www.example.com") u, _ = url.Parse("http://www.www.foobar.com")
_, exists = rs.GetRule(u) _, exists = rs.GetRule(u)
assert.False(t, exists, "expected www.www.example.com rule to NOT be present") assert.False(t, exists, "expected www.www.foobar.com rule to NOT be present")
rs, err = NewRuleset("http://127.0.0.1:9999/valid-config.gz")
if err != nil {
t.Errorf("failed to load gzipped ruleset from http server: %s", err.Error())
}
u, _ = url.Parse("http://example.com")
r, exists = rs.GetRule(u) r, exists = rs.GetRule(u)
assert.Equal(t, r.Domains[0], "example.com") assert.Equal(t, r.Domains[0], "example.com")
os.Setenv("RULESET", "http://127.0.0.1:9999/valid-config.gz") os.Setenv("RULESET", "http://127.0.0.1:9999/valid-config.yml")
rs = NewRulesetFromEnv() rs = NewRulesetFromEnv()
r, exists = rs.GetRule(u) r, exists = rs.GetRule(u)
assert.True(t, exists, "expected example.com rule to be present") assert.True(t, exists, "expected example.com rule to be present from env")
if !assert.Equal(t, r.Domains[0], "example.com") { if !assert.Equal(t, r.Domains[0], "example.com") {
t.Error("expected no errors loading ruleset from gzip url using environment variable, but got one") t.Error("expected no errors loading ruleset from url using environment variable, but got one")
} }
} }
@@ -132,7 +111,10 @@ func loadRuleFromString(yaml string) (Ruleset, error) {
tmpFile.WriteString(yaml) tmpFile.WriteString(yaml)
rs := Ruleset{} rs := Ruleset{
_rulemap: map[string]*Rule{},
Rules: []Rule{},
}
err := rs.loadRulesFromLocalFile(tmpFile.Name()) err := rs.loadRulesFromLocalFile(tmpFile.Name())
return rs, err return rs, err
@@ -189,8 +171,9 @@ func TestLoadRulesFromLocalDir(t *testing.T) {
} }
rs := Ruleset{} rs := Ruleset{}
fmt.Println(baseDir)
err = rs.loadRulesFromLocalDir(baseDir) err = rs.loadRulesFromLocalDir(baseDir)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, rs.Count(), len(testCases)*3) assert.Equal(t, len(testCases)*3, rs.Count())
} }