feat: Implement CSR signing and node registration handler for agent join
This commit is contained in:
parent
f1f2b8f9ef
commit
bf80b65873
@ -248,124 +248,9 @@ func runInit(cmd *cobra.Command, args []string) {
|
||||
log.Printf("Failed to create API server: %v", err)
|
||||
} else {
|
||||
// Register the join handler
|
||||
apiServer.RegisterJoinHandler(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("Received join request from %s", r.RemoteAddr)
|
||||
|
||||
// In Phase 2, we're not requiring client certificates yet
|
||||
log.Printf("Processing join request without client certificate verification (Phase 2)")
|
||||
|
||||
// Read request body
|
||||
var joinReq cli.JoinRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&joinReq); err != nil {
|
||||
log.Printf("Error decoding join request: %v", err)
|
||||
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate request
|
||||
if joinReq.NodeName == "" || joinReq.AdvertiseAddr == "" || joinReq.CSRData == "" {
|
||||
log.Printf("Invalid join request: missing required fields")
|
||||
http.Error(w, "Missing required fields", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Processing join request for node: %s, advertise address: %s",
|
||||
joinReq.NodeName, joinReq.AdvertiseAddr)
|
||||
|
||||
// Decode CSR data
|
||||
csrData, err := base64.StdEncoding.DecodeString(joinReq.CSRData)
|
||||
if err != nil {
|
||||
log.Printf("Error decoding CSR data: %v", err)
|
||||
http.Error(w, "Invalid CSR data", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a temporary file for the CSR
|
||||
tempCSRFile, err := os.CreateTemp("", "node-csr-*.pem")
|
||||
if err != nil {
|
||||
log.Printf("Error creating temp CSR file: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer os.Remove(tempCSRFile.Name())
|
||||
|
||||
// Write CSR data to temp file
|
||||
if _, err := tempCSRFile.Write(csrData); err != nil {
|
||||
log.Printf("Error writing CSR data to temp file: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tempCSRFile.Close()
|
||||
|
||||
// Create a temp file for the signed certificate
|
||||
tempCertFile, err := os.CreateTemp("", "node-cert-*.pem")
|
||||
if err != nil {
|
||||
log.Printf("Error creating temp cert file: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer os.Remove(tempCertFile.Name())
|
||||
tempCertFile.Close()
|
||||
|
||||
// Sign the CSR
|
||||
if err := pki.SignCertificateRequest(
|
||||
filepath.Join(pkiDir, "ca.key"),
|
||||
filepath.Join(pkiDir, "ca.crt"),
|
||||
tempCSRFile.Name(),
|
||||
tempCertFile.Name(),
|
||||
365*24*time.Hour, // 1 year validity
|
||||
); err != nil {
|
||||
log.Printf("Error signing CSR: %v", err)
|
||||
http.Error(w, "Failed to sign certificate", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Read the signed certificate
|
||||
signedCert, err := os.ReadFile(tempCertFile.Name())
|
||||
if err != nil {
|
||||
log.Printf("Error reading signed certificate: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Read the CA certificate
|
||||
caCert, err := os.ReadFile(filepath.Join(pkiDir, "ca.crt"))
|
||||
if err != nil {
|
||||
log.Printf("Error reading CA certificate: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate a unique node UID
|
||||
nodeUID := uuid.New().String()
|
||||
|
||||
// Store node registration in etcd (placeholder for now)
|
||||
// In a future phase, we'll implement proper node registration with subnet assignment
|
||||
|
||||
// Create response
|
||||
joinResp := cli.JoinResponse{
|
||||
NodeName: joinReq.NodeName,
|
||||
NodeUID: nodeUID,
|
||||
SignedCertificate: base64.StdEncoding.EncodeToString(signedCert),
|
||||
CACertificate: base64.StdEncoding.EncodeToString(caCert),
|
||||
AssignedSubnet: "10.100.0.0/24", // Placeholder, will be properly implemented in network phase
|
||||
}
|
||||
|
||||
// If etcd peer was requested, add join instructions (placeholder)
|
||||
if etcdPeer {
|
||||
joinResp.EtcdJoinInstructions = "Etcd peer join not implemented in this phase"
|
||||
}
|
||||
|
||||
// Send response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(joinResp); err != nil {
|
||||
log.Printf("Error encoding join response: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Successfully processed join request for node: %s", joinReq.NodeName)
|
||||
})
|
||||
joinHandler := api.NewJoinHandler(etcdStore, caKeyPath, caCertPath)
|
||||
apiServer.RegisterJoinHandler(joinHandler)
|
||||
log.Printf("Registered join handler with CA key: %s, CA cert: %s", caKeyPath, caCertPath)
|
||||
|
||||
// Start the server in a goroutine
|
||||
go func() {
|
||||
|
@ -1,9 +1,11 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -17,27 +19,31 @@ import (
|
||||
|
||||
// JoinRequest represents the data sent by an agent when joining
|
||||
type JoinRequest struct {
|
||||
CSR []byte `json:"csr"`
|
||||
CSRData string `json:"csrData"` // base64 encoded CSR
|
||||
AdvertiseAddr string `json:"advertiseAddr"`
|
||||
NodeName string `json:"nodeName,omitempty"` // Optional, leader can generate
|
||||
WireguardPubKey string `json:"wireguardPubKey"` // Placeholder for now
|
||||
WireGuardPubKey string `json:"wireguardPubKey"` // Placeholder for now
|
||||
}
|
||||
|
||||
// JoinResponse represents the data sent back to the agent
|
||||
type JoinResponse struct {
|
||||
NodeName string `json:"nodeName"`
|
||||
NodeUID string `json:"nodeUID"`
|
||||
SignedCert []byte `json:"signedCert"`
|
||||
CACert []byte `json:"caCert"`
|
||||
JoinTimestamp int64 `json:"joinTimestamp"`
|
||||
SignedCertificate string `json:"signedCertificate"` // base64 encoded certificate
|
||||
CACertificate string `json:"caCertificate"` // base64 encoded CA certificate
|
||||
AssignedSubnet string `json:"assignedSubnet"` // Placeholder for now
|
||||
EtcdJoinInstructions string `json:"etcdJoinInstructions,omitempty"`
|
||||
}
|
||||
|
||||
// NewJoinHandler creates a handler for agent join requests
|
||||
func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("Received join request from %s", r.RemoteAddr)
|
||||
|
||||
// Read and parse the request body
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read request body: %v", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@ -45,16 +51,19 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
|
||||
|
||||
var joinReq JoinRequest
|
||||
if err := json.Unmarshal(body, &joinReq); err != nil {
|
||||
log.Printf("Failed to parse request: %v", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate request
|
||||
if len(joinReq.CSR) == 0 {
|
||||
http.Error(w, "Missing CSR", http.StatusBadRequest)
|
||||
if joinReq.CSRData == "" {
|
||||
log.Printf("Missing CSR data")
|
||||
http.Error(w, "Missing CSR data", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if joinReq.AdvertiseAddr == "" {
|
||||
log.Printf("Missing advertise address")
|
||||
http.Error(w, "Missing advertise address", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@ -63,16 +72,26 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
|
||||
nodeName := joinReq.NodeName
|
||||
if nodeName == "" {
|
||||
nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8])
|
||||
log.Printf("Generated node name: %s", nodeName)
|
||||
}
|
||||
|
||||
// Generate a unique node ID
|
||||
nodeUID := uuid.New().String()
|
||||
log.Printf("Generated node UID: %s", nodeUID)
|
||||
|
||||
// Decode CSR data
|
||||
csrData, err := base64.StdEncoding.DecodeString(joinReq.CSRData)
|
||||
if err != nil {
|
||||
log.Printf("Failed to decode CSR data: %v", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to decode CSR data: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Sign the CSR
|
||||
// Create a temporary file for the CSR
|
||||
tempDir := os.TempDir()
|
||||
csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID))
|
||||
if err := os.WriteFile(csrPath, joinReq.CSR, 0600); err != nil {
|
||||
if err := os.WriteFile(csrPath, csrData, 0600); err != nil {
|
||||
log.Printf("Failed to save CSR: %v", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@ -81,6 +100,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
|
||||
// Sign the CSR
|
||||
certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID))
|
||||
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil {
|
||||
log.Printf("Failed to sign CSR: %v", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@ -89,6 +109,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
|
||||
// Read the signed certificate
|
||||
signedCert, err := os.ReadFile(certPath)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read signed certificate: %v", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@ -96,6 +117,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
|
||||
// Read the CA certificate
|
||||
caCert, err := os.ReadFile(caCertPath)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read CA certificate: %v", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@ -105,31 +127,36 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
|
||||
nodeReg := map[string]interface{}{
|
||||
"uid": nodeUID,
|
||||
"advertiseAddr": joinReq.AdvertiseAddr,
|
||||
"wireguardPubKey": joinReq.WireguardPubKey,
|
||||
"wireguardPubKey": joinReq.WireGuardPubKey,
|
||||
"joinTimestamp": time.Now().Unix(),
|
||||
}
|
||||
nodeRegData, err := json.Marshal(nodeReg)
|
||||
if err != nil {
|
||||
log.Printf("Failed to marshal node registration: %v", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Storing node registration in etcd at key: %s", nodeRegKey)
|
||||
if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil {
|
||||
log.Printf("Failed to store node registration: %v", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
log.Printf("Successfully stored node registration in etcd")
|
||||
|
||||
// Prepare and send response
|
||||
joinResp := JoinResponse{
|
||||
NodeName: nodeName,
|
||||
NodeUID: nodeUID,
|
||||
SignedCert: signedCert,
|
||||
CACert: caCert,
|
||||
JoinTimestamp: time.Now().Unix(),
|
||||
SignedCertificate: base64.StdEncoding.EncodeToString(signedCert),
|
||||
CACertificate: base64.StdEncoding.EncodeToString(caCert),
|
||||
AssignedSubnet: "10.100.0.0/24", // Placeholder for now, will be implemented in network phase
|
||||
}
|
||||
|
||||
respData, err := json.Marshal(joinResp)
|
||||
if err != nil {
|
||||
log.Printf("Failed to marshal response: %v", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@ -137,5 +164,6 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(respData)
|
||||
log.Printf("Successfully processed join request for node: %s", nodeName)
|
||||
}
|
||||
}
|
||||
|
168
internal/api/join_handler_test.go
Normal file
168
internal/api/join_handler_test.go
Normal file
@ -0,0 +1,168 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"git.dws.rip/dubey/kat/internal/pki"
|
||||
"git.dws.rip/dubey/kat/internal/store"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockStateStore for testing
|
||||
type MockStateStore struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockStateStore) Put(ctx context.Context, key string, value []byte) error {
|
||||
args := m.Called(ctx, key, value)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockStateStore) Get(ctx context.Context, key string) (*store.KV, error) {
|
||||
args := m.Called(ctx, key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*store.KV), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockStateStore) Delete(ctx context.Context, key string) error {
|
||||
args := m.Called(ctx, key)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockStateStore) List(ctx context.Context, prefix string) ([]store.KV, error) {
|
||||
args := m.Called(ctx, prefix)
|
||||
return args.Get(0).([]store.KV), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockStateStore) Watch(ctx context.Context, keyOrPrefix string, startRevision int64) (<-chan store.WatchEvent, error) {
|
||||
args := m.Called(ctx, keyOrPrefix, startRevision)
|
||||
return args.Get(0).(chan store.WatchEvent), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockStateStore) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockStateStore) Campaign(ctx context.Context, leaderID string, leaseTTLSeconds int64) (context.Context, error) {
|
||||
args := m.Called(ctx, leaderID, leaseTTLSeconds)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(context.Context), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockStateStore) Resign(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockStateStore) GetLeader(ctx context.Context) (string, error) {
|
||||
args := m.Called(ctx)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockStateStore) DoTransaction(ctx context.Context, checks []store.Compare, onSuccess []store.Op, onFailure []store.Op) (bool, error) {
|
||||
args := m.Called(ctx, checks, onSuccess, onFailure)
|
||||
return args.Bool(0), args.Error(1)
|
||||
}
|
||||
|
||||
func TestJoinHandler(t *testing.T) {
|
||||
// Create temporary directory for test PKI files
|
||||
tempDir, err := os.MkdirTemp("", "kat-test-pki-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Generate CA for testing
|
||||
caKeyPath := filepath.Join(tempDir, "ca.key")
|
||||
caCertPath := filepath.Join(tempDir, "ca.crt")
|
||||
err = pki.GenerateCA(tempDir, caKeyPath, caCertPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test CA: %v", err)
|
||||
}
|
||||
|
||||
// Generate a test CSR
|
||||
nodeKeyPath := filepath.Join(tempDir, "node.key")
|
||||
nodeCSRPath := filepath.Join(tempDir, "node.csr")
|
||||
err = pki.GenerateCertificateRequest("test-node", nodeKeyPath, nodeCSRPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test CSR: %v", err)
|
||||
}
|
||||
|
||||
// Read the CSR file
|
||||
csrData, err := os.ReadFile(nodeCSRPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read CSR file: %v", err)
|
||||
}
|
||||
|
||||
// Create mock state store
|
||||
mockStore := new(MockStateStore)
|
||||
mockStore.On("Put", mock.Anything, mock.MatchedBy(func(key string) bool {
|
||||
return key == "/kat/nodes/registration/test-node"
|
||||
}), mock.Anything).Return(nil)
|
||||
|
||||
// Create join handler
|
||||
handler := NewJoinHandler(mockStore, caKeyPath, caCertPath)
|
||||
|
||||
// Create test request
|
||||
joinReq := JoinRequest{
|
||||
NodeName: "test-node",
|
||||
AdvertiseAddr: "192.168.1.100",
|
||||
CSRData: base64.StdEncoding.EncodeToString(csrData),
|
||||
WireGuardPubKey: "test-pubkey",
|
||||
}
|
||||
reqBody, err := json.Marshal(joinReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal join request: %v", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req := httptest.NewRequest("POST", "/internal/v1alpha1/join", bytes.NewBuffer(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call handler
|
||||
handler(w, req)
|
||||
|
||||
// Check response
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Read response body
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var joinResp JoinResponse
|
||||
err = json.Unmarshal(respBody, &joinResp)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
// Verify response fields
|
||||
assert.Equal(t, "test-node", joinResp.NodeName)
|
||||
assert.NotEmpty(t, joinResp.NodeUID)
|
||||
assert.NotEmpty(t, joinResp.SignedCertificate)
|
||||
assert.NotEmpty(t, joinResp.CACertificate)
|
||||
assert.Equal(t, "10.100.0.0/24", joinResp.AssignedSubnet) // Placeholder value
|
||||
|
||||
// Verify mock was called
|
||||
mockStore.AssertExpectations(t)
|
||||
}
|
@ -136,6 +136,7 @@ func (s *Server) Stop(ctx context.Context) error {
|
||||
// RegisterJoinHandler registers the handler for agent join requests
|
||||
func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) {
|
||||
s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler)
|
||||
log.Printf("Registered join handler at /internal/v1alpha1/join")
|
||||
}
|
||||
|
||||
// RegisterNodeStatusHandler registers the handler for node status updates
|
||||
|
Loading…
x
Reference in New Issue
Block a user