package store import ( "context" "fmt" "log" "net/url" "sync" "time" "go.etcd.io/etcd/client/v3/concurrency" "go.etcd.io/etcd/server/v3/embed" "go.etcd.io/etcd/server/v3/etcdserver/api/v3client" clientv3 "go.etcd.io/etcd/client/v3" ) const ( defaultDialTimeout = 5 * time.Second defaultRequestTimeout = 5 * time.Second leaderElectionPrefix = "/kat/leader_election/" ) // EtcdEmbedConfig holds configuration for an embedded etcd server. type EtcdEmbedConfig struct { Name string DataDir string ClientURLs []string // URLs for client communication PeerURLs []string // URLs for peer communication InitialCluster string // e.g., "node1=http://localhost:2380" // Add other etcd config fields as needed: LogLevel, etc. } // EtcdStore implements the StateStore interface using etcd. type EtcdStore struct { client *clientv3.Client etcdServer *embed.Etcd // Holds the embedded server instance, if any // For leadership session *concurrency.Session election *concurrency.Election leaderID string leaseTTL int64 campaignCtx context.Context campaignDone func() // Cancels campaignCtx resignMutex sync.Mutex // Protects session and election during resign } // StartEmbeddedEtcd starts an embedded etcd server based on the provided config. func StartEmbeddedEtcd(cfg EtcdEmbedConfig) (*embed.Etcd, error) { embedCfg := embed.NewConfig() embedCfg.Name = cfg.Name embedCfg.Dir = cfg.DataDir embedCfg.InitialClusterToken = "kat-etcd-cluster" // Make this configurable if needed embedCfg.ForceNewCluster = false // Set to true only for initial bootstrap of a new cluster if needed lpurl, err := parseURLs(cfg.PeerURLs) if err != nil { return nil, fmt.Errorf("invalid peer URLs: %w", err) } embedCfg.ListenPeerUrls = lpurl // Set the advertise peer URLs to match the listen peer URLs embedCfg.AdvertisePeerUrls = lpurl // Update the initial cluster to use the same URLs initialCluster := fmt.Sprintf("%s=%s", cfg.Name, cfg.PeerURLs[0]) embedCfg.InitialCluster = initialCluster lcurl, err := parseURLs(cfg.ClientURLs) if err != nil { return nil, fmt.Errorf("invalid client URLs: %w", err) } embedCfg.ListenClientUrls = lcurl // TODO: Configure logging, metrics, etc. for embedded etcd // embedCfg.Logger = "zap" // embedCfg.LogLevel = "info" e, err := embed.StartEtcd(embedCfg) if err != nil { return nil, fmt.Errorf("failed to start embedded etcd: %w", err) } select { case <-e.Server.ReadyNotify(): log.Printf("Embedded etcd server is ready (name: %s)", cfg.Name) case <-time.After(60 * time.Second): // Adjust timeout as needed e.Server.Stop() // trigger a shutdown return nil, fmt.Errorf("embedded etcd server took too long to start") } return e, nil } func parseURLs(urlsStr []string) ([]url.URL, error) { urls := make([]url.URL, len(urlsStr)) for i, s := range urlsStr { u, err := url.Parse(s) if err != nil { return nil, fmt.Errorf("parsing URL '%s': %w", s, err) } urls[i] = *u } return urls, nil } // NewEtcdStore creates a new EtcdStore. // If etcdServer is not nil, it assumes it's managing an embedded server. // endpoints are the etcd client endpoints. func NewEtcdStore(endpoints []string, etcdServer *embed.Etcd) (*EtcdStore, error) { var cli *clientv3.Client var err error if etcdServer != nil { // If embedded server is provided, use its client directly cli = v3client.New(etcdServer.Server) } else { cli, err = clientv3.New(clientv3.Config{ Endpoints: endpoints, DialTimeout: defaultDialTimeout, // TODO: Add TLS config if connecting to secure external etcd }) if err != nil { return nil, fmt.Errorf("failed to create etcd client: %w", err) } } return &EtcdStore{ client: cli, etcdServer: etcdServer, }, nil } func (s *EtcdStore) Put(ctx context.Context, key string, value []byte) error { reqCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout) defer cancel() _, err := s.client.Put(reqCtx, key, string(value)) return err } func (s *EtcdStore) Get(ctx context.Context, key string) (*KV, error) { reqCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout) defer cancel() resp, err := s.client.Get(reqCtx, key) if err != nil { return nil, err } if len(resp.Kvs) == 0 { return nil, fmt.Errorf("key not found: %s", key) // Or a specific error type } kv := resp.Kvs[0] return &KV{ Key: string(kv.Key), Value: kv.Value, Version: kv.ModRevision, }, nil } func (s *EtcdStore) Delete(ctx context.Context, key string) error { reqCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout) defer cancel() _, err := s.client.Delete(reqCtx, key) return err } func (s *EtcdStore) List(ctx context.Context, prefix string) ([]KV, error) { reqCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout) defer cancel() resp, err := s.client.Get(reqCtx, prefix, clientv3.WithPrefix()) if err != nil { return nil, err } kvs := make([]KV, len(resp.Kvs)) for i, etcdKv := range resp.Kvs { kvs[i] = KV{ Key: string(etcdKv.Key), Value: etcdKv.Value, Version: etcdKv.ModRevision, } } return kvs, nil } func (s *EtcdStore) Watch(ctx context.Context, keyOrPrefix string, startRevision int64) (<-chan WatchEvent, error) { watchChan := make(chan WatchEvent) opts := []clientv3.OpOption{clientv3.WithPrefix()} if startRevision > 0 { opts = append(opts, clientv3.WithRev(startRevision)) } etcdWatchChan := s.client.Watch(ctx, keyOrPrefix, opts...) go func() { defer close(watchChan) for resp := range etcdWatchChan { if err := resp.Err(); err != nil { log.Printf("EtcdStore watch error: %v", err) // Depending on error, might need to signal channel consumer return } for _, ev := range resp.Events { event := WatchEvent{ KV: KV{ Key: string(ev.Kv.Key), Value: ev.Kv.Value, Version: ev.Kv.ModRevision, }, } if ev.PrevKv != nil { event.PrevKV = &KV{ Key: string(ev.PrevKv.Key), Value: ev.PrevKv.Value, Version: ev.PrevKv.ModRevision, } } switch ev.Type { case clientv3.EventTypePut: event.Type = EventTypePut case clientv3.EventTypeDelete: event.Type = EventTypeDelete default: log.Printf("EtcdStore unknown event type: %v", ev.Type) continue } select { case watchChan <- event: case <-ctx.Done(): log.Printf("EtcdStore watch context cancelled for %s", keyOrPrefix) return } } } }() return watchChan, nil } func (s *EtcdStore) Close() error { s.resignMutex.Lock() if s.session != nil { // Attempt to close session gracefully, which should also resign from election // if campaign was active. s.session.Close() // This is synchronous s.session = nil s.election = nil if s.campaignDone != nil { s.campaignDone() // Ensure leadership context is cancelled s.campaignDone = nil } } s.resignMutex.Unlock() var clientErr error if s.client != nil { clientErr = s.client.Close() } // Only close the embedded server if we own it and it's not already closed if s.etcdServer != nil { // Wrap in a recover to handle potential "close of closed channel" panic func() { defer func() { if r := recover(); r != nil { // Log the panic but continue - the server was likely already closed log.Printf("Recovered from panic while closing etcd server: %v", r) } }() s.etcdServer.Close() // This stops the embedded server s.etcdServer = nil }() } if clientErr != nil { return fmt.Errorf("error closing etcd client: %w", clientErr) } return nil } func (s *EtcdStore) Campaign(ctx context.Context, leaderID string, leaseTTLSeconds int64) (leadershipCtx context.Context, err error) { s.resignMutex.Lock() defer s.resignMutex.Unlock() if s.session != nil { return nil, fmt.Errorf("campaign already in progress or session active") } s.leaderID = leaderID s.leaseTTL = leaseTTLSeconds // Create a new session session, err := concurrency.NewSession(s.client, concurrency.WithTTL(int(leaseTTLSeconds))) if err != nil { return nil, fmt.Errorf("failed to create etcd session: %w", err) } s.session = session election := concurrency.NewElection(session, leaderElectionPrefix) s.election = election // Create a cancellable context for this campaign attempt // This context will be returned and is cancelled when leadership is lost or Resign is called. campaignSpecificCtx, cancelCampaignSpecificCtx := context.WithCancel(ctx) s.campaignCtx = campaignSpecificCtx s.campaignDone = cancelCampaignSpecificCtx go func() { defer func() { // This block ensures that if the campaign goroutine exits for any reason // (e.g. session.Done(), campaign error, context cancellation), // the leadership context is cancelled. s.resignMutex.Lock() if s.campaignDone != nil { // Check if not already resigned s.campaignDone() s.campaignDone = nil // Prevent double cancel } // Clean up session if it's still this one if s.session == session { s.session.Close() // Attempt to close the session s.session = nil s.election = nil } s.resignMutex.Unlock() }() // Campaign for leadership in a blocking way // The campaignCtx (parent context) can cancel this. if err := election.Campaign(s.campaignCtx, leaderID); err != nil { log.Printf("Error during leadership campaign for %s: %v", leaderID, err) // Error here usually means context cancelled or session closed. return } // If Campaign returns without error, it means we are elected. // Keep leadership context alive until session is done or campaignCtx is cancelled. log.Printf("Successfully campaigned, %s is now leader", leaderID) // Monitor the session; if it closes, leadership is lost. select { case <-session.Done(): log.Printf("Etcd session closed for leader %s, leadership lost", leaderID) case <-s.campaignCtx.Done(): // This is campaignSpecificCtx log.Printf("Leadership campaign context cancelled for %s", leaderID) } }() return s.campaignCtx, nil } func (s *EtcdStore) Resign(ctx context.Context) error { s.resignMutex.Lock() defer s.resignMutex.Unlock() if s.election == nil || s.session == nil { log.Println("Resign called but not currently leading or no active session.") return nil // Not an error to resign if not leading } log.Printf("Resigning leadership for %s", s.leaderID) // Cancel the leadership context if s.campaignDone != nil { s.campaignDone() s.campaignDone = nil } // Resign from the election. This is a best-effort. // The context passed to Resign should be short-lived. resignCtx, cancel := context.WithTimeout(context.Background(), defaultRequestTimeout) defer cancel() if err := s.election.Resign(resignCtx); err != nil { log.Printf("Error resigning from election: %v. Session will eventually expire.", err) // Don't return error here, as session closure will handle it. } // Close the session to ensure lease is revoked quickly. if s.session != nil { err := s.session.Close() // This is synchronous s.session = nil s.election = nil if err != nil { return fmt.Errorf("error closing session during resign: %w", err) } } log.Printf("Successfully resigned leadership for %s", s.leaderID) return nil } func (s *EtcdStore) GetLeader(ctx context.Context) (string, error) { // This method needs a temporary session if one doesn't exist, // or it can try to get the leader key directly if the election pattern stores it. // concurrency.NewElection().Leader(ctx) is the way. // It requires a session. If we are campaigning, we have one. // If we are just an observer, we might need a short-lived session. s.resignMutex.Lock() currentSession := s.session s.resignMutex.Unlock() var tempSession *concurrency.Session var err error if currentSession == nil { // Create a temporary session to observe leader // Use a shorter TTL for observer session if desired, or same as campaign TTL ttl := s.leaseTTL if ttl == 0 { ttl = 10 // Default observer TTL } tempSession, err = concurrency.NewSession(s.client, concurrency.WithTTL(int(ttl))) if err != nil { return "", fmt.Errorf("failed to create temporary session for GetLeader: %w", err) } defer tempSession.Close() currentSession = tempSession } election := concurrency.NewElection(currentSession, leaderElectionPrefix) reqCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout) defer cancel() // First try to get the leader using the election API resp, err := election.Leader(reqCtx) if err != nil && err != concurrency.ErrElectionNoLeader { return "", fmt.Errorf("failed to get leader: %w", err) } if resp != nil && len(resp.Kvs) > 0 { return string(resp.Kvs[0].Value), nil } // If that fails, try to get the leader directly from the key-value store // This is a fallback mechanism since the election API might not always work as expected getResp, err := s.client.Get(reqCtx, leaderElectionPrefix, clientv3.WithPrefix()) if err != nil { return "", fmt.Errorf("failed to get leader from key-value store: %w", err) } // Find the key with the highest revision (most recent leader) var highestRev int64 var leaderValue string for _, kv := range getResp.Kvs { if kv.ModRevision > highestRev { highestRev = kv.ModRevision leaderValue = string(kv.Value) } } return leaderValue, nil } func (s *EtcdStore) DoTransaction(ctx context.Context, checks []Compare, onSuccess []Op, onFailure []Op) (bool, error) { etcdCmps := make([]clientv3.Cmp, len(checks)) for i, c := range checks { if c.ExpectedVersion == 0 { // Key should not exist etcdCmps[i] = clientv3.Compare(clientv3.ModRevision(c.Key), "=", 0) } else { // Key should exist with specific version etcdCmps[i] = clientv3.Compare(clientv3.ModRevision(c.Key), "=", c.ExpectedVersion) } } etcdThenOps := make([]clientv3.Op, len(onSuccess)) for i, o := range onSuccess { switch o.Type { case OpPut: etcdThenOps[i] = clientv3.OpPut(o.Key, string(o.Value)) case OpDelete: etcdThenOps[i] = clientv3.OpDelete(o.Key) default: return false, fmt.Errorf("unsupported operation type in transaction 'onSuccess': %v", o.Type) } } etcdElseOps := make([]clientv3.Op, len(onFailure)) for i, o := range onFailure { switch o.Type { case OpPut: etcdElseOps[i] = clientv3.OpPut(o.Key, string(o.Value)) case OpDelete: etcdElseOps[i] = clientv3.OpDelete(o.Key) default: return false, fmt.Errorf("unsupported operation type in transaction 'onFailure': %v", o.Type) } } reqCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout) defer cancel() txn := s.client.Txn(reqCtx) if len(etcdCmps) > 0 { txn = txn.If(etcdCmps...) } txn = txn.Then(etcdThenOps...) if len(etcdElseOps) > 0 { txn = txn.Else(etcdElseOps...) } resp, err := txn.Commit() if err != nil { return false, fmt.Errorf("etcd transaction commit failed: %w", err) } return resp.Succeeded, nil }