feat: Implement basic API server with mTLS for leader join endpoint
This commit is contained in:
139
internal/api/server_test.go
Normal file
139
internal/api/server_test.go
Normal file
@ -0,0 +1,139 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"kat-system/internal/pki"
|
||||
)
|
||||
|
||||
func TestServerWithMTLS(t *testing.T) {
|
||||
// Skip in short mode
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping test in short mode")
|
||||
}
|
||||
|
||||
// Create temporary directory for test certificates
|
||||
tempDir, err := os.MkdirTemp("", "kat-api-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Generate CA
|
||||
caKeyPath := filepath.Join(tempDir, "ca.key")
|
||||
caCertPath := filepath.Join(tempDir, "ca.crt")
|
||||
if err := pki.GenerateCA(caKeyPath, caCertPath, "KAT Test CA", 24*time.Hour); err != nil {
|
||||
t.Fatalf("Failed to generate CA: %v", err)
|
||||
}
|
||||
|
||||
// Generate server certificate
|
||||
serverKeyPath := filepath.Join(tempDir, "server.key")
|
||||
serverCSRPath := filepath.Join(tempDir, "server.csr")
|
||||
serverCertPath := filepath.Join(tempDir, "server.crt")
|
||||
if err := pki.GenerateCertificateRequest("server.test", serverKeyPath, serverCSRPath); err != nil {
|
||||
t.Fatalf("Failed to generate server CSR: %v", err)
|
||||
}
|
||||
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, serverCSRPath, serverCertPath, 24*time.Hour); err != nil {
|
||||
t.Fatalf("Failed to sign server certificate: %v", err)
|
||||
}
|
||||
|
||||
// Generate client certificate
|
||||
clientKeyPath := filepath.Join(tempDir, "client.key")
|
||||
clientCSRPath := filepath.Join(tempDir, "client.csr")
|
||||
clientCertPath := filepath.Join(tempDir, "client.crt")
|
||||
if err := pki.GenerateCertificateRequest("client.test", clientKeyPath, clientCSRPath); err != nil {
|
||||
t.Fatalf("Failed to generate client CSR: %v", err)
|
||||
}
|
||||
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, clientCSRPath, clientCertPath, 24*time.Hour); err != nil {
|
||||
t.Fatalf("Failed to sign client certificate: %v", err)
|
||||
}
|
||||
|
||||
// Create and start server
|
||||
server, err := NewServer("localhost:0", serverCertPath, serverKeyPath, caCertPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// Add a test handler
|
||||
server.router.HandleFunc("GET", "/test", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("test successful"))
|
||||
})
|
||||
|
||||
// Start server in a goroutine
|
||||
go func() {
|
||||
if err := server.Start(); err != nil && err != http.ErrServerClosed {
|
||||
t.Errorf("Server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Load CA cert
|
||||
caCert, err := os.ReadFile(caCertPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read CA cert: %v", err)
|
||||
}
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
|
||||
// Load client cert
|
||||
clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load client cert: %v", err)
|
||||
}
|
||||
|
||||
// Create HTTP client with mTLS
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: caCertPool,
|
||||
Certificates: []tls.Certificate{clientCert},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Test with valid client cert
|
||||
resp, err := client.Get("https://localhost:8443/test")
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(body), "test successful") {
|
||||
t.Errorf("Unexpected response: %s", body)
|
||||
}
|
||||
|
||||
// Test with no client cert (should fail)
|
||||
clientWithoutCert := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: caCertPool,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = clientWithoutCert.Get("https://localhost:8443/test")
|
||||
if err == nil {
|
||||
t.Error("Request without client cert should fail")
|
||||
}
|
||||
|
||||
// Shutdown server
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
server.Stop(ctx)
|
||||
}
|
Reference in New Issue
Block a user