add ability to load rulesets from directory
This commit is contained in:
@@ -10,34 +10,52 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"ladder/pkg/ruleset"
|
||||
|
||||
"github.com/PuerkitoBio/goquery"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
UserAgent = getenv("USER_AGENT", "Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)")
|
||||
ForwardedFor = getenv("X_FORWARDED_FOR", "66.249.66.1")
|
||||
rulesSet = loadRules()
|
||||
allowedDomains = strings.Split(os.Getenv("ALLOWED_DOMAINS"), ",")
|
||||
rulesSet = ruleset.NewRulesetFromEnv()
|
||||
allowedDomains = []string{}
|
||||
)
|
||||
|
||||
func ProxySite(c *fiber.Ctx) error {
|
||||
// Get the url from the URL
|
||||
url := c.Params("*")
|
||||
func init() {
|
||||
allowedDomains = strings.Split(os.Getenv("ALLOWED_DOMAINS"), ",")
|
||||
if os.Getenv("ALLOWED_DOMAINS_RULESET") == "true" {
|
||||
allowedDomains = append(allowedDomains, rulesSet.Domains()...)
|
||||
}
|
||||
}
|
||||
|
||||
queries := c.Queries()
|
||||
body, _, resp, err := fetchSite(url, queries)
|
||||
if err != nil {
|
||||
log.Println("ERROR:", err)
|
||||
c.SendStatus(fiber.StatusInternalServerError)
|
||||
return c.SendString(err.Error())
|
||||
func ProxySite(rulesetPath string) fiber.Handler {
|
||||
if rulesetPath != "" {
|
||||
rs, err := ruleset.NewRuleset(rulesetPath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
rulesSet = rs
|
||||
}
|
||||
|
||||
c.Set("Content-Type", resp.Header.Get("Content-Type"))
|
||||
c.Set("Content-Security-Policy", resp.Header.Get("Content-Security-Policy"))
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Get the url from the URL
|
||||
url := c.Params("*")
|
||||
|
||||
return c.SendString(body)
|
||||
queries := c.Queries()
|
||||
body, _, resp, err := fetchSite(url, queries)
|
||||
if err != nil {
|
||||
log.Println("ERROR:", err)
|
||||
c.SendStatus(fiber.StatusInternalServerError)
|
||||
return c.SendString(err.Error())
|
||||
}
|
||||
|
||||
c.Set("Content-Type", resp.Header.Get("Content-Type"))
|
||||
c.Set("Content-Security-Policy", resp.Header.Get("Content-Security-Policy"))
|
||||
|
||||
return c.SendString(body)
|
||||
}
|
||||
}
|
||||
|
||||
func fetchSite(urlpath string, queries map[string]string) (string, *http.Request, *http.Response, error) {
|
||||
@@ -122,7 +140,7 @@ func fetchSite(urlpath string, queries map[string]string) (string, *http.Request
|
||||
return body, req, resp, nil
|
||||
}
|
||||
|
||||
func rewriteHtml(bodyB []byte, u *url.URL, rule Rule) string {
|
||||
func rewriteHtml(bodyB []byte, u *url.URL, rule ruleset.Rule) string {
|
||||
// Rewrite the HTML
|
||||
body := string(bodyB)
|
||||
|
||||
@@ -156,63 +174,11 @@ func getenv(key, fallback string) string {
|
||||
return value
|
||||
}
|
||||
|
||||
func loadRules() RuleSet {
|
||||
rulesUrl := os.Getenv("RULESET")
|
||||
if rulesUrl == "" {
|
||||
RulesList := RuleSet{}
|
||||
return RulesList
|
||||
}
|
||||
log.Println("Loading rules")
|
||||
|
||||
var ruleSet RuleSet
|
||||
if strings.HasPrefix(rulesUrl, "http") {
|
||||
|
||||
resp, err := http.Get(rulesUrl)
|
||||
if err != nil {
|
||||
log.Println("ERROR:", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Println("ERROR:", resp.StatusCode, rulesUrl)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Println("ERROR:", err)
|
||||
}
|
||||
yaml.Unmarshal(body, &ruleSet)
|
||||
|
||||
if err != nil {
|
||||
log.Println("ERROR:", err)
|
||||
}
|
||||
} else {
|
||||
yamlFile, err := os.ReadFile(rulesUrl)
|
||||
if err != nil {
|
||||
log.Println("ERROR:", err)
|
||||
}
|
||||
yaml.Unmarshal(yamlFile, &ruleSet)
|
||||
}
|
||||
|
||||
domains := []string{}
|
||||
for _, rule := range ruleSet {
|
||||
|
||||
domains = append(domains, rule.Domain)
|
||||
domains = append(domains, rule.Domains...)
|
||||
if os.Getenv("ALLOWED_DOMAINS_RULESET") == "true" {
|
||||
allowedDomains = append(allowedDomains, domains...)
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("Loaded ", len(ruleSet), " rules for", len(domains), "Domains")
|
||||
return ruleSet
|
||||
}
|
||||
|
||||
func fetchRule(domain string, path string) Rule {
|
||||
func fetchRule(domain string, path string) ruleset.Rule {
|
||||
if len(rulesSet) == 0 {
|
||||
return Rule{}
|
||||
return ruleset.Rule{}
|
||||
}
|
||||
rule := Rule{}
|
||||
rule := ruleset.Rule{}
|
||||
for _, rule := range rulesSet {
|
||||
domains := rule.Domains
|
||||
if rule.Domain != "" {
|
||||
@@ -231,7 +197,7 @@ func fetchRule(domain string, path string) Rule {
|
||||
return rule
|
||||
}
|
||||
|
||||
func applyRules(body string, rule Rule) string {
|
||||
func applyRules(body string, rule ruleset.Rule) string {
|
||||
if len(rulesSet) == 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"ladder/pkg/ruleset"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -13,7 +14,7 @@ import (
|
||||
|
||||
func TestProxySite(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Get("/:url", ProxySite)
|
||||
app.Get("/:url", ProxySite(""))
|
||||
|
||||
req := httptest.NewRequest("GET", "/https://example.com", nil)
|
||||
resp, err := app.Test(req)
|
||||
@@ -51,7 +52,7 @@ func TestRewriteHtml(t *testing.T) {
|
||||
</html>
|
||||
`
|
||||
|
||||
actual := rewriteHtml(bodyB, u, Rule{})
|
||||
actual := rewriteHtml(bodyB, u, ruleset.Rule{})
|
||||
assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user