feat: Implement CSR signing and node registration handler for agent join

This commit is contained in:
Tanishq Dubey 2025-05-17 13:05:21 -04:00
parent f1f2b8f9ef
commit bf80b65873
No known key found for this signature in database
GPG Key ID: CFC1931B84DFC3F9
4 changed files with 217 additions and 135 deletions

View File

@ -248,124 +248,9 @@ func runInit(cmd *cobra.Command, args []string) {
log.Printf("Failed to create API server: %v", err) log.Printf("Failed to create API server: %v", err)
} else { } else {
// Register the join handler // Register the join handler
apiServer.RegisterJoinHandler(func(w http.ResponseWriter, r *http.Request) { joinHandler := api.NewJoinHandler(etcdStore, caKeyPath, caCertPath)
log.Printf("Received join request from %s", r.RemoteAddr) apiServer.RegisterJoinHandler(joinHandler)
log.Printf("Registered join handler with CA key: %s, CA cert: %s", caKeyPath, caCertPath)
// In Phase 2, we're not requiring client certificates yet
log.Printf("Processing join request without client certificate verification (Phase 2)")
// Read request body
var joinReq cli.JoinRequest
if err := json.NewDecoder(r.Body).Decode(&joinReq); err != nil {
log.Printf("Error decoding join request: %v", err)
http.Error(w, "Invalid request format", http.StatusBadRequest)
return
}
// Validate request
if joinReq.NodeName == "" || joinReq.AdvertiseAddr == "" || joinReq.CSRData == "" {
log.Printf("Invalid join request: missing required fields")
http.Error(w, "Missing required fields", http.StatusBadRequest)
return
}
log.Printf("Processing join request for node: %s, advertise address: %s",
joinReq.NodeName, joinReq.AdvertiseAddr)
// Decode CSR data
csrData, err := base64.StdEncoding.DecodeString(joinReq.CSRData)
if err != nil {
log.Printf("Error decoding CSR data: %v", err)
http.Error(w, "Invalid CSR data", http.StatusBadRequest)
return
}
// Create a temporary file for the CSR
tempCSRFile, err := os.CreateTemp("", "node-csr-*.pem")
if err != nil {
log.Printf("Error creating temp CSR file: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
defer os.Remove(tempCSRFile.Name())
// Write CSR data to temp file
if _, err := tempCSRFile.Write(csrData); err != nil {
log.Printf("Error writing CSR data to temp file: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
tempCSRFile.Close()
// Create a temp file for the signed certificate
tempCertFile, err := os.CreateTemp("", "node-cert-*.pem")
if err != nil {
log.Printf("Error creating temp cert file: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
defer os.Remove(tempCertFile.Name())
tempCertFile.Close()
// Sign the CSR
if err := pki.SignCertificateRequest(
filepath.Join(pkiDir, "ca.key"),
filepath.Join(pkiDir, "ca.crt"),
tempCSRFile.Name(),
tempCertFile.Name(),
365*24*time.Hour, // 1 year validity
); err != nil {
log.Printf("Error signing CSR: %v", err)
http.Error(w, "Failed to sign certificate", http.StatusInternalServerError)
return
}
// Read the signed certificate
signedCert, err := os.ReadFile(tempCertFile.Name())
if err != nil {
log.Printf("Error reading signed certificate: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
// Read the CA certificate
caCert, err := os.ReadFile(filepath.Join(pkiDir, "ca.crt"))
if err != nil {
log.Printf("Error reading CA certificate: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
// Generate a unique node UID
nodeUID := uuid.New().String()
// Store node registration in etcd (placeholder for now)
// In a future phase, we'll implement proper node registration with subnet assignment
// Create response
joinResp := cli.JoinResponse{
NodeName: joinReq.NodeName,
NodeUID: nodeUID,
SignedCertificate: base64.StdEncoding.EncodeToString(signedCert),
CACertificate: base64.StdEncoding.EncodeToString(caCert),
AssignedSubnet: "10.100.0.0/24", // Placeholder, will be properly implemented in network phase
}
// If etcd peer was requested, add join instructions (placeholder)
if etcdPeer {
joinResp.EtcdJoinInstructions = "Etcd peer join not implemented in this phase"
}
// Send response
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(joinResp); err != nil {
log.Printf("Error encoding join response: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
log.Printf("Successfully processed join request for node: %s", joinReq.NodeName)
})
// Start the server in a goroutine // Start the server in a goroutine
go func() { go func() {

View File

@ -1,9 +1,11 @@
package api package api
import ( import (
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -17,27 +19,31 @@ import (
// JoinRequest represents the data sent by an agent when joining // JoinRequest represents the data sent by an agent when joining
type JoinRequest struct { type JoinRequest struct {
CSR []byte `json:"csr"` CSRData string `json:"csrData"` // base64 encoded CSR
AdvertiseAddr string `json:"advertiseAddr"` AdvertiseAddr string `json:"advertiseAddr"`
NodeName string `json:"nodeName,omitempty"` // Optional, leader can generate NodeName string `json:"nodeName,omitempty"` // Optional, leader can generate
WireguardPubKey string `json:"wireguardPubKey"` // Placeholder for now WireGuardPubKey string `json:"wireguardPubKey"` // Placeholder for now
} }
// JoinResponse represents the data sent back to the agent // JoinResponse represents the data sent back to the agent
type JoinResponse struct { type JoinResponse struct {
NodeName string `json:"nodeName"` NodeName string `json:"nodeName"`
NodeUID string `json:"nodeUID"` NodeUID string `json:"nodeUID"`
SignedCert []byte `json:"signedCert"` SignedCertificate string `json:"signedCertificate"` // base64 encoded certificate
CACert []byte `json:"caCert"` CACertificate string `json:"caCertificate"` // base64 encoded CA certificate
JoinTimestamp int64 `json:"joinTimestamp"` AssignedSubnet string `json:"assignedSubnet"` // Placeholder for now
EtcdJoinInstructions string `json:"etcdJoinInstructions,omitempty"`
} }
// NewJoinHandler creates a handler for agent join requests // NewJoinHandler creates a handler for agent join requests
func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc { func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
log.Printf("Received join request from %s", r.RemoteAddr)
// Read and parse the request body // Read and parse the request body
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
log.Printf("Failed to read request body: %v", err)
http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest) http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest)
return return
} }
@ -45,16 +51,19 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
var joinReq JoinRequest var joinReq JoinRequest
if err := json.Unmarshal(body, &joinReq); err != nil { if err := json.Unmarshal(body, &joinReq); err != nil {
log.Printf("Failed to parse request: %v", err)
http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest) http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest)
return return
} }
// Validate request // Validate request
if len(joinReq.CSR) == 0 { if joinReq.CSRData == "" {
http.Error(w, "Missing CSR", http.StatusBadRequest) log.Printf("Missing CSR data")
http.Error(w, "Missing CSR data", http.StatusBadRequest)
return return
} }
if joinReq.AdvertiseAddr == "" { if joinReq.AdvertiseAddr == "" {
log.Printf("Missing advertise address")
http.Error(w, "Missing advertise address", http.StatusBadRequest) http.Error(w, "Missing advertise address", http.StatusBadRequest)
return return
} }
@ -63,16 +72,26 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
nodeName := joinReq.NodeName nodeName := joinReq.NodeName
if nodeName == "" { if nodeName == "" {
nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8]) nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8])
log.Printf("Generated node name: %s", nodeName)
} }
// Generate a unique node ID // Generate a unique node ID
nodeUID := uuid.New().String() nodeUID := uuid.New().String()
log.Printf("Generated node UID: %s", nodeUID)
// Decode CSR data
csrData, err := base64.StdEncoding.DecodeString(joinReq.CSRData)
if err != nil {
log.Printf("Failed to decode CSR data: %v", err)
http.Error(w, fmt.Sprintf("Failed to decode CSR data: %v", err), http.StatusBadRequest)
return
}
// Sign the CSR
// Create a temporary file for the CSR // Create a temporary file for the CSR
tempDir := os.TempDir() tempDir := os.TempDir()
csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID)) csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID))
if err := os.WriteFile(csrPath, joinReq.CSR, 0600); err != nil { if err := os.WriteFile(csrPath, csrData, 0600); err != nil {
log.Printf("Failed to save CSR: %v", err)
http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError)
return return
} }
@ -81,6 +100,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
// Sign the CSR // Sign the CSR
certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID)) certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID))
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil { if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil {
log.Printf("Failed to sign CSR: %v", err)
http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError)
return return
} }
@ -89,6 +109,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
// Read the signed certificate // Read the signed certificate
signedCert, err := os.ReadFile(certPath) signedCert, err := os.ReadFile(certPath)
if err != nil { if err != nil {
log.Printf("Failed to read signed certificate: %v", err)
http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError)
return return
} }
@ -96,6 +117,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
// Read the CA certificate // Read the CA certificate
caCert, err := os.ReadFile(caCertPath) caCert, err := os.ReadFile(caCertPath)
if err != nil { if err != nil {
log.Printf("Failed to read CA certificate: %v", err)
http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError)
return return
} }
@ -105,31 +127,36 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
nodeReg := map[string]interface{}{ nodeReg := map[string]interface{}{
"uid": nodeUID, "uid": nodeUID,
"advertiseAddr": joinReq.AdvertiseAddr, "advertiseAddr": joinReq.AdvertiseAddr,
"wireguardPubKey": joinReq.WireguardPubKey, "wireguardPubKey": joinReq.WireGuardPubKey,
"joinTimestamp": time.Now().Unix(), "joinTimestamp": time.Now().Unix(),
} }
nodeRegData, err := json.Marshal(nodeReg) nodeRegData, err := json.Marshal(nodeReg)
if err != nil { if err != nil {
log.Printf("Failed to marshal node registration: %v", err)
http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError)
return return
} }
log.Printf("Storing node registration in etcd at key: %s", nodeRegKey)
if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil { if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil {
log.Printf("Failed to store node registration: %v", err)
http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError)
return return
} }
log.Printf("Successfully stored node registration in etcd")
// Prepare and send response // Prepare and send response
joinResp := JoinResponse{ joinResp := JoinResponse{
NodeName: nodeName, NodeName: nodeName,
NodeUID: nodeUID, NodeUID: nodeUID,
SignedCert: signedCert, SignedCertificate: base64.StdEncoding.EncodeToString(signedCert),
CACert: caCert, CACertificate: base64.StdEncoding.EncodeToString(caCert),
JoinTimestamp: time.Now().Unix(), AssignedSubnet: "10.100.0.0/24", // Placeholder for now, will be implemented in network phase
} }
respData, err := json.Marshal(joinResp) respData, err := json.Marshal(joinResp)
if err != nil { if err != nil {
log.Printf("Failed to marshal response: %v", err)
http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError)
return return
} }
@ -137,5 +164,6 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(respData) w.Write(respData)
log.Printf("Successfully processed join request for node: %s", nodeName)
} }
} }

