From 55f671c3544f13c9577965c40c31962918f78436 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 19 Oct 2025 17:41:08 +0200 Subject: [PATCH] Refactor backend options handling and validation --- pkg/backends/backend.go | 303 ++++++++++++++++++++++++++++++++++- pkg/backends/llama.go | 21 +++ pkg/backends/mlx.go | 23 +++ pkg/backends/vllm.go | 24 +++ pkg/instance/instance.go | 115 ++++--------- pkg/instance/options.go | 227 +++++++------------------- pkg/instance/process.go | 47 +----- pkg/instance/proxy.go | 18 +-- pkg/manager/operations.go | 36 +---- pkg/validation/validation.go | 91 +---------- 10 files changed, 480 insertions(+), 425 deletions(-) diff --git a/pkg/backends/backend.go b/pkg/backends/backend.go index 1dcb87c..2f829b9 100644 --- a/pkg/backends/backend.go +++ b/pkg/backends/backend.go @@ -1,5 +1,13 @@ package backends +import ( + "encoding/json" + "fmt" + "llamactl/pkg/config" + "llamactl/pkg/validation" + "maps" +) + type BackendType string const ( @@ -13,10 +21,301 @@ type Options struct { BackendType BackendType `json:"backend_type"` BackendOptions map[string]any `json:"backend_options,omitempty"` - Nodes map[string]struct{} `json:"-"` - // Backend-specific options LlamaServerOptions *LlamaServerOptions `json:"-"` MlxServerOptions *MlxServerOptions `json:"-"` VllmServerOptions *VllmServerOptions `json:"-"` } + +func (o *Options) UnmarshalJSON(data []byte) error { + // Use anonymous struct to avoid recursion + type Alias Options + aux := &struct { + *Alias + }{ + Alias: (*Alias)(o), + } + + if err := json.Unmarshal(data, aux); err != nil { + return err + } + + // Parse backend-specific options + switch o.BackendType { + case BackendTypeLlamaCpp: + if o.BackendOptions != nil { + // Convert map to JSON and then unmarshal to LlamaServerOptions + optionsData, err := json.Marshal(o.BackendOptions) + if err != nil { + return fmt.Errorf("failed to marshal backend options: %w", err) + } + + o.LlamaServerOptions = &LlamaServerOptions{} + if err := json.Unmarshal(optionsData, o.LlamaServerOptions); err != nil { + return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err) + } + } + case BackendTypeMlxLm: + if o.BackendOptions != nil { + optionsData, err := json.Marshal(o.BackendOptions) + if err != nil { + return fmt.Errorf("failed to marshal backend options: %w", err) + } + + o.MlxServerOptions = &MlxServerOptions{} + if err := json.Unmarshal(optionsData, o.MlxServerOptions); err != nil { + return fmt.Errorf("failed to unmarshal MLX options: %w", err) + } + } + case BackendTypeVllm: + if o.BackendOptions != nil { + optionsData, err := json.Marshal(o.BackendOptions) + if err != nil { + return fmt.Errorf("failed to marshal backend options: %w", err) + } + + o.VllmServerOptions = &VllmServerOptions{} + if err := json.Unmarshal(optionsData, o.VllmServerOptions); err != nil { + return fmt.Errorf("failed to unmarshal vLLM options: %w", err) + } + } + } + + return nil +} + +func (o *Options) MarshalJSON() ([]byte, error) { + // Use anonymous struct to avoid recursion + type Alias Options + aux := &struct { + *Alias + }{ + Alias: (*Alias)(o), + } + + // Prepare BackendOptions map + if o.BackendOptions == nil { + o.BackendOptions = make(map[string]any) + } + + // Populate BackendOptions based on backend-specific options + switch o.BackendType { + case BackendTypeLlamaCpp: + if o.LlamaServerOptions != nil { + optionsData, err := json.Marshal(o.LlamaServerOptions) + if err != nil { + return nil, fmt.Errorf("failed to marshal llama.cpp options: %w", err) + } + if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil { + return nil, fmt.Errorf("failed to unmarshal llama.cpp options to map: %w", err) + } + } + case BackendTypeMlxLm: + if o.MlxServerOptions != nil { + optionsData, err := json.Marshal(o.MlxServerOptions) + if err != nil { + return nil, fmt.Errorf("failed to marshal MLX options: %w", err) + } + if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil { + return nil, fmt.Errorf("failed to unmarshal MLX options to map: %w", err) + } + } + case BackendTypeVllm: + if o.VllmServerOptions != nil { + optionsData, err := json.Marshal(o.VllmServerOptions) + if err != nil { + return nil, fmt.Errorf("failed to marshal vLLM options: %w", err) + } + if err := json.Unmarshal(optionsData, &o.BackendOptions); err != nil { + return nil, fmt.Errorf("failed to unmarshal vLLM options to map: %w", err) + } + } + } + + return json.Marshal(aux) +} + +func getBackendSettings(o *Options, backendConfig *config.BackendConfig) *config.BackendSettings { + switch o.BackendType { + case BackendTypeLlamaCpp: + return &backendConfig.LlamaCpp + case BackendTypeMlxLm: + return &backendConfig.MLX + case BackendTypeVllm: + return &backendConfig.VLLM + default: + return nil + } +} + +func (o *Options) isDockerEnabled(backend *config.BackendSettings) bool { + if backend.Docker != nil && backend.Docker.Enabled && o.BackendType != BackendTypeMlxLm { + return true + } + return false +} + +func (o *Options) IsDockerEnabled(backendConfig *config.BackendConfig) bool { + backendSettings := getBackendSettings(o, backendConfig) + return o.isDockerEnabled(backendSettings) +} + +// GetCommand builds the command to run the backend +func (o *Options) GetCommand(backendConfig *config.BackendConfig) string { + + backendSettings := getBackendSettings(o, backendConfig) + + if o.isDockerEnabled(backendSettings) { + return "docker" + } + + return backendSettings.Command +} + +// buildCommandArgs builds command line arguments for the backend +func (o *Options) BuildCommandArgs(backendConfig *config.BackendConfig) []string { + + var args []string + + backendSettings := getBackendSettings(o, backendConfig) + + if o.isDockerEnabled(backendSettings) { + // For Docker, start with Docker args + args = append(args, backendSettings.Docker.Args...) + args = append(args, backendSettings.Docker.Image) + + switch o.BackendType { + case BackendTypeLlamaCpp: + if o.LlamaServerOptions != nil { + args = append(args, o.LlamaServerOptions.BuildDockerArgs()...) + } + case BackendTypeVllm: + if o.VllmServerOptions != nil { + args = append(args, o.VllmServerOptions.BuildDockerArgs()...) + } + } + + } else { + // For native execution, start with backend args + args = append(args, backendSettings.Args...) + + switch o.BackendType { + case BackendTypeLlamaCpp: + if o.LlamaServerOptions != nil { + args = append(args, o.LlamaServerOptions.BuildCommandArgs()...) + } + case BackendTypeMlxLm: + if o.MlxServerOptions != nil { + args = append(args, o.MlxServerOptions.BuildCommandArgs()...) + } + case BackendTypeVllm: + if o.VllmServerOptions != nil { + args = append(args, o.VllmServerOptions.BuildCommandArgs()...) + } + } + } + + return args +} + +// BuildEnvironment builds the environment variables for the backend process +func (o *Options) BuildEnvironment(backendConfig *config.BackendConfig, environment map[string]string) map[string]string { + + backendSettings := getBackendSettings(o, backendConfig) + env := map[string]string{} + + if backendSettings.Environment != nil { + maps.Copy(env, backendSettings.Environment) + } + + if o.isDockerEnabled(backendSettings) { + if backendSettings.Docker.Environment != nil { + maps.Copy(env, backendSettings.Docker.Environment) + } + } + + if environment != nil { + maps.Copy(env, environment) + } + + return env +} + +func (o *Options) GetPort() int { + if o != nil { + switch o.BackendType { + case BackendTypeLlamaCpp: + if o.LlamaServerOptions != nil { + return o.LlamaServerOptions.Port + } + case BackendTypeMlxLm: + if o.MlxServerOptions != nil { + return o.MlxServerOptions.Port + } + case BackendTypeVllm: + if o.VllmServerOptions != nil { + return o.VllmServerOptions.Port + } + } + } + return 0 +} + +func (o *Options) SetPort(port int) { + if o != nil { + switch o.BackendType { + case BackendTypeLlamaCpp: + if o.LlamaServerOptions != nil { + o.LlamaServerOptions.Port = port + } + case BackendTypeMlxLm: + if o.MlxServerOptions != nil { + o.MlxServerOptions.Port = port + } + case BackendTypeVllm: + if o.VllmServerOptions != nil { + o.VllmServerOptions.Port = port + } + } + } +} + +func (o *Options) GetHost() string { + if o != nil { + switch o.BackendType { + case BackendTypeLlamaCpp: + if o.LlamaServerOptions != nil { + return o.LlamaServerOptions.Host + } + case BackendTypeMlxLm: + if o.MlxServerOptions != nil { + return o.MlxServerOptions.Host + } + case BackendTypeVllm: + if o.VllmServerOptions != nil { + return o.VllmServerOptions.Host + } + } + } + return "localhost" +} + +func (o *Options) GetResponseHeaders(backendConfig *config.BackendConfig) map[string]string { + backendSettings := getBackendSettings(o, backendConfig) + return backendSettings.ResponseHeaders +} + +// ValidateInstanceOptions performs validation based on backend type +func (o *Options) ValidateInstanceOptions() error { + // Validate based on backend type + switch o.BackendType { + case BackendTypeLlamaCpp: + return validateLlamaCppOptions(o.LlamaServerOptions) + case BackendTypeMlxLm: + return validateMlxOptions(o.MlxServerOptions) + case BackendTypeVllm: + return validateVllmOptions(o.VllmServerOptions) + default: + return validation.ValidationError(fmt.Errorf("unsupported backend type: %s", o.BackendType)) + } +} diff --git a/pkg/backends/llama.go b/pkg/backends/llama.go index da4a42a..7cdc182 100644 --- a/pkg/backends/llama.go +++ b/pkg/backends/llama.go @@ -2,6 +2,8 @@ package backends import ( "encoding/json" + "fmt" + "llamactl/pkg/validation" "reflect" "strconv" ) @@ -364,3 +366,22 @@ func ParseLlamaCommand(command string) (*LlamaServerOptions, error) { return &llamaOptions, nil } + +// validateLlamaCppOptions validates llama.cpp specific options +func validateLlamaCppOptions(options *LlamaServerOptions) error { + if options == nil { + return validation.ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend")) + } + + // Use reflection to check all string fields for injection patterns + if err := validation.ValidateStructStrings(options, ""); err != nil { + return err + } + + // Basic network validation for port + if options.Port < 0 || options.Port > 65535 { + return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port)) + } + + return nil +} diff --git a/pkg/backends/mlx.go b/pkg/backends/mlx.go index 66ab1fa..00f62da 100644 --- a/pkg/backends/mlx.go +++ b/pkg/backends/mlx.go @@ -1,5 +1,10 @@ package backends +import ( + "fmt" + "llamactl/pkg/validation" +) + type MlxServerOptions struct { // Basic connection options Model string `json:"model,omitempty"` @@ -50,3 +55,21 @@ func ParseMlxCommand(command string) (*MlxServerOptions, error) { return &mlxOptions, nil } + +// validateMlxOptions validates MLX backend specific options +func validateMlxOptions(options *MlxServerOptions) error { + if options == nil { + return validation.ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend")) + } + + if err := validation.ValidateStructStrings(options, ""); err != nil { + return err + } + + // Basic network validation for port + if options.Port < 0 || options.Port > 65535 { + return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port)) + } + + return nil +} diff --git a/pkg/backends/vllm.go b/pkg/backends/vllm.go index 047aca8..0a8c6f3 100644 --- a/pkg/backends/vllm.go +++ b/pkg/backends/vllm.go @@ -1,5 +1,10 @@ package backends +import ( + "fmt" + "llamactl/pkg/validation" +) + // vllmMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated var vllmMultiValuedFlags = map[string]bool{ "api-key": true, @@ -194,3 +199,22 @@ func ParseVllmCommand(command string) (*VllmServerOptions, error) { return &vllmOptions, nil } + +// validateVllmOptions validates vLLM backend specific options +func validateVllmOptions(options *VllmServerOptions) error { + if options == nil { + return validation.ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend")) + } + + // Use reflection to check all string fields for injection patterns + if err := validation.ValidateStructStrings(options, ""); err != nil { + return err + } + + // Basic network validation for port + if options.Port < 0 || options.Port > 65535 { + return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port)) + } + + return nil +} diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go index ee6757d..762d49e 100644 --- a/pkg/instance/instance.go +++ b/pkg/instance/instance.go @@ -3,7 +3,6 @@ package instance import ( "encoding/json" "fmt" - "llamactl/pkg/backends" "llamactl/pkg/config" "log" "net/http/httputil" @@ -124,48 +123,6 @@ func (i *Instance) IsRunning() bool { return i.status.isRunning() } -func (i *Instance) GetPort() int { - opts := i.GetOptions() - if opts != nil { - switch opts.BackendType { - case backends.BackendTypeLlamaCpp: - if opts.LlamaServerOptions != nil { - return opts.LlamaServerOptions.Port - } - case backends.BackendTypeMlxLm: - if opts.MlxServerOptions != nil { - return opts.MlxServerOptions.Port - } - case backends.BackendTypeVllm: - if opts.VllmServerOptions != nil { - return opts.VllmServerOptions.Port - } - } - } - return 0 -} - -func (i *Instance) GetHost() string { - opts := i.GetOptions() - if opts != nil { - switch opts.BackendType { - case backends.BackendTypeLlamaCpp: - if opts.LlamaServerOptions != nil { - return opts.LlamaServerOptions.Host - } - case backends.BackendTypeMlxLm: - if opts.MlxServerOptions != nil { - return opts.MlxServerOptions.Host - } - case backends.BackendTypeVllm: - if opts.VllmServerOptions != nil { - return opts.VllmServerOptions.Host - } - } - } - return "" -} - // SetOptions sets the options func (i *Instance) SetOptions(opts *Options) { if opts == nil { @@ -198,6 +155,20 @@ func (i *Instance) SetTimeProvider(tp TimeProvider) { } } +func (i *Instance) GetHost() string { + if i.options == nil { + return "localhost" + } + return i.options.GetHost() +} + +func (i *Instance) GetPort() int { + if i.options == nil { + return 0 + } + return i.options.GetPort() +} + // GetProxy returns the reverse proxy for this instance func (i *Instance) GetProxy() (*httputil.ReverseProxy, error) { if i.proxy == nil { @@ -266,39 +237,31 @@ func (i *Instance) ShouldTimeout() bool { return i.proxy.shouldTimeout() } -// getBackendHostPort extracts the host and port from instance options -// Returns the configured host and port for the backend -func (i *Instance) getBackendHostPort() (string, int) { +func (i *Instance) getCommand() string { opts := i.GetOptions() if opts == nil { - return "localhost", 0 + return "" } - var host string - var port int - switch opts.BackendType { - case backends.BackendTypeLlamaCpp: - if opts.LlamaServerOptions != nil { - host = opts.LlamaServerOptions.Host - port = opts.LlamaServerOptions.Port - } - case backends.BackendTypeMlxLm: - if opts.MlxServerOptions != nil { - host = opts.MlxServerOptions.Host - port = opts.MlxServerOptions.Port - } - case backends.BackendTypeVllm: - if opts.VllmServerOptions != nil { - host = opts.VllmServerOptions.Host - port = opts.VllmServerOptions.Port - } + return opts.BackendOptions.GetCommand(i.globalBackendSettings) +} + +func (i *Instance) buildCommandArgs() []string { + opts := i.GetOptions() + if opts == nil { + return nil } - if host == "" { - host = "localhost" + return opts.BackendOptions.BuildCommandArgs(i.globalBackendSettings) +} + +func (i *Instance) buildEnvironment() map[string]string { + opts := i.GetOptions() + if opts == nil { + return nil } - return host, port + return opts.BackendOptions.BuildEnvironment(i.globalBackendSettings, opts.Environment) } // MarshalJSON implements json.Marshaler for Instance @@ -307,21 +270,7 @@ func (i *Instance) MarshalJSON() ([]byte, error) { opts := i.GetOptions() // Determine if docker is enabled for this instance's backend - var dockerEnabled bool - if opts != nil { - switch opts.BackendType { - case backends.BackendTypeLlamaCpp: - if i.globalBackendSettings != nil && i.globalBackendSettings.LlamaCpp.Docker != nil && i.globalBackendSettings.LlamaCpp.Docker.Enabled { - dockerEnabled = true - } - case backends.BackendTypeVllm: - if i.globalBackendSettings != nil && i.globalBackendSettings.VLLM.Docker != nil && i.globalBackendSettings.VLLM.Docker.Enabled { - dockerEnabled = true - } - case backends.BackendTypeMlxLm: - // MLX does not support docker currently - } - } + dockerEnabled := opts.BackendOptions.IsDockerEnabled(i.globalBackendSettings) return json.Marshal(&struct { Name string `json:"name"` diff --git a/pkg/instance/options.go b/pkg/instance/options.go index d53ec09..375bcbe 100644 --- a/pkg/instance/options.go +++ b/pkg/instance/options.go @@ -6,7 +6,6 @@ import ( "llamactl/pkg/backends" "llamactl/pkg/config" "log" - "maps" "slices" "sync" ) @@ -21,18 +20,12 @@ type Options struct { OnDemandStart *bool `json:"on_demand_start,omitempty"` // Idle timeout IdleTimeout *int `json:"idle_timeout,omitempty"` // minutes - //Environment variables + // Environment variables Environment map[string]string `json:"environment,omitempty"` - - BackendType backends.BackendType `json:"backend_type"` - BackendOptions map[string]any `json:"backend_options,omitempty"` - + // Assigned nodes Nodes map[string]struct{} `json:"-"` - - // Backend-specific options - LlamaServerOptions *backends.LlamaServerOptions `json:"-"` - MlxServerOptions *backends.MlxServerOptions `json:"-"` - VllmServerOptions *backends.VllmServerOptions `json:"-"` + // Backend options + BackendOptions backends.Options `json:"-"` } // options wraps Options with thread-safe access (unexported). @@ -62,6 +55,18 @@ func (o *options) set(opts *Options) { o.opts = opts } +func (o *options) GetHost() string { + o.mu.RLock() + defer o.mu.RUnlock() + return o.opts.BackendOptions.GetHost() +} + +func (o *options) GetPort() int { + o.mu.RLock() + defer o.mu.RUnlock() + return o.opts.BackendOptions.GetPort() +} + // MarshalJSON implements json.Marshaler for options wrapper func (o *options) MarshalJSON() ([]byte, error) { o.mu.RLock() @@ -85,7 +90,9 @@ func (c *Options) UnmarshalJSON(data []byte) error { // Use anonymous struct to avoid recursion type Alias Options aux := &struct { - Nodes []string `json:"nodes,omitempty"` // Accept JSON array + Nodes []string `json:"nodes,omitempty"` + BackendType backends.BackendType `json:"backend_type"` + BackendOptions map[string]any `json:"backend_options,omitempty"` *Alias }{ Alias: (*Alias)(c), @@ -103,47 +110,27 @@ func (c *Options) UnmarshalJSON(data []byte) error { } } - // Parse backend-specific options - switch c.BackendType { - case backends.BackendTypeLlamaCpp: - if c.BackendOptions != nil { - // Convert map to JSON and then unmarshal to LlamaServerOptions - optionsData, err := json.Marshal(c.BackendOptions) - if err != nil { - return fmt.Errorf("failed to marshal backend options: %w", err) - } + // Create backend options struct and unmarshal + c.BackendOptions = backends.Options{ + BackendType: aux.BackendType, + BackendOptions: aux.BackendOptions, + } - c.LlamaServerOptions = &backends.LlamaServerOptions{} - if err := json.Unmarshal(optionsData, c.LlamaServerOptions); err != nil { - return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err) - } - } - case backends.BackendTypeMlxLm: - if c.BackendOptions != nil { - optionsData, err := json.Marshal(c.BackendOptions) - if err != nil { - return fmt.Errorf("failed to marshal backend options: %w", err) - } + // Marshal the backend options to JSON for proper unmarshaling + backendJson, err := json.Marshal(struct { + BackendType backends.BackendType `json:"backend_type"` + BackendOptions map[string]any `json:"backend_options,omitempty"` + }{ + BackendType: aux.BackendType, + BackendOptions: aux.BackendOptions, + }) + if err != nil { + return fmt.Errorf("failed to marshal backend options: %w", err) + } - c.MlxServerOptions = &backends.MlxServerOptions{} - if err := json.Unmarshal(optionsData, c.MlxServerOptions); err != nil { - return fmt.Errorf("failed to unmarshal MLX options: %w", err) - } - } - case backends.BackendTypeVllm: - if c.BackendOptions != nil { - optionsData, err := json.Marshal(c.BackendOptions) - if err != nil { - return fmt.Errorf("failed to marshal backend options: %w", err) - } - - c.VllmServerOptions = &backends.VllmServerOptions{} - if err := json.Unmarshal(optionsData, c.VllmServerOptions); err != nil { - return fmt.Errorf("failed to unmarshal vLLM options: %w", err) - } - } - default: - return fmt.Errorf("unknown backend type: %s", c.BackendType) + // Unmarshal into the backends.Options struct to trigger its custom unmarshaling + if err := json.Unmarshal(backendJson, &c.BackendOptions); err != nil { + return fmt.Errorf("failed to unmarshal backend options: %w", err) } return nil @@ -154,7 +141,9 @@ func (c *Options) MarshalJSON() ([]byte, error) { // Use anonymous struct to avoid recursion type Alias Options aux := struct { - Nodes []string `json:"nodes,omitempty"` // Output as JSON array + Nodes []string `json:"nodes,omitempty"` // Output as JSON array + BackendType backends.BackendType `json:"backend_type"` + BackendOptions map[string]any `json:"backend_options,omitempty"` *Alias }{ Alias: (*Alias)(c), @@ -170,52 +159,25 @@ func (c *Options) MarshalJSON() ([]byte, error) { slices.Sort(aux.Nodes) } - // Convert backend-specific options back to BackendOptions map for JSON - switch c.BackendType { - case backends.BackendTypeLlamaCpp: - if c.LlamaServerOptions != nil { - data, err := json.Marshal(c.LlamaServerOptions) - if err != nil { - return nil, fmt.Errorf("failed to marshal llama server options: %w", err) - } + // Set backend type + aux.BackendType = c.BackendOptions.BackendType - var backendOpts map[string]any - if err := json.Unmarshal(data, &backendOpts); err != nil { - return nil, fmt.Errorf("failed to unmarshal to map: %w", err) - } - - aux.BackendOptions = backendOpts - } - case backends.BackendTypeMlxLm: - if c.MlxServerOptions != nil { - data, err := json.Marshal(c.MlxServerOptions) - if err != nil { - return nil, fmt.Errorf("failed to marshal MLX server options: %w", err) - } - - var backendOpts map[string]any - if err := json.Unmarshal(data, &backendOpts); err != nil { - return nil, fmt.Errorf("failed to unmarshal to map: %w", err) - } - - aux.BackendOptions = backendOpts - } - case backends.BackendTypeVllm: - if c.VllmServerOptions != nil { - data, err := json.Marshal(c.VllmServerOptions) - if err != nil { - return nil, fmt.Errorf("failed to marshal vLLM server options: %w", err) - } - - var backendOpts map[string]any - if err := json.Unmarshal(data, &backendOpts); err != nil { - return nil, fmt.Errorf("failed to unmarshal to map: %w", err) - } - - aux.BackendOptions = backendOpts - } + // Marshal the backends.Options struct to get the properly formatted backend options + backendData, err := json.Marshal(c.BackendOptions) + if err != nil { + return nil, fmt.Errorf("failed to marshal backend options: %w", err) } + // Unmarshal into a temporary struct to extract the backend_options map + var tempBackend struct { + BackendOptions map[string]any `json:"backend_options,omitempty"` + } + if err := json.Unmarshal(backendData, &tempBackend); err != nil { + return nil, fmt.Errorf("failed to unmarshal backend data: %w", err) + } + + aux.BackendOptions = tempBackend.BackendOptions + return json.Marshal(aux) } @@ -257,78 +219,3 @@ func (c *Options) validateAndApplyDefaults(name string, globalSettings *config.I } } } - -// getCommand builds the command to run the backend -func (c *Options) getCommand(backendConfig *config.BackendSettings) string { - - if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm { - return "docker" - } - - return backendConfig.Command -} - -// buildCommandArgs builds command line arguments for the backend -func (c *Options) buildCommandArgs(backendConfig *config.BackendSettings) []string { - - var args []string - - if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm { - // For Docker, start with Docker args - args = append(args, backendConfig.Docker.Args...) - args = append(args, backendConfig.Docker.Image) - - switch c.BackendType { - case backends.BackendTypeLlamaCpp: - if c.LlamaServerOptions != nil { - args = append(args, c.LlamaServerOptions.BuildDockerArgs()...) - } - case backends.BackendTypeVllm: - if c.VllmServerOptions != nil { - args = append(args, c.VllmServerOptions.BuildDockerArgs()...) - } - } - - } else { - // For native execution, start with backend args - args = append(args, backendConfig.Args...) - - switch c.BackendType { - case backends.BackendTypeLlamaCpp: - if c.LlamaServerOptions != nil { - args = append(args, c.LlamaServerOptions.BuildCommandArgs()...) - } - case backends.BackendTypeMlxLm: - if c.MlxServerOptions != nil { - args = append(args, c.MlxServerOptions.BuildCommandArgs()...) - } - case backends.BackendTypeVllm: - if c.VllmServerOptions != nil { - args = append(args, c.VllmServerOptions.BuildCommandArgs()...) - } - } - } - - return args -} - -// buildEnvironment builds the environment variables for the backend process -func (c *Options) buildEnvironment(backendConfig *config.BackendSettings) map[string]string { - env := map[string]string{} - - if backendConfig.Environment != nil { - maps.Copy(env, backendConfig.Environment) - } - - if backendConfig.Docker != nil && backendConfig.Docker.Enabled && c.BackendType != backends.BackendTypeMlxLm { - if backendConfig.Docker.Environment != nil { - maps.Copy(env, backendConfig.Docker.Environment) - } - } - - if c.Environment != nil { - maps.Copy(env, c.Environment) - } - - return env -} diff --git a/pkg/instance/process.go b/pkg/instance/process.go index 9c7cfec..3429e61 100644 --- a/pkg/instance/process.go +++ b/pkg/instance/process.go @@ -12,9 +12,6 @@ import ( "sync" "syscall" "time" - - "llamactl/pkg/backends" - "llamactl/pkg/config" ) // process manages the OS process lifecycle for a local instance. @@ -216,7 +213,8 @@ func (p *process) waitForHealthy(timeout int) error { defer cancel() // Get host/port from instance - host, port := p.instance.getBackendHostPort() + host := p.instance.options.GetHost() + port := p.instance.options.GetPort() healthURL := fmt.Sprintf("http://%s:%d/health", host, port) // Create a dedicated HTTP client for health checks @@ -386,26 +384,15 @@ func (p *process) handleAutoRestart(err error) { // buildCommand builds the command to execute using backend-specific logic func (p *process) buildCommand() (*exec.Cmd, error) { - // Get options - opts := p.instance.GetOptions() - if opts == nil { - return nil, fmt.Errorf("instance options are nil") - } - - // Get backend configuration - backendConfig, err := p.getBackendConfig() - if err != nil { - return nil, err - } // Build the environment variables - env := opts.buildEnvironment(backendConfig) + env := p.instance.buildEnvironment() // Get the command to execute - command := opts.getCommand(backendConfig) + command := p.instance.getCommand() // Build command arguments - args := opts.buildCommandArgs(backendConfig) + args := p.instance.buildCommandArgs() // Create the exec.Cmd cmd := exec.CommandContext(p.ctx, command, args...) @@ -420,27 +407,3 @@ func (p *process) buildCommand() (*exec.Cmd, error) { return cmd, nil } - -// getBackendConfig resolves the backend configuration for the current instance -func (p *process) getBackendConfig() (*config.BackendSettings, error) { - opts := p.instance.GetOptions() - if opts == nil { - return nil, fmt.Errorf("instance options are nil") - } - - var backendTypeStr string - - switch opts.BackendType { - case backends.BackendTypeLlamaCpp: - backendTypeStr = "llama-cpp" - case backends.BackendTypeMlxLm: - backendTypeStr = "mlx" - case backends.BackendTypeVllm: - backendTypeStr = "vllm" - default: - return nil, fmt.Errorf("unsupported backend type: %s", opts.BackendType) - } - - settings := p.instance.globalBackendSettings.GetBackendSettings(backendTypeStr) - return &settings, nil -} diff --git a/pkg/instance/proxy.go b/pkg/instance/proxy.go index 321095b..26990f9 100644 --- a/pkg/instance/proxy.go +++ b/pkg/instance/proxy.go @@ -2,7 +2,6 @@ package instance import ( "fmt" - "llamactl/pkg/backends" "net/http" "net/http/httputil" "net/url" @@ -68,8 +67,11 @@ func (p *proxy) build() (*httputil.ReverseProxy, error) { } // Get host/port from process - host, port := p.instance.getBackendHostPort() - + host := p.instance.options.GetHost() + port := p.instance.options.GetPort() + if port == 0 { + return nil, fmt.Errorf("instance %s has no port assigned", p.instance.Name) + } targetURL, err := url.Parse(fmt.Sprintf("http://%s:%d", host, port)) if err != nil { return nil, fmt.Errorf("failed to parse target URL for instance %s: %w", p.instance.Name, err) @@ -78,15 +80,7 @@ func (p *proxy) build() (*httputil.ReverseProxy, error) { proxy := httputil.NewSingleHostReverseProxy(targetURL) // Get response headers from backend config - var responseHeaders map[string]string - switch options.BackendType { - case backends.BackendTypeLlamaCpp: - responseHeaders = p.instance.globalBackendSettings.LlamaCpp.ResponseHeaders - case backends.BackendTypeVllm: - responseHeaders = p.instance.globalBackendSettings.VLLM.ResponseHeaders - case backends.BackendTypeMlxLm: - responseHeaders = p.instance.globalBackendSettings.MLX.ResponseHeaders - } + responseHeaders := options.BackendOptions.GetResponseHeaders(p.instance.globalBackendSettings) proxy.ModifyResponse = func(resp *http.Response) error { // Remove CORS headers from backend response to avoid conflicts diff --git a/pkg/manager/operations.go b/pkg/manager/operations.go index 7129794..fd150e8 100644 --- a/pkg/manager/operations.go +++ b/pkg/manager/operations.go @@ -2,7 +2,6 @@ package manager import ( "fmt" - "llamactl/pkg/backends" "llamactl/pkg/instance" "llamactl/pkg/validation" "os" @@ -86,7 +85,7 @@ func (im *instanceManager) CreateInstance(name string, options *instance.Options return nil, err } - err = validation.ValidateInstanceOptions(options) + err = options.BackendOptions.ValidateInstanceOptions() if err != nil { return nil, err } @@ -232,7 +231,7 @@ func (im *instanceManager) UpdateInstance(name string, options *instance.Options return nil, fmt.Errorf("instance options cannot be nil") } - err := validation.ValidateInstanceOptions(options) + err := options.BackendOptions.ValidateInstanceOptions() if err != nil { return nil, err } @@ -493,39 +492,12 @@ func (im *instanceManager) GetInstanceLogs(name string, numLines int) (string, e // getPortFromOptions extracts the port from backend-specific options func (im *instanceManager) getPortFromOptions(options *instance.Options) int { - switch options.BackendType { - case backends.BackendTypeLlamaCpp: - if options.LlamaServerOptions != nil { - return options.LlamaServerOptions.Port - } - case backends.BackendTypeMlxLm: - if options.MlxServerOptions != nil { - return options.MlxServerOptions.Port - } - case backends.BackendTypeVllm: - if options.VllmServerOptions != nil { - return options.VllmServerOptions.Port - } - } - return 0 + return options.BackendOptions.GetPort() } // setPortInOptions sets the port in backend-specific options func (im *instanceManager) setPortInOptions(options *instance.Options, port int) { - switch options.BackendType { - case backends.BackendTypeLlamaCpp: - if options.LlamaServerOptions != nil { - options.LlamaServerOptions.Port = port - } - case backends.BackendTypeMlxLm: - if options.MlxServerOptions != nil { - options.MlxServerOptions.Port = port - } - case backends.BackendTypeVllm: - if options.VllmServerOptions != nil { - options.VllmServerOptions.Port = port - } - } + options.BackendOptions.SetPort(port) } // assignAndValidatePort assigns a port if not specified and validates it's not in use diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go index 6d6638d..ca7b8ec 100644 --- a/pkg/validation/validation.go +++ b/pkg/validation/validation.go @@ -2,8 +2,6 @@ package validation import ( "fmt" - "llamactl/pkg/backends" - "llamactl/pkg/instance" "reflect" "regexp" ) @@ -24,8 +22,8 @@ var ( type ValidationError error -// validateStringForInjection checks if a string contains dangerous patterns -func validateStringForInjection(value string) error { +// ValidateStringForInjection checks if a string contains dangerous patterns +func ValidateStringForInjection(value string) error { for _, pattern := range dangerousPatterns { if pattern.MatchString(value) { return ValidationError(fmt.Errorf("value contains potentially dangerous characters: %s", value)) @@ -34,83 +32,8 @@ func validateStringForInjection(value string) error { return nil } -// ValidateInstanceOptions performs validation based on backend type -func ValidateInstanceOptions(options *instance.Options) error { - if options == nil { - return ValidationError(fmt.Errorf("options cannot be nil")) - } - - // Validate based on backend type - switch options.BackendType { - case backends.BackendTypeLlamaCpp: - return validateLlamaCppOptions(options) - case backends.BackendTypeMlxLm: - return validateMlxOptions(options) - case backends.BackendTypeVllm: - return validateVllmOptions(options) - default: - return ValidationError(fmt.Errorf("unsupported backend type: %s", options.BackendType)) - } -} - -// validateLlamaCppOptions validates llama.cpp specific options -func validateLlamaCppOptions(options *instance.Options) error { - if options.LlamaServerOptions == nil { - return ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend")) - } - - // Use reflection to check all string fields for injection patterns - if err := validateStructStrings(options.LlamaServerOptions, ""); err != nil { - return err - } - - // Basic network validation for port - if options.LlamaServerOptions.Port < 0 || options.LlamaServerOptions.Port > 65535 { - return ValidationError(fmt.Errorf("invalid port range: %d", options.LlamaServerOptions.Port)) - } - - return nil -} - -// validateMlxOptions validates MLX backend specific options -func validateMlxOptions(options *instance.Options) error { - if options.MlxServerOptions == nil { - return ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend")) - } - - if err := validateStructStrings(options.MlxServerOptions, ""); err != nil { - return err - } - - // Basic network validation for port - if options.MlxServerOptions.Port < 0 || options.MlxServerOptions.Port > 65535 { - return ValidationError(fmt.Errorf("invalid port range: %d", options.MlxServerOptions.Port)) - } - - return nil -} - -// validateVllmOptions validates vLLM backend specific options -func validateVllmOptions(options *instance.Options) error { - if options.VllmServerOptions == nil { - return ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend")) - } - - // Use reflection to check all string fields for injection patterns - if err := validateStructStrings(options.VllmServerOptions, ""); err != nil { - return err - } - - // Basic network validation for port - if options.VllmServerOptions.Port < 0 || options.VllmServerOptions.Port > 65535 { - return ValidationError(fmt.Errorf("invalid port range: %d", options.VllmServerOptions.Port)) - } - - return nil -} - -// validateStructStrings recursively validates all string fields in a struct -func validateStructStrings(v any, fieldPath string) error { +// ValidateStructStrings recursively validates all string fields in a struct +func ValidateStructStrings(v any, fieldPath string) error { val := reflect.ValueOf(v) if val.Kind() == reflect.Ptr { val = val.Elem() @@ -136,21 +59,21 @@ func validateStructStrings(v any, fieldPath string) error { switch field.Kind() { case reflect.String: - if err := validateStringForInjection(field.String()); err != nil { + if err := ValidateStringForInjection(field.String()); err != nil { return ValidationError(fmt.Errorf("field %s: %w", fieldName, err)) } case reflect.Slice: if field.Type().Elem().Kind() == reflect.String { for j := 0; j < field.Len(); j++ { - if err := validateStringForInjection(field.Index(j).String()); err != nil { + if err := ValidateStringForInjection(field.Index(j).String()); err != nil { return ValidationError(fmt.Errorf("field %s[%d]: %w", fieldName, j, err)) } } } case reflect.Struct: - if err := validateStructStrings(field.Interface(), fieldName); err != nil { + if err := ValidateStructStrings(field.Interface(), fieldName); err != nil { return err } }