diff --git a/internal/api/join_handler.go b/internal/api/join_handler.go new file mode 100644 index 0000000..591b88e --- /dev/null +++ b/internal/api/join_handler.go @@ -0,0 +1,141 @@ +package api + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/google/uuid" + + "kat-system/internal/pki" + "kat-system/internal/store" +) + +// JoinRequest represents the data sent by an agent when joining +type JoinRequest struct { + CSR []byte `json:"csr"` + AdvertiseAddr string `json:"advertiseAddr"` + NodeName string `json:"nodeName,omitempty"` // Optional, leader can generate + WireguardPubKey string `json:"wireguardPubKey"` // Placeholder for now +} + +// JoinResponse represents the data sent back to the agent +type JoinResponse struct { + NodeName string `json:"nodeName"` + NodeUID string `json:"nodeUID"` + SignedCert []byte `json:"signedCert"` + CACert []byte `json:"caCert"` + JoinTimestamp int64 `json:"joinTimestamp"` +} + +// NewJoinHandler creates a handler for agent join requests +func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Read and parse the request body + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest) + return + } + defer r.Body.Close() + + var joinReq JoinRequest + if err := json.Unmarshal(body, &joinReq); err != nil { + http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest) + return + } + + // Validate request + if len(joinReq.CSR) == 0 { + http.Error(w, "Missing CSR", http.StatusBadRequest) + return + } + if joinReq.AdvertiseAddr == "" { + http.Error(w, "Missing advertise address", http.StatusBadRequest) + return + } + + // Generate node name if not provided + nodeName := joinReq.NodeName + if nodeName == "" { + nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8]) + } + + // Generate a unique node ID + nodeUID := uuid.New().String() + + // Sign the CSR + // Create a temporary file for the CSR + tempDir := os.TempDir() + csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID)) + if err := os.WriteFile(csrPath, joinReq.CSR, 0600); err != nil { + http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError) + return + } + defer os.Remove(csrPath) + + // Sign the CSR + certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID)) + if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil { + http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError) + return + } + defer os.Remove(certPath) + + // Read the signed certificate + signedCert, err := os.ReadFile(certPath) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError) + return + } + + // Read the CA certificate + caCert, err := os.ReadFile(caCertPath) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError) + return + } + + // Store node registration in etcd + nodeRegKey := fmt.Sprintf("/kat/nodes/registration/%s", nodeName) + nodeReg := map[string]interface{}{ + "uid": nodeUID, + "advertiseAddr": joinReq.AdvertiseAddr, + "wireguardPubKey": joinReq.WireguardPubKey, + "joinTimestamp": time.Now().Unix(), + } + nodeRegData, err := json.Marshal(nodeReg) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError) + return + } + + if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil { + http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError) + return + } + + // Prepare and send response + joinResp := JoinResponse{ + NodeName: nodeName, + NodeUID: nodeUID, + SignedCert: signedCert, + CACert: caCert, + JoinTimestamp: time.Now().Unix(), + } + + respData, err := json.Marshal(joinResp) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(respData) + } +} diff --git a/internal/api/router.go b/internal/api/router.go new file mode 100644 index 0000000..320c546 --- /dev/null +++ b/internal/api/router.go @@ -0,0 +1,48 @@ +package api + +import ( + "net/http" + "strings" +) + +// Route represents a single API route +type Route struct { + Method string + Path string + Handler http.HandlerFunc +} + +// Router is a simple HTTP router for the KAT API +type Router struct { + routes []Route +} + +// NewRouter creates a new router instance +func NewRouter() *Router { + return &Router{ + routes: []Route{}, + } +} + +// HandleFunc registers a new route with the router +func (r *Router) HandleFunc(method, path string, handler http.HandlerFunc) { + r.routes = append(r.routes, Route{ + Method: strings.ToUpper(method), + Path: path, + Handler: handler, + }) +} + +// ServeHTTP implements the http.Handler interface +func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Find matching route + for _, route := range r.routes { + if route.Method == req.Method && route.Path == req.URL.Path { + route.Handler(w, req) + return + } + } + + // No route matched + http.NotFound(w, req) +} diff --git a/internal/api/server.go b/internal/api/server.go new file mode 100644 index 0000000..d3aa590 --- /dev/null +++ b/internal/api/server.go @@ -0,0 +1,84 @@ +package api + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "os" + "time" +) + +// Server represents the API server for KAT +type Server struct { + httpServer *http.Server + router *Router + certFile string + keyFile string + caFile string +} + +// NewServer creates a new API server instance +func NewServer(addr string, certFile, keyFile, caFile string) (*Server, error) { + router := NewRouter() + + server := &Server{ + router: router, + certFile: certFile, + keyFile: keyFile, + caFile: caFile, + } + + // Create the HTTP server with TLS config + server.httpServer = &http.Server{ + Addr: addr, + Handler: router, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + } + + return server, nil +} + +// Start begins listening for requests +func (s *Server) Start() error { + // Load server certificate and key + cert, err := tls.LoadX509KeyPair(s.certFile, s.keyFile) + if err != nil { + return fmt.Errorf("failed to load server certificate and key: %w", err) + } + + // Load CA certificate for client verification + caCert, err := os.ReadFile(s.caFile) + if err != nil { + return fmt.Errorf("failed to read CA certificate: %w", err) + } + + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return fmt.Errorf("failed to append CA certificate to pool") + } + + // Configure TLS + s.httpServer.TLSConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: caCertPool, + MinVersion: tls.VersionTLS12, + } + + // Start the server + return s.httpServer.ListenAndServeTLS("", "") +} + +// Stop gracefully shuts down the server +func (s *Server) Stop(ctx context.Context) error { + return s.httpServer.Shutdown(ctx) +} + +// RegisterJoinHandler registers the handler for agent join requests +func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) { + s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler) +} diff --git a/internal/api/server_test.go b/internal/api/server_test.go new file mode 100644 index 0000000..d6ebeae --- /dev/null +++ b/internal/api/server_test.go @@ -0,0 +1,139 @@ +package api + +import ( + "context" + "crypto/tls" + "crypto/x509" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "kat-system/internal/pki" +) + +func TestServerWithMTLS(t *testing.T) { + // Skip in short mode + if testing.Short() { + t.Skip("Skipping test in short mode") + } + + // Create temporary directory for test certificates + tempDir, err := os.MkdirTemp("", "kat-api-test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Generate CA + caKeyPath := filepath.Join(tempDir, "ca.key") + caCertPath := filepath.Join(tempDir, "ca.crt") + if err := pki.GenerateCA(caKeyPath, caCertPath, "KAT Test CA", 24*time.Hour); err != nil { + t.Fatalf("Failed to generate CA: %v", err) + } + + // Generate server certificate + serverKeyPath := filepath.Join(tempDir, "server.key") + serverCSRPath := filepath.Join(tempDir, "server.csr") + serverCertPath := filepath.Join(tempDir, "server.crt") + if err := pki.GenerateCertificateRequest("server.test", serverKeyPath, serverCSRPath); err != nil { + t.Fatalf("Failed to generate server CSR: %v", err) + } + if err := pki.SignCertificateRequest(caKeyPath, caCertPath, serverCSRPath, serverCertPath, 24*time.Hour); err != nil { + t.Fatalf("Failed to sign server certificate: %v", err) + } + + // Generate client certificate + clientKeyPath := filepath.Join(tempDir, "client.key") + clientCSRPath := filepath.Join(tempDir, "client.csr") + clientCertPath := filepath.Join(tempDir, "client.crt") + if err := pki.GenerateCertificateRequest("client.test", clientKeyPath, clientCSRPath); err != nil { + t.Fatalf("Failed to generate client CSR: %v", err) + } + if err := pki.SignCertificateRequest(caKeyPath, caCertPath, clientCSRPath, clientCertPath, 24*time.Hour); err != nil { + t.Fatalf("Failed to sign client certificate: %v", err) + } + + // Create and start server + server, err := NewServer("localhost:0", serverCertPath, serverKeyPath, caCertPath) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Add a test handler + server.router.HandleFunc("GET", "/test", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("test successful")) + }) + + // Start server in a goroutine + go func() { + if err := server.Start(); err != nil && err != http.ErrServerClosed { + t.Errorf("Server error: %v", err) + } + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Load CA cert + caCert, err := os.ReadFile(caCertPath) + if err != nil { + t.Fatalf("Failed to read CA cert: %v", err) + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + // Load client cert + clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) + if err != nil { + t.Fatalf("Failed to load client cert: %v", err) + } + + // Create HTTP client with mTLS + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: caCertPool, + Certificates: []tls.Certificate{clientCert}, + }, + }, + } + + // Test with valid client cert + resp, err := client.Get("https://localhost:8443/test") + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response: %v", err) + } + + if !strings.Contains(string(body), "test successful") { + t.Errorf("Unexpected response: %s", body) + } + + // Test with no client cert (should fail) + clientWithoutCert := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: caCertPool, + }, + }, + } + + _, err = clientWithoutCert.Get("https://localhost:8443/test") + if err == nil { + t.Error("Request without client cert should fail") + } + + // Shutdown server + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + server.Stop(ctx) +}