kat/internal/cli/join.go

169 lines
5.3 KiB
Go

package cli
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"time"
"git.dws.rip/dubey/kat/internal/pki"
)
// JoinRequest represents the data sent to the leader when joining
type JoinRequest struct {
NodeName string `json:"nodeName"`
AdvertiseAddr string `json:"advertiseAddr"`
CSRData string `json:"csrData"` // base64 encoded CSR
WireGuardPubKey string `json:"wireguardPubKey"`
}
// JoinResponse represents the data received from the leader after a successful join
type JoinResponse struct {
NodeName string `json:"nodeName"`
NodeUID string `json:"nodeUID"`
SignedCertificate string `json:"signedCertificate"` // base64 encoded certificate
CACertificate string `json:"caCertificate"` // base64 encoded CA certificate
AssignedSubnet string `json:"assignedSubnet"`
EtcdJoinInstructions string `json:"etcdJoinInstructions,omitempty"`
}
// JoinCluster sends a join request to the leader and processes the response
func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir string) (*JoinResponse, error) {
// Create PKI directory if it doesn't exist
if err := os.MkdirAll(pkiDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create PKI directory: %w", err)
}
// Generate key and CSR
nodeKeyPath := filepath.Join(pkiDir, "node.key")
nodeCSRPath := filepath.Join(pkiDir, "node.csr")
nodeCertPath := filepath.Join(pkiDir, "node.crt")
caCertPath := filepath.Join(pkiDir, "ca.crt")
log.Printf("Generating node key and CSR...")
if err := pki.GenerateCertificateRequest(nodeName, nodeKeyPath, nodeCSRPath); err != nil {
return nil, fmt.Errorf("failed to generate key and CSR: %w", err)
}
// Read the CSR file
csrData, err := os.ReadFile(nodeCSRPath)
if err != nil {
return nil, fmt.Errorf("failed to read CSR file: %w", err)
}
// Create join request
joinReq := JoinRequest{
NodeName: nodeName,
AdvertiseAddr: advertiseAddr,
CSRData: base64.StdEncoding.EncodeToString(csrData),
WireGuardPubKey: "placeholder", // Will be implemented in a future phase
}
// Marshal request to JSON
reqBody, err := json.Marshal(joinReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal join request: %w", err)
}
// Create HTTP client with TLS configuration
client := &http.Client{
Timeout: 30 * time.Second,
}
// If leader CA cert is provided, configure TLS to trust it
if leaderCACert != "" {
// Read the CA cert file
caCert, err := os.ReadFile(leaderCACert)
if err != nil {
return nil, fmt.Errorf("failed to read leader CA certificate: %w", err)
}
// Create a cert pool and add the CA cert
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to parse leader CA certificate")
}
// Configure TLS
client.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: caCertPool,
},
}
} else {
// For Phase 2 development, allow insecure connections
// This should be removed in production
log.Println("WARNING: No leader CA certificate provided. TLS verification disabled (Phase 2 development mode).")
log.Println("This is expected for the initial join process in Phase 2.")
client.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
}
// Send join request to leader
joinURL := fmt.Sprintf("https://%s/internal/v1alpha1/join", leaderAPI)
log.Printf("Sending join request to %s...", joinURL)
resp, err := client.Post(joinURL, "application/json", bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to send join request: %w", err)
}
defer resp.Body.Close()
// Read response body
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
// Check response status
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("join request failed with status %d: %s", resp.StatusCode, string(respBody))
}
// Parse response
var joinResp JoinResponse
if err := json.Unmarshal(respBody, &joinResp); err != nil {
return nil, fmt.Errorf("failed to parse join response: %w", err)
}
// Save signed certificate
certData, err := base64.StdEncoding.DecodeString(joinResp.SignedCertificate)
if err != nil {
return nil, fmt.Errorf("failed to decode signed certificate: %w", err)
}
if err := os.WriteFile(nodeCertPath, certData, 0600); err != nil {
return nil, fmt.Errorf("failed to save signed certificate: %w", err)
}
log.Printf("Saved signed certificate to %s", nodeCertPath)
// Save CA certificate
caCertData, err := base64.StdEncoding.DecodeString(joinResp.CACertificate)
if err != nil {
return nil, fmt.Errorf("failed to decode CA certificate: %w", err)
}
if err := os.WriteFile(caCertPath, caCertData, 0600); err != nil {
return nil, fmt.Errorf("failed to save CA certificate: %w", err)
}
log.Printf("Saved CA certificate to %s", caCertPath)
log.Printf("Successfully joined cluster as node: %s", joinResp.NodeName)
if joinResp.AssignedSubnet != "" {
log.Printf("Assigned subnet: %s", joinResp.AssignedSubnet)
}
if joinResp.EtcdJoinInstructions != "" {
log.Printf("Etcd join instructions: %s", joinResp.EtcdJoinInstructions)
}
return &joinResp, nil
}