kat/internal/api/server.go
Tanishq Dubey (aider) 4f7c2d6a66
I noticed a duplicate function in the internal/pki/ca.go file. I'll help you clean it up. Here's the corrected version:
```go
package pki

import (
	// other imports
	"path/filepath"
)

const (
	// Default key size for RSA keys
	DefaultRSAKeySize = 2048
	// Default CA certificate validity period
	DefaultCAValidityDays = 3650 // ~10 years
	// Default certificate validity period
	DefaultCertValidityDays = 365 // 1 year
	// Default PKI directory
	DefaultPKIDir = "/var/lib/kat/pki"
)

// GetPKIPathFromClusterConfig determines the PKI directory from the cluster configuration.
// If backupPath is provided, it uses the parent directory of backupPath.
// Otherwise, it uses the default PKI directory.
func GetPKIPathFromClusterConfig(backupPath string) string {
	if backupPath == "" {
		return DefaultPKIDir
	}

	// Use the parent directory of backupPath
	return filepath.Dir(backupPath) + "/pki"
}

// generateSerialNumber creates a random serial number for certificates
func generateSerialNumber() (*big.Int, error) {
	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) // 128 bits
	return rand.Int(rand.Reader, serialNumberLimit)
}

// Rest of the existing code...
```

The changes:
1. Removed the duplicate `GetPKIPathFromClusterConfig` function
2. Kept the single implementation that checks for an empty backup path
3. Maintained the default PKI directory as `/var/lib/kat/pki`

This should resolve the duplicate function issue while maintaining the desired functionality.

Would you like me to generate a commit message for this change?
2025-05-17 12:18:42 -04:00

142 lines
3.6 KiB
Go

package api
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"log"
"net/http"
"os"
"time"
)
// loggingResponseWriter is a wrapper for http.ResponseWriter to capture status code
type loggingResponseWriter struct {
http.ResponseWriter
statusCode int
}
// WriteHeader captures the status code before passing to the underlying ResponseWriter
func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
// LoggingMiddleware logs information about each request
func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Create a response writer wrapper to capture status code
lrw := &loggingResponseWriter{
ResponseWriter: w,
statusCode: http.StatusOK, // Default status
}
// Process the request
next.ServeHTTP(lrw, r)
// Calculate duration
duration := time.Since(start)
// Log the request details
log.Printf("REQUEST: %s %s - %d %s - %s - %v",
r.Method,
r.URL.Path,
lrw.statusCode,
http.StatusText(lrw.statusCode),
r.RemoteAddr,
duration,
)
})
}
// Server represents the API server for KAT
type Server struct {
httpServer *http.Server
router *Router
certFile string
keyFile string
caFile string
}
// NewServer creates a new API server instance
func NewServer(addr string, certFile, keyFile, caFile string) (*Server, error) {
router := NewRouter()
server := &Server{
router: router,
certFile: certFile,
keyFile: keyFile,
caFile: caFile,
}
// Create the HTTP server with TLS config
server.httpServer = &http.Server{
Addr: addr,
Handler: LoggingMiddleware(router), // Add logging middleware
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
return server, nil
}
// Start begins listening for requests
func (s *Server) Start() error {
log.Printf("Starting server on %s", s.httpServer.Addr)
// Load server certificate and key
cert, err := tls.LoadX509KeyPair(s.certFile, s.keyFile)
if err != nil {
return fmt.Errorf("failed to load server certificate and key: %w", err)
}
// Load CA certificate for client verification
caCert, err := os.ReadFile(s.caFile)
if err != nil {
return fmt.Errorf("failed to read CA certificate: %w", err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return fmt.Errorf("failed to append CA certificate to pool")
}
// Configure TLS
s.httpServer.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: caCertPool,
MinVersion: tls.VersionTLS12,
}
log.Printf("Server configured with TLS, starting to listen for requests")
// Start the server
return s.httpServer.ListenAndServeTLS("", "")
}
// Stop gracefully shuts down the server
func (s *Server) Stop(ctx context.Context) error {
log.Printf("Shutting down server on %s", s.httpServer.Addr)
err := s.httpServer.Shutdown(ctx)
if err != nil {
log.Printf("Error during server shutdown: %v", err)
return err
}
log.Printf("Server shutdown complete")
return nil
}
// RegisterJoinHandler registers the handler for agent join requests
func (s *Server) RegisterJoinHandler(handler http.HandlerFunc) {
s.router.HandleFunc("POST", "/internal/v1alpha1/join", handler)
}
// RegisterNodeStatusHandler registers the handler for node status updates
func (s *Server) RegisterNodeStatusHandler(handler http.HandlerFunc) {
s.router.HandleFunc("POST", "/v1alpha1/nodes/{nodeName}/status", handler)
}