Compare commits

...

20 Commits

Author SHA1 Message Date
dad5586339 Add verbose to test
All checks were successful
Integration Tests / integration-tests (pull_request) Successful in 9m55s
Unit Tests / unit-tests (pull_request) Successful in 10m10s
2025-05-17 13:23:09 -04:00
e4a19a6bb8 Based on the changes, I'll generate a concise commit message that captures the essence of the modifications:
feat: add node registration verification and idle loop for joined nodes
2025-05-17 13:19:16 -04:00
8bdccdc8c7 refactor: simplify imports and clean up code formatting in main.go 2025-05-17 13:19:13 -04:00
bf80b65873 feat: Implement CSR signing and node registration handler for agent join 2025-05-17 13:05:21 -04:00
f1f2b8f9ef fix: update TestServerWithMTLS to match Phase 2 TLS configuration 2025-05-17 12:50:16 -04:00
ce6f2ce29d Minor fixes 2025-05-17 12:48:37 -04:00
b33127bd34 fix: disable client cert verification for Phase 2 development 2025-05-17 12:38:20 -04:00
c07f389996 feat: modify TLS config to allow initial node join without client certificate 2025-05-17 12:32:26 -04:00
4f7c2d6a66 I noticed a duplicate function in the internal/pki/ca.go file. I'll help you clean it up. Here's the corrected version:
```go
package pki

import (
	// other imports
	"path/filepath"
)

const (
	// Default key size for RSA keys
	DefaultRSAKeySize = 2048
	// Default CA certificate validity period
	DefaultCAValidityDays = 3650 // ~10 years
	// Default certificate validity period
	DefaultCertValidityDays = 365 // 1 year
	// Default PKI directory
	DefaultPKIDir = "/var/lib/kat/pki"
)

// GetPKIPathFromClusterConfig determines the PKI directory from the cluster configuration.
// If backupPath is provided, it uses the parent directory of backupPath.
// Otherwise, it uses the default PKI directory.
func GetPKIPathFromClusterConfig(backupPath string) string {
	if backupPath == "" {
		return DefaultPKIDir
	}

	// Use the parent directory of backupPath
	return filepath.Dir(backupPath) + "/pki"
}

// generateSerialNumber creates a random serial number for certificates
func generateSerialNumber() (*big.Int, error) {
	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) // 128 bits
	return rand.Int(rand.Reader, serialNumberLimit)
}

// Rest of the existing code...
```

The changes:
1. Removed the duplicate `GetPKIPathFromClusterConfig` function
2. Kept the single implementation that checks for an empty backup path
3. Maintained the default PKI directory as `/var/lib/kat/pki`

This should resolve the duplicate function issue while maintaining the desired functionality.

Would you like me to generate a commit message for this change?
2025-05-17 12:18:42 -04:00
af6a584628 feat: add request logging middleware and improve server logging 2025-05-17 12:18:32 -04:00
8f1944ba15 feat: implement mTLS API server with client certificate verification in kat-agent 2025-05-17 11:36:52 -04:00
9e63518308 feat: Implement basic API server with mTLS for leader join endpoint 2025-05-16 22:18:58 -04:00
800e4f72f2 Run gofmt 2025-05-16 22:13:42 -04:00
2f6d3c9bb2 Use local paths when possible, some AI cleanup 2025-05-16 21:20:39 -04:00
4f6365d453 fix: handle CSR file path and raw PEM data in SignCertificateRequest 2025-05-16 21:17:23 -04:00
47f9b69876 fix: add DNS names to CSR and improve certificate generation 2025-05-16 21:15:43 -04:00
787262c8a0 refactor: change default PKI directory to user home directory 2025-05-16 21:15:40 -04:00
52d7af083e refactor: remove duplicate certificate request functions from certs.go 2025-05-16 21:01:34 -04:00
bcff04db12 Based on the implementation, I'll generate a concise commit message that captures the essence of the changes:
feat: implement PKI initialization and leader mTLS certificate generation
2025-05-16 20:59:01 -04:00
7adabe8630 feat: implement internal PKI utilities for CA and certificate management 2025-05-16 20:47:57 -04:00
21 changed files with 1717 additions and 38 deletions

6
.gitignore vendored
View File

@ -29,3 +29,9 @@ go.work.sum
.local .local
*.csr
*.crt
*.key
*.srl
.kat/

View File

@ -18,24 +18,24 @@ clean:
# Run all tests # Run all tests
test: generate test: generate
@echo "Running all tests..." @echo "Running all tests..."
@go test -count=1 ./... @go test -v -count=1 ./... --coverprofile=coverage.out
# Run unit tests only (faster, no integration tests) # Run unit tests only (faster, no integration tests)
test-unit: test-unit:
@echo "Running unit tests..." @echo "Running unit tests..."
@go test -count=1 -short ./... @go test -v -count=1 ./...
# Run integration tests only # Run integration tests only
test-integration: test-integration:
@echo "Running integration tests..." @echo "Running integration tests..."
@go test -count=1 -run Integration ./... @go test -v -count=1 -run Integration ./...
# Run tests for a specific package # Run tests for a specific package
test-package: test-package:
@echo "Running tests for package $(PACKAGE)..." @echo "Running tests for package $(PACKAGE)..."
@go test -v ./$(PACKAGE) @go test -v ./$(PACKAGE)
kat-agent: kat-agent: $(shell find ./cmd/kat-agent -name '*.go') $(shell find . -name 'go.mod' -o -name 'go.sum')
@echo "Building kat-agent..." @echo "Building kat-agent..."
@go build -o kat-agent ./cmd/kat-agent/main.go @go build -o kat-agent ./cmd/kat-agent/main.go

View File

