feat: Implement agent heartbeat with mTLS and node status tracking
This commit is contained in:
		| @ -266,6 +266,10 @@ func runInit(cmd *cobra.Command, args []string) { | ||||
| 				apiServer.RegisterJoinHandler(joinHandler) | ||||
| 				log.Printf("Registered join handler with CA key: %s, CA cert: %s", caKeyPath, caCertPath) | ||||
| 		 | ||||
| 				// Register the node status handler | ||||
| 				nodeStatusHandler := api.NewNodeStatusHandler(etcdStore) | ||||
| 				apiServer.RegisterNodeStatusHandler(nodeStatusHandler) | ||||
|  | ||||
| 				// Start the server in a goroutine | ||||
| 				go func() { | ||||
| 					if err := apiServer.Start(); err != nil && err != http.ErrServerClosed { | ||||
| @ -333,7 +337,8 @@ func runJoin(cmd *cobra.Command, args []string) { | ||||
| 	pkiDir := filepath.Join(os.Getenv("HOME"), ".kat-agent", nodeName, "pki") | ||||
|  | ||||
| 	// Join the cluster | ||||
| 	if err := cli.JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert, pkiDir); err != nil { | ||||
| 	joinResp, err := cli.JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert, pkiDir) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to join cluster: %v", err) | ||||
| 	} | ||||
|  | ||||
| @ -343,20 +348,36 @@ func runJoin(cmd *cobra.Command, args []string) { | ||||
| 	ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) | ||||
| 	defer stop() | ||||
| 	 | ||||
| 	// Stay up in an idle loop until interrupted | ||||
| 	log.Printf("Node %s is now running. Press Ctrl+C to exit.", nodeName) | ||||
| 	ticker := time.NewTicker(30 * time.Second) | ||||
| 	defer ticker.Stop() | ||||
| 	// Create and start the agent with heartbeating | ||||
| 	agent, err := agent.NewAgent( | ||||
| 		joinResp.NodeName, | ||||
| 		joinResp.NodeUID, | ||||
| 		leaderAPI, | ||||
| 		advertiseAddr, | ||||
| 		pkiDir, | ||||
| 		15, // Default heartbeat interval in seconds | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to create agent: %v", err) | ||||
| 	} | ||||
| 	 | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			log.Println("Received shutdown signal. Exiting...") | ||||
| 			return | ||||
| 		case <-ticker.C: | ||||
| 			log.Printf("Node %s is still running...", nodeName) | ||||
| 	// Setup mTLS client | ||||
| 	if err := agent.SetupMTLSClient(); err != nil { | ||||
| 		log.Fatalf("Failed to setup mTLS client: %v", err) | ||||
| 	} | ||||
| 	 | ||||
| 	// Start heartbeating | ||||
| 	if err := agent.StartHeartbeat(ctx); err != nil { | ||||
| 		log.Fatalf("Failed to start heartbeat: %v", err) | ||||
| 	} | ||||
| 	 | ||||
| 	log.Printf("Node %s is now running with heartbeat. Press Ctrl+C to exit.", nodeName) | ||||
| 	 | ||||
| 	// Wait for shutdown signal | ||||
| 	<-ctx.Done() | ||||
| 	log.Println("Received shutdown signal. Stopping heartbeat...") | ||||
| 	agent.StopHeartbeat() | ||||
| 	log.Println("Exiting...") | ||||
| } | ||||
|  | ||||
| func runVerify(cmd *cobra.Command, args []string) { | ||||
|  | ||||
							
								
								
									
										248
									
								
								internal/agent/agent.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										248
									
								
								internal/agent/agent.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,248 @@ | ||||
| package agent | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"crypto/x509" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 	"runtime" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // NodeStatus represents the data sent in a heartbeat | ||||
| type NodeStatus struct { | ||||
| 	NodeName    string    `json:"nodeName"` | ||||
| 	NodeUID     string    `json:"nodeUID"` | ||||
| 	Timestamp   time.Time `json:"timestamp"` | ||||
| 	Resources   Resources `json:"resources"` | ||||
| 	Workloads   []WorkloadStatus `json:"workloadInstances,omitempty"` | ||||
| 	NetworkInfo NetworkInfo `json:"overlayNetwork"` | ||||
| } | ||||
|  | ||||
| // Resources represents the node's resource capacity and usage | ||||
| type Resources struct { | ||||
| 	Capacity    ResourceMetrics `json:"capacity"` | ||||
| 	Allocatable ResourceMetrics `json:"allocatable"` | ||||
| } | ||||
|  | ||||
| // ResourceMetrics contains CPU and memory metrics | ||||
| type ResourceMetrics struct { | ||||
| 	CPU    string `json:"cpu"`    // e.g., "2000m" | ||||
| 	Memory string `json:"memory"` // e.g., "4096Mi" | ||||
| } | ||||
|  | ||||
| // WorkloadStatus represents the status of a workload instance | ||||
| type WorkloadStatus struct { | ||||
| 	WorkloadName  string `json:"workloadName"` | ||||
| 	Namespace     string `json:"namespace"` | ||||
| 	InstanceID    string `json:"instanceID"` | ||||
| 	ContainerID   string `json:"containerID"` | ||||
| 	ImageID       string `json:"imageID"` | ||||
| 	State         string `json:"state"`         // "running", "exited", "paused", "unknown" | ||||
| 	ExitCode      int    `json:"exitCode"` | ||||
| 	HealthStatus  string `json:"healthStatus"`  // "healthy", "unhealthy", "pending_check" | ||||
| 	Restarts      int    `json:"restarts"` | ||||
| } | ||||
|  | ||||
| // NetworkInfo contains information about the node's overlay network | ||||
| type NetworkInfo struct { | ||||
| 	Status       string `json:"status"`       // "connected", "disconnected", "initializing" | ||||
| 	LastPeerSync string `json:"lastPeerSync"` // timestamp | ||||
| } | ||||
|  | ||||
| // Agent represents a KAT agent node | ||||
| type Agent struct { | ||||
| 	NodeName      string | ||||
| 	NodeUID       string | ||||
| 	LeaderAPI     string | ||||
| 	AdvertiseAddr string | ||||
| 	PKIDir        string | ||||
| 	 | ||||
| 	// mTLS client for leader communication | ||||
| 	client        *http.Client | ||||
| 	 | ||||
| 	// Heartbeat configuration | ||||
| 	heartbeatInterval time.Duration | ||||
| 	stopHeartbeat     chan struct{} | ||||
| } | ||||
|  | ||||
| // NewAgent creates a new Agent instance | ||||
| func NewAgent(nodeName, nodeUID, leaderAPI, advertiseAddr, pkiDir string, heartbeatIntervalSeconds int) (*Agent, error) { | ||||
| 	if heartbeatIntervalSeconds <= 0 { | ||||
| 		heartbeatIntervalSeconds = 15 // Default to 15 seconds | ||||
| 	} | ||||
| 	 | ||||
| 	return &Agent{ | ||||
| 		NodeName:          nodeName, | ||||
| 		NodeUID:           nodeUID, | ||||
| 		LeaderAPI:         leaderAPI, | ||||
| 		AdvertiseAddr:     advertiseAddr, | ||||
| 		PKIDir:            pkiDir, | ||||
| 		heartbeatInterval: time.Duration(heartbeatIntervalSeconds) * time.Second, | ||||
| 		stopHeartbeat:     make(chan struct{}), | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| // SetupMTLSClient configures the HTTP client with mTLS using the agent's certificates | ||||
| func (a *Agent) SetupMTLSClient() error { | ||||
| 	// Load client certificate and key | ||||
| 	cert, err := tls.LoadX509KeyPair( | ||||
| 		fmt.Sprintf("%s/node.crt", a.PKIDir), | ||||
| 		fmt.Sprintf("%s/node.key", a.PKIDir), | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to load client certificate and key: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	// Load CA certificate | ||||
| 	caCert, err := os.ReadFile(fmt.Sprintf("%s/ca.crt", a.PKIDir)) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to read CA certificate: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	caCertPool := x509.NewCertPool() | ||||
| 	if !caCertPool.AppendCertsFromPEM(caCert) { | ||||
| 		return fmt.Errorf("failed to append CA certificate to pool") | ||||
| 	} | ||||
|  | ||||
| 	// Create TLS configuration | ||||
| 	tlsConfig := &tls.Config{ | ||||
| 		Certificates: []tls.Certificate{cert}, | ||||
| 		RootCAs:      caCertPool, | ||||
| 		MinVersion:   tls.VersionTLS12, | ||||
| 	} | ||||
|  | ||||
| 	// Create HTTP client with TLS configuration | ||||
| 	a.client = &http.Client{ | ||||
| 		Transport: &http.Transport{ | ||||
| 			TLSClientConfig: tlsConfig, | ||||
| 		}, | ||||
| 		Timeout: 10 * time.Second, | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // StartHeartbeat begins sending periodic heartbeats to the leader | ||||
| func (a *Agent) StartHeartbeat(ctx context.Context) error { | ||||
| 	if a.client == nil { | ||||
| 		if err := a.SetupMTLSClient(); err != nil { | ||||
| 			return fmt.Errorf("failed to setup mTLS client: %w", err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	log.Printf("Starting heartbeat to leader at %s every %v", a.LeaderAPI, a.heartbeatInterval) | ||||
|  | ||||
| 	ticker := time.NewTicker(a.heartbeatInterval) | ||||
| 	defer ticker.Stop() | ||||
|  | ||||
| 	// Send initial heartbeat immediately | ||||
| 	if err := a.sendHeartbeat(); err != nil { | ||||
| 		log.Printf("Initial heartbeat failed: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			select { | ||||
| 			case <-ticker.C: | ||||
| 				if err := a.sendHeartbeat(); err != nil { | ||||
| 					log.Printf("Heartbeat failed: %v", err) | ||||
| 				} | ||||
| 			case <-a.stopHeartbeat: | ||||
| 				log.Printf("Heartbeat stopped") | ||||
| 				return | ||||
| 			case <-ctx.Done(): | ||||
| 				log.Printf("Heartbeat context cancelled") | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // StopHeartbeat stops the heartbeat goroutine | ||||
| func (a *Agent) StopHeartbeat() { | ||||
| 	close(a.stopHeartbeat) | ||||
| } | ||||
|  | ||||
| // sendHeartbeat sends a single heartbeat to the leader | ||||
| func (a *Agent) sendHeartbeat() error { | ||||
| 	// Gather node status | ||||
| 	status := a.gatherNodeStatus() | ||||
|  | ||||
| 	// Marshal to JSON | ||||
| 	statusJSON, err := json.Marshal(status) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to marshal node status: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	// Construct URL | ||||
| 	url := fmt.Sprintf("https://%s/v1alpha1/nodes/%s/status", a.LeaderAPI, a.NodeName) | ||||
|  | ||||
| 	// Create request | ||||
| 	req, err := http.NewRequest("POST", url, bytes.NewBuffer(statusJSON)) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to create request: %w", err) | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
|  | ||||
| 	// Send request | ||||
| 	resp, err := a.client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to send heartbeat: %w", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	// Check response | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		return fmt.Errorf("heartbeat returned non-OK status: %d", resp.StatusCode) | ||||
| 	} | ||||
|  | ||||
| 	log.Printf("Heartbeat sent successfully to %s", url) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // gatherNodeStatus collects the current node status | ||||
| func (a *Agent) gatherNodeStatus() NodeStatus { | ||||
| 	// For now, just provide basic information | ||||
| 	// In future phases, this will include actual resource usage, workload status, etc. | ||||
| 	 | ||||
| 	// Get basic system info for initial capacity reporting | ||||
| 	var m runtime.MemStats | ||||
| 	runtime.ReadMemStats(&m) | ||||
| 	 | ||||
| 	// Convert to human-readable format (very simplified for now) | ||||
| 	cpuCapacity := fmt.Sprintf("%dm", runtime.NumCPU() * 1000) | ||||
| 	memCapacity := fmt.Sprintf("%dMi", m.Sys / (1024 * 1024)) | ||||
| 	 | ||||
| 	// For allocatable, we'll just use 90% of capacity for this phase | ||||
| 	cpuAllocatable := fmt.Sprintf("%dm", runtime.NumCPU() * 900) | ||||
| 	memAllocatable := fmt.Sprintf("%dMi", (m.Sys / (1024 * 1024)) * 9 / 10) | ||||
| 	 | ||||
| 	return NodeStatus{ | ||||
| 		NodeName:  a.NodeName, | ||||
| 		NodeUID:   a.NodeUID, | ||||
| 		Timestamp: time.Now(), | ||||
| 		Resources: Resources{ | ||||
| 			Capacity: ResourceMetrics{ | ||||
| 				CPU:    cpuCapacity, | ||||
| 				Memory: memCapacity, | ||||
| 			}, | ||||
| 			Allocatable: ResourceMetrics{ | ||||
| 				CPU:    cpuAllocatable, | ||||
| 				Memory: memAllocatable, | ||||
| 			}, | ||||
| 		}, | ||||
| 		NetworkInfo: NetworkInfo{ | ||||
| 			Status:       "initializing", // Placeholder until network is implemented | ||||
| 			LastPeerSync: time.Now().Format(time.RFC3339), | ||||
| 		}, | ||||
| 		// Workloads will be empty for now | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										137
									
								
								internal/agent/agent_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								internal/agent/agent_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,137 @@ | ||||
| package agent | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"crypto/x509" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"git.dws.rip/dubey/kat/internal/pki" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| func TestAgentHeartbeat(t *testing.T) { | ||||
| 	// Create temporary directory for test PKI files | ||||
| 	tempDir, err := os.MkdirTemp("", "kat-test-agent-*") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to create temp directory: %v", err) | ||||
| 	} | ||||
| 	defer os.RemoveAll(tempDir) | ||||
|  | ||||
| 	// Generate CA for testing | ||||
| 	pkiDir := filepath.Join(tempDir, "pki") | ||||
| 	caKeyPath := filepath.Join(pkiDir, "ca.key") | ||||
| 	caCertPath := filepath.Join(pkiDir, "ca.crt") | ||||
| 	err = pki.GenerateCA(pkiDir, caKeyPath, caCertPath) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to generate test CA: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// Generate node certificate | ||||
| 	nodeKeyPath := filepath.Join(pkiDir, "node.key") | ||||
| 	nodeCSRPath := filepath.Join(pkiDir, "node.csr") | ||||
| 	nodeCertPath := filepath.Join(pkiDir, "node.crt") | ||||
| 	err = pki.GenerateCertificateRequest("test-node", nodeKeyPath, nodeCSRPath) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to generate node key and CSR: %v", err) | ||||
| 	} | ||||
| 	err = pki.SignCertificateRequest(caKeyPath, caCertPath, nodeCSRPath, nodeCertPath, 24*time.Hour) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to sign node CSR: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// Create a test server that requires client certificates | ||||
| 	server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		// Verify the request path | ||||
| 		if r.URL.Path != "/v1alpha1/nodes/test-node/status" { | ||||
| 			t.Errorf("Expected path /v1alpha1/nodes/test-node/status, got %s", r.URL.Path) | ||||
| 			http.Error(w, "Invalid path", http.StatusBadRequest) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Verify the request method | ||||
| 		if r.Method != "POST" { | ||||
| 			t.Errorf("Expected method POST, got %s", r.Method) | ||||
| 			http.Error(w, "Invalid method", http.StatusMethodNotAllowed) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Parse the request body | ||||
| 		var status NodeStatus | ||||
| 		decoder := json.NewDecoder(r.Body) | ||||
| 		if err := decoder.Decode(&status); err != nil { | ||||
| 			t.Errorf("Failed to decode request body: %v", err) | ||||
| 			http.Error(w, "Invalid request body", http.StatusBadRequest) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Verify the node name | ||||
| 		if status.NodeName != "test-node" { | ||||
| 			t.Errorf("Expected node name test-node, got %s", status.NodeName) | ||||
| 			http.Error(w, "Invalid node name", http.StatusBadRequest) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Verify that resources are present | ||||
| 		if status.Resources.Capacity.CPU == "" || status.Resources.Capacity.Memory == "" { | ||||
| 			t.Errorf("Missing resource capacity information") | ||||
| 			http.Error(w, "Missing resource information", http.StatusBadRequest) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Return success | ||||
| 		w.WriteHeader(http.StatusOK) | ||||
| 	})) | ||||
| 	defer server.Close() | ||||
|  | ||||
| 	// Configure the server to require client certificates | ||||
| 	server.TLS.ClientAuth = tls.RequireAndVerifyClientCert | ||||
| 	server.TLS.ClientCAs = x509.NewCertPool() | ||||
| 	caCertData, err := os.ReadFile(caCertPath) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to read CA certificate: %v", err) | ||||
| 	} | ||||
| 	server.TLS.ClientCAs.AppendCertsFromPEM(caCertData) | ||||
|  | ||||
| 	// Extract the host:port from the server URL | ||||
| 	serverURL := server.URL | ||||
| 	hostPort := serverURL[8:] // Remove "https://" prefix | ||||
|  | ||||
| 	// Create an agent | ||||
| 	agent, err := NewAgent("test-node", "test-uid", hostPort, "192.168.1.100", pkiDir, 1) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to create agent: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// Setup mTLS client | ||||
| 	err = agent.SetupMTLSClient() | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to setup mTLS client: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// Create a context with timeout | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	// Start heartbeat | ||||
| 	err = agent.StartHeartbeat(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to start heartbeat: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// Wait for at least one heartbeat | ||||
| 	time.Sleep(2 * time.Second) | ||||
|  | ||||
| 	// Stop heartbeat | ||||
| 	agent.StopHeartbeat() | ||||
|  | ||||
| 	// Test passed if we got here without errors | ||||
| 	fmt.Println("Agent heartbeat test passed") | ||||
| } | ||||
							
								
								
									
										108
									
								
								internal/api/node_status_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								internal/api/node_status_handler.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,108 @@ | ||||
| package api | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"git.dws.rip/dubey/kat/internal/store" | ||||
| ) | ||||
|  | ||||
| // NodeStatusRequest represents the data sent by an agent in a heartbeat | ||||
| type NodeStatusRequest struct { | ||||
| 	NodeName    string    `json:"nodeName"` | ||||
| 	NodeUID     string    `json:"nodeUID"` | ||||
| 	Timestamp   time.Time `json:"timestamp"` | ||||
| 	Resources   struct { | ||||
| 		Capacity    map[string]string `json:"capacity"` | ||||
| 		Allocatable map[string]string `json:"allocatable"` | ||||
| 	} `json:"resources"` | ||||
| 	WorkloadInstances []struct { | ||||
| 		WorkloadName  string `json:"workloadName"` | ||||
| 		Namespace     string `json:"namespace"` | ||||
| 		InstanceID    string `json:"instanceID"` | ||||
| 		ContainerID   string `json:"containerID"` | ||||
| 		ImageID       string `json:"imageID"` | ||||
| 		State         string `json:"state"` | ||||
| 		ExitCode      int    `json:"exitCode"` | ||||
| 		HealthStatus  string `json:"healthStatus"` | ||||
| 		Restarts      int    `json:"restarts"` | ||||
| 	} `json:"workloadInstances,omitempty"` | ||||
| 	OverlayNetwork struct { | ||||
| 		Status       string `json:"status"` | ||||
| 		LastPeerSync string `json:"lastPeerSync"` | ||||
| 	} `json:"overlayNetwork"` | ||||
| } | ||||
|  | ||||
| // NewNodeStatusHandler creates a handler for node status updates | ||||
| func NewNodeStatusHandler(stateStore store.StateStore) http.HandlerFunc { | ||||
| 	return func(w http.ResponseWriter, r *http.Request) { | ||||
| 		// Extract node name from URL path | ||||
| 		pathParts := strings.Split(r.URL.Path, "/") | ||||
| 		if len(pathParts) < 4 { | ||||
| 			http.Error(w, "Invalid URL path", http.StatusBadRequest) | ||||
| 			return | ||||
| 		} | ||||
| 		nodeName := pathParts[len(pathParts)-2] // /v1alpha1/nodes/{nodeName}/status | ||||
|  | ||||
| 		log.Printf("Received status update from node: %s", nodeName) | ||||
|  | ||||
| 		// Read and parse the request body | ||||
| 		body, err := io.ReadAll(r.Body) | ||||
| 		if err != nil { | ||||
| 			log.Printf("Failed to read request body: %v", err) | ||||
| 			http.Error(w, "Failed to read request body", http.StatusBadRequest) | ||||
| 			return | ||||
| 		} | ||||
| 		defer r.Body.Close() | ||||
|  | ||||
| 		var statusReq NodeStatusRequest | ||||
| 		if err := json.Unmarshal(body, &statusReq); err != nil { | ||||
| 			log.Printf("Failed to parse status request: %v", err) | ||||
| 			http.Error(w, "Failed to parse status request", http.StatusBadRequest) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Validate that the node name in the URL matches the one in the request | ||||
| 		if statusReq.NodeName != nodeName { | ||||
| 			log.Printf("Node name mismatch: %s (URL) vs %s (body)", nodeName, statusReq.NodeName) | ||||
| 			http.Error(w, "Node name mismatch", http.StatusBadRequest) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Store the node status in etcd | ||||
| 		nodeStatusKey := fmt.Sprintf("/kat/nodes/status/%s", nodeName) | ||||
| 		nodeStatus := map[string]interface{}{ | ||||
| 			"lastHeartbeat": time.Now().Unix(), | ||||
| 			"status":        "Ready", | ||||
| 			"resources":     statusReq.Resources, | ||||
| 			"network":       statusReq.OverlayNetwork, | ||||
| 		} | ||||
|  | ||||
| 		// Add workload instances if present | ||||
| 		if len(statusReq.WorkloadInstances) > 0 { | ||||
| 			nodeStatus["workloadInstances"] = statusReq.WorkloadInstances | ||||
| 		} | ||||
|  | ||||
| 		nodeStatusData, err := json.Marshal(nodeStatus) | ||||
| 		if err != nil { | ||||
| 			log.Printf("Failed to marshal node status: %v", err) | ||||
| 			http.Error(w, "Failed to marshal node status", http.StatusInternalServerError) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		log.Printf("Storing node status in etcd at key: %s", nodeStatusKey) | ||||
| 		if err := stateStore.Put(r.Context(), nodeStatusKey, nodeStatusData); err != nil { | ||||
| 			log.Printf("Failed to store node status: %v", err) | ||||
| 			http.Error(w, "Failed to store node status", http.StatusInternalServerError) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		log.Printf("Successfully stored status update for node: %s", nodeName) | ||||
| 		w.WriteHeader(http.StatusOK) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										108
									
								
								internal/api/node_status_handler_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								internal/api/node_status_handler_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,108 @@ | ||||
| package api | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"git.dws.rip/dubey/kat/internal/store" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"github.com/stretchr/testify/mock" | ||||
| ) | ||||
|  | ||||
| func TestNodeStatusHandler(t *testing.T) { | ||||
| 	// Create mock state store | ||||
| 	mockStore := new(MockStateStore) | ||||
| 	mockStore.On("Put", mock.Anything, mock.MatchedBy(func(key string) bool { | ||||
| 		return key == "/kat/nodes/status/test-node" | ||||
| 	}), mock.Anything).Return(nil) | ||||
|  | ||||
| 	// Create node status handler | ||||
| 	handler := NewNodeStatusHandler(mockStore) | ||||
|  | ||||
| 	// Create test request | ||||
| 	statusReq := NodeStatusRequest{ | ||||
| 		NodeName:  "test-node", | ||||
| 		NodeUID:   "test-uid", | ||||
| 		Timestamp: time.Now(), | ||||
| 		Resources: struct { | ||||
| 			Capacity    map[string]string `json:"capacity"` | ||||
| 			Allocatable map[string]string `json:"allocatable"` | ||||
| 		}{ | ||||
| 			Capacity: map[string]string{ | ||||
| 				"cpu":    "2000m", | ||||
| 				"memory": "4096Mi", | ||||
| 			}, | ||||
| 			Allocatable: map[string]string{ | ||||
| 				"cpu":    "1800m", | ||||
| 				"memory": "3800Mi", | ||||
| 			}, | ||||
| 		}, | ||||
| 		OverlayNetwork: struct { | ||||
| 			Status       string `json:"status"` | ||||
| 			LastPeerSync string `json:"lastPeerSync"` | ||||
| 		}{ | ||||
| 			Status:       "connected", | ||||
| 			LastPeerSync: time.Now().Format(time.RFC3339), | ||||
| 		}, | ||||
| 	} | ||||
| 	reqBody, err := json.Marshal(statusReq) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to marshal status request: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// Create HTTP request | ||||
| 	req := httptest.NewRequest("POST", "/v1alpha1/nodes/test-node/status", bytes.NewBuffer(reqBody)) | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
| 	w := httptest.NewRecorder() | ||||
|  | ||||
| 	// Call handler | ||||
| 	handler(w, req) | ||||
|  | ||||
| 	// Check response | ||||
| 	resp := w.Result() | ||||
| 	defer resp.Body.Close() | ||||
| 	assert.Equal(t, http.StatusOK, resp.StatusCode) | ||||
|  | ||||
| 	// Verify mock was called | ||||
| 	mockStore.AssertExpectations(t) | ||||
| } | ||||
|  | ||||
| func TestNodeStatusHandlerNameMismatch(t *testing.T) { | ||||
| 	// Create mock state store | ||||
| 	mockStore := new(MockStateStore) | ||||
|  | ||||
| 	// Create node status handler | ||||
| 	handler := NewNodeStatusHandler(mockStore) | ||||
|  | ||||
| 	// Create test request with mismatched node name | ||||
| 	statusReq := NodeStatusRequest{ | ||||
| 		NodeName:  "wrong-node", // This doesn't match the URL path | ||||
| 		NodeUID:   "test-uid", | ||||
| 		Timestamp: time.Now(), | ||||
| 	} | ||||
| 	reqBody, err := json.Marshal(statusReq) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to marshal status request: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// Create HTTP request | ||||
| 	req := httptest.NewRequest("POST", "/v1alpha1/nodes/test-node/status", bytes.NewBuffer(reqBody)) | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
| 	w := httptest.NewRecorder() | ||||
|  | ||||
| 	// Call handler | ||||
| 	handler(w, req) | ||||
|  | ||||
| 	// Check response - should be bad request due to name mismatch | ||||
| 	resp := w.Result() | ||||
| 	defer resp.Body.Close() | ||||
| 	assert.Equal(t, http.StatusBadRequest, resp.StatusCode) | ||||
|  | ||||
| 	// Verify mock was not called | ||||
| 	mockStore.AssertNotCalled(t, "Put", mock.Anything, mock.Anything, mock.Anything) | ||||
| } | ||||
| @ -142,4 +142,5 @@ func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) { | ||||
| // RegisterNodeStatusHandler registers the handler for node status updates | ||||
| func (s *Server) RegisterNodeStatusHandler(handler http.HandlerFunc) { | ||||
| 	s.router.HandleFunc("POST", "/v1alpha1/nodes/{nodeName}/status", handler) | ||||
| 	log.Printf("Registered node status handler at /v1alpha1/nodes/{nodeName}/status") | ||||
| } | ||||
|  | ||||
| @ -36,7 +36,7 @@ type JoinResponse struct { | ||||
| } | ||||
|  | ||||
| // JoinCluster sends a join request to the leader and processes the response | ||||
| func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir string) error { | ||||
| func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir string) (*JoinResponse, error) { | ||||
| 	// Create PKI directory if it doesn't exist | ||||
| 	if err := os.MkdirAll(pkiDir, 0700); err != nil { | ||||
| 		return fmt.Errorf("failed to create PKI directory: %w", err) | ||||
| @ -164,5 +164,5 @@ func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir | ||||
| 		log.Printf("Etcd join instructions: %s", joinResp.EtcdJoinInstructions) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| 	return &joinResp, nil | ||||
| } | ||||
|  | ||||
		Reference in New Issue
	
	Block a user