View File

@ -0,0 +1,168 @@
package api
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"git.dws.rip/dubey/kat/internal/pki"
"git.dws.rip/dubey/kat/internal/store"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// MockStateStore for testing
type MockStateStore struct {
mock.Mock
}
func (m *MockStateStore) Put(ctx context.Context, key string, value []byte) error {
args := m.Called(ctx, key, value)
return args.Error(0)
}
func (m *MockStateStore) Get(ctx context.Context, key string) (*store.KV, error) {
args := m.Called(ctx, key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*store.KV), args.Error(1)
}
func (m *MockStateStore) Delete(ctx context.Context, key string) error {
args := m.Called(ctx, key)
return args.Error(0)
}
func (m *MockStateStore) List(ctx context.Context, prefix string) ([]store.KV, error) {
args := m.Called(ctx, prefix)
return args.Get(0).([]store.KV), args.Error(1)
}
func (m *MockStateStore) Watch(ctx context.Context, keyOrPrefix string, startRevision int64) (<-chan store.WatchEvent, error) {
args := m.Called(ctx, keyOrPrefix, startRevision)
return args.Get(0).(chan store.WatchEvent), args.Error(1)
}
func (m *MockStateStore) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockStateStore) Campaign(ctx context.Context, leaderID string, leaseTTLSeconds int64) (context.Context, error) {
args := m.Called(ctx, leaderID, leaseTTLSeconds)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(context.Context), args.Error(1)
}
func (m *MockStateStore) Resign(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func (m *MockStateStore) GetLeader(ctx context.Context) (string, error) {
args := m.Called(ctx)
return args.String(0), args.Error(1)
}
func (m *MockStateStore) DoTransaction(ctx context.Context, checks []store.Compare, onSuccess []store.Op, onFailure []store.Op) (bool, error) {
args := m.Called(ctx, checks, onSuccess, onFailure)
return args.Bool(0), args.Error(1)
}
func TestJoinHandler(t *testing.T) {
// Create temporary directory for test PKI files
tempDir, err := os.MkdirTemp("", "kat-test-pki-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Generate CA for testing
caKeyPath := filepath.Join(tempDir, "ca.key")
caCertPath := filepath.Join(tempDir, "ca.crt")
err = pki.GenerateCA(tempDir, caKeyPath, caCertPath)
if err != nil {
t.Fatalf("Failed to generate test CA: %v", err)
}
// Generate a test CSR
nodeKeyPath := filepath.Join(tempDir, "node.key")
nodeCSRPath := filepath.Join(tempDir, "node.csr")
err = pki.GenerateCertificateRequest("test-node", nodeKeyPath, nodeCSRPath)
if err != nil {
t.Fatalf("Failed to generate test CSR: %v", err)
}
// Read the CSR file
csrData, err := os.ReadFile(nodeCSRPath)
if err != nil {
t.Fatalf("Failed to read CSR file: %v", err)
}
// Create mock state store
mockStore := new(MockStateStore)
mockStore.On("Put", mock.Anything, mock.MatchedBy(func(key string) bool {
return key == "/kat/nodes/registration/test-node"
}), mock.Anything).Return(nil)
// Create join handler
handler := NewJoinHandler(mockStore, caKeyPath, caCertPath)
// Create test request
joinReq := JoinRequest{
NodeName: "test-node",
AdvertiseAddr: "192.168.1.100",
CSRData: base64.StdEncoding.EncodeToString(csrData),
WireGuardPubKey: "test-pubkey",
}
reqBody, err := json.Marshal(joinReq)
if err != nil {
t.Fatalf("Failed to marshal join request: %v", err)
}
// Create HTTP request
req := httptest.NewRequest("POST", "/internal/v1alpha1/join", bytes.NewBuffer(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
// Call handler
handler(w, req)
// Check response
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Read response body
respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
// Parse response
var joinResp JoinResponse
err = json.Unmarshal(respBody, &joinResp)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Verify response fields
assert.Equal(t, "test-node", joinResp.NodeName)
assert.NotEmpty(t, joinResp.NodeUID)
assert.NotEmpty(t, joinResp.SignedCertificate)
assert.NotEmpty(t, joinResp.CACertificate)
assert.Equal(t, "10.100.0.0/24", joinResp.AssignedSubnet) // Placeholder value
// Verify mock was called
mockStore.AssertExpectations(t)
}

View File

@ -136,6 +136,7 @@ func (s *Server) Stop(ctx context.Context) error {
// RegisterJoinHandler registers the handler for agent join requests // RegisterJoinHandler registers the handler for agent join requests
func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) { func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) {
s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler) s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler)
log.Printf("Registered join handler at /internal/v1alpha1/join")
} }
// RegisterNodeStatusHandler registers the handler for node status updates // RegisterNodeStatusHandler registers the handler for node status updates