feat: implement internal PKI utilities for CA and certificate management

This commit is contained in:
Tanishq Dubey 2025-05-16 20:47:57 -04:00
parent 58bdca5703
commit 7adabe8630
No known key found for this signature in database
GPG Key ID: CFC1931B84DFC3F9
4 changed files with 586 additions and 0 deletions

166
internal/pki/ca.go Normal file
View File

@ -0,0 +1,166 @@
package pki
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"os"
"path/filepath"
"time"
)
const (
// Default key size for RSA keys
DefaultRSAKeySize = 2048
// Default CA certificate validity period
DefaultCAValidityDays = 3650 // ~10 years
// Default certificate validity period
DefaultCertValidityDays = 365 // 1 year
// Default PKI directory
DefaultPKIDir = "/var/lib/kat/pki"
)
// GenerateCA creates a new Certificate Authority key pair and certificate.
// It saves the private key and certificate to the specified paths.
func GenerateCA(pkiDir string, keyPath, certPath string) error {
// Create PKI directory if it doesn't exist
if err := os.MkdirAll(pkiDir, 0700); err != nil {
return fmt.Errorf("failed to create PKI directory: %w", err)
}
// Generate RSA key
key, err := rsa.GenerateKey(rand.Reader, DefaultRSAKeySize)
if err != nil {
return fmt.Errorf("failed to generate CA key: %w", err)
}
// Create self-signed certificate
serialNumber, err := generateSerialNumber()
if err != nil {
return fmt.Errorf("failed to generate serial number: %w", err)
}
// Certificate template
notBefore := time.Now()
notAfter := notBefore.Add(time.Duration(DefaultCAValidityDays) * 24 * time.Hour)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "KAT Root CA",
Organization: []string{"KAT System"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
MaxPathLen: 1, // Only allow one level of intermediate certs
}
// Create certificate
derBytes, err := x509.CreateCertificate(
rand.Reader,
&template,
&template, // Self-signed
&key.PublicKey,
key,
)
if err != nil {
return fmt.Errorf("failed to create CA certificate: %w", err)
}
// Save private key
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to open CA key file for writing: %w", err)
}
defer keyOut.Close()
err = pem.Encode(keyOut, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
if err != nil {
return fmt.Errorf("failed to write CA key to file: %w", err)
}
// Save certificate
certOut, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return fmt.Errorf("failed to open CA certificate file for writing: %w", err)
}
defer certOut.Close()
err = pem.Encode(certOut, &pem.Block{
Type: "CERTIFICATE",
Bytes: derBytes,
})
if err != nil {
return fmt.Errorf("failed to write CA certificate to file: %w", err)
}
return nil
}
// GetPKIPathFromClusterConfig determines the PKI directory from the cluster configuration.
// If backupPath is provided, it uses the parent directory of backupPath.
// Otherwise, it uses the default PKI directory.
func GetPKIPathFromClusterConfig(backupPath string) string {
if backupPath == "" {
return DefaultPKIDir
}
// Use the parent directory of backupPath
return filepath.Dir(backupPath) + "/pki"
}
// generateSerialNumber creates a random serial number for certificates
func generateSerialNumber() (*big.Int, error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) // 128 bits
return rand.Int(rand.Reader, serialNumberLimit)
}
// LoadCACertificate loads a CA certificate from a file
func LoadCACertificate(certPath string) (*x509.Certificate, error) {
certPEM, err := os.ReadFile(certPath)
if err != nil {
return nil, fmt.Errorf("failed to read CA certificate file: %w", err)
}
block, _ := pem.Decode(certPEM)
if block == nil || block.Type != "CERTIFICATE" {
return nil, fmt.Errorf("failed to decode PEM block containing certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse CA certificate: %w", err)
}
return cert, nil
}
// LoadCAPrivateKey loads a CA private key from a file
func LoadCAPrivateKey(keyPath string) (*rsa.PrivateKey, error) {
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
return nil, fmt.Errorf("failed to read CA key file: %w", err)
}
block, _ := pem.Decode(keyPEM)
if block == nil || block.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("failed to decode PEM block containing private key")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse CA private key: %w", err)
}
return key, nil
}

73
internal/pki/ca_test.go Normal file
View File

