package store

import (
	"context"
	"fmt"
	"log"
	"net/url"
	"sync"
	"time"

	"go.etcd.io/etcd/client/v3/concurrency"
	"go.etcd.io/etcd/server/v3/embed"
	"go.etcd.io/etcd/server/v3/etcdserver/api/v3client"

	clientv3 "go.etcd.io/etcd/client/v3"
)

const (
	defaultDialTimeout    = 5 * time.Second
	defaultRequestTimeout = 5 * time.Second
	leaderElectionPrefix  = "/kat/leader_election/"
)

// EtcdEmbedConfig holds configuration for an embedded etcd server.
type EtcdEmbedConfig struct {
	Name           string
	DataDir        string
	ClientURLs     []string // URLs for client communication
	PeerURLs       []string // URLs for peer communication
	InitialCluster string   // e.g., "node1=http://localhost:2380"
	// Add other etcd config fields as needed: LogLevel, etc.
}

// EtcdStore implements the StateStore interface using etcd.
type EtcdStore struct {
	client     *clientv3.Client
	etcdServer *embed.Etcd // Holds the embedded server instance, if any

	// For leadership
	session      *concurrency.Session
	election     *concurrency.Election
	leaderID     string
	leaseTTL     int64
	campaignCtx  context.Context
	campaignDone func()     // Cancels campaignCtx
	resignMutex  sync.Mutex // Protects session and election during resign
}

// StartEmbeddedEtcd starts an embedded etcd server based on the provided config.
func StartEmbeddedEtcd(cfg EtcdEmbedConfig) (*embed.Etcd, error) {
	embedCfg := embed.NewConfig()
	embedCfg.Name = cfg.Name
	embedCfg.Dir = cfg.DataDir
	embedCfg.InitialClusterToken = "kat-etcd-cluster" // Make this configurable if needed
	embedCfg.ForceNewCluster = false                  // Set to true only for initial bootstrap of a new cluster if needed

	lpurl, err := parseURLs(cfg.PeerURLs)
	if err != nil {
		return nil, fmt.Errorf("invalid peer URLs: %w", err)
	}
	embedCfg.ListenPeerUrls = lpurl

	// Set the advertise peer URLs to match the listen peer URLs
	embedCfg.AdvertisePeerUrls = lpurl

	// Update the initial cluster to use the same URLs
	initialCluster := fmt.Sprintf("%s=%s", cfg.Name, cfg.PeerURLs[0])
	embedCfg.InitialCluster = initialCluster

	lcurl, err := parseURLs(cfg.ClientURLs)
	if err != nil {
		return nil, fmt.Errorf("invalid client URLs: %w", err)
	}
	embedCfg.ListenClientUrls = lcurl

	// TODO: Configure logging, metrics, etc. for embedded etcd
	// embedCfg.Logger = "zap"
	// embedCfg.LogLevel = "info"

	e, err := embed.StartEtcd(embedCfg)
	if err != nil {
		return nil, fmt.Errorf("failed to start embedded etcd: %w", err)
	}

	select {
	case <-e.Server.ReadyNotify():
		log.Printf("Embedded etcd server is ready (name: %s)", cfg.Name)
	case <-time.After(60 * time.Second): // Adjust timeout as needed
		e.Server.Stop() // trigger a shutdown
		return nil, fmt.Errorf("embedded etcd server took too long to start")
	}
	return e, nil
}

func parseURLs(urlsStr []string) ([]url.URL, error) {
	urls := make([]url.URL, len(urlsStr))
	for i, s := range urlsStr {
		u, err := url.Parse(s)
		if err != nil {
			return nil, fmt.Errorf("parsing URL '%s': %w", s, err)
		}
		urls[i] = *u
	}
	return urls, nil
}

// NewEtcdStore creates a new EtcdStore.
// If etcdServer is not nil, it assumes it's managing an embedded server.
// endpoints are the etcd client endpoints.
func NewEtcdStore(endpoints []string, etcdServer *embed.Etcd) (*EtcdStore, error) {
	var cli *clientv3.Client
	var err error

	if etcdServer != nil {
		// If embedded server is provided, use its client directly
		cli = v3client.New(etcdServer.Server)
	} else {
		cli, err = clientv3.New(clientv3.Config{
			Endpoints:   endpoints,
			DialTimeout: defaultDialTimeout,
			// TODO: Add TLS config if connecting to secure external etcd
		})
		if err != nil {
			return nil, fmt.Errorf("failed to create etcd client: %w", err)
		}
	}

	return &EtcdStore{
		client:     cli,
		etcdServer: etcdServer,
	}, nil
}

