package api import ( "context" "crypto/tls" "crypto/x509" "io" "net/http" "os" "path/filepath" "strings" "testing" "time" "git.dws.rip/dubey/kat/internal/pki" ) // TestServerWithMTLS tests the server with TLS configuration // Note: In Phase 2, we've temporarily disabled client certificate verification // to simplify the initial join process. This test has been updated to reflect that. func TestServerWithMTLS(t *testing.T) { // Skip in short mode if testing.Short() { t.Skip("Skipping test in short mode") } // Create temporary directory for test certificates tempDir, err := os.MkdirTemp("", "kat-api-test") if err != nil { t.Fatalf("Failed to create temp dir: %v", err) } defer os.RemoveAll(tempDir) // Generate CA caKeyPath := filepath.Join(tempDir, "ca.key") caCertPath := filepath.Join(tempDir, "ca.crt") if err := pki.GenerateCA(tempDir, caKeyPath, caCertPath); err != nil { t.Fatalf("Failed to generate CA: %v", err) } // Generate server certificate serverKeyPath := filepath.Join(tempDir, "server.key") serverCSRPath := filepath.Join(tempDir, "server.csr") serverCertPath := filepath.Join(tempDir, "server.crt") if err := pki.GenerateCertificateRequest("localhost", serverKeyPath, serverCSRPath); err != nil { t.Fatalf("Failed to generate server CSR: %v", err) } if err := pki.SignCertificateRequest(caKeyPath, caCertPath, serverCSRPath, serverCertPath, 24*time.Hour); err != nil { t.Fatalf("Failed to sign server certificate: %v", err) } // Generate client certificate clientKeyPath := filepath.Join(tempDir, "client.key") clientCSRPath := filepath.Join(tempDir, "client.csr") clientCertPath := filepath.Join(tempDir, "client.crt") if err := pki.GenerateCertificateRequest("client.test", clientKeyPath, clientCSRPath); err != nil { t.Fatalf("Failed to generate client CSR: %v", err) } if err := pki.SignCertificateRequest(caKeyPath, caCertPath, clientCSRPath, clientCertPath, 24*time.Hour); err != nil { t.Fatalf("Failed to sign client certificate: %v", err) } // Create and start server server, err := NewServer("localhost:8443", serverCertPath, serverKeyPath, caCertPath) if err != nil { t.Fatalf("Failed to create server: %v", err) } // Add a test handler server.router.HandleFunc("GET", "/test", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("test successful")) }) // Start server in a goroutine go func() { if err := server.Start(); err != nil && err != http.ErrServerClosed { t.Errorf("Server error: %v", err) } }() // Wait for server to start time.Sleep(250 * time.Millisecond) // Load CA cert caCert, err := os.ReadFile(caCertPath) if err != nil { t.Fatalf("Failed to read CA cert: %v", err) } caCertPool := x509.NewCertPool() caCertPool.AppendCertsFromPEM(caCert) // Load client cert clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) if err != nil { t.Fatalf("Failed to load client cert: %v", err) } // Create HTTP client with mTLS client := &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: caCertPool, Certificates: []tls.Certificate{clientCert}, }, }, } // Test with valid client cert resp, err := client.Get("https://localhost:8443/test") if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Failed to read response: %v", err) } if !strings.Contains(string(body), "test successful") { t.Errorf("Unexpected response: %s", body) } // Test with no client cert (should succeed in Phase 2) clientWithoutCert := &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: caCertPool, }, }, } resp, err = clientWithoutCert.Get("https://localhost:8443/test") if err != nil { t.Errorf("Request without client cert should succeed in Phase 2: %v", err) } else { defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("Failed to read response: %v", err) } if !strings.Contains(string(body), "test successful") { t.Errorf("Unexpected response: %s", body) } } // Shutdown server ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() server.Stop(ctx) }