```go package pki import ( // other imports "path/filepath" ) const ( // Default key size for RSA keys DefaultRSAKeySize = 2048 // Default CA certificate validity period DefaultCAValidityDays = 3650 // ~10 years // Default certificate validity period DefaultCertValidityDays = 365 // 1 year // Default PKI directory DefaultPKIDir = "/var/lib/kat/pki" ) // GetPKIPathFromClusterConfig determines the PKI directory from the cluster configuration. // If backupPath is provided, it uses the parent directory of backupPath. // Otherwise, it uses the default PKI directory. func GetPKIPathFromClusterConfig(backupPath string) string { if backupPath == "" { return DefaultPKIDir } // Use the parent directory of backupPath return filepath.Dir(backupPath) + "/pki" } // generateSerialNumber creates a random serial number for certificates func generateSerialNumber() (*big.Int, error) { serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) // 128 bits return rand.Int(rand.Reader, serialNumberLimit) } // Rest of the existing code... ``` The changes: 1. Removed the duplicate `GetPKIPathFromClusterConfig` function 2. Kept the single implementation that checks for an empty backup path 3. Maintained the default PKI directory as `/var/lib/kat/pki` This should resolve the duplicate function issue while maintaining the desired functionality. Would you like me to generate a commit message for this change?
142 lines
3.6 KiB
Go
142 lines
3.6 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
)
|
|
|
|
// loggingResponseWriter is a wrapper for http.ResponseWriter to capture status code
|
|
type loggingResponseWriter struct {
|
|
http.ResponseWriter
|
|
statusCode int
|
|
}
|
|
|
|
// WriteHeader captures the status code before passing to the underlying ResponseWriter
|
|
func (lrw *loggingResponseWriter) WriteHeader(code int) {
|
|
lrw.statusCode = code
|
|
lrw.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
// LoggingMiddleware logs information about each request
|
|
func LoggingMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
|
|
// Create a response writer wrapper to capture status code
|
|
lrw := &loggingResponseWriter{
|
|
ResponseWriter: w,
|
|
statusCode: http.StatusOK, // Default status
|
|
}
|
|
|
|
// Process the request
|
|
next.ServeHTTP(lrw, r)
|
|
|
|
// Calculate duration
|
|
duration := time.Since(start)
|
|
|
|
// Log the request details
|
|
log.Printf("REQUEST: %s %s - %d %s - %s - %v",
|
|
r.Method,
|
|
r.URL.Path,
|
|
lrw.statusCode,
|
|
http.StatusText(lrw.statusCode),
|
|
r.RemoteAddr,
|
|
duration,
|
|
)
|
|
})
|
|
}
|
|
|
|
// Server represents the API server for KAT
|
|
type Server struct {
|
|
httpServer *http.Server
|
|
router *Router
|
|
certFile string
|
|
keyFile string
|
|
caFile string
|
|
}
|
|
|
|
// NewServer creates a new API server instance
|
|
func NewServer(addr string, certFile, keyFile, caFile string) (*Server, error) {
|
|
router := NewRouter()
|
|
|
|
server := &Server{
|
|
router: router,
|
|
certFile: certFile,
|
|
keyFile: keyFile,
|
|
caFile: caFile,
|
|
}
|
|
|
|
// Create the HTTP server with TLS config
|
|
server.httpServer = &http.Server{
|
|
Addr: addr,
|
|
Handler: LoggingMiddleware(router), // Add logging middleware
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
IdleTimeout: 120 * time.Second,
|
|
}
|
|
|
|
return server, nil
|
|
}
|
|
|
|
// Start begins listening for requests
|
|
func (s *Server) Start() error {
|
|
log.Printf("Starting server on %s", s.httpServer.Addr)
|
|
|
|
// Load server certificate and key
|
|
cert, err := tls.LoadX509KeyPair(s.certFile, s.keyFile)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load server certificate and key: %w", err)
|
|
}
|
|
|
|
// Load CA certificate for client verification
|
|
caCert, err := os.ReadFile(s.caFile)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read CA certificate: %w", err)
|
|
}
|
|
|
|
caCertPool := x509.NewCertPool()
|
|
if !caCertPool.AppendCertsFromPEM(caCert) {
|
|
return fmt.Errorf("failed to append CA certificate to pool")
|
|
}
|
|
|
|
// Configure TLS
|
|
s.httpServer.TLSConfig = &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
ClientCAs: caCertPool,
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
|
|
log.Printf("Server configured with TLS, starting to listen for requests")
|
|
// Start the server
|
|
return s.httpServer.ListenAndServeTLS("", "")
|
|
}
|
|
|
|
// Stop gracefully shuts down the server
|
|
func (s *Server) Stop(ctx context.Context) error {
|
|
log.Printf("Shutting down server on %s", s.httpServer.Addr)
|
|
err := s.httpServer.Shutdown(ctx)
|
|
if err != nil {
|
|
log.Printf("Error during server shutdown: %v", err)
|
|
return err
|
|
}
|
|
log.Printf("Server shutdown complete")
|
|
return nil
|
|
}
|
|
|
|
// RegisterJoinHandler registers the handler for agent join requests
|
|
func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) {
|
|
s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler)
|
|
}
|
|
|
|
// RegisterNodeStatusHandler registers the handler for node status updates
|
|
func (s *Server) RegisterNodeStatusHandler(handler http.HandlerFunc) {
|
|
s.router.HandleFunc("POST", "/v1alpha1/nodes/{nodeName}/status", handler)
|
|
}
|