197 lines
4.8 KiB
Go
197 lines
4.8 KiB
Go
package ruleset_v2
|
|
|
|
import (
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
var (
|
|
validYAML = `
|
|
domains:
|
|
- example.com
|
|
- www.example.com
|
|
responsemodifications:
|
|
- name: APIContent
|
|
params: []
|
|
- name: SetContentSecurityPolicy
|
|
params:
|
|
- foobar
|
|
- name: SetIncomingCookie
|
|
params:
|
|
- authorization-bearer
|
|
- hunter2
|
|
requestmodifications:
|
|
- name: ForwardRequestHeaders
|
|
params: []
|
|
`
|
|
|
|
invalidYAML = `
|
|
domains:
|
|
- example.com
|
|
- www.example.com
|
|
responsemodifications:
|
|
- name: APIContent
|
|
- name: SetContentSecurityPolicy
|
|
- name: INVALIDSetIncomingCookie
|
|
params:
|
|
- authorization-bearer
|
|
- hunter2
|
|
requestmodifications:
|
|
- name: ForwardRequestHeaders
|
|
params: []
|
|
`
|
|
)
|
|
|
|
func TestLoadRulesFromRemoteFile(t *testing.T) {
|
|
app := fiber.New()
|
|
defer app.Shutdown()
|
|
|
|
app.Get("/valid-config.yml", func(c *fiber.Ctx) error {
|
|
c.SendString(validYAML)
|
|
return nil
|
|
})
|
|
|
|
app.Get("/invalid-config.yml", func(c *fiber.Ctx) error {
|
|
c.SendString(invalidYAML)
|
|
return nil
|
|
})
|
|
|
|
app.Get("/valid-config.gz", func(c *fiber.Ctx) error {
|
|
c.Set("Content-Type", "application/octet-stream")
|
|
|
|
rs, err := loadRuleFromString(validYAML)
|
|
if err != nil {
|
|
t.Errorf("failed to load valid yaml from string: %s", err.Error())
|
|
}
|
|
|
|
s, err := rs.GzipYaml()
|
|
if err != nil {
|
|
t.Errorf("failed to load gzip serialize yaml: %s", err.Error())
|
|
}
|
|
|
|
err = c.SendStream(s)
|
|
if err != nil {
|
|
t.Errorf("failed to stream gzip serialized yaml: %s", err.Error())
|
|
}
|
|
return nil
|
|
})
|
|
|
|
// Start the server in a goroutine
|
|
go func() {
|
|
if err := app.Listen("127.0.0.1:9999"); err != nil {
|
|
t.Errorf("Server failed to start: %s", err.Error())
|
|
}
|
|
}()
|
|
|
|
// Wait for the server to start
|
|
time.Sleep(time.Second * 1)
|
|
|
|
rs, err := NewRuleset("http://127.0.0.1:9999/valid-config.yml")
|
|
if err != nil {
|
|
t.Errorf("failed to load plaintext ruleset from http server: %s", err.Error())
|
|
}
|
|
|
|
u, _ := url.Parse("http://example.com")
|
|
r, exists := rs.GetRule(u)
|
|
assert.True(t, exists, "expected example.com rule to be present")
|
|
assert.Equal(t, r.Domains[0], "example.com")
|
|
|
|
u, _ = url.Parse("http://www.www.example.com")
|
|
_, exists = rs.GetRule(u)
|
|
assert.False(t, exists, "expected www.www.example.com rule to NOT be present")
|
|
|
|
rs, err = NewRuleset("http://127.0.0.1:9999/valid-config.gz")
|
|
if err != nil {
|
|
t.Errorf("failed to load gzipped ruleset from http server: %s", err.Error())
|
|
}
|
|
|
|
r, exists = rs.GetRule(u)
|
|
assert.Equal(t, r.Domains[0], "example.com")
|
|
|
|
os.Setenv("RULESET", "http://127.0.0.1:9999/valid-config.gz")
|
|
|
|
rs = NewRulesetFromEnv()
|
|
r, exists = rs.GetRule(u)
|
|
assert.True(t, exists, "expected example.com rule to be present")
|
|
if !assert.Equal(t, r.Domains[0], "example.com") {
|
|
t.Error("expected no errors loading ruleset from gzip url using environment variable, but got one")
|
|
}
|
|
}
|
|
|
|
func loadRuleFromString(yaml string) (Ruleset, error) {
|
|
// Create a temporary file and load it
|
|
tmpFile, _ := os.CreateTemp("", "ruleset*.yaml")
|
|
|
|
defer os.Remove(tmpFile.Name())
|
|
|
|
tmpFile.WriteString(yaml)
|
|
|
|
rs := Ruleset{}
|
|
err := rs.loadRulesFromLocalFile(tmpFile.Name())
|
|
|
|
return rs, err
|
|
}
|
|
|
|
// TestLoadRulesFromLocalFile tests the loading of rules from a local YAML file.
|
|
func TestLoadRulesFromLocalFile(t *testing.T) {
|
|
_, err := loadRuleFromString(validYAML)
|
|
if err != nil {
|
|
t.Errorf("Failed to load rules from valid YAML: %s", err)
|
|
}
|
|
|
|
_, err = loadRuleFromString(invalidYAML)
|
|
if err == nil {
|
|
t.Errorf("Expected an error when loading invalid YAML, but got none")
|
|
}
|
|
}
|
|
|
|
// TestLoadRulesFromLocalDir tests the loading of rules from a local nested directory full of yaml rulesets
|
|
func TestLoadRulesFromLocalDir(t *testing.T) {
|
|
// Create a temporary directory
|
|
baseDir, err := os.MkdirTemp("", "ruleset_test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temporary directory: %s", err)
|
|
}
|
|
|
|
defer os.RemoveAll(baseDir)
|
|
|
|
// Create a nested subdirectory
|
|
nestedDir := filepath.Join(baseDir, "nested")
|
|
err = os.Mkdir(nestedDir, 0o755)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Failed to create nested directory: %s", err)
|
|
}
|
|
|
|
// Create a nested subdirectory
|
|
nestedTwiceDir := filepath.Join(nestedDir, "nestedTwice")
|
|
err = os.Mkdir(nestedTwiceDir, 0o755)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create twice-nested directory: %s", err)
|
|
}
|
|
|
|
testCases := []string{"test.yaml", "test2.yaml", "test-3.yaml", "test 4.yaml", "1987.test.yaml.yml", "foobar.example.com.yaml", "foobar.com.yml"}
|
|
for _, fileName := range testCases {
|
|
filePath := filepath.Join(nestedDir, "2x-"+fileName)
|
|
os.WriteFile(filePath, []byte(validYAML), 0o644)
|
|
|
|
filePath = filepath.Join(nestedDir, fileName)
|
|
os.WriteFile(filePath, []byte(validYAML), 0o644)
|
|
|
|
filePath = filepath.Join(baseDir, "base-"+fileName)
|
|
os.WriteFile(filePath, []byte(validYAML), 0o644)
|
|
}
|
|
|
|
rs := Ruleset{}
|
|
err = rs.loadRulesFromLocalDir(baseDir)
|
|
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, rs.Count(), len(testCases)*3)
|
|
}
|