Merge branch 'main' into 39-request-header-fields-too-large

This commit is contained in:
mms-gianni
2023-11-15 20:45:59 +01:00
committed by GitHub
6 changed files with 488 additions and 119 deletions

View File

@@ -10,18 +10,26 @@ import (
"regexp"
"strings"
"ladder/pkg/ruleset"
"github.com/PuerkitoBio/goquery"
"github.com/gofiber/fiber/v2"
"gopkg.in/yaml.v3"
)
var (
UserAgent = getenv("USER_AGENT", "Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)")
ForwardedFor = getenv("X_FORWARDED_FOR", "66.249.66.1")
rulesSet = loadRules()
allowedDomains = strings.Split(os.Getenv("ALLOWED_DOMAINS"), ",")
rulesSet = ruleset.NewRulesetFromEnv()
allowedDomains = []string{}
)
func init() {
allowedDomains = strings.Split(os.Getenv("ALLOWED_DOMAINS"), ",")
if os.Getenv("ALLOWED_DOMAINS_RULESET") == "true" {
allowedDomains = append(allowedDomains, rulesSet.Domains()...)
}
}
// extracts a URL from the request ctx. If the URL in the request
// is a relative path, it reconstructs the full URL using the referer header.
func extractUrl(c *fiber.Ctx) (string, error) {
@@ -75,29 +83,39 @@ func extractUrl(c *fiber.Ctx) (string, error) {
}
func ProxySite(c *fiber.Ctx) error {
// Get the url from the URL
url, err := extractUrl(c)
if err != nil {
log.Println("ERROR In URL extraction:", err)
func ProxySite(rulesetPath string) fiber.Handler {
if rulesetPath != "" {
rs, err := ruleset.NewRuleset(rulesetPath)
if err != nil {
panic(err)
}
rulesSet = rs
}
queries := c.Queries()
body, _, resp, err := fetchSite(url, queries)
if err != nil {
log.Println("ERROR:", err)
c.SendStatus(fiber.StatusInternalServerError)
return c.SendString(err.Error())
}
return func(c *fiber.Ctx) error {
// Get the url from the URL
url, err := extractUrl(c)
if err != nil {
log.Println("ERROR In URL extraction:", err)
}
queries := c.Queries()
body, _, resp, err := fetchSite(url, queries)
if err != nil {
log.Println("ERROR:", err)
c.SendStatus(fiber.StatusInternalServerError)
return c.SendString(err.Error())
}
c.Cookie(&fiber.Cookie{})
c.Set("Content-Type", resp.Header.Get("Content-Type"))
c.Set("Content-Security-Policy", resp.Header.Get("Content-Security-Policy"))
return c.SendString(body)
return c.SendString(body)
}
}
func modifyURL(uri string, rule Rule) (string, error) {
func modifyURL(uri string, rule ruleset.Rule) (string, error) {
newUrl, err := url.Parse(uri)
if err != nil {
return "", err
@@ -205,7 +223,7 @@ func fetchSite(urlpath string, queries map[string]string) (string, *http.Request
}
if rule.Headers.CSP != "" {
log.Println(rule.Headers.CSP)
//log.Println(rule.Headers.CSP)
resp.Header.Set("Content-Security-Policy", rule.Headers.CSP)
}
@@ -214,7 +232,7 @@ func fetchSite(urlpath string, queries map[string]string) (string, *http.Request
return body, req, resp, nil
}
func rewriteHtml(bodyB []byte, u *url.URL, rule Rule) string {
func rewriteHtml(bodyB []byte, u *url.URL, rule ruleset.Rule) string {
// Rewrite the HTML
body := string(bodyB)
@@ -248,63 +266,11 @@ func getenv(key, fallback string) string {
return value
}
func loadRules() RuleSet {
rulesUrl := os.Getenv("RULESET")
if rulesUrl == "" {
RulesList := RuleSet{}
return RulesList
}
log.Println("Loading rules")
var ruleSet RuleSet
if strings.HasPrefix(rulesUrl, "http") {
resp, err := http.Get(rulesUrl)
if err != nil {
log.Println("ERROR:", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
log.Println("ERROR:", resp.StatusCode, rulesUrl)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Println("ERROR:", err)
}
yaml.Unmarshal(body, &ruleSet)
if err != nil {
log.Println("ERROR:", err)
}
} else {
yamlFile, err := os.ReadFile(rulesUrl)
if err != nil {
log.Println("ERROR:", err)
}
yaml.Unmarshal(yamlFile, &ruleSet)
}
domains := []string{}
for _, rule := range ruleSet {
domains = append(domains, rule.Domain)
domains = append(domains, rule.Domains...)
if os.Getenv("ALLOWED_DOMAINS_RULESET") == "true" {
allowedDomains = append(allowedDomains, domains...)
}
}
log.Println("Loaded ", len(ruleSet), " rules for", len(domains), "Domains")
return ruleSet
}
func fetchRule(domain string, path string) Rule {
func fetchRule(domain string, path string) ruleset.Rule {
if len(rulesSet) == 0 {
return Rule{}
return ruleset.Rule{}
}
rule := Rule{}
rule := ruleset.Rule{}
for _, rule := range rulesSet {
domains := rule.Domains
if rule.Domain != "" {
@@ -323,7 +289,7 @@ func fetchRule(domain string, path string) Rule {
return rule
}
func applyRules(body string, rule Rule) string {
func applyRules(body string, rule ruleset.Rule) string {
if len(rulesSet) == 0 {
return body
}