finish ruleset_v2 implementation

This commit is contained in:
Kevin Pham
2023-12-05 14:02:54 -06:00
parent 9d77c63697
commit b2f6cf9f1d
4 changed files with 120 additions and 692 deletions

View File

@@ -13,6 +13,7 @@ import (
"strings"
"gopkg.in/yaml.v3"
//"encoding/json"
)
type IRuleset interface {
@@ -21,12 +22,68 @@ type IRuleset interface {
}
type Ruleset struct {
rulesetPath string
rules map[string]Rule
Rules []Rule `json:"rules" yaml:"rules"`
_rulemap map[string]*Rule // internal map for fast lookups; points at a rule in the Rules slice
}
func (rs Ruleset) GetRule(url *url.URL) (rule Rule, exists bool) {
rule, exists = rs.rules[url.Hostname()]
func (rs *Ruleset) UnmarshalYAML(unmarshal func(interface{}) error) error {
type AuxRuleset struct {
Rules []Rule `yaml:"rules"`
}
yr := &AuxRuleset{}
if err := unmarshal(&yr); err != nil {
return err
}
rs._rulemap = make(map[string]*Rule)
rs.Rules = yr.Rules
// create a map of pointers to rules loaded above based on domain string keys
// this way we don't have two copies of the rule in ruleset
for i, rule := range rs.Rules {
rulePtr := &rs.Rules[i]
for _, domain := range rule.Domains {
rs._rulemap[domain] = rulePtr
if !strings.HasPrefix(domain, "www.") {
rs._rulemap["www."+domain] = rulePtr
}
}
}
return nil
}
// MarshalYAML implements the yaml.Marshaler interface.
// It customizes the marshaling of a Ruleset object into YAML
func (rs *Ruleset) MarshalYAML() (interface{}, error) {
type AuxRule struct {
Domains []string `yaml:"domains"`
RequestModifications []_rqm `yaml:"requestmodifications"`
ResponseModifications []_rsm `yaml:"responsemodifications"`
}
type Aux struct {
Rules []AuxRule `yaml:"rules"`
}
aux := Aux{}
for _, rule := range rs.Rules {
auxRule := AuxRule{
Domains: rule.Domains,
RequestModifications: rule._rqms,
ResponseModifications: rule._rsms,
}
aux.Rules = append(aux.Rules, auxRule)
}
out, err := yaml.Marshal(&aux)
return out, err
}
func (rs Ruleset) GetRule(url *url.URL) (rule *Rule, exists bool) {
rule, exists = rs._rulemap[url.Hostname()]
return rule, exists
}
@@ -38,8 +95,8 @@ func (rs Ruleset) HasRule(url *url.URL) bool {
// NewRuleset loads a new RuleSet from a path
func NewRuleset(path string) (Ruleset, error) {
rs := Ruleset{
rulesetPath: path,
rules: map[string]Rule{},
_rulemap: map[string]*Rule{},
Rules: []Rule{},
}
switch {
@@ -88,18 +145,20 @@ func (rs *Ruleset) loadRulesFromLocalDir(path string) error {
return nil
}
isYAML := filepath.Ext(path) == "yaml" || filepath.Ext(path) == "yml"
isYAML := filepath.Ext(path) == ".yaml" || filepath.Ext(path) == ".yml"
if !isYAML {
return nil
}
err = rs.loadRulesFromLocalFile(path)
tmpRs := Ruleset{_rulemap: make(map[string]*Rule)}
err = tmpRs.loadRulesFromLocalFile(path)
if err != nil {
log.Printf("WARN: failed to load directory ruleset '%s': %s, skipping", path, err)
return nil
}
rs.Rules = append(rs.Rules, tmpRs.Rules...)
log.Printf("INFO: loaded ruleset %s\n", path)
//log.Printf("INFO: loaded ruleset %s\n", path)
return nil
})
@@ -120,21 +179,12 @@ func (rs *Ruleset) loadRulesFromLocalFile(path string) error {
return errors.Join(e, err)
}
rule := Rule{}
err = yaml.Unmarshal(yamlFile, &rule)
err = yaml.Unmarshal(yamlFile, rs)
if err != nil {
e := fmt.Errorf("failed to load rules from local file, possible syntax error in '%s' - %s", path, err)
ee := errors.Join(e, err)
debugPrintRule(string(yamlFile), ee)
return ee
}
for _, domain := range rule.Domains {
rs.rules[domain] = rule
if !strings.HasSuffix(domain, "www.") {
rs.rules["www."+domain] = rule
}
debugPrintRule(string(yamlFile), e)
return e
}
return nil
@@ -144,7 +194,6 @@ func (rs *Ruleset) loadRulesFromLocalFile(path string) error {
// It supports plain and gzip compressed content.
// Returns an error if there's an issue accessing the URL or if there's a syntax error in the YAML.
func (rs *Ruleset) loadRulesFromRemoteFile(rulesURL string) error {
rule := Rule{}
resp, err := http.Get(rulesURL)
if err != nil {
@@ -160,7 +209,7 @@ func (rs *Ruleset) loadRulesFromRemoteFile(rulesURL string) error {
var reader io.Reader
// in case remote server did not set content-encoding gzip header
isGzip := strings.HasSuffix(rulesURL, ".gz") || strings.HasSuffix(rulesURL, ".gzip")
isGzip := strings.HasSuffix(rulesURL, ".gz") || strings.HasSuffix(rulesURL, ".gzip") || resp.Header.Get("content-encoding") == "gzip"
if isGzip {
reader, err = gzip.NewReader(resp.Body)
@@ -171,24 +220,12 @@ func (rs *Ruleset) loadRulesFromRemoteFile(rulesURL string) error {
reader = resp.Body
}
err = yaml.NewDecoder(reader).Decode(&rule)
err = yaml.NewDecoder(reader).Decode(&rs)
if err != nil {
return fmt.Errorf("failed to load rules from remote url '%s' with status code '%s' and possible syntax error - %s", rulesURL, resp.Status, err)
}
if rs.rules == nil {
fmt.Println("nilmap")
rs.rules = make(map[string]Rule)
}
for _, domain := range rule.Domains {
rs.rules[domain] = rule
if !strings.HasSuffix(domain, "www.") {
rs.rules["www."+domain] = rule
}
}
return nil
}
@@ -204,31 +241,11 @@ func (rs *Ruleset) Yaml() (string, error) {
return string(y), nil
}
// GzipYaml returns an io.Reader that streams the Gzip-compressed YAML representation of the RuleSet.
func (rs *Ruleset) GzipYaml() (io.Reader, error) {
pr, pw := io.Pipe()
go func() {
defer pw.Close()
gw := gzip.NewWriter(pw)
defer gw.Close()
if err := yaml.NewEncoder(gw).Encode(rs); err != nil {
gw.Close() // Ensure to close the gzip writer
pw.CloseWithError(err)
return
}
}()
return pr, nil
}
// Domains extracts and returns a slice of all domains present in the RuleSet.
func (rs *Ruleset) Domains() []string {
var domains []string
for domain := range rs.rules {
domains = append(domains, domain)
for _, rule := range rs.Rules {
domains = append(domains, rule.Domains...)
}
return domains
}
@@ -240,7 +257,7 @@ func (rs *Ruleset) DomainCount() int {
// Count returns the total number of rules in the RuleSet.
func (rs *Ruleset) Count() int {
return len(rs.rules)
return len(rs.Rules)
}
// PrintStats logs the number of rules and domains loaded in the RuleSet.

View File

@@ -1,6 +1,7 @@
package ruleset_v2
import (
"fmt"
"net/url"
"os"
"path/filepath"
@@ -13,38 +14,40 @@ import (
var (
validYAML = `
domains:
- example.com
- www.example.com
responsemodifications:
- name: APIContent
params: []
- name: SetContentSecurityPolicy
params:
- foobar
- name: SetIncomingCookie
params:
- authorization-bearer
- hunter2
requestmodifications:
- name: ForwardRequestHeaders
params: []
rules:
- domains:
- example.com
- www.example.com
responsemodifications:
- name: APIContent
params: []
- name: SetContentSecurityPolicy
params:
- foobar
- name: SetIncomingCookie
params:
- authorization-bearer
- hunter2
requestmodifications:
- name: ForwardRequestHeaders
params: []
`
invalidYAML = `
domains:
- example.com
- www.example.com
responsemodifications:
- name: APIContent
- name: SetContentSecurityPolicy
- name: INVALIDSetIncomingCookie
params:
- authorization-bearer
- hunter2
requestmodifications:
- name: ForwardRequestHeaders
params: []
rules:
domains:
- example.com
- www.example.com
responsemodifications:
- name: APIContent
- name: SetContentSecurityPolicy
- name: INVALIDSetIncomingCookie
params:
- authorization-bearer
- hunter2
requestmodifications:
- name: ForwardRequestHeaders
params: []
`
)
@@ -62,26 +65,6 @@ func TestLoadRulesFromRemoteFile(t *testing.T) {
return nil
})
app.Get("/valid-config.gz", func(c *fiber.Ctx) error {
c.Set("Content-Type", "application/octet-stream")
rs, err := loadRuleFromString(validYAML)
if err != nil {
t.Errorf("failed to load valid yaml from string: %s", err.Error())
}
s, err := rs.GzipYaml()
if err != nil {
t.Errorf("failed to load gzip serialize yaml: %s", err.Error())
}
err = c.SendStream(s)
if err != nil {
t.Errorf("failed to stream gzip serialized yaml: %s", err.Error())
}
return nil
})
// Start the server in a goroutine
go func() {
if err := app.Listen("127.0.0.1:9999"); err != nil {
@@ -102,25 +85,21 @@ func TestLoadRulesFromRemoteFile(t *testing.T) {
assert.True(t, exists, "expected example.com rule to be present")
assert.Equal(t, r.Domains[0], "example.com")
u, _ = url.Parse("http://www.www.example.com")
u, _ = url.Parse("http://www.www.foobar.com")
_, exists = rs.GetRule(u)
assert.False(t, exists, "expected www.www.example.com rule to NOT be present")
rs, err = NewRuleset("http://127.0.0.1:9999/valid-config.gz")
if err != nil {
t.Errorf("failed to load gzipped ruleset from http server: %s", err.Error())
}
assert.False(t, exists, "expected www.www.foobar.com rule to NOT be present")
u, _ = url.Parse("http://example.com")
r, exists = rs.GetRule(u)
assert.Equal(t, r.Domains[0], "example.com")
os.Setenv("RULESET", "http://127.0.0.1:9999/valid-config.gz")
os.Setenv("RULESET", "http://127.0.0.1:9999/valid-config.yml")
rs = NewRulesetFromEnv()
r, exists = rs.GetRule(u)
assert.True(t, exists, "expected example.com rule to be present")
assert.True(t, exists, "expected example.com rule to be present from env")
if !assert.Equal(t, r.Domains[0], "example.com") {
t.Error("expected no errors loading ruleset from gzip url using environment variable, but got one")
t.Error("expected no errors loading ruleset from url using environment variable, but got one")
}
}
@@ -132,7 +111,10 @@ func loadRuleFromString(yaml string) (Ruleset, error) {
tmpFile.WriteString(yaml)
rs := Ruleset{}
rs := Ruleset{
_rulemap: map[string]*Rule{},
Rules: []Rule{},
}
err := rs.loadRulesFromLocalFile(tmpFile.Name())
return rs, err
@@ -189,8 +171,9 @@ func TestLoadRulesFromLocalDir(t *testing.T) {
}
rs := Ruleset{}
fmt.Println(baseDir)
err = rs.loadRulesFromLocalDir(baseDir)
assert.NoError(t, err)
assert.Equal(t, rs.Count(), len(testCases)*3)
assert.Equal(t, len(testCases)*3, rs.Count())
}