@ -0,0 +1,73 @@
package pki
import (
"os"
"path/filepath"
"testing"
)
func TestGenerateCA(t *testing.T) {
// Create a temporary directory for the test
tempDir, err := os.MkdirTemp("", "kat-pki-test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Define paths for CA key and certificate
keyPath := filepath.Join(tempDir, "ca.key")
certPath := filepath.Join(tempDir, "ca.crt")
// Generate CA
err = GenerateCA(tempDir, keyPath, certPath)
if err != nil {
t.Fatalf("GenerateCA failed: %v", err)
}
// Verify files exist
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
t.Errorf("CA key file was not created at %s", keyPath)
}
if _, err := os.Stat(certPath); os.IsNotExist(err) {
t.Errorf("CA certificate file was not created at %s", certPath)
}
// Load and verify CA certificate
caCert, err := LoadCACertificate(certPath)
if err != nil {
t.Fatalf("Failed to load CA certificate: %v", err)
}
// Verify CA properties
if !caCert.IsCA {
t.Errorf("Certificate is not marked as CA")
}
if caCert.Subject.CommonName != "KAT Root CA" {
t.Errorf("Unexpected CA CommonName: got %s, want %s", caCert.Subject.CommonName, "KAT Root CA")
}
if len(caCert.Subject.Organization) == 0 || caCert.Subject.Organization[0] != "KAT System" {
t.Errorf("Unexpected CA Organization: got %v, want [KAT System]", caCert.Subject.Organization)
}
// Load and verify CA key
_, err = LoadCAPrivateKey(keyPath)
if err != nil {
t.Fatalf("Failed to load CA private key: %v", err)
}
}
func TestGetPKIPathFromClusterConfig(t *testing.T) {
// Test with empty backup path
pkiPath := GetPKIPathFromClusterConfig("")
if pkiPath != DefaultPKIDir {
t.Errorf("Expected default PKI path %s, got %s", DefaultPKIDir, pkiPath)
}
// Test with backup path
backupPath := "/opt/kat/backups"
expectedPKIPath := "/opt/kat/pki"
pkiPath = GetPKIPathFromClusterConfig(backupPath)
if pkiPath != expectedPKIPath {
t.Errorf("Expected PKI path %s, got %s", expectedPKIPath, pkiPath)
}
}

219
internal/pki/certs.go Normal file
View File

@ -0,0 +1,219 @@
package pki
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"net"
"os"
"time"
)
// GenerateCertificateRequest generates a new RSA key pair and a Certificate Signing Request (CSR).
// It saves the private key and CSR to the specified paths.
func GenerateCertificateRequest(commonName string, keyOutPath, csrOutPath string) error {
// Generate RSA key
key, err := rsa.GenerateKey(rand.Reader, DefaultRSAKeySize)
if err != nil {
return fmt.Errorf("failed to generate key: %w", err)
}
// Create CSR template
template := x509.CertificateRequest{
Subject: pkix.Name{
CommonName: commonName,
Organization: []string{"KAT System"},
},
DNSNames: []string{commonName},
}
// Add IP addresses if commonName is an IP
if ip := net.ParseIP(commonName); ip != nil {
template.IPAddresses = []net.IP{ip}
}
// Create CSR
csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &template, key)
if err != nil {
return fmt.Errorf("failed to create CSR: %w", err)
}
// Save private key
keyOut, err := os.OpenFile(keyOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to open key file for writing: %w", err)
}
defer keyOut.Close()
err = pem.Encode(keyOut, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
if err != nil {
return fmt.Errorf("failed to write key to file: %w", err)
}
// Save CSR
csrOut, err := os.OpenFile(csrOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return fmt.Errorf("failed to open CSR file for writing: %w", err)
}
defer csrOut.Close()
err = pem.Encode(csrOut, &pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csrBytes,
})
if err != nil {
return fmt.Errorf("failed to write CSR to file: %w", err)
}
return nil
}
// SignCertificateRequest signs a CSR using the CA key and certificate.
// It loads the CA key and certificate from the specified paths,
// parses the CSR data, and issues a signed certificate.
func SignCertificateRequest(caKeyPath, caCertPath string, csrData []byte, certOutPath string, durationDays int) error {
// Load CA private key
caKey, err := LoadCAPrivateKey(caKeyPath)
if err != nil {
return err
}
// Load CA certificate
caCert, err := LoadCACertificate(caCertPath)
if err != nil {
return err
}
// Parse CSR
block, _ := pem.Decode(csrData)
if block == nil || block.Type != "CERTIFICATE REQUEST" {
return fmt.Errorf("failed to decode PEM block containing CSR")
}
csr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse CSR: %w", err)
}
// Verify CSR signature
if err := csr.CheckSignature(); err != nil {
return fmt.Errorf("CSR signature verification failed: %w", err)
}
// Generate serial number
serialNumber, err := generateSerialNumber()
if err != nil {
return fmt.Errorf("failed to generate serial number: %w", err)
}
// Set certificate validity period
if durationDays <= 0 {
durationDays = DefaultCertValidityDays
}
notBefore := time.Now()
notAfter := notBefore.Add(time.Duration(durationDays) * 24 * time.Hour)
// Create certificate template
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: csr.Subject,
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
IsCA: false,
DNSNames: csr.DNSNames,
IPAddresses: csr.IPAddresses,
}
// Create certificate
derBytes, err := x509.CreateCertificate(
rand.Reader,
&template,
caCert,
csr.PublicKey,
caKey,
)
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
}
// Save certificate
certOut, err := os.OpenFile(certOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return fmt.Errorf("failed to open certificate file for writing: %w", err)
}
defer certOut.Close()
err = pem.Encode(certOut, &pem.Block{
Type: "CERTIFICATE",
Bytes: derBytes,
})
if err != nil {
return fmt.Errorf("failed to write certificate to file: %w", err)
}
return nil
}
// ParseCSRFromBytes parses a PEM-encoded CSR from bytes
func ParseCSRFromBytes(csrData []byte) (*x509.CertificateRequest, error) {
block, _ := pem.Decode(csrData)
if block == nil || block.Type != "CERTIFICATE REQUEST" {
return nil, fmt.Errorf("failed to decode PEM block containing CSR")
}
csr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse CSR: %w", err)
}
return csr, nil
}
// LoadCertificate loads an X.509 certificate from a file
func LoadCertificate(certPath string) (*x509.Certificate, error) {
certPEM, err := os.ReadFile(certPath)
if err != nil {
return nil, fmt.Errorf("failed to read certificate file: %w", err)
}
block, _ := pem.Decode(certPEM)
if block == nil || block.Type != "CERTIFICATE" {
return nil, fmt.Errorf("failed to decode PEM block containing certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate: %w", err)
}
return cert, nil
}
// LoadPrivateKey loads an RSA private key from a file
func LoadPrivateKey(keyPath string) (*rsa.PrivateKey, error) {
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
return nil, fmt.Errorf("failed to read key file: %w", err)
}
block, _ := pem.Decode(keyPEM)
if block == nil || block.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("failed to decode PEM block containing private key")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
return key, nil
}

