feat: Implement CSR signing and node registration handler for agent join
This commit is contained in:
parent
f1f2b8f9ef
commit
bf80b65873
@ -248,124 +248,9 @@ func runInit(cmd *cobra.Command, args []string) {
|
|||||||
log.Printf("Failed to create API server: %v", err)
|
log.Printf("Failed to create API server: %v", err)
|
||||||
} else {
|
} else {
|
||||||
// Register the join handler
|
// Register the join handler
|
||||||
apiServer.RegisterJoinHandler(func(w http.ResponseWriter, r *http.Request) {
|
joinHandler := api.NewJoinHandler(etcdStore, caKeyPath, caCertPath)
|
||||||
log.Printf("Received join request from %s", r.RemoteAddr)
|
apiServer.RegisterJoinHandler(joinHandler)
|
||||||
|
log.Printf("Registered join handler with CA key: %s, CA cert: %s", caKeyPath, caCertPath)
|
||||||
// In Phase 2, we're not requiring client certificates yet
|
|
||||||
log.Printf("Processing join request without client certificate verification (Phase 2)")
|
|
||||||
|
|
||||||
// Read request body
|
|
||||||
var joinReq cli.JoinRequest
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&joinReq); err != nil {
|
|
||||||
log.Printf("Error decoding join request: %v", err)
|
|
||||||
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate request
|
|
||||||
if joinReq.NodeName == "" || joinReq.AdvertiseAddr == "" || joinReq.CSRData == "" {
|
|
||||||
log.Printf("Invalid join request: missing required fields")
|
|
||||||
http.Error(w, "Missing required fields", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Processing join request for node: %s, advertise address: %s",
|
|
||||||
joinReq.NodeName, joinReq.AdvertiseAddr)
|
|
||||||
|
|
||||||
// Decode CSR data
|
|
||||||
csrData, err := base64.StdEncoding.DecodeString(joinReq.CSRData)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error decoding CSR data: %v", err)
|
|
||||||
http.Error(w, "Invalid CSR data", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a temporary file for the CSR
|
|
||||||
tempCSRFile, err := os.CreateTemp("", "node-csr-*.pem")
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error creating temp CSR file: %v", err)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer os.Remove(tempCSRFile.Name())
|
|
||||||
|
|
||||||
// Write CSR data to temp file
|
|
||||||
if _, err := tempCSRFile.Write(csrData); err != nil {
|
|
||||||
log.Printf("Error writing CSR data to temp file: %v", err)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tempCSRFile.Close()
|
|
||||||
|
|
||||||
// Create a temp file for the signed certificate
|
|
||||||
tempCertFile, err := os.CreateTemp("", "node-cert-*.pem")
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error creating temp cert file: %v", err)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer os.Remove(tempCertFile.Name())
|
|
||||||
tempCertFile.Close()
|
|
||||||
|
|
||||||
// Sign the CSR
|
|
||||||
if err := pki.SignCertificateRequest(
|
|
||||||
filepath.Join(pkiDir, "ca.key"),
|
|
||||||
filepath.Join(pkiDir, "ca.crt"),
|
|
||||||
tempCSRFile.Name(),
|
|
||||||
tempCertFile.Name(),
|
|
||||||
365*24*time.Hour, // 1 year validity
|
|
||||||
); err != nil {
|
|
||||||
log.Printf("Error signing CSR: %v", err)
|
|
||||||
http.Error(w, "Failed to sign certificate", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read the signed certificate
|
|
||||||
signedCert, err := os.ReadFile(tempCertFile.Name())
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error reading signed certificate: %v", err)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read the CA certificate
|
|
||||||
caCert, err := os.ReadFile(filepath.Join(pkiDir, "ca.crt"))
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error reading CA certificate: %v", err)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate a unique node UID
|
|
||||||
nodeUID := uuid.New().String()
|
|
||||||
|
|
||||||
// Store node registration in etcd (placeholder for now)
|
|
||||||
// In a future phase, we'll implement proper node registration with subnet assignment
|
|
||||||
|
|
||||||
// Create response
|
|
||||||
joinResp := cli.JoinResponse{
|
|
||||||
NodeName: joinReq.NodeName,
|
|
||||||
NodeUID: nodeUID,
|
|
||||||
SignedCertificate: base64.StdEncoding.EncodeToString(signedCert),
|
|
||||||
CACertificate: base64.StdEncoding.EncodeToString(caCert),
|
|
||||||
AssignedSubnet: "10.100.0.0/24", // Placeholder, will be properly implemented in network phase
|
|
||||||
}
|
|
||||||
|
|
||||||
// If etcd peer was requested, add join instructions (placeholder)
|
|
||||||
if etcdPeer {
|
|
||||||
joinResp.EtcdJoinInstructions = "Etcd peer join not implemented in this phase"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send response
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
if err := json.NewEncoder(w).Encode(joinResp); err != nil {
|
|
||||||
log.Printf("Error encoding join response: %v", err)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Successfully processed join request for node: %s", joinReq.NodeName)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Start the server in a goroutine
|
// Start the server in a goroutine
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -17,27 +19,31 @@ import (
|
|||||||
|
|
||||||
// JoinRequest represents the data sent by an agent when joining
|
// JoinRequest represents the data sent by an agent when joining
|
||||||
type JoinRequest struct {
|
type JoinRequest struct {
|
||||||
CSR []byte `json:"csr"`
|
CSRData string `json:"csrData"` // base64 encoded CSR
|
||||||
AdvertiseAddr string `json:"advertiseAddr"`
|
AdvertiseAddr string `json:"advertiseAddr"`
|
||||||
NodeName string `json:"nodeName,omitempty"` // Optional, leader can generate
|
NodeName string `json:"nodeName,omitempty"` // Optional, leader can generate
|
||||||
WireguardPubKey string `json:"wireguardPubKey"` // Placeholder for now
|
WireGuardPubKey string `json:"wireguardPubKey"` // Placeholder for now
|
||||||
}
|
}
|
||||||
|
|
||||||
// JoinResponse represents the data sent back to the agent
|
// JoinResponse represents the data sent back to the agent
|
||||||
type JoinResponse struct {
|
type JoinResponse struct {
|
||||||
NodeName string `json:"nodeName"`
|
NodeName string `json:"nodeName"`
|
||||||
NodeUID string `json:"nodeUID"`
|
NodeUID string `json:"nodeUID"`
|
||||||
SignedCert []byte `json:"signedCert"`
|
SignedCertificate string `json:"signedCertificate"` // base64 encoded certificate
|
||||||
CACert []byte `json:"caCert"`
|
CACertificate string `json:"caCertificate"` // base64 encoded CA certificate
|
||||||
JoinTimestamp int64 `json:"joinTimestamp"`
|
AssignedSubnet string `json:"assignedSubnet"` // Placeholder for now
|
||||||
|
EtcdJoinInstructions string `json:"etcdJoinInstructions,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewJoinHandler creates a handler for agent join requests
|
// NewJoinHandler creates a handler for agent join requests
|
||||||
func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc {
|
func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log.Printf("Received join request from %s", r.RemoteAddr)
|
||||||
|
|
||||||
// Read and parse the request body
|
// Read and parse the request body
|
||||||
body, err := io.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("Failed to read request body: %v", err)
|
||||||
http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest)
|
http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -45,16 +51,19 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
|
|||||||
|
|
||||||
var joinReq JoinRequest
|
var joinReq JoinRequest
|
||||||
if err := json.Unmarshal(body, &joinReq); err != nil {
|
if err := json.Unmarshal(body, &joinReq); err != nil {
|
||||||
|
log.Printf("Failed to parse request: %v", err)
|
||||||
http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest)
|
http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate request
|
// Validate request
|
||||||
if len(joinReq.CSR) == 0 {
|
if joinReq.CSRData == "" {
|
||||||
http.Error(w, "Missing CSR", http.StatusBadRequest)
|
log.Printf("Missing CSR data")
|
||||||
|
http.Error(w, "Missing CSR data", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if joinReq.AdvertiseAddr == "" {
|
if joinReq.AdvertiseAddr == "" {
|
||||||
|
log.Printf("Missing advertise address")
|
||||||
http.Error(w, "Missing advertise address", http.StatusBadRequest)
|
http.Error(w, "Missing advertise address", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -63,16 +72,26 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
|
|||||||
nodeName := joinReq.NodeName
|
nodeName := joinReq.NodeName
|
||||||
if nodeName == "" {
|
if nodeName == "" {
|
||||||
nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8])
|
nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8])
|
||||||
|
log.Printf("Generated node name: %s", nodeName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate a unique node ID
|
// Generate a unique node ID
|
||||||
nodeUID := uuid.New().String()
|
nodeUID := uuid.New().String()
|
||||||
|
log.Printf("Generated node UID: %s", nodeUID)
|
||||||
|
|
||||||
|
// Decode CSR data
|
||||||
|
csrData, err := base64.StdEncoding.DecodeString(joinReq.CSRData)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to decode CSR data: %v", err)
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to decode CSR data: %v", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Sign the CSR
|
|
||||||
// Create a temporary file for the CSR
|
// Create a temporary file for the CSR
|
||||||
tempDir := os.TempDir()
|
tempDir := os.TempDir()
|
||||||
csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID))
|
csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID))
|
||||||
if err := os.WriteFile(csrPath, joinReq.CSR, 0600); err != nil {
|
if err := os.WriteFile(csrPath, csrData, 0600); err != nil {
|
||||||
|
log.Printf("Failed to save CSR: %v", err)
|
||||||
http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -81,6 +100,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
|
|||||||
// Sign the CSR
|
// Sign the CSR
|
||||||
certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID))
|
certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID))
|
||||||
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil {
|
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil {
|
||||||
|
log.Printf("Failed to sign CSR: %v", err)
|
||||||
http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -89,6 +109,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
|
|||||||
// Read the signed certificate
|
// Read the signed certificate
|
||||||
signedCert, err := os.ReadFile(certPath)
|
signedCert, err := os.ReadFile(certPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("Failed to read signed certificate: %v", err)
|
||||||
http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -96,6 +117,7 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
|
|||||||
// Read the CA certificate
|
// Read the CA certificate
|
||||||
caCert, err := os.ReadFile(caCertPath)
|
caCert, err := os.ReadFile(caCertPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("Failed to read CA certificate: %v", err)
|
||||||
http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -105,31 +127,36 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
|
|||||||
nodeReg := map[string]interface{}{
|
nodeReg := map[string]interface{}{
|
||||||
"uid": nodeUID,
|
"uid": nodeUID,
|
||||||
"advertiseAddr": joinReq.AdvertiseAddr,
|
"advertiseAddr": joinReq.AdvertiseAddr,
|
||||||
"wireguardPubKey": joinReq.WireguardPubKey,
|
"wireguardPubKey": joinReq.WireGuardPubKey,
|
||||||
"joinTimestamp": time.Now().Unix(),
|
"joinTimestamp": time.Now().Unix(),
|
||||||
}
|
}
|
||||||
nodeRegData, err := json.Marshal(nodeReg)
|
nodeRegData, err := json.Marshal(nodeReg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("Failed to marshal node registration: %v", err)
|
||||||
http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Printf("Storing node registration in etcd at key: %s", nodeRegKey)
|
||||||
if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil {
|
if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil {
|
||||||
|
log.Printf("Failed to store node registration: %v", err)
|
||||||
http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
log.Printf("Successfully stored node registration in etcd")
|
||||||
|
|
||||||
// Prepare and send response
|
// Prepare and send response
|
||||||
joinResp := JoinResponse{
|
joinResp := JoinResponse{
|
||||||
NodeName: nodeName,
|
NodeName: nodeName,
|
||||||
NodeUID: nodeUID,
|
NodeUID: nodeUID,
|
||||||
SignedCert: signedCert,
|
SignedCertificate: base64.StdEncoding.EncodeToString(signedCert),
|
||||||
CACert: caCert,
|
CACertificate: base64.StdEncoding.EncodeToString(caCert),
|
||||||
JoinTimestamp: time.Now().Unix(),
|
AssignedSubnet: "10.100.0.0/24", // Placeholder for now, will be implemented in network phase
|
||||||
}
|
}
|
||||||
|
|
||||||
respData, err := json.Marshal(joinResp)
|
respData, err := json.Marshal(joinResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("Failed to marshal response: %v", err)
|
||||||
http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -137,5 +164,6 @@ func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) h
|
|||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write(respData)
|
w.Write(respData)
|
||||||
|
log.Printf("Successfully processed join request for node: %s", nodeName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
168
internal/api/join_handler_test.go
Normal file
168
internal/api/join_handler_test.go
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.dws.rip/dubey/kat/internal/pki"
|
||||||
|
"git.dws.rip/dubey/kat/internal/store"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockStateStore for testing
|
||||||
|
type MockStateStore struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockStateStore) Put(ctx context.Context, key string, value []byte) error {
|
||||||
|
args := m.Called(ctx, key, value)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockStateStore) Get(ctx context.Context, key string) (*store.KV, error) {
|
||||||
|
args := m.Called(ctx, key)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
return args.Get(0).(*store.KV), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockStateStore) Delete(ctx context.Context, key string) error {
|
||||||
|
args := m.Called(ctx, key)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockStateStore) List(ctx context.Context, prefix string) ([]store.KV, error) {
|
||||||
|
args := m.Called(ctx, prefix)
|
||||||
|
return args.Get(0).([]store.KV), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockStateStore) Watch(ctx context.Context, keyOrPrefix string, startRevision int64) (<-chan store.WatchEvent, error) {
|
||||||
|
args := m.Called(ctx, keyOrPrefix, startRevision)
|
||||||
|
return args.Get(0).(chan store.WatchEvent), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockStateStore) Close() error {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockStateStore) Campaign(ctx context.Context, leaderID string, leaseTTLSeconds int64) (context.Context, error) {
|
||||||
|
args := m.Called(ctx, leaderID, leaseTTLSeconds)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
return args.Get(0).(context.Context), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockStateStore) Resign(ctx context.Context) error {
|
||||||
|
args := m.Called(ctx)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockStateStore) GetLeader(ctx context.Context) (string, error) {
|
||||||
|
args := m.Called(ctx)
|
||||||
|
return args.String(0), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockStateStore) DoTransaction(ctx context.Context, checks []store.Compare, onSuccess []store.Op, onFailure []store.Op) (bool, error) {
|
||||||
|
args := m.Called(ctx, checks, onSuccess, onFailure)
|
||||||
|
return args.Bool(0), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJoinHandler(t *testing.T) {
|
||||||
|
// Create temporary directory for test PKI files
|
||||||
|
tempDir, err := os.MkdirTemp("", "kat-test-pki-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Generate CA for testing
|
||||||
|
caKeyPath := filepath.Join(tempDir, "ca.key")
|
||||||
|
caCertPath := filepath.Join(tempDir, "ca.crt")
|
||||||
|
err = pki.GenerateCA(tempDir, caKeyPath, caCertPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate test CA: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a test CSR
|
||||||
|
nodeKeyPath := filepath.Join(tempDir, "node.key")
|
||||||
|
nodeCSRPath := filepath.Join(tempDir, "node.csr")
|
||||||
|
err = pki.GenerateCertificateRequest("test-node", nodeKeyPath, nodeCSRPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate test CSR: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the CSR file
|
||||||
|
csrData, err := os.ReadFile(nodeCSRPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read CSR file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create mock state store
|
||||||
|
mockStore := new(MockStateStore)
|
||||||
|
mockStore.On("Put", mock.Anything, mock.MatchedBy(func(key string) bool {
|
||||||
|
return key == "/kat/nodes/registration/test-node"
|
||||||
|
}), mock.Anything).Return(nil)
|
||||||
|
|
||||||
|
// Create join handler
|
||||||
|
handler := NewJoinHandler(mockStore, caKeyPath, caCertPath)
|
||||||
|
|
||||||
|
// Create test request
|
||||||
|
joinReq := JoinRequest{
|
||||||
|
NodeName: "test-node",
|
||||||
|
AdvertiseAddr: "192.168.1.100",
|
||||||
|
CSRData: base64.StdEncoding.EncodeToString(csrData),
|
||||||
|
WireGuardPubKey: "test-pubkey",
|
||||||
|
}
|
||||||
|
reqBody, err := json.Marshal(joinReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to marshal join request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create HTTP request
|
||||||
|
req := httptest.NewRequest("POST", "/internal/v1alpha1/join", bytes.NewBuffer(reqBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Call handler
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
resp := w.Result()
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
// Read response body
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read response body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse response
|
||||||
|
var joinResp JoinResponse
|
||||||
|
err = json.Unmarshal(respBody, &joinResp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify response fields
|
||||||
|
assert.Equal(t, "test-node", joinResp.NodeName)
|
||||||
|
assert.NotEmpty(t, joinResp.NodeUID)
|
||||||
|
assert.NotEmpty(t, joinResp.SignedCertificate)
|
||||||
|
assert.NotEmpty(t, joinResp.CACertificate)
|
||||||
|
assert.Equal(t, "10.100.0.0/24", joinResp.AssignedSubnet) // Placeholder value
|
||||||
|
|
||||||
|
// Verify mock was called
|
||||||
|
mockStore.AssertExpectations(t)
|
||||||
|
}
|
@ -136,6 +136,7 @@ func (s *Server) Stop(ctx context.Context) error {
|
|||||||
// RegisterJoinHandler registers the handler for agent join requests
|
// RegisterJoinHandler registers the handler for agent join requests
|
||||||
func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) {
|
func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) {
|
||||||
s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler)
|
s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler)
|
||||||
|
log.Printf("Registered join handler at /internal/v1alpha1/join")
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterNodeStatusHandler registers the handler for node status updates
|
// RegisterNodeStatusHandler registers the handler for node status updates
|
||||||
|
Loading…
x
Reference in New Issue
Block a user