kat/internal/pki/certs_test.go

129 lines
3.5 KiB
Go

package pki
import (
"os"
"path/filepath"
"testing"
)
func TestGenerateCertificateRequest(t *testing.T) {
// Create a temporary directory for the test
tempDir, err := os.MkdirTemp("", "kat-csr-test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Define paths for key and CSR
keyPath := filepath.Join(tempDir, "node.key")
csrPath := filepath.Join(tempDir, "node.csr")
commonName := "test-node.kat.cluster.local"
// Generate CSR
err = GenerateCertificateRequest(commonName, keyPath, csrPath)
if err != nil {
t.Fatalf("GenerateCertificateRequest failed: %v", err)
}
// Verify files exist
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
t.Errorf("Key file was not created at %s", keyPath)
}
if _, err := os.Stat(csrPath); os.IsNotExist(err) {
t.Errorf("CSR file was not created at %s", csrPath)
}
// Read CSR file
csrData, err := os.ReadFile(csrPath)
if err != nil {
t.Fatalf("Failed to read CSR file: %v", err)
}
// Parse CSR
csr, err := ParseCSRFromBytes(csrData)
if err != nil {
t.Fatalf("Failed to parse CSR: %v", err)
}
// Verify CSR properties
if csr.Subject.CommonName != commonName {
t.Errorf("Unexpected CSR CommonName: got %s, want %s", csr.Subject.CommonName, commonName)
}
if len(csr.DNSNames) == 0 || csr.DNSNames[0] != commonName {
t.Errorf("Unexpected CSR DNSNames: got %v, want [%s]", csr.DNSNames, commonName)
}
}
func TestSignCertificateRequest(t *testing.T) {
// Create a temporary directory for the test
tempDir, err := os.MkdirTemp("", "kat-cert-test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Generate CA
caKeyPath := filepath.Join(tempDir, "ca.key")
caCertPath := filepath.Join(tempDir, "ca.crt")
err = GenerateCA(tempDir, caKeyPath, caCertPath)
if err != nil {
t.Fatalf("GenerateCA failed: %v", err)
}
// Generate CSR
nodeKeyPath := filepath.Join(tempDir, "node.key")
csrPath := filepath.Join(tempDir, "node.csr")
commonName := "test-node.kat.cluster.local"
err = GenerateCertificateRequest(commonName, nodeKeyPath, csrPath)
if err != nil {
t.Fatalf("GenerateCertificateRequest failed: %v", err)
}
// Read CSR file
csrData, err := os.ReadFile(csrPath)
if err != nil {
t.Fatalf("Failed to read CSR file: %v", err)
}
// Sign CSR
certPath := filepath.Join(tempDir, "node.crt")
err = SignCertificateRequest(caKeyPath, caCertPath, string(csrData), certPath, 30) // 30 days validity
if err != nil {
t.Fatalf("SignCertificateRequest failed: %v", err)
}
// Verify certificate file exists
if _, err := os.Stat(certPath); os.IsNotExist(err) {
t.Errorf("Certificate file was not created at %s", certPath)
}
// Load and verify certificate
cert, err := LoadCertificate(certPath)
if err != nil {
t.Fatalf("Failed to load certificate: %v", err)
}
// Verify certificate properties
if cert.Subject.CommonName != commonName {
t.Errorf("Unexpected certificate CommonName: got %s, want %s", cert.Subject.CommonName, commonName)
}
if cert.IsCA {
t.Errorf("Certificate should not be a CA")
}
if len(cert.DNSNames) == 0 || cert.DNSNames[0] != commonName {
t.Errorf("Unexpected certificate DNSNames: got %v, want [%s]", cert.DNSNames, commonName)
}
// Load CA certificate to verify chain
caCert, err := LoadCACertificate(caCertPath)
if err != nil {
t.Fatalf("Failed to load CA certificate: %v", err)
}
// Verify certificate is signed by CA
err = cert.CheckSignatureFrom(caCert)
if err != nil {
t.Errorf("Certificate signature verification failed: %v", err)
}
}