add ability to load rulesets from directory
This commit is contained in:
@@ -36,6 +36,11 @@ func main() {
|
|||||||
Help: "This will spawn multiple processes listening",
|
Help: "This will spawn multiple processes listening",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
ruleset := parser.String("r", "ruleset", &argparse.Options{
|
||||||
|
Required: false,
|
||||||
|
Help: "File, Directory or URL to a ruleset.yml. Overrides RULESET environment variable.",
|
||||||
|
})
|
||||||
|
|
||||||
err := parser.Parse(os.Args)
|
err := parser.Parse(os.Args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Print(parser.Usage(err))
|
fmt.Print(parser.Usage(err))
|
||||||
@@ -80,7 +85,7 @@ func main() {
|
|||||||
app.Get("raw/*", handlers.Raw)
|
app.Get("raw/*", handlers.Raw)
|
||||||
app.Get("api/*", handlers.Api)
|
app.Get("api/*", handlers.Api)
|
||||||
app.Get("ruleset", handlers.Raw)
|
app.Get("ruleset", handlers.Raw)
|
||||||
app.Get("/*", handlers.ProxySite)
|
app.Get("/*", handlers.ProxySite(*ruleset))
|
||||||
|
|
||||||
log.Fatal(app.Listen(":" + *port))
|
log.Fatal(app.Listen(":" + *port))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,19 +10,36 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"ladder/pkg/ruleset"
|
||||||
|
|
||||||
"github.com/PuerkitoBio/goquery"
|
"github.com/PuerkitoBio/goquery"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
UserAgent = getenv("USER_AGENT", "Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)")
|
UserAgent = getenv("USER_AGENT", "Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)")
|
||||||
ForwardedFor = getenv("X_FORWARDED_FOR", "66.249.66.1")
|
ForwardedFor = getenv("X_FORWARDED_FOR", "66.249.66.1")
|
||||||
rulesSet = loadRules()
|
rulesSet = ruleset.NewRulesetFromEnv()
|
||||||
allowedDomains = strings.Split(os.Getenv("ALLOWED_DOMAINS"), ",")
|
allowedDomains = []string{}
|
||||||
)
|
)
|
||||||
|
|
||||||
func ProxySite(c *fiber.Ctx) error {
|
func init() {
|
||||||
|
allowedDomains = strings.Split(os.Getenv("ALLOWED_DOMAINS"), ",")
|
||||||
|
if os.Getenv("ALLOWED_DOMAINS_RULESET") == "true" {
|
||||||
|
allowedDomains = append(allowedDomains, rulesSet.Domains()...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ProxySite(rulesetPath string) fiber.Handler {
|
||||||
|
if rulesetPath != "" {
|
||||||
|
rs, err := ruleset.NewRuleset(rulesetPath)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
rulesSet = rs
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
// Get the url from the URL
|
// Get the url from the URL
|
||||||
url := c.Params("*")
|
url := c.Params("*")
|
||||||
|
|
||||||
@@ -38,6 +55,7 @@ func ProxySite(c *fiber.Ctx) error {
|
|||||||
c.Set("Content-Security-Policy", resp.Header.Get("Content-Security-Policy"))
|
c.Set("Content-Security-Policy", resp.Header.Get("Content-Security-Policy"))
|
||||||
|
|
||||||
return c.SendString(body)
|
return c.SendString(body)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchSite(urlpath string, queries map[string]string) (string, *http.Request, *http.Response, error) {
|
func fetchSite(urlpath string, queries map[string]string) (string, *http.Request, *http.Response, error) {
|
||||||
@@ -122,7 +140,7 @@ func fetchSite(urlpath string, queries map[string]string) (string, *http.Request
|
|||||||
return body, req, resp, nil
|
return body, req, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func rewriteHtml(bodyB []byte, u *url.URL, rule Rule) string {
|
func rewriteHtml(bodyB []byte, u *url.URL, rule ruleset.Rule) string {
|
||||||
// Rewrite the HTML
|
// Rewrite the HTML
|
||||||
body := string(bodyB)
|
body := string(bodyB)
|
||||||
|
|
||||||
@@ -156,63 +174,11 @@ func getenv(key, fallback string) string {
|
|||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadRules() RuleSet {
|
func fetchRule(domain string, path string) ruleset.Rule {
|
||||||
rulesUrl := os.Getenv("RULESET")
|
|
||||||
if rulesUrl == "" {
|
|
||||||
RulesList := RuleSet{}
|
|
||||||
return RulesList
|
|
||||||
}
|
|
||||||
log.Println("Loading rules")
|
|
||||||
|
|
||||||
var ruleSet RuleSet
|
|
||||||
if strings.HasPrefix(rulesUrl, "http") {
|
|
||||||
|
|
||||||
resp, err := http.Get(rulesUrl)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("ERROR:", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
|
||||||
log.Println("ERROR:", resp.StatusCode, rulesUrl)
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("ERROR:", err)
|
|
||||||
}
|
|
||||||
yaml.Unmarshal(body, &ruleSet)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Println("ERROR:", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
yamlFile, err := os.ReadFile(rulesUrl)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("ERROR:", err)
|
|
||||||
}
|
|
||||||
yaml.Unmarshal(yamlFile, &ruleSet)
|
|
||||||
}
|
|
||||||
|
|
||||||
domains := []string{}
|
|
||||||
for _, rule := range ruleSet {
|
|
||||||
|
|
||||||
domains = append(domains, rule.Domain)
|
|
||||||
domains = append(domains, rule.Domains...)
|
|
||||||
if os.Getenv("ALLOWED_DOMAINS_RULESET") == "true" {
|
|
||||||
allowedDomains = append(allowedDomains, domains...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println("Loaded ", len(ruleSet), " rules for", len(domains), "Domains")
|
|
||||||
return ruleSet
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchRule(domain string, path string) Rule {
|
|
||||||
if len(rulesSet) == 0 {
|
if len(rulesSet) == 0 {
|
||||||
return Rule{}
|
return ruleset.Rule{}
|
||||||
}
|
}
|
||||||
rule := Rule{}
|
rule := ruleset.Rule{}
|
||||||
for _, rule := range rulesSet {
|
for _, rule := range rulesSet {
|
||||||
domains := rule.Domains
|
domains := rule.Domains
|
||||||
if rule.Domain != "" {
|
if rule.Domain != "" {
|
||||||
@@ -231,7 +197,7 @@ func fetchRule(domain string, path string) Rule {
|
|||||||
return rule
|
return rule
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyRules(body string, rule Rule) string {
|
func applyRules(body string, rule ruleset.Rule) string {
|
||||||
if len(rulesSet) == 0 {
|
if len(rulesSet) == 0 {
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"ladder/pkg/ruleset"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -13,7 +14,7 @@ import (
|
|||||||
|
|
||||||
func TestProxySite(t *testing.T) {
|
func TestProxySite(t *testing.T) {
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
app.Get("/:url", ProxySite)
|
app.Get("/:url", ProxySite(""))
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/https://example.com", nil)
|
req := httptest.NewRequest("GET", "/https://example.com", nil)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
@@ -51,7 +52,7 @@ func TestRewriteHtml(t *testing.T) {
|
|||||||
</html>
|
</html>
|
||||||
`
|
`
|
||||||
|
|
||||||
actual := rewriteHtml(bodyB, u, Rule{})
|
actual := rewriteHtml(bodyB, u, ruleset.Rule{})
|
||||||
assert.Equal(t, expected, actual)
|
assert.Equal(t, expected, actual)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
274
pkg/ruleset/ruleset.go
Normal file
274
pkg/ruleset/ruleset.go
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
package ruleset
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"compress/gzip"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Regex struct {
|
||||||
|
Match string `yaml:"match"`
|
||||||
|
Replace string `yaml:"replace"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RuleSet []Rule
|
||||||
|
|
||||||
|
type Rule struct {
|
||||||
|
Domain string `yaml:"domain,omitempty"`
|
||||||
|
Domains []string `yaml:"domains,omitempty"`
|
||||||
|
Paths []string `yaml:"paths,omitempty"`
|
||||||
|
Headers struct {
|
||||||
|
UserAgent string `yaml:"user-agent,omitempty"`
|
||||||
|
XForwardedFor string `yaml:"x-forwarded-for,omitempty"`
|
||||||
|
Referer string `yaml:"referer,omitempty"`
|
||||||
|
Cookie string `yaml:"cookie,omitempty"`
|
||||||
|
CSP string `yaml:"content-security-policy,omitempty"`
|
||||||
|
} `yaml:"headers,omitempty"`
|
||||||
|
GoogleCache bool `yaml:"googleCache,omitempty"`
|
||||||
|
RegexRules []Regex `yaml:"regexRules"`
|
||||||
|
Injections []struct {
|
||||||
|
Position string `yaml:"position"`
|
||||||
|
Append string `yaml:"append"`
|
||||||
|
Prepend string `yaml:"prepend"`
|
||||||
|
Replace string `yaml:"replace"`
|
||||||
|
} `yaml:"injections"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRulesetFromEnv creates a new RuleSet based on the RULESET environment variable.
|
||||||
|
// It logs a warning and returns an empty RuleSet if the RULESET environment variable is not set.
|
||||||
|
// If the RULESET is set but the rules cannot be loaded, it panics.
|
||||||
|
func NewRulesetFromEnv() RuleSet {
|
||||||
|
rulesPath, ok := os.LookupEnv("RULESET")
|
||||||
|
if !ok {
|
||||||
|
log.Printf("WARN: No ruleset specified. Set the `RULESET` environment variable to load one for a better success rate.")
|
||||||
|
return RuleSet{}
|
||||||
|
}
|
||||||
|
ruleSet, err := NewRuleset(rulesPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Panicln(ruleSet)
|
||||||
|
}
|
||||||
|
return ruleSet
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRuleset loads a RuleSet from a given string of rule paths, separated by semicolons.
|
||||||
|
// It supports loading rules from both local file paths and remote URLs.
|
||||||
|
// Returns a RuleSet and an error if any issues occur during loading.
|
||||||
|
func NewRuleset(rulePaths string) (RuleSet, error) {
|
||||||
|
ruleSet := RuleSet{}
|
||||||
|
errs := []error{}
|
||||||
|
|
||||||
|
rp := strings.Split(rulePaths, ";")
|
||||||
|
for _, rule := range rp {
|
||||||
|
rulePath := strings.Trim(rule, " ")
|
||||||
|
var err error
|
||||||
|
|
||||||
|
isRemote, _ := regexp.MatchString(`^https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()!@:%_\+.~#?&\/\/=]*)`, rulePath)
|
||||||
|
if isRemote {
|
||||||
|
err = ruleSet.loadRulesFromRemoteFile(rulePath)
|
||||||
|
} else {
|
||||||
|
err = ruleSet.loadRulesFromLocalDir(rulePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
e := errors.New(fmt.Sprintf("WARN: failed to load ruleset from ''%s", rulePath))
|
||||||
|
errs = append(errs, errors.Join(e, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) != 0 {
|
||||||
|
e := errors.New(fmt.Sprintf("WARN: failed to load %d rulesets", len(rp)))
|
||||||
|
errs = append(errs, e)
|
||||||
|
// panic if the user specified a local ruleset, but it wasn't found on disk
|
||||||
|
// don't fail silently
|
||||||
|
for _, err := range errs {
|
||||||
|
if errors.Is(os.ErrNotExist, err) {
|
||||||
|
e := fmt.Errorf("PANIC: ruleset '%s' not found", err)
|
||||||
|
panic(errors.Join(e, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// else, bubble up any errors, such as syntax or remote host issues
|
||||||
|
return ruleSet, errors.Join(errs...)
|
||||||
|
}
|
||||||
|
ruleSet.PrintStats()
|
||||||
|
return ruleSet, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ================== RULESET loading logic ===================================
|
||||||
|
|
||||||
|
// loadRulesFromLocalDir loads rules from a local directory specified by the path.
|
||||||
|
// It walks through the directory, loading rules from YAML files.
|
||||||
|
// Returns an error if the directory cannot be accessed
|
||||||
|
// If there is an issue loading any file, it will be skipped
|
||||||
|
func (rs *RuleSet) loadRulesFromLocalDir(path string) error {
|
||||||
|
_, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
yamlRegex := regexp.MustCompile(`.*\.ya?ml`)
|
||||||
|
|
||||||
|
err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if info.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if isYaml := yamlRegex.MatchString(path); !isYaml {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rs.loadRulesFromLocalFile(path)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("WARN: failed to load directory ruleset '%s': %s, skipping", path, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Printf("INFO: loaded ruleset %s\n", path)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadRulesFromLocalFile loads rules from a local YAML file specified by the path.
|
||||||
|
// Returns an error if the file cannot be read or if there's a syntax error in the YAML.
|
||||||
|
func (rs *RuleSet) loadRulesFromLocalFile(path string) error {
|
||||||
|
yamlFile, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
e := errors.New(fmt.Sprintf("failed to read rules from local file: '%s'", path))
|
||||||
|
return errors.Join(e, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var r RuleSet
|
||||||
|
err = yaml.Unmarshal(yamlFile, &r)
|
||||||
|
if err != nil {
|
||||||
|
e := errors.New(fmt.Sprintf("failed to load rules from local file, possible syntax error in '%s'", path))
|
||||||
|
ee := errors.Join(e, err)
|
||||||
|
if _, ok := os.LookupEnv("DEBUG"); ok {
|
||||||
|
debugPrintRule(string(yamlFile), ee)
|
||||||
|
}
|
||||||
|
return ee
|
||||||
|
}
|
||||||
|
*rs = append(*rs, r...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadRulesFromRemoteFile loads rules from a remote URL.
|
||||||
|
// It supports plain and gzip compressed content.
|
||||||
|
// Returns an error if there's an issue accessing the URL or if there's a syntax error in the YAML.
|
||||||
|
func (rs *RuleSet) loadRulesFromRemoteFile(rulesUrl string) error {
|
||||||
|
var r RuleSet
|
||||||
|
resp, err := http.Get(rulesUrl)
|
||||||
|
if err != nil {
|
||||||
|
e := errors.New(fmt.Sprintf("failed to load rules from remote url '%s'", rulesUrl))
|
||||||
|
return errors.Join(e, err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
e := errors.New(fmt.Sprintf("failed to load rules from remote url (%s) on '%s'", resp.Status, rulesUrl))
|
||||||
|
return errors.Join(e, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reader io.Reader
|
||||||
|
isGzip := strings.HasSuffix(rulesUrl, ".gz") || strings.HasSuffix(rulesUrl, ".gzip") || resp.Header.Get("content-encoding") == "gzip"
|
||||||
|
|
||||||
|
if isGzip {
|
||||||
|
reader, err = gzip.NewReader(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create gzip reader for URL '%s' with status code '%s': %w", rulesUrl, resp.Status, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
reader = resp.Body
|
||||||
|
}
|
||||||
|
|
||||||
|
err = yaml.NewDecoder(reader).Decode(&r)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
e := errors.New(fmt.Sprintf("failed to load rules from remote url '%s' with status code '%s' and possible syntax error", rulesUrl, resp.Status))
|
||||||
|
ee := errors.Join(e, err)
|
||||||
|
return ee
|
||||||
|
}
|
||||||
|
|
||||||
|
*rs = append(*rs, r...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ================= utility methods ==========================
|
||||||
|
|
||||||
|
// Yaml returns the ruleset as a Yaml string
|
||||||
|
func (rs *RuleSet) Yaml() (string, error) {
|
||||||
|
y, err := yaml.Marshal(rs)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(y), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GzipYaml returns an io.Reader that streams the Gzip-compressed YAML representation of the RuleSet.
|
||||||
|
func (rs *RuleSet) GzipYaml() (io.Reader, error) {
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer pw.Close()
|
||||||
|
|
||||||
|
gw := gzip.NewWriter(pw)
|
||||||
|
defer gw.Close()
|
||||||
|
|
||||||
|
if err := yaml.NewEncoder(gw).Encode(rs); err != nil {
|
||||||
|
gw.Close() // Ensure to close the gzip writer
|
||||||
|
pw.CloseWithError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return pr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Domains extracts and returns a slice of all domains present in the RuleSet.
|
||||||
|
func (rs *RuleSet) Domains() []string {
|
||||||
|
var domains []string
|
||||||
|
for _, rule := range *rs {
|
||||||
|
domains = append(domains, rule.Domain)
|
||||||
|
domains = append(domains, rule.Domains...)
|
||||||
|
}
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
// DomainCount returns the count of unique domains present in the RuleSet.
|
||||||
|
func (rs *RuleSet) DomainCount() int {
|
||||||
|
return len(rs.Domains())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the total number of rules in the RuleSet.
|
||||||
|
func (rs *RuleSet) Count() int {
|
||||||
|
return len(*rs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrintStats logs the number of rules and domains loaded in the RuleSet.
|
||||||
|
func (rs *RuleSet) PrintStats() {
|
||||||
|
log.Printf("INFO: Loaded %d rules for %d domains\n", rs.Count(), rs.DomainCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
// debugPrintRule is a utility function for printing a rule and associated error for debugging purposes.
|
||||||
|
func debugPrintRule(rule string, err error) {
|
||||||
|
fmt.Println("------------------------------ BEGIN DEBUG RULESET -----------------------------")
|
||||||
|
fmt.Printf("%s\n", err.Error())
|
||||||
|
fmt.Println("--------------------------------------------------------------------------------")
|
||||||
|
fmt.Println(rule)
|
||||||
|
fmt.Println("------------------------------ END DEBUG RULESET -------------------------------")
|
||||||
|
}
|
||||||
153
pkg/ruleset/ruleset_test.go
Normal file
153
pkg/ruleset/ruleset_test.go
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
package ruleset
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
validYAML = `
|
||||||
|
- domain: example.com
|
||||||
|
regexRules:
|
||||||
|
- match: "^http:"
|
||||||
|
replace: "https:"`
|
||||||
|
|
||||||
|
invalidYAML = `
|
||||||
|
- domain: [thisIsATestYamlThatIsMeantToFail.example]
|
||||||
|
regexRules:
|
||||||
|
- match: "^http:"
|
||||||
|
replace: "https:"
|
||||||
|
- match: "[incomplete"`
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadRulesFromRemoteFile(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
defer app.Shutdown()
|
||||||
|
|
||||||
|
app.Get("/valid-config.yml", func(c *fiber.Ctx) error {
|
||||||
|
c.SendString(validYAML)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
app.Get("/invalid-config.yml", func(c *fiber.Ctx) error {
|
||||||
|
c.SendString(invalidYAML)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
app.Get("/valid-config.gz", func(c *fiber.Ctx) error {
|
||||||
|
c.Set("Content-Type", "application/octet-stream")
|
||||||
|
rs, err := loadRuleFromString(validYAML)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to load valid yaml from string: %s", err.Error())
|
||||||
|
}
|
||||||
|
s, err := rs.GzipYaml()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to load gzip serialize yaml: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.SendStream(s)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to stream gzip serialized yaml: %s", err.Error())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start the server in a goroutine
|
||||||
|
go func() {
|
||||||
|
if err := app.Listen("127.0.0.1:9999"); err != nil {
|
||||||
|
t.Errorf("Server failed to start: %s", err.Error())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for the server to start
|
||||||
|
time.Sleep(time.Second * 1)
|
||||||
|
|
||||||
|
rs, err := NewRuleset("http://127.0.0.1:9999/valid-config.yml")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to load plaintext ruleset from http server: %s", err.Error())
|
||||||
|
}
|
||||||
|
assert.Equal(t, rs[0].Domain, "example.com")
|
||||||
|
|
||||||
|
rs, err = NewRuleset("http://127.0.0.1:9999/valid-config.gz")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to load gzipped ruleset from http server: %s", err.Error())
|
||||||
|
}
|
||||||
|
assert.Equal(t, rs[0].Domain, "example.com")
|
||||||
|
|
||||||
|
os.Setenv("RULESET", "http://127.0.0.1:9999/valid-config.gz")
|
||||||
|
rs = NewRulesetFromEnv()
|
||||||
|
if !assert.Equal(t, rs[0].Domain, "example.com") {
|
||||||
|
t.Error("expected no errors loading ruleset from gzip url using environment variable, but got one")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadRuleFromString(yaml string) (RuleSet, error) {
|
||||||
|
// Create a temporary file and load it
|
||||||
|
tmpFile, _ := os.CreateTemp("", "ruleset*.yaml")
|
||||||
|
defer os.Remove(tmpFile.Name())
|
||||||
|
tmpFile.WriteString(yaml)
|
||||||
|
rs := RuleSet{}
|
||||||
|
err := rs.loadRulesFromLocalFile(tmpFile.Name())
|
||||||
|
return rs, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoadRulesFromLocalFile tests the loading of rules from a local YAML file.
|
||||||
|
func TestLoadRulesFromLocalFile(t *testing.T) {
|
||||||
|
rs, err := loadRuleFromString(validYAML)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to load rules from valid YAML: %s", err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, rs[0].Domain, "example.com")
|
||||||
|
assert.Equal(t, rs[0].RegexRules[0].Match, "^http:")
|
||||||
|
assert.Equal(t, rs[0].RegexRules[0].Replace, "https:")
|
||||||
|
|
||||||
|
_, err = loadRuleFromString(invalidYAML)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected an error when loading invalid YAML, but got none")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoadRulesFromLocalDir tests the loading of rules from a local nested directory full of yaml rulesets
|
||||||
|
func TestLoadRulesFromLocalDir(t *testing.T) {
|
||||||
|
// Create a temporary directory
|
||||||
|
baseDir, err := os.MkdirTemp("", "ruleset_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temporary directory: %s", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(baseDir)
|
||||||
|
|
||||||
|
// Create a nested subdirectory
|
||||||
|
nestedDir := filepath.Join(baseDir, "nested")
|
||||||
|
err = os.Mkdir(nestedDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create nested directory: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a nested subdirectory
|
||||||
|
nestedTwiceDir := filepath.Join(nestedDir, "nestedTwice")
|
||||||
|
err = os.Mkdir(nestedTwiceDir, 0755)
|
||||||
|
|
||||||
|
testCases := []string{"test.yaml", "test2.yaml", "test-3.yaml", "test 4.yaml", "1987.test.yaml.yml", "foobar.example.com.yaml", "foobar.com.yml"}
|
||||||
|
for _, fileName := range testCases {
|
||||||
|
filePath := filepath.Join(nestedDir, "2x-"+fileName)
|
||||||
|
os.WriteFile(filePath, []byte(validYAML), 0644)
|
||||||
|
filePath = filepath.Join(nestedDir, fileName)
|
||||||
|
os.WriteFile(filePath, []byte(validYAML), 0644)
|
||||||
|
filePath = filepath.Join(baseDir, "base-"+fileName)
|
||||||
|
os.WriteFile(filePath, []byte(validYAML), 0644)
|
||||||
|
}
|
||||||
|
rs := RuleSet{}
|
||||||
|
err = rs.loadRulesFromLocalDir(baseDir)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, rs.Count(), len(testCases)*3)
|
||||||
|
|
||||||
|
for _, rule := range rs {
|
||||||
|
assert.Equal(t, rule.Domain, "example.com")
|
||||||
|
assert.Equal(t, rule.RegexRules[0].Match, "^http:")
|
||||||
|
assert.Equal(t, rule.RegexRules[0].Replace, "https:")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user