142 lines
4.1 KiB
Go
142 lines
4.1 KiB
Go
package api
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"kat-system/internal/pki"
|
|
"kat-system/internal/store"
|
|
)
|
|
|
|
// JoinRequest represents the data sent by an agent when joining
|
|
type JoinRequest struct {
|
|
CSR []byte `json:"csr"`
|
|
AdvertiseAddr string `json:"advertiseAddr"`
|
|
NodeName string `json:"nodeName,omitempty"` // Optional, leader can generate
|
|
WireguardPubKey string `json:"wireguardPubKey"` // Placeholder for now
|
|
}
|
|
|
|
// JoinResponse represents the data sent back to the agent
|
|
type JoinResponse struct {
|
|
NodeName string `json:"nodeName"`
|
|
NodeUID string `json:"nodeUID"`
|
|
SignedCert []byte `json:"signedCert"`
|
|
CACert []byte `json:"caCert"`
|
|
JoinTimestamp int64 `json:"joinTimestamp"`
|
|
}
|
|
|
|
// NewJoinHandler creates a handler for agent join requests
|
|
func NewJoinHandler(stateStore store.StateStore, caKeyPath, caCertPath string) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Read and parse the request body
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusBadRequest)
|
|
return
|
|
}
|
|
defer r.Body.Close()
|
|
|
|
var joinReq JoinRequest
|
|
if err := json.Unmarshal(body, &joinReq); err != nil {
|
|
http.Error(w, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate request
|
|
if len(joinReq.CSR) == 0 {
|
|
http.Error(w, "Missing CSR", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if joinReq.AdvertiseAddr == "" {
|
|
http.Error(w, "Missing advertise address", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Generate node name if not provided
|
|
nodeName := joinReq.NodeName
|
|
if nodeName == "" {
|
|
nodeName = fmt.Sprintf("node-%s", uuid.New().String()[:8])
|
|
}
|
|
|
|
// Generate a unique node ID
|
|
nodeUID := uuid.New().String()
|
|
|
|
// Sign the CSR
|
|
// Create a temporary file for the CSR
|
|
tempDir := os.TempDir()
|
|
csrPath := filepath.Join(tempDir, fmt.Sprintf("%s.csr", nodeUID))
|
|
if err := os.WriteFile(csrPath, joinReq.CSR, 0600); err != nil {
|
|
http.Error(w, fmt.Sprintf("Failed to save CSR: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer os.Remove(csrPath)
|
|
|
|
// Sign the CSR
|
|
certPath := filepath.Join(tempDir, fmt.Sprintf("%s.crt", nodeUID))
|
|
if err := pki.SignCertificateRequest(caKeyPath, caCertPath, csrPath, certPath, 365*24*time.Hour); err != nil {
|
|
http.Error(w, fmt.Sprintf("Failed to sign CSR: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer os.Remove(certPath)
|
|
|
|
// Read the signed certificate
|
|
signedCert, err := os.ReadFile(certPath)
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("Failed to read signed certificate: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Read the CA certificate
|
|
caCert, err := os.ReadFile(caCertPath)
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("Failed to read CA certificate: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Store node registration in etcd
|
|
nodeRegKey := fmt.Sprintf("/kat/nodes/registration/%s", nodeName)
|
|
nodeReg := map[string]interface{}{
|
|
"uid": nodeUID,
|
|
"advertiseAddr": joinReq.AdvertiseAddr,
|
|
"wireguardPubKey": joinReq.WireguardPubKey,
|
|
"joinTimestamp": time.Now().Unix(),
|
|
}
|
|
nodeRegData, err := json.Marshal(nodeReg)
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("Failed to marshal node registration: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if err := stateStore.Put(r.Context(), nodeRegKey, nodeRegData); err != nil {
|
|
http.Error(w, fmt.Sprintf("Failed to store node registration: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Prepare and send response
|
|
joinResp := JoinResponse{
|
|
NodeName: nodeName,
|
|
NodeUID: nodeUID,
|
|
SignedCert: signedCert,
|
|
CACert: caCert,
|
|
JoinTimestamp: time.Now().Unix(),
|
|
}
|
|
|
|
respData, err := json.Marshal(joinResp)
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("Failed to marshal response: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(respData)
|
|
}
|
|
}
|