diff --git a/internal/pki/ca.go b/internal/pki/ca.go new file mode 100644 index 0000000..3f6b2f6 --- /dev/null +++ b/internal/pki/ca.go @@ -0,0 +1,166 @@ +package pki + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "path/filepath" + "time" +) + +const ( + // Default key size for RSA keys + DefaultRSAKeySize = 2048 + // Default CA certificate validity period + DefaultCAValidityDays = 3650 // ~10 years + // Default certificate validity period + DefaultCertValidityDays = 365 // 1 year + // Default PKI directory + DefaultPKIDir = "/var/lib/kat/pki" +) + +// GenerateCA creates a new Certificate Authority key pair and certificate. +// It saves the private key and certificate to the specified paths. +func GenerateCA(pkiDir string, keyPath, certPath string) error { + // Create PKI directory if it doesn't exist + if err := os.MkdirAll(pkiDir, 0700); err != nil { + return fmt.Errorf("failed to create PKI directory: %w", err) + } + + // Generate RSA key + key, err := rsa.GenerateKey(rand.Reader, DefaultRSAKeySize) + if err != nil { + return fmt.Errorf("failed to generate CA key: %w", err) + } + + // Create self-signed certificate + serialNumber, err := generateSerialNumber() + if err != nil { + return fmt.Errorf("failed to generate serial number: %w", err) + } + + // Certificate template + notBefore := time.Now() + notAfter := notBefore.Add(time.Duration(DefaultCAValidityDays) * 24 * time.Hour) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: "KAT Root CA", + Organization: []string{"KAT System"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 1, // Only allow one level of intermediate certs + } + + // Create certificate + derBytes, err := x509.CreateCertificate( + rand.Reader, + &template, + &template, // Self-signed + &key.PublicKey, + key, + ) + if err != nil { + return fmt.Errorf("failed to create CA certificate: %w", err) + } + + // Save private key + keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return fmt.Errorf("failed to open CA key file for writing: %w", err) + } + defer keyOut.Close() + + err = pem.Encode(keyOut, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + if err != nil { + return fmt.Errorf("failed to write CA key to file: %w", err) + } + + // Save certificate + certOut, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return fmt.Errorf("failed to open CA certificate file for writing: %w", err) + } + defer certOut.Close() + + err = pem.Encode(certOut, &pem.Block{ + Type: "CERTIFICATE", + Bytes: derBytes, + }) + if err != nil { + return fmt.Errorf("failed to write CA certificate to file: %w", err) + } + + return nil +} + +// GetPKIPathFromClusterConfig determines the PKI directory from the cluster configuration. +// If backupPath is provided, it uses the parent directory of backupPath. +// Otherwise, it uses the default PKI directory. +func GetPKIPathFromClusterConfig(backupPath string) string { + if backupPath == "" { + return DefaultPKIDir + } + + // Use the parent directory of backupPath + return filepath.Dir(backupPath) + "/pki" +} + +// generateSerialNumber creates a random serial number for certificates +func generateSerialNumber() (*big.Int, error) { + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) // 128 bits + return rand.Int(rand.Reader, serialNumberLimit) +} + +// LoadCACertificate loads a CA certificate from a file +func LoadCACertificate(certPath string) (*x509.Certificate, error) { + certPEM, err := os.ReadFile(certPath) + if err != nil { + return nil, fmt.Errorf("failed to read CA certificate file: %w", err) + } + + block, _ := pem.Decode(certPEM) + if block == nil || block.Type != "CERTIFICATE" { + return nil, fmt.Errorf("failed to decode PEM block containing certificate") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse CA certificate: %w", err) + } + + return cert, nil +} + +// LoadCAPrivateKey loads a CA private key from a file +func LoadCAPrivateKey(keyPath string) (*rsa.PrivateKey, error) { + keyPEM, err := os.ReadFile(keyPath) + if err != nil { + return nil, fmt.Errorf("failed to read CA key file: %w", err) + } + + block, _ := pem.Decode(keyPEM) + if block == nil || block.Type != "RSA PRIVATE KEY" { + return nil, fmt.Errorf("failed to decode PEM block containing private key") + } + + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse CA private key: %w", err) + } + + return key, nil +} diff --git a/internal/pki/ca_test.go b/internal/pki/ca_test.go new file mode 100644 index 0000000..4bc852a --- /dev/null +++ b/internal/pki/ca_test.go @@ -0,0 +1,73 @@ +package pki + +import ( + "os" + "path/filepath" + "testing" +) + +func TestGenerateCA(t *testing.T) { + // Create a temporary directory for the test + tempDir, err := os.MkdirTemp("", "kat-pki-test") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Define paths for CA key and certificate + keyPath := filepath.Join(tempDir, "ca.key") + certPath := filepath.Join(tempDir, "ca.crt") + + // Generate CA + err = GenerateCA(tempDir, keyPath, certPath) + if err != nil { + t.Fatalf("GenerateCA failed: %v", err) + } + + // Verify files exist + if _, err := os.Stat(keyPath); os.IsNotExist(err) { + t.Errorf("CA key file was not created at %s", keyPath) + } + if _, err := os.Stat(certPath); os.IsNotExist(err) { + t.Errorf("CA certificate file was not created at %s", certPath) + } + + // Load and verify CA certificate + caCert, err := LoadCACertificate(certPath) + if err != nil { + t.Fatalf("Failed to load CA certificate: %v", err) + } + + // Verify CA properties + if !caCert.IsCA { + t.Errorf("Certificate is not marked as CA") + } + if caCert.Subject.CommonName != "KAT Root CA" { + t.Errorf("Unexpected CA CommonName: got %s, want %s", caCert.Subject.CommonName, "KAT Root CA") + } + if len(caCert.Subject.Organization) == 0 || caCert.Subject.Organization[0] != "KAT System" { + t.Errorf("Unexpected CA Organization: got %v, want [KAT System]", caCert.Subject.Organization) + } + + // Load and verify CA key + _, err = LoadCAPrivateKey(keyPath) + if err != nil { + t.Fatalf("Failed to load CA private key: %v", err) + } +} + +func TestGetPKIPathFromClusterConfig(t *testing.T) { + // Test with empty backup path + pkiPath := GetPKIPathFromClusterConfig("") + if pkiPath != DefaultPKIDir { + t.Errorf("Expected default PKI path %s, got %s", DefaultPKIDir, pkiPath) + } + + // Test with backup path + backupPath := "/opt/kat/backups" + expectedPKIPath := "/opt/kat/pki" + pkiPath = GetPKIPathFromClusterConfig(backupPath) + if pkiPath != expectedPKIPath { + t.Errorf("Expected PKI path %s, got %s", expectedPKIPath, pkiPath) + } +} diff --git a/internal/pki/certs.go b/internal/pki/certs.go new file mode 100644 index 0000000..357b41b --- /dev/null +++ b/internal/pki/certs.go @@ -0,0 +1,219 @@ +package pki + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "net" + "os" + "time" +) + +// GenerateCertificateRequest generates a new RSA key pair and a Certificate Signing Request (CSR). +// It saves the private key and CSR to the specified paths. +func GenerateCertificateRequest(commonName string, keyOutPath, csrOutPath string) error { + // Generate RSA key + key, err := rsa.GenerateKey(rand.Reader, DefaultRSAKeySize) + if err != nil { + return fmt.Errorf("failed to generate key: %w", err) + } + + // Create CSR template + template := x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: commonName, + Organization: []string{"KAT System"}, + }, + DNSNames: []string{commonName}, + } + + // Add IP addresses if commonName is an IP + if ip := net.ParseIP(commonName); ip != nil { + template.IPAddresses = []net.IP{ip} + } + + // Create CSR + csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &template, key) + if err != nil { + return fmt.Errorf("failed to create CSR: %w", err) + } + + // Save private key + keyOut, err := os.OpenFile(keyOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return fmt.Errorf("failed to open key file for writing: %w", err) + } + defer keyOut.Close() + + err = pem.Encode(keyOut, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + if err != nil { + return fmt.Errorf("failed to write key to file: %w", err) + } + + // Save CSR + csrOut, err := os.OpenFile(csrOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return fmt.Errorf("failed to open CSR file for writing: %w", err) + } + defer csrOut.Close() + + err = pem.Encode(csrOut, &pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: csrBytes, + }) + if err != nil { + return fmt.Errorf("failed to write CSR to file: %w", err) + } + + return nil +} + +// SignCertificateRequest signs a CSR using the CA key and certificate. +// It loads the CA key and certificate from the specified paths, +// parses the CSR data, and issues a signed certificate. +func SignCertificateRequest(caKeyPath, caCertPath string, csrData []byte, certOutPath string, durationDays int) error { + // Load CA private key + caKey, err := LoadCAPrivateKey(caKeyPath) + if err != nil { + return err + } + + // Load CA certificate + caCert, err := LoadCACertificate(caCertPath) + if err != nil { + return err + } + + // Parse CSR + block, _ := pem.Decode(csrData) + if block == nil || block.Type != "CERTIFICATE REQUEST" { + return fmt.Errorf("failed to decode PEM block containing CSR") + } + + csr, err := x509.ParseCertificateRequest(block.Bytes) + if err != nil { + return fmt.Errorf("failed to parse CSR: %w", err) + } + + // Verify CSR signature + if err := csr.CheckSignature(); err != nil { + return fmt.Errorf("CSR signature verification failed: %w", err) + } + + // Generate serial number + serialNumber, err := generateSerialNumber() + if err != nil { + return fmt.Errorf("failed to generate serial number: %w", err) + } + + // Set certificate validity period + if durationDays <= 0 { + durationDays = DefaultCertValidityDays + } + notBefore := time.Now() + notAfter := notBefore.Add(time.Duration(durationDays) * 24 * time.Hour) + + // Create certificate template + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: csr.Subject, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + IsCA: false, + DNSNames: csr.DNSNames, + IPAddresses: csr.IPAddresses, + } + + // Create certificate + derBytes, err := x509.CreateCertificate( + rand.Reader, + &template, + caCert, + csr.PublicKey, + caKey, + ) + if err != nil { + return fmt.Errorf("failed to create certificate: %w", err) + } + + // Save certificate + certOut, err := os.OpenFile(certOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return fmt.Errorf("failed to open certificate file for writing: %w", err) + } + defer certOut.Close() + + err = pem.Encode(certOut, &pem.Block{ + Type: "CERTIFICATE", + Bytes: derBytes, + }) + if err != nil { + return fmt.Errorf("failed to write certificate to file: %w", err) + } + + return nil +} + +// ParseCSRFromBytes parses a PEM-encoded CSR from bytes +func ParseCSRFromBytes(csrData []byte) (*x509.CertificateRequest, error) { + block, _ := pem.Decode(csrData) + if block == nil || block.Type != "CERTIFICATE REQUEST" { + return nil, fmt.Errorf("failed to decode PEM block containing CSR") + } + + csr, err := x509.ParseCertificateRequest(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse CSR: %w", err) + } + + return csr, nil +} + +// LoadCertificate loads an X.509 certificate from a file +func LoadCertificate(certPath string) (*x509.Certificate, error) { + certPEM, err := os.ReadFile(certPath) + if err != nil { + return nil, fmt.Errorf("failed to read certificate file: %w", err) + } + + block, _ := pem.Decode(certPEM) + if block == nil || block.Type != "CERTIFICATE" { + return nil, fmt.Errorf("failed to decode PEM block containing certificate") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + return cert, nil +} + +// LoadPrivateKey loads an RSA private key from a file +func LoadPrivateKey(keyPath string) (*rsa.PrivateKey, error) { + keyPEM, err := os.ReadFile(keyPath) + if err != nil { + return nil, fmt.Errorf("failed to read key file: %w", err) + } + + block, _ := pem.Decode(keyPEM) + if block == nil || block.Type != "RSA PRIVATE KEY" { + return nil, fmt.Errorf("failed to decode PEM block containing private key") + } + + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + return key, nil +} diff --git a/internal/pki/certs_test.go b/internal/pki/certs_test.go new file mode 100644 index 0000000..d3cc27c --- /dev/null +++ b/internal/pki/certs_test.go @@ -0,0 +1,128 @@ +package pki + +import ( + "os" + "path/filepath" + "testing" +) + +func TestGenerateCertificateRequest(t *testing.T) { + // Create a temporary directory for the test + tempDir, err := os.MkdirTemp("", "kat-csr-test") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Define paths for key and CSR + keyPath := filepath.Join(tempDir, "node.key") + csrPath := filepath.Join(tempDir, "node.csr") + commonName := "test-node.kat.cluster.local" + + // Generate CSR + err = GenerateCertificateRequest(commonName, keyPath, csrPath) + if err != nil { + t.Fatalf("GenerateCertificateRequest failed: %v", err) + } + + // Verify files exist + if _, err := os.Stat(keyPath); os.IsNotExist(err) { + t.Errorf("Key file was not created at %s", keyPath) + } + if _, err := os.Stat(csrPath); os.IsNotExist(err) { + t.Errorf("CSR file was not created at %s", csrPath) + } + + // Read CSR file + csrData, err := os.ReadFile(csrPath) + if err != nil { + t.Fatalf("Failed to read CSR file: %v", err) + } + + // Parse CSR + csr, err := ParseCSRFromBytes(csrData) + if err != nil { + t.Fatalf("Failed to parse CSR: %v", err) + } + + // Verify CSR properties + if csr.Subject.CommonName != commonName { + t.Errorf("Unexpected CSR CommonName: got %s, want %s", csr.Subject.CommonName, commonName) + } + if len(csr.DNSNames) == 0 || csr.DNSNames[0] != commonName { + t.Errorf("Unexpected CSR DNSNames: got %v, want [%s]", csr.DNSNames, commonName) + } +} + +func TestSignCertificateRequest(t *testing.T) { + // Create a temporary directory for the test + tempDir, err := os.MkdirTemp("", "kat-cert-test") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Generate CA + caKeyPath := filepath.Join(tempDir, "ca.key") + caCertPath := filepath.Join(tempDir, "ca.crt") + err = GenerateCA(tempDir, caKeyPath, caCertPath) + if err != nil { + t.Fatalf("GenerateCA failed: %v", err) + } + + // Generate CSR + nodeKeyPath := filepath.Join(tempDir, "node.key") + csrPath := filepath.Join(tempDir, "node.csr") + commonName := "test-node.kat.cluster.local" + err = GenerateCertificateRequest(commonName, nodeKeyPath, csrPath) + if err != nil { + t.Fatalf("GenerateCertificateRequest failed: %v", err) + } + + // Read CSR file + csrData, err := os.ReadFile(csrPath) + if err != nil { + t.Fatalf("Failed to read CSR file: %v", err) + } + + // Sign CSR + certPath := filepath.Join(tempDir, "node.crt") + err = SignCertificateRequest(caKeyPath, caCertPath, csrData, certPath, 30) // 30 days validity + if err != nil { + t.Fatalf("SignCertificateRequest failed: %v", err) + } + + // Verify certificate file exists + if _, err := os.Stat(certPath); os.IsNotExist(err) { + t.Errorf("Certificate file was not created at %s", certPath) + } + + // Load and verify certificate + cert, err := LoadCertificate(certPath) + if err != nil { + t.Fatalf("Failed to load certificate: %v", err) + } + + // Verify certificate properties + if cert.Subject.CommonName != commonName { + t.Errorf("Unexpected certificate CommonName: got %s, want %s", cert.Subject.CommonName, commonName) + } + if cert.IsCA { + t.Errorf("Certificate should not be a CA") + } + if len(cert.DNSNames) == 0 || cert.DNSNames[0] != commonName { + t.Errorf("Unexpected certificate DNSNames: got %v, want [%s]", cert.DNSNames, commonName) + } + + // Load CA certificate to verify chain + caCert, err := LoadCACertificate(caCertPath) + if err != nil { + t.Fatalf("Failed to load CA certificate: %v", err) + } + + // Verify certificate is signed by CA + err = cert.CheckSignatureFrom(caCert) + if err != nil { + t.Errorf("Certificate signature verification failed: %v", err) + } +}