128
internal/pki/certs_test.go Normal file
View File

@ -0,0 +1,128 @@
package pki
import (
"os"
"path/filepath"
"testing"
)
func TestGenerateCertificateRequest(t *testing.T) {
// Create a temporary directory for the test
tempDir, err := os.MkdirTemp("", "kat-csr-test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Define paths for key and CSR
keyPath := filepath.Join(tempDir, "node.key")
csrPath := filepath.Join(tempDir, "node.csr")
commonName := "test-node.kat.cluster.local"
// Generate CSR
err = GenerateCertificateRequest(commonName, keyPath, csrPath)
if err != nil {
t.Fatalf("GenerateCertificateRequest failed: %v", err)
}
// Verify files exist
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
t.Errorf("Key file was not created at %s", keyPath)
}
if _, err := os.Stat(csrPath); os.IsNotExist(err) {
t.Errorf("CSR file was not created at %s", csrPath)
}
// Read CSR file
csrData, err := os.ReadFile(csrPath)
if err != nil {
t.Fatalf("Failed to read CSR file: %v", err)
}
// Parse CSR
csr, err := ParseCSRFromBytes(csrData)
if err != nil {
t.Fatalf("Failed to parse CSR: %v", err)
}
// Verify CSR properties
if csr.Subject.CommonName != commonName {
t.Errorf("Unexpected CSR CommonName: got %s, want %s", csr.Subject.CommonName, commonName)
}
if len(csr.DNSNames) == 0 || csr.DNSNames[0] != commonName {
t.Errorf("Unexpected CSR DNSNames: got %v, want [%s]", csr.DNSNames, commonName)
}
}
func TestSignCertificateRequest(t *testing.T) {
// Create a temporary directory for the test
tempDir, err := os.MkdirTemp("", "kat-cert-test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Generate CA
caKeyPath := filepath.Join(tempDir, "ca.key")
caCertPath := filepath.Join(tempDir, "ca.crt")
err = GenerateCA(tempDir, caKeyPath, caCertPath)
if err != nil {
t.Fatalf("GenerateCA failed: %v", err)
}
// Generate CSR
nodeKeyPath := filepath.Join(tempDir, "node.key")
csrPath := filepath.Join(tempDir, "node.csr")
commonName := "test-node.kat.cluster.local"
err = GenerateCertificateRequest(commonName, nodeKeyPath, csrPath)
if err != nil {
t.Fatalf("GenerateCertificateRequest failed: %v", err)
}
// Read CSR file
csrData, err := os.ReadFile(csrPath)
if err != nil {
t.Fatalf("Failed to read CSR file: %v", err)
}
// Sign CSR
certPath := filepath.Join(tempDir, "node.crt")
err = SignCertificateRequest(caKeyPath, caCertPath, csrData, certPath, 30) // 30 days validity
if err != nil {
t.Fatalf("SignCertificateRequest failed: %v", err)
}
// Verify certificate file exists
if _, err := os.Stat(certPath); os.IsNotExist(err) {
t.Errorf("Certificate file was not created at %s", certPath)
}
// Load and verify certificate
cert, err := LoadCertificate(certPath)
if err != nil {
t.Fatalf("Failed to load certificate: %v", err)
}
// Verify certificate properties
if cert.Subject.CommonName != commonName {
t.Errorf("Unexpected certificate CommonName: got %s, want %s", cert.Subject.CommonName, commonName)
}
if cert.IsCA {
t.Errorf("Certificate should not be a CA")
}
if len(cert.DNSNames) == 0 || cert.DNSNames[0] != commonName {
t.Errorf("Unexpected certificate DNSNames: got %v, want [%s]", cert.DNSNames, commonName)
}
// Load CA certificate to verify chain
caCert, err := LoadCACertificate(caCertPath)
if err != nil {
t.Fatalf("Failed to load CA certificate: %v", err)
}
// Verify certificate is signed by CA
err = cert.CheckSignatureFrom(caCert)
if err != nil {
t.Errorf("Certificate signature verification failed: %v", err)
}
}