package pki import ( "os" "path/filepath" "testing" ) func TestGenerateCertificateRequest(t *testing.T) { // Create a temporary directory for the test tempDir, err := os.MkdirTemp("", "kat-csr-test") if err != nil { t.Fatalf("Failed to create temp directory: %v", err) } defer os.RemoveAll(tempDir) // Define paths for key and CSR keyPath := filepath.Join(tempDir, "node.key") csrPath := filepath.Join(tempDir, "node.csr") commonName := "test-node.kat.cluster.local" // Generate CSR err = GenerateCertificateRequest(commonName, keyPath, csrPath) if err != nil { t.Fatalf("GenerateCertificateRequest failed: %v", err) } // Verify files exist if _, err := os.Stat(keyPath); os.IsNotExist(err) { t.Errorf("Key file was not created at %s", keyPath) } if _, err := os.Stat(csrPath); os.IsNotExist(err) { t.Errorf("CSR file was not created at %s", csrPath) } // Read CSR file csrData, err := os.ReadFile(csrPath) if err != nil { t.Fatalf("Failed to read CSR file: %v", err) } // Parse CSR csr, err := ParseCSRFromBytes(csrData) if err != nil { t.Fatalf("Failed to parse CSR: %v", err) } // Verify CSR properties if csr.Subject.CommonName != commonName { t.Errorf("Unexpected CSR CommonName: got %s, want %s", csr.Subject.CommonName, commonName) } if len(csr.DNSNames) == 0 || csr.DNSNames[0] != commonName { t.Errorf("Unexpected CSR DNSNames: got %v, want [%s]", csr.DNSNames, commonName) } } func TestSignCertificateRequest(t *testing.T) { // Create a temporary directory for the test tempDir, err := os.MkdirTemp("", "kat-cert-test") if err != nil { t.Fatalf("Failed to create temp directory: %v", err) } defer os.RemoveAll(tempDir) // Generate CA caKeyPath := filepath.Join(tempDir, "ca.key") caCertPath := filepath.Join(tempDir, "ca.crt") err = GenerateCA(tempDir, caKeyPath, caCertPath) if err != nil { t.Fatalf("GenerateCA failed: %v", err) } // Generate CSR nodeKeyPath := filepath.Join(tempDir, "node.key") csrPath := filepath.Join(tempDir, "node.csr") commonName := "test-node.kat.cluster.local" err = GenerateCertificateRequest(commonName, nodeKeyPath, csrPath) if err != nil { t.Fatalf("GenerateCertificateRequest failed: %v", err) } // Read CSR file csrData, err := os.ReadFile(csrPath) if err != nil { t.Fatalf("Failed to read CSR file: %v", err) } // Sign CSR certPath := filepath.Join(tempDir, "node.crt") err = SignCertificateRequest(caKeyPath, caCertPath, string(csrData), certPath, 30) // 30 days validity if err != nil { t.Fatalf("SignCertificateRequest failed: %v", err) } // Verify certificate file exists if _, err := os.Stat(certPath); os.IsNotExist(err) { t.Errorf("Certificate file was not created at %s", certPath) } // Load and verify certificate cert, err := LoadCertificate(certPath) if err != nil { t.Fatalf("Failed to load certificate: %v", err) } // Verify certificate properties if cert.Subject.CommonName != commonName { t.Errorf("Unexpected certificate CommonName: got %s, want %s", cert.Subject.CommonName, commonName) } if cert.IsCA { t.Errorf("Certificate should not be a CA") } if len(cert.DNSNames) == 0 || cert.DNSNames[0] != commonName { t.Errorf("Unexpected certificate DNSNames: got %v, want [%s]", cert.DNSNames, commonName) } // Load CA certificate to verify chain caCert, err := LoadCACertificate(caCertPath) if err != nil { t.Fatalf("Failed to load CA certificate: %v", err) } // Verify certificate is signed by CA err = cert.CheckSignatureFrom(caCert) if err != nil { t.Errorf("Certificate signature verification failed: %v", err) } }