package api import ( "encoding/base64" "encoding/json" "fmt" "io" "log" "net/http" "os" "path/filepath" "time" "github.com/google/uuid" "git.dws.rip/dubey/kat/internal/pki" "git.dws.rip/dubey/kat/internal/store" ) // JoinRequest represents the data sent by an agent when joining type JoinRequest struct { CSRData string `json:"csrData"` // base64 encoded CSR AdvertiseAddr string `json:"advertiseAddr"` NodeName string `json:"nodeName,omitempty"` // Optional, leader can generate WireGuardPubKey string `json:"wireguardPubKey"` // Placeholder for now } // JoinResponse represents the data sent back to the agent type JoinResponse struct { NodeName string `json:"nodeName"` NodeUID string `json:"nodeUID"` SignedCertificate string `json:"signedCertificate"` // base64 encoded certificate CACertificate string `json:"caCertificate"` // base64 encoded CA certificate AssignedSubnet string `json:"assignedSubnet"` // Placeholder for now EtcdJoinInstructions string `json:"etcdJoinInstructions,omitempty"` } // NewJoinHandler creates a handler for agent join requests func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { log.Printf("Received join request from %s", r.RemoteAddr) // Read and parse the request body body, err := io.ReadAll(r.Body) if err != nil { log.Printf("Failed to read request body: %v", err) http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest) return } defer r.Body.Close() var joinReq JoinRequest if err := json.Unmarshal(body, &joinReq); err != nil { log.Printf("Failed to parse request: %v", err) http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest) return } // Validate request if joinReq.CSRData == "" { log.Printf("Missing CSR data") http.Error(w, "Missing CSR data", http.StatusBadRequest) return } if joinReq.AdvertiseAddr == "" { log.Printf("Missing advertise address") http.Error(w, "Missing advertise address", http.StatusBadRequest) return } // Generate node name if not provided nodeName := joinReq.NodeName if nodeName == "" { nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8]) log.Printf("Generated node name: %s", nodeName) } // Generate a unique node ID nodeUID := uuid.New().String() log.Printf("Generated node UID: %s", nodeUID) // Decode CSR data csrData, err := base64.StdEncoding.DecodeString(joinReq.CSRData) if err != nil { log.Printf("Failed to decode CSR data: %v", err) http.Error(w, fmt.Sprintf("Failed to decode CSR data: %v", err), http.StatusBadRequest) return } // Create a temporary file for the CSR tempDir := os.TempDir() csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID)) if err := os.WriteFile(csrPath, csrData, 0600); err != nil { log.Printf("Failed to save CSR: %v", err) http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError) return } defer os.Remove(csrPath) // Sign the CSR certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID)) if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil { log.Printf("Failed to sign CSR: %v", err) http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError) return } defer os.Remove(certPath) // Read the signed certificate signedCert, err := os.ReadFile(certPath) if err != nil { log.Printf("Failed to read signed certificate: %v", err) http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError) return } // Read the CA certificate caCert, err := os.ReadFile(caCertPath) if err != nil { log.Printf("Failed to read CA certificate: %v", err) http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError) return } // Store node registration in etcd nodeRegKey := fmt.Sprintf("/kat/nodes/registration/%s", nodeName) nodeReg := map[string]interface{}{ "uid": nodeUID, "advertiseAddr": joinReq.AdvertiseAddr, "wireguardPubKey": joinReq.WireGuardPubKey, "joinTimestamp": time.Now().Unix(), } nodeRegData, err := json.Marshal(nodeReg) if err != nil { log.Printf("Failed to marshal node registration: %v", err) http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError) return } log.Printf("Storing node registration in etcd at key: %s", nodeRegKey) if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil { log.Printf("Failed to store node registration: %v", err) http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError) return } log.Printf("Successfully stored node registration in etcd") // Prepare and send response joinResp := JoinResponse{ NodeName: nodeName, NodeUID: nodeUID, SignedCertificate: base64.StdEncoding.EncodeToString(signedCert), CACertificate: base64.StdEncoding.EncodeToString(caCert), AssignedSubnet: "10.100.0.0/24", // Placeholder for now, will be implemented in network phase } respData, err := json.Marshal(joinResp) if err != nil { log.Printf("Failed to marshal response: %v", err) http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(respData) log.Printf("Successfully processed join request for node: %s", nodeName) } }