@ -4,14 +4,18 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"net/http"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"syscall" "syscall"
"time" "time"
"git.dws.rip/dubey/kat/internal/api"
"git.dws.rip/dubey/kat/internal/cli"
"git.dws.rip/dubey/kat/internal/config" "git.dws.rip/dubey/kat/internal/config"
"git.dws.rip/dubey/kat/internal/leader" "git.dws.rip/dubey/kat/internal/leader"
"git.dws.rip/dubey/kat/internal/pki"
"git.dws.rip/dubey/kat/internal/store" "git.dws.rip/dubey/kat/internal/store"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -34,15 +38,41 @@ campaigns for leadership, and stores initial cluster configuration.`,
Run: runInit, Run: runInit,
} }
joinCmd = &cobra.Command{
Use: "join",
Short: "Joins an existing KAT cluster.",
Long: `Connects to an existing KAT leader, submits a certificate signing request,
and obtains the necessary credentials to participate in the cluster.`,
Run: runJoin,
}
verifyCmd = &cobra.Command{
Use: "verify",
Short: "Verifies node registration in etcd.",
Long: `Connects to etcd and verifies that a node is properly registered.
This is useful for testing and debugging.`,
Run: runVerify,
}
// Global flags / config paths // Global flags / config paths
clusterConfigPath string clusterConfigPath string
nodeName string nodeName string
// Join command flags
leaderAPI string
advertiseAddr string
leaderCACert string
etcdPeer bool
// Verify command flags
etcdEndpoint string
) )
const ( const (
clusterUIDKey = "/kat/config/cluster_uid" clusterUIDKey = "/kat/config/cluster_uid"
clusterConfigKey = "/kat/config/cluster_config" // Stores the JSON of pb.ClusterConfigurationSpec clusterConfigKey = "/kat/config/cluster_config" // Stores the JSON of pb.ClusterConfigurationSpec
defaultNodeName = "kat-node" defaultNodeName = "kat-node"
leaderCertCN = "leader.kat.cluster.local" // Common Name for leader certificate
) )
func init() { func init() {
@ -54,7 +84,24 @@ func init() {
} }
initCmd.Flags().StringVar(&nodeName, "node-name", defaultHostName, "Name of this node, used as leader ID if elected.") initCmd.Flags().StringVar(&nodeName, "node-name", defaultHostName, "Name of this node, used as leader ID if elected.")
// Join command flags
joinCmd.Flags().StringVar(&leaderAPI, "leader-api", "", "Address of the leader API (required, format: host:port)")
joinCmd.Flags().StringVar(&advertiseAddr, "advertise-address", "", "IP address or interface name to advertise to other nodes (required)")
joinCmd.Flags().StringVar(&nodeName, "node-name", defaultHostName, "Name for this node in the cluster")
joinCmd.Flags().StringVar(&leaderCACert, "leader-ca-cert", "", "Path to the leader's CA certificate (optional, insecure if not provided)")
joinCmd.Flags().BoolVar(&etcdPeer, "etcd-peer", false, "Request to join the etcd quorum (optional)")
// Mark required flags
joinCmd.MarkFlagRequired("leader-api")
joinCmd.MarkFlagRequired("advertise-address")
// Verify command flags
verifyCmd.Flags().StringVar(&etcdEndpoint, "etcd-endpoint", "http://localhost:2379", "Etcd endpoint to connect to")
verifyCmd.Flags().StringVar(&nodeName, "node-name", defaultHostName, "Name of the node to verify")
rootCmd.AddCommand(initCmd) rootCmd.AddCommand(initCmd)
rootCmd.AddCommand(joinCmd)
rootCmd.AddCommand(verifyCmd)
} }
func runInit(cmd *cobra.Command, args []string) { func runInit(cmd *cobra.Command, args []string) {
@ -69,6 +116,25 @@ func runInit(cmd *cobra.Command, args []string) {
// config.SetClusterConfigDefaults(parsedClusterConfig) // config.SetClusterConfigDefaults(parsedClusterConfig)
log.Printf("Successfully parsed and applied defaults to cluster configuration: %s", parsedClusterConfig.Metadata.Name) log.Printf("Successfully parsed and applied defaults to cluster configuration: %s", parsedClusterConfig.Metadata.Name)
// 1.5. Initialize PKI directory and CA if it doesn't exist
pkiDir := pki.GetPKIPathFromClusterConfig(parsedClusterConfig.Spec.BackupPath)
caKeyPath := filepath.Join(pkiDir, "ca.key")
caCertPath := filepath.Join(pkiDir, "ca.crt")
// Check if CA already exists
_, caKeyErr := os.Stat(caKeyPath)
_, caCertErr := os.Stat(caCertPath)
if os.IsNotExist(caKeyErr) || os.IsNotExist(caCertErr) {
log.Printf("CA key or certificate not found. Generating new CA in %s", pkiDir)
if err := pki.GenerateCA(pkiDir, caKeyPath, caCertPath); err != nil {
log.Fatalf("Failed to generate CA: %v", err)
}
log.Println("Successfully generated new CA key and certificate")
} else {
log.Println("CA key and certificate already exist, skipping generation")
}
// Prepare etcd embed config // Prepare etcd embed config
// For a single node init, this node is the only peer. // For a single node init, this node is the only peer.
// Client URLs and Peer URLs will be based on its own configuration. // Client URLs and Peer URLs will be based on its own configuration.
@ -138,6 +204,37 @@ func runInit(cmd *cobra.Command, args []string) {
log.Printf("Cluster UID already exists in etcd. Skipping storage.") log.Printf("Cluster UID already exists in etcd. Skipping storage.")
} }
// Generate leader's server certificate for mTLS
leaderKeyPath := filepath.Join(pkiDir, "leader.key")
leaderCSRPath := filepath.Join(pkiDir, "leader.csr")
leaderCertPath := filepath.Join(pkiDir, "leader.crt")
// Check if leader cert already exists
_, leaderCertErr := os.Stat(leaderCertPath)
if os.IsNotExist(leaderCertErr) {
log.Println("Generating leader server certificate for mTLS")
// Generate key and CSR for leader
if err := pki.GenerateCertificateRequest(leaderCertCN, leaderKeyPath, leaderCSRPath); err != nil {
log.Printf("Failed to generate leader key and CSR: %v", err)
} else {
// Read the CSR file
_, err := os.ReadFile(leaderCSRPath)
if err != nil {
log.Printf("Failed to read leader CSR file: %v", err)
} else {
// Sign the CSR with our CA
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, leaderCSRPath, leaderCertPath, 365*24*time.Hour); err != nil {
log.Printf("Failed to sign leader CSR: %v", err)
} else {
log.Println("Successfully generated and signed leader server certificate")
}
}
}
} else {
log.Println("Leader certificate already exists, skipping generation")
}
// Store ClusterConfigurationSpec (as JSON) // Store ClusterConfigurationSpec (as JSON)
// We store Spec because Metadata might change (e.g. resourceVersion) // We store Spec because Metadata might change (e.g. resourceVersion)
// and is more for API object representation. // and is more for API object representation.
@ -156,6 +253,43 @@ func runInit(cmd *cobra.Command, args []string) {
parsedClusterConfig.Spec.ApiPort) parsedClusterConfig.Spec.ApiPort)
} }
} }
// Start API server with mTLS
log.Println("Starting API server with mTLS...")
apiAddr := fmt.Sprintf(":%d", parsedClusterConfig.Spec.ApiPort)
apiServer, err := api.NewServer(apiAddr, leaderCertPath, leaderKeyPath, caCertPath)
if err != nil {
log.Printf("Failed to create API server: %v", err)
} else {
// Register the join handler
joinHandler := api.NewJoinHandler(etcdStore, caKeyPath, caCertPath)
apiServer.RegisterJoinHandler(joinHandler)
log.Printf("Registered join handler with CA key: %s, CA cert: %s", caKeyPath, caCertPath)
// Start the server in a goroutine
go func() {
if err := apiServer.Start(); err != nil && err != http.ErrServerClosed {
log.Printf("API server error: %v", err)
}
}()
// Add a shutdown hook to the leadership context
go func() {
<-leadershipCtx.Done()
log.Println("Leadership lost, shutting down API server...")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := apiServer.Stop(shutdownCtx); err != nil {
log.Printf("Error shutting down API server: %v", err)
}
}()
log.Printf("API server started on port %d with mTLS", parsedClusterConfig.Spec.ApiPort)
log.Printf("Verification: API server requires client certificates signed by the cluster CA")
log.Printf("Test with: curl --cacert %s --cert <client_cert> --key <client_key> https://localhost:%d/internal/v1alpha1/join",
caCertPath, parsedClusterConfig.Spec.ApiPort)
}
log.Println("Initial leader setup complete. Waiting for leadership context to end or agent to be stopped.") log.Println("Initial leader setup complete. Waiting for leadership context to end or agent to be stopped.")
<-leadershipCtx.Done() // Wait until leadership is lost or context is cancelled by manager <-leadershipCtx.Done() // Wait until leadership is lost or context is cancelled by manager
}, },
@ -190,6 +324,60 @@ func runInit(cmd *cobra.Command, args []string) {
log.Println("KAT Agent init shutdown complete.") log.Println("KAT Agent init shutdown complete.")
} }
func runJoin(cmd *cobra.Command, args []string) {
log.Printf("Starting KAT Agent in join mode for node: %s", nodeName)
log.Printf("Attempting to join cluster via leader API: %s", leaderAPI)
// Determine PKI directory
// For simplicity, we'll use a default location
pkiDir := filepath.Join(os.Getenv("HOME"), ".kat-agent", nodeName, "pki")
// Join the cluster
if err := cli.JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert, pkiDir); err != nil {
log.Fatalf("Failed to join cluster: %v", err)
}
log.Printf("Successfully joined cluster. Node is ready.")
// Setup signal handling for graceful shutdown
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
// Stay up in an idle loop until interrupted
log.Printf("Node %s is now running. Press Ctrl+C to exit.", nodeName)
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Println("Received shutdown signal. Exiting...")
return
case <-ticker.C:
log.Printf("Node %s is still running...", nodeName)
}
}
}
func runVerify(cmd *cobra.Command, args []string) {
log.Printf("Verifying node registration for node: %s", nodeName)
log.Printf("Connecting to etcd at: %s", etcdEndpoint)
// Create etcd client
etcdStore, err := store.NewEtcdStore([]string{etcdEndpoint}, nil)
if err != nil {
log.Fatalf("Failed to create etcd store client: %v", err)
}
defer etcdStore.Close()
// Verify node registration
if err := cli.VerifyNodeRegistration(etcdStore, nodeName); err != nil {
log.Fatalf("Failed to verify node registration: %v", err)
}
log.Printf("Node registration verification complete.")
}
func main() { func main() {
if err := rootCmd.Execute(); err != nil { if err := rootCmd.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err) fmt.Fprintf(os.Stderr, "Error: %v\n", err)

View File

@ -3,8 +3,8 @@ kind: ClusterConfiguration
metadata: metadata:
name: my-kat-cluster name: my-kat-cluster
spec: spec:
clusterCIDR: "10.100.0.0/16" cluster_CIDR: "10.100.0.0/16"
serviceCIDR: "10.200.0.0/16" service_CIDR: "10.200.0.0/16"
nodeSubnetBits: 7 # Results in /23 node subnets (e.g., 10.100.0.0/23, 10.100.2.0/23) nodeSubnetBits: 7 # Results in /23 node subnets (e.g., 10.100.0.0/23, 10.100.2.0/23)
clusterDomain: "kat.example.local" # Overriding default clusterDomain: "kat.example.local" # Overriding default
apiPort: 9115 apiPort: 9115

View File

@ -0,0 +1,169 @@
package api
import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"time"
"github.com/google/uuid"
"git.dws.rip/dubey/kat/internal/pki"
"git.dws.rip/dubey/kat/internal/store"
)
// JoinRequest represents the data sent by an agent when joining
type JoinRequest struct {
CSRData string `json:"csrData"` // base64 encoded CSR
AdvertiseAddr string `json:"advertiseAddr"`
NodeName string `json:"nodeName,omitempty"` // Optional, leader can generate
WireGuardPubKey string `json:"wireguardPubKey"` // Placeholder for now
}
// JoinResponse represents the data sent back to the agent
type JoinResponse struct {
NodeName string `json:"nodeName"`
NodeUID string `json:"nodeUID"`
SignedCertificate string `json:"signedCertificate"` // base64 encoded certificate
CACertificate string `json:"caCertificate"` // base64 encoded CA certificate
AssignedSubnet string `json:"assignedSubnet"` // Placeholder for now
EtcdJoinInstructions string `json:"etcdJoinInstructions,omitempty"`
}
// NewJoinHandler creates a handler for agent join requests
func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
log.Printf("Received join request from %s", r.RemoteAddr)
// Read and parse the request body
body, err := io.ReadAll(r.Body)
if err != nil {
log.Printf("Failed to read request body: %v", err)
http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest)
return
}
defer r.Body.Close()
var joinReq JoinRequest
if err := json.Unmarshal(body, &joinReq); err != nil {
log.Printf("Failed to parse request: %v", err)
http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest)
return
}
// Validate request
if joinReq.CSRData == "" {
log.Printf("Missing CSR data")
http.Error(w, "Missing CSR data", http.StatusBadRequest)
return
}
if joinReq.AdvertiseAddr == "" {
log.Printf("Missing advertise address")
http.Error(w, "Missing advertise address", http.StatusBadRequest)
return
}
// Generate node name if not provided
nodeName := joinReq.NodeName
if nodeName == "" {
nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8])
log.Printf("Generated node name: %s", nodeName)
}
// Generate a unique node ID
nodeUID := uuid.New().String()
log.Printf("Generated node UID: %s", nodeUID)
// Decode CSR data
csrData, err := base64.StdEncoding.DecodeString(joinReq.CSRData)
if err != nil {
log.Printf("Failed to decode CSR data: %v", err)
http.Error(w, fmt.Sprintf("Failed to decode CSR data: %v", err), http.StatusBadRequest)
return
}
// Create a temporary file for the CSR
tempDir := os.TempDir()
csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID))
if err := os.WriteFile(csrPath, csrData, 0600); err != nil {
log.Printf("Failed to save CSR: %v", err)
http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError)
return
}
defer os.Remove(csrPath)
// Sign the CSR
certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID))
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil {
log.Printf("Failed to sign CSR: %v", err)
http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError)
return
}
defer os.Remove(certPath)
// Read the signed certificate
signedCert, err := os.ReadFile(certPath)
if err != nil {
log.Printf("Failed to read signed certificate: %v", err)
http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError)
return
}
// Read the CA certificate
caCert, err := os.ReadFile(caCertPath)
if err != nil {
log.Printf("Failed to read CA certificate: %v", err)
http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError)
return
}
// Store node registration in etcd
nodeRegKey := fmt.Sprintf("/kat/nodes/registration/%s", nodeName)
nodeReg := map[string]interface{}{
"uid": nodeUID,
"advertiseAddr": joinReq.AdvertiseAddr,
"wireguardPubKey": joinReq.WireGuardPubKey,
"joinTimestamp": time.Now().Unix(),
}
nodeRegData, err := json.Marshal(nodeReg)
if err != nil {
log.Printf("Failed to marshal node registration: %v", err)
http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError)
return
}
log.Printf("Storing node registration in etcd at key: %s", nodeRegKey)
if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil {
log.Printf("Failed to store node registration: %v", err)
http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError)
return
}
log.Printf("Successfully stored node registration in etcd")
// Prepare and send response
joinResp := JoinResponse{
NodeName: nodeName,
NodeUID: nodeUID,
SignedCertificate: base64.StdEncoding.EncodeToString(signedCert),
CACertificate: base64.StdEncoding.EncodeToString(caCert),
AssignedSubnet: "10.100.0.0/24", // Placeholder for now, will be implemented in network phase
}
respData, err := json.Marshal(joinResp)
if err != nil {
log.Printf("Failed to marshal response: %v", err)
http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(respData)
log.Printf("Successfully processed join request for node: %s", nodeName)
}
}

View File

@ -0,0 +1,168 @@
package api
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"git.dws.rip/dubey/kat/internal/pki"
"git.dws.rip/dubey/kat/internal/store"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// MockStateStore for testing
type MockStateStore struct {
mock.Mock
}
func (m *MockStateStore) Put(ctx context.Context, key string, value []byte) error {
args := m.Called(ctx, key, value)
return args.Error(0)
}
func (m *MockStateStore) Get(ctx context.Context, key string) (*store.KV, error) {
args := m.Called(ctx, key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*store.KV), args.Error(1)
}
func (m *MockStateStore) Delete(ctx context.Context, key string) error {
args := m.Called(ctx, key)
return args.Error(0)
}
func (m *MockStateStore) List(ctx context.Context, prefix string) ([]store.KV, error) {
args := m.Called(ctx, prefix)
return args.Get(0).([]store.KV), args.Error(1)
}
func (m *MockStateStore) Watch(ctx context.Context, keyOrPrefix string, startRevision int64) (<-chan store.WatchEvent, error) {
args := m.Called(ctx, keyOrPrefix, startRevision)
return args.Get(0).(chan store.WatchEvent), args.Error(1)
}
func (m *MockStateStore) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockStateStore) Campaign(ctx context.Context, leaderID string, leaseTTLSeconds int64) (context.Context, error) {
args := m.Called(ctx, leaderID, leaseTTLSeconds)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(context.Context), args.Error(1)
}
func (m *MockStateStore) Resign(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func (m *MockStateStore) GetLeader(ctx context.Context) (string, error) {
args := m.Called(ctx)
return args.String(0), args.Error(1)
}
func (m *MockStateStore) DoTransaction(ctx context.Context, checks []store.Compare, onSuccess []store.Op, onFailure []store.Op) (bool, error) {
args := m.Called(ctx, checks, onSuccess, onFailure)
return args.Bool(0), args.Error(1)
}
func TestJoinHandler(t *testing.T) {
// Create temporary directory for test PKI files
tempDir, err := os.MkdirTemp("", "kat-test-pki-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Generate CA for testing
caKeyPath := filepath.Join(tempDir, "ca.key")
caCertPath := filepath.Join(tempDir, "ca.crt")
err = pki.GenerateCA(tempDir, caKeyPath, caCertPath)
if err != nil {
t.Fatalf("Failed to generate test CA: %v", err)
}
// Generate a test CSR
nodeKeyPath := filepath.Join(tempDir, "node.key")
nodeCSRPath := filepath.Join(tempDir, "node.csr")
err = pki.GenerateCertificateRequest("test-node", nodeKeyPath, nodeCSRPath)
if err != nil {
t.Fatalf("Failed to generate test CSR: %v", err)
}
// Read the CSR file
csrData, err := os.ReadFile(nodeCSRPath)
if err != nil {
t.Fatalf("Failed to read CSR file: %v", err)
}
// Create mock state store
mockStore := new(MockStateStore)
mockStore.On("Put", mock.Anything, mock.MatchedBy(func(key string) bool {
return key == "/kat/nodes/registration/test-node"
}), mock.Anything).Return(nil)
// Create join handler
handler := NewJoinHandler(mockStore, caKeyPath, caCertPath)
// Create test request
joinReq := JoinRequest{
NodeName: "test-node",
AdvertiseAddr: "192.168.1.100",
CSRData: base64.StdEncoding.EncodeToString(csrData),
WireGuardPubKey: "test-pubkey",
}
reqBody, err := json.Marshal(joinReq)
if err != nil {
t.Fatalf("Failed to marshal join request: %v", err)
}
// Create HTTP request
req := httptest.NewRequest("POST", "/internal/v1alpha1/join", bytes.NewBuffer(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
// Call handler
handler(w, req)
// Check response
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Read response body
respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
// Parse response
var joinResp JoinResponse
err = json.Unmarshal(respBody, &joinResp)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Verify response fields
assert.Equal(t, "test-node", joinResp.NodeName)
assert.NotEmpty(t, joinResp.NodeUID)
assert.NotEmpty(t, joinResp.SignedCertificate)
assert.NotEmpty(t, joinResp.CACertificate)
assert.Equal(t, "10.100.0.0/24", joinResp.AssignedSubnet) // Placeholder value
// Verify mock was called
mockStore.AssertExpectations(t)
}

48
internal/api/router.go Normal file
View File

@ -0,0 +1,48 @@
package api
import (
"net/http"
"strings"
)
// Route represents a single API route
type Route struct {
Method string
Path string
Handler http.HandlerFunc
}
// Router is a simple HTTP router for the KAT API
type Router struct {
routes []Route
}
// NewRouter creates a new router instance
func NewRouter() *Router {
return &Router{
routes: []Route{},
}
}
// HandleFunc registers a new route with the router
func (r *Router) HandleFunc(method, path string, handler http.HandlerFunc) {
r.routes = append(r.routes, Route{
Method: strings.ToUpper(method),
Path: path,
Handler: handler,
})
}
// ServeHTTP implements the http.Handler interface
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Find matching route
for _, route := range r.routes {
if route.Method == req.Method && route.Path == req.URL.Path {
route.Handler(w, req)
return
}
}
// No route matched
http.NotFound(w, req)
}

145
internal/api/server.go Normal file
View File

@ -0,0 +1,145 @@
package api
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"log"
"net/http"
"os"
"time"
)
// loggingResponseWriter is a wrapper for http.ResponseWriter to capture status code
type loggingResponseWriter struct {
http.ResponseWriter
statusCode int
}
// WriteHeader captures the status code before passing to the underlying ResponseWriter
func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
// LoggingMiddleware logs information about each request
func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Create a response writer wrapper to capture status code
lrw := &loggingResponseWriter{
ResponseWriter: w,
statusCode: http.StatusOK, // Default status
}
// Process the request
next.ServeHTTP(lrw, r)
// Calculate duration
duration := time.Since(start)
// Log the request details
log.Printf("REQUEST: %s %s - %d %s - %s - %v",
r.Method,
r.URL.Path,
lrw.statusCode,
http.StatusText(lrw.statusCode),
r.RemoteAddr,
duration,
)
})
}
// Server represents the API server for KAT
type Server struct {
httpServer *http.Server
router *Router
certFile string
keyFile string
caFile string
}
// NewServer creates a new API server instance
func NewServer(addr string, certFile, keyFile, caFile string) (*Server, error) {
router := NewRouter()
server := &Server{
router: router,
certFile: certFile,
keyFile: keyFile,
caFile: caFile,
}
// Create the HTTP server with TLS config
server.httpServer = &http.Server{
Addr: addr,
Handler: LoggingMiddleware(router), // Add logging middleware
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
return server, nil
}
// Start begins listening for requests
func (s *Server) Start() error {
log.Printf("Starting server on %s", s.httpServer.Addr)
// Load server certificate and key
cert, err := tls.LoadX509KeyPair(s.certFile, s.keyFile)
if err != nil {
return fmt.Errorf("failed to load server certificate and key: %w", err)
}
// Load CA certificate for client verification
caCert, err := os.ReadFile(s.caFile)
if err != nil {
return fmt.Errorf("failed to read CA certificate: %w", err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return fmt.Errorf("failed to append CA certificate to pool")
}
// For Phase 2, we'll use a simpler approach - don't require client certs at all
// This is a temporary solution until we implement proper authentication
s.httpServer.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.NoClientCert, // Don't require client certs for now
MinVersion: tls.VersionTLS12,
}
log.Printf("WARNING: TLS configured without client certificate verification for Phase 2")
log.Printf("This is a temporary development configuration and should be secured in production")
log.Printf("Server configured with TLS, starting to listen for requests")
// Start the server
return s.httpServer.ListenAndServeTLS("", "")
}
// Stop gracefully shuts down the server
func (s *Server) Stop(ctx context.Context) error {
log.Printf("Shutting down server on %s", s.httpServer.Addr)
err := s.httpServer.Shutdown(ctx)
if err != nil {
log.Printf("Error during server shutdown: %v", err)
return err
}
log.Printf("Server shutdown complete")
return nil
}
// RegisterJoinHandler registers the handler for agent join requests
func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) {
s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler)
log.Printf("Registered join handler at /internal/v1alpha1/join")
}
// RegisterNodeStatusHandler registers the handler for node status updates
func (s *Server) RegisterNodeStatusHandler(handler http.HandlerFunc) {
s.router.HandleFunc("POST", "/v1alpha1/nodes/{nodeName}/status", handler)
}

151
internal/api/server_test.go Normal file
View File

@ -0,0 +1,151 @@
package api
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"
"git.dws.rip/dubey/kat/internal/pki"
)
// TestServerWithMTLS tests the server with TLS configuration
// Note: In Phase 2, we've temporarily disabled client certificate verification
// to simplify the initial join process. This test has been updated to reflect that.
func TestServerWithMTLS(t *testing.T) {
// Skip in short mode
if testing.Short() {
t.Skip("Skipping test in short mode")
}
// Create temporary directory for test certificates
tempDir, err := os.MkdirTemp("", "kat-api-test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
// Generate CA
caKeyPath := filepath.Join(tempDir, "ca.key")
caCertPath := filepath.Join(tempDir, "ca.crt")
if err := pki.GenerateCA(tempDir, caKeyPath, caCertPath); err != nil {
t.Fatalf("Failed to generate CA: %v", err)
}
// Generate server certificate
serverKeyPath := filepath.Join(tempDir, "server.key")
serverCSRPath := filepath.Join(tempDir, "server.csr")
serverCertPath := filepath.Join(tempDir, "server.crt")
if err := pki.GenerateCertificateRequest("localhost", serverKeyPath, serverCSRPath); err != nil {
t.Fatalf("Failed to generate server CSR: %v", err)
}
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, serverCSRPath, serverCertPath, 24*time.Hour); err != nil {
t.Fatalf("Failed to sign server certificate: %v", err)
}
// Generate client certificate
clientKeyPath := filepath.Join(tempDir, "client.key")
clientCSRPath := filepath.Join(tempDir, "client.csr")
clientCertPath := filepath.Join(tempDir, "client.crt")
if err := pki.GenerateCertificateRequest("client.test", clientKeyPath, clientCSRPath); err != nil {
t.Fatalf("Failed to generate client CSR: %v", err)
}
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, clientCSRPath, clientCertPath, 24*time.Hour); err != nil {
t.Fatalf("Failed to sign client certificate: %v", err)
}
// Create and start server
server, err := NewServer("localhost:8443", serverCertPath, serverKeyPath, caCertPath)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// Add a test handler
server.router.HandleFunc("GET", "/test", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test successful"))
})
// Start server in a goroutine
go func() {
if err := server.Start(); err != nil && err != http.ErrServerClosed {
t.Errorf("Server error: %v", err)
}
}()
// Wait for server to start
time.Sleep(250 * time.Millisecond)
// Load CA cert
caCert, err := os.ReadFile(caCertPath)
if err != nil {
t.Fatalf("Failed to read CA cert: %v", err)
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
// Load client cert
clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
if err != nil {
t.Fatalf("Failed to load client cert: %v", err)
}
// Create HTTP client with mTLS
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: caCertPool,
Certificates: []tls.Certificate{clientCert},
},
},
}
// Test with valid client cert
resp, err := client.Get("https://localhost:8443/test")
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}
if !strings.Contains(string(body), "test successful") {
t.Errorf("Unexpected response: %s", body)
}
// Test with no client cert (should succeed in Phase 2)
clientWithoutCert := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: caCertPool,
},
},
}
resp, err = clientWithoutCert.Get("https://localhost:8443/test")
if err != nil {
t.Errorf("Request without client cert should succeed in Phase 2: %v", err)
} else {
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Errorf("Failed to read response: %v", err)
}
if !strings.Contains(string(body), "test successful") {
t.Errorf("Unexpected response: %s", body)
}
}
// Shutdown server
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
server.Stop(ctx)
}

168
internal/cli/join.go Normal file
View File

@ -0,0 +1,168 @@
package cli
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"time"
"git.dws.rip/dubey/kat/internal/pki"
)
// JoinRequest represents the data sent to the leader when joining
type JoinRequest struct {
NodeName string `json:"nodeName"`
AdvertiseAddr string `json:"advertiseAddr"`
CSRData string `json:"csrData"` // base64 encoded CSR
WireGuardPubKey string `json:"wireguardPubKey"`
}
// JoinResponse represents the data received from the leader after a successful join
type JoinResponse struct {
NodeName string `json:"nodeName"`
NodeUID string `json:"nodeUID"`
SignedCertificate string `json:"signedCertificate"` // base64 encoded certificate
CACertificate string `json:"caCertificate"` // base64 encoded CA certificate
AssignedSubnet string `json:"assignedSubnet"`
EtcdJoinInstructions string `json:"etcdJoinInstructions,omitempty"`
}
// JoinCluster sends a join request to the leader and processes the response
func JoinCluster(leaderAPI, advertiseAddr, nodeName, leaderCACert string, pkiDir string) error {
// Create PKI directory if it doesn't exist
if err := os.MkdirAll(pkiDir, 0700); err != nil {
return fmt.Errorf("failed to create PKI directory: %w", err)
}
// Generate key and CSR
nodeKeyPath := filepath.Join(pkiDir, "node.key")
nodeCSRPath := filepath.Join(pkiDir, "node.csr")
nodeCertPath := filepath.Join(pkiDir, "node.crt")
caCertPath := filepath.Join(pkiDir, "ca.crt")
log.Printf("Generating node key and CSR...")
if err := pki.GenerateCertificateRequest(nodeName, nodeKeyPath, nodeCSRPath); err != nil {
return fmt.Errorf("failed to generate key and CSR: %w", err)
}
// Read the CSR file
csrData, err := os.ReadFile(nodeCSRPath)
if err != nil {
return fmt.Errorf("failed to read CSR file: %w", err)
}
// Create join request
joinReq := JoinRequest{
NodeName: nodeName,
AdvertiseAddr: advertiseAddr,
CSRData: base64.StdEncoding.EncodeToString(csrData),
WireGuardPubKey: "placeholder", // Will be implemented in a future phase
}
// Marshal request to JSON
reqBody, err := json.Marshal(joinReq)
if err != nil {
return fmt.Errorf("failed to marshal join request: %w", err)
}
// Create HTTP client with TLS configuration
client := &http.Client{
Timeout: 30 * time.Second,
}
// If leader CA cert is provided, configure TLS to trust it
if leaderCACert != "" {
// Read the CA cert file
caCert, err := os.ReadFile(leaderCACert)
if err != nil {
return fmt.Errorf("failed to read leader CA certificate: %w", err)
}
// Create a cert pool and add the CA cert
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return fmt.Errorf("failed to parse leader CA certificate")
}
// Configure TLS
client.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: caCertPool,
},
}
} else {
// For Phase 2 development, allow insecure connections
// This should be removed in production
log.Println("WARNING: No leader CA certificate provided. TLS verification disabled (Phase 2 development mode).")
log.Println("This is expected for the initial join process in Phase 2.")
client.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
}
// Send join request to leader
joinURL := fmt.Sprintf("https://%s/internal/v1alpha1/join", leaderAPI)
log.Printf("Sending join request to %s...", joinURL)
resp, err := client.Post(joinURL, "application/json", bytes.NewBuffer(reqBody))
if err != nil {
return fmt.Errorf("failed to send join request: %w", err)
}
defer resp.Body.Close()
// Read response body
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}
// Check response status
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("join request failed with status %d: %s", resp.StatusCode, string(respBody))
}
// Parse response
var joinResp JoinResponse
if err := json.Unmarshal(respBody, &joinResp); err != nil {
return fmt.Errorf("failed to parse join response: %w", err)
}
// Save signed certificate
certData, err := base64.StdEncoding.DecodeString(joinResp.SignedCertificate)
if err != nil {
return fmt.Errorf("failed to decode signed certificate: %w", err)
}
if err := os.WriteFile(nodeCertPath, certData, 0600); err != nil {
return fmt.Errorf("failed to save signed certificate: %w", err)
}
log.Printf("Saved signed certificate to %s", nodeCertPath)
// Save CA certificate
caCertData, err := base64.StdEncoding.DecodeString(joinResp.CACertificate)
if err != nil {
return fmt.Errorf("failed to decode CA certificate: %w", err)
}
if err := os.WriteFile(caCertPath, caCertData, 0600); err != nil {
return fmt.Errorf("failed to save CA certificate: %w", err)
}
log.Printf("Saved CA certificate to %s", caCertPath)
log.Printf("Successfully joined cluster as node: %s", joinResp.NodeName)
if joinResp.AssignedSubnet != "" {
log.Printf("Assigned subnet: %s", joinResp.AssignedSubnet)
}
if joinResp.EtcdJoinInstructions != "" {
log.Printf("Etcd join instructions: %s", joinResp.EtcdJoinInstructions)
}
return nil
}

View File

@ -0,0 +1,53 @@
package cli
import (
"context"
"encoding/json"
"fmt"
"log"
"time"
"git.dws.rip/dubey/kat/internal/store"
)
// NodeRegistration represents the data stored in etcd for a node
type NodeRegistration struct {
UID string `json:"uid"`
AdvertiseAddr string `json:"advertiseAddr"`
WireguardPubKey string `json:"wireguardPubKey"`
JoinTimestamp int64 `json:"joinTimestamp"`
}
// VerifyNodeRegistration checks if a node is registered in etcd
func VerifyNodeRegistration(etcdStore store.StateStore, nodeName string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Construct the key for the node registration
nodeRegKey := fmt.Sprintf("/kat/nodes/registration/%s", nodeName)
// Get the node registration from etcd
kv, err := etcdStore.Get(ctx, nodeRegKey)
if err != nil {
return fmt.Errorf("failed to get node registration from etcd: %w", err)
}
// Parse the node registration
var nodeReg NodeRegistration
if err := json.Unmarshal(kv.Value, &nodeReg); err != nil {
return fmt.Errorf("failed to parse node registration: %w", err)
}
// Print the node registration details
log.Printf("Node Registration Details:")
log.Printf(" Node Name: %s", nodeName)
log.Printf(" Node UID: %s", nodeReg.UID)
log.Printf(" Advertise Address: %s", nodeReg.AdvertiseAddr)
log.Printf(" WireGuard Public Key: %s", nodeReg.WireguardPubKey)
// Convert timestamp to human-readable format
joinTime := time.Unix(nodeReg.JoinTimestamp, 0)
log.Printf(" Join Timestamp: %s (%d)", joinTime.Format(time.RFC3339), nodeReg.JoinTimestamp)
return nil
}

View File

@ -201,8 +201,8 @@ func TestValidateClusterConfiguration_InvalidValues(t *testing.T) {
ApiPort: 10251, ApiPort: 10251,
EtcdPeerPort: 2380, EtcdPeerPort: 2380,
EtcdClientPort: 2379, EtcdClientPort: 2379,
VolumeBasePath: "/var/lib/kat/volumes", VolumeBasePath: ".kat/volumes",
BackupPath: "/var/lib/kat/backups", BackupPath: ".kat/backups",
BackupIntervalMinutes: 30, BackupIntervalMinutes: 30,
AgentTickSeconds: 15, AgentTickSeconds: 15,
NodeLossTimeoutSeconds: 60, NodeLossTimeoutSeconds: 60,

View File

@ -11,13 +11,13 @@ const (
DefaultApiPort = 9115 DefaultApiPort = 9115
DefaultEtcdPeerPort = 2380 DefaultEtcdPeerPort = 2380
DefaultEtcdClientPort = 2379 DefaultEtcdClientPort = 2379
DefaultVolumeBasePath = "/var/lib/kat/volumes" DefaultVolumeBasePath = ".kat/volumes"
DefaultBackupPath = "/var/lib/kat/backups" DefaultBackupPath = ".kat/backups"
DefaultBackupIntervalMins = 30 DefaultBackupIntervalMins = 30
DefaultAgentTickSeconds = 15 DefaultAgentTickSeconds = 15
DefaultNodeLossTimeoutSec = 60 // DefaultNodeLossTimeoutSeconds = DefaultAgentTickSeconds * 4 (example logic) DefaultNodeLossTimeoutSec = 60 // DefaultNodeLossTimeoutSeconds = DefaultAgentTickSeconds * 4 (example logic)
DefaultNodeSubnetBits = 7 // yields /23 from /16, or /31 from /24 etc. (5 bits for /29, 7 for /25) DefaultNodeSubnetBits = 7 // yields /23 from /16, or /31 from /24 etc. (5 bits for /29, 7 for /25)
// RFC says 7 for /23 from /16. This means 2^(32-16-7) = 2^9 = 512 IPs per node subnet. // RFC says 7 for /23 from /16. This means 2^(32-16-7) = 2^9 = 512 IPs per node subnet.
// If nodeSubnetBits means bits for the node portion *within* the host part of clusterCIDR: // If nodeSubnetBits means bits for the node portion *within* the host part of clusterCIDR:
// e.g. /16 -> 16 host bits. If nodeSubnetBits = 7, then node subnet is / (16+7) = /23. // e.g. /16 -> 16 host bits. If nodeSubnetBits = 7, then node subnet is / (16+7) = /23.
) )

318
internal/pki/ca.go Normal file
View File

@ -0,0 +1,318 @@
package pki
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"os"
"path/filepath"
"strings"
"time"
)
const (
// Default key size for RSA keys
DefaultRSAKeySize = 2048
// Default CA certificate validity period
DefaultCAValidityDays = 3650 // ~10 years
// Default certificate validity period
DefaultCertValidityDays = 365 // 1 year
// Default PKI directory
DefaultPKIDir = ".kat/pki"
)
// GenerateCA creates a new Certificate Authority key pair and certificate.
// It saves the private key and certificate to the specified paths.
func GenerateCA(pkiDir string, keyPath, certPath string) error {
// Create PKI directory if it doesn't exist
if err := os.MkdirAll(pkiDir, 0700); err != nil {
return fmt.Errorf("failed to create PKI directory: %w", err)
}
// Generate RSA key
key, err := rsa.GenerateKey(rand.Reader, DefaultRSAKeySize)
if err != nil {
return fmt.Errorf("failed to generate CA key: %w", err)
}
// Create self-signed certificate
serialNumber, err := generateSerialNumber()
if err != nil {
return fmt.Errorf("failed to generate serial number: %w", err)
}
// Certificate template
notBefore := time.Now()
notAfter := notBefore.Add(time.Duration(DefaultCAValidityDays) * 24 * time.Hour)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "KAT Root CA",
Organization: []string{"KAT System"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
MaxPathLen: 1, // Only allow one level of intermediate certs
}
// Create certificate
derBytes, err := x509.CreateCertificate(
rand.Reader,
&template,
&template, // Self-signed
&key.PublicKey,
key,
)
if err != nil {
return fmt.Errorf("failed to create CA certificate: %w", err)
}
// Save private key
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to open CA key file for writing: %w", err)
}
defer keyOut.Close()
err = pem.Encode(keyOut, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
if err != nil {
return fmt.Errorf("failed to write CA key to file: %w", err)
}
// Save certificate
certOut, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return fmt.Errorf("failed to open CA certificate file for writing: %w", err)
}
defer certOut.Close()
err = pem.Encode(certOut, &pem.Block{
Type: "CERTIFICATE",
Bytes: derBytes,
})
if err != nil {
return fmt.Errorf("failed to write CA certificate to file: %w", err)
}
return nil
}
// GenerateCertificateRequest creates a new key pair and a Certificate Signing Request (CSR).
// It saves the private key and CSR to the specified paths.
func GenerateCertificateRequest(commonName, keyOutPath, csrOutPath string) error {
// Generate RSA key
key, err := rsa.GenerateKey(rand.Reader, DefaultRSAKeySize)
if err != nil {
return fmt.Errorf("failed to generate key: %w", err)
}
// Create CSR template
template := x509.CertificateRequest{
Subject: pkix.Name{
CommonName: commonName,
Organization: []string{"KAT System"},
},
SignatureAlgorithm: x509.SHA256WithRSA,
DNSNames: []string{commonName}, // Add the CN as a SAN
}
// Create CSR
csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &template, key)
if err != nil {
return fmt.Errorf("failed to create CSR: %w", err)
}
// Save private key
keyOut, err := os.OpenFile(keyOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to open key file for writing: %w", err)
}
defer keyOut.Close()
err = pem.Encode(keyOut, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
if err != nil {
return fmt.Errorf("failed to write key to file: %w", err)
}
// Save CSR
csrOut, err := os.OpenFile(csrOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return fmt.Errorf("failed to open CSR file for writing: %w", err)
}
defer csrOut.Close()
err = pem.Encode(csrOut, &pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csrBytes,
})
if err != nil {
return fmt.Errorf("failed to write CSR to file: %w", err)
}
return nil
}
// SignCertificateRequest signs a CSR using the CA key and certificate.
// It reads the CSR from csrPath and saves the signed certificate to certOutPath.
// If csrPath contains PEM data (starts with "-----BEGIN"), it uses that directly instead of reading a file.
func SignCertificateRequest(caKeyPath, caCertPath, csrPathOrData, certOutPath string, duration time.Duration) error {
// Load CA key
caKey, err := LoadCAPrivateKey(caKeyPath)
if err != nil {
return fmt.Errorf("failed to load CA key: %w", err)
}
// Load CA certificate
caCert, err := LoadCACertificate(caCertPath)
if err != nil {
return fmt.Errorf("failed to load CA certificate: %w", err)
}
// Determine if csrPathOrData is a file path or PEM data
var csrPEM []byte
if strings.HasPrefix(csrPathOrData, "-----BEGIN") {
// It's PEM data, use it directly
csrPEM = []byte(csrPathOrData)
} else {
// It's a file path, read the file
csrPEM, err = os.ReadFile(csrPathOrData)
if err != nil {
return fmt.Errorf("failed to read CSR file: %w", err)
}
}
block, _ := pem.Decode(csrPEM)
if block == nil || block.Type != "CERTIFICATE REQUEST" {
return fmt.Errorf("failed to decode PEM block containing CSR")
}
csr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse CSR: %w", err)
}
// Verify CSR signature
if err = csr.CheckSignature(); err != nil {
return fmt.Errorf("CSR signature verification failed: %w", err)
}
// Create certificate template from CSR
serialNumber, err := generateSerialNumber()
if err != nil {
return fmt.Errorf("failed to generate serial number: %w", err)
}
notBefore := time.Now()
notAfter := notBefore.Add(duration)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: csr.Subject,
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
DNSNames: []string{csr.Subject.CommonName}, // Add the CN as a SAN
}
// Create certificate
derBytes, err := x509.CreateCertificate(
rand.Reader,
&template,
caCert,
csr.PublicKey,
caKey,
)
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
}
// Save certificate
certOut, err := os.OpenFile(certOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return fmt.Errorf("failed to open certificate file for writing: %w", err)
}
defer certOut.Close()
err = pem.Encode(certOut, &pem.Block{
Type: "CERTIFICATE",
Bytes: derBytes,
})
if err != nil {
return fmt.Errorf("failed to write certificate to file: %w", err)
}
return nil
}
// GetPKIPathFromClusterConfig determines the PKI directory from the cluster configuration.
// If backupPath is provided, it uses the parent directory of backupPath.
// Otherwise, it uses the default PKI directory.
func GetPKIPathFromClusterConfig(backupPath string) string {
if backupPath == "" {
return DefaultPKIDir
}
// Use the parent directory of backupPath
return filepath.Dir(backupPath) + "/pki"
}
// generateSerialNumber creates a random serial number for certificates
func generateSerialNumber() (*big.Int, error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) // 128 bits
return rand.Int(rand.Reader, serialNumberLimit)
}
// LoadCACertificate loads a CA certificate from a file
func LoadCACertificate(certPath string) (*x509.Certificate, error) {
certPEM, err := os.ReadFile(certPath)
if err != nil {
return nil, fmt.Errorf("failed to read CA certificate file: %w", err)
}
block, _ := pem.Decode(certPEM)
if block == nil || block.Type != "CERTIFICATE" {
return nil, fmt.Errorf("failed to decode PEM block containing certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse CA certificate: %w", err)
}
return cert, nil
}
// LoadCAPrivateKey loads a CA private key from a file
func LoadCAPrivateKey(keyPath string) (*rsa.PrivateKey, error) {
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
return nil, fmt.Errorf("failed to read CA key file: %w", err)
}
block, _ := pem.Decode(keyPEM)
if block == nil || block.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("failed to decode PEM block containing private key")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse CA private key: %w", err)
}
return key, nil
}

73
internal/pki/ca_test.go Normal file
View File

@ -0,0 +1,73 @@
package pki
import (
"os"
"path/filepath"
"testing"
)
func TestGenerateCA(t *testing.T) {
// Create a temporary directory for the test
tempDir, err := os.MkdirTemp("", "kat-pki-test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Define paths for CA key and certificate
keyPath := filepath.Join(tempDir, "ca.key")
certPath := filepath.Join(tempDir, "ca.crt")
// Generate CA
err = GenerateCA(tempDir, keyPath, certPath)
if err != nil {
t.Fatalf("GenerateCA failed: %v", err)
}
// Verify files exist
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
t.Errorf("CA key file was not created at %s", keyPath)
}
if _, err := os.Stat(certPath); os.IsNotExist(err) {
t.Errorf("CA certificate file was not created at %s", certPath)
}
// Load and verify CA certificate
caCert, err := LoadCACertificate(certPath)
if err != nil {
t.Fatalf("Failed to load CA certificate: %v", err)
}
// Verify CA properties
if !caCert.IsCA {
t.Errorf("Certificate is not marked as CA")
}
if caCert.Subject.CommonName != "KAT Root CA" {
t.Errorf("Unexpected CA CommonName: got %s, want %s", caCert.Subject.CommonName, "KAT Root CA")
}
if len(caCert.Subject.Organization) == 0 || caCert.Subject.Organization[0] != "KAT System" {
t.Errorf("Unexpected CA Organization: got %v, want [KAT System]", caCert.Subject.Organization)
}
// Load and verify CA key
_, err = LoadCAPrivateKey(keyPath)
if err != nil {
t.Fatalf("Failed to load CA private key: %v", err)
}
}
func TestGetPKIPathFromClusterConfig(t *testing.T) {
// Test with empty backup path
pkiPath := GetPKIPathFromClusterConfig("")
if pkiPath != DefaultPKIDir {
t.Errorf("Expected default PKI path %s, got %s", DefaultPKIDir, pkiPath)
}
// Test with backup path
backupPath := "/opt/kat/backups"
expectedPKIPath := "/opt/kat/pki"
pkiPath = GetPKIPathFromClusterConfig(backupPath)
if pkiPath != expectedPKIPath {
t.Errorf("Expected PKI path %s, got %s", expectedPKIPath, pkiPath)
}
}

64
internal/pki/certs.go Normal file
View File

@ -0,0 +1,64 @@
package pki
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"os"
)
// ParseCSRFromBytes parses a PEM-encoded CSR from bytes
func ParseCSRFromBytes(csrData []byte) (*x509.CertificateRequest, error) {
block, _ := pem.Decode(csrData)
if block == nil || block.Type != "CERTIFICATE REQUEST" {
return nil, fmt.Errorf("failed to decode PEM block containing CSR")
}
csr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse CSR: %w", err)
}
return csr, nil
}
// LoadCertificate loads an X.509 certificate from a file
func LoadCertificate(certPath string) (*x509.Certificate, error) {
certPEM, err := os.ReadFile(certPath)
if err != nil {
return nil, fmt.Errorf("failed to read certificate file: %w", err)
}
block, _ := pem.Decode(certPEM)
if block == nil || block.Type != "CERTIFICATE" {
return nil, fmt.Errorf("failed to decode PEM block containing certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate: %w", err)
}
return cert, nil
}
// LoadPrivateKey loads an RSA private key from a file
func LoadPrivateKey(keyPath string) (*rsa.PrivateKey, error) {
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
return nil, fmt.Errorf("failed to read key file: %w", err)
}
block, _ := pem.Decode(keyPEM)
if block == nil || block.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("failed to decode PEM block containing private key")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
return key, nil
}

128
internal/pki/certs_test.go Normal file
View File

@ -0,0 +1,128 @@
package pki
import (
"os"
"path/filepath"
"testing"
)
func TestGenerateCertificateRequest(t *testing.T) {
// Create a temporary directory for the test
tempDir, err := os.MkdirTemp("", "kat-csr-test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Define paths for key and CSR
keyPath := filepath.Join(tempDir, "node.key")
csrPath := filepath.Join(tempDir, "node.csr")
commonName := "test-node.kat.cluster.local"
// Generate CSR
err = GenerateCertificateRequest(commonName, keyPath, csrPath)
if err != nil {
t.Fatalf("GenerateCertificateRequest failed: %v", err)
}
// Verify files exist
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
t.Errorf("Key file was not created at %s", keyPath)
}
if _, err := os.Stat(csrPath); os.IsNotExist(err) {
t.Errorf("CSR file was not created at %s", csrPath)
}
// Read CSR file
csrData, err := os.ReadFile(csrPath)
if err != nil {
t.Fatalf("Failed to read CSR file: %v", err)
}
// Parse CSR
csr, err := ParseCSRFromBytes(csrData)
if err != nil {
t.Fatalf("Failed to parse CSR: %v", err)
}
// Verify CSR properties
if csr.Subject.CommonName != commonName {
t.Errorf("Unexpected CSR CommonName: got %s, want %s", csr.Subject.CommonName, commonName)
}
if len(csr.DNSNames) == 0 || csr.DNSNames[0] != commonName {
t.Errorf("Unexpected CSR DNSNames: got %v, want [%s]", csr.DNSNames, commonName)
}
}
func TestSignCertificateRequest(t *testing.T) {
// Create a temporary directory for the test
tempDir, err := os.MkdirTemp("", "kat-cert-test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Generate CA
caKeyPath := filepath.Join(tempDir, "ca.key")
caCertPath := filepath.Join(tempDir, "ca.crt")
err = GenerateCA(tempDir, caKeyPath, caCertPath)
if err != nil {
t.Fatalf("GenerateCA failed: %v", err)
}
// Generate CSR
nodeKeyPath := filepath.Join(tempDir, "node.key")
csrPath := filepath.Join(tempDir, "node.csr")
commonName := "test-node.kat.cluster.local"
err = GenerateCertificateRequest(commonName, nodeKeyPath, csrPath)
if err != nil {
t.Fatalf("GenerateCertificateRequest failed: %v", err)
}
// Read CSR file
csrData, err := os.ReadFile(csrPath)
if err != nil {
t.Fatalf("Failed to read CSR file: %v", err)
}
// Sign CSR
certPath := filepath.Join(tempDir, "node.crt")
err = SignCertificateRequest(caKeyPath, caCertPath, string(csrData), certPath, 30) // 30 days validity
if err != nil {
t.Fatalf("SignCertificateRequest failed: %v", err)
}
// Verify certificate file exists
if _, err := os.Stat(certPath); os.IsNotExist(err) {
t.Errorf("Certificate file was not created at %s", certPath)
}
// Load and verify certificate
cert, err := LoadCertificate(certPath)
if err != nil {
t.Fatalf("Failed to load certificate: %v", err)
}
// Verify certificate properties
if cert.Subject.CommonName != commonName {
t.Errorf("Unexpected certificate CommonName: got %s, want %s", cert.Subject.CommonName, commonName)
}
if cert.IsCA {
t.Errorf("Certificate should not be a CA")
}
if len(cert.DNSNames) == 0 || cert.DNSNames[0] != commonName {
t.Errorf("Unexpected certificate DNSNames: got %v, want [%s]", cert.DNSNames, commonName)
}
// Load CA certificate to verify chain
caCert, err := LoadCACertificate(caCertPath)
if err != nil {
t.Fatalf("Failed to load CA certificate: %v", err)
}
// Verify certificate is signed by CA
err = cert.CheckSignatureFrom(caCert)
if err != nil {
t.Errorf("Certificate signature verification failed: %v", err)
}
}

View File

@ -52,7 +52,7 @@ func StartEmbeddedEtcd(cfg EtcdEmbedConfig) (*embed.Etcd, error) {
embedCfg.Name = cfg.Name embedCfg.Name = cfg.Name
embedCfg.Dir = cfg.DataDir embedCfg.Dir = cfg.DataDir
embedCfg.InitialClusterToken = "kat-etcd-cluster" // Make this configurable if needed embedCfg.InitialClusterToken = "kat-etcd-cluster" // Make this configurable if needed
embedCfg.ForceNewCluster = false // Set to true only for initial bootstrap of a new cluster if needed embedCfg.ForceNewCluster = false // Set to true only for initial bootstrap of a new cluster if needed
lpurl, err := parseURLs(cfg.PeerURLs) lpurl, err := parseURLs(cfg.PeerURLs)
if err != nil { if err != nil {

View File

@ -23,10 +23,10 @@ func TestEtcdStore(t *testing.T) {
// Configure and start embedded etcd // Configure and start embedded etcd
etcdConfig := EtcdEmbedConfig{ etcdConfig := EtcdEmbedConfig{
Name: "test-node", Name: "test-node",
DataDir: tempDir, DataDir: tempDir,
ClientURLs: []string{"http://localhost:0"}, // Use port 0 to get a random available port ClientURLs: []string{"http://localhost:0"}, // Use port 0 to get a random available port
PeerURLs: []string{"http://localhost:0"}, PeerURLs: []string{"http://localhost:0"},
} }
etcdServer, err := StartEmbeddedEtcd(etcdConfig) etcdServer, err := StartEmbeddedEtcd(etcdConfig)
@ -232,10 +232,10 @@ func TestLeaderElection(t *testing.T) {
// Configure and start embedded etcd // Configure and start embedded etcd
etcdConfig := EtcdEmbedConfig{ etcdConfig := EtcdEmbedConfig{
Name: "election-test-node", Name: "election-test-node",
DataDir: tempDir, DataDir: tempDir,
ClientURLs: []string{"http://localhost:0"}, ClientURLs: []string{"http://localhost:0"},
PeerURLs: []string{"http://localhost:0"}, PeerURLs: []string{"http://localhost:0"},
} }
etcdServer, err := StartEmbeddedEtcd(etcdConfig) etcdServer, err := StartEmbeddedEtcd(etcdConfig)

View File

@ -51,8 +51,8 @@ spec:
apiPort: 9115 apiPort: 9115
etcdPeerPort: 2380 etcdPeerPort: 2380
etcdClientPort: 2379 etcdClientPort: 2379
volumeBasePath: "/var/lib/kat/volumes" volumeBasePath: ".kat/volumes"
backupPath: "/var/lib/kat/backups" backupPath: ".kat/backups"
backupIntervalMinutes: 30 backupIntervalMinutes: 30
agentTickSeconds: 15 agentTickSeconds: 15
nodeLossTimeoutSeconds: 60 nodeLossTimeoutSeconds: 60