diff --git a/cmd/kat-agent/main.go b/cmd/kat-agent/main.go index d227d53..c94acbc 100644 --- a/cmd/kat-agent/main.go +++ b/cmd/kat-agent/main.go @@ -248,124 +248,9 @@ func runInit(cmd *cobra.Command, args []string) { log.Printf("Failed to create API server: %v", err) } else { // Register the join handler - apiServer.RegisterJoinHandler(func(w http.ResponseWriter, r *http.Request) { - log.Printf("Received join request from %s", r.RemoteAddr) - - // In Phase 2, we're not requiring client certificates yet - log.Printf("Processing join request without client certificate verification (Phase 2)") - - // Read request body - var joinReq cli.JoinRequest - if err := json.NewDecoder(r.Body).Decode(&joinReq); err != nil { - log.Printf("Error decoding join request: %v", err) - http.Error(w, "Invalid request format", http.StatusBadRequest) - return - } - - // Validate request - if joinReq.NodeName == "" || joinReq.AdvertiseAddr == "" || joinReq.CSRData == "" { - log.Printf("Invalid join request: missing required fields") - http.Error(w, "Missing required fields", http.StatusBadRequest) - return - } - - log.Printf("Processing join request for node: %s, advertise address: %s", - joinReq.NodeName, joinReq.AdvertiseAddr) - - // Decode CSR data - csrData, err := base64.StdEncoding.DecodeString(joinReq.CSRData) - if err != nil { - log.Printf("Error decoding CSR data: %v", err) - http.Error(w, "Invalid CSR data", http.StatusBadRequest) - return - } - - // Create a temporary file for the CSR - tempCSRFile, err := os.CreateTemp("", "node-csr-*.pem") - if err != nil { - log.Printf("Error creating temp CSR file: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } - defer os.Remove(tempCSRFile.Name()) - - // Write CSR data to temp file - if _, err := tempCSRFile.Write(csrData); err != nil { - log.Printf("Error writing CSR data to temp file: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } - tempCSRFile.Close() - - // Create a temp file for the signed certificate - tempCertFile, err := os.CreateTemp("", "node-cert-*.pem") - if err != nil { - log.Printf("Error creating temp cert file: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } - defer os.Remove(tempCertFile.Name()) - tempCertFile.Close() - - // Sign the CSR - if err := pki.SignCertificateRequest( - filepath.Join(pkiDir, "ca.key"), - filepath.Join(pkiDir, "ca.crt"), - tempCSRFile.Name(), - tempCertFile.Name(), - 365*24*time.Hour, // 1 year validity - ); err != nil { - log.Printf("Error signing CSR: %v", err) - http.Error(w, "Failed to sign certificate", http.StatusInternalServerError) - return - } - - // Read the signed certificate - signedCert, err := os.ReadFile(tempCertFile.Name()) - if err != nil { - log.Printf("Error reading signed certificate: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } - - // Read the CA certificate - caCert, err := os.ReadFile(filepath.Join(pkiDir, "ca.crt")) - if err != nil { - log.Printf("Error reading CA certificate: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } - - // Generate a unique node UID - nodeUID := uuid.New().String() - - // Store node registration in etcd (placeholder for now) - // In a future phase, we'll implement proper node registration with subnet assignment - - // Create response - joinResp := cli.JoinResponse{ - NodeName: joinReq.NodeName, - NodeUID: nodeUID, - SignedCertificate: base64.StdEncoding.EncodeToString(signedCert), - CACertificate: base64.StdEncoding.EncodeToString(caCert), - AssignedSubnet: "10.100.0.0/24", // Placeholder, will be properly implemented in network phase - } - - // If etcd peer was requested, add join instructions (placeholder) - if etcdPeer { - joinResp.EtcdJoinInstructions = "Etcd peer join not implemented in this phase" - } - - // Send response - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(joinResp); err != nil { - log.Printf("Error encoding join response: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } - - log.Printf("Successfully processed join request for node: %s", joinReq.NodeName) - }) + joinHandler := api.NewJoinHandler(etcdStore, caKeyPath, caCertPath) + apiServer.RegisterJoinHandler(joinHandler) + log.Printf("Registered join handler with CA key: %s, CA cert: %s", caKeyPath, caCertPath) // Start the server in a goroutine go func() { diff --git a/internal/api/join_handler.go b/internal/api/join_handler.go index 30808f2..804331e 100644 --- a/internal/api/join_handler.go +++ b/internal/api/join_handler.go @@ -1,9 +1,11 @@ package api import ( + "encoding/base64" "encoding/json" "fmt" "io" + "log" "net/http" "os" "path/filepath" @@ -17,27 +19,31 @@ import ( // JoinRequest represents the data sent by an agent when joining type JoinRequest struct { - CSR []byte `json:"csr"` + CSRData string `json:"csrData"` // base64 encoded CSR AdvertiseAddr string `json:"advertiseAddr"` NodeName string `json:"nodeName,omitempty"` // Optional, leader can generate - WireguardPubKey string `json:"wireguardPubKey"` // Placeholder for now + WireGuardPubKey string `json:"wireguardPubKey"` // Placeholder for now } // JoinResponse represents the data sent back to the agent type JoinResponse struct { - NodeName string `json:"nodeName"` - NodeUID string `json:"nodeUID"` - SignedCert []byte `json:"signedCert"` - CACert []byte `json:"caCert"` - JoinTimestamp int64 `json:"joinTimestamp"` + NodeName string `json:"nodeName"` + NodeUID string `json:"nodeUID"` + SignedCertificate string `json:"signedCertificate"` // base64 encoded certificate + CACertificate string `json:"caCertificate"` // base64 encoded CA certificate + AssignedSubnet string `json:"assignedSubnet"` // Placeholder for now + EtcdJoinInstructions string `json:"etcdJoinInstructions,omitempty"` } // NewJoinHandler creates a handler for agent join requests func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + log.Printf("Received join request from %s", r.RemoteAddr) + // Read and parse the request body body, err := io.ReadAll(r.Body) if err != nil { + log.Printf("Failed to read request body: %v", err) http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest) return } @@ -45,16 +51,19 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h var joinReq JoinRequest if err := json.Unmarshal(body, &joinReq); err != nil { + log.Printf("Failed to parse request: %v", err) http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest) return } // Validate request - if len(joinReq.CSR) == 0 { - http.Error(w, "Missing CSR", http.StatusBadRequest) + if joinReq.CSRData == "" { + log.Printf("Missing CSR data") + http.Error(w, "Missing CSR data", http.StatusBadRequest) return } if joinReq.AdvertiseAddr == "" { + log.Printf("Missing advertise address") http.Error(w, "Missing advertise address", http.StatusBadRequest) return } @@ -63,16 +72,26 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h nodeName := joinReq.NodeName if nodeName == "" { nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8]) + log.Printf("Generated node name: %s", nodeName) } // Generate a unique node ID nodeUID := uuid.New().String() + log.Printf("Generated node UID: %s", nodeUID) + + // Decode CSR data + csrData, err := base64.StdEncoding.DecodeString(joinReq.CSRData) + if err != nil { + log.Printf("Failed to decode CSR data: %v", err) + http.Error(w, fmt.Sprintf("Failed to decode CSR data: %v", err), http.StatusBadRequest) + return + } - // Sign the CSR // Create a temporary file for the CSR tempDir := os.TempDir() csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID)) - if err := os.WriteFile(csrPath, joinReq.CSR, 0600); err != nil { + if err := os.WriteFile(csrPath, csrData, 0600); err != nil { + log.Printf("Failed to save CSR: %v", err) http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError) return } @@ -81,6 +100,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h // Sign the CSR certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID)) if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil { + log.Printf("Failed to sign CSR: %v", err) http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError) return } @@ -89,6 +109,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h // Read the signed certificate signedCert, err := os.ReadFile(certPath) if err != nil { + log.Printf("Failed to read signed certificate: %v", err) http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError) return } @@ -96,6 +117,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h // Read the CA certificate caCert, err := os.ReadFile(caCertPath) if err != nil { + log.Printf("Failed to read CA certificate: %v", err) http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError) return } @@ -105,31 +127,36 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h nodeReg := map[string]interface{}{ "uid": nodeUID, "advertiseAddr": joinReq.AdvertiseAddr, - "wireguardPubKey": joinReq.WireguardPubKey, + "wireguardPubKey": joinReq.WireGuardPubKey, "joinTimestamp": time.Now().Unix(), } nodeRegData, err := json.Marshal(nodeReg) if err != nil { + log.Printf("Failed to marshal node registration: %v", err) http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError) return } + log.Printf("Storing node registration in etcd at key: %s", nodeRegKey) if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil { + log.Printf("Failed to store node registration: %v", err) http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError) return } + log.Printf("Successfully stored node registration in etcd") // Prepare and send response joinResp := JoinResponse{ - NodeName: nodeName, - NodeUID: nodeUID, - SignedCert: signedCert, - CACert: caCert, - JoinTimestamp: time.Now().Unix(), + NodeName: nodeName, + NodeUID: nodeUID, + SignedCertificate: base64.StdEncoding.EncodeToString(signedCert), + CACertificate: base64.StdEncoding.EncodeToString(caCert), + AssignedSubnet: "10.100.0.0/24", // Placeholder for now, will be implemented in network phase } respData, err := json.Marshal(joinResp) if err != nil { + log.Printf("Failed to marshal response: %v", err) http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError) return } @@ -137,5 +164,6 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(respData) + log.Printf("Successfully processed join request for node: %s", nodeName) } } diff --git a/internal/api/join_handler_test.go b/internal/api/join_handler_test.go new file mode 100644 index 0000000..985ff44 --- /dev/null +++ b/internal/api/join_handler_test.go @@ -0,0 +1,168 @@ +package api + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "git.dws.rip/dubey/kat/internal/pki" + "git.dws.rip/dubey/kat/internal/store" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockStateStore for testing +type MockStateStore struct { + mock.Mock +} + +func (m *MockStateStore) Put(ctx context.Context, key string, value []byte) error { + args := m.Called(ctx, key, value) + return args.Error(0) +} + +func (m *MockStateStore) Get(ctx context.Context, key string) (*store.KV, error) { + args := m.Called(ctx, key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*store.KV), args.Error(1) +} + +func (m *MockStateStore) Delete(ctx context.Context, key string) error { + args := m.Called(ctx, key) + return args.Error(0) +} + +func (m *MockStateStore) List(ctx context.Context, prefix string) ([]store.KV, error) { + args := m.Called(ctx, prefix) + return args.Get(0).([]store.KV), args.Error(1) +} + +func (m *MockStateStore) Watch(ctx context.Context, keyOrPrefix string, startRevision int64) (<-chan store.WatchEvent, error) { + args := m.Called(ctx, keyOrPrefix, startRevision) + return args.Get(0).(chan store.WatchEvent), args.Error(1) +} + +func (m *MockStateStore) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockStateStore) Campaign(ctx context.Context, leaderID string, leaseTTLSeconds int64) (context.Context, error) { + args := m.Called(ctx, leaderID, leaseTTLSeconds) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(context.Context), args.Error(1) +} + +func (m *MockStateStore) Resign(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *MockStateStore) GetLeader(ctx context.Context) (string, error) { + args := m.Called(ctx) + return args.String(0), args.Error(1) +} + +func (m *MockStateStore) DoTransaction(ctx context.Context, checks []store.Compare, onSuccess []store.Op, onFailure []store.Op) (bool, error) { + args := m.Called(ctx, checks, onSuccess, onFailure) + return args.Bool(0), args.Error(1) +} + +func TestJoinHandler(t *testing.T) { + // Create temporary directory for test PKI files + tempDir, err := os.MkdirTemp("", "kat-test-pki-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Generate CA for testing + caKeyPath := filepath.Join(tempDir, "ca.key") + caCertPath := filepath.Join(tempDir, "ca.crt") + err = pki.GenerateCA(tempDir, caKeyPath, caCertPath) + if err != nil { + t.Fatalf("Failed to generate test CA: %v", err) + } + + // Generate a test CSR + nodeKeyPath := filepath.Join(tempDir, "node.key") + nodeCSRPath := filepath.Join(tempDir, "node.csr") + err = pki.GenerateCertificateRequest("test-node", nodeKeyPath, nodeCSRPath) + if err != nil { + t.Fatalf("Failed to generate test CSR: %v", err) + } + + // Read the CSR file + csrData, err := os.ReadFile(nodeCSRPath) + if err != nil { + t.Fatalf("Failed to read CSR file: %v", err) + } + + // Create mock state store + mockStore := new(MockStateStore) + mockStore.On("Put", mock.Anything, mock.MatchedBy(func(key string) bool { + return key == "/kat/nodes/registration/test-node" + }), mock.Anything).Return(nil) + + // Create join handler + handler := NewJoinHandler(mockStore, caKeyPath, caCertPath) + + // Create test request + joinReq := JoinRequest{ + NodeName: "test-node", + AdvertiseAddr: "192.168.1.100", + CSRData: base64.StdEncoding.EncodeToString(csrData), + WireGuardPubKey: "test-pubkey", + } + reqBody, err := json.Marshal(joinReq) + if err != nil { + t.Fatalf("Failed to marshal join request: %v", err) + } + + // Create HTTP request + req := httptest.NewRequest("POST", "/internal/v1alpha1/join", bytes.NewBuffer(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + // Call handler + handler(w, req) + + // Check response + resp := w.Result() + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Read response body + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + // Parse response + var joinResp JoinResponse + err = json.Unmarshal(respBody, &joinResp) + if err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + // Verify response fields + assert.Equal(t, "test-node", joinResp.NodeName) + assert.NotEmpty(t, joinResp.NodeUID) + assert.NotEmpty(t, joinResp.SignedCertificate) + assert.NotEmpty(t, joinResp.CACertificate) + assert.Equal(t, "10.100.0.0/24", joinResp.AssignedSubnet) // Placeholder value + + // Verify mock was called + mockStore.AssertExpectations(t) +} diff --git a/internal/api/server.go b/internal/api/server.go index 18ce1d7..caba510 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -136,6 +136,7 @@ func (s *Server) Stop(ctx context.Context) error { // RegisterJoinHandler registers the handler for agent join requests func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) { s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler) + log.Printf("Registered join handler at /internal/v1alpha1/join") } // RegisterNodeStatusHandler registers the handler for node status updates