package api import ( "encoding/json" "fmt" "io" "net/http" "os" "path/filepath" "time" "github.com/google/uuid" "git.dws.rip/dubey/kat/internal/pki" "git.dws.rip/dubey/kat/internal/store" ) // JoinRequest represents the data sent by an agent when joining type JoinRequest struct { CSR []byte `json:"csr"` AdvertiseAddr string `json:"advertiseAddr"` NodeName string `json:"nodeName,omitempty"` // Optional, leader can generate WireguardPubKey string `json:"wireguardPubKey"` // Placeholder for now } // JoinResponse represents the data sent back to the agent type JoinResponse struct { NodeName string `json:"nodeName"` NodeUID string `json:"nodeUID"` SignedCert []byte `json:"signedCert"` CACert []byte `json:"caCert"` JoinTimestamp int64 `json:"joinTimestamp"` } // NewJoinHandler creates a handler for agent join requests func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Read and parse the request body body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest) return } defer r.Body.Close() var joinReq JoinRequest if err := json.Unmarshal(body, &joinReq); err != nil { http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest) return } // Validate request if len(joinReq.CSR) == 0 { http.Error(w, "Missing CSR", http.StatusBadRequest) return } if joinReq.AdvertiseAddr == "" { http.Error(w, "Missing advertise address", http.StatusBadRequest) return } // Generate node name if not provided nodeName := joinReq.NodeName if nodeName == "" { nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8]) } // Generate a unique node ID nodeUID := uuid.New().String() // Sign the CSR // Create a temporary file for the CSR tempDir := os.TempDir() csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID)) if err := os.WriteFile(csrPath, joinReq.CSR, 0600); err != nil { http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError) return } defer os.Remove(csrPath) // Sign the CSR certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID)) if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil { http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError) return } defer os.Remove(certPath) // Read the signed certificate signedCert, err := os.ReadFile(certPath) if err != nil { http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError) return } // Read the CA certificate caCert, err := os.ReadFile(caCertPath) if err != nil { http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError) return } // Store node registration in etcd nodeRegKey := fmt.Sprintf("/kat/nodes/registration/%s", nodeName) nodeReg := map[string]interface{}{ "uid": nodeUID, "advertiseAddr": joinReq.AdvertiseAddr, "wireguardPubKey": joinReq.WireguardPubKey, "joinTimestamp": time.Now().Unix(), } nodeRegData, err := json.Marshal(nodeReg) if err != nil { http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError) return } if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil { http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError) return } // Prepare and send response joinResp := JoinResponse{ NodeName: nodeName, NodeUID: nodeUID, SignedCert: signedCert, CACert: caCert, JoinTimestamp: time.Now().Unix(), } respData, err := json.Marshal(joinResp) if err != nil { http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(respData) } }