From 511889e56de28744349c7c70f42431fcfc617eea Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 14 Nov 2025 18:38:31 +0100 Subject: [PATCH] Implement per instance command override on backend --- pkg/backends/backend.go | 42 ++++++++++++++++----- pkg/backends/llama_test.go | 77 ++++++++++++++++++++++++++++++++++++++ pkg/backends/mlx_test.go | 55 +++++++++++++++++++++++++++ pkg/instance/instance.go | 2 +- pkg/instance/options.go | 28 ++++++++++++++ 5 files changed, 193 insertions(+), 11 deletions(-) diff --git a/pkg/backends/backend.go b/pkg/backends/backend.go index 022e778..2fa85be 100644 --- a/pkg/backends/backend.go +++ b/pkg/backends/backend.go @@ -142,27 +142,49 @@ func (o *Options) getBackend() backend { } } -func (o *Options) isDockerEnabled(backend *config.BackendSettings) bool { - if backend.Docker != nil && backend.Docker.Enabled && o.BackendType != BackendTypeMlxLm { - return true +// isDockerEnabled checks if Docker is enabled with an optional override +func (o *Options) isDockerEnabled(backend *config.BackendSettings, dockerEnabledOverride *bool) bool { + // Check if backend supports Docker + if backend.Docker == nil { + return false } - return false + + // MLX doesn't support Docker + if o.BackendType == BackendTypeMlxLm { + return false + } + + // Check for instance-level override + if dockerEnabledOverride != nil { + return *dockerEnabledOverride + } + + // Fall back to config value + return backend.Docker.Enabled } func (o *Options) IsDockerEnabled(backendConfig *config.BackendConfig) bool { backendSettings := o.getBackendSettings(backendConfig) - return o.isDockerEnabled(backendSettings) + return o.isDockerEnabled(backendSettings, nil) } // GetCommand builds the command to run the backend -func (o *Options) GetCommand(backendConfig *config.BackendConfig) string { - +func (o *Options) GetCommand(backendConfig *config.BackendConfig, dockerEnabled *bool, commandOverride string) string { backendSettings := o.getBackendSettings(backendConfig) - if o.isDockerEnabled(backendSettings) { + // Determine if Docker is enabled + useDocker := o.isDockerEnabled(backendSettings, dockerEnabled) + + if useDocker { return "docker" } + // Check for command override (only applies when not in Docker mode) + if commandOverride != "" { + return commandOverride + } + + // Fall back to config command return backendSettings.Command } @@ -177,7 +199,7 @@ func (o *Options) BuildCommandArgs(backendConfig *config.BackendConfig) []string return args } - if o.isDockerEnabled(backendSettings) { + if o.isDockerEnabled(backendSettings, nil) { // For Docker, start with Docker args args = append(args, backendSettings.Docker.Args...) args = append(args, backendSettings.Docker.Image) @@ -202,7 +224,7 @@ func (o *Options) BuildEnvironment(backendConfig *config.BackendConfig, environm maps.Copy(env, backendSettings.Environment) } - if o.isDockerEnabled(backendSettings) { + if o.isDockerEnabled(backendSettings, nil) { if backendSettings.Docker.Environment != nil { maps.Copy(env, backendSettings.Docker.Environment) } diff --git a/pkg/backends/llama_test.go b/pkg/backends/llama_test.go index 961967b..4440092 100644 --- a/pkg/backends/llama_test.go +++ b/pkg/backends/llama_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "llamactl/pkg/backends" + "llamactl/pkg/config" "llamactl/pkg/testutil" "reflect" "testing" @@ -549,3 +550,79 @@ func TestParseLlamaCommand_ExtraArgs(t *testing.T) { }) } } +func TestLlamaCppGetCommand_WithOverrides(t *testing.T) { + tests := []struct { + name string + dockerInConfig bool + dockerEnabled *bool + commandOverride string + expected string + }{ + { + name: "no overrides - use config command", + dockerInConfig: false, + dockerEnabled: nil, + commandOverride: "", + expected: "/usr/bin/llama-server", + }, + { + name: "override to enable docker", + dockerInConfig: false, + dockerEnabled: boolPtr(true), + commandOverride: "", + expected: "docker", + }, + { + name: "override to disable docker", + dockerInConfig: true, + dockerEnabled: boolPtr(false), + commandOverride: "", + expected: "/usr/bin/llama-server", + }, + { + name: "command override", + dockerInConfig: false, + dockerEnabled: nil, + commandOverride: "/custom/llama-server", + expected: "/custom/llama-server", + }, + { + name: "docker takes precedence over command override", + dockerInConfig: false, + dockerEnabled: boolPtr(true), + commandOverride: "/custom/llama-server", + expected: "docker", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + backendConfig := &config.BackendConfig{ + LlamaCpp: config.BackendSettings{ + Command: "/usr/bin/llama-server", + Docker: &config.DockerSettings{ + Enabled: tt.dockerInConfig, + Image: "test-image", + }, + }, + } + + opts := backends.Options{ + BackendType: backends.BackendTypeLlamaCpp, + LlamaServerOptions: &backends.LlamaServerOptions{ + Model: "test-model.gguf", + }, + } + + result := opts.GetCommand(backendConfig, tt.dockerEnabled, tt.commandOverride) + if result != tt.expected { + t.Errorf("GetCommand() = %v, want %v", result, tt.expected) + } + }) + } +} + +// Helper function to create bool pointer +func boolPtr(b bool) *bool { + return &b +} diff --git a/pkg/backends/mlx_test.go b/pkg/backends/mlx_test.go index f8a2ee5..f24d1a5 100644 --- a/pkg/backends/mlx_test.go +++ b/pkg/backends/mlx_test.go @@ -2,6 +2,7 @@ package backends_test import ( "llamactl/pkg/backends" + "llamactl/pkg/config" "llamactl/pkg/testutil" "testing" ) @@ -274,3 +275,57 @@ func TestParseMlxCommand_ExtraArgs(t *testing.T) { }) } } +func TestMlxGetCommand_NoDocker(t *testing.T) { + // MLX backend should never use Docker + backendConfig := &config.BackendConfig{ + MLX: config.BackendSettings{ + Command: "/usr/bin/mlx-server", + Docker: &config.DockerSettings{ + Enabled: true, // Even if enabled in config + Image: "test-image", + }, + }, + } + + opts := backends.Options{ + BackendType: backends.BackendTypeMlxLm, + MlxServerOptions: &backends.MlxServerOptions{ + Model: "test-model", + }, + } + + tests := []struct { + name string + dockerEnabled *bool + commandOverride string + expected string + }{ + { + name: "ignores docker in config", + dockerEnabled: nil, + commandOverride: "", + expected: "/usr/bin/mlx-server", + }, + { + name: "ignores docker override", + dockerEnabled: boolPtr(true), + commandOverride: "", + expected: "/usr/bin/mlx-server", + }, + { + name: "respects command override", + dockerEnabled: nil, + commandOverride: "/custom/mlx-server", + expected: "/custom/mlx-server", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := opts.GetCommand(backendConfig, tt.dockerEnabled, tt.commandOverride) + if result != tt.expected { + t.Errorf("GetCommand() = %v, want %v", result, tt.expected) + } + }) + } +} diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go index 5e5dc27..e9a16f6 100644 --- a/pkg/instance/instance.go +++ b/pkg/instance/instance.go @@ -255,7 +255,7 @@ func (i *Instance) getCommand() string { return "" } - return opts.BackendOptions.GetCommand(i.globalBackendSettings) + return opts.BackendOptions.GetCommand(i.globalBackendSettings, opts.DockerEnabled, opts.CommandOverride) } func (i *Instance) buildCommandArgs() []string { diff --git a/pkg/instance/options.go b/pkg/instance/options.go index 0c4b582..59cb31f 100644 --- a/pkg/instance/options.go +++ b/pkg/instance/options.go @@ -5,6 +5,7 @@ import ( "fmt" "llamactl/pkg/backends" "llamactl/pkg/config" + "llamactl/pkg/validation" "log" "slices" "sync" @@ -22,6 +23,11 @@ type Options struct { IdleTimeout *int `json:"idle_timeout,omitempty"` // minutes // Environment variables Environment map[string]string `json:"environment,omitempty"` + + // Execution context overrides + DockerEnabled *bool `json:"docker_enabled,omitempty"` + CommandOverride string `json:"command_override,omitempty"` + // Assigned nodes Nodes map[string]struct{} `json:"-"` // Backend options @@ -200,6 +206,28 @@ func (c *Options) validateAndApplyDefaults(name string, globalSettings *config.I *c.IdleTimeout = 0 } + // Validate docker_enabled and command_override relationship + if c.DockerEnabled != nil && *c.DockerEnabled && c.CommandOverride != "" { + log.Printf("Instance %s: command_override cannot be set when docker_enabled is true, ignoring command_override", name) + c.CommandOverride = "" // Clear invalid configuration + } + + // Validate command_override if set + if c.CommandOverride != "" { + if err := validation.ValidateStringForInjection(c.CommandOverride); err != nil { + log.Printf("Instance %s: invalid command_override: %v, clearing value", name, err) + c.CommandOverride = "" // Clear invalid value + } + } + + // Validate docker_enabled for MLX backend + if c.BackendOptions.BackendType == backends.BackendTypeMlxLm { + if c.DockerEnabled != nil && *c.DockerEnabled { + log.Printf("Instance %s: docker_enabled is not supported for MLX backend, ignoring", name) + c.DockerEnabled = nil // Clear invalid configuration + } + } + // Apply defaults from global settings for nil fields if globalSettings != nil { if c.AutoRestart == nil {