func (s *EtcdStore) Put(ctx context.Context, key string, value []byte) error {
	reqCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
	defer cancel()
	_, err := s.client.Put(reqCtx, key, string(value))
	return err
}

func (s *EtcdStore) Get(ctx context.Context, key string) (*KV, error) {
	reqCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
	defer cancel()
	resp, err := s.client.Get(reqCtx, key)
	if err != nil {
		return nil, err
	}
	if len(resp.Kvs) == 0 {
		return nil, fmt.Errorf("key not found: %s", key) // Or a specific error type
	}
	kv := resp.Kvs[0]
	return &KV{
		Key:     string(kv.Key),
		Value:   kv.Value,
		Version: kv.ModRevision,
	}, nil
}

func (s *EtcdStore) Delete(ctx context.Context, key string) error {
	reqCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
	defer cancel()
	_, err := s.client.Delete(reqCtx, key)
	return err
}

func (s *EtcdStore) List(ctx context.Context, prefix string) ([]KV, error) {
	reqCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
	defer cancel()
	resp, err := s.client.Get(reqCtx, prefix, clientv3.WithPrefix())
	if err != nil {
		return nil, err
	}
	kvs := make([]KV, len(resp.Kvs))
	for i, etcdKv := range resp.Kvs {
		kvs[i] = KV{
			Key:     string(etcdKv.Key),
			Value:   etcdKv.Value,
			Version: etcdKv.ModRevision,
		}
	}
	return kvs, nil
}

func (s *EtcdStore) Watch(ctx context.Context, keyOrPrefix string, startRevision int64) (<-chan WatchEvent, error) {
	watchChan := make(chan WatchEvent)
	opts := []clientv3.OpOption{clientv3.WithPrefix()}
	if startRevision > 0 {
		opts = append(opts, clientv3.WithRev(startRevision))
	}

	etcdWatchChan := s.client.Watch(ctx, keyOrPrefix, opts...)

	go func() {
		defer close(watchChan)
		for resp := range etcdWatchChan {
			if err := resp.Err(); err != nil {
				log.Printf("EtcdStore watch error: %v", err)
				// Depending on error, might need to signal channel consumer
				return
			}
			for _, ev := range resp.Events {
				event := WatchEvent{
					KV: KV{
						Key:     string(ev.Kv.Key),
						Value:   ev.Kv.Value,
						Version: ev.Kv.ModRevision,
					},
				}
				if ev.PrevKv != nil {
					event.PrevKV = &KV{
						Key:     string(ev.PrevKv.Key),
						Value:   ev.PrevKv.Value,
						Version: ev.PrevKv.ModRevision,
					}
				}

				switch ev.Type {
				case clientv3.EventTypePut:
					event.Type = EventTypePut
				case clientv3.EventTypeDelete:
					event.Type = EventTypeDelete
				default:
					log.Printf("EtcdStore unknown event type: %v", ev.Type)
					continue
				}
				select {
				case watchChan <- event:
				case <-ctx.Done():
					log.Printf("EtcdStore watch context cancelled for %s", keyOrPrefix)
					return
				}
			}
		}
	}()

	return watchChan, nil
}

func (s *EtcdStore) Close() error {
	s.resignMutex.Lock()
	if s.session != nil {
		// Attempt to close session gracefully, which should also resign from election
		// if campaign was active.
		s.session.Close() // This is synchronous
		s.session = nil
		s.election = nil
		if s.campaignDone != nil {
			s.campaignDone() // Ensure leadership context is cancelled
			s.campaignDone = nil
		}
	}
	s.resignMutex.Unlock()

	var clientErr error
	if s.client != nil {
		clientErr = s.client.Close()
	}

	// Only close the embedded server if we own it and it's not already closed
	if s.etcdServer != nil {
		// Wrap in a recover to handle potential "close of closed channel" panic
		func() {
			defer func() {
				if r := recover(); r != nil {
					// Log the panic but continue - the server was likely already closed
					log.Printf("Recovered from panic while closing etcd server: %v", r)
				}
			}()
			s.etcdServer.Close() // This stops the embedded server
			s.etcdServer = nil
		}()
	}

	if clientErr != nil {
		return fmt.Errorf("error closing etcd client: %w", clientErr)
	}
	return nil
}

