I noticed a duplicate function in the internal/pki/ca.go file. I'll help you clean it up. Here's the corrected version:

```go
package pki

import (
	// other imports
	"path/filepath"
)

const (
	// Default key size for RSA keys
	DefaultRSAKeySize = 2048
	// Default CA certificate validity period
	DefaultCAValidityDays = 3650 // ~10 years
	// Default certificate validity period
	DefaultCertValidityDays = 365 // 1 year
	// Default PKI directory
	DefaultPKIDir = "/var/lib/kat/pki"
)

// GetPKIPathFromClusterConfig determines the PKI directory from the cluster configuration.
// If backupPath is provided, it uses the parent directory of backupPath.
// Otherwise, it uses the default PKI directory.
func GetPKIPathFromClusterConfig(backupPath string) string {
	if backupPath == "" {
		return DefaultPKIDir
	}

	// Use the parent directory of backupPath
	return filepath.Dir(backupPath) + "/pki"
}

// generateSerialNumber creates a random serial number for certificates
func generateSerialNumber() (*big.Int, error) {
	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) // 128 bits
	return rand.Int(rand.Reader, serialNumberLimit)
}

// Rest of the existing code...
```

The changes:
1. Removed the duplicate `GetPKIPathFromClusterConfig` function
2. Kept the single implementation that checks for an empty backup path
3. Maintained the default PKI directory as `/var/lib/kat/pki`

This should resolve the duplicate function issue while maintaining the desired functionality.

Would you like me to generate a commit message for this change?
This commit is contained in:
2025-05-17 12:18:42 -04:00
parent af6a584628
commit 4f7c2d6a66
4 changed files with 344 additions and 3 deletions

View File

@ -2,6 +2,8 @@ package main
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"log"
"net/http"
@ -12,6 +14,7 @@ import (
"time"
"git.dws.rip/dubey/kat/internal/api"
"git.dws.rip/dubey/kat/internal/cli"
"git.dws.rip/dubey/kat/internal/config"
"git.dws.rip/dubey/kat/internal/leader"
"git.dws.rip/dubey/kat/internal/pki"
@ -37,9 +40,23 @@ campaigns for leadership, and stores initial cluster configuration.`,
Run: runInit,
}
joinCmd = &cobra.Command{
Use: "join",
Short: "Joins an existing KAT cluster.",
Long: `Connects to an existing KAT leader, submits a certificate signing request,
and obtains the necessary credentials to participate in the cluster.`,
Run: runJoin,
}
// Global flags / config paths
clusterConfigPath string
nodeName string
// Join command flags
leaderAPI string
advertiseAddr string
leaderCACert string
etcdPeer bool
)
const (
@ -58,7 +75,19 @@ func init() {
}
initCmd.Flags().StringVar(&nodeName, "node-name", defaultHostName, "Name of this node, used as leader ID if elected.")
// Join command flags
joinCmd.Flags().StringVar(&leaderAPI, "leader-api", "", "Address of the leader API (required, format: host:port)")
joinCmd.Flags().StringVar(&advertiseAddr, "advertise-address", "", "IP address or interface name to advertise to other nodes (required)")
joinCmd.Flags().StringVar(&nodeName, "node-name", defaultHostName, "Name for this node in the cluster")
joinCmd.Flags().StringVar(&leaderCACert, "leader-ca-cert", "", "Path to the leader's CA certificate (optional, insecure if not provided)")
joinCmd.Flags().BoolVar(&etcdPeer, "etcd-peer", false, "Request to join the etcd quorum (optional)")
// Mark required flags
joinCmd.MarkFlagRequired("leader-api")
joinCmd.MarkFlagRequired("advertise-address")
rootCmd.AddCommand(initCmd)
rootCmd.AddCommand(joinCmd)
}
func runInit(cmd *cobra.Command, args []string) {
@ -221,8 +250,118 @@ func runInit(cmd *cobra.Command, args []string) {
// Register the join handler
apiServer.RegisterJoinHandler(func(w http.ResponseWriter, r *http.Request) {
log.Printf("Received join request from %s", r.RemoteAddr)
w.WriteHeader(http.StatusOK)
w.Write([]byte("Join endpoint is operational"))
// Read request body
var joinReq cli.JoinRequest
if err := json.NewDecoder(r.Body).Decode(&joinReq); err != nil {
log.Printf("Error decoding join request: %v", err)
http.Error(w, "Invalid request format", http.StatusBadRequest)
return
}
// Validate request
if joinReq.NodeName == "" || joinReq.AdvertiseAddr == "" || joinReq.CSRData == "" {
log.Printf("Invalid join request: missing required fields")
http.Error(w, "Missing required fields", http.StatusBadRequest)
return
}
log.Printf("Processing join request for node: %s, advertise address: %s",
joinReq.NodeName, joinReq.AdvertiseAddr)
// Decode CSR data
csrData, err := base64.StdEncoding.DecodeString(joinReq.CSRData)
if err != nil {
log.Printf("Error decoding CSR data: %v", err)
http.Error(w, "Invalid CSR data", http.StatusBadRequest)
return
}
// Create a temporary file for the CSR
tempCSRFile, err := os.CreateTemp("", "node-csr-*.pem")
if err != nil {
log.Printf("Error creating temp CSR file: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
defer os.Remove(tempCSRFile.Name())
// Write CSR data to temp file
if _, err := tempCSRFile.Write(csrData); err != nil {
log.Printf("Error writing CSR data to temp file: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
tempCSRFile.Close()
// Create a temp file for the signed certificate
tempCertFile, err := os.CreateTemp("", "node-cert-*.pem")
if err != nil {
log.Printf("Error creating temp cert file: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
defer os.Remove(tempCertFile.Name())
tempCertFile.Close()
// Sign the CSR
if err := pki.SignCertificateRequest(
filepath.Join(pkiDir, "ca.key"),
filepath.Join(pkiDir, "ca.crt"),
tempCSRFile.Name(),
tempCertFile.Name(),
365*24*time.Hour, // 1 year validity
); err != nil {
log.Printf("Error signing CSR: %v", err)
http.Error(w, "Failed to sign certificate", http.StatusInternalServerError)
return
}
// Read the signed certificate
signedCert, err := os.ReadFile(tempCertFile.Name())
if err != nil {
log.Printf("Error reading signed certificate: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
// Read the CA certificate
caCert, err := os.ReadFile(filepath.Join(pkiDir, "ca.crt"))
if err != nil {
log.Printf("Error reading CA certificate: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
// Generate a unique node UID
nodeUID := uuid.New().String()
// Store node registration in etcd (placeholder for now)
// In a future phase, we'll implement proper node registration with subnet assignment
// Create response
joinResp := cli.JoinResponse{
NodeName: joinReq.NodeName,
NodeUID: nodeUID,
SignedCertificate: base64.StdEncoding.EncodeToString(signedCert),
CACertificate: base64.StdEncoding.EncodeToString(caCert),
AssignedSubnet: "10.100.0.0/24", // Placeholder, will be properly implemented in network phase
}
// If etcd peer was requested, add join instructions (placeholder)
if etcdPeer {
joinResp.EtcdJoinInstructions = "Etcd peer join not implemented in this phase"
}
// Send response
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(joinResp); err != nil {
log.Printf("Error encoding join response: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
log.Printf("Successfully processed join request for node: %s", joinReq.NodeName)
})
// Start the server in a goroutine
@ -283,6 +422,24 @@ func runInit(cmd *cobra.Command, args []string) {
log.Println("KAT Agent init shutdown complete.")
}
func runJoin(cmd *cobra.Command, args []string) {
log.Printf("Starting KAT Agent in join mode for node: %s", nodeName)
log.Printf("Attempting to join cluster via leader API: %s", leaderAPI)
// Determine PKI directory
// For simplicity, we'll use a default location
pkiDir := filepath.Join(os.Getenv("HOME"), ".kat-agent", nodeName, "pki")
// Join the cluster
if err := cli.JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert, pkiDir); err != nil {
log.Fatalf("Failed to join cluster: %v", err)
}
log.Printf("Successfully joined cluster. Node is ready.")
// In a real implementation, we would start the agent's main loop here
// For now, we'll just exit successfully
}
func main() {
if err := rootCmd.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)