fix: correct error handling in JoinCluster function to return proper response

This commit is contained in:
Tanishq Dubey 2025-05-18 10:46:01 -04:00 committed by Tanishq Dubey (aider)
parent ee9d14be05
commit 0e50eaa407
No known key found for this signature in database
GPG Key ID: CFC1931B84DFC3F9

View File

@ -39,7 +39,7 @@ type JoinResponse struct {
func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir string) (*JoinResponse, error) { func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir string) (*JoinResponse, error) {
// Create PKI directory if it doesn't exist // Create PKI directory if it doesn't exist
if err := os.MkdirAll(pkiDir, 0700); err != nil { if err := os.MkdirAll(pkiDir, 0700); err != nil {
return fmt.Errorf("failed to create PKI directory: %w", err) return nil, fmt.Errorf("failed to create PKI directory: %w", err)
} }
// Generate key and CSR // Generate key and CSR
@ -50,13 +50,13 @@ func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir
log.Printf("Generating node key and CSR...") log.Printf("Generating node key and CSR...")
if err := pki.GenerateCertificateRequest(nodeName, nodeKeyPath, nodeCSRPath); err != nil { if err := pki.GenerateCertificateRequest(nodeName, nodeKeyPath, nodeCSRPath); err != nil {
return fmt.Errorf("failed to generate key and CSR: %w", err) return nil, fmt.Errorf("failed to generate key and CSR: %w", err)
} }
// Read the CSR file // Read the CSR file
csrData, err := os.ReadFile(nodeCSRPath) csrData, err := os.ReadFile(nodeCSRPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to read CSR file: %w", err) return nil, fmt.Errorf("failed to read CSR file: %w", err)
} }
// Create join request // Create join request
@ -70,7 +70,7 @@ func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir
// Marshal request to JSON // Marshal request to JSON
reqBody, err := json.Marshal(joinReq) reqBody, err := json.Marshal(joinReq)
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal join request: %w", err) return nil, fmt.Errorf("failed to marshal join request: %w", err)
} }
// Create HTTP client with TLS configuration // Create HTTP client with TLS configuration
@ -83,13 +83,13 @@ func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir
// Read the CA cert file // Read the CA cert file
caCert, err := os.ReadFile(leaderCACert) caCert, err := os.ReadFile(leaderCACert)
if err != nil { if err != nil {
return fmt.Errorf("failed to read leader CA certificate: %w", err) return nil, fmt.Errorf("failed to read leader CA certificate: %w", err)
} }
// Create a cert pool and add the CA cert // Create a cert pool and add the CA cert
caCertPool := x509.NewCertPool() caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) { if !caCertPool.AppendCertsFromPEM(caCert) {
return fmt.Errorf("failed to parse leader CA certificate") return nil, fmt.Errorf("failed to parse leader CA certificate")
} }
// Configure TLS // Configure TLS
@ -115,44 +115,44 @@ func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir
log.Printf("Sending join request to %s...", joinURL) log.Printf("Sending join request to %s...", joinURL)
resp, err := client.Post(joinURL, "application/json", bytes.NewBuffer(reqBody)) resp, err := client.Post(joinURL, "application/json", bytes.NewBuffer(reqBody))
if err != nil { if err != nil {
return fmt.Errorf("failed to send join request: %w", err) return nil, fmt.Errorf("failed to send join request: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
// Read response body // Read response body
respBody, err := io.ReadAll(resp.Body) respBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return fmt.Errorf("failed to read response body: %w", err) return nil, fmt.Errorf("failed to read response body: %w", err)
} }
// Check response status // Check response status
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("join request failed with status %d: %s", resp.StatusCode, string(respBody)) return nil, fmt.Errorf("join request failed with status %d: %s", resp.StatusCode, string(respBody))
} }
// Parse response // Parse response
var joinResp JoinResponse var joinResp JoinResponse
if err := json.Unmarshal(respBody, &joinResp); err != nil { if err := json.Unmarshal(respBody, &joinResp); err != nil {
return fmt.Errorf("failed to parse join response: %w", err) return nil, fmt.Errorf("failed to parse join response: %w", err)
} }
// Save signed certificate // Save signed certificate
certData, err := base64.StdEncoding.DecodeString(joinResp.SignedCertificate) certData, err := base64.StdEncoding.DecodeString(joinResp.SignedCertificate)
if err != nil { if err != nil {
return fmt.Errorf("failed to decode signed certificate: %w", err) return nil, fmt.Errorf("failed to decode signed certificate: %w", err)
} }
if err := os.WriteFile(nodeCertPath, certData, 0600); err != nil { if err := os.WriteFile(nodeCertPath, certData, 0600); err != nil {
return fmt.Errorf("failed to save signed certificate: %w", err) return nil, fmt.Errorf("failed to save signed certificate: %w", err)
} }
log.Printf("Saved signed certificate to %s", nodeCertPath) log.Printf("Saved signed certificate to %s", nodeCertPath)
// Save CA certificate // Save CA certificate
caCertData, err := base64.StdEncoding.DecodeString(joinResp.CACertificate) caCertData, err := base64.StdEncoding.DecodeString(joinResp.CACertificate)
if err != nil { if err != nil {
return fmt.Errorf("failed to decode CA certificate: %w", err) return nil, fmt.Errorf("failed to decode CA certificate: %w", err)
} }
if err := os.WriteFile(caCertPath, caCertData, 0600); err != nil { if err := os.WriteFile(caCertPath, caCertData, 0600); err != nil {
return fmt.Errorf("failed to save CA certificate: %w", err) return nil, fmt.Errorf("failed to save CA certificate: %w", err)
} }
log.Printf("Saved CA certificate to %s", caCertPath) log.Printf("Saved CA certificate to %s", caCertPath)