func (s *EtcdStore) Campaign(ctx context.Context, leaderID string, leaseTTLSeconds int64) (leadershipCtx context.Context, err error) {
	s.resignMutex.Lock()
	defer s.resignMutex.Unlock()

	if s.session != nil {
		return nil, fmt.Errorf("campaign already in progress or session active")
	}

	s.leaderID = leaderID
	s.leaseTTL = leaseTTLSeconds

	// Create a new session
	session, err := concurrency.NewSession(s.client, concurrency.WithTTL(int(leaseTTLSeconds)))
	if err != nil {
		return nil, fmt.Errorf("failed to create etcd session: %w", err)
	}
	s.session = session

	election := concurrency.NewElection(session, leaderElectionPrefix)
	s.election = election

	// Create a cancellable context for this campaign attempt
	// This context will be returned and is cancelled when leadership is lost or Resign is called.
	campaignSpecificCtx, cancelCampaignSpecificCtx := context.WithCancel(ctx)
	s.campaignCtx = campaignSpecificCtx
	s.campaignDone = cancelCampaignSpecificCtx

	go func() {
		defer func() {
			// This block ensures that if the campaign goroutine exits for any reason
			// (e.g. session.Done(), campaign error, context cancellation),
			// the leadership context is cancelled.
			s.resignMutex.Lock()
			if s.campaignDone != nil { // Check if not already resigned
				s.campaignDone()
				s.campaignDone = nil // Prevent double cancel
			}
			// Clean up session if it's still this one
			if s.session == session {
				s.session.Close() // Attempt to close the session
				s.session = nil
				s.election = nil
			}
			s.resignMutex.Unlock()
		}()

		// Campaign for leadership in a blocking way
		// The campaignCtx (parent context) can cancel this.
		if err := election.Campaign(s.campaignCtx, leaderID); err != nil {
			log.Printf("Error during leadership campaign for %s: %v", leaderID, err)
			// Error here usually means context cancelled or session closed.
			return
		}

		// If Campaign returns without error, it means we are elected.
		// Keep leadership context alive until session is done or campaignCtx is cancelled.
		log.Printf("Successfully campaigned, %s is now leader", leaderID)

		// Monitor the session; if it closes, leadership is lost.
		select {
		case <-session.Done():
			log.Printf("Etcd session closed for leader %s, leadership lost", leaderID)
		case <-s.campaignCtx.Done(): // This is campaignSpecificCtx
			log.Printf("Leadership campaign context cancelled for %s", leaderID)
		}
	}()

	return s.campaignCtx, nil
}

func (s *EtcdStore) Resign(ctx context.Context) error {
	s.resignMutex.Lock()
	defer s.resignMutex.Unlock()

	if s.election == nil || s.session == nil {
		log.Println("Resign called but not currently leading or no active session.")
		return nil // Not an error to resign if not leading
	}

	log.Printf("Resigning leadership for %s", s.leaderID)

	// Cancel the leadership context
	if s.campaignDone != nil {
		s.campaignDone()
		s.campaignDone = nil
	}

	// Resign from the election. This is a best-effort.
	// The context passed to Resign should be short-lived.
	resignCtx, cancel := context.WithTimeout(context.Background(), defaultRequestTimeout)
	defer cancel()
	if err := s.election.Resign(resignCtx); err != nil {
		log.Printf("Error resigning from election: %v. Session will eventually expire.", err)
		// Don't return error here, as session closure will handle it.
	}

	// Close the session to ensure lease is revoked quickly.
	if s.session != nil {
		err := s.session.Close() // This is synchronous
		s.session = nil
		s.election = nil
		if err != nil {
			return fmt.Errorf("error closing session during resign: %w", err)
		}
	}

	log.Printf("Successfully resigned leadership for %s", s.leaderID)
	return nil
}

