From e87d19d7f53f11d0d5929be7cade4a0f12f0ae1f Mon Sep 17 00:00:00 2001 From: Kevin Pham Date: Tue, 14 Nov 2023 15:42:26 -0600 Subject: [PATCH] add ability to load rulesets from directory --- cmd/main.go | 7 +- handlers/proxy.go | 110 +++++---------- handlers/proxy.test.go | 5 +- pkg/ruleset/ruleset.go | 274 ++++++++++++++++++++++++++++++++++++ pkg/ruleset/ruleset_test.go | 153 ++++++++++++++++++++ 5 files changed, 474 insertions(+), 75 deletions(-) create mode 100644 pkg/ruleset/ruleset.go create mode 100644 pkg/ruleset/ruleset_test.go diff --git a/cmd/main.go b/cmd/main.go index 36c6fef..256c52e 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -36,6 +36,11 @@ func main() { Help: "This will spawn multiple processes listening", }) + ruleset := parser.String("r", "ruleset", &argparse.Options{ + Required: false, + Help: "File, Directory or URL to a ruleset.yml. Overrides RULESET environment variable.", + }) + err := parser.Parse(os.Args) if err != nil { fmt.Print(parser.Usage(err)) @@ -80,7 +85,7 @@ func main() { app.Get("raw/*", handlers.Raw) app.Get("api/*", handlers.Api) app.Get("ruleset", handlers.Raw) - app.Get("/*", handlers.ProxySite) + app.Get("/*", handlers.ProxySite(*ruleset)) log.Fatal(app.Listen(":" + *port)) } diff --git a/handlers/proxy.go b/handlers/proxy.go index 2f89037..18a6f41 100644 --- a/handlers/proxy.go +++ b/handlers/proxy.go @@ -10,34 +10,52 @@ import ( "regexp" "strings" + "ladder/pkg/ruleset" + "github.com/PuerkitoBio/goquery" "github.com/gofiber/fiber/v2" - "gopkg.in/yaml.v3" ) var ( UserAgent = getenv("USER_AGENT", "Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)") ForwardedFor = getenv("X_FORWARDED_FOR", "66.249.66.1") - rulesSet = loadRules() - allowedDomains = strings.Split(os.Getenv("ALLOWED_DOMAINS"), ",") + rulesSet = ruleset.NewRulesetFromEnv() + allowedDomains = []string{} ) -func ProxySite(c *fiber.Ctx) error { - // Get the url from the URL - url := c.Params("*") +func init() { + allowedDomains = strings.Split(os.Getenv("ALLOWED_DOMAINS"), ",") + if os.Getenv("ALLOWED_DOMAINS_RULESET") == "true" { + allowedDomains = append(allowedDomains, rulesSet.Domains()...) + } +} - queries := c.Queries() - body, _, resp, err := fetchSite(url, queries) - if err != nil { - log.Println("ERROR:", err) - c.SendStatus(fiber.StatusInternalServerError) - return c.SendString(err.Error()) +func ProxySite(rulesetPath string) fiber.Handler { + if rulesetPath != "" { + rs, err := ruleset.NewRuleset(rulesetPath) + if err != nil { + panic(err) + } + rulesSet = rs } - c.Set("Content-Type", resp.Header.Get("Content-Type")) - c.Set("Content-Security-Policy", resp.Header.Get("Content-Security-Policy")) + return func(c *fiber.Ctx) error { + // Get the url from the URL + url := c.Params("*") - return c.SendString(body) + queries := c.Queries() + body, _, resp, err := fetchSite(url, queries) + if err != nil { + log.Println("ERROR:", err) + c.SendStatus(fiber.StatusInternalServerError) + return c.SendString(err.Error()) + } + + c.Set("Content-Type", resp.Header.Get("Content-Type")) + c.Set("Content-Security-Policy", resp.Header.Get("Content-Security-Policy")) + + return c.SendString(body) + } } func fetchSite(urlpath string, queries map[string]string) (string, *http.Request, *http.Response, error) { @@ -122,7 +140,7 @@ func fetchSite(urlpath string, queries map[string]string) (string, *http.Request return body, req, resp, nil } -func rewriteHtml(bodyB []byte, u *url.URL, rule Rule) string { +func rewriteHtml(bodyB []byte, u *url.URL, rule ruleset.Rule) string { // Rewrite the HTML body := string(bodyB) @@ -156,63 +174,11 @@ func getenv(key, fallback string) string { return value } -func loadRules() RuleSet { - rulesUrl := os.Getenv("RULESET") - if rulesUrl == "" { - RulesList := RuleSet{} - return RulesList - } - log.Println("Loading rules") - - var ruleSet RuleSet - if strings.HasPrefix(rulesUrl, "http") { - - resp, err := http.Get(rulesUrl) - if err != nil { - log.Println("ERROR:", err) - } - defer resp.Body.Close() - - if resp.StatusCode >= 400 { - log.Println("ERROR:", resp.StatusCode, rulesUrl) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Println("ERROR:", err) - } - yaml.Unmarshal(body, &ruleSet) - - if err != nil { - log.Println("ERROR:", err) - } - } else { - yamlFile, err := os.ReadFile(rulesUrl) - if err != nil { - log.Println("ERROR:", err) - } - yaml.Unmarshal(yamlFile, &ruleSet) - } - - domains := []string{} - for _, rule := range ruleSet { - - domains = append(domains, rule.Domain) - domains = append(domains, rule.Domains...) - if os.Getenv("ALLOWED_DOMAINS_RULESET") == "true" { - allowedDomains = append(allowedDomains, domains...) - } - } - - log.Println("Loaded ", len(ruleSet), " rules for", len(domains), "Domains") - return ruleSet -} - -func fetchRule(domain string, path string) Rule { +func fetchRule(domain string, path string) ruleset.Rule { if len(rulesSet) == 0 { - return Rule{} + return ruleset.Rule{} } - rule := Rule{} + rule := ruleset.Rule{} for _, rule := range rulesSet { domains := rule.Domains if rule.Domain != "" { @@ -231,7 +197,7 @@ func fetchRule(domain string, path string) Rule { return rule } -func applyRules(body string, rule Rule) string { +func applyRules(body string, rule ruleset.Rule) string { if len(rulesSet) == 0 { return body } diff --git a/handlers/proxy.test.go b/handlers/proxy.test.go index af66e54..07f72bd 100644 --- a/handlers/proxy.test.go +++ b/handlers/proxy.test.go @@ -2,6 +2,7 @@ package handlers import ( + "ladder/pkg/ruleset" "net/http" "net/http/httptest" "net/url" @@ -13,7 +14,7 @@ import ( func TestProxySite(t *testing.T) { app := fiber.New() - app.Get("/:url", ProxySite) + app.Get("/:url", ProxySite("")) req := httptest.NewRequest("GET", "/https://example.com", nil) resp, err := app.Test(req) @@ -51,7 +52,7 @@ func TestRewriteHtml(t *testing.T) { ` - actual := rewriteHtml(bodyB, u, Rule{}) + actual := rewriteHtml(bodyB, u, ruleset.Rule{}) assert.Equal(t, expected, actual) } diff --git a/pkg/ruleset/ruleset.go b/pkg/ruleset/ruleset.go new file mode 100644 index 0000000..a4efd3d --- /dev/null +++ b/pkg/ruleset/ruleset.go @@ -0,0 +1,274 @@ +package ruleset + +import ( + "errors" + "fmt" + "io" + "log" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + + "compress/gzip" + + "gopkg.in/yaml.v3" +) + +type Regex struct { + Match string `yaml:"match"` + Replace string `yaml:"replace"` +} + +type RuleSet []Rule + +type Rule struct { + Domain string `yaml:"domain,omitempty"` + Domains []string `yaml:"domains,omitempty"` + Paths []string `yaml:"paths,omitempty"` + Headers struct { + UserAgent string `yaml:"user-agent,omitempty"` + XForwardedFor string `yaml:"x-forwarded-for,omitempty"` + Referer string `yaml:"referer,omitempty"` + Cookie string `yaml:"cookie,omitempty"` + CSP string `yaml:"content-security-policy,omitempty"` + } `yaml:"headers,omitempty"` + GoogleCache bool `yaml:"googleCache,omitempty"` + RegexRules []Regex `yaml:"regexRules"` + Injections []struct { + Position string `yaml:"position"` + Append string `yaml:"append"` + Prepend string `yaml:"prepend"` + Replace string `yaml:"replace"` + } `yaml:"injections"` +} + +// NewRulesetFromEnv creates a new RuleSet based on the RULESET environment variable. +// It logs a warning and returns an empty RuleSet if the RULESET environment variable is not set. +// If the RULESET is set but the rules cannot be loaded, it panics. +func NewRulesetFromEnv() RuleSet { + rulesPath, ok := os.LookupEnv("RULESET") + if !ok { + log.Printf("WARN: No ruleset specified. Set the `RULESET` environment variable to load one for a better success rate.") + return RuleSet{} + } + ruleSet, err := NewRuleset(rulesPath) + if err != nil { + log.Panicln(ruleSet) + } + return ruleSet +} + +// NewRuleset loads a RuleSet from a given string of rule paths, separated by semicolons. +// It supports loading rules from both local file paths and remote URLs. +// Returns a RuleSet and an error if any issues occur during loading. +func NewRuleset(rulePaths string) (RuleSet, error) { + ruleSet := RuleSet{} + errs := []error{} + + rp := strings.Split(rulePaths, ";") + for _, rule := range rp { + rulePath := strings.Trim(rule, " ") + var err error + + isRemote, _ := regexp.MatchString(`^https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()!@:%_\+.~#?&\/\/=]*)`, rulePath) + if isRemote { + err = ruleSet.loadRulesFromRemoteFile(rulePath) + } else { + err = ruleSet.loadRulesFromLocalDir(rulePath) + } + + if err != nil { + e := errors.New(fmt.Sprintf("WARN: failed to load ruleset from ''%s", rulePath)) + errs = append(errs, errors.Join(e, err)) + continue + } + } + + if len(errs) != 0 { + e := errors.New(fmt.Sprintf("WARN: failed to load %d rulesets", len(rp))) + errs = append(errs, e) + // panic if the user specified a local ruleset, but it wasn't found on disk + // don't fail silently + for _, err := range errs { + if errors.Is(os.ErrNotExist, err) { + e := fmt.Errorf("PANIC: ruleset '%s' not found", err) + panic(errors.Join(e, err)) + } + } + // else, bubble up any errors, such as syntax or remote host issues + return ruleSet, errors.Join(errs...) + } + ruleSet.PrintStats() + return ruleSet, nil +} + +// ================== RULESET loading logic =================================== + +// loadRulesFromLocalDir loads rules from a local directory specified by the path. +// It walks through the directory, loading rules from YAML files. +// Returns an error if the directory cannot be accessed +// If there is an issue loading any file, it will be skipped +func (rs *RuleSet) loadRulesFromLocalDir(path string) error { + _, err := os.Stat(path) + if err != nil { + return err + } + + yamlRegex := regexp.MustCompile(`.*\.ya?ml`) + + err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + if isYaml := yamlRegex.MatchString(path); !isYaml { + return nil + } + + err = rs.loadRulesFromLocalFile(path) + if err != nil { + log.Printf("WARN: failed to load directory ruleset '%s': %s, skipping", path, err) + return nil + } + log.Printf("INFO: loaded ruleset %s\n", path) + return nil + }) + + if err != nil { + return err + } + return nil +} + +// loadRulesFromLocalFile loads rules from a local YAML file specified by the path. +// Returns an error if the file cannot be read or if there's a syntax error in the YAML. +func (rs *RuleSet) loadRulesFromLocalFile(path string) error { + yamlFile, err := os.ReadFile(path) + if err != nil { + e := errors.New(fmt.Sprintf("failed to read rules from local file: '%s'", path)) + return errors.Join(e, err) + } + + var r RuleSet + err = yaml.Unmarshal(yamlFile, &r) + if err != nil { + e := errors.New(fmt.Sprintf("failed to load rules from local file, possible syntax error in '%s'", path)) + ee := errors.Join(e, err) + if _, ok := os.LookupEnv("DEBUG"); ok { + debugPrintRule(string(yamlFile), ee) + } + return ee + } + *rs = append(*rs, r...) + return nil +} + +// loadRulesFromRemoteFile loads rules from a remote URL. +// It supports plain and gzip compressed content. +// Returns an error if there's an issue accessing the URL or if there's a syntax error in the YAML. +func (rs *RuleSet) loadRulesFromRemoteFile(rulesUrl string) error { + var r RuleSet + resp, err := http.Get(rulesUrl) + if err != nil { + e := errors.New(fmt.Sprintf("failed to load rules from remote url '%s'", rulesUrl)) + return errors.Join(e, err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + e := errors.New(fmt.Sprintf("failed to load rules from remote url (%s) on '%s'", resp.Status, rulesUrl)) + return errors.Join(e, err) + } + + var reader io.Reader + isGzip := strings.HasSuffix(rulesUrl, ".gz") || strings.HasSuffix(rulesUrl, ".gzip") || resp.Header.Get("content-encoding") == "gzip" + + if isGzip { + reader, err = gzip.NewReader(resp.Body) + if err != nil { + return fmt.Errorf("failed to create gzip reader for URL '%s' with status code '%s': %w", rulesUrl, resp.Status, err) + } + } else { + reader = resp.Body + } + + err = yaml.NewDecoder(reader).Decode(&r) + + if err != nil { + e := errors.New(fmt.Sprintf("failed to load rules from remote url '%s' with status code '%s' and possible syntax error", rulesUrl, resp.Status)) + ee := errors.Join(e, err) + return ee + } + + *rs = append(*rs, r...) + return nil +} + +// ================= utility methods ========================== + +// Yaml returns the ruleset as a Yaml string +func (rs *RuleSet) Yaml() (string, error) { + y, err := yaml.Marshal(rs) + if err != nil { + return "", err + } + return string(y), nil +} + +// GzipYaml returns an io.Reader that streams the Gzip-compressed YAML representation of the RuleSet. +func (rs *RuleSet) GzipYaml() (io.Reader, error) { + pr, pw := io.Pipe() + + go func() { + defer pw.Close() + + gw := gzip.NewWriter(pw) + defer gw.Close() + + if err := yaml.NewEncoder(gw).Encode(rs); err != nil { + gw.Close() // Ensure to close the gzip writer + pw.CloseWithError(err) + return + } + }() + + return pr, nil +} + +// Domains extracts and returns a slice of all domains present in the RuleSet. +func (rs *RuleSet) Domains() []string { + var domains []string + for _, rule := range *rs { + domains = append(domains, rule.Domain) + domains = append(domains, rule.Domains...) + } + return domains +} + +// DomainCount returns the count of unique domains present in the RuleSet. +func (rs *RuleSet) DomainCount() int { + return len(rs.Domains()) +} + +// Count returns the total number of rules in the RuleSet. +func (rs *RuleSet) Count() int { + return len(*rs) +} + +// PrintStats logs the number of rules and domains loaded in the RuleSet. +func (rs *RuleSet) PrintStats() { + log.Printf("INFO: Loaded %d rules for %d domains\n", rs.Count(), rs.DomainCount()) +} + +// debugPrintRule is a utility function for printing a rule and associated error for debugging purposes. +func debugPrintRule(rule string, err error) { + fmt.Println("------------------------------ BEGIN DEBUG RULESET -----------------------------") + fmt.Printf("%s\n", err.Error()) + fmt.Println("--------------------------------------------------------------------------------") + fmt.Println(rule) + fmt.Println("------------------------------ END DEBUG RULESET -------------------------------") +} diff --git a/pkg/ruleset/ruleset_test.go b/pkg/ruleset/ruleset_test.go new file mode 100644 index 0000000..85c6f33 --- /dev/null +++ b/pkg/ruleset/ruleset_test.go @@ -0,0 +1,153 @@ +package ruleset + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/assert" +) + +var ( + validYAML = ` +- domain: example.com + regexRules: + - match: "^http:" + replace: "https:"` + + invalidYAML = ` +- domain: [thisIsATestYamlThatIsMeantToFail.example] + regexRules: + - match: "^http:" + replace: "https:" + - match: "[incomplete"` +) + +func TestLoadRulesFromRemoteFile(t *testing.T) { + app := fiber.New() + defer app.Shutdown() + + app.Get("/valid-config.yml", func(c *fiber.Ctx) error { + c.SendString(validYAML) + return nil + }) + app.Get("/invalid-config.yml", func(c *fiber.Ctx) error { + c.SendString(invalidYAML) + return nil + }) + + app.Get("/valid-config.gz", func(c *fiber.Ctx) error { + c.Set("Content-Type", "application/octet-stream") + rs, err := loadRuleFromString(validYAML) + if err != nil { + t.Errorf("failed to load valid yaml from string: %s", err.Error()) + } + s, err := rs.GzipYaml() + if err != nil { + t.Errorf("failed to load gzip serialize yaml: %s", err.Error()) + } + + err = c.SendStream(s) + if err != nil { + t.Errorf("failed to stream gzip serialized yaml: %s", err.Error()) + } + return nil + }) + + // Start the server in a goroutine + go func() { + if err := app.Listen("127.0.0.1:9999"); err != nil { + t.Errorf("Server failed to start: %s", err.Error()) + } + }() + + // Wait for the server to start + time.Sleep(time.Second * 1) + + rs, err := NewRuleset("http://127.0.0.1:9999/valid-config.yml") + if err != nil { + t.Errorf("failed to load plaintext ruleset from http server: %s", err.Error()) + } + assert.Equal(t, rs[0].Domain, "example.com") + + rs, err = NewRuleset("http://127.0.0.1:9999/valid-config.gz") + if err != nil { + t.Errorf("failed to load gzipped ruleset from http server: %s", err.Error()) + } + assert.Equal(t, rs[0].Domain, "example.com") + + os.Setenv("RULESET", "http://127.0.0.1:9999/valid-config.gz") + rs = NewRulesetFromEnv() + if !assert.Equal(t, rs[0].Domain, "example.com") { + t.Error("expected no errors loading ruleset from gzip url using environment variable, but got one") + } +} + +func loadRuleFromString(yaml string) (RuleSet, error) { + // Create a temporary file and load it + tmpFile, _ := os.CreateTemp("", "ruleset*.yaml") + defer os.Remove(tmpFile.Name()) + tmpFile.WriteString(yaml) + rs := RuleSet{} + err := rs.loadRulesFromLocalFile(tmpFile.Name()) + return rs, err +} + +// TestLoadRulesFromLocalFile tests the loading of rules from a local YAML file. +func TestLoadRulesFromLocalFile(t *testing.T) { + rs, err := loadRuleFromString(validYAML) + if err != nil { + t.Errorf("Failed to load rules from valid YAML: %s", err) + } + assert.Equal(t, rs[0].Domain, "example.com") + assert.Equal(t, rs[0].RegexRules[0].Match, "^http:") + assert.Equal(t, rs[0].RegexRules[0].Replace, "https:") + + _, err = loadRuleFromString(invalidYAML) + if err == nil { + t.Errorf("Expected an error when loading invalid YAML, but got none") + } +} + +// TestLoadRulesFromLocalDir tests the loading of rules from a local nested directory full of yaml rulesets +func TestLoadRulesFromLocalDir(t *testing.T) { + // Create a temporary directory + baseDir, err := os.MkdirTemp("", "ruleset_test") + if err != nil { + t.Fatalf("Failed to create temporary directory: %s", err) + } + defer os.RemoveAll(baseDir) + + // Create a nested subdirectory + nestedDir := filepath.Join(baseDir, "nested") + err = os.Mkdir(nestedDir, 0755) + if err != nil { + t.Fatalf("Failed to create nested directory: %s", err) + } + + // Create a nested subdirectory + nestedTwiceDir := filepath.Join(nestedDir, "nestedTwice") + err = os.Mkdir(nestedTwiceDir, 0755) + + testCases := []string{"test.yaml", "test2.yaml", "test-3.yaml", "test 4.yaml", "1987.test.yaml.yml", "foobar.example.com.yaml", "foobar.com.yml"} + for _, fileName := range testCases { + filePath := filepath.Join(nestedDir, "2x-"+fileName) + os.WriteFile(filePath, []byte(validYAML), 0644) + filePath = filepath.Join(nestedDir, fileName) + os.WriteFile(filePath, []byte(validYAML), 0644) + filePath = filepath.Join(baseDir, "base-"+fileName) + os.WriteFile(filePath, []byte(validYAML), 0644) + } + rs := RuleSet{} + err = rs.loadRulesFromLocalDir(baseDir) + assert.NoError(t, err) + assert.Equal(t, rs.Count(), len(testCases)*3) + + for _, rule := range rs { + assert.Equal(t, rule.Domain, "example.com") + assert.Equal(t, rule.RegexRules[0].Match, "^http:") + assert.Equal(t, rule.RegexRules[0].Replace, "https:") + } +}