feat: Implement basic API server with mTLS for leader join endpoint
This commit is contained in:
parent
800e4f72f2
commit
9e63518308
141
internal/api/join_handler.go
Normal file
141
internal/api/join_handler.go
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"kat-system/internal/pki"
|
||||||
|
"kat-system/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JoinRequest represents the data sent by an agent when joining
|
||||||
|
type JoinRequest struct {
|
||||||
|
CSR []byte `json:"csr"`
|
||||||
|
AdvertiseAddr string `json:"advertiseAddr"`
|
||||||
|
NodeName string `json:"nodeName,omitempty"` // Optional, leader can generate
|
||||||
|
WireguardPubKey string `json:"wireguardPubKey"` // Placeholder for now
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinResponse represents the data sent back to the agent
|
||||||
|
type JoinResponse struct {
|
||||||
|
NodeName string `json:"nodeName"`
|
||||||
|
NodeUID string `json:"nodeUID"`
|
||||||
|
SignedCert []byte `json:"signedCert"`
|
||||||
|
CACert []byte `json:"caCert"`
|
||||||
|
JoinTimestamp int64 `json:"joinTimestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJoinHandler creates a handler for agent join requests
|
||||||
|
func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Read and parse the request body
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer r.Body.Close()
|
||||||
|
|
||||||
|
var joinReq JoinRequest
|
||||||
|
if err := json.Unmarshal(body, &joinReq); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate request
|
||||||
|
if len(joinReq.CSR) == 0 {
|
||||||
|
http.Error(w, "Missing CSR", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if joinReq.AdvertiseAddr == "" {
|
||||||
|
http.Error(w, "Missing advertise address", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate node name if not provided
|
||||||
|
nodeName := joinReq.NodeName
|
||||||
|
if nodeName == "" {
|
||||||
|
nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a unique node ID
|
||||||
|
nodeUID := uuid.New().String()
|
||||||
|
|
||||||
|
// Sign the CSR
|
||||||
|
// Create a temporary file for the CSR
|
||||||
|
tempDir := os.TempDir()
|
||||||
|
csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID))
|
||||||
|
if err := os.WriteFile(csrPath, joinReq.CSR, 0600); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer os.Remove(csrPath)
|
||||||
|
|
||||||
|
// Sign the CSR
|
||||||
|
certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID))
|
||||||
|
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer os.Remove(certPath)
|
||||||
|
|
||||||
|
// Read the signed certificate
|
||||||
|
signedCert, err := os.ReadFile(certPath)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the CA certificate
|
||||||
|
caCert, err := os.ReadFile(caCertPath)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store node registration in etcd
|
||||||
|
nodeRegKey := fmt.Sprintf("/kat/nodes/registration/%s", nodeName)
|
||||||
|
nodeReg := map[string]interface{}{
|
||||||
|
"uid": nodeUID,
|
||||||
|
"advertiseAddr": joinReq.AdvertiseAddr,
|
||||||
|
"wireguardPubKey": joinReq.WireguardPubKey,
|
||||||
|
"joinTimestamp": time.Now().Unix(),
|
||||||
|
}
|
||||||
|
nodeRegData, err := json.Marshal(nodeReg)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare and send response
|
||||||
|
joinResp := JoinResponse{
|
||||||
|
NodeName: nodeName,
|
||||||
|
NodeUID: nodeUID,
|
||||||
|
SignedCert: signedCert,
|
||||||
|
CACert: caCert,
|
||||||
|
JoinTimestamp: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
respData, err := json.Marshal(joinResp)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write(respData)
|
||||||
|
}
|
||||||
|
}
|
48
internal/api/router.go
Normal file
48
internal/api/router.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Route represents a single API route
|
||||||
|
type Route struct {
|
||||||
|
Method string
|
||||||
|
Path string
|
||||||
|
Handler http.HandlerFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router is a simple HTTP router for the KAT API
|
||||||
|
type Router struct {
|
||||||
|
routes []Route
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRouter creates a new router instance
|
||||||
|
func NewRouter() *Router {
|
||||||
|
return &Router{
|
||||||
|
routes: []Route{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleFunc registers a new route with the router
|
||||||
|
func (r *Router) HandleFunc(method, path string, handler http.HandlerFunc) {
|
||||||
|
r.routes = append(r.routes, Route{
|
||||||
|
Method: strings.ToUpper(method),
|
||||||
|
Path: path,
|
||||||
|
Handler: handler,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP implements the http.Handler interface
|
||||||
|
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
|
// Find matching route
|
||||||
|
for _, route := range r.routes {
|
||||||
|
if route.Method == req.Method && route.Path == req.URL.Path {
|
||||||
|
route.Handler(w, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No route matched
|
||||||
|
http.NotFound(w, req)
|
||||||
|
}
|
84
internal/api/server.go
Normal file
84
internal/api/server.go
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server represents the API server for KAT
|
||||||
|
type Server struct {
|
||||||
|
httpServer *http.Server
|
||||||
|
router *Router
|
||||||
|
certFile string
|
||||||
|
keyFile string
|
||||||
|
caFile string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates a new API server instance
|
||||||
|
func NewServer(addr string, certFile, keyFile, caFile string) (*Server, error) {
|
||||||
|
router := NewRouter()
|
||||||
|
|
||||||
|
server := &Server{
|
||||||
|
router: router,
|
||||||
|
certFile: certFile,
|
||||||
|
keyFile: keyFile,
|
||||||
|
caFile: caFile,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the HTTP server with TLS config
|
||||||
|
server.httpServer = &http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Handler: router,
|
||||||
|
ReadTimeout: 30 * time.Second,
|
||||||
|
WriteTimeout: 30 * time.Second,
|
||||||
|
IdleTimeout: 120 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
return server, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins listening for requests
|
||||||
|
func (s *Server) Start() error {
|
||||||
|
// Load server certificate and key
|
||||||
|
cert, err := tls.LoadX509KeyPair(s.certFile, s.keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load server certificate and key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load CA certificate for client verification
|
||||||
|
caCert, err := os.ReadFile(s.caFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read CA certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||||
|
return fmt.Errorf("failed to append CA certificate to pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure TLS
|
||||||
|
s.httpServer.TLSConfig = &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||||
|
ClientCAs: caCertPool,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the server
|
||||||
|
return s.httpServer.ListenAndServeTLS("", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the server
|
||||||
|
func (s *Server) Stop(ctx context.Context) error {
|
||||||
|
return s.httpServer.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterJoinHandler registers the handler for agent join requests
|
||||||
|
func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) {
|
||||||
|
s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler)
|
||||||
|
}
|
139
internal/api/server_test.go
Normal file
139
internal/api/server_test.go
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"kat-system/internal/pki"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServerWithMTLS(t *testing.T) {
|
||||||
|
// Skip in short mode
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create temporary directory for test certificates
|
||||||
|
tempDir, err := os.MkdirTemp("", "kat-api-test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Generate CA
|
||||||
|
caKeyPath := filepath.Join(tempDir, "ca.key")
|
||||||
|
caCertPath := filepath.Join(tempDir, "ca.crt")
|
||||||
|
if err := pki.GenerateCA(caKeyPath, caCertPath, "KAT Test CA", 24*time.Hour); err != nil {
|
||||||
|
t.Fatalf("Failed to generate CA: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate server certificate
|
||||||
|
serverKeyPath := filepath.Join(tempDir, "server.key")
|
||||||
|
serverCSRPath := filepath.Join(tempDir, "server.csr")
|
||||||
|
serverCertPath := filepath.Join(tempDir, "server.crt")
|
||||||
|
if err := pki.GenerateCertificateRequest("server.test", serverKeyPath, serverCSRPath); err != nil {
|
||||||
|
t.Fatalf("Failed to generate server CSR: %v", err)
|
||||||
|
}
|
||||||
|
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, serverCSRPath, serverCertPath, 24*time.Hour); err != nil {
|
||||||
|
t.Fatalf("Failed to sign server certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate client certificate
|
||||||
|
clientKeyPath := filepath.Join(tempDir, "client.key")
|
||||||
|
clientCSRPath := filepath.Join(tempDir, "client.csr")
|
||||||
|
clientCertPath := filepath.Join(tempDir, "client.crt")
|
||||||
|
if err := pki.GenerateCertificateRequest("client.test", clientKeyPath, clientCSRPath); err != nil {
|
||||||
|
t.Fatalf("Failed to generate client CSR: %v", err)
|
||||||
|
}
|
||||||
|
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, clientCSRPath, clientCertPath, 24*time.Hour); err != nil {
|
||||||
|
t.Fatalf("Failed to sign client certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and start server
|
||||||
|
server, err := NewServer("localhost:0", serverCertPath, serverKeyPath, caCertPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a test handler
|
||||||
|
server.router.HandleFunc("GET", "/test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("test successful"))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start server in a goroutine
|
||||||
|
go func() {
|
||||||
|
if err := server.Start(); err != nil && err != http.ErrServerClosed {
|
||||||
|
t.Errorf("Server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for server to start
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Load CA cert
|
||||||
|
caCert, err := os.ReadFile(caCertPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read CA cert: %v", err)
|
||||||
|
}
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
caCertPool.AppendCertsFromPEM(caCert)
|
||||||
|
|
||||||
|
// Load client cert
|
||||||
|
clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load client cert: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create HTTP client with mTLS
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
RootCAs: caCertPool,
|
||||||
|
Certificates: []tls.Certificate{clientCert},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with valid client cert
|
||||||
|
resp, err := client.Get("https://localhost:8443/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(string(body), "test successful") {
|
||||||
|
t.Errorf("Unexpected response: %s", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with no client cert (should fail)
|
||||||
|
clientWithoutCert := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
RootCAs: caCertPool,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = clientWithoutCert.Get("https://localhost:8443/test")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Request without client cert should fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown server
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
server.Stop(ctx)
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user