From b23320c288aa05f2a1bf7bd6039e4f52519e9d9c Mon Sep 17 00:00:00 2001 From: Tanishq Dubey Date: Mon, 15 Jul 2024 20:28:11 -0400 Subject: [PATCH] basic http request parsing --- main.go | 142 +++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 100 insertions(+), 42 deletions(-) diff --git a/main.go b/main.go index 10c2caa..45dadcd 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "log" "net" "os" + "strconv" "strings" "time" ) @@ -57,48 +58,38 @@ end of the header section, and an optional message body. [ message-body ] */ type HTTPRequest struct { - StartLine RequestLine + StartLine RequestLine + Headers map[string]string + MessageBody []byte } type HTTPResponse struct { StartLine StatusLine } -func printDiff(s1, s2 string) { - length := len(s1) - if len(s2) > length { - length = len(s2) +// ReadBytesUntil reads a slice of bytes until a given character is reached. +// It returns the number of bytes read, the bytes until the character, and the remaining bytes. +func ReadBytesUntil(b []byte, c byte) (int, []byte, []byte) { + i := 0 + for i < len(b) && b[i] != c { + i++ } - - for i := 0; i < length; i++ { - var char1, char2 byte - - if i < len(s1) { - char1 = s1[i] - } else { - char1 = ' ' // padding for shorter string - } - - if i < len(s2) { - char2 = s2[i] - } else { - char2 = ' ' // padding for shorter string - } - - if char1 != char2 { - fmt.Printf("Difference at index %d: '%c' != '%c'\n", i, char1, char2) - } + if i == len(b)-1 && b[i] != c { + return 0, nil, b } + return i, b[:i], b[i+1:] } func ParseHTTPRequest(b []byte) (HTTPRequest, error) { ret := HTTPRequest{} - rs := string(b[:]) - method, rs, found := strings.Cut(rs, " ") - if !found { + // Construct startline + _, mR, br := ReadBytesUntil(b, ' ') + b = br + if mR == nil { return ret, errors.New("could not find method in request") } + method := string(mR[:]) method = strings.ToUpper(strings.TrimSpace(method)) ret.StartLine = RequestLine{} @@ -108,30 +99,97 @@ func ParseHTTPRequest(b []byte) (HTTPRequest, error) { case string(HTTPMETHOD_POST): ret.StartLine.Method = HTTPMETHOD_POST default: - if method != "GET" { - result1 := strings.Compare("GET", method) - fmt.Println(result1) - fmt.Println(len("GET"), len(method)) - } - return ret, fmt.Errorf("unsupported method '%s'", method) } - rt, rs, found := strings.Cut(rs, " ") - if !found { + _, rt, br := ReadBytesUntil(b, ' ') + b = br + if rt == nil { return ret, errors.New("could not find target in request") } - ret.StartLine.RequestTarget = rt + ret.StartLine.RequestTarget = string(rt[:]) - hv, rs, found := strings.Cut(rs, "\r\n") - if !found { + _, hv, br := ReadBytesUntil(b, '\r') + if hv == nil { return ret, errors.New("could not find http version in request") } - if hv != "HTTP/1.0" && hv != "HTTP/1.1" { + + _, nc, br := ReadBytesUntil(b, '\n') + b = br + if nc == nil { + return ret, errors.New("malformed request") + } + + hvs := string(hv[:]) + if hvs != "HTTP/1.0" && hvs != "HTTP/1.1" { return ret, fmt.Errorf("unsupported http version %s", hv) } - ret.StartLine.HTTPVersion = hv - fmt.Println("rm: ", rs) + ret.StartLine.HTTPVersion = hvs + + // Check for headers + c, hr, br := ReadBytesUntil(b, '\r') + b = br + if hr == nil { + return ret, errors.New("malformed request") + } + if c == 0 { + _, nc, _ := ReadBytesUntil(b, '\n') + if nc == nil { + return ret, errors.New("malformed request") + } + return ret, nil + } else { + _, nc, br := ReadBytesUntil(b, '\n') + b = br + if nc == nil { + return ret, errors.New("malformed request") + } + ret.Headers = map[string]string{} + h := string(hr[:]) + for len(h) > 0 { + // We have some headers + k, v, found := strings.Cut(h, ":") + if !found { + return ret, fmt.Errorf("malformed header '%s'", h) + } + ret.Headers[k] = strings.TrimSpace(v) + _, hr, br := ReadBytesUntil(b, '\r') + if hr == nil { + return ret, fmt.Errorf("malformed header '%s'", h) + } + h = string(hr[:]) + b = br + + _, nc, br := ReadBytesUntil(b, '\n') + b = br + if nc == nil { + return ret, errors.New("malformed request") + } + } + } + + // Parse Message Body + + // Message body if it exists + if _, ok := ret.Headers["Transfer-Encoding"]; ok { + if _, okc := ret.Headers["Content-Length"]; okc { + return ret, fmt.Errorf("cannot specify both 'Transfer-Encoding' and 'Content-Length'") + } + return ret, fmt.Errorf("unimplemented") + } + if val, ok := ret.Headers["Content-Length"]; ok { + mlen, err := strconv.Atoi(val) + if err != nil { + return ret, fmt.Errorf("malformed Content-Length '%s'", val) + } + if len(b) == mlen { + fmt.Println("mlen:", mlen) + ret.MessageBody = make([]byte, mlen) + copy(ret.MessageBody, b) + } else { + return ret, fmt.Errorf("malformed Content-Length '%s'", val) + } + } return ret, nil } @@ -195,7 +253,7 @@ func main() { } } - logger.Printf("%s - %s - %s -> 200", req.StartLine.Method, req.StartLine.HTTPVersion, req.StartLine.RequestTarget) + logger.Printf("%s - %s - %s -> 200 {%s (Size: %d)}", req.StartLine.Method, req.StartLine.HTTPVersion, req.StartLine.RequestTarget, string(req.MessageBody[:]), len(req.MessageBody)) bw, err := c.Write([]byte("HTTP/1.1 200 OK\r\n\r\n")) if err != nil { logger.Printf("ERROR %s - (Could not write err) %s", c.RemoteAddr().String(), err)