diff --git a/proxychain/ruleset/rule_test.go b/proxychain/ruleset/rule_test.go index 98f7bc9..6829350 100644 --- a/proxychain/ruleset/rule_test.go +++ b/proxychain/ruleset/rule_test.go @@ -102,7 +102,6 @@ requestmodifications: if len(rule.RequestModifications) != 1 { t.Errorf("expected number of RequestModifications to be 1, got %d", len(rule.RequestModifications)) } - fmt.Println(rule.RequestModifications[0].Name) } func TestRuleMarshalYAML(t *testing.T) { diff --git a/proxychain/ruleset/ruleset.go b/proxychain/ruleset/ruleset.go index 499f0ad..8475488 100644 --- a/proxychain/ruleset/ruleset.go +++ b/proxychain/ruleset/ruleset.go @@ -1,7 +1,18 @@ package ruleset_v2 import ( + "compress/gzip" + "errors" + "fmt" + "io" + "log" + "net/http" "net/url" + "os" + "path/filepath" + "strings" + + "gopkg.in/yaml.v3" ) type IRuleset interface { @@ -14,19 +25,234 @@ type Ruleset struct { rules map[string]Rule } -func (rs Ruleset) GetRule(url url.URL) (rule Rule, exists bool) { +func (rs Ruleset) GetRule(url *url.URL) (rule Rule, exists bool) { rule, exists = rs.rules[url.Hostname()] return rule, exists } -func (rs Ruleset) HasRule(url url.URL) bool { +func (rs Ruleset) HasRule(url *url.URL) bool { _, exists := rs.GetRule(url) return exists } +// NewRuleset loads a new RuleSet from a path func NewRuleset(path string) (Ruleset, error) { rs := Ruleset{ rulesetPath: path, + rules: map[string]Rule{}, + } + + switch { + case strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://"): + err := rs.loadRulesFromRemoteFile(path) + return rs, err + default: + err := rs.loadRulesFromLocalDir(path) + return rs, err } - return rs, nil +} + +// NewRulesetFromEnv creates a new RuleSet based on the RULESET environment variable. +// It logs a warning and returns an empty RuleSet if the RULESET environment variable is not set. +// If the RULESET is set but the rules cannot be loaded, it panics. +func NewRulesetFromEnv() Ruleset { + rulesPath, ok := os.LookupEnv("RULESET") + if !ok { + log.Printf("WARN: No ruleset specified. Set the `RULESET` environment variable to load one for a better success rate.") + return Ruleset{} + } + + ruleSet, err := NewRuleset(rulesPath) + if err != nil { + log.Println(err) + } + + return ruleSet +} + +// loadRulesFromLocalDir loads rules from a local directory specified by the path. +// It walks through the directory, loading rules from YAML files. +// Returns an error if the directory cannot be accessed +// If there is an issue loading any file, it will be skipped +func (rs *Ruleset) loadRulesFromLocalDir(path string) error { + _, err := os.Stat(path) + if err != nil { + return fmt.Errorf("loadRulesFromLocalDir: invalid path - %s", err) + } + + err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + + isYAML := filepath.Ext(path) == "yaml" || filepath.Ext(path) == "yml" + if !isYAML { + return nil + } + + err = rs.loadRulesFromLocalFile(path) + if err != nil { + log.Printf("WARN: failed to load directory ruleset '%s': %s, skipping", path, err) + return nil + } + + log.Printf("INFO: loaded ruleset %s\n", path) + + return nil + }) + + if err != nil { + return err + } + + return nil +} + +// loadRulesFromLocalFile loads rules from a local YAML file specified by the path. +// Returns an error if the file cannot be read or if there's a syntax error in the YAML. +func (rs *Ruleset) loadRulesFromLocalFile(path string) error { + yamlFile, err := os.ReadFile(path) + if err != nil { + e := fmt.Errorf("failed to read rules from local file: '%s'", path) + return errors.Join(e, err) + } + + rule := Rule{} + err = yaml.Unmarshal(yamlFile, &rule) + + if err != nil { + e := fmt.Errorf("failed to load rules from local file, possible syntax error in '%s' - %s", path, err) + ee := errors.Join(e, err) + debugPrintRule(string(yamlFile), ee) + return ee + } + + for _, domain := range rule.Domains { + rs.rules[domain] = rule + if !strings.HasSuffix(domain, "www.") { + rs.rules["www."+domain] = rule + } + } + + return nil +} + +// loadRulesFromRemoteFile loads rules from a remote URL. +// It supports plain and gzip compressed content. +// Returns an error if there's an issue accessing the URL or if there's a syntax error in the YAML. +func (rs *Ruleset) loadRulesFromRemoteFile(rulesURL string) error { + rule := Rule{} + + resp, err := http.Get(rulesURL) + if err != nil { + return fmt.Errorf("failed to load rules from remote url '%s' - %s", rulesURL, err) + } + + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + return fmt.Errorf("failed to load rules from remote url (%s) on '%s' - %s", resp.Status, rulesURL, err) + } + + var reader io.Reader + + // in case remote server did not set content-encoding gzip header + isGzip := strings.HasSuffix(rulesURL, ".gz") || strings.HasSuffix(rulesURL, ".gzip") + if isGzip { + reader, err = gzip.NewReader(resp.Body) + + if err != nil { + return fmt.Errorf("failed to create gzip reader for URL '%s' with status code '%s': %w", rulesURL, resp.Status, err) + } + } else { + reader = resp.Body + } + + err = yaml.NewDecoder(reader).Decode(&rule) + + if err != nil { + return fmt.Errorf("failed to load rules from remote url '%s' with status code '%s' and possible syntax error - %s", rulesURL, resp.Status, err) + } + + if rs.rules == nil { + fmt.Println("nilmap") + rs.rules = make(map[string]Rule) + } + + for _, domain := range rule.Domains { + rs.rules[domain] = rule + if !strings.HasSuffix(domain, "www.") { + rs.rules["www."+domain] = rule + } + } + + return nil +} + +// ================= utility methods ========================== + +// Yaml returns the ruleset as a Yaml string +func (rs *Ruleset) Yaml() (string, error) { + y, err := yaml.Marshal(rs) + if err != nil { + return "", err + } + + return string(y), nil +} + +// GzipYaml returns an io.Reader that streams the Gzip-compressed YAML representation of the RuleSet. +func (rs *Ruleset) GzipYaml() (io.Reader, error) { + pr, pw := io.Pipe() + + go func() { + defer pw.Close() + + gw := gzip.NewWriter(pw) + defer gw.Close() + + if err := yaml.NewEncoder(gw).Encode(rs); err != nil { + gw.Close() // Ensure to close the gzip writer + pw.CloseWithError(err) + return + } + }() + + return pr, nil +} + +// Domains extracts and returns a slice of all domains present in the RuleSet. +func (rs *Ruleset) Domains() []string { + var domains []string + for domain := range rs.rules { + domains = append(domains, domain) + } + return domains +} + +// DomainCount returns the count of unique domains present in the RuleSet. +func (rs *Ruleset) DomainCount() int { + return len(rs.Domains()) +} + +// Count returns the total number of rules in the RuleSet. +func (rs *Ruleset) Count() int { + return len(rs.rules) +} + +// PrintStats logs the number of rules and domains loaded in the RuleSet. +func (rs *Ruleset) PrintStats() { + log.Printf("INFO: Loaded %d rules for %d domains\n", rs.Count(), rs.DomainCount()) +} + +// debugPrintRule is a utility function for printing a rule and associated error for debugging purposes. +func debugPrintRule(rule string, err error) { + fmt.Println("------------------------------ BEGIN DEBUG RULESET -----------------------------") + fmt.Printf("%s\n", err.Error()) + fmt.Println("--------------------------------------------------------------------------------") + fmt.Println(rule) + fmt.Println("------------------------------ END DEBUG RULESET -------------------------------") } diff --git a/proxychain/ruleset/ruleset_test.go b/proxychain/ruleset/ruleset_test.go new file mode 100644 index 0000000..cec8285 --- /dev/null +++ b/proxychain/ruleset/ruleset_test.go @@ -0,0 +1,196 @@ +package ruleset_v2 + +import ( + "net/url" + "os" + "path/filepath" + "testing" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/assert" +) + +var ( + validYAML = ` +domains: +- example.com +- www.example.com +responsemodifications: +- name: APIContent + params: [] +- name: SetContentSecurityPolicy + params: + - foobar +- name: SetIncomingCookie + params: + - authorization-bearer + - hunter2 +requestmodifications: +- name: ForwardRequestHeaders + params: [] +` + + invalidYAML = ` +domains: +- example.com +- www.example.com +responsemodifications: +- name: APIContent +- name: SetContentSecurityPolicy +- name: INVALIDSetIncomingCookie + params: + - authorization-bearer + - hunter2 +requestmodifications: +- name: ForwardRequestHeaders + params: [] +` +) + +func TestLoadRulesFromRemoteFile(t *testing.T) { + app := fiber.New() + defer app.Shutdown() + + app.Get("/valid-config.yml", func(c *fiber.Ctx) error { + c.SendString(validYAML) + return nil + }) + + app.Get("/invalid-config.yml", func(c *fiber.Ctx) error { + c.SendString(invalidYAML) + return nil + }) + + app.Get("/valid-config.gz", func(c *fiber.Ctx) error { + c.Set("Content-Type", "application/octet-stream") + + rs, err := loadRuleFromString(validYAML) + if err != nil { + t.Errorf("failed to load valid yaml from string: %s", err.Error()) + } + + s, err := rs.GzipYaml() + if err != nil { + t.Errorf("failed to load gzip serialize yaml: %s", err.Error()) + } + + err = c.SendStream(s) + if err != nil { + t.Errorf("failed to stream gzip serialized yaml: %s", err.Error()) + } + return nil + }) + + // Start the server in a goroutine + go func() { + if err := app.Listen("127.0.0.1:9999"); err != nil { + t.Errorf("Server failed to start: %s", err.Error()) + } + }() + + // Wait for the server to start + time.Sleep(time.Second * 1) + + rs, err := NewRuleset("http://127.0.0.1:9999/valid-config.yml") + if err != nil { + t.Errorf("failed to load plaintext ruleset from http server: %s", err.Error()) + } + + u, _ := url.Parse("http://example.com") + r, exists := rs.GetRule(u) + assert.True(t, exists, "expected example.com rule to be present") + assert.Equal(t, r.Domains[0], "example.com") + + u, _ = url.Parse("http://www.www.example.com") + _, exists = rs.GetRule(u) + assert.False(t, exists, "expected www.www.example.com rule to NOT be present") + + rs, err = NewRuleset("http://127.0.0.1:9999/valid-config.gz") + if err != nil { + t.Errorf("failed to load gzipped ruleset from http server: %s", err.Error()) + } + + r, exists = rs.GetRule(u) + assert.Equal(t, r.Domains[0], "example.com") + + os.Setenv("RULESET", "http://127.0.0.1:9999/valid-config.gz") + + rs = NewRulesetFromEnv() + r, exists = rs.GetRule(u) + assert.True(t, exists, "expected example.com rule to be present") + if !assert.Equal(t, r.Domains[0], "example.com") { + t.Error("expected no errors loading ruleset from gzip url using environment variable, but got one") + } +} + +func loadRuleFromString(yaml string) (Ruleset, error) { + // Create a temporary file and load it + tmpFile, _ := os.CreateTemp("", "ruleset*.yaml") + + defer os.Remove(tmpFile.Name()) + + tmpFile.WriteString(yaml) + + rs := Ruleset{} + err := rs.loadRulesFromLocalFile(tmpFile.Name()) + + return rs, err +} + +// TestLoadRulesFromLocalFile tests the loading of rules from a local YAML file. +func TestLoadRulesFromLocalFile(t *testing.T) { + _, err := loadRuleFromString(validYAML) + if err != nil { + t.Errorf("Failed to load rules from valid YAML: %s", err) + } + + _, err = loadRuleFromString(invalidYAML) + if err == nil { + t.Errorf("Expected an error when loading invalid YAML, but got none") + } +} + +// TestLoadRulesFromLocalDir tests the loading of rules from a local nested directory full of yaml rulesets +func TestLoadRulesFromLocalDir(t *testing.T) { + // Create a temporary directory + baseDir, err := os.MkdirTemp("", "ruleset_test") + if err != nil { + t.Fatalf("Failed to create temporary directory: %s", err) + } + + defer os.RemoveAll(baseDir) + + // Create a nested subdirectory + nestedDir := filepath.Join(baseDir, "nested") + err = os.Mkdir(nestedDir, 0o755) + + if err != nil { + t.Fatalf("Failed to create nested directory: %s", err) + } + + // Create a nested subdirectory + nestedTwiceDir := filepath.Join(nestedDir, "nestedTwice") + err = os.Mkdir(nestedTwiceDir, 0o755) + if err != nil { + t.Fatalf("Failed to create twice-nested directory: %s", err) + } + + testCases := []string{"test.yaml", "test2.yaml", "test-3.yaml", "test 4.yaml", "1987.test.yaml.yml", "foobar.example.com.yaml", "foobar.com.yml"} + for _, fileName := range testCases { + filePath := filepath.Join(nestedDir, "2x-"+fileName) + os.WriteFile(filePath, []byte(validYAML), 0o644) + + filePath = filepath.Join(nestedDir, fileName) + os.WriteFile(filePath, []byte(validYAML), 0o644) + + filePath = filepath.Join(baseDir, "base-"+fileName) + os.WriteFile(filePath, []byte(validYAML), 0o644) + } + + rs := Ruleset{} + err = rs.loadRulesFromLocalDir(baseDir) + + assert.NoError(t, err) + assert.Equal(t, rs.Count(), len(testCases)*3) +} diff --git a/proxychain/ruleset/todo.md b/proxychain/ruleset/todo.md new file mode 100644 index 0000000..6d94d89 --- /dev/null +++ b/proxychain/ruleset/todo.md @@ -0,0 +1 @@ +ruleset loading rule tests are failing; maybe concurrency issue with assigning to nil map?