func (s *EtcdStore) GetLeader(ctx context.Context) (string, error) {
	// This method needs a temporary session if one doesn't exist,
	// or it can try to get the leader key directly if the election pattern stores it.
	// concurrency.NewElection().Leader(ctx) is the way.
	// It requires a session. If we are campaigning, we have one.
	// If we are just an observer, we might need a short-lived session.

	s.resignMutex.Lock()
	currentSession := s.session
	s.resignMutex.Unlock()

	var tempSession *concurrency.Session
	var err error

	if currentSession == nil {
		// Create a temporary session to observe leader
		// Use a shorter TTL for observer session if desired, or same as campaign TTL
		ttl := s.leaseTTL
		if ttl == 0 {
			ttl = 10 // Default observer TTL
		}
		tempSession, err = concurrency.NewSession(s.client, concurrency.WithTTL(int(ttl)))
		if err != nil {
			return "", fmt.Errorf("failed to create temporary session for GetLeader: %w", err)
		}
		defer tempSession.Close()
		currentSession = tempSession
	}

	election := concurrency.NewElection(currentSession, leaderElectionPrefix)
	reqCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
	defer cancel()

	// First try to get the leader using the election API
	resp, err := election.Leader(reqCtx)
	if err != nil && err != concurrency.ErrElectionNoLeader {
		return "", fmt.Errorf("failed to get leader: %w", err)
	}

	if resp != nil && len(resp.Kvs) > 0 {
		return string(resp.Kvs[0].Value), nil
	}

	// If that fails, try to get the leader directly from the key-value store
	// This is a fallback mechanism since the election API might not always work as expected
	getResp, err := s.client.Get(reqCtx, leaderElectionPrefix, clientv3.WithPrefix())
	if err != nil {
		return "", fmt.Errorf("failed to get leader from key-value store: %w", err)
	}

	// Find the key with the highest revision (most recent leader)
	var highestRev int64
	var leaderValue string

	for _, kv := range getResp.Kvs {
		if kv.ModRevision > highestRev {
			highestRev = kv.ModRevision
			leaderValue = string(kv.Value)
		}
	}

	return leaderValue, nil
}

func (s *EtcdStore) DoTransaction(ctx context.Context, checks []Compare, onSuccess []Op, onFailure []Op) (bool, error) {
	etcdCmps := make([]clientv3.Cmp, len(checks))
	for i, c := range checks {
		if c.ExpectedVersion == 0 { // Key should not exist
			etcdCmps[i] = clientv3.Compare(clientv3.ModRevision(c.Key), "=", 0)
		} else { // Key should exist with specific version
			etcdCmps[i] = clientv3.Compare(clientv3.ModRevision(c.Key), "=", c.ExpectedVersion)
		}
	}

	etcdThenOps := make([]clientv3.Op, len(onSuccess))
	for i, o := range onSuccess {
		switch o.Type {
		case OpPut:
			etcdThenOps[i] = clientv3.OpPut(o.Key, string(o.Value))
		case OpDelete:
			etcdThenOps[i] = clientv3.OpDelete(o.Key)
		default:
			return false, fmt.Errorf("unsupported operation type in transaction 'onSuccess': %v", o.Type)
		}
	}

	etcdElseOps := make([]clientv3.Op, len(onFailure))
	for i, o := range onFailure {
		switch o.Type {
		case OpPut:
			etcdElseOps[i] = clientv3.OpPut(o.Key, string(o.Value))
		case OpDelete:
			etcdElseOps[i] = clientv3.OpDelete(o.Key)
		default:
			return false, fmt.Errorf("unsupported operation type in transaction 'onFailure': %v", o.Type)
		}
	}

	reqCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
	defer cancel()

	txn := s.client.Txn(reqCtx)
	if len(etcdCmps) > 0 {
		txn = txn.If(etcdCmps...)
	}
	txn = txn.Then(etcdThenOps...)

	if len(etcdElseOps) > 0 {
		txn = txn.Else(etcdElseOps...)
	}

	resp, err := txn.Commit()
	if err != nil {
		return false, fmt.Errorf("etcd transaction commit failed: %w", err)
	}

	return resp.Succeeded, nil
}