diff --git a/pkg/ruleset/ruleset.go b/pkg/ruleset/ruleset.go deleted file mode 100644 index 24fa577..0000000 --- a/pkg/ruleset/ruleset.go +++ /dev/null @@ -1,347 +0,0 @@ -package ruleset - -import ( - "compress/gzip" - "errors" - "fmt" - "io" - "log" - "net/http" - "os" - "path/filepath" - "regexp" - "strings" - - "gopkg.in/yaml.v3" -) - -type Regex struct { - Match string `yaml:"match"` - Replace string `yaml:"replace"` -} -type KV struct { - Key string `yaml:"key"` - Value string `yaml:"value"` -} - -type RuleSet []Rule - -type Rule struct { - Domain string `yaml:"domain,omitempty"` - Domains []string `yaml:"domains,omitempty"` - Paths []string `yaml:"paths,omitempty"` - Headers struct { - UserAgent string `yaml:"user-agent,omitempty"` - XForwardedFor string `yaml:"x-forwarded-for,omitempty"` - Referer string `yaml:"referer,omitempty"` - Cookie string `yaml:"cookie,omitempty"` - CSP string `yaml:"content-security-policy,omitempty"` - } `yaml:"headers,omitempty"` - GoogleCache bool `yaml:"googleCache,omitempty"` - RegexRules []Regex `yaml:"regexRules,omitempty"` - - URLMods struct { - Domain []Regex `yaml:"domain,omitempty"` - Path []Regex `yaml:"path,omitempty"` - Query []KV `yaml:"query,omitempty"` - } `yaml:"urlMods,omitempty"` - - Injections []struct { - Position string `yaml:"position,omitempty"` - Append string `yaml:"append,omitempty"` - Prepend string `yaml:"prepend,omitempty"` - Replace string `yaml:"replace,omitempty"` - } `yaml:"injections,omitempty"` -} - -var remoteRegex = regexp.MustCompile(`^https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()!@:%_\+.~#?&\/\/=]*)`) - -// NewRulesetFromEnv creates a new RuleSet based on the RULESET environment variable. -// It logs a warning and returns an empty RuleSet if the RULESET environment variable is not set. -// If the RULESET is set but the rules cannot be loaded, it panics. -func NewRulesetFromEnv() RuleSet { - rulesPath, ok := os.LookupEnv("RULESET") - if !ok { - log.Printf("WARN: No ruleset specified. Set the `RULESET` environment variable to load one for a better success rate.") - return RuleSet{} - } - - ruleSet, err := NewRuleset(rulesPath) - if err != nil { - log.Println(err) - } - - return ruleSet -} - -// NewRuleset loads a RuleSet from a given string of rule paths, separated by semicolons. -// It supports loading rules from both local file paths and remote URLs. -// Returns a RuleSet and an error if any issues occur during loading. -func NewRuleset(rulePaths string) (RuleSet, error) { - var ruleSet RuleSet - - var errs []error - - rp := strings.Split(rulePaths, ";") - for _, rule := range rp { - var err error - - rulePath := strings.Trim(rule, " ") - isRemote := remoteRegex.MatchString(rulePath) - - if isRemote { - err = ruleSet.loadRulesFromRemoteFile(rulePath) - } else { - err = ruleSet.loadRulesFromLocalDir(rulePath) - } - - if err != nil { - e := fmt.Errorf("WARN: failed to load ruleset from '%s'", rulePath) - errs = append(errs, errors.Join(e, err)) - - continue - } - } - - if len(errs) != 0 { - e := fmt.Errorf("WARN: failed to load %d rulesets", len(rp)) - errs = append(errs, e) - - // panic if the user specified a local ruleset, but it wasn't found on disk - // don't fail silently - for _, err := range errs { - if errors.Is(os.ErrNotExist, err) { - e := fmt.Errorf("PANIC: ruleset '%s' not found", err) - panic(errors.Join(e, err)) - } - } - - // else, bubble up any errors, such as syntax or remote host issues - return ruleSet, errors.Join(errs...) - } - - ruleSet.PrintStats() - - return ruleSet, nil -} - -// ================== RULESET loading logic =================================== - -// loadRulesFromLocalDir loads rules from a local directory specified by the path. -// It walks through the directory, loading rules from YAML files. -// Returns an error if the directory cannot be accessed -// If there is an issue loading any file, it will be skipped -func (rs *RuleSet) loadRulesFromLocalDir(path string) error { - _, err := os.Stat(path) - if err != nil { - return err - } - - yamlRegex := regexp.MustCompile(`.*\.ya?ml`) - - err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - if info.IsDir() { - return nil - } - if isYaml := yamlRegex.MatchString(path); !isYaml { - return nil - } - - err = rs.loadRulesFromLocalFile(path) - if err != nil { - log.Printf("WARN: failed to load directory ruleset '%s': %s, skipping", path, err) - return nil - } - - log.Printf("INFO: loaded ruleset %s\n", path) - - return nil - }) - - if err != nil { - return err - } - - return nil -} - -// loadRulesFromLocalFile loads rules from a local YAML file specified by the path. -// Returns an error if the file cannot be read or if there's a syntax error in the YAML. -func (rs *RuleSet) loadRulesFromLocalFile(path string) error { - yamlFile, err := os.ReadFile(path) - if err != nil { - e := fmt.Errorf("failed to read rules from local file: '%s'", path) - return errors.Join(e, err) - } - - var r RuleSet - err = yaml.Unmarshal(yamlFile, &r) - - if err != nil { - e := fmt.Errorf("failed to load rules from local file, possible syntax error in '%s'", path) - ee := errors.Join(e, err) - - if _, ok := os.LookupEnv("DEBUG"); ok { - debugPrintRule(string(yamlFile), ee) - } - - return ee - } - - *rs = append(*rs, r...) - - return nil -} - -// loadRulesFromRemoteFile loads rules from a remote URL. -// It supports plain and gzip compressed content. -// Returns an error if there's an issue accessing the URL or if there's a syntax error in the YAML. -func (rs *RuleSet) loadRulesFromRemoteFile(rulesURL string) error { - var r RuleSet - - resp, err := http.Get(rulesURL) - if err != nil { - e := fmt.Errorf("failed to load rules from remote url '%s'", rulesURL) - return errors.Join(e, err) - } - - defer resp.Body.Close() - - if resp.StatusCode >= 400 { - e := fmt.Errorf("failed to load rules from remote url (%s) on '%s'", resp.Status, rulesURL) - return errors.Join(e, err) - } - - var reader io.Reader - - isGzip := strings.HasSuffix(rulesURL, ".gz") || strings.HasSuffix(rulesURL, ".gzip") || resp.Header.Get("content-encoding") == "gzip" - - if isGzip { - reader, err = gzip.NewReader(resp.Body) - - if err != nil { - return fmt.Errorf("failed to create gzip reader for URL '%s' with status code '%s': %w", rulesURL, resp.Status, err) - } - } else { - reader = resp.Body - } - - err = yaml.NewDecoder(reader).Decode(&r) - - if err != nil { - e := fmt.Errorf("failed to load rules from remote url '%s' with status code '%s' and possible syntax error", rulesURL, resp.Status) - ee := errors.Join(e, err) - - return ee - } - - *rs = append(*rs, r...) - - return nil -} - -// ================= utility methods ========================== - -// Yaml returns the ruleset as a Yaml string -func (rs *RuleSet) Yaml() (string, error) { - y, err := yaml.Marshal(rs) - if err != nil { - return "", err - } - - return string(y), nil -} - -// GzipYaml returns an io.Reader that streams the Gzip-compressed YAML representation of the RuleSet. -func (rs *RuleSet) GzipYaml() (io.Reader, error) { - pr, pw := io.Pipe() - - go func() { - defer pw.Close() - - gw := gzip.NewWriter(pw) - defer gw.Close() - - if err := yaml.NewEncoder(gw).Encode(rs); err != nil { - gw.Close() // Ensure to close the gzip writer - pw.CloseWithError(err) - return - } - }() - - return pr, nil -} - -// Domains extracts and returns a slice of all domains present in the RuleSet. -func (rs *RuleSet) Domains() []string { - var domains []string - for _, rule := range *rs { - domains = append(domains, rule.Domain) - domains = append(domains, rule.Domains...) - } - return domains -} - -// DomainCount returns the count of unique domains present in the RuleSet. -func (rs *RuleSet) DomainCount() int { - return len(rs.Domains()) -} - -// Count returns the total number of rules in the RuleSet. -func (rs *RuleSet) Count() int { - return len(*rs) -} - -// PrintStats logs the number of rules and domains loaded in the RuleSet. -func (rs *RuleSet) PrintStats() { - log.Printf("INFO: Loaded %d rules for %d domains\n", rs.Count(), rs.DomainCount()) -} - -// debugPrintRule is a utility function for printing a rule and associated error for debugging purposes. -func debugPrintRule(rule string, err error) { - fmt.Println("------------------------------ BEGIN DEBUG RULESET -----------------------------") - fmt.Printf("%s\n", err.Error()) - fmt.Println("--------------------------------------------------------------------------------") - fmt.Println(rule) - fmt.Println("------------------------------ END DEBUG RULESET -------------------------------") -} - -// ======================= RuleSetMap implementation ================================================= - -// RuleSetMap: A map with domain names as keys and pointers to the corresponding Rules as values. -// This type is used to efficiently access rules based on domain names. -type RuleSetMap map[string]*Rule - -// ToMap converts a RuleSet into a RuleSetMap. It transforms each Rule in the RuleSet -// into a map entry where the key is the Rule's domain (lowercase) -// and the value is a pointer to the Rule. This method is used to -// efficiently access rules based on domain names. -// The RuleSetMap may be accessed with or without a "www." prefix in the domain. -func (rs *RuleSet) ToMap() RuleSetMap { - rsm := make(RuleSetMap) - - addMapEntry := func(d string, rule *Rule) { - d = strings.ToLower(d) - rsm[d] = rule - if strings.HasPrefix(d, "www.") { - d = strings.TrimPrefix(d, "www.") - rsm[d] = rule - } else { - d = fmt.Sprintf("www.%s", d) - rsm[d] = rule - } - } - - for i, rule := range *rs { - rulePtr := &(*rs)[i] - addMapEntry(rule.Domain, rulePtr) - for _, domain := range rule.Domains { - addMapEntry(domain, rulePtr) - } - } - - return rsm -} diff --git a/pkg/ruleset/ruleset_test.go b/pkg/ruleset/ruleset_test.go deleted file mode 100644 index 99bd663..0000000 --- a/pkg/ruleset/ruleset_test.go +++ /dev/null @@ -1,225 +0,0 @@ -package ruleset - -import ( - "os" - "path/filepath" - "testing" - "time" - - "github.com/gofiber/fiber/v2" - "github.com/stretchr/testify/assert" -) - -var ( - validYAML = ` -- domain: example.com - regexRules: - - match: "^http:" - replace: "https:"` - - invalidYAML = ` -- domain: [thisIsATestYamlThatIsMeantToFail.example] - regexRules: - - match: "^http:" - replace: "https:" - - match: "[incomplete"` -) - -func TestLoadRulesFromRemoteFile(t *testing.T) { - app := fiber.New() - defer app.Shutdown() - - app.Get("/valid-config.yml", func(c *fiber.Ctx) error { - c.SendString(validYAML) - return nil - }) - - app.Get("/invalid-config.yml", func(c *fiber.Ctx) error { - c.SendString(invalidYAML) - return nil - }) - - app.Get("/valid-config.gz", func(c *fiber.Ctx) error { - c.Set("Content-Type", "application/octet-stream") - - rs, err := loadRuleFromString(validYAML) - if err != nil { - t.Errorf("failed to load valid yaml from string: %s", err.Error()) - } - - s, err := rs.GzipYaml() - if err != nil { - t.Errorf("failed to load gzip serialize yaml: %s", err.Error()) - } - - err = c.SendStream(s) - if err != nil { - t.Errorf("failed to stream gzip serialized yaml: %s", err.Error()) - } - return nil - }) - - // Start the server in a goroutine - go func() { - if err := app.Listen("127.0.0.1:9999"); err != nil { - t.Errorf("Server failed to start: %s", err.Error()) - } - }() - - // Wait for the server to start - time.Sleep(time.Second * 1) - - rs, err := NewRuleset("http://127.0.0.1:9999/valid-config.yml") - if err != nil { - t.Errorf("failed to load plaintext ruleset from http server: %s", err.Error()) - } - - assert.Equal(t, rs[0].Domain, "example.com") - - rs, err = NewRuleset("http://127.0.0.1:9999/valid-config.gz") - if err != nil { - t.Errorf("failed to load gzipped ruleset from http server: %s", err.Error()) - } - - assert.Equal(t, rs[0].Domain, "example.com") - - os.Setenv("RULESET", "http://127.0.0.1:9999/valid-config.gz") - - rs = NewRulesetFromEnv() - if !assert.Equal(t, rs[0].Domain, "example.com") { - t.Error("expected no errors loading ruleset from gzip url using environment variable, but got one") - } -} - -func loadRuleFromString(yaml string) (RuleSet, error) { - // Create a temporary file and load it - tmpFile, _ := os.CreateTemp("", "ruleset*.yaml") - - defer os.Remove(tmpFile.Name()) - - tmpFile.WriteString(yaml) - - rs := RuleSet{} - err := rs.loadRulesFromLocalFile(tmpFile.Name()) - - return rs, err -} - -// TestLoadRulesFromLocalFile tests the loading of rules from a local YAML file. -func TestLoadRulesFromLocalFile(t *testing.T) { - rs, err := loadRuleFromString(validYAML) - if err != nil { - t.Errorf("Failed to load rules from valid YAML: %s", err) - } - - assert.Equal(t, rs[0].Domain, "example.com") - assert.Equal(t, rs[0].RegexRules[0].Match, "^http:") - assert.Equal(t, rs[0].RegexRules[0].Replace, "https:") - - _, err = loadRuleFromString(invalidYAML) - if err == nil { - t.Errorf("Expected an error when loading invalid YAML, but got none") - } -} - -// TestLoadRulesFromLocalDir tests the loading of rules from a local nested directory full of yaml rulesets -func TestLoadRulesFromLocalDir(t *testing.T) { - // Create a temporary directory - baseDir, err := os.MkdirTemp("", "ruleset_test") - if err != nil { - t.Fatalf("Failed to create temporary directory: %s", err) - } - - defer os.RemoveAll(baseDir) - - // Create a nested subdirectory - nestedDir := filepath.Join(baseDir, "nested") - err = os.Mkdir(nestedDir, 0o755) - - if err != nil { - t.Fatalf("Failed to create nested directory: %s", err) - } - - // Create a nested subdirectory - nestedTwiceDir := filepath.Join(nestedDir, "nestedTwice") - err = os.Mkdir(nestedTwiceDir, 0o755) - if err != nil { - t.Fatalf("Failed to create twice-nested directory: %s", err) - } - - testCases := []string{"test.yaml", "test2.yaml", "test-3.yaml", "test 4.yaml", "1987.test.yaml.yml", "foobar.example.com.yaml", "foobar.com.yml"} - for _, fileName := range testCases { - filePath := filepath.Join(nestedDir, "2x-"+fileName) - os.WriteFile(filePath, []byte(validYAML), 0o644) - - filePath = filepath.Join(nestedDir, fileName) - os.WriteFile(filePath, []byte(validYAML), 0o644) - - filePath = filepath.Join(baseDir, "base-"+fileName) - os.WriteFile(filePath, []byte(validYAML), 0o644) - } - - rs := RuleSet{} - err = rs.loadRulesFromLocalDir(baseDir) - - assert.NoError(t, err) - assert.Equal(t, rs.Count(), len(testCases)*3) - - for _, rule := range rs { - assert.Equal(t, rule.Domain, "example.com") - assert.Equal(t, rule.RegexRules[0].Match, "^http:") - assert.Equal(t, rule.RegexRules[0].Replace, "https:") - } -} - -func TestToMap(t *testing.T) { - // Prepare a ruleset with multiple rules, including "www." prefixed domains - rules := RuleSet{ - { - Domain: "Example.com", - RegexRules: []Regex{{Match: "match1", Replace: "replace1"}}, - }, - { - Domain: "www.AnotherExample.com", - RegexRules: []Regex{{Match: "match2", Replace: "replace2"}}, - }, - { - Domain: "www.foo.bAr.baz.bOol.quX.com", - RegexRules: []Regex{{Match: "match3", Replace: "replace3"}}, - }, - } - - // Convert to RuleSetMap - rsm := rules.ToMap() - - // Test for correct number of entries - if len(rsm) != 6 { - t.Errorf("Expected 6 entries in RuleSetMap, got %d", len(rsm)) - } - - // Test for correct mapping - testDomains := []struct { - domain string - expectedMatch string - }{ - {"example.com", "match1"}, - {"www.example.com", "match1"}, - {"anotherexample.com", "match2"}, - {"www.anotherexample.com", "match2"}, - {"foo.bar.baz.bool.qux.com", "match3"}, - {"no.ruleset.domain.com", ""}, - } - - for _, test := range testDomains { - if test.domain == "no.ruleset.domain.com" { - assert.Empty(t, test.expectedMatch) - continue - } - rule, exists := rsm[test.domain] - if !exists { - t.Errorf("Expected domain %s to exist in RuleSetMap", test.domain) - } else if rule.RegexRules[0].Match != test.expectedMatch { - t.Errorf("Expected match for %s to be %s, got %s", test.domain, test.expectedMatch, rule.RegexRules[0].Match) - } - } -} diff --git a/proxychain/ruleset/ruleset.go b/proxychain/ruleset/ruleset.go index 8475488..9731df6 100644 --- a/proxychain/ruleset/ruleset.go +++ b/proxychain/ruleset/ruleset.go @@ -13,6 +13,7 @@ import ( "strings" "gopkg.in/yaml.v3" + //"encoding/json" ) type IRuleset interface { @@ -21,12 +22,68 @@ type IRuleset interface { } type Ruleset struct { - rulesetPath string - rules map[string]Rule + Rules []Rule `json:"rules" yaml:"rules"` + _rulemap map[string]*Rule // internal map for fast lookups; points at a rule in the Rules slice } -func (rs Ruleset) GetRule(url *url.URL) (rule Rule, exists bool) { - rule, exists = rs.rules[url.Hostname()] +func (rs *Ruleset) UnmarshalYAML(unmarshal func(interface{}) error) error { + type AuxRuleset struct { + Rules []Rule `yaml:"rules"` + } + yr := &AuxRuleset{} + + if err := unmarshal(&yr); err != nil { + return err + } + + rs._rulemap = make(map[string]*Rule) + rs.Rules = yr.Rules + + // create a map of pointers to rules loaded above based on domain string keys + // this way we don't have two copies of the rule in ruleset + for i, rule := range rs.Rules { + rulePtr := &rs.Rules[i] + for _, domain := range rule.Domains { + rs._rulemap[domain] = rulePtr + if !strings.HasPrefix(domain, "www.") { + rs._rulemap["www."+domain] = rulePtr + } + } + } + + return nil +} + +// MarshalYAML implements the yaml.Marshaler interface. +// It customizes the marshaling of a Ruleset object into YAML +func (rs *Ruleset) MarshalYAML() (interface{}, error) { + + type AuxRule struct { + Domains []string `yaml:"domains"` + RequestModifications []_rqm `yaml:"requestmodifications"` + ResponseModifications []_rsm `yaml:"responsemodifications"` + } + type Aux struct { + Rules []AuxRule `yaml:"rules"` + } + + aux := Aux{} + + for _, rule := range rs.Rules { + auxRule := AuxRule{ + Domains: rule.Domains, + RequestModifications: rule._rqms, + ResponseModifications: rule._rsms, + } + aux.Rules = append(aux.Rules, auxRule) + } + + out, err := yaml.Marshal(&aux) + return out, err +} + +func (rs Ruleset) GetRule(url *url.URL) (rule *Rule, exists bool) { + rule, exists = rs._rulemap[url.Hostname()] return rule, exists } @@ -38,8 +95,8 @@ func (rs Ruleset) HasRule(url *url.URL) bool { // NewRuleset loads a new RuleSet from a path func NewRuleset(path string) (Ruleset, error) { rs := Ruleset{ - rulesetPath: path, - rules: map[string]Rule{}, + _rulemap: map[string]*Rule{}, + Rules: []Rule{}, } switch { @@ -88,18 +145,20 @@ func (rs *Ruleset) loadRulesFromLocalDir(path string) error { return nil } - isYAML := filepath.Ext(path) == "yaml" || filepath.Ext(path) == "yml" + isYAML := filepath.Ext(path) == ".yaml" || filepath.Ext(path) == ".yml" if !isYAML { return nil } - err = rs.loadRulesFromLocalFile(path) + tmpRs := Ruleset{_rulemap: make(map[string]*Rule)} + err = tmpRs.loadRulesFromLocalFile(path) if err != nil { log.Printf("WARN: failed to load directory ruleset '%s': %s, skipping", path, err) return nil } + rs.Rules = append(rs.Rules, tmpRs.Rules...) - log.Printf("INFO: loaded ruleset %s\n", path) + //log.Printf("INFO: loaded ruleset %s\n", path) return nil }) @@ -120,21 +179,12 @@ func (rs *Ruleset) loadRulesFromLocalFile(path string) error { return errors.Join(e, err) } - rule := Rule{} - err = yaml.Unmarshal(yamlFile, &rule) + err = yaml.Unmarshal(yamlFile, rs) if err != nil { e := fmt.Errorf("failed to load rules from local file, possible syntax error in '%s' - %s", path, err) - ee := errors.Join(e, err) - debugPrintRule(string(yamlFile), ee) - return ee - } - - for _, domain := range rule.Domains { - rs.rules[domain] = rule - if !strings.HasSuffix(domain, "www.") { - rs.rules["www."+domain] = rule - } + debugPrintRule(string(yamlFile), e) + return e } return nil @@ -144,7 +194,6 @@ func (rs *Ruleset) loadRulesFromLocalFile(path string) error { // It supports plain and gzip compressed content. // Returns an error if there's an issue accessing the URL or if there's a syntax error in the YAML. func (rs *Ruleset) loadRulesFromRemoteFile(rulesURL string) error { - rule := Rule{} resp, err := http.Get(rulesURL) if err != nil { @@ -160,7 +209,7 @@ func (rs *Ruleset) loadRulesFromRemoteFile(rulesURL string) error { var reader io.Reader // in case remote server did not set content-encoding gzip header - isGzip := strings.HasSuffix(rulesURL, ".gz") || strings.HasSuffix(rulesURL, ".gzip") + isGzip := strings.HasSuffix(rulesURL, ".gz") || strings.HasSuffix(rulesURL, ".gzip") || resp.Header.Get("content-encoding") == "gzip" if isGzip { reader, err = gzip.NewReader(resp.Body) @@ -171,24 +220,12 @@ func (rs *Ruleset) loadRulesFromRemoteFile(rulesURL string) error { reader = resp.Body } - err = yaml.NewDecoder(reader).Decode(&rule) + err = yaml.NewDecoder(reader).Decode(&rs) if err != nil { return fmt.Errorf("failed to load rules from remote url '%s' with status code '%s' and possible syntax error - %s", rulesURL, resp.Status, err) } - if rs.rules == nil { - fmt.Println("nilmap") - rs.rules = make(map[string]Rule) - } - - for _, domain := range rule.Domains { - rs.rules[domain] = rule - if !strings.HasSuffix(domain, "www.") { - rs.rules["www."+domain] = rule - } - } - return nil } @@ -204,31 +241,11 @@ func (rs *Ruleset) Yaml() (string, error) { return string(y), nil } -// GzipYaml returns an io.Reader that streams the Gzip-compressed YAML representation of the RuleSet. -func (rs *Ruleset) GzipYaml() (io.Reader, error) { - pr, pw := io.Pipe() - - go func() { - defer pw.Close() - - gw := gzip.NewWriter(pw) - defer gw.Close() - - if err := yaml.NewEncoder(gw).Encode(rs); err != nil { - gw.Close() // Ensure to close the gzip writer - pw.CloseWithError(err) - return - } - }() - - return pr, nil -} - // Domains extracts and returns a slice of all domains present in the RuleSet. func (rs *Ruleset) Domains() []string { var domains []string - for domain := range rs.rules { - domains = append(domains, domain) + for _, rule := range rs.Rules { + domains = append(domains, rule.Domains...) } return domains } @@ -240,7 +257,7 @@ func (rs *Ruleset) DomainCount() int { // Count returns the total number of rules in the RuleSet. func (rs *Ruleset) Count() int { - return len(rs.rules) + return len(rs.Rules) } // PrintStats logs the number of rules and domains loaded in the RuleSet. diff --git a/proxychain/ruleset/ruleset_test.go b/proxychain/ruleset/ruleset_test.go index cec8285..c7311bd 100644 --- a/proxychain/ruleset/ruleset_test.go +++ b/proxychain/ruleset/ruleset_test.go @@ -1,6 +1,7 @@ package ruleset_v2 import ( + "fmt" "net/url" "os" "path/filepath" @@ -13,38 +14,40 @@ import ( var ( validYAML = ` -domains: -- example.com -- www.example.com -responsemodifications: -- name: APIContent - params: [] -- name: SetContentSecurityPolicy - params: - - foobar -- name: SetIncomingCookie - params: - - authorization-bearer - - hunter2 -requestmodifications: -- name: ForwardRequestHeaders - params: [] +rules: + - domains: + - example.com + - www.example.com + responsemodifications: + - name: APIContent + params: [] + - name: SetContentSecurityPolicy + params: + - foobar + - name: SetIncomingCookie + params: + - authorization-bearer + - hunter2 + requestmodifications: + - name: ForwardRequestHeaders + params: [] ` invalidYAML = ` -domains: -- example.com -- www.example.com -responsemodifications: -- name: APIContent -- name: SetContentSecurityPolicy -- name: INVALIDSetIncomingCookie - params: - - authorization-bearer - - hunter2 -requestmodifications: -- name: ForwardRequestHeaders - params: [] +rules: + domains: + - example.com + - www.example.com + responsemodifications: + - name: APIContent + - name: SetContentSecurityPolicy + - name: INVALIDSetIncomingCookie + params: + - authorization-bearer + - hunter2 + requestmodifications: + - name: ForwardRequestHeaders + params: [] ` ) @@ -62,26 +65,6 @@ func TestLoadRulesFromRemoteFile(t *testing.T) { return nil }) - app.Get("/valid-config.gz", func(c *fiber.Ctx) error { - c.Set("Content-Type", "application/octet-stream") - - rs, err := loadRuleFromString(validYAML) - if err != nil { - t.Errorf("failed to load valid yaml from string: %s", err.Error()) - } - - s, err := rs.GzipYaml() - if err != nil { - t.Errorf("failed to load gzip serialize yaml: %s", err.Error()) - } - - err = c.SendStream(s) - if err != nil { - t.Errorf("failed to stream gzip serialized yaml: %s", err.Error()) - } - return nil - }) - // Start the server in a goroutine go func() { if err := app.Listen("127.0.0.1:9999"); err != nil { @@ -102,25 +85,21 @@ func TestLoadRulesFromRemoteFile(t *testing.T) { assert.True(t, exists, "expected example.com rule to be present") assert.Equal(t, r.Domains[0], "example.com") - u, _ = url.Parse("http://www.www.example.com") + u, _ = url.Parse("http://www.www.foobar.com") _, exists = rs.GetRule(u) - assert.False(t, exists, "expected www.www.example.com rule to NOT be present") - - rs, err = NewRuleset("http://127.0.0.1:9999/valid-config.gz") - if err != nil { - t.Errorf("failed to load gzipped ruleset from http server: %s", err.Error()) - } + assert.False(t, exists, "expected www.www.foobar.com rule to NOT be present") + u, _ = url.Parse("http://example.com") r, exists = rs.GetRule(u) assert.Equal(t, r.Domains[0], "example.com") - os.Setenv("RULESET", "http://127.0.0.1:9999/valid-config.gz") + os.Setenv("RULESET", "http://127.0.0.1:9999/valid-config.yml") rs = NewRulesetFromEnv() r, exists = rs.GetRule(u) - assert.True(t, exists, "expected example.com rule to be present") + assert.True(t, exists, "expected example.com rule to be present from env") if !assert.Equal(t, r.Domains[0], "example.com") { - t.Error("expected no errors loading ruleset from gzip url using environment variable, but got one") + t.Error("expected no errors loading ruleset from url using environment variable, but got one") } } @@ -132,7 +111,10 @@ func loadRuleFromString(yaml string) (Ruleset, error) { tmpFile.WriteString(yaml) - rs := Ruleset{} + rs := Ruleset{ + _rulemap: map[string]*Rule{}, + Rules: []Rule{}, + } err := rs.loadRulesFromLocalFile(tmpFile.Name()) return rs, err @@ -189,8 +171,9 @@ func TestLoadRulesFromLocalDir(t *testing.T) { } rs := Ruleset{} + fmt.Println(baseDir) err = rs.loadRulesFromLocalDir(baseDir) assert.NoError(t, err) - assert.Equal(t, rs.Count(), len(testCases)*3) + assert.Equal(t, len(testCases)*3, rs.Count()) }