package api import ( "bytes" "context" "encoding/base64" "encoding/json" "io" "net/http" "net/http/httptest" "os" "path/filepath" "testing" "git.dws.rip/dubey/kat/internal/pki" "git.dws.rip/dubey/kat/internal/store" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) // MockStateStore for testing type MockStateStore struct { mock.Mock } func (m *MockStateStore) Put(ctx context.Context, key string, value []byte) error { args := m.Called(ctx, key, value) return args.Error(0) } func (m *MockStateStore) Get(ctx context.Context, key string) (*store.KV, error) { args := m.Called(ctx, key) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*store.KV), args.Error(1) } func (m *MockStateStore) Delete(ctx context.Context, key string) error { args := m.Called(ctx, key) return args.Error(0) } func (m *MockStateStore) List(ctx context.Context, prefix string) ([]store.KV, error) { args := m.Called(ctx, prefix) return args.Get(0).([]store.KV), args.Error(1) } func (m *MockStateStore) Watch(ctx context.Context, keyOrPrefix string, startRevision int64) (<-chan store.WatchEvent, error) { args := m.Called(ctx, keyOrPrefix, startRevision) return args.Get(0).(chan store.WatchEvent), args.Error(1) } func (m *MockStateStore) Close() error { args := m.Called() return args.Error(0) } func (m *MockStateStore) Campaign(ctx context.Context, leaderID string, leaseTTLSeconds int64) (context.Context, error) { args := m.Called(ctx, leaderID, leaseTTLSeconds) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(context.Context), args.Error(1) } func (m *MockStateStore) Resign(ctx context.Context) error { args := m.Called(ctx) return args.Error(0) } func (m *MockStateStore) GetLeader(ctx context.Context) (string, error) { args := m.Called(ctx) return args.String(0), args.Error(1) } func (m *MockStateStore) DoTransaction(ctx context.Context, checks []store.Compare, onSuccess []store.Op, onFailure []store.Op) (bool, error) { args := m.Called(ctx, checks, onSuccess, onFailure) return args.Bool(0), args.Error(1) } func TestJoinHandler(t *testing.T) { // Create temporary directory for test PKI files tempDir, err := os.MkdirTemp("", "kat-test-pki-*") if err != nil { t.Fatalf("Failed to create temp directory: %v", err) } defer os.RemoveAll(tempDir) // Generate CA for testing caKeyPath := filepath.Join(tempDir, "ca.key") caCertPath := filepath.Join(tempDir, "ca.crt") err = pki.GenerateCA(tempDir, caKeyPath, caCertPath) if err != nil { t.Fatalf("Failed to generate test CA: %v", err) } // Generate a test CSR nodeKeyPath := filepath.Join(tempDir, "node.key") nodeCSRPath := filepath.Join(tempDir, "node.csr") err = pki.GenerateCertificateRequest("test-node", nodeKeyPath, nodeCSRPath) if err != nil { t.Fatalf("Failed to generate test CSR: %v", err) } // Read the CSR file csrData, err := os.ReadFile(nodeCSRPath) if err != nil { t.Fatalf("Failed to read CSR file: %v", err) } // Create mock state store mockStore := new(MockStateStore) mockStore.On("Put", mock.Anything, mock.MatchedBy(func(key string) bool { return key == "/kat/nodes/registration/test-node" }), mock.Anything).Return(nil) // Create join handler handler := NewJoinHandler(mockStore, caKeyPath, caCertPath) // Create test request joinReq := JoinRequest{ NodeName: "test-node", AdvertiseAddr: "192.168.1.100", CSRData: base64.StdEncoding.EncodeToString(csrData), WireGuardPubKey: "test-pubkey", } reqBody, err := json.Marshal(joinReq) if err != nil { t.Fatalf("Failed to marshal join request: %v", err) } // Create HTTP request req := httptest.NewRequest("POST", "/internal/v1alpha1/join", bytes.NewBuffer(reqBody)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() // Call handler handler(w, req) // Check response resp := w.Result() defer resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) // Read response body respBody, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Failed to read response body: %v", err) } // Parse response var joinResp JoinResponse err = json.Unmarshal(respBody, &joinResp) if err != nil { t.Fatalf("Failed to parse response: %v", err) } // Verify response fields assert.Equal(t, "test-node", joinResp.NodeName) assert.NotEmpty(t, joinResp.NodeUID) assert.NotEmpty(t, joinResp.SignedCertificate) assert.NotEmpty(t, joinResp.CACertificate) assert.Equal(t, "10.100.0.0/24", joinResp.AssignedSubnet) // Placeholder value // Verify mock was called mockStore.AssertExpectations(t) }