diff --git a/go.mod b/go.mod index 0730f78..9cd178a 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,8 @@ module git.dws.rip/dubey/kat -go 1.24.2 +go 1.21 + +require ( + google.golang.org/protobuf v1.31.0 + gopkg.in/yaml.v3 v3.0.1 +) diff --git a/internal/config/parse_test.go b/internal/config/parse_test.go new file mode 100644 index 0000000..a061a7c --- /dev/null +++ b/internal/config/parse_test.go @@ -0,0 +1,332 @@ +package config + +import ( + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "git.dws.rip/dubey/kat/api/v1alpha1" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func createTestClusterKatFile(t *testing.T, content string) string { + t.Helper() + tmpFile, err := ioutil.TempFile(t.TempDir(), "cluster.*.kat") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + if _, err := tmpFile.WriteString(content); err != nil { + tmpFile.Close() + t.Fatalf("Failed to write to temp file: %v", err) + } + if err := tmpFile.Close(); err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } + return tmpFile.Name() +} + +func TestParseClusterConfiguration_Valid(t *testing.T) { + yamlContent := ` +apiVersion: kat.dws.rip/v1alpha1 +kind: ClusterConfiguration +metadata: + name: test-cluster +spec: + clusterCIDR: "10.0.0.0/16" + serviceCIDR: "10.1.0.0/16" + nodeSubnetBits: 8 # /24 for nodes + apiPort: 8080 # Non-default +` + filePath := createTestClusterKatFile(t, yamlContent) + + config, err := ParseClusterConfiguration(filePath) + if err != nil { + t.Fatalf("ParseClusterConfiguration() error = %v, wantErr %v", err, false) + } + + if config.Metadata.Name != "test-cluster" { + t.Errorf("Expected metadata.name 'test-cluster', got '%s'", config.Metadata.Name) + } + if config.Spec.ClusterCidr != "10.0.0.0/16" { + t.Errorf("Expected spec.clusterCIDR '10.0.0.0/16', got '%s'", config.Spec.ClusterCidr) + } + if config.Spec.ApiPort != 8080 { + t.Errorf("Expected spec.apiPort 8080, got %d", config.Spec.ApiPort) + } + // Check a default value + if config.Spec.ClusterDomain != DefaultClusterDomain { + t.Errorf("Expected default spec.clusterDomain '%s', got '%s'", DefaultClusterDomain, config.Spec.ClusterDomain) + } + if config.Spec.NodeSubnetBits != 8 { + t.Errorf("Expected spec.nodeSubnetBits 8, got %d", config.Spec.NodeSubnetBits) + } +} + +func TestParseClusterConfiguration_FileNotFound(t *testing.T) { + _, err := ParseClusterConfiguration("nonexistent.kat") + if err == nil { + t.Fatalf("ParseClusterConfiguration() with non-existent file did not return an error") + } + if !strings.Contains(err.Error(), "file not found") { + t.Errorf("Expected 'file not found' error, got: %v", err) + } +} + +func TestParseClusterConfiguration_InvalidYAML(t *testing.T) { + filePath := createTestClusterKatFile(t, "this: is: not: valid: yaml") + _, err := ParseClusterConfiguration(filePath) + if err == nil { + t.Fatalf("ParseClusterConfiguration() with invalid YAML did not return an error") + } + if !strings.Contains(err.Error(), "unmarshal YAML") { + t.Errorf("Expected 'unmarshal YAML' error, got: %v", err) + } +} + +func TestParseClusterConfiguration_MissingRequiredFields(t *testing.T) { + tests := []struct { + name string + content string + wantErr string + }{ + { + name: "missing metadata name", + content: ` +apiVersion: kat.dws.rip/v1alpha1 +kind: ClusterConfiguration +spec: + clusterCIDR: "10.0.0.0/16" + serviceCIDR: "10.1.0.0/16" +`, + wantErr: "metadata.name is required", + }, + { + name: "missing clusterCIDR", + content: ` +apiVersion: kat.dws.rip/v1alpha1 +kind: ClusterConfiguration +metadata: + name: test-cluster +spec: + serviceCIDR: "10.1.0.0/16" +`, + wantErr: "spec.clusterCIDR is required", + }, + { + name: "invalid kind", + content: ` +apiVersion: kat.dws.rip/v1alpha1 +kind: WrongKind +metadata: + name: test-cluster +spec: + clusterCIDR: "10.0.0.0/16" + serviceCIDR: "10.1.0.0/16" +`, + wantErr: "invalid kind", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + filePath := createTestClusterKatFile(t, tt.content) + _, err := ParseClusterConfiguration(filePath) + if err == nil { + t.Fatalf("ParseClusterConfiguration() did not return an error for %s", tt.name) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("Expected error containing '%s', got: %v", tt.wantErr, err) + } + }) + } +} + +func TestSetClusterConfigDefaults(t *testing.T) { + config := &v1alpha1.ClusterConfiguration{ + Spec: &v1alpha1.ClusterConfigurationSpec{}, + } + SetClusterConfigDefaults(config) + + if config.Spec.ClusterDomain != DefaultClusterDomain { + t.Errorf("DefaultClusterDomain: got %s, want %s", config.Spec.ClusterDomain, DefaultClusterDomain) + } + if config.Spec.ApiPort != DefaultApiPort { + t.Errorf("DefaultApiPort: got %d, want %d", config.Spec.ApiPort, DefaultApiPort) + } + if config.Spec.AgentPort != DefaultAgentPort { + t.Errorf("DefaultAgentPort: got %d, want %d", config.Spec.AgentPort, DefaultAgentPort) + } + if config.Spec.EtcdClientPort != DefaultEtcdClientPort { + t.Errorf("DefaultEtcdClientPort: got %d, want %d", config.Spec.EtcdClientPort, DefaultEtcdClientPort) + } + if config.Spec.EtcdPeerPort != DefaultEtcdPeerPort { + t.Errorf("DefaultEtcdPeerPort: got %d, want %d", config.Spec.EtcdPeerPort, DefaultEtcdPeerPort) + } + if config.Spec.VolumeBasePath != DefaultVolumeBasePath { + t.Errorf("DefaultVolumeBasePath: got %s, want %s", config.Spec.VolumeBasePath, DefaultVolumeBasePath) + } + if config.Spec.BackupPath != DefaultBackupPath { + t.Errorf("DefaultBackupPath: got %s, want %s", config.Spec.BackupPath, DefaultBackupPath) + } + if config.Spec.BackupIntervalMinutes != DefaultBackupIntervalMins { + t.Errorf("DefaultBackupIntervalMins: got %d, want %d", config.Spec.BackupIntervalMinutes, DefaultBackupIntervalMins) + } + if config.Spec.AgentTickSeconds != DefaultAgentTickSeconds { + t.Errorf("DefaultAgentTickSeconds: got %d, want %d", config.Spec.AgentTickSeconds, DefaultAgentTickSeconds) + } + if config.Spec.NodeLossTimeoutSeconds != DefaultNodeLossTimeoutSec { + t.Errorf("DefaultNodeLossTimeoutSec: got %d, want %d", config.Spec.NodeLossTimeoutSeconds, DefaultNodeLossTimeoutSec) + } + if config.Spec.NodeSubnetBits != DefaultNodeSubnetBits { + t.Errorf("DefaultNodeSubnetBits: got %d, want %d", config.Spec.NodeSubnetBits, DefaultNodeSubnetBits) + } + + // Test NodeLossTimeoutSeconds derivation + configWithTick := &v1alpha1.ClusterConfiguration{ + Spec: &v1alpha1.ClusterConfigurationSpec{AgentTickSeconds: 10}, + } + SetClusterConfigDefaults(configWithTick) + if configWithTick.Spec.NodeLossTimeoutSeconds != 40 { // 10 * 4 + t.Errorf("Derived NodeLossTimeoutSeconds: got %d, want %d", configWithTick.Spec.NodeLossTimeoutSeconds, 40) + } +} + +func TestValidateClusterConfiguration_InvalidValues(t *testing.T) { + baseValidSpec := func() *v1alpha1.ClusterConfigurationSpec { + return &v1alpha1.ClusterConfigurationSpec{ + ClusterCidr: "10.0.0.0/16", + ServiceCidr: "10.1.0.0/16", + NodeSubnetBits: 8, + ClusterDomain: "test.local", + AgentPort: 10250, + ApiPort: 10251, + EtcdPeerPort: 2380, + EtcdClientPort: 2379, + VolumeBasePath: "/var/lib/kat/volumes", + BackupPath: "/var/lib/kat/backups", + BackupIntervalMinutes: 30, + AgentTickSeconds: 15, + NodeLossTimeoutSeconds:60, + } + } + baseValidMetadata := func() *v1alpha1.ObjectMeta { + return &v1alpha1.ObjectMeta{Name: "test"} + } + + tests := []struct { + name string + mutator func(cfg *v1alpha1.ClusterConfiguration) + wantErr string + }{ + {"invalid clusterCIDR", func(cfg *v1alpha1.ClusterConfiguration) { cfg.Spec.ClusterCidr = "invalid" }, "invalid spec.clusterCIDR"}, + {"invalid serviceCIDR", func(cfg *v1alpha1.ClusterConfiguration) { cfg.Spec.ServiceCidr = "invalid" }, "invalid spec.serviceCIDR"}, + {"invalid agentPort low", func(cfg *v1alpha1.ClusterConfiguration) { cfg.Spec.AgentPort = 0 }, "invalid port for agentPort"}, + {"invalid agentPort high", func(cfg *v1alpha1.ClusterConfiguration) { cfg.Spec.AgentPort = 70000 }, "invalid port for agentPort"}, + {"port conflict", func(cfg *v1alpha1.ClusterConfiguration) { cfg.Spec.ApiPort = cfg.Spec.AgentPort }, "port conflict"}, + {"invalid nodeSubnetBits low", func(cfg *v1alpha1.ClusterConfiguration) { cfg.Spec.NodeSubnetBits = 0 }, "invalid spec.nodeSubnetBits"}, + {"invalid nodeSubnetBits high", func(cfg *v1alpha1.ClusterConfiguration) { cfg.Spec.NodeSubnetBits = 32 }, "invalid spec.nodeSubnetBits"}, + {"invalid nodeSubnetBits vs clusterCIDR", func(cfg *v1alpha1.ClusterConfiguration) { cfg.Spec.ClusterCidr = "10.0.0.0/28"; cfg.Spec.NodeSubnetBits = 8 }, "results in an invalid subnet size"}, + {"invalid agentTickSeconds", func(cfg *v1alpha1.ClusterConfiguration) { cfg.Spec.AgentTickSeconds = 0 }, "agentTickSeconds must be positive"}, + {"invalid nodeLossTimeoutSeconds", func(cfg *v1alpha1.ClusterConfiguration) { cfg.Spec.NodeLossTimeoutSeconds = 0 }, "nodeLossTimeoutSeconds must be positive"}, + {"nodeLoss < agentTick", func(cfg *v1alpha1.ClusterConfiguration) { cfg.Spec.NodeLossTimeoutSeconds = cfg.Spec.AgentTickSeconds - 1 }, "nodeLossTimeoutSeconds must be greater"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &v1alpha1.ClusterConfiguration{Metadata: baseValidMetadata(), Spec: baseValidSpec()} + tt.mutator(config) + err := ValidateClusterConfiguration(config) + if err == nil { + t.Fatalf("ValidateClusterConfiguration() did not return an error for %s", tt.name) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("Expected error containing '%s', got: %v", tt.wantErr, err) + } + }) + } +} + +func TestParseQuadletDirectory_ValidSimple(t *testing.T) { + files := map[string][]byte{ + "workload.kat": []byte(` +apiVersion: kat.dws.rip/v1alpha1 +kind: Workload +metadata: + name: test-workload +spec: + type: SERVICE + source: + image: "nginx:latest" +`), + "vlb.kat": []byte(` +apiVersion: kat.dws.rip/v1alpha1 +kind: VirtualLoadBalancer +metadata: + name: test-workload # Assumed to match workload name +spec: + ports: + - containerPort: 80 +`), + } + + parsed, err := ParseQuadletDirectory(files) + if err != nil { + t.Fatalf("ParseQuadletDirectory() error = %v", err) + } + if parsed.Workload == nil { + t.Fatal("Parsed Workload is nil") + } + if parsed.Workload.Metadata.Name != "test-workload" { + t.Errorf("Expected Workload name 'test-workload', got '%s'", parsed.Workload.Metadata.Name) + } + if parsed.VirtualLoadBalancer == nil { + t.Fatal("Parsed VirtualLoadBalancer is nil") + } + if parsed.VirtualLoadBalancer.Metadata.Name != "test-workload" { + t.Errorf("Expected VLB name 'test-workload', got '%s'", parsed.VirtualLoadBalancer.Metadata.Name) + } +} + +func TestParseQuadletDirectory_MissingWorkload(t *testing.T) { + files := map[string][]byte{ + "vlb.kat": []byte(`kind: VirtualLoadBalancer`), + } + _, err := ParseQuadletDirectory(files) + if err == nil { + t.Fatal("ParseQuadletDirectory() with missing workload.kat did not return an error") + } + if !strings.Contains(err.Error(), "required Workload definition (workload.kat) not found") { + t.Errorf("Expected 'required Workload' error, got: %v", err) + } +} + +func TestParseQuadletDirectory_MultipleWorkloads(t *testing.T) { + files := map[string][]byte{ + "workload1.kat": []byte(` +apiVersion: kat.dws.rip/v1alpha1 +kind: Workload +metadata: + name: wl1 +spec: + type: SERVICE + source: {image: "img1"}`), + "workload2.kat": []byte(` +apiVersion: kat.dws.rip/v1alpha1 +kind: Workload +metadata: + name: wl2 +spec: + type: SERVICE + source: {image: "img2"}`), + } + + _, err := ParseQuadletDirectory(files) + if err == nil { + t.Fatal("ParseQuadletDirectory() with multiple workload.kat did not return an error") + } + if !strings.Contains(err.Error(), "multiple Workload definitions found") { + t.Errorf("Expected 'multiple Workload' error, got: %v", err) + } +} diff --git a/internal/utils/tar_test.go b/internal/utils/tar_test.go new file mode 100644 index 0000000..b8cea2f --- /dev/null +++ b/internal/utils/tar_test.go @@ -0,0 +1,206 @@ +package utils + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "io" + "os" + "path/filepath" + "strings" + "testing" +) + +func createTestTarGz(t *testing.T, files map[string]string, modifyHeader func(hdr *tar.Header)) io.Reader { + t.Helper() + var buf bytes.Buffer + gzw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gzw) + + for name, content := range files { + hdr := &tar.Header{ + Name: name, + Mode: 0644, + Size: int64(len(content)), + } + if modifyHeader != nil { + modifyHeader(hdr) + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("Failed to write tar header for %s: %v", name, err) + } + if _, err := tw.Write([]byte(content)); err != nil { + t.Fatalf("Failed to write tar content for %s: %v", name, err) + } + } + + if err := tw.Close(); err != nil { + t.Fatalf("Failed to close tar writer: %v", err) + } + if err := gzw.Close(); err != nil { + t.Fatalf("Failed to close gzip writer: %v", err) + } + return &buf +} + +func TestUntarQuadlets_Valid(t *testing.T) { + inputFiles := map[string]string{ + "workload.kat": "kind: Workload", + "vlb.kat": "kind: VirtualLoadBalancer", + } + reader := createTestTarGz(t, inputFiles, nil) + + outputFiles, err := UntarQuadlets(reader) + if err != nil { + t.Fatalf("UntarQuadlets() error = %v, wantErr %v", err, false) + } + + if len(outputFiles) != len(inputFiles) { + t.Errorf("Expected %d files, got %d", len(inputFiles), len(outputFiles)) + } + for name, content := range inputFiles { + outContent, ok := outputFiles[name] + if !ok { + t.Errorf("Expected file %s not found in output", name) + } + if string(outContent) != content { + t.Errorf("Content mismatch for %s: got '%s', want '%s'", name, string(outContent), content) + } + } +} + +func TestUntarQuadlets_EmptyArchive(t *testing.T) { + reader := createTestTarGz(t, map[string]string{}, nil) + _, err := UntarQuadlets(reader) + if err == nil { + t.Fatal("UntarQuadlets() with empty archive did not return an error") + } + if !strings.Contains(err.Error(), "no .kat files found") { + t.Errorf("Expected 'no .kat files found' error, got: %v", err) + } +} + +func TestUntarQuadlets_NonKatFile(t *testing.T) { + inputFiles := map[string]string{"config.txt": "some data"} + reader := createTestTarGz(t, inputFiles, nil) + _, err := UntarQuadlets(reader) + if err == nil { + t.Fatal("UntarQuadlets() with non-.kat file did not return an error") + } + if !strings.Contains(err.Error(), "only .kat files are allowed") { + t.Errorf("Expected 'only .kat files are allowed' error, got: %v", err) + } +} + +func TestUntarQuadlets_FileInSubdirectory(t *testing.T) { + inputFiles := map[string]string{"subdir/workload.kat": "kind: Workload"} + reader := createTestTarGz(t, inputFiles, nil) + _, err := UntarQuadlets(reader) + if err == nil { + t.Fatal("UntarQuadlets() with file in subdirectory did not return an error") + } + if !strings.Contains(err.Error(), "subdirectories are not allowed") { + t.Errorf("Expected 'subdirectories are not allowed' error, got: %v", err) + } +} + +func TestUntarQuadlets_PathTraversal(t *testing.T) { + inputFiles := map[string]string{"../workload.kat": "kind: Workload"} + reader := createTestTarGz(t, inputFiles, nil) + _, err := UntarQuadlets(reader) + if err == nil { + t.Fatal("UntarQuadlets() with path traversal did not return an error") + } + if !strings.Contains(err.Error(), "contains '..'") { + t.Errorf("Expected 'contains ..' error, got: %v", err) + } +} + +func TestUntarQuadlets_FileTooLarge(t *testing.T) { + largeContent := strings.Repeat("a", int(maxQuadletFileSize)+1) + inputFiles := map[string]string{"large.kat": largeContent} + reader := createTestTarGz(t, inputFiles, nil) + _, err := UntarQuadlets(reader) + if err == nil { + t.Fatal("UntarQuadlets() with large file did not return an error") + } + if !strings.Contains(err.Error(), "file large.kat in tar is too large") { + t.Errorf("Expected 'file ... too large' error, got: %v", err) + } +} + +func TestUntarQuadlets_TotalSizeTooLarge(t *testing.T) { + numFiles := (maxTotalQuadletSize / maxQuadletFileSize) + 2 + fileSize := maxQuadletFileSize / 2 + + inputFiles := make(map[string]string) + content := strings.Repeat("a", int(fileSize)) + for i := 0; i < int(numFiles); i++ { + inputFiles[filepath.Join(".", "file"+string(rune(i+'0'))+".kat")] = content + } + + reader := createTestTarGz(t, inputFiles, nil) + _, err := UntarQuadlets(reader) + if err == nil { + t.Fatal("UntarQuadlets() with total large size did not return an error") + } + if !strings.Contains(err.Error(), "total size of files in tar is too large") { + t.Errorf("Expected 'total size ... too large' error, got: %v", err) + } +} + +func TestUntarQuadlets_TooManyFiles(t *testing.T) { + inputFiles := make(map[string]string) + for i := 0; i <= maxQuadletFiles; i++ { + inputFiles[filepath.Join(".", "file"+string(rune(i+'a'))+".kat")] = "content" + } + reader := createTestTarGz(t, inputFiles, nil) + _, err := UntarQuadlets(reader) + if err == nil { + t.Fatal("UntarQuadlets() with too many files did not return an error") + } + if !strings.Contains(err.Error(), "too many files in quadlet bundle") { + t.Errorf("Expected 'too many files' error, got: %v", err) + } +} + +func TestUntarQuadlets_UnsupportedFileType(t *testing.T) { + reader := createTestTarGz(t, map[string]string{"link.kat": ""}, func(hdr *tar.Header) { + hdr.Typeflag = tar.TypeSymlink + hdr.Linkname = "target.kat" + hdr.Size = 0 + }) + _, err := UntarQuadlets(reader) + if err == nil { + t.Fatal("UntarQuadlets() with symlink did not return an error") + } + if !strings.Contains(err.Error(), "unsupported file type") { + t.Errorf("Expected 'unsupported file type' error, got: %v", err) + } +} + +func TestUntarQuadlets_CorruptedGzip(t *testing.T) { + corruptedInput := bytes.NewBufferString("this is not a valid gzip stream") + _, err := UntarQuadlets(corruptedInput) + if err == nil { + t.Fatal("UntarQuadlets() with corrupted gzip did not return an error") + } + if !strings.Contains(err.Error(), "failed to create gzip reader") && !strings.Contains(err.Error(), "gzip: invalid header") { + t.Errorf("Expected 'gzip format' or 'invalid header' error, got: %v", err) + } +} + +func TestUntarQuadlets_CorruptedTar(t *testing.T) { + var buf bytes.Buffer + gzw := gzip.NewWriter(&buf) + _, _ = gzw.Write([]byte("this is not a valid tar stream but inside gzip")) + _ = gzw.Close() + + _, err := UntarQuadlets(&buf) + if err == nil { + t.Fatal("UntarQuadlets() with corrupted tar did not return an error") + } + if !strings.Contains(err.Error(), "tar") { + t.Errorf("Expected error related to 'tar' format, got: %v", err) + } +}