feat: Implement basic API server with mTLS for leader join endpoint

This commit is contained in:
Tanishq Dubey 2025-05-16 22:18:58 -04:00
parent 800e4f72f2
commit 9e63518308
No known key found for this signature in database
GPG Key ID: CFC1931B84DFC3F9
4 changed files with 412 additions and 0 deletions

View File

@ -0,0 +1,141 @@
package api
import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
"github.com/google/uuid"
"kat-system/internal/pki"
"kat-system/internal/store"
)
// JoinRequest represents the data sent by an agent when joining
type JoinRequest struct {
CSR []byte `json:"csr"`
AdvertiseAddr string `json:"advertiseAddr"`
NodeName string `json:"nodeName,omitempty"` // Optional, leader can generate
WireguardPubKey string `json:"wireguardPubKey"` // Placeholder for now
}
// JoinResponse represents the data sent back to the agent
type JoinResponse struct {
NodeName string `json:"nodeName"`
NodeUID string `json:"nodeUID"`
SignedCert []byte `json:"signedCert"`
CACert []byte `json:"caCert"`
JoinTimestamp int64 `json:"joinTimestamp"`
}
// NewJoinHandler creates a handler for agent join requests
func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Read and parse the request body
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest)
return
}
defer r.Body.Close()
var joinReq JoinRequest
if err := json.Unmarshal(body, &joinReq); err != nil {
http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest)
return
}
// Validate request
if len(joinReq.CSR) == 0 {
http.Error(w, "Missing CSR", http.StatusBadRequest)
return
}
if joinReq.AdvertiseAddr == "" {
http.Error(w, "Missing advertise address", http.StatusBadRequest)
return
}
// Generate node name if not provided
nodeName := joinReq.NodeName
if nodeName == "" {
nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8])
}
// Generate a unique node ID
nodeUID := uuid.New().String()
// Sign the CSR
// Create a temporary file for the CSR
tempDir := os.TempDir()
csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID))
if err := os.WriteFile(csrPath, joinReq.CSR, 0600); err != nil {
http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError)
return
}
defer os.Remove(csrPath)
// Sign the CSR
certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID))
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil {
http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError)
return
}
defer os.Remove(certPath)
// Read the signed certificate
signedCert, err := os.ReadFile(certPath)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError)
return
}
// Read the CA certificate
caCert, err := os.ReadFile(caCertPath)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError)
return
}
// Store node registration in etcd
nodeRegKey := fmt.Sprintf("/kat/nodes/registration/%s", nodeName)
nodeReg := map[string]interface{}{
"uid": nodeUID,
"advertiseAddr": joinReq.AdvertiseAddr,
"wireguardPubKey": joinReq.WireguardPubKey,
"joinTimestamp": time.Now().Unix(),
}
nodeRegData, err := json.Marshal(nodeReg)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError)
return
}
if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil {
http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError)
return
}
// Prepare and send response
joinResp := JoinResponse{
NodeName: nodeName,
NodeUID: nodeUID,
SignedCert: signedCert,
CACert: caCert,
JoinTimestamp: time.Now().Unix(),
}
respData, err := json.Marshal(joinResp)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(respData)
}
}

48
internal/api/router.go Normal file
View File

@ -0,0 +1,48 @@
package api
import (
"net/http"
"strings"
)
// Route represents a single API route
type Route struct {
Method string
Path string
Handler http.HandlerFunc
}
// Router is a simple HTTP router for the KAT API
type Router struct {
routes []Route
}
// NewRouter creates a new router instance
func NewRouter() *Router {
return &Router{
routes: []Route{},
}
}
// HandleFunc registers a new route with the router
func (r *Router) HandleFunc(method, path string, handler http.HandlerFunc) {
r.routes = append(r.routes, Route{
Method: strings.ToUpper(method),
Path: path,
Handler: handler,
})
}
// ServeHTTP implements the http.Handler interface
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Find matching route
for _, route := range r.routes {
if route.Method == req.Method && route.Path == req.URL.Path {
route.Handler(w, req)
return
}
}
// No route matched
http.NotFound(w, req)
}

84
internal/api/server.go Normal file
View File

@ -0,0 +1,84 @@
package api
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
"os"
"time"
)
// Server represents the API server for KAT
type Server struct {
httpServer *http.Server
router *Router
certFile string
keyFile string
caFile string
}
// NewServer creates a new API server instance
func NewServer(addr string, certFile, keyFile, caFile string) (*Server, error) {
router := NewRouter()
server := &Server{
router: router,
certFile: certFile,
keyFile: keyFile,
caFile: caFile,
}
// Create the HTTP server with TLS config
server.httpServer = &http.Server{
Addr: addr,
Handler: router,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
return server, nil
}
// Start begins listening for requests
func (s *Server) Start() error {
// Load server certificate and key
cert, err := tls.LoadX509KeyPair(s.certFile, s.keyFile)
if err != nil {
return fmt.Errorf("failed to load server certificate and key: %w", err)
}
// Load CA certificate for client verification
caCert, err := os.ReadFile(s.caFile)
if err != nil {
return fmt.Errorf("failed to read CA certificate: %w", err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return fmt.Errorf("failed to append CA certificate to pool")
}
// Configure TLS
s.httpServer.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: caCertPool,
MinVersion: tls.VersionTLS12,
}
// Start the server
return s.httpServer.ListenAndServeTLS("", "")
}
// Stop gracefully shuts down the server
func (s *Server) Stop(ctx context.Context) error {
return s.httpServer.Shutdown(ctx)
}
// RegisterJoinHandler registers the handler for agent join requests
func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) {
s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler)
}

139
internal/api/server_test.go Normal file
View File

@ -0,0 +1,139 @@
package api
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"
"kat-system/internal/pki"
)
func TestServerWithMTLS(t *testing.T) {
// Skip in short mode
if testing.Short() {
t.Skip("Skipping test in short mode")
}
// Create temporary directory for test certificates
tempDir, err := os.MkdirTemp("", "kat-api-test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
// Generate CA
caKeyPath := filepath.Join(tempDir, "ca.key")
caCertPath := filepath.Join(tempDir, "ca.crt")
if err := pki.GenerateCA(caKeyPath, caCertPath, "KAT Test CA", 24*time.Hour); err != nil {
t.Fatalf("Failed to generate CA: %v", err)
}
// Generate server certificate
serverKeyPath := filepath.Join(tempDir, "server.key")
serverCSRPath := filepath.Join(tempDir, "server.csr")
serverCertPath := filepath.Join(tempDir, "server.crt")
if err := pki.GenerateCertificateRequest("server.test", serverKeyPath, serverCSRPath); err != nil {
t.Fatalf("Failed to generate server CSR: %v", err)
}
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, serverCSRPath, serverCertPath, 24*time.Hour); err != nil {
t.Fatalf("Failed to sign server certificate: %v", err)
}
// Generate client certificate
clientKeyPath := filepath.Join(tempDir, "client.key")
clientCSRPath := filepath.Join(tempDir, "client.csr")
clientCertPath := filepath.Join(tempDir, "client.crt")
if err := pki.GenerateCertificateRequest("client.test", clientKeyPath, clientCSRPath); err != nil {
t.Fatalf("Failed to generate client CSR: %v", err)
}
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, clientCSRPath, clientCertPath, 24*time.Hour); err != nil {
t.Fatalf("Failed to sign client certificate: %v", err)
}
// Create and start server
server, err := NewServer("localhost:0", serverCertPath, serverKeyPath, caCertPath)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// Add a test handler
server.router.HandleFunc("GET", "/test", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test successful"))
})
// Start server in a goroutine
go func() {
if err := server.Start(); err != nil && err != http.ErrServerClosed {
t.Errorf("Server error: %v", err)
}
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
// Load CA cert
caCert, err := os.ReadFile(caCertPath)
if err != nil {
t.Fatalf("Failed to read CA cert: %v", err)
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
// Load client cert
clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
if err != nil {
t.Fatalf("Failed to load client cert: %v", err)
}
// Create HTTP client with mTLS
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: caCertPool,
Certificates: []tls.Certificate{clientCert},
},
},
}
// Test with valid client cert
resp, err := client.Get("https://localhost:8443/test")
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}
if !strings.Contains(string(body), "test successful") {
t.Errorf("Unexpected response: %s", body)
}
// Test with no client cert (should fail)
clientWithoutCert := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: caCertPool,
},
},
}
_, err = clientWithoutCert.Get("https://localhost:8443/test")
if err == nil {
t.Error("Request without client cert should fail")
}
// Shutdown server
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
server.Stop(ctx)
}