package database import ( "context" "crypto/rand" "database/sql" "encoding/base64" "fmt" "time" "git.dws.rip/DWS/dyn/internal/models" _ "github.com/mattn/go-sqlite3" ) type DB struct { conn *sql.DB } func New(dbPath string) (*DB, error) { conn, err := sql.Open("sqlite3", dbPath) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } if err := conn.Ping(); err != nil { return nil, fmt.Errorf("failed to ping database: %w", err) } db := &DB{conn: conn} if err := db.migrate(); err != nil { return nil, fmt.Errorf("failed to migrate database: %w", err) } return db, nil } func (db *DB) migrate() error { schema := ` CREATE TABLE IF NOT EXISTS spaces ( token TEXT PRIMARY KEY, subdomain TEXT UNIQUE NOT NULL, last_ip TEXT, updated_at DATETIME, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX IF NOT EXISTS idx_subdomain ON spaces(subdomain); ` _, err := db.conn.Exec(schema) return err } func (db *DB) Close() error { return db.conn.Close() } func (db *DB) CreateSpace(ctx context.Context, subdomain string) (*models.Space, error) { token, err := generateToken() if err != nil { return nil, fmt.Errorf("failed to generate token: %w", err) } space := &models.Space{ Token: token, Subdomain: subdomain, CreatedAt: time.Now(), } query := ` INSERT INTO spaces (token, subdomain, last_ip, updated_at, created_at) VALUES (?, ?, ?, ?, ?) ` _, err = db.conn.ExecContext(ctx, query, space.Token, space.Subdomain, space.LastIP, space.UpdatedAt, space.CreatedAt, ) if err != nil { if isUniqueConstraintError(err) { return nil, fmt.Errorf("subdomain already taken") } return nil, fmt.Errorf("failed to create space: %w", err) } return space, nil } func (db *DB) GetSpaceByToken(ctx context.Context, token string) (*models.Space, error) { query := ` SELECT token, subdomain, last_ip, updated_at, created_at FROM spaces WHERE token = ? ` row := db.conn.QueryRowContext(ctx, query, token) space := &models.Space{} var updatedAt sql.NullTime err := row.Scan( &space.Token, &space.Subdomain, &space.LastIP, &updatedAt, &space.CreatedAt, ) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("failed to get space: %w", err) } if updatedAt.Valid { space.UpdatedAt = updatedAt.Time } return space, nil } func (db *DB) GetSpaceBySubdomain(ctx context.Context, subdomain string) (*models.Space, error) { query := ` SELECT token, subdomain, last_ip, updated_at, created_at FROM spaces WHERE subdomain = ? ` row := db.conn.QueryRowContext(ctx, query, subdomain) space := &models.Space{} var updatedAt sql.NullTime err := row.Scan( &space.Token, &space.Subdomain, &space.LastIP, &updatedAt, &space.CreatedAt, ) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("failed to get space: %w", err) } if updatedAt.Valid { space.UpdatedAt = updatedAt.Time } return space, nil } func (db *DB) UpdateSpaceIP(ctx context.Context, token string, ip string) error { query := ` UPDATE spaces SET last_ip = ?, updated_at = ? WHERE token = ? ` _, err := db.conn.ExecContext(ctx, query, ip, time.Now(), token) if err != nil { return fmt.Errorf("failed to update space IP: %w", err) } return nil } func (db *DB) SubdomainExists(ctx context.Context, subdomain string) (bool, error) { query := `SELECT 1 FROM spaces WHERE subdomain = ?` row := db.conn.QueryRowContext(ctx, query, subdomain) var exists int err := row.Scan(&exists) if err == sql.ErrNoRows { return false, nil } if err != nil { return false, fmt.Errorf("failed to check subdomain: %w", err) } return true, nil } func generateToken() (string, error) { bytes := make([]byte, 24) if _, err := rand.Read(bytes); err != nil { return "", err } return base64.URLEncoding.EncodeToString(bytes), nil } func isUniqueConstraintError(err error) bool { if err == nil { return false } errStr := err.Error() return contains(errStr, "UNIQUE constraint failed") || contains(errStr, "duplicate key value") } func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || containsInternal(s, substr)) } func containsInternal(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false }