diff --git a/cmd/kat-agent/main.go b/cmd/kat-agent/main.go index 2a99992..146932c 100644 --- a/cmd/kat-agent/main.go +++ b/cmd/kat-agent/main.go @@ -265,6 +265,10 @@ func runInit(cmd *cobra.Command, args []string) { joinHandler := api.NewJoinHandler(etcdStore, caKeyPath, caCertPath) apiServer.RegisterJoinHandler(joinHandler) log.Printf("Registered join handler with CA key: %s, CA cert: %s", caKeyPath, caCertPath) + + // Register the node status handler + nodeStatusHandler := api.NewNodeStatusHandler(etcdStore) + apiServer.RegisterNodeStatusHandler(nodeStatusHandler) // Start the server in a goroutine go func() { @@ -333,7 +337,8 @@ func runJoin(cmd *cobra.Command, args []string) { pkiDir := filepath.Join(os.Getenv("HOME"), ".kat-agent", nodeName, "pki") // Join the cluster - if err := cli.JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert, pkiDir); err != nil { + joinResp, err := cli.JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert, pkiDir) + if err != nil { log.Fatalf("Failed to join cluster: %v", err) } @@ -343,20 +348,36 @@ func runJoin(cmd *cobra.Command, args []string) { ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() - // Stay up in an idle loop until interrupted - log.Printf("Node %s is now running. Press Ctrl+C to exit.", nodeName) - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - log.Println("Received shutdown signal. Exiting...") - return - case <-ticker.C: - log.Printf("Node %s is still running...", nodeName) - } + // Create and start the agent with heartbeating + agent, err := agent.NewAgent( + joinResp.NodeName, + joinResp.NodeUID, + leaderAPI, + advertiseAddr, + pkiDir, + 15, // Default heartbeat interval in seconds + ) + if err != nil { + log.Fatalf("Failed to create agent: %v", err) } + + // Setup mTLS client + if err := agent.SetupMTLSClient(); err != nil { + log.Fatalf("Failed to setup mTLS client: %v", err) + } + + // Start heartbeating + if err := agent.StartHeartbeat(ctx); err != nil { + log.Fatalf("Failed to start heartbeat: %v", err) + } + + log.Printf("Node %s is now running with heartbeat. Press Ctrl+C to exit.", nodeName) + + // Wait for shutdown signal + <-ctx.Done() + log.Println("Received shutdown signal. Stopping heartbeat...") + agent.StopHeartbeat() + log.Println("Exiting...") } func runVerify(cmd *cobra.Command, args []string) { diff --git a/internal/agent/agent.go b/internal/agent/agent.go new file mode 100644 index 0000000..21239a1 --- /dev/null +++ b/internal/agent/agent.go @@ -0,0 +1,248 @@ +package agent + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "log" + "net/http" + "os" + "runtime" + "time" +) + +// NodeStatus represents the data sent in a heartbeat +type NodeStatus struct { + NodeName string `json:"nodeName"` + NodeUID string `json:"nodeUID"` + Timestamp time.Time `json:"timestamp"` + Resources Resources `json:"resources"` + Workloads []WorkloadStatus `json:"workloadInstances,omitempty"` + NetworkInfo NetworkInfo `json:"overlayNetwork"` +} + +// Resources represents the node's resource capacity and usage +type Resources struct { + Capacity ResourceMetrics `json:"capacity"` + Allocatable ResourceMetrics `json:"allocatable"` +} + +// ResourceMetrics contains CPU and memory metrics +type ResourceMetrics struct { + CPU string `json:"cpu"` // e.g., "2000m" + Memory string `json:"memory"` // e.g., "4096Mi" +} + +// WorkloadStatus represents the status of a workload instance +type WorkloadStatus struct { + WorkloadName string `json:"workloadName"` + Namespace string `json:"namespace"` + InstanceID string `json:"instanceID"` + ContainerID string `json:"containerID"` + ImageID string `json:"imageID"` + State string `json:"state"` // "running", "exited", "paused", "unknown" + ExitCode int `json:"exitCode"` + HealthStatus string `json:"healthStatus"` // "healthy", "unhealthy", "pending_check" + Restarts int `json:"restarts"` +} + +// NetworkInfo contains information about the node's overlay network +type NetworkInfo struct { + Status string `json:"status"` // "connected", "disconnected", "initializing" + LastPeerSync string `json:"lastPeerSync"` // timestamp +} + +// Agent represents a KAT agent node +type Agent struct { + NodeName string + NodeUID string + LeaderAPI string + AdvertiseAddr string + PKIDir string + + // mTLS client for leader communication + client *http.Client + + // Heartbeat configuration + heartbeatInterval time.Duration + stopHeartbeat chan struct{} +} + +// NewAgent creates a new Agent instance +func NewAgent(nodeName, nodeUID, leaderAPI, advertiseAddr, pkiDir string, heartbeatIntervalSeconds int) (*Agent, error) { + if heartbeatIntervalSeconds <= 0 { + heartbeatIntervalSeconds = 15 // Default to 15 seconds + } + + return &Agent{ + NodeName: nodeName, + NodeUID: nodeUID, + LeaderAPI: leaderAPI, + AdvertiseAddr: advertiseAddr, + PKIDir: pkiDir, + heartbeatInterval: time.Duration(heartbeatIntervalSeconds) * time.Second, + stopHeartbeat: make(chan struct{}), + }, nil +} + +// SetupMTLSClient configures the HTTP client with mTLS using the agent's certificates +func (a *Agent) SetupMTLSClient() error { + // Load client certificate and key + cert, err := tls.LoadX509KeyPair( + fmt.Sprintf("%s/node.crt", a.PKIDir), + fmt.Sprintf("%s/node.key", a.PKIDir), + ) + if err != nil { + return fmt.Errorf("failed to load client certificate and key: %w", err) + } + + // Load CA certificate + caCert, err := os.ReadFile(fmt.Sprintf("%s/ca.crt", a.PKIDir)) + if err != nil { + return fmt.Errorf("failed to read CA certificate: %w", err) + } + + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return fmt.Errorf("failed to append CA certificate to pool") + } + + // Create TLS configuration + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + } + + // Create HTTP client with TLS configuration + a.client = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + Timeout: 10 * time.Second, + } + + return nil +} + +// StartHeartbeat begins sending periodic heartbeats to the leader +func (a *Agent) StartHeartbeat(ctx context.Context) error { + if a.client == nil { + if err := a.SetupMTLSClient(); err != nil { + return fmt.Errorf("failed to setup mTLS client: %w", err) + } + } + + log.Printf("Starting heartbeat to leader at %s every %v", a.LeaderAPI, a.heartbeatInterval) + + ticker := time.NewTicker(a.heartbeatInterval) + defer ticker.Stop() + + // Send initial heartbeat immediately + if err := a.sendHeartbeat(); err != nil { + log.Printf("Initial heartbeat failed: %v", err) + } + + go func() { + for { + select { + case <-ticker.C: + if err := a.sendHeartbeat(); err != nil { + log.Printf("Heartbeat failed: %v", err) + } + case <-a.stopHeartbeat: + log.Printf("Heartbeat stopped") + return + case <-ctx.Done(): + log.Printf("Heartbeat context cancelled") + return + } + } + }() + + return nil +} + +// StopHeartbeat stops the heartbeat goroutine +func (a *Agent) StopHeartbeat() { + close(a.stopHeartbeat) +} + +// sendHeartbeat sends a single heartbeat to the leader +func (a *Agent) sendHeartbeat() error { + // Gather node status + status := a.gatherNodeStatus() + + // Marshal to JSON + statusJSON, err := json.Marshal(status) + if err != nil { + return fmt.Errorf("failed to marshal node status: %w", err) + } + + // Construct URL + url := fmt.Sprintf("https://%s/v1alpha1/nodes/%s/status", a.LeaderAPI, a.NodeName) + + // Create request + req, err := http.NewRequest("POST", url, bytes.NewBuffer(statusJSON)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + // Send request + resp, err := a.client.Do(req) + if err != nil { + return fmt.Errorf("failed to send heartbeat: %w", err) + } + defer resp.Body.Close() + + // Check response + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("heartbeat returned non-OK status: %d", resp.StatusCode) + } + + log.Printf("Heartbeat sent successfully to %s", url) + return nil +} + +// gatherNodeStatus collects the current node status +func (a *Agent) gatherNodeStatus() NodeStatus { + // For now, just provide basic information + // In future phases, this will include actual resource usage, workload status, etc. + + // Get basic system info for initial capacity reporting + var m runtime.MemStats + runtime.ReadMemStats(&m) + + // Convert to human-readable format (very simplified for now) + cpuCapacity := fmt.Sprintf("%dm", runtime.NumCPU() * 1000) + memCapacity := fmt.Sprintf("%dMi", m.Sys / (1024 * 1024)) + + // For allocatable, we'll just use 90% of capacity for this phase + cpuAllocatable := fmt.Sprintf("%dm", runtime.NumCPU() * 900) + memAllocatable := fmt.Sprintf("%dMi", (m.Sys / (1024 * 1024)) * 9 / 10) + + return NodeStatus{ + NodeName: a.NodeName, + NodeUID: a.NodeUID, + Timestamp: time.Now(), + Resources: Resources{ + Capacity: ResourceMetrics{ + CPU: cpuCapacity, + Memory: memCapacity, + }, + Allocatable: ResourceMetrics{ + CPU: cpuAllocatable, + Memory: memAllocatable, + }, + }, + NetworkInfo: NetworkInfo{ + Status: "initializing", // Placeholder until network is implemented + LastPeerSync: time.Now().Format(time.RFC3339), + }, + // Workloads will be empty for now + } +} diff --git a/internal/agent/agent_test.go b/internal/agent/agent_test.go new file mode 100644 index 0000000..d8f8094 --- /dev/null +++ b/internal/agent/agent_test.go @@ -0,0 +1,137 @@ +package agent + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "git.dws.rip/dubey/kat/internal/pki" + "github.com/stretchr/testify/assert" +) + +func TestAgentHeartbeat(t *testing.T) { + // Create temporary directory for test PKI files + tempDir, err := os.MkdirTemp("", "kat-test-agent-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Generate CA for testing + pkiDir := filepath.Join(tempDir, "pki") + caKeyPath := filepath.Join(pkiDir, "ca.key") + caCertPath := filepath.Join(pkiDir, "ca.crt") + err = pki.GenerateCA(pkiDir, caKeyPath, caCertPath) + if err != nil { + t.Fatalf("Failed to generate test CA: %v", err) + } + + // Generate node certificate + nodeKeyPath := filepath.Join(pkiDir, "node.key") + nodeCSRPath := filepath.Join(pkiDir, "node.csr") + nodeCertPath := filepath.Join(pkiDir, "node.crt") + err = pki.GenerateCertificateRequest("test-node", nodeKeyPath, nodeCSRPath) + if err != nil { + t.Fatalf("Failed to generate node key and CSR: %v", err) + } + err = pki.SignCertificateRequest(caKeyPath, caCertPath, nodeCSRPath, nodeCertPath, 24*time.Hour) + if err != nil { + t.Fatalf("Failed to sign node CSR: %v", err) + } + + // Create a test server that requires client certificates + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify the request path + if r.URL.Path != "/v1alpha1/nodes/test-node/status" { + t.Errorf("Expected path /v1alpha1/nodes/test-node/status, got %s", r.URL.Path) + http.Error(w, "Invalid path", http.StatusBadRequest) + return + } + + // Verify the request method + if r.Method != "POST" { + t.Errorf("Expected method POST, got %s", r.Method) + http.Error(w, "Invalid method", http.StatusMethodNotAllowed) + return + } + + // Parse the request body + var status NodeStatus + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&status); err != nil { + t.Errorf("Failed to decode request body: %v", err) + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + + // Verify the node name + if status.NodeName != "test-node" { + t.Errorf("Expected node name test-node, got %s", status.NodeName) + http.Error(w, "Invalid node name", http.StatusBadRequest) + return + } + + // Verify that resources are present + if status.Resources.Capacity.CPU == "" || status.Resources.Capacity.Memory == "" { + t.Errorf("Missing resource capacity information") + http.Error(w, "Missing resource information", http.StatusBadRequest) + return + } + + // Return success + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Configure the server to require client certificates + server.TLS.ClientAuth = tls.RequireAndVerifyClientCert + server.TLS.ClientCAs = x509.NewCertPool() + caCertData, err := os.ReadFile(caCertPath) + if err != nil { + t.Fatalf("Failed to read CA certificate: %v", err) + } + server.TLS.ClientCAs.AppendCertsFromPEM(caCertData) + + // Extract the host:port from the server URL + serverURL := server.URL + hostPort := serverURL[8:] // Remove "https://" prefix + + // Create an agent + agent, err := NewAgent("test-node", "test-uid", hostPort, "192.168.1.100", pkiDir, 1) + if err != nil { + t.Fatalf("Failed to create agent: %v", err) + } + + // Setup mTLS client + err = agent.SetupMTLSClient() + if err != nil { + t.Fatalf("Failed to setup mTLS client: %v", err) + } + + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Start heartbeat + err = agent.StartHeartbeat(ctx) + if err != nil { + t.Fatalf("Failed to start heartbeat: %v", err) + } + + // Wait for at least one heartbeat + time.Sleep(2 * time.Second) + + // Stop heartbeat + agent.StopHeartbeat() + + // Test passed if we got here without errors + fmt.Println("Agent heartbeat test passed") +} diff --git a/internal/api/node_status_handler.go b/internal/api/node_status_handler.go new file mode 100644 index 0000000..38a6b7a --- /dev/null +++ b/internal/api/node_status_handler.go @@ -0,0 +1,108 @@ +package api + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strings" + "time" + + "git.dws.rip/dubey/kat/internal/store" +) + +// NodeStatusRequest represents the data sent by an agent in a heartbeat +type NodeStatusRequest struct { + NodeName string `json:"nodeName"` + NodeUID string `json:"nodeUID"` + Timestamp time.Time `json:"timestamp"` + Resources struct { + Capacity map[string]string `json:"capacity"` + Allocatable map[string]string `json:"allocatable"` + } `json:"resources"` + WorkloadInstances []struct { + WorkloadName string `json:"workloadName"` + Namespace string `json:"namespace"` + InstanceID string `json:"instanceID"` + ContainerID string `json:"containerID"` + ImageID string `json:"imageID"` + State string `json:"state"` + ExitCode int `json:"exitCode"` + HealthStatus string `json:"healthStatus"` + Restarts int `json:"restarts"` + } `json:"workloadInstances,omitempty"` + OverlayNetwork struct { + Status string `json:"status"` + LastPeerSync string `json:"lastPeerSync"` + } `json:"overlayNetwork"` +} + +// NewNodeStatusHandler creates a handler for node status updates +func NewNodeStatusHandler(stateStore store.StateStore) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Extract node name from URL path + pathParts := strings.Split(r.URL.Path, "/") + if len(pathParts) < 4 { + http.Error(w, "Invalid URL path", http.StatusBadRequest) + return + } + nodeName := pathParts[len(pathParts)-2] // /v1alpha1/nodes/{nodeName}/status + + log.Printf("Received status update from node: %s", nodeName) + + // Read and parse the request body + body, err := io.ReadAll(r.Body) + if err != nil { + log.Printf("Failed to read request body: %v", err) + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + defer r.Body.Close() + + var statusReq NodeStatusRequest + if err := json.Unmarshal(body, &statusReq); err != nil { + log.Printf("Failed to parse status request: %v", err) + http.Error(w, "Failed to parse status request", http.StatusBadRequest) + return + } + + // Validate that the node name in the URL matches the one in the request + if statusReq.NodeName != nodeName { + log.Printf("Node name mismatch: %s (URL) vs %s (body)", nodeName, statusReq.NodeName) + http.Error(w, "Node name mismatch", http.StatusBadRequest) + return + } + + // Store the node status in etcd + nodeStatusKey := fmt.Sprintf("/kat/nodes/status/%s", nodeName) + nodeStatus := map[string]interface{}{ + "lastHeartbeat": time.Now().Unix(), + "status": "Ready", + "resources": statusReq.Resources, + "network": statusReq.OverlayNetwork, + } + + // Add workload instances if present + if len(statusReq.WorkloadInstances) > 0 { + nodeStatus["workloadInstances"] = statusReq.WorkloadInstances + } + + nodeStatusData, err := json.Marshal(nodeStatus) + if err != nil { + log.Printf("Failed to marshal node status: %v", err) + http.Error(w, "Failed to marshal node status", http.StatusInternalServerError) + return + } + + log.Printf("Storing node status in etcd at key: %s", nodeStatusKey) + if err := stateStore.Put(r.Context(), nodeStatusKey, nodeStatusData); err != nil { + log.Printf("Failed to store node status: %v", err) + http.Error(w, "Failed to store node status", http.StatusInternalServerError) + return + } + + log.Printf("Successfully stored status update for node: %s", nodeName) + w.WriteHeader(http.StatusOK) + } +} diff --git a/internal/api/node_status_handler_test.go b/internal/api/node_status_handler_test.go new file mode 100644 index 0000000..875fb1d --- /dev/null +++ b/internal/api/node_status_handler_test.go @@ -0,0 +1,108 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "git.dws.rip/dubey/kat/internal/store" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestNodeStatusHandler(t *testing.T) { + // Create mock state store + mockStore := new(MockStateStore) + mockStore.On("Put", mock.Anything, mock.MatchedBy(func(key string) bool { + return key == "/kat/nodes/status/test-node" + }), mock.Anything).Return(nil) + + // Create node status handler + handler := NewNodeStatusHandler(mockStore) + + // Create test request + statusReq := NodeStatusRequest{ + NodeName: "test-node", + NodeUID: "test-uid", + Timestamp: time.Now(), + Resources: struct { + Capacity map[string]string `json:"capacity"` + Allocatable map[string]string `json:"allocatable"` + }{ + Capacity: map[string]string{ + "cpu": "2000m", + "memory": "4096Mi", + }, + Allocatable: map[string]string{ + "cpu": "1800m", + "memory": "3800Mi", + }, + }, + OverlayNetwork: struct { + Status string `json:"status"` + LastPeerSync string `json:"lastPeerSync"` + }{ + Status: "connected", + LastPeerSync: time.Now().Format(time.RFC3339), + }, + } + reqBody, err := json.Marshal(statusReq) + if err != nil { + t.Fatalf("Failed to marshal status request: %v", err) + } + + // Create HTTP request + req := httptest.NewRequest("POST", "/v1alpha1/nodes/test-node/status", bytes.NewBuffer(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + // Call handler + handler(w, req) + + // Check response + resp := w.Result() + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Verify mock was called + mockStore.AssertExpectations(t) +} + +func TestNodeStatusHandlerNameMismatch(t *testing.T) { + // Create mock state store + mockStore := new(MockStateStore) + + // Create node status handler + handler := NewNodeStatusHandler(mockStore) + + // Create test request with mismatched node name + statusReq := NodeStatusRequest{ + NodeName: "wrong-node", // This doesn't match the URL path + NodeUID: "test-uid", + Timestamp: time.Now(), + } + reqBody, err := json.Marshal(statusReq) + if err != nil { + t.Fatalf("Failed to marshal status request: %v", err) + } + + // Create HTTP request + req := httptest.NewRequest("POST", "/v1alpha1/nodes/test-node/status", bytes.NewBuffer(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + // Call handler + handler(w, req) + + // Check response - should be bad request due to name mismatch + resp := w.Result() + defer resp.Body.Close() + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + + // Verify mock was not called + mockStore.AssertNotCalled(t, "Put", mock.Anything, mock.Anything, mock.Anything) +} diff --git a/internal/api/server.go b/internal/api/server.go index caba510..0110fb8 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -142,4 +142,5 @@ func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) { // RegisterNodeStatusHandler registers the handler for node status updates func (s *Server) RegisterNodeStatusHandler(handler http.HandlerFunc) { s.router.HandleFunc("POST", "/v1alpha1/nodes/{nodeName}/status", handler) + log.Printf("Registered node status handler at /v1alpha1/nodes/{nodeName}/status") } diff --git a/internal/cli/join.go b/internal/cli/join.go index 6834321..e8b8901 100644 --- a/internal/cli/join.go +++ b/internal/cli/join.go @@ -36,7 +36,7 @@ type JoinResponse struct { } // JoinCluster sends a join request to the leader and processes the response -func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir string) error { +func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir string) (*JoinResponse, error) { // Create PKI directory if it doesn't exist if err := os.MkdirAll(pkiDir, 0700); err != nil { return fmt.Errorf("failed to create PKI directory: %w", err) @@ -164,5 +164,5 @@ func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir log.Printf("Etcd join instructions: %s", joinResp.EtcdJoinInstructions) } - return nil + return &joinResp, nil }