package api

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"io"
	"net/http"
	"os"
	"path/filepath"
	"strings"
	"testing"
	"time"

	"git.dws.rip/dubey/kat/internal/pki"
)

// TestServerWithMTLS tests the server with TLS configuration
// Note: In Phase 2, we've temporarily disabled client certificate verification
// to simplify the initial join process. This test has been updated to reflect that.
func TestServerWithMTLS(t *testing.T) {
	// Skip in short mode
	if testing.Short() {
		t.Skip("Skipping test in short mode")
	}

	// Create temporary directory for test certificates
	tempDir, err := os.MkdirTemp("", "kat-api-test")
	if err != nil {
		t.Fatalf("Failed to create temp dir: %v", err)
	}
	defer os.RemoveAll(tempDir)

	// Generate CA
	caKeyPath := filepath.Join(tempDir, "ca.key")
	caCertPath := filepath.Join(tempDir, "ca.crt")
	if err := pki.GenerateCA(tempDir, caKeyPath, caCertPath); err != nil {
		t.Fatalf("Failed to generate CA: %v", err)
	}

	// Generate server certificate
	serverKeyPath := filepath.Join(tempDir, "server.key")
	serverCSRPath := filepath.Join(tempDir, "server.csr")
	serverCertPath := filepath.Join(tempDir, "server.crt")
	if err := pki.GenerateCertificateRequest("localhost", serverKeyPath, serverCSRPath); err != nil {
		t.Fatalf("Failed to generate server CSR: %v", err)
	}
	if err := pki.SignCertificateRequest(caKeyPath, caCertPath, serverCSRPath, serverCertPath, 24*time.Hour); err != nil {
		t.Fatalf("Failed to sign server certificate: %v", err)
	}

	// Generate client certificate
	clientKeyPath := filepath.Join(tempDir, "client.key")
	clientCSRPath := filepath.Join(tempDir, "client.csr")
	clientCertPath := filepath.Join(tempDir, "client.crt")
	if err := pki.GenerateCertificateRequest("client.test", clientKeyPath, clientCSRPath); err != nil {
		t.Fatalf("Failed to generate client CSR: %v", err)
	}
	if err := pki.SignCertificateRequest(caKeyPath, caCertPath, clientCSRPath, clientCertPath, 24*time.Hour); err != nil {
		t.Fatalf("Failed to sign client certificate: %v", err)
	}

	// Create and start server
	server, err := NewServer("localhost:8443", serverCertPath, serverKeyPath, caCertPath)
	if err != nil {
		t.Fatalf("Failed to create server: %v", err)
	}

	// Add a test handler
	server.router.HandleFunc("GET", "/test", func(w http.ResponseWriter, r *http.Request) {
		w.Write([]byte("test successful"))
	})

	// Start server in a goroutine
	go func() {
		if err := server.Start(); err != nil && err != http.ErrServerClosed {
			t.Errorf("Server error: %v", err)
		}
	}()

	// Wait for server to start
	time.Sleep(250 * time.Millisecond)

	// Load CA cert
	caCert, err := os.ReadFile(caCertPath)
	if err != nil {
		t.Fatalf("Failed to read CA cert: %v", err)
	}
	caCertPool := x509.NewCertPool()
	caCertPool.AppendCertsFromPEM(caCert)

	// Load client cert
	clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
	if err != nil {
		t.Fatalf("Failed to load client cert: %v", err)
	}

	// Create HTTP client with mTLS
	client := &http.Client{
		Transport: &http.Transport{
			TLSClientConfig: &tls.Config{
				RootCAs:      caCertPool,
				Certificates: []tls.Certificate{clientCert},
			},
		},
	}

	// Test with valid client cert
	resp, err := client.Get("https://localhost:8443/test")
	if err != nil {
		t.Fatalf("Request failed: %v", err)
	}
	defer resp.Body.Close()

	body, err := io.ReadAll(resp.Body)
	if err != nil {
		t.Fatalf("Failed to read response: %v", err)
	}

	if !strings.Contains(string(body), "test successful") {
		t.Errorf("Unexpected response: %s", body)
	}

	// Test with no client cert (should succeed in Phase 2)
	clientWithoutCert := &http.Client{
		Transport: &http.Transport{
			TLSClientConfig: &tls.Config{
				RootCAs: caCertPool,
			},
		},
	}

	resp, err = clientWithoutCert.Get("https://localhost:8443/test")
	if err != nil {
		t.Errorf("Request without client cert should succeed in Phase 2: %v", err)
	} else {
		defer resp.Body.Close()
		body, err := io.ReadAll(resp.Body)
		if err != nil {
			t.Errorf("Failed to read response: %v", err)
		}
		if !strings.Contains(string(body), "test successful") {
			t.Errorf("Unexpected response: %s", body)
		}
	}

	// Shutdown server
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()
	server.Stop(ctx)
}