feat: Implement CSR signing and node registration handler for agent join
This commit is contained in:
		| @ -248,124 +248,9 @@ func runInit(cmd *cobra.Command, args []string) { | |||||||
| 				log.Printf("Failed to create API server: %v", err) | 				log.Printf("Failed to create API server: %v", err) | ||||||
| 			} else { | 			} else { | ||||||
| 				// Register the join handler | 				// Register the join handler | ||||||
| 				apiServer.RegisterJoinHandler(func(w http.ResponseWriter, r *http.Request) { | 				joinHandler := api.NewJoinHandler(etcdStore, caKeyPath, caCertPath) | ||||||
| 					log.Printf("Received join request from %s", r.RemoteAddr) | 				apiServer.RegisterJoinHandler(joinHandler) | ||||||
| 					 | 				log.Printf("Registered join handler with CA key: %s, CA cert: %s", caKeyPath, caCertPath) | ||||||
| 					// In Phase 2, we're not requiring client certificates yet |  | ||||||
| 					log.Printf("Processing join request without client certificate verification (Phase 2)") |  | ||||||
| 					 |  | ||||||
| 					// Read request body |  | ||||||
| 					var joinReq cli.JoinRequest |  | ||||||
| 					if err := json.NewDecoder(r.Body).Decode(&joinReq); err != nil { |  | ||||||
| 						log.Printf("Error decoding join request: %v", err) |  | ||||||
| 						http.Error(w, "Invalid request format", http.StatusBadRequest) |  | ||||||
| 						return |  | ||||||
| 					} |  | ||||||
| 					 |  | ||||||
| 					// Validate request |  | ||||||
| 					if joinReq.NodeName == "" || joinReq.AdvertiseAddr == "" || joinReq.CSRData == "" { |  | ||||||
| 						log.Printf("Invalid join request: missing required fields") |  | ||||||
| 						http.Error(w, "Missing required fields", http.StatusBadRequest) |  | ||||||
| 						return |  | ||||||
| 					} |  | ||||||
| 					 |  | ||||||
| 					log.Printf("Processing join request for node: %s, advertise address: %s",  |  | ||||||
| 						joinReq.NodeName, joinReq.AdvertiseAddr) |  | ||||||
| 					 |  | ||||||
| 					// Decode CSR data |  | ||||||
| 					csrData, err := base64.StdEncoding.DecodeString(joinReq.CSRData) |  | ||||||
| 					if err != nil { |  | ||||||
| 						log.Printf("Error decoding CSR data: %v", err) |  | ||||||
| 						http.Error(w, "Invalid CSR data", http.StatusBadRequest) |  | ||||||
| 						return |  | ||||||
| 					} |  | ||||||
| 					 |  | ||||||
| 					// Create a temporary file for the CSR |  | ||||||
| 					tempCSRFile, err := os.CreateTemp("", "node-csr-*.pem") |  | ||||||
| 					if err != nil { |  | ||||||
| 						log.Printf("Error creating temp CSR file: %v", err) |  | ||||||
| 						http.Error(w, "Internal server error", http.StatusInternalServerError) |  | ||||||
| 						return |  | ||||||
| 					} |  | ||||||
| 					defer os.Remove(tempCSRFile.Name()) |  | ||||||
| 					 |  | ||||||
| 					// Write CSR data to temp file |  | ||||||
| 					if _, err := tempCSRFile.Write(csrData); err != nil { |  | ||||||
| 						log.Printf("Error writing CSR data to temp file: %v", err) |  | ||||||
| 						http.Error(w, "Internal server error", http.StatusInternalServerError) |  | ||||||
| 						return |  | ||||||
| 					} |  | ||||||
| 					tempCSRFile.Close() |  | ||||||
| 					 |  | ||||||
| 					// Create a temp file for the signed certificate |  | ||||||
| 					tempCertFile, err := os.CreateTemp("", "node-cert-*.pem") |  | ||||||
| 					if err != nil { |  | ||||||
| 						log.Printf("Error creating temp cert file: %v", err) |  | ||||||
| 						http.Error(w, "Internal server error", http.StatusInternalServerError) |  | ||||||
| 						return |  | ||||||
| 					} |  | ||||||
| 					defer os.Remove(tempCertFile.Name()) |  | ||||||
| 					tempCertFile.Close() |  | ||||||
| 					 |  | ||||||
| 					// Sign the CSR |  | ||||||
| 					if err := pki.SignCertificateRequest( |  | ||||||
| 						filepath.Join(pkiDir, "ca.key"), |  | ||||||
| 						filepath.Join(pkiDir, "ca.crt"), |  | ||||||
| 						tempCSRFile.Name(), |  | ||||||
| 						tempCertFile.Name(), |  | ||||||
| 						365*24*time.Hour, // 1 year validity |  | ||||||
| 					); err != nil { |  | ||||||
| 						log.Printf("Error signing CSR: %v", err) |  | ||||||
| 						http.Error(w, "Failed to sign certificate", http.StatusInternalServerError) |  | ||||||
| 						return |  | ||||||
| 					} |  | ||||||
| 					 |  | ||||||
| 					// Read the signed certificate |  | ||||||
| 					signedCert, err := os.ReadFile(tempCertFile.Name()) |  | ||||||
| 					if err != nil { |  | ||||||
| 						log.Printf("Error reading signed certificate: %v", err) |  | ||||||
| 						http.Error(w, "Internal server error", http.StatusInternalServerError) |  | ||||||
| 						return |  | ||||||
| 					} |  | ||||||
| 					 |  | ||||||
| 					// Read the CA certificate |  | ||||||
| 					caCert, err := os.ReadFile(filepath.Join(pkiDir, "ca.crt")) |  | ||||||
| 					if err != nil { |  | ||||||
| 						log.Printf("Error reading CA certificate: %v", err) |  | ||||||
| 						http.Error(w, "Internal server error", http.StatusInternalServerError) |  | ||||||
| 						return |  | ||||||
| 					} |  | ||||||
| 					 |  | ||||||
| 					// Generate a unique node UID |  | ||||||
| 					nodeUID := uuid.New().String() |  | ||||||
| 					 |  | ||||||
| 					// Store node registration in etcd (placeholder for now) |  | ||||||
| 					// In a future phase, we'll implement proper node registration with subnet assignment |  | ||||||
| 					 |  | ||||||
| 					// Create response |  | ||||||
| 					joinResp := cli.JoinResponse{ |  | ||||||
| 						NodeName:          joinReq.NodeName, |  | ||||||
| 						NodeUID:           nodeUID, |  | ||||||
| 						SignedCertificate: base64.StdEncoding.EncodeToString(signedCert), |  | ||||||
| 						CACertificate:     base64.StdEncoding.EncodeToString(caCert), |  | ||||||
| 						AssignedSubnet:    "10.100.0.0/24", // Placeholder, will be properly implemented in network phase |  | ||||||
| 					} |  | ||||||
| 					 |  | ||||||
| 					// If etcd peer was requested, add join instructions (placeholder) |  | ||||||
| 					if etcdPeer { |  | ||||||
| 						joinResp.EtcdJoinInstructions = "Etcd peer join not implemented in this phase" |  | ||||||
| 					} |  | ||||||
| 					 |  | ||||||
| 					// Send response |  | ||||||
| 					w.Header().Set("Content-Type", "application/json") |  | ||||||
| 					if err := json.NewEncoder(w).Encode(joinResp); err != nil { |  | ||||||
| 						log.Printf("Error encoding join response: %v", err) |  | ||||||
| 						http.Error(w, "Internal server error", http.StatusInternalServerError) |  | ||||||
| 						return |  | ||||||
| 					} |  | ||||||
| 					 |  | ||||||
| 					log.Printf("Successfully processed join request for node: %s", joinReq.NodeName) |  | ||||||
| 				}) |  | ||||||
|  |  | ||||||
| 				// Start the server in a goroutine | 				// Start the server in a goroutine | ||||||
| 				go func() { | 				go func() { | ||||||
|  | |||||||
| @ -1,9 +1,11 @@ | |||||||
| package api | package api | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"encoding/base64" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
|  | 	"log" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
| @ -17,27 +19,31 @@ import ( | |||||||
|  |  | ||||||
| // JoinRequest represents the data sent by an agent when joining | // JoinRequest represents the data sent by an agent when joining | ||||||
| type JoinRequest struct { | type JoinRequest struct { | ||||||
| 	CSR             []byte `json:"csr"` | 	CSRData         string `json:"csrData"`         // base64 encoded CSR | ||||||
| 	AdvertiseAddr   string `json:"advertiseAddr"` | 	AdvertiseAddr   string `json:"advertiseAddr"` | ||||||
| 	NodeName        string `json:"nodeName,omitempty"` // Optional, leader can generate | 	NodeName        string `json:"nodeName,omitempty"` // Optional, leader can generate | ||||||
| 	WireguardPubKey string `json:"wireguardPubKey"`    // Placeholder for now | 	WireGuardPubKey string `json:"wireguardPubKey"`    // Placeholder for now | ||||||
| } | } | ||||||
|  |  | ||||||
| // JoinResponse represents the data sent back to the agent | // JoinResponse represents the data sent back to the agent | ||||||
| type JoinResponse struct { | type JoinResponse struct { | ||||||
| 	NodeName          string `json:"nodeName"` | 	NodeName          string `json:"nodeName"` | ||||||
| 	NodeUID           string `json:"nodeUID"` | 	NodeUID           string `json:"nodeUID"` | ||||||
| 	SignedCert    []byte `json:"signedCert"` | 	SignedCertificate string `json:"signedCertificate"` // base64 encoded certificate | ||||||
| 	CACert        []byte `json:"caCert"` | 	CACertificate     string `json:"caCertificate"`     // base64 encoded CA certificate | ||||||
| 	JoinTimestamp int64  `json:"joinTimestamp"` | 	AssignedSubnet    string `json:"assignedSubnet"`    // Placeholder for now | ||||||
|  | 	EtcdJoinInstructions string `json:"etcdJoinInstructions,omitempty"` | ||||||
| } | } | ||||||
|  |  | ||||||
| // NewJoinHandler creates a handler for agent join requests | // NewJoinHandler creates a handler for agent join requests | ||||||
| func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc { | func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc { | ||||||
| 	return func(w http.ResponseWriter, r *http.Request) { | 	return func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 		log.Printf("Received join request from %s", r.RemoteAddr) | ||||||
|  | 		 | ||||||
| 		// Read and parse the request body | 		// Read and parse the request body | ||||||
| 		body, err := io.ReadAll(r.Body) | 		body, err := io.ReadAll(r.Body) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | 			log.Printf("Failed to read request body: %v", err) | ||||||
| 			http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest) | 			http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| @ -45,16 +51,19 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h | |||||||
|  |  | ||||||
| 		var joinReq JoinRequest | 		var joinReq JoinRequest | ||||||
| 		if err := json.Unmarshal(body, &joinReq); err != nil { | 		if err := json.Unmarshal(body, &joinReq); err != nil { | ||||||
|  | 			log.Printf("Failed to parse request: %v", err) | ||||||
| 			http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest) | 			http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Validate request | 		// Validate request | ||||||
| 		if len(joinReq.CSR) == 0 { | 		if joinReq.CSRData == "" { | ||||||
| 			http.Error(w, "Missing CSR", http.StatusBadRequest) | 			log.Printf("Missing CSR data") | ||||||
|  | 			http.Error(w, "Missing CSR data", http.StatusBadRequest) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		if joinReq.AdvertiseAddr == "" { | 		if joinReq.AdvertiseAddr == "" { | ||||||
|  | 			log.Printf("Missing advertise address") | ||||||
| 			http.Error(w, "Missing advertise address", http.StatusBadRequest) | 			http.Error(w, "Missing advertise address", http.StatusBadRequest) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| @ -63,16 +72,26 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h | |||||||
| 		nodeName := joinReq.NodeName | 		nodeName := joinReq.NodeName | ||||||
| 		if nodeName == "" { | 		if nodeName == "" { | ||||||
| 			nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8]) | 			nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8]) | ||||||
|  | 			log.Printf("Generated node name: %s", nodeName) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Generate a unique node ID | 		// Generate a unique node ID | ||||||
| 		nodeUID := uuid.New().String() | 		nodeUID := uuid.New().String() | ||||||
|  | 		log.Printf("Generated node UID: %s", nodeUID) | ||||||
|  |  | ||||||
|  | 		// Decode CSR data | ||||||
|  | 		csrData, err := base64.StdEncoding.DecodeString(joinReq.CSRData) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Printf("Failed to decode CSR data: %v", err) | ||||||
|  | 			http.Error(w, fmt.Sprintf("Failed to decode CSR data: %v", err), http.StatusBadRequest) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		// Sign the CSR |  | ||||||
| 		// Create a temporary file for the CSR | 		// Create a temporary file for the CSR | ||||||
| 		tempDir := os.TempDir() | 		tempDir := os.TempDir() | ||||||
| 		csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID)) | 		csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID)) | ||||||
| 		if err := os.WriteFile(csrPath, joinReq.CSR, 0600); err != nil { | 		if err := os.WriteFile(csrPath, csrData, 0600); err != nil { | ||||||
|  | 			log.Printf("Failed to save CSR: %v", err) | ||||||
| 			http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError) | 			http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| @ -81,6 +100,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h | |||||||
| 		// Sign the CSR | 		// Sign the CSR | ||||||
| 		certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID)) | 		certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID)) | ||||||
| 		if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil { | 		if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil { | ||||||
|  | 			log.Printf("Failed to sign CSR: %v", err) | ||||||
| 			http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError) | 			http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| @ -89,6 +109,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h | |||||||
| 		// Read the signed certificate | 		// Read the signed certificate | ||||||
| 		signedCert, err := os.ReadFile(certPath) | 		signedCert, err := os.ReadFile(certPath) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | 			log.Printf("Failed to read signed certificate: %v", err) | ||||||
| 			http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError) | 			http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| @ -96,6 +117,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h | |||||||
| 		// Read the CA certificate | 		// Read the CA certificate | ||||||
| 		caCert, err := os.ReadFile(caCertPath) | 		caCert, err := os.ReadFile(caCertPath) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | 			log.Printf("Failed to read CA certificate: %v", err) | ||||||
| 			http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError) | 			http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| @ -105,31 +127,36 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h | |||||||
| 		nodeReg := map[string]interface{}{ | 		nodeReg := map[string]interface{}{ | ||||||
| 			"uid":             nodeUID, | 			"uid":             nodeUID, | ||||||
| 			"advertiseAddr":   joinReq.AdvertiseAddr, | 			"advertiseAddr":   joinReq.AdvertiseAddr, | ||||||
| 			"wireguardPubKey": joinReq.WireguardPubKey, | 			"wireguardPubKey": joinReq.WireGuardPubKey, | ||||||
| 			"joinTimestamp":   time.Now().Unix(), | 			"joinTimestamp":   time.Now().Unix(), | ||||||
| 		} | 		} | ||||||
| 		nodeRegData, err := json.Marshal(nodeReg) | 		nodeRegData, err := json.Marshal(nodeReg) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | 			log.Printf("Failed to marshal node registration: %v", err) | ||||||
| 			http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError) | 			http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		log.Printf("Storing node registration in etcd at key: %s", nodeRegKey) | ||||||
| 		if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil { | 		if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil { | ||||||
|  | 			log.Printf("Failed to store node registration: %v", err) | ||||||
| 			http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError) | 			http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  | 		log.Printf("Successfully stored node registration in etcd") | ||||||
|  |  | ||||||
| 		// Prepare and send response | 		// Prepare and send response | ||||||
| 		joinResp := JoinResponse{ | 		joinResp := JoinResponse{ | ||||||
| 			NodeName:          nodeName, | 			NodeName:          nodeName, | ||||||
| 			NodeUID:           nodeUID, | 			NodeUID:           nodeUID, | ||||||
| 			SignedCert:    signedCert, | 			SignedCertificate: base64.StdEncoding.EncodeToString(signedCert), | ||||||
| 			CACert:        caCert, | 			CACertificate:     base64.StdEncoding.EncodeToString(caCert), | ||||||
| 			JoinTimestamp: time.Now().Unix(), | 			AssignedSubnet:    "10.100.0.0/24", // Placeholder for now, will be implemented in network phase | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		respData, err := json.Marshal(joinResp) | 		respData, err := json.Marshal(joinResp) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | 			log.Printf("Failed to marshal response: %v", err) | ||||||
| 			http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError) | 			http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| @ -137,5 +164,6 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h | |||||||
| 		w.Header().Set("Content-Type", "application/json") | 		w.Header().Set("Content-Type", "application/json") | ||||||
| 		w.WriteHeader(http.StatusOK) | 		w.WriteHeader(http.StatusOK) | ||||||
| 		w.Write(respData) | 		w.Write(respData) | ||||||
|  | 		log.Printf("Successfully processed join request for node: %s", nodeName) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
							
								
								
									
										168
									
								
								internal/api/join_handler_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										168
									
								
								internal/api/join_handler_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,168 @@ | |||||||
|  | package api | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"context" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"git.dws.rip/dubey/kat/internal/pki" | ||||||
|  | 	"git.dws.rip/dubey/kat/internal/store" | ||||||
|  | 	"github.com/stretchr/testify/assert" | ||||||
|  | 	"github.com/stretchr/testify/mock" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // MockStateStore for testing | ||||||
|  | type MockStateStore struct { | ||||||
|  | 	mock.Mock | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *MockStateStore) Put(ctx context.Context, key string, value []byte) error { | ||||||
|  | 	args := m.Called(ctx, key, value) | ||||||
|  | 	return args.Error(0) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *MockStateStore) Get(ctx context.Context, key string) (*store.KV, error) { | ||||||
|  | 	args := m.Called(ctx, key) | ||||||
|  | 	if args.Get(0) == nil { | ||||||
|  | 		return nil, args.Error(1) | ||||||
|  | 	} | ||||||
|  | 	return args.Get(0).(*store.KV), args.Error(1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *MockStateStore) Delete(ctx context.Context, key string) error { | ||||||
|  | 	args := m.Called(ctx, key) | ||||||
|  | 	return args.Error(0) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *MockStateStore) List(ctx context.Context, prefix string) ([]store.KV, error) { | ||||||
|  | 	args := m.Called(ctx, prefix) | ||||||
|  | 	return args.Get(0).([]store.KV), args.Error(1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *MockStateStore) Watch(ctx context.Context, keyOrPrefix string, startRevision int64) (<-chan store.WatchEvent, error) { | ||||||
|  | 	args := m.Called(ctx, keyOrPrefix, startRevision) | ||||||
|  | 	return args.Get(0).(chan store.WatchEvent), args.Error(1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *MockStateStore) Close() error { | ||||||
|  | 	args := m.Called() | ||||||
|  | 	return args.Error(0) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *MockStateStore) Campaign(ctx context.Context, leaderID string, leaseTTLSeconds int64) (context.Context, error) { | ||||||
|  | 	args := m.Called(ctx, leaderID, leaseTTLSeconds) | ||||||
|  | 	if args.Get(0) == nil { | ||||||
|  | 		return nil, args.Error(1) | ||||||
|  | 	} | ||||||
|  | 	return args.Get(0).(context.Context), args.Error(1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *MockStateStore) Resign(ctx context.Context) error { | ||||||
|  | 	args := m.Called(ctx) | ||||||
|  | 	return args.Error(0) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *MockStateStore) GetLeader(ctx context.Context) (string, error) { | ||||||
|  | 	args := m.Called(ctx) | ||||||
|  | 	return args.String(0), args.Error(1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *MockStateStore) DoTransaction(ctx context.Context, checks []store.Compare, onSuccess []store.Op, onFailure []store.Op) (bool, error) { | ||||||
|  | 	args := m.Called(ctx, checks, onSuccess, onFailure) | ||||||
|  | 	return args.Bool(0), args.Error(1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestJoinHandler(t *testing.T) { | ||||||
|  | 	// Create temporary directory for test PKI files | ||||||
|  | 	tempDir, err := os.MkdirTemp("", "kat-test-pki-*") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to create temp directory: %v", err) | ||||||
|  | 	} | ||||||
|  | 	defer os.RemoveAll(tempDir) | ||||||
|  |  | ||||||
|  | 	// Generate CA for testing | ||||||
|  | 	caKeyPath := filepath.Join(tempDir, "ca.key") | ||||||
|  | 	caCertPath := filepath.Join(tempDir, "ca.crt") | ||||||
|  | 	err = pki.GenerateCA(tempDir, caKeyPath, caCertPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to generate test CA: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Generate a test CSR | ||||||
|  | 	nodeKeyPath := filepath.Join(tempDir, "node.key") | ||||||
|  | 	nodeCSRPath := filepath.Join(tempDir, "node.csr") | ||||||
|  | 	err = pki.GenerateCertificateRequest("test-node", nodeKeyPath, nodeCSRPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to generate test CSR: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Read the CSR file | ||||||
|  | 	csrData, err := os.ReadFile(nodeCSRPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to read CSR file: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Create mock state store | ||||||
|  | 	mockStore := new(MockStateStore) | ||||||
|  | 	mockStore.On("Put", mock.Anything, mock.MatchedBy(func(key string) bool { | ||||||
|  | 		return key == "/kat/nodes/registration/test-node" | ||||||
|  | 	}), mock.Anything).Return(nil) | ||||||
|  |  | ||||||
|  | 	// Create join handler | ||||||
|  | 	handler := NewJoinHandler(mockStore, caKeyPath, caCertPath) | ||||||
|  |  | ||||||
|  | 	// Create test request | ||||||
|  | 	joinReq := JoinRequest{ | ||||||
|  | 		NodeName:        "test-node", | ||||||
|  | 		AdvertiseAddr:   "192.168.1.100", | ||||||
|  | 		CSRData:         base64.StdEncoding.EncodeToString(csrData), | ||||||
|  | 		WireGuardPubKey: "test-pubkey", | ||||||
|  | 	} | ||||||
|  | 	reqBody, err := json.Marshal(joinReq) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to marshal join request: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Create HTTP request | ||||||
|  | 	req := httptest.NewRequest("POST", "/internal/v1alpha1/join", bytes.NewBuffer(reqBody)) | ||||||
|  | 	req.Header.Set("Content-Type", "application/json") | ||||||
|  | 	w := httptest.NewRecorder() | ||||||
|  |  | ||||||
|  | 	// Call handler | ||||||
|  | 	handler(w, req) | ||||||
|  |  | ||||||
|  | 	// Check response | ||||||
|  | 	resp := w.Result() | ||||||
|  | 	defer resp.Body.Close() | ||||||
|  | 	assert.Equal(t, http.StatusOK, resp.StatusCode) | ||||||
|  |  | ||||||
|  | 	// Read response body | ||||||
|  | 	respBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to read response body: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Parse response | ||||||
|  | 	var joinResp JoinResponse | ||||||
|  | 	err = json.Unmarshal(respBody, &joinResp) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to parse response: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Verify response fields | ||||||
|  | 	assert.Equal(t, "test-node", joinResp.NodeName) | ||||||
|  | 	assert.NotEmpty(t, joinResp.NodeUID) | ||||||
|  | 	assert.NotEmpty(t, joinResp.SignedCertificate) | ||||||
|  | 	assert.NotEmpty(t, joinResp.CACertificate) | ||||||
|  | 	assert.Equal(t, "10.100.0.0/24", joinResp.AssignedSubnet) // Placeholder value | ||||||
|  |  | ||||||
|  | 	// Verify mock was called | ||||||
|  | 	mockStore.AssertExpectations(t) | ||||||
|  | } | ||||||
| @ -136,6 +136,7 @@ func (s *Server) Stop(ctx context.Context) error { | |||||||
| // RegisterJoinHandler registers the handler for agent join requests | // RegisterJoinHandler registers the handler for agent join requests | ||||||
| func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) { | func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) { | ||||||
| 	s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler) | 	s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler) | ||||||
|  | 	log.Printf("Registered join handler at /internal/v1alpha1/join") | ||||||
| } | } | ||||||
|  |  | ||||||
| // RegisterNodeStatusHandler registers the handler for node status updates | // RegisterNodeStatusHandler registers the handler for node status updates | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user