basic http request parsing

This commit is contained in:
Tanishq Dubey 2024-07-15 20:28:11 -04:00
parent b740f57488
commit b23320c288

140
main.go
View File

@ -7,6 +7,7 @@ import (
"log" "log"
"net" "net"
"os" "os"
"strconv"
"strings" "strings"
"time" "time"
) )
@ -58,47 +59,37 @@ end of the header section, and an optional message body.
*/ */
type HTTPRequest struct { type HTTPRequest struct {
StartLine RequestLine StartLine RequestLine
Headers map[string]string
MessageBody []byte
} }
type HTTPResponse struct { type HTTPResponse struct {
StartLine StatusLine StartLine StatusLine
} }
func printDiff(s1, s2 string) { // ReadBytesUntil reads a slice of bytes until a given character is reached.
length := len(s1) // It returns the number of bytes read, the bytes until the character, and the remaining bytes.
if len(s2) > length { func ReadBytesUntil(b []byte, c byte) (int, []byte, []byte) {
length = len(s2) i := 0
} for i < len(b) && b[i] != c {
i++
for i := 0; i < length; i++ {
var char1, char2 byte
if i < len(s1) {
char1 = s1[i]
} else {
char1 = ' ' // padding for shorter string
}
if i < len(s2) {
char2 = s2[i]
} else {
char2 = ' ' // padding for shorter string
}
if char1 != char2 {
fmt.Printf("Difference at index %d: '%c' != '%c'\n", i, char1, char2)
} }
if i == len(b)-1 && b[i] != c {
return 0, nil, b
} }
return i, b[:i], b[i+1:]
} }
func ParseHTTPRequest(b []byte) (HTTPRequest, error) { func ParseHTTPRequest(b []byte) (HTTPRequest, error) {
ret := HTTPRequest{} ret := HTTPRequest{}
rs := string(b[:])
method, rs, found := strings.Cut(rs, " ") // Construct startline
if !found { _, mR, br := ReadBytesUntil(b, ' ')
b = br
if mR == nil {
return ret, errors.New("could not find method in request") return ret, errors.New("could not find method in request")
} }
method := string(mR[:])
method = strings.ToUpper(strings.TrimSpace(method)) method = strings.ToUpper(strings.TrimSpace(method))
ret.StartLine = RequestLine{} ret.StartLine = RequestLine{}
@ -108,30 +99,97 @@ func ParseHTTPRequest(b []byte) (HTTPRequest, error) {
case string(HTTPMETHOD_POST): case string(HTTPMETHOD_POST):
ret.StartLine.Method = HTTPMETHOD_POST ret.StartLine.Method = HTTPMETHOD_POST
default: default:
if method != "GET" {
result1 := strings.Compare("GET", method)
fmt.Println(result1)
fmt.Println(len("GET"), len(method))
}
return ret, fmt.Errorf("unsupported method '%s'", method) return ret, fmt.Errorf("unsupported method '%s'", method)
} }
rt, rs, found := strings.Cut(rs, " ") _, rt, br := ReadBytesUntil(b, ' ')
if !found { b = br
if rt == nil {
return ret, errors.New("could not find target in request") return ret, errors.New("could not find target in request")
} }
ret.StartLine.RequestTarget = rt ret.StartLine.RequestTarget = string(rt[:])
hv, rs, found := strings.Cut(rs, "\r\n") _, hv, br := ReadBytesUntil(b, '\r')
if !found { if hv == nil {
return ret, errors.New("could not find http version in request") return ret, errors.New("could not find http version in request")
} }
if hv != "HTTP/1.0" && hv != "HTTP/1.1" {
_, nc, br := ReadBytesUntil(b, '\n')
b = br
if nc == nil {
return ret, errors.New("malformed request")
}
hvs := string(hv[:])
if hvs != "HTTP/1.0" && hvs != "HTTP/1.1" {
return ret, fmt.Errorf("unsupported http version %s", hv) return ret, fmt.Errorf("unsupported http version %s", hv)
} }
ret.StartLine.HTTPVersion = hv ret.StartLine.HTTPVersion = hvs
fmt.Println("rm: ", rs)
// Check for headers
c, hr, br := ReadBytesUntil(b, '\r')
b = br
if hr == nil {
return ret, errors.New("malformed request")
}
if c == 0 {
_, nc, _ := ReadBytesUntil(b, '\n')
if nc == nil {
return ret, errors.New("malformed request")
}
return ret, nil
} else {
_, nc, br := ReadBytesUntil(b, '\n')
b = br
if nc == nil {
return ret, errors.New("malformed request")
}
ret.Headers = map[string]string{}
h := string(hr[:])
for len(h) > 0 {
// We have some headers
k, v, found := strings.Cut(h, ":")
if !found {
return ret, fmt.Errorf("malformed header '%s'", h)
}
ret.Headers[k] = strings.TrimSpace(v)
_, hr, br := ReadBytesUntil(b, '\r')
if hr == nil {
return ret, fmt.Errorf("malformed header '%s'", h)
}
h = string(hr[:])
b = br
_, nc, br := ReadBytesUntil(b, '\n')
b = br
if nc == nil {
return ret, errors.New("malformed request")
}
}
}
// Parse Message Body
// Message body if it exists
if _, ok := ret.Headers["Transfer-Encoding"]; ok {
if _, okc := ret.Headers["Content-Length"]; okc {
return ret, fmt.Errorf("cannot specify both 'Transfer-Encoding' and 'Content-Length'")
}
return ret, fmt.Errorf("unimplemented")
}
if val, ok := ret.Headers["Content-Length"]; ok {
mlen, err := strconv.Atoi(val)
if err != nil {
return ret, fmt.Errorf("malformed Content-Length '%s'", val)
}
if len(b) == mlen {
fmt.Println("mlen:", mlen)
ret.MessageBody = make([]byte, mlen)
copy(ret.MessageBody, b)
} else {
return ret, fmt.Errorf("malformed Content-Length '%s'", val)
}
}
return ret, nil return ret, nil
} }
@ -195,7 +253,7 @@ func main() {
} }
} }
logger.Printf("%s - %s - %s -> 200", req.StartLine.Method, req.StartLine.HTTPVersion, req.StartLine.RequestTarget) logger.Printf("%s - %s - %s -> 200 {%s (Size: %d)}", req.StartLine.Method, req.StartLine.HTTPVersion, req.StartLine.RequestTarget, string(req.MessageBody[:]), len(req.MessageBody))
bw, err := c.Write([]byte("HTTP/1.1 200 OK\r\n\r\n")) bw, err := c.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
if err != nil { if err != nil {
logger.Printf("ERROR %s - (Could not write err) %s", c.RemoteAddr().String(), err) logger.Printf("ERROR %s - (Could not write err) %s", c.RemoteAddr().String(), err)