Files
dyn/internal/testutil/mock_technitium.go

239 lines
5.7 KiB
Go

package testutil
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"sync"
)
// MockTechnitiumServer simulates the Technitium DNS API for testing
type MockTechnitiumServer struct {
Server *httptest.Server
Records map[string]MockDNSRecord
mu sync.RWMutex
Username string
Password string
Token string
}
// MockDNSRecord represents a DNS record stored in the mock server
type MockDNSRecord struct {
Domain string `json:"domain"`
Type string `json:"type"`
IPAddress string `json:"ipAddress"`
TTL int `json:"ttl"`
}
// NewMockTechnitiumServer creates a new mock Technitium server
func NewMockTechnitiumServer() *MockTechnitiumServer {
mock := &MockTechnitiumServer{
Records: make(map[string]MockDNSRecord),
Username: "admin",
Password: "test-password",
Token: "test-api-token",
}
mux := http.NewServeMux()
mux.HandleFunc("/api/dns/records/add", mock.handleAddRecord)
mux.HandleFunc("/api/dns/records/delete", mock.handleDeleteRecord)
mux.HandleFunc("/api/dns/records/get", mock.handleGetRecords)
// Health check endpoint
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Technitium DNS Server Mock"))
})
mock.Server = httptest.NewServer(mux)
return mock
}
// Close shuts down the mock server
func (m *MockTechnitiumServer) Close() {
m.Server.Close()
}
// URL returns the base URL of the mock server
func (m *MockTechnitiumServer) URL() string {
return m.Server.URL
}
// GetRecords returns all stored DNS records (for testing assertions)
func (m *MockTechnitiumServer) GetRecords() map[string]MockDNSRecord {
m.mu.RLock()
defer m.mu.RUnlock()
// Return a copy to avoid race conditions
records := make(map[string]MockDNSRecord)
for k, v := range m.Records {
records[k] = v
}
return records
}
// GetRecordCount returns the number of stored records
func (m *MockTechnitiumServer) GetRecordCount() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.Records)
}
// ClearRecords removes all stored records
func (m *MockTechnitiumServer) ClearRecords() {
m.mu.Lock()
defer m.mu.Unlock()
m.Records = make(map[string]MockDNSRecord)
}
func (m *MockTechnitiumServer) authenticate(r *http.Request) bool {
// Check for API token in header
authHeader := r.Header.Get("Authorization")
if authHeader == "Basic "+m.Token {
return true
}
// Check for username/password in basic auth
user, pass, ok := r.BasicAuth()
if ok && user == m.Username && pass == m.Password {
return true
}
return false
}
func (m *MockTechnitiumServer) handleAddRecord(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if !m.authenticate(r) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "error",
"error": map[string]string{
"code": "Unauthorized",
"message": "Invalid credentials",
},
})
return
}
// Parse form data
if err := r.ParseForm(); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
domain := r.FormValue("domain")
recordType := r.FormValue("type")
ipAddress := r.FormValue("ipAddress")
if domain == "" || recordType == "" {
http.Error(w, "Missing required fields", http.StatusBadRequest)
return
}
// Store the record
m.mu.Lock()
key := fmt.Sprintf("%s:%s", domain, recordType)
m.Records[key] = MockDNSRecord{
Domain: domain,
Type: recordType,
IPAddress: ipAddress,
TTL: 300,
}
m.mu.Unlock()
// Return success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "ok",
"response": map[string]interface{}{
"domain": domain,
"type": recordType,
},
})
}
func (m *MockTechnitiumServer) handleDeleteRecord(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if !m.authenticate(r) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "error",
"error": map[string]string{
"code": "Unauthorized",
"message": "Invalid credentials",
},
})
return
}
// Parse form data
if err := r.ParseForm(); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
domain := r.FormValue("domain")
recordType := r.FormValue("type")
// Delete the record
m.mu.Lock()
key := fmt.Sprintf("%s:%s", domain, recordType)
delete(m.Records, key)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "ok",
})
}
func (m *MockTechnitiumServer) handleGetRecords(w http.ResponseWriter, r *http.Request) {
if !m.authenticate(r) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "error",
"error": map[string]string{
"code": "Unauthorized",
"message": "Invalid credentials",
},
})
return
}
domain := r.URL.Query().Get("domain")
m.mu.RLock()
var records []MockDNSRecord
for _, record := range m.Records {
if domain == "" || record.Domain == domain {
records = append(records, record)
}
}
m.mu.RUnlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "ok",
"response": map[string]interface{}{
"domain": domain,
"records": records,
},
})
}