add ruleset_v2 loading code
This commit is contained in:
@@ -102,7 +102,6 @@ requestmodifications:
|
|||||||
if len(rule.RequestModifications) != 1 {
|
if len(rule.RequestModifications) != 1 {
|
||||||
t.Errorf("expected number of RequestModifications to be 1, got %d", len(rule.RequestModifications))
|
t.Errorf("expected number of RequestModifications to be 1, got %d", len(rule.RequestModifications))
|
||||||
}
|
}
|
||||||
fmt.Println(rule.RequestModifications[0].Name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRuleMarshalYAML(t *testing.T) {
|
func TestRuleMarshalYAML(t *testing.T) {
|
||||||
|
|||||||
@@ -1,7 +1,18 @@
|
|||||||
package ruleset_v2
|
package ruleset_v2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"compress/gzip"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IRuleset interface {
|
type IRuleset interface {
|
||||||
@@ -14,19 +25,234 @@ type Ruleset struct {
|
|||||||
rules map[string]Rule
|
rules map[string]Rule
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs Ruleset) GetRule(url url.URL) (rule Rule, exists bool) {
|
func (rs Ruleset) GetRule(url *url.URL) (rule Rule, exists bool) {
|
||||||
rule, exists = rs.rules[url.Hostname()]
|
rule, exists = rs.rules[url.Hostname()]
|
||||||
return rule, exists
|
return rule, exists
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs Ruleset) HasRule(url url.URL) bool {
|
func (rs Ruleset) HasRule(url *url.URL) bool {
|
||||||
_, exists := rs.GetRule(url)
|
_, exists := rs.GetRule(url)
|
||||||
return exists
|
return exists
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewRuleset loads a new RuleSet from a path
|
||||||
func NewRuleset(path string) (Ruleset, error) {
|
func NewRuleset(path string) (Ruleset, error) {
|
||||||
rs := Ruleset{
|
rs := Ruleset{
|
||||||
rulesetPath: path,
|
rulesetPath: path,
|
||||||
|
rules: map[string]Rule{},
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://"):
|
||||||
|
err := rs.loadRulesFromRemoteFile(path)
|
||||||
|
return rs, err
|
||||||
|
default:
|
||||||
|
err := rs.loadRulesFromLocalDir(path)
|
||||||
|
return rs, err
|
||||||
}
|
}
|
||||||
return rs, nil
|
}
|
||||||
|
|
||||||
|
// NewRulesetFromEnv creates a new RuleSet based on the RULESET environment variable.
|
||||||
|
// It logs a warning and returns an empty RuleSet if the RULESET environment variable is not set.
|
||||||
|
// If the RULESET is set but the rules cannot be loaded, it panics.
|
||||||
|
func NewRulesetFromEnv() Ruleset {
|
||||||
|
rulesPath, ok := os.LookupEnv("RULESET")
|
||||||
|
if !ok {
|
||||||
|
log.Printf("WARN: No ruleset specified. Set the `RULESET` environment variable to load one for a better success rate.")
|
||||||
|
return Ruleset{}
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleSet, err := NewRuleset(rulesPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ruleSet
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadRulesFromLocalDir loads rules from a local directory specified by the path.
|
||||||
|
// It walks through the directory, loading rules from YAML files.
|
||||||
|
// Returns an error if the directory cannot be accessed
|
||||||
|
// If there is an issue loading any file, it will be skipped
|
||||||
|
func (rs *Ruleset) loadRulesFromLocalDir(path string) error {
|
||||||
|
_, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("loadRulesFromLocalDir: invalid path - %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if info.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
isYAML := filepath.Ext(path) == "yaml" || filepath.Ext(path) == "yml"
|
||||||
|
if !isYAML {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rs.loadRulesFromLocalFile(path)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("WARN: failed to load directory ruleset '%s': %s, skipping", path, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("INFO: loaded ruleset %s\n", path)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadRulesFromLocalFile loads rules from a local YAML file specified by the path.
|
||||||
|
// Returns an error if the file cannot be read or if there's a syntax error in the YAML.
|
||||||
|
func (rs *Ruleset) loadRulesFromLocalFile(path string) error {
|
||||||
|
yamlFile, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
e := fmt.Errorf("failed to read rules from local file: '%s'", path)
|
||||||
|
return errors.Join(e, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := Rule{}
|
||||||
|
err = yaml.Unmarshal(yamlFile, &rule)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
e := fmt.Errorf("failed to load rules from local file, possible syntax error in '%s' - %s", path, err)
|
||||||
|
ee := errors.Join(e, err)
|
||||||
|
debugPrintRule(string(yamlFile), ee)
|
||||||
|
return ee
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range rule.Domains {
|
||||||
|
rs.rules[domain] = rule
|
||||||
|
if !strings.HasSuffix(domain, "www.") {
|
||||||
|
rs.rules["www."+domain] = rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadRulesFromRemoteFile loads rules from a remote URL.
|
||||||
|
// It supports plain and gzip compressed content.
|
||||||
|
// Returns an error if there's an issue accessing the URL or if there's a syntax error in the YAML.
|
||||||
|
func (rs *Ruleset) loadRulesFromRemoteFile(rulesURL string) error {
|
||||||
|
rule := Rule{}
|
||||||
|
|
||||||
|
resp, err := http.Get(rulesURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load rules from remote url '%s' - %s", rulesURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return fmt.Errorf("failed to load rules from remote url (%s) on '%s' - %s", resp.Status, rulesURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reader io.Reader
|
||||||
|
|
||||||
|
// in case remote server did not set content-encoding gzip header
|
||||||
|
isGzip := strings.HasSuffix(rulesURL, ".gz") || strings.HasSuffix(rulesURL, ".gzip")
|
||||||
|
if isGzip {
|
||||||
|
reader, err = gzip.NewReader(resp.Body)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create gzip reader for URL '%s' with status code '%s': %w", rulesURL, resp.Status, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
reader = resp.Body
|
||||||
|
}
|
||||||
|
|
||||||
|
err = yaml.NewDecoder(reader).Decode(&rule)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load rules from remote url '%s' with status code '%s' and possible syntax error - %s", rulesURL, resp.Status, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rs.rules == nil {
|
||||||
|
fmt.Println("nilmap")
|
||||||
|
rs.rules = make(map[string]Rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range rule.Domains {
|
||||||
|
rs.rules[domain] = rule
|
||||||
|
if !strings.HasSuffix(domain, "www.") {
|
||||||
|
rs.rules["www."+domain] = rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ================= utility methods ==========================
|
||||||
|
|
||||||
|
// Yaml returns the ruleset as a Yaml string
|
||||||
|
func (rs *Ruleset) Yaml() (string, error) {
|
||||||
|
y, err := yaml.Marshal(rs)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(y), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GzipYaml returns an io.Reader that streams the Gzip-compressed YAML representation of the RuleSet.
|
||||||
|
func (rs *Ruleset) GzipYaml() (io.Reader, error) {
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer pw.Close()
|
||||||
|
|
||||||
|
gw := gzip.NewWriter(pw)
|
||||||
|
defer gw.Close()
|
||||||
|
|
||||||
|
if err := yaml.NewEncoder(gw).Encode(rs); err != nil {
|
||||||
|
gw.Close() // Ensure to close the gzip writer
|
||||||
|
pw.CloseWithError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return pr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Domains extracts and returns a slice of all domains present in the RuleSet.
|
||||||
|
func (rs *Ruleset) Domains() []string {
|
||||||
|
var domains []string
|
||||||
|
for domain := range rs.rules {
|
||||||
|
domains = append(domains, domain)
|
||||||
|
}
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
// DomainCount returns the count of unique domains present in the RuleSet.
|
||||||
|
func (rs *Ruleset) DomainCount() int {
|
||||||
|
return len(rs.Domains())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the total number of rules in the RuleSet.
|
||||||
|
func (rs *Ruleset) Count() int {
|
||||||
|
return len(rs.rules)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrintStats logs the number of rules and domains loaded in the RuleSet.
|
||||||
|
func (rs *Ruleset) PrintStats() {
|
||||||
|
log.Printf("INFO: Loaded %d rules for %d domains\n", rs.Count(), rs.DomainCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
// debugPrintRule is a utility function for printing a rule and associated error for debugging purposes.
|
||||||
|
func debugPrintRule(rule string, err error) {
|
||||||
|
fmt.Println("------------------------------ BEGIN DEBUG RULESET -----------------------------")
|
||||||
|
fmt.Printf("%s\n", err.Error())
|
||||||
|
fmt.Println("--------------------------------------------------------------------------------")
|
||||||
|
fmt.Println(rule)
|
||||||
|
fmt.Println("------------------------------ END DEBUG RULESET -------------------------------")
|
||||||
}
|
}
|
||||||
|
|||||||
196
proxychain/ruleset/ruleset_test.go
Normal file
196
proxychain/ruleset/ruleset_test.go
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
package ruleset_v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
validYAML = `
|
||||||
|
domains:
|
||||||
|
- example.com
|
||||||
|
- www.example.com
|
||||||
|
responsemodifications:
|
||||||
|
- name: APIContent
|
||||||
|
params: []
|
||||||
|
- name: SetContentSecurityPolicy
|
||||||
|
params:
|
||||||
|
- foobar
|
||||||
|
- name: SetIncomingCookie
|
||||||
|
params:
|
||||||
|
- authorization-bearer
|
||||||
|
- hunter2
|
||||||
|
requestmodifications:
|
||||||
|
- name: ForwardRequestHeaders
|
||||||
|
params: []
|
||||||
|
`
|
||||||
|
|
||||||
|
invalidYAML = `
|
||||||
|
domains:
|
||||||
|
- example.com
|
||||||
|
- www.example.com
|
||||||
|
responsemodifications:
|
||||||
|
- name: APIContent
|
||||||
|
- name: SetContentSecurityPolicy
|
||||||
|
- name: INVALIDSetIncomingCookie
|
||||||
|
params:
|
||||||
|
- authorization-bearer
|
||||||
|
- hunter2
|
||||||
|
requestmodifications:
|
||||||
|
- name: ForwardRequestHeaders
|
||||||
|
params: []
|
||||||
|
`
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadRulesFromRemoteFile(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
defer app.Shutdown()
|
||||||
|
|
||||||
|
app.Get("/valid-config.yml", func(c *fiber.Ctx) error {
|
||||||
|
c.SendString(validYAML)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
app.Get("/invalid-config.yml", func(c *fiber.Ctx) error {
|
||||||
|
c.SendString(invalidYAML)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
app.Get("/valid-config.gz", func(c *fiber.Ctx) error {
|
||||||
|
c.Set("Content-Type", "application/octet-stream")
|
||||||
|
|
||||||
|
rs, err := loadRuleFromString(validYAML)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to load valid yaml from string: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := rs.GzipYaml()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to load gzip serialize yaml: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.SendStream(s)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to stream gzip serialized yaml: %s", err.Error())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start the server in a goroutine
|
||||||
|
go func() {
|
||||||
|
if err := app.Listen("127.0.0.1:9999"); err != nil {
|
||||||
|
t.Errorf("Server failed to start: %s", err.Error())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for the server to start
|
||||||
|
time.Sleep(time.Second * 1)
|
||||||
|
|
||||||
|
rs, err := NewRuleset("http://127.0.0.1:9999/valid-config.yml")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to load plaintext ruleset from http server: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
u, _ := url.Parse("http://example.com")
|
||||||
|
r, exists := rs.GetRule(u)
|
||||||
|
assert.True(t, exists, "expected example.com rule to be present")
|
||||||
|
assert.Equal(t, r.Domains[0], "example.com")
|
||||||
|
|
||||||
|
u, _ = url.Parse("http://www.www.example.com")
|
||||||
|
_, exists = rs.GetRule(u)
|
||||||
|
assert.False(t, exists, "expected www.www.example.com rule to NOT be present")
|
||||||
|
|
||||||
|
rs, err = NewRuleset("http://127.0.0.1:9999/valid-config.gz")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to load gzipped ruleset from http server: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
r, exists = rs.GetRule(u)
|
||||||
|
assert.Equal(t, r.Domains[0], "example.com")
|
||||||
|
|
||||||
|
os.Setenv("RULESET", "http://127.0.0.1:9999/valid-config.gz")
|
||||||
|
|
||||||
|
rs = NewRulesetFromEnv()
|
||||||
|
r, exists = rs.GetRule(u)
|
||||||
|
assert.True(t, exists, "expected example.com rule to be present")
|
||||||
|
if !assert.Equal(t, r.Domains[0], "example.com") {
|
||||||
|
t.Error("expected no errors loading ruleset from gzip url using environment variable, but got one")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadRuleFromString(yaml string) (Ruleset, error) {
|
||||||
|
// Create a temporary file and load it
|
||||||
|
tmpFile, _ := os.CreateTemp("", "ruleset*.yaml")
|
||||||
|
|
||||||
|
defer os.Remove(tmpFile.Name())
|
||||||
|
|
||||||
|
tmpFile.WriteString(yaml)
|
||||||
|
|
||||||
|
rs := Ruleset{}
|
||||||
|
err := rs.loadRulesFromLocalFile(tmpFile.Name())
|
||||||
|
|
||||||
|
return rs, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoadRulesFromLocalFile tests the loading of rules from a local YAML file.
|
||||||
|
func TestLoadRulesFromLocalFile(t *testing.T) {
|
||||||
|
_, err := loadRuleFromString(validYAML)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to load rules from valid YAML: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = loadRuleFromString(invalidYAML)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected an error when loading invalid YAML, but got none")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoadRulesFromLocalDir tests the loading of rules from a local nested directory full of yaml rulesets
|
||||||
|
func TestLoadRulesFromLocalDir(t *testing.T) {
|
||||||
|
// Create a temporary directory
|
||||||
|
baseDir, err := os.MkdirTemp("", "ruleset_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temporary directory: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer os.RemoveAll(baseDir)
|
||||||
|
|
||||||
|
// Create a nested subdirectory
|
||||||
|
nestedDir := filepath.Join(baseDir, "nested")
|
||||||
|
err = os.Mkdir(nestedDir, 0o755)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create nested directory: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a nested subdirectory
|
||||||
|
nestedTwiceDir := filepath.Join(nestedDir, "nestedTwice")
|
||||||
|
err = os.Mkdir(nestedTwiceDir, 0o755)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create twice-nested directory: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []string{"test.yaml", "test2.yaml", "test-3.yaml", "test 4.yaml", "1987.test.yaml.yml", "foobar.example.com.yaml", "foobar.com.yml"}
|
||||||
|
for _, fileName := range testCases {
|
||||||
|
filePath := filepath.Join(nestedDir, "2x-"+fileName)
|
||||||
|
os.WriteFile(filePath, []byte(validYAML), 0o644)
|
||||||
|
|
||||||
|
filePath = filepath.Join(nestedDir, fileName)
|
||||||
|
os.WriteFile(filePath, []byte(validYAML), 0o644)
|
||||||
|
|
||||||
|
filePath = filepath.Join(baseDir, "base-"+fileName)
|
||||||
|
os.WriteFile(filePath, []byte(validYAML), 0o644)
|
||||||
|
}
|
||||||
|
|
||||||
|
rs := Ruleset{}
|
||||||
|
err = rs.loadRulesFromLocalDir(baseDir)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, rs.Count(), len(testCases)*3)
|
||||||
|
}
|
||||||
1
proxychain/ruleset/todo.md
Normal file
1
proxychain/ruleset/todo.md
Normal file
@@ -0,0 +1 @@
|
|||||||
|
ruleset loading rule tests are failing; maybe concurrency issue with assigning to nil map?
|
||||||
Reference in New Issue
Block a user