package pki

import (
	"os"
	"path/filepath"
	"testing"
)

func TestGenerateCertificateRequest(t *testing.T) {
	// Create a temporary directory for the test
	tempDir, err := os.MkdirTemp("", "kat-csr-test")
	if err != nil {
		t.Fatalf("Failed to create temp directory: %v", err)
	}
	defer os.RemoveAll(tempDir)

	// Define paths for key and CSR
	keyPath := filepath.Join(tempDir, "node.key")
	csrPath := filepath.Join(tempDir, "node.csr")
	commonName := "test-node.kat.cluster.local"

	// Generate CSR
	err = GenerateCertificateRequest(commonName, keyPath, csrPath)
	if err != nil {
		t.Fatalf("GenerateCertificateRequest failed: %v", err)
	}

	// Verify files exist
	if _, err := os.Stat(keyPath); os.IsNotExist(err) {
		t.Errorf("Key file was not created at %s", keyPath)
	}
	if _, err := os.Stat(csrPath); os.IsNotExist(err) {
		t.Errorf("CSR file was not created at %s", csrPath)
	}

	// Read CSR file
	csrData, err := os.ReadFile(csrPath)
	if err != nil {
		t.Fatalf("Failed to read CSR file: %v", err)
	}

	// Parse CSR
	csr, err := ParseCSRFromBytes(csrData)
	if err != nil {
		t.Fatalf("Failed to parse CSR: %v", err)
	}

	// Verify CSR properties
	if csr.Subject.CommonName != commonName {
		t.Errorf("Unexpected CSR CommonName: got %s, want %s", csr.Subject.CommonName, commonName)
	}
	if len(csr.DNSNames) == 0 || csr.DNSNames[0] != commonName {
		t.Errorf("Unexpected CSR DNSNames: got %v, want [%s]", csr.DNSNames, commonName)
	}
}

func TestSignCertificateRequest(t *testing.T) {
	// Create a temporary directory for the test
	tempDir, err := os.MkdirTemp("", "kat-cert-test")
	if err != nil {
		t.Fatalf("Failed to create temp directory: %v", err)
	}
	defer os.RemoveAll(tempDir)

	// Generate CA
	caKeyPath := filepath.Join(tempDir, "ca.key")
	caCertPath := filepath.Join(tempDir, "ca.crt")
	err = GenerateCA(tempDir, caKeyPath, caCertPath)
	if err != nil {
		t.Fatalf("GenerateCA failed: %v", err)
	}

	// Generate CSR
	nodeKeyPath := filepath.Join(tempDir, "node.key")
	csrPath := filepath.Join(tempDir, "node.csr")
	commonName := "test-node.kat.cluster.local"
	err = GenerateCertificateRequest(commonName, nodeKeyPath, csrPath)
	if err != nil {
		t.Fatalf("GenerateCertificateRequest failed: %v", err)
	}

	// Read CSR file
	csrData, err := os.ReadFile(csrPath)
	if err != nil {
		t.Fatalf("Failed to read CSR file: %v", err)
	}

	// Sign CSR
	certPath := filepath.Join(tempDir, "node.crt")
	err = SignCertificateRequest(caKeyPath, caCertPath, csrData, certPath, 30) // 30 days validity
	if err != nil {
		t.Fatalf("SignCertificateRequest failed: %v", err)
	}

	// Verify certificate file exists
	if _, err := os.Stat(certPath); os.IsNotExist(err) {
		t.Errorf("Certificate file was not created at %s", certPath)
	}

	// Load and verify certificate
	cert, err := LoadCertificate(certPath)
	if err != nil {
		t.Fatalf("Failed to load certificate: %v", err)
	}

	// Verify certificate properties
	if cert.Subject.CommonName != commonName {
		t.Errorf("Unexpected certificate CommonName: got %s, want %s", cert.Subject.CommonName, commonName)
	}
	if cert.IsCA {
		t.Errorf("Certificate should not be a CA")
	}
	if len(cert.DNSNames) == 0 || cert.DNSNames[0] != commonName {
		t.Errorf("Unexpected certificate DNSNames: got %v, want [%s]", cert.DNSNames, commonName)
	}

	// Load CA certificate to verify chain
	caCert, err := LoadCACertificate(caCertPath)
	if err != nil {
		t.Fatalf("Failed to load CA certificate: %v", err)
	}

	// Verify certificate is signed by CA
	err = cert.CheckSignatureFrom(caCert)
	if err != nil {
		t.Errorf("Certificate signature verification failed: %v", err)
	}
}