package cli import ( "bytes" "crypto/tls" "crypto/x509" "encoding/base64" "encoding/json" "fmt" "io" "log" "net/http" "os" "path/filepath" "time" "git.dws.rip/dubey/kat/internal/pki" ) // JoinRequest represents the data sent to the leader when joining type JoinRequest struct { NodeName string `json:"nodeName"` AdvertiseAddr string `json:"advertiseAddr"` CSRData string `json:"csrData"` // base64 encoded CSR WireGuardPubKey string `json:"wireguardPubKey"` } // JoinResponse represents the data received from the leader after a successful join type JoinResponse struct { NodeName string `json:"nodeName"` NodeUID string `json:"nodeUID"` SignedCertificate string `json:"signedCertificate"` // base64 encoded certificate CACertificate string `json:"caCertificate"` // base64 encoded CA certificate AssignedSubnet string `json:"assignedSubnet"` EtcdJoinInstructions string `json:"etcdJoinInstructions,omitempty"` } // JoinCluster sends a join request to the leader and processes the response func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir string) (*JoinResponse, error) { // Create PKI directory if it doesn't exist if err := os.MkdirAll(pkiDir, 0700); err != nil { return fmt.Errorf("failed to create PKI directory: %w", err) } // Generate key and CSR nodeKeyPath := filepath.Join(pkiDir, "node.key") nodeCSRPath := filepath.Join(pkiDir, "node.csr") nodeCertPath := filepath.Join(pkiDir, "node.crt") caCertPath := filepath.Join(pkiDir, "ca.crt") log.Printf("Generating node key and CSR...") if err := pki.GenerateCertificateRequest(nodeName, nodeKeyPath, nodeCSRPath); err != nil { return fmt.Errorf("failed to generate key and CSR: %w", err) } // Read the CSR file csrData, err := os.ReadFile(nodeCSRPath) if err != nil { return fmt.Errorf("failed to read CSR file: %w", err) } // Create join request joinReq := JoinRequest{ NodeName: nodeName, AdvertiseAddr: advertiseAddr, CSRData: base64.StdEncoding.EncodeToString(csrData), WireGuardPubKey: "placeholder", // Will be implemented in a future phase } // Marshal request to JSON reqBody, err := json.Marshal(joinReq) if err != nil { return fmt.Errorf("failed to marshal join request: %w", err) } // Create HTTP client with TLS configuration client := &http.Client{ Timeout: 30 * time.Second, } // If leader CA cert is provided, configure TLS to trust it if leaderCACert != "" { // Read the CA cert file caCert, err := os.ReadFile(leaderCACert) if err != nil { return fmt.Errorf("failed to read leader CA certificate: %w", err) } // Create a cert pool and add the CA cert caCertPool := x509.NewCertPool() if !caCertPool.AppendCertsFromPEM(caCert) { return fmt.Errorf("failed to parse leader CA certificate") } // Configure TLS client.Transport = &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: caCertPool, }, } } else { // For Phase 2 development, allow insecure connections // This should be removed in production log.Println("WARNING: No leader CA certificate provided. TLS verification disabled (Phase 2 development mode).") log.Println("This is expected for the initial join process in Phase 2.") client.Transport = &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, } } // Send join request to leader joinURL := fmt.Sprintf("https://%s/internal/v1alpha1/join", leaderAPI) log.Printf("Sending join request to %s...", joinURL) resp, err := client.Post(joinURL, "application/json", bytes.NewBuffer(reqBody)) if err != nil { return fmt.Errorf("failed to send join request: %w", err) } defer resp.Body.Close() // Read response body respBody, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read response body: %w", err) } // Check response status if resp.StatusCode != http.StatusOK { return fmt.Errorf("join request failed with status %d: %s", resp.StatusCode, string(respBody)) } // Parse response var joinResp JoinResponse if err := json.Unmarshal(respBody, &joinResp); err != nil { return fmt.Errorf("failed to parse join response: %w", err) } // Save signed certificate certData, err := base64.StdEncoding.DecodeString(joinResp.SignedCertificate) if err != nil { return fmt.Errorf("failed to decode signed certificate: %w", err) } if err := os.WriteFile(nodeCertPath, certData, 0600); err != nil { return fmt.Errorf("failed to save signed certificate: %w", err) } log.Printf("Saved signed certificate to %s", nodeCertPath) // Save CA certificate caCertData, err := base64.StdEncoding.DecodeString(joinResp.CACertificate) if err != nil { return fmt.Errorf("failed to decode CA certificate: %w", err) } if err := os.WriteFile(caCertPath, caCertData, 0600); err != nil { return fmt.Errorf("failed to save CA certificate: %w", err) } log.Printf("Saved CA certificate to %s", caCertPath) log.Printf("Successfully joined cluster as node: %s", joinResp.NodeName) if joinResp.AssignedSubnet != "" { log.Printf("Assigned subnet: %s", joinResp.AssignedSubnet) } if joinResp.EtcdJoinInstructions != "" { log.Printf("Etcd join instructions: %s", joinResp.EtcdJoinInstructions) } return &joinResp, nil }