From 2dccc7ca3574bf26f93e63355df60f0b5cf65de7 Mon Sep 17 00:00:00 2001 From: Kevin Pham Date: Mon, 27 Nov 2023 12:45:23 -0600 Subject: [PATCH] add req/resp header forwarding modifiers --- proxychain/proxychain.go | 9 ++-- .../forward_request_headers.go | 43 ++++++++++++++++ .../forward_response_headers.go | 51 +++++++++++++++++++ 3 files changed, 99 insertions(+), 4 deletions(-) create mode 100644 proxychain/requestmodifers/forward_request_headers.go create mode 100644 proxychain/responsemodifers/forward_response_headers.go diff --git a/proxychain/proxychain.go b/proxychain/proxychain.go index 738aa00..98be1d0 100644 --- a/proxychain/proxychain.go +++ b/proxychain/proxychain.go @@ -90,8 +90,7 @@ type ProxyChain struct { requestModifications []RequestModification onceRequestModifications []RequestModification onceResponseModifications []ResponseModification - resultModifications []ResponseModification - htmlTokenRewriters []rr.IHTMLTokenRewriter + responseModifications []ResponseModification Ruleset *ruleset.RuleSet debugMode bool abortErr error @@ -139,7 +138,7 @@ func (chain *ProxyChain) AddOnceResponseModifications(mods ...ResponseModificati // AddResponseModifications sets the ProxyChain's response modifers // the modifier will not fire until ProxyChain.Execute() is run. func (chain *ProxyChain) AddResponseModifications(mods ...ResponseModification) *ProxyChain { - chain.resultModifications = mods + chain.responseModifications = mods return chain } @@ -339,6 +338,8 @@ func (chain *ProxyChain) _reset() { chain.Request = nil // chain.Response = nil chain.Context = nil + chain.onceResponseModifications = []ResponseModification{} + chain.onceRequestModifications = []RequestModification{} } // NewProxyChain initializes a new ProxyChain @@ -395,7 +396,7 @@ func (chain *ProxyChain) _execute() (io.Reader, error) { */ // Apply ResponseModifiers to proxychain - for _, applyResultModificationsTo := range chain.resultModifications { + for _, applyResultModificationsTo := range chain.responseModifications { err := applyResultModificationsTo(chain) if err != nil { return nil, chain.abort(err) diff --git a/proxychain/requestmodifers/forward_request_headers.go b/proxychain/requestmodifers/forward_request_headers.go new file mode 100644 index 0000000..08d1f9b --- /dev/null +++ b/proxychain/requestmodifers/forward_request_headers.go @@ -0,0 +1,43 @@ +package requestmodifers + +import ( + "ladder/proxychain" + "strings" +) + +var forwardBlacklist map[string]bool + +func init() { + forwardBlacklist = map[string]bool{ + "host": true, + "connection": true, + "keep-alive": true, + "content-length": true, + "content-encoding": true, + "transfer-encoding": true, + "referer": true, + "x-forwarded-for": true, + "x-real-ip": true, + "forwarded": true, + } +} + +// ForwardRequestHeaders forwards the requests headers sent from the client to the upstream server +func ForwardRequestHeaders(ua string) proxychain.RequestModification { + return func(chain *proxychain.ProxyChain) error { + + forwardHeaders := func(key, value []byte) { + k := strings.ToLower(string(key)) + v := string(value) + if forwardBlacklist[k] { + return + } + chain.Request.Header.Set(k, v) + } + + chain.Context.Request(). + Header.VisitAll(forwardHeaders) + + return nil + } +} diff --git a/proxychain/responsemodifers/forward_response_headers.go b/proxychain/responsemodifers/forward_response_headers.go new file mode 100644 index 0000000..5e08a58 --- /dev/null +++ b/proxychain/responsemodifers/forward_response_headers.go @@ -0,0 +1,51 @@ +package responsemodifers + +import ( + "fmt" + "ladder/proxychain" + "net/url" + "strings" +) + +var forwardBlacklist map[string]bool + +func init() { + forwardBlacklist = map[string]bool{ + "content-length": true, + "content-encoding": true, + "transfer-encoding": true, + "strict-transport-security": true, + "connection": true, + "keep-alive": true, + } +} + +// ForwardResponseHeaders forwards the response headers from the upstream server to the client +func ForwardResponseHeaders() proxychain.ResponseModification { + return func(chain *proxychain.ProxyChain) error { + + for uname, headers := range chain.Response.Header { + name := strings.ToLower(uname) + if forwardBlacklist[name] { + continue + } + + // patch location header to forward to proxy instead + if name == "location" { + u, err := url.Parse(chain.Context.BaseURL()) + if err != nil { + return err + } + newLocation := fmt.Sprintf("%s://%s/%s", u.Scheme, u.Host, headers[0]) + chain.Context.Set("location", newLocation) + } + + // forward headers + for _, value := range headers { + chain.Context.Set(name, value) + } + } + + return nil + } +}