dwshttp/main.go

269 lines
6.3 KiB
Go

package main
import (
"errors"
"fmt"
"io"
"log"
"net"
"os"
"strconv"
"strings"
"time"
)
type HTTPMethod string
const (
HTTPMETHOD_GET HTTPMethod = "GET"
HTTPMETHOD_POST HTTPMethod = "POST"
)
/*
A request-line begins with a method token, followed by a single space
(SP), the request-target, another single space (SP), the protocol
version, and ends with CRLF.
request-line = method SP request-target SP HTTP-version CRLF
HTTP-version = HTTP-name "/" DIGIT "." DIGIT
*/
type RequestLine struct {
Method HTTPMethod
RequestTarget string
HTTPVersion string
}
/*
The first line of a response message is the status-line, consisting
of the protocol version, a space (SP), the status code, another
space, a possibly empty textual phrase describing the status code,
and ending with CRLF.
status-line = HTTP-version SP status-code SP reason-phrase CRLF
*/
type StatusLine struct {
}
/*
All HTTP/1.1 messages consist of a start-line followed by a sequence
of octets in a format similar to the Internet Message Format
[RFC5322]: zero or more header fields (collectively referred to as
the "headers" or the "header section"), an empty line indicating the
end of the header section, and an optional message body.
HTTP-message = start-line
*( header-field CRLF )
CRLF
[ message-body ]
*/
type HTTPRequest struct {
StartLine RequestLine
Headers map[string]string
MessageBody []byte
}
type HTTPResponse struct {
StartLine StatusLine
}
// ReadBytesUntil reads a slice of bytes until a given character is reached.
// It returns the number of bytes read, the bytes until the character, and the remaining bytes.
func ReadBytesUntil(b []byte, c byte) (int, []byte, []byte) {
i := 0
for i < len(b) && b[i] != c {
i++
}
if i == len(b)-1 && b[i] != c {
return 0, nil, b
}
return i, b[:i], b[i+1:]
}
func ParseHTTPRequest(b []byte) (HTTPRequest, error) {
ret := HTTPRequest{}
// Construct startline
_, mR, br := ReadBytesUntil(b, ' ')
b = br
if mR == nil {
return ret, errors.New("could not find method in request")
}
method := string(mR[:])
method = strings.ToUpper(strings.TrimSpace(method))
ret.StartLine = RequestLine{}
switch method {
case string(HTTPMETHOD_GET):
ret.StartLine.Method = HTTPMETHOD_GET
case string(HTTPMETHOD_POST):
ret.StartLine.Method = HTTPMETHOD_POST
default:
return ret, fmt.Errorf("unsupported method '%s'", method)
}
_, rt, br := ReadBytesUntil(b, ' ')
b = br
if rt == nil {
return ret, errors.New("could not find target in request")
}
ret.StartLine.RequestTarget = string(rt[:])
_, hv, br := ReadBytesUntil(b, '\r')
if hv == nil {
return ret, errors.New("could not find http version in request")
}
_, nc, br := ReadBytesUntil(b, '\n')
b = br
if nc == nil {
return ret, errors.New("malformed request")
}
hvs := string(hv[:])
if hvs != "HTTP/1.0" && hvs != "HTTP/1.1" {
return ret, fmt.Errorf("unsupported http version %s", hv)
}
ret.StartLine.HTTPVersion = hvs
// Check for headers
c, hr, br := ReadBytesUntil(b, '\r')
b = br
if hr == nil {
return ret, errors.New("malformed request")
}
if c == 0 {
_, nc, _ := ReadBytesUntil(b, '\n')
if nc == nil {
return ret, errors.New("malformed request")
}
return ret, nil
} else {
_, nc, br := ReadBytesUntil(b, '\n')
b = br
if nc == nil {
return ret, errors.New("malformed request")
}
ret.Headers = map[string]string{}
h := string(hr[:])
for len(h) > 0 {
// We have some headers
k, v, found := strings.Cut(h, ":")
if !found {
return ret, fmt.Errorf("malformed header '%s'", h)
}
ret.Headers[k] = strings.TrimSpace(v)
_, hr, br := ReadBytesUntil(b, '\r')
if hr == nil {
return ret, fmt.Errorf("malformed header '%s'", h)
}
h = string(hr[:])
b = br
_, nc, br := ReadBytesUntil(b, '\n')
b = br
if nc == nil {
return ret, errors.New("malformed request")
}
}
}
// Parse Message Body
// Message body if it exists
if _, ok := ret.Headers["Transfer-Encoding"]; ok {
if _, okc := ret.Headers["Content-Length"]; okc {
return ret, fmt.Errorf("cannot specify both 'Transfer-Encoding' and 'Content-Length'")
}
return ret, fmt.Errorf("unimplemented")
}
if val, ok := ret.Headers["Content-Length"]; ok {
mlen, err := strconv.Atoi(val)
if err != nil {
return ret, fmt.Errorf("malformed Content-Length '%s'", val)
}
if len(b) == mlen {
fmt.Println("mlen:", mlen)
ret.MessageBody = make([]byte, mlen)
copy(ret.MessageBody, b)
} else {
return ret, fmt.Errorf("malformed Content-Length '%s'", val)
}
}
return ret, nil
}
func main() {
logger := log.New(os.Stdout, "DWSHTTP", log.LstdFlags)
l, err := net.Listen("tcp", ":4221")
if err != nil {
logger.Fatal(err)
}
defer l.Close()
logger.Print("Started server on 4221")
for {
conn, err := l.Accept()
if err != nil {
logger.Fatal(err)
}
logger.Printf("New connection: %s", conn.RemoteAddr().String())
go func(c net.Conn) {
defer c.Close()
bSize := 0
rSize := 256
buf := make([]byte, bSize)
tmp := make([]byte, rSize)
err := c.SetReadDeadline(time.Now().Add(time.Second * 5))
if err != nil {
logger.Printf("ERROR %s - (%s) %s", c.RemoteAddr().String(), "connsetup", err)
}
for {
n, err := c.Read(tmp)
if err != nil {
if err != io.EOF {
logger.Fatalf("ERROR %s - (%s) %s", c.RemoteAddr().String(), "unexpected error while reading", err)
}
break
}
buf = append(buf, tmp[:n]...)
if n > 0 {
if n < rSize {
break
}
continue
}
}
req, err := ParseHTTPRequest(buf)
if err != nil {
logger.Printf("ERROR %s - (%s) %s", c.RemoteAddr().String(), "parsing", err)
bw, err := c.Write([]byte("HTTP/1.1 500 Internal Server Error\r\n\r\n"))
if err != nil {
logger.Printf("ERROR %s - (Could not write err) %s", c.RemoteAddr().String(), err)
return
}
if bw < 1 {
return
}
}
logger.Printf("%s - %s - %s -> 200 {%s (Size: %d)}", req.StartLine.Method, req.StartLine.HTTPVersion, req.StartLine.RequestTarget, string(req.MessageBody[:]), len(req.MessageBody))
bw, err := c.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
if err != nil {
logger.Printf("ERROR %s - (Could not write err) %s", c.RemoteAddr().String(), err)
return
}
if bw < 1 {
return
}
}(conn)
}
}