Compare commits
10 Commits
main
...
8f1944ba15
Author | SHA1 | Date | |
---|---|---|---|
8f1944ba15 | |||
9e63518308 | |||
800e4f72f2 | |||
2f6d3c9bb2 | |||
4f6365d453 | |||
47f9b69876 | |||
787262c8a0 | |||
52d7af083e | |||
bcff04db12 | |||
7adabe8630 |
@ -4,14 +4,17 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.dws.rip/dubey/kat/internal/api"
|
||||||
"git.dws.rip/dubey/kat/internal/config"
|
"git.dws.rip/dubey/kat/internal/config"
|
||||||
"git.dws.rip/dubey/kat/internal/leader"
|
"git.dws.rip/dubey/kat/internal/leader"
|
||||||
|
"git.dws.rip/dubey/kat/internal/pki"
|
||||||
"git.dws.rip/dubey/kat/internal/store"
|
"git.dws.rip/dubey/kat/internal/store"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@ -43,6 +46,7 @@ const (
|
|||||||
clusterUIDKey = "/kat/config/cluster_uid"
|
clusterUIDKey = "/kat/config/cluster_uid"
|
||||||
clusterConfigKey = "/kat/config/cluster_config" // Stores the JSON of pb.ClusterConfigurationSpec
|
clusterConfigKey = "/kat/config/cluster_config" // Stores the JSON of pb.ClusterConfigurationSpec
|
||||||
defaultNodeName = "kat-node"
|
defaultNodeName = "kat-node"
|
||||||
|
leaderCertCN = "leader.kat.cluster.local" // Common Name for leader certificate
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@ -69,6 +73,25 @@ func runInit(cmd *cobra.Command, args []string) {
|
|||||||
// config.SetClusterConfigDefaults(parsedClusterConfig)
|
// config.SetClusterConfigDefaults(parsedClusterConfig)
|
||||||
log.Printf("Successfully parsed and applied defaults to cluster configuration: %s", parsedClusterConfig.Metadata.Name)
|
log.Printf("Successfully parsed and applied defaults to cluster configuration: %s", parsedClusterConfig.Metadata.Name)
|
||||||
|
|
||||||
|
// 1.5. Initialize PKI directory and CA if it doesn't exist
|
||||||
|
pkiDir := pki.GetPKIPathFromClusterConfig(parsedClusterConfig.Spec.BackupPath)
|
||||||
|
caKeyPath := filepath.Join(pkiDir, "ca.key")
|
||||||
|
caCertPath := filepath.Join(pkiDir, "ca.crt")
|
||||||
|
|
||||||
|
// Check if CA already exists
|
||||||
|
_, caKeyErr := os.Stat(caKeyPath)
|
||||||
|
_, caCertErr := os.Stat(caCertPath)
|
||||||
|
|
||||||
|
if os.IsNotExist(caKeyErr) || os.IsNotExist(caCertErr) {
|
||||||
|
log.Printf("CA key or certificate not found. Generating new CA in %s", pkiDir)
|
||||||
|
if err := pki.GenerateCA(pkiDir, caKeyPath, caCertPath); err != nil {
|
||||||
|
log.Fatalf("Failed to generate CA: %v", err)
|
||||||
|
}
|
||||||
|
log.Println("Successfully generated new CA key and certificate")
|
||||||
|
} else {
|
||||||
|
log.Println("CA key and certificate already exist, skipping generation")
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare etcd embed config
|
// Prepare etcd embed config
|
||||||
// For a single node init, this node is the only peer.
|
// For a single node init, this node is the only peer.
|
||||||
// Client URLs and Peer URLs will be based on its own configuration.
|
// Client URLs and Peer URLs will be based on its own configuration.
|
||||||
@ -138,6 +161,37 @@ func runInit(cmd *cobra.Command, args []string) {
|
|||||||
log.Printf("Cluster UID already exists in etcd. Skipping storage.")
|
log.Printf("Cluster UID already exists in etcd. Skipping storage.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate leader's server certificate for mTLS
|
||||||
|
leaderKeyPath := filepath.Join(pkiDir, "leader.key")
|
||||||
|
leaderCSRPath := filepath.Join(pkiDir, "leader.csr")
|
||||||
|
leaderCertPath := filepath.Join(pkiDir, "leader.crt")
|
||||||
|
|
||||||
|
// Check if leader cert already exists
|
||||||
|
_, leaderCertErr := os.Stat(leaderCertPath)
|
||||||
|
if os.IsNotExist(leaderCertErr) {
|
||||||
|
log.Println("Generating leader server certificate for mTLS")
|
||||||
|
|
||||||
|
// Generate key and CSR for leader
|
||||||
|
if err := pki.GenerateCertificateRequest(leaderCertCN, leaderKeyPath, leaderCSRPath); err != nil {
|
||||||
|
log.Printf("Failed to generate leader key and CSR: %v", err)
|
||||||
|
} else {
|
||||||
|
// Read the CSR file
|
||||||
|
_, err := os.ReadFile(leaderCSRPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to read leader CSR file: %v", err)
|
||||||
|
} else {
|
||||||
|
// Sign the CSR with our CA
|
||||||
|
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, leaderCSRPath, leaderCertPath, 365*24*time.Hour); err != nil {
|
||||||
|
log.Printf("Failed to sign leader CSR: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Println("Successfully generated and signed leader server certificate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Println("Leader certificate already exists, skipping generation")
|
||||||
|
}
|
||||||
|
|
||||||
// Store ClusterConfigurationSpec (as JSON)
|
// Store ClusterConfigurationSpec (as JSON)
|
||||||
// We store Spec because Metadata might change (e.g. resourceVersion)
|
// We store Spec because Metadata might change (e.g. resourceVersion)
|
||||||
// and is more for API object representation.
|
// and is more for API object representation.
|
||||||
@ -156,6 +210,45 @@ func runInit(cmd *cobra.Command, args []string) {
|
|||||||
parsedClusterConfig.Spec.ApiPort)
|
parsedClusterConfig.Spec.ApiPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start API server with mTLS
|
||||||
|
log.Println("Starting API server with mTLS...")
|
||||||
|
apiAddr := fmt.Sprintf(":%d", parsedClusterConfig.Spec.ApiPort)
|
||||||
|
apiServer, err := api.NewServer(apiAddr, leaderCertPath, leaderKeyPath, caCertPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to create API server: %v", err)
|
||||||
|
} else {
|
||||||
|
// Register the join handler
|
||||||
|
apiServer.RegisterJoinHandler(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log.Printf("Received join request from %s", r.RemoteAddr)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("Join endpoint is operational"))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start the server in a goroutine
|
||||||
|
go func() {
|
||||||
|
if err := apiServer.Start(); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Printf("API server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Add a shutdown hook to the leadership context
|
||||||
|
go func() {
|
||||||
|
<-leadershipCtx.Done()
|
||||||
|
log.Println("Leadership lost, shutting down API server...")
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := apiServer.Stop(shutdownCtx); err != nil {
|
||||||
|
log.Printf("Error shutting down API server: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Printf("API server started on port %d with mTLS", parsedClusterConfig.Spec.ApiPort)
|
||||||
|
log.Printf("Verification: API server requires client certificates signed by the cluster CA")
|
||||||
|
log.Printf("Test with: curl --cacert %s --cert <client_cert> --key <client_key> https://localhost:%d/internal/v1alpha1/join",
|
||||||
|
caCertPath, parsedClusterConfig.Spec.ApiPort)
|
||||||
|
}
|
||||||
|
|
||||||
log.Println("Initial leader setup complete. Waiting for leadership context to end or agent to be stopped.")
|
log.Println("Initial leader setup complete. Waiting for leadership context to end or agent to be stopped.")
|
||||||
<-leadershipCtx.Done() // Wait until leadership is lost or context is cancelled by manager
|
<-leadershipCtx.Done() // Wait until leadership is lost or context is cancelled by manager
|
||||||
},
|
},
|
||||||
|
@ -3,8 +3,8 @@ kind: ClusterConfiguration
|
|||||||
metadata:
|
metadata:
|
||||||
name: my-kat-cluster
|
name: my-kat-cluster
|
||||||
spec:
|
spec:
|
||||||
clusterCIDR: "10.100.0.0/16"
|
cluster_CIDR: "10.100.0.0/16"
|
||||||
serviceCIDR: "10.200.0.0/16"
|
service_CIDR: "10.200.0.0/16"
|
||||||
nodeSubnetBits: 7 # Results in /23 node subnets (e.g., 10.100.0.0/23, 10.100.2.0/23)
|
nodeSubnetBits: 7 # Results in /23 node subnets (e.g., 10.100.0.0/23, 10.100.2.0/23)
|
||||||
clusterDomain: "kat.example.local" # Overriding default
|
clusterDomain: "kat.example.local" # Overriding default
|
||||||
apiPort: 9115
|
apiPort: 9115
|
||||||
|
141
internal/api/join_handler.go
Normal file
141
internal/api/join_handler.go
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"kat-system/internal/pki"
|
||||||
|
"kat-system/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JoinRequest represents the data sent by an agent when joining
|
||||||
|
type JoinRequest struct {
|
||||||
|
CSR []byte `json:"csr"`
|
||||||
|
AdvertiseAddr string `json:"advertiseAddr"`
|
||||||
|
NodeName string `json:"nodeName,omitempty"` // Optional, leader can generate
|
||||||
|
WireguardPubKey string `json:"wireguardPubKey"` // Placeholder for now
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinResponse represents the data sent back to the agent
|
||||||
|
type JoinResponse struct {
|
||||||
|
NodeName string `json:"nodeName"`
|
||||||
|
NodeUID string `json:"nodeUID"`
|
||||||
|
SignedCert []byte `json:"signedCert"`
|
||||||
|
CACert []byte `json:"caCert"`
|
||||||
|
JoinTimestamp int64 `json:"joinTimestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJoinHandler creates a handler for agent join requests
|
||||||
|
func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Read and parse the request body
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer r.Body.Close()
|
||||||
|
|
||||||
|
var joinReq JoinRequest
|
||||||
|
if err := json.Unmarshal(body, &joinReq); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate request
|
||||||
|
if len(joinReq.CSR) == 0 {
|
||||||
|
http.Error(w, "Missing CSR", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if joinReq.AdvertiseAddr == "" {
|
||||||
|
http.Error(w, "Missing advertise address", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate node name if not provided
|
||||||
|
nodeName := joinReq.NodeName
|
||||||
|
if nodeName == "" {
|
||||||
|
nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a unique node ID
|
||||||
|
nodeUID := uuid.New().String()
|
||||||
|
|
||||||
|
// Sign the CSR
|
||||||
|
// Create a temporary file for the CSR
|
||||||
|
tempDir := os.TempDir()
|
||||||
|
csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID))
|
||||||
|
if err := os.WriteFile(csrPath, joinReq.CSR, 0600); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer os.Remove(csrPath)
|
||||||
|
|
||||||
|
// Sign the CSR
|
||||||
|
certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID))
|
||||||
|
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer os.Remove(certPath)
|
||||||
|
|
||||||
|
// Read the signed certificate
|
||||||
|
signedCert, err := os.ReadFile(certPath)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the CA certificate
|
||||||
|
caCert, err := os.ReadFile(caCertPath)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store node registration in etcd
|
||||||
|
nodeRegKey := fmt.Sprintf("/kat/nodes/registration/%s", nodeName)
|
||||||
|
nodeReg := map[string]interface{}{
|
||||||
|
"uid": nodeUID,
|
||||||
|
"advertiseAddr": joinReq.AdvertiseAddr,
|
||||||
|
"wireguardPubKey": joinReq.WireguardPubKey,
|
||||||
|
"joinTimestamp": time.Now().Unix(),
|
||||||
|
}
|
||||||
|
nodeRegData, err := json.Marshal(nodeReg)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare and send response
|
||||||
|
joinResp := JoinResponse{
|
||||||
|
NodeName: nodeName,
|
||||||
|
NodeUID: nodeUID,
|
||||||
|
SignedCert: signedCert,
|
||||||
|
CACert: caCert,
|
||||||
|
JoinTimestamp: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
respData, err := json.Marshal(joinResp)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write(respData)
|
||||||
|
}
|
||||||
|
}
|
48
internal/api/router.go
Normal file
48
internal/api/router.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Route represents a single API route
|
||||||
|
type Route struct {
|
||||||
|
Method string
|
||||||
|
Path string
|
||||||
|
Handler http.HandlerFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router is a simple HTTP router for the KAT API
|
||||||
|
type Router struct {
|
||||||
|
routes []Route
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRouter creates a new router instance
|
||||||
|
func NewRouter() *Router {
|
||||||
|
return &Router{
|
||||||
|
routes: []Route{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleFunc registers a new route with the router
|
||||||
|
func (r *Router) HandleFunc(method, path string, handler http.HandlerFunc) {
|
||||||
|
r.routes = append(r.routes, Route{
|
||||||
|
Method: strings.ToUpper(method),
|
||||||
|
Path: path,
|
||||||
|
Handler: handler,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP implements the http.Handler interface
|
||||||
|
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
|
// Find matching route
|
||||||
|
for _, route := range r.routes {
|
||||||
|
if route.Method == req.Method && route.Path == req.URL.Path {
|
||||||
|
route.Handler(w, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No route matched
|
||||||
|
http.NotFound(w, req)
|
||||||
|
}
|
84
internal/api/server.go
Normal file
84
internal/api/server.go
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server represents the API server for KAT
|
||||||
|
type Server struct {
|
||||||
|
httpServer *http.Server
|
||||||
|
router *Router
|
||||||
|
certFile string
|
||||||
|
keyFile string
|
||||||
|
caFile string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates a new API server instance
|
||||||
|
func NewServer(addr string, certFile, keyFile, caFile string) (*Server, error) {
|
||||||
|
router := NewRouter()
|
||||||
|
|
||||||
|
server := &Server{
|
||||||
|
router: router,
|
||||||
|
certFile: certFile,
|
||||||
|
keyFile: keyFile,
|
||||||
|
caFile: caFile,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the HTTP server with TLS config
|
||||||
|
server.httpServer = &http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Handler: router,
|
||||||
|
ReadTimeout: 30 * time.Second,
|
||||||
|
WriteTimeout: 30 * time.Second,
|
||||||
|
IdleTimeout: 120 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
return server, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins listening for requests
|
||||||
|
func (s *Server) Start() error {
|
||||||
|
// Load server certificate and key
|
||||||
|
cert, err := tls.LoadX509KeyPair(s.certFile, s.keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load server certificate and key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load CA certificate for client verification
|
||||||
|
caCert, err := os.ReadFile(s.caFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read CA certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||||
|
return fmt.Errorf("failed to append CA certificate to pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure TLS
|
||||||
|
s.httpServer.TLSConfig = &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||||
|
ClientCAs: caCertPool,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the server
|
||||||
|
return s.httpServer.ListenAndServeTLS("", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the server
|
||||||
|
func (s *Server) Stop(ctx context.Context) error {
|
||||||
|
return s.httpServer.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterJoinHandler registers the handler for agent join requests
|
||||||
|
func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) {
|
||||||
|
s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler)
|
||||||
|
}
|
139
internal/api/server_test.go
Normal file
139
internal/api/server_test.go
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"kat-system/internal/pki"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServerWithMTLS(t *testing.T) {
|
||||||
|
// Skip in short mode
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create temporary directory for test certificates
|
||||||
|
tempDir, err := os.MkdirTemp("", "kat-api-test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Generate CA
|
||||||
|
caKeyPath := filepath.Join(tempDir, "ca.key")
|
||||||
|
caCertPath := filepath.Join(tempDir, "ca.crt")
|
||||||
|
if err := pki.GenerateCA(caKeyPath, caCertPath, "KAT Test CA", 24*time.Hour); err != nil {
|
||||||
|
t.Fatalf("Failed to generate CA: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate server certificate
|
||||||
|
serverKeyPath := filepath.Join(tempDir, "server.key")
|
||||||
|
serverCSRPath := filepath.Join(tempDir, "server.csr")
|
||||||
|
serverCertPath := filepath.Join(tempDir, "server.crt")
|
||||||
|
if err := pki.GenerateCertificateRequest("server.test", serverKeyPath, serverCSRPath); err != nil {
|
||||||
|
t.Fatalf("Failed to generate server CSR: %v", err)
|
||||||
|
}
|
||||||
|
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, serverCSRPath, serverCertPath, 24*time.Hour); err != nil {
|
||||||
|
t.Fatalf("Failed to sign server certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate client certificate
|
||||||
|
clientKeyPath := filepath.Join(tempDir, "client.key")
|
||||||
|
clientCSRPath := filepath.Join(tempDir, "client.csr")
|
||||||
|
clientCertPath := filepath.Join(tempDir, "client.crt")
|
||||||
|
if err := pki.GenerateCertificateRequest("client.test", clientKeyPath, clientCSRPath); err != nil {
|
||||||
|
t.Fatalf("Failed to generate client CSR: %v", err)
|
||||||
|
}
|
||||||
|
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, clientCSRPath, clientCertPath, 24*time.Hour); err != nil {
|
||||||
|
t.Fatalf("Failed to sign client certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and start server
|
||||||
|
server, err := NewServer("localhost:0", serverCertPath, serverKeyPath, caCertPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a test handler
|
||||||
|
server.router.HandleFunc("GET", "/test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("test successful"))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start server in a goroutine
|
||||||
|
go func() {
|
||||||
|
if err := server.Start(); err != nil && err != http.ErrServerClosed {
|
||||||
|
t.Errorf("Server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for server to start
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Load CA cert
|
||||||
|
caCert, err := os.ReadFile(caCertPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read CA cert: %v", err)
|
||||||
|
}
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
caCertPool.AppendCertsFromPEM(caCert)
|
||||||
|
|
||||||
|
// Load client cert
|
||||||
|
clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load client cert: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create HTTP client with mTLS
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
RootCAs: caCertPool,
|
||||||
|
Certificates: []tls.Certificate{clientCert},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with valid client cert
|
||||||
|
resp, err := client.Get("https://localhost:8443/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(string(body), "test successful") {
|
||||||
|
t.Errorf("Unexpected response: %s", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with no client cert (should fail)
|
||||||
|
clientWithoutCert := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
RootCAs: caCertPool,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = clientWithoutCert.Get("https://localhost:8443/test")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Request without client cert should fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown server
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
server.Stop(ctx)
|
||||||
|
}
|
@ -201,8 +201,8 @@ func TestValidateClusterConfiguration_InvalidValues(t *testing.T) {
|
|||||||
ApiPort: 10251,
|
ApiPort: 10251,
|
||||||
EtcdPeerPort: 2380,
|
EtcdPeerPort: 2380,
|
||||||
EtcdClientPort: 2379,
|
EtcdClientPort: 2379,
|
||||||
VolumeBasePath: "/var/lib/kat/volumes",
|
VolumeBasePath: "~/.kat/volumes",
|
||||||
BackupPath: "/var/lib/kat/backups",
|
BackupPath: "~/.kat/backups",
|
||||||
BackupIntervalMinutes: 30,
|
BackupIntervalMinutes: 30,
|
||||||
AgentTickSeconds: 15,
|
AgentTickSeconds: 15,
|
||||||
NodeLossTimeoutSeconds: 60,
|
NodeLossTimeoutSeconds: 60,
|
||||||
|
@ -11,13 +11,13 @@ const (
|
|||||||
DefaultApiPort = 9115
|
DefaultApiPort = 9115
|
||||||
DefaultEtcdPeerPort = 2380
|
DefaultEtcdPeerPort = 2380
|
||||||
DefaultEtcdClientPort = 2379
|
DefaultEtcdClientPort = 2379
|
||||||
DefaultVolumeBasePath = "/var/lib/kat/volumes"
|
DefaultVolumeBasePath = "~/.kat/volumes"
|
||||||
DefaultBackupPath = "/var/lib/kat/backups"
|
DefaultBackupPath = "~/.kat/backups"
|
||||||
DefaultBackupIntervalMins = 30
|
DefaultBackupIntervalMins = 30
|
||||||
DefaultAgentTickSeconds = 15
|
DefaultAgentTickSeconds = 15
|
||||||
DefaultNodeLossTimeoutSec = 60 // DefaultNodeLossTimeoutSeconds = DefaultAgentTickSeconds * 4 (example logic)
|
DefaultNodeLossTimeoutSec = 60 // DefaultNodeLossTimeoutSeconds = DefaultAgentTickSeconds * 4 (example logic)
|
||||||
DefaultNodeSubnetBits = 7 // yields /23 from /16, or /31 from /24 etc. (5 bits for /29, 7 for /25)
|
DefaultNodeSubnetBits = 7 // yields /23 from /16, or /31 from /24 etc. (5 bits for /29, 7 for /25)
|
||||||
// RFC says 7 for /23 from /16. This means 2^(32-16-7) = 2^9 = 512 IPs per node subnet.
|
// RFC says 7 for /23 from /16. This means 2^(32-16-7) = 2^9 = 512 IPs per node subnet.
|
||||||
// If nodeSubnetBits means bits for the node portion *within* the host part of clusterCIDR:
|
// If nodeSubnetBits means bits for the node portion *within* the host part of clusterCIDR:
|
||||||
// e.g. /16 -> 16 host bits. If nodeSubnetBits = 7, then node subnet is / (16+7) = /23.
|
// e.g. /16 -> 16 host bits. If nodeSubnetBits = 7, then node subnet is / (16+7) = /23.
|
||||||
)
|
)
|
318
internal/pki/ca.go
Normal file
318
internal/pki/ca.go
Normal file
@ -0,0 +1,318 @@
|
|||||||
|
package pki
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Default key size for RSA keys
|
||||||
|
DefaultRSAKeySize = 2048
|
||||||
|
// Default CA certificate validity period
|
||||||
|
DefaultCAValidityDays = 3650 // ~10 years
|
||||||
|
// Default certificate validity period
|
||||||
|
DefaultCertValidityDays = 365 // 1 year
|
||||||
|
// Default PKI directory
|
||||||
|
DefaultPKIDir = "~/.kat/pki"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateCA creates a new Certificate Authority key pair and certificate.
|
||||||
|
// It saves the private key and certificate to the specified paths.
|
||||||
|
func GenerateCA(pkiDir string, keyPath, certPath string) error {
|
||||||
|
// Create PKI directory if it doesn't exist
|
||||||
|
if err := os.MkdirAll(pkiDir, 0700); err != nil {
|
||||||
|
return fmt.Errorf("failed to create PKI directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate RSA key
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, DefaultRSAKeySize)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate CA key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create self-signed certificate
|
||||||
|
serialNumber, err := generateSerialNumber()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate serial number: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Certificate template
|
||||||
|
notBefore := time.Now()
|
||||||
|
notAfter := notBefore.Add(time.Duration(DefaultCAValidityDays) * 24 * time.Hour)
|
||||||
|
|
||||||
|
template := x509.Certificate{
|
||||||
|
SerialNumber: serialNumber,
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: "KAT Root CA",
|
||||||
|
Organization: []string{"KAT System"},
|
||||||
|
},
|
||||||
|
NotBefore: notBefore,
|
||||||
|
NotAfter: notAfter,
|
||||||
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
IsCA: true,
|
||||||
|
MaxPathLen: 1, // Only allow one level of intermediate certs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create certificate
|
||||||
|
derBytes, err := x509.CreateCertificate(
|
||||||
|
rand.Reader,
|
||||||
|
&template,
|
||||||
|
&template, // Self-signed
|
||||||
|
&key.PublicKey,
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create CA certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save private key
|
||||||
|
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open CA key file for writing: %w", err)
|
||||||
|
}
|
||||||
|
defer keyOut.Close()
|
||||||
|
|
||||||
|
err = pem.Encode(keyOut, &pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write CA key to file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save certificate
|
||||||
|
certOut, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open CA certificate file for writing: %w", err)
|
||||||
|
}
|
||||||
|
defer certOut.Close()
|
||||||
|
|
||||||
|
err = pem.Encode(certOut, &pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: derBytes,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write CA certificate to file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateCertificateRequest creates a new key pair and a Certificate Signing Request (CSR).
|
||||||
|
// It saves the private key and CSR to the specified paths.
|
||||||
|
func GenerateCertificateRequest(commonName, keyOutPath, csrOutPath string) error {
|
||||||
|
// Generate RSA key
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, DefaultRSAKeySize)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create CSR template
|
||||||
|
template := x509.CertificateRequest{
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: commonName,
|
||||||
|
Organization: []string{"KAT System"},
|
||||||
|
},
|
||||||
|
SignatureAlgorithm: x509.SHA256WithRSA,
|
||||||
|
DNSNames: []string{commonName}, // Add the CN as a SAN
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create CSR
|
||||||
|
csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &template, key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create CSR: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save private key
|
||||||
|
keyOut, err := os.OpenFile(keyOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open key file for writing: %w", err)
|
||||||
|
}
|
||||||
|
defer keyOut.Close()
|
||||||
|
|
||||||
|
err = pem.Encode(keyOut, &pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write key to file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save CSR
|
||||||
|
csrOut, err := os.OpenFile(csrOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open CSR file for writing: %w", err)
|
||||||
|
}
|
||||||
|
defer csrOut.Close()
|
||||||
|
|
||||||
|
err = pem.Encode(csrOut, &pem.Block{
|
||||||
|
Type: "CERTIFICATE REQUEST",
|
||||||
|
Bytes: csrBytes,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write CSR to file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignCertificateRequest signs a CSR using the CA key and certificate.
|
||||||
|
// It reads the CSR from csrPath and saves the signed certificate to certOutPath.
|
||||||
|
// If csrPath contains PEM data (starts with "-----BEGIN"), it uses that directly instead of reading a file.
|
||||||
|
func SignCertificateRequest(caKeyPath, caCertPath, csrPathOrData, certOutPath string, duration time.Duration) error {
|
||||||
|
// Load CA key
|
||||||
|
caKey, err := LoadCAPrivateKey(caKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load CA key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load CA certificate
|
||||||
|
caCert, err := LoadCACertificate(caCertPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load CA certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine if csrPathOrData is a file path or PEM data
|
||||||
|
var csrPEM []byte
|
||||||
|
if strings.HasPrefix(csrPathOrData, "-----BEGIN") {
|
||||||
|
// It's PEM data, use it directly
|
||||||
|
csrPEM = []byte(csrPathOrData)
|
||||||
|
} else {
|
||||||
|
// It's a file path, read the file
|
||||||
|
csrPEM, err = os.ReadFile(csrPathOrData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read CSR file: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ := pem.Decode(csrPEM)
|
||||||
|
if block == nil || block.Type != "CERTIFICATE REQUEST" {
|
||||||
|
return fmt.Errorf("failed to decode PEM block containing CSR")
|
||||||
|
}
|
||||||
|
|
||||||
|
csr, err := x509.ParseCertificateRequest(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse CSR: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify CSR signature
|
||||||
|
if err = csr.CheckSignature(); err != nil {
|
||||||
|
return fmt.Errorf("CSR signature verification failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create certificate template from CSR
|
||||||
|
serialNumber, err := generateSerialNumber()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate serial number: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
notBefore := time.Now()
|
||||||
|
notAfter := notBefore.Add(duration)
|
||||||
|
|
||||||
|
template := x509.Certificate{
|
||||||
|
SerialNumber: serialNumber,
|
||||||
|
Subject: csr.Subject,
|
||||||
|
NotBefore: notBefore,
|
||||||
|
NotAfter: notAfter,
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||||
|
DNSNames: []string{csr.Subject.CommonName}, // Add the CN as a SAN
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create certificate
|
||||||
|
derBytes, err := x509.CreateCertificate(
|
||||||
|
rand.Reader,
|
||||||
|
&template,
|
||||||
|
caCert,
|
||||||
|
csr.PublicKey,
|
||||||
|
caKey,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save certificate
|
||||||
|
certOut, err := os.OpenFile(certOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open certificate file for writing: %w", err)
|
||||||
|
}
|
||||||
|
defer certOut.Close()
|
||||||
|
|
||||||
|
err = pem.Encode(certOut, &pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: derBytes,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write certificate to file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPKIPathFromClusterConfig determines the PKI directory from the cluster configuration.
|
||||||
|
// If backupPath is provided, it uses the parent directory of backupPath.
|
||||||
|
// Otherwise, it uses the default PKI directory.
|
||||||
|
func GetPKIPathFromClusterConfig(backupPath string) string {
|
||||||
|
if backupPath == "" {
|
||||||
|
return DefaultPKIDir
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the parent directory of backupPath
|
||||||
|
return filepath.Dir(backupPath) + "/pki"
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateSerialNumber creates a random serial number for certificates
|
||||||
|
func generateSerialNumber() (*big.Int, error) {
|
||||||
|
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) // 128 bits
|
||||||
|
return rand.Int(rand.Reader, serialNumberLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadCACertificate loads a CA certificate from a file
|
||||||
|
func LoadCACertificate(certPath string) (*x509.Certificate, error) {
|
||||||
|
certPEM, err := os.ReadFile(certPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read CA certificate file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ := pem.Decode(certPEM)
|
||||||
|
if block == nil || block.Type != "CERTIFICATE" {
|
||||||
|
return nil, fmt.Errorf("failed to decode PEM block containing certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse CA certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadCAPrivateKey loads a CA private key from a file
|
||||||
|
func LoadCAPrivateKey(keyPath string) (*rsa.PrivateKey, error) {
|
||||||
|
keyPEM, err := os.ReadFile(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read CA key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ := pem.Decode(keyPEM)
|
||||||
|
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
||||||
|
return nil, fmt.Errorf("failed to decode PEM block containing private key")
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse CA private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
73
internal/pki/ca_test.go
Normal file
73
internal/pki/ca_test.go
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
package pki
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateCA(t *testing.T) {
|
||||||
|
// Create a temporary directory for the test
|
||||||
|
tempDir, err := os.MkdirTemp("", "kat-pki-test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Define paths for CA key and certificate
|
||||||
|
keyPath := filepath.Join(tempDir, "ca.key")
|
||||||
|
certPath := filepath.Join(tempDir, "ca.crt")
|
||||||
|
|
||||||
|
// Generate CA
|
||||||
|
err = GenerateCA(tempDir, keyPath, certPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateCA failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify files exist
|
||||||
|
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
|
||||||
|
t.Errorf("CA key file was not created at %s", keyPath)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(certPath); os.IsNotExist(err) {
|
||||||
|
t.Errorf("CA certificate file was not created at %s", certPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load and verify CA certificate
|
||||||
|
caCert, err := LoadCACertificate(certPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load CA certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify CA properties
|
||||||
|
if !caCert.IsCA {
|
||||||
|
t.Errorf("Certificate is not marked as CA")
|
||||||
|
}
|
||||||
|
if caCert.Subject.CommonName != "KAT Root CA" {
|
||||||
|
t.Errorf("Unexpected CA CommonName: got %s, want %s", caCert.Subject.CommonName, "KAT Root CA")
|
||||||
|
}
|
||||||
|
if len(caCert.Subject.Organization) == 0 || caCert.Subject.Organization[0] != "KAT System" {
|
||||||
|
t.Errorf("Unexpected CA Organization: got %v, want [KAT System]", caCert.Subject.Organization)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load and verify CA key
|
||||||
|
_, err = LoadCAPrivateKey(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load CA private key: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPKIPathFromClusterConfig(t *testing.T) {
|
||||||
|
// Test with empty backup path
|
||||||
|
pkiPath := GetPKIPathFromClusterConfig("")
|
||||||
|
if pkiPath != DefaultPKIDir {
|
||||||
|
t.Errorf("Expected default PKI path %s, got %s", DefaultPKIDir, pkiPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with backup path
|
||||||
|
backupPath := "/opt/kat/backups"
|
||||||
|
expectedPKIPath := "/opt/kat/pki"
|
||||||
|
pkiPath = GetPKIPathFromClusterConfig(backupPath)
|
||||||
|
if pkiPath != expectedPKIPath {
|
||||||
|
t.Errorf("Expected PKI path %s, got %s", expectedPKIPath, pkiPath)
|
||||||
|
}
|
||||||
|
}
|
64
internal/pki/certs.go
Normal file
64
internal/pki/certs.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package pki
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseCSRFromBytes parses a PEM-encoded CSR from bytes
|
||||||
|
func ParseCSRFromBytes(csrData []byte) (*x509.CertificateRequest, error) {
|
||||||
|
block, _ := pem.Decode(csrData)
|
||||||
|
if block == nil || block.Type != "CERTIFICATE REQUEST" {
|
||||||
|
return nil, fmt.Errorf("failed to decode PEM block containing CSR")
|
||||||
|
}
|
||||||
|
|
||||||
|
csr, err := x509.ParseCertificateRequest(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse CSR: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return csr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadCertificate loads an X.509 certificate from a file
|
||||||
|
func LoadCertificate(certPath string) (*x509.Certificate, error) {
|
||||||
|
certPEM, err := os.ReadFile(certPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read certificate file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ := pem.Decode(certPEM)
|
||||||
|
if block == nil || block.Type != "CERTIFICATE" {
|
||||||
|
return nil, fmt.Errorf("failed to decode PEM block containing certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadPrivateKey loads an RSA private key from a file
|
||||||
|
func LoadPrivateKey(keyPath string) (*rsa.PrivateKey, error) {
|
||||||
|
keyPEM, err := os.ReadFile(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ := pem.Decode(keyPEM)
|
||||||
|
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
||||||
|
return nil, fmt.Errorf("failed to decode PEM block containing private key")
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
128
internal/pki/certs_test.go
Normal file
128
internal/pki/certs_test.go
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
package pki
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateCertificateRequest(t *testing.T) {
|
||||||
|
// Create a temporary directory for the test
|
||||||
|
tempDir, err := os.MkdirTemp("", "kat-csr-test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Define paths for key and CSR
|
||||||
|
keyPath := filepath.Join(tempDir, "node.key")
|
||||||
|
csrPath := filepath.Join(tempDir, "node.csr")
|
||||||
|
commonName := "test-node.kat.cluster.local"
|
||||||
|
|
||||||
|
// Generate CSR
|
||||||
|
err = GenerateCertificateRequest(commonName, keyPath, csrPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateCertificateRequest failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify files exist
|
||||||
|
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
|
||||||
|
t.Errorf("Key file was not created at %s", keyPath)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(csrPath); os.IsNotExist(err) {
|
||||||
|
t.Errorf("CSR file was not created at %s", csrPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read CSR file
|
||||||
|
csrData, err := os.ReadFile(csrPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read CSR file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse CSR
|
||||||
|
csr, err := ParseCSRFromBytes(csrData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse CSR: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify CSR properties
|
||||||
|
if csr.Subject.CommonName != commonName {
|
||||||
|
t.Errorf("Unexpected CSR CommonName: got %s, want %s", csr.Subject.CommonName, commonName)
|
||||||
|
}
|
||||||
|
if len(csr.DNSNames) == 0 || csr.DNSNames[0] != commonName {
|
||||||
|
t.Errorf("Unexpected CSR DNSNames: got %v, want [%s]", csr.DNSNames, commonName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignCertificateRequest(t *testing.T) {
|
||||||
|
// Create a temporary directory for the test
|
||||||
|
tempDir, err := os.MkdirTemp("", "kat-cert-test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Generate CA
|
||||||
|
caKeyPath := filepath.Join(tempDir, "ca.key")
|
||||||
|
caCertPath := filepath.Join(tempDir, "ca.crt")
|
||||||
|
err = GenerateCA(tempDir, caKeyPath, caCertPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateCA failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate CSR
|
||||||
|
nodeKeyPath := filepath.Join(tempDir, "node.key")
|
||||||
|
csrPath := filepath.Join(tempDir, "node.csr")
|
||||||
|
commonName := "test-node.kat.cluster.local"
|
||||||
|
err = GenerateCertificateRequest(commonName, nodeKeyPath, csrPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateCertificateRequest failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read CSR file
|
||||||
|
csrData, err := os.ReadFile(csrPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read CSR file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign CSR
|
||||||
|
certPath := filepath.Join(tempDir, "node.crt")
|
||||||
|
err = SignCertificateRequest(caKeyPath, caCertPath, string(csrData), certPath, 30) // 30 days validity
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SignCertificateRequest failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify certificate file exists
|
||||||
|
if _, err := os.Stat(certPath); os.IsNotExist(err) {
|
||||||
|
t.Errorf("Certificate file was not created at %s", certPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load and verify certificate
|
||||||
|
cert, err := LoadCertificate(certPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify certificate properties
|
||||||
|
if cert.Subject.CommonName != commonName {
|
||||||
|
t.Errorf("Unexpected certificate CommonName: got %s, want %s", cert.Subject.CommonName, commonName)
|
||||||
|
}
|
||||||
|
if cert.IsCA {
|
||||||
|
t.Errorf("Certificate should not be a CA")
|
||||||
|
}
|
||||||
|
if len(cert.DNSNames) == 0 || cert.DNSNames[0] != commonName {
|
||||||
|
t.Errorf("Unexpected certificate DNSNames: got %v, want [%s]", cert.DNSNames, commonName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load CA certificate to verify chain
|
||||||
|
caCert, err := LoadCACertificate(caCertPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load CA certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify certificate is signed by CA
|
||||||
|
err = cert.CheckSignatureFrom(caCert)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Certificate signature verification failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
@ -52,7 +52,7 @@ func StartEmbeddedEtcd(cfg EtcdEmbedConfig) (*embed.Etcd, error) {
|
|||||||
embedCfg.Name = cfg.Name
|
embedCfg.Name = cfg.Name
|
||||||
embedCfg.Dir = cfg.DataDir
|
embedCfg.Dir = cfg.DataDir
|
||||||
embedCfg.InitialClusterToken = "kat-etcd-cluster" // Make this configurable if needed
|
embedCfg.InitialClusterToken = "kat-etcd-cluster" // Make this configurable if needed
|
||||||
embedCfg.ForceNewCluster = false // Set to true only for initial bootstrap of a new cluster if needed
|
embedCfg.ForceNewCluster = false // Set to true only for initial bootstrap of a new cluster if needed
|
||||||
|
|
||||||
lpurl, err := parseURLs(cfg.PeerURLs)
|
lpurl, err := parseURLs(cfg.PeerURLs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -23,10 +23,10 @@ func TestEtcdStore(t *testing.T) {
|
|||||||
|
|
||||||
// Configure and start embedded etcd
|
// Configure and start embedded etcd
|
||||||
etcdConfig := EtcdEmbedConfig{
|
etcdConfig := EtcdEmbedConfig{
|
||||||
Name: "test-node",
|
Name: "test-node",
|
||||||
DataDir: tempDir,
|
DataDir: tempDir,
|
||||||
ClientURLs: []string{"http://localhost:0"}, // Use port 0 to get a random available port
|
ClientURLs: []string{"http://localhost:0"}, // Use port 0 to get a random available port
|
||||||
PeerURLs: []string{"http://localhost:0"},
|
PeerURLs: []string{"http://localhost:0"},
|
||||||
}
|
}
|
||||||
|
|
||||||
etcdServer, err := StartEmbeddedEtcd(etcdConfig)
|
etcdServer, err := StartEmbeddedEtcd(etcdConfig)
|
||||||
@ -232,10 +232,10 @@ func TestLeaderElection(t *testing.T) {
|
|||||||
|
|
||||||
// Configure and start embedded etcd
|
// Configure and start embedded etcd
|
||||||
etcdConfig := EtcdEmbedConfig{
|
etcdConfig := EtcdEmbedConfig{
|
||||||
Name: "election-test-node",
|
Name: "election-test-node",
|
||||||
DataDir: tempDir,
|
DataDir: tempDir,
|
||||||
ClientURLs: []string{"http://localhost:0"},
|
ClientURLs: []string{"http://localhost:0"},
|
||||||
PeerURLs: []string{"http://localhost:0"},
|
PeerURLs: []string{"http://localhost:0"},
|
||||||
}
|
}
|
||||||
|
|
||||||
etcdServer, err := StartEmbeddedEtcd(etcdConfig)
|
etcdServer, err := StartEmbeddedEtcd(etcdConfig)
|
||||||
|
@ -51,8 +51,8 @@ spec:
|
|||||||
apiPort: 9115
|
apiPort: 9115
|
||||||
etcdPeerPort: 2380
|
etcdPeerPort: 2380
|
||||||
etcdClientPort: 2379
|
etcdClientPort: 2379
|
||||||
volumeBasePath: "/var/lib/kat/volumes"
|
volumeBasePath: "~/.kat/volumes"
|
||||||
backupPath: "/var/lib/kat/backups"
|
backupPath: "~/.kat/backups"
|
||||||
backupIntervalMinutes: 30
|
backupIntervalMinutes: 30
|
||||||
agentTickSeconds: 15
|
agentTickSeconds: 15
|
||||||
nodeLossTimeoutSeconds: 60
|
nodeLossTimeoutSeconds: 60
|
||||||
|
Reference in New Issue
Block a user