diff --git a/pkg/backends/backend.go b/pkg/backends/backend.go index 2f829b9..fc32549 100644 --- a/pkg/backends/backend.go +++ b/pkg/backends/backend.go @@ -14,9 +14,17 @@ const ( BackendTypeLlamaCpp BackendType = "llama_cpp" BackendTypeMlxLm BackendType = "mlx_lm" BackendTypeVllm BackendType = "vllm" - // BackendTypeMlxVlm BackendType = "mlx_vlm" // Future expansion ) +type backend interface { + BuildCommandArgs() []string + BuildDockerArgs() []string + GetPort() int + SetPort(int) + GetHost() string + Validate() error +} + type Options struct { BackendType BackendType `json:"backend_type"` BackendOptions map[string]any `json:"backend_options,omitempty"` @@ -135,7 +143,7 @@ func (o *Options) MarshalJSON() ([]byte, error) { return json.Marshal(aux) } -func getBackendSettings(o *Options, backendConfig *config.BackendConfig) *config.BackendSettings { +func (o *Options) getBackendSettings(backendConfig *config.BackendConfig) *config.BackendSettings { switch o.BackendType { case BackendTypeLlamaCpp: return &backendConfig.LlamaCpp @@ -148,6 +156,20 @@ func getBackendSettings(o *Options, backendConfig *config.BackendConfig) *config } } +// getBackend returns the actual backend implementation +func (o *Options) getBackend() backend { + switch o.BackendType { + case BackendTypeLlamaCpp: + return o.LlamaServerOptions + case BackendTypeMlxLm: + return o.MlxServerOptions + case BackendTypeVllm: + return o.VllmServerOptions + default: + return nil + } +} + func (o *Options) isDockerEnabled(backend *config.BackendSettings) bool { if backend.Docker != nil && backend.Docker.Enabled && o.BackendType != BackendTypeMlxLm { return true @@ -156,14 +178,14 @@ func (o *Options) isDockerEnabled(backend *config.BackendSettings) bool { } func (o *Options) IsDockerEnabled(backendConfig *config.BackendConfig) bool { - backendSettings := getBackendSettings(o, backendConfig) + backendSettings := o.getBackendSettings(backendConfig) return o.isDockerEnabled(backendSettings) } // GetCommand builds the command to run the backend func (o *Options) GetCommand(backendConfig *config.BackendConfig) string { - backendSettings := getBackendSettings(o, backendConfig) + backendSettings := o.getBackendSettings(backendConfig) if o.isDockerEnabled(backendSettings) { return "docker" @@ -177,42 +199,22 @@ func (o *Options) BuildCommandArgs(backendConfig *config.BackendConfig) []string var args []string - backendSettings := getBackendSettings(o, backendConfig) + backendSettings := o.getBackendSettings(backendConfig) + backend := o.getBackend() + if backend == nil { + return args + } if o.isDockerEnabled(backendSettings) { // For Docker, start with Docker args args = append(args, backendSettings.Docker.Args...) args = append(args, backendSettings.Docker.Image) - - switch o.BackendType { - case BackendTypeLlamaCpp: - if o.LlamaServerOptions != nil { - args = append(args, o.LlamaServerOptions.BuildDockerArgs()...) - } - case BackendTypeVllm: - if o.VllmServerOptions != nil { - args = append(args, o.VllmServerOptions.BuildDockerArgs()...) - } - } + args = append(args, backend.BuildDockerArgs()...) } else { // For native execution, start with backend args args = append(args, backendSettings.Args...) - - switch o.BackendType { - case BackendTypeLlamaCpp: - if o.LlamaServerOptions != nil { - args = append(args, o.LlamaServerOptions.BuildCommandArgs()...) - } - case BackendTypeMlxLm: - if o.MlxServerOptions != nil { - args = append(args, o.MlxServerOptions.BuildCommandArgs()...) - } - case BackendTypeVllm: - if o.VllmServerOptions != nil { - args = append(args, o.VllmServerOptions.BuildCommandArgs()...) - } - } + args = append(args, backend.BuildCommandArgs()...) } return args @@ -221,7 +223,7 @@ func (o *Options) BuildCommandArgs(backendConfig *config.BackendConfig) []string // BuildEnvironment builds the environment variables for the backend process func (o *Options) BuildEnvironment(backendConfig *config.BackendConfig, environment map[string]string) map[string]string { - backendSettings := getBackendSettings(o, backendConfig) + backendSettings := o.getBackendSettings(backendConfig) env := map[string]string{} if backendSettings.Environment != nil { @@ -242,80 +244,39 @@ func (o *Options) BuildEnvironment(backendConfig *config.BackendConfig, environm } func (o *Options) GetPort() int { - if o != nil { - switch o.BackendType { - case BackendTypeLlamaCpp: - if o.LlamaServerOptions != nil { - return o.LlamaServerOptions.Port - } - case BackendTypeMlxLm: - if o.MlxServerOptions != nil { - return o.MlxServerOptions.Port - } - case BackendTypeVllm: - if o.VllmServerOptions != nil { - return o.VllmServerOptions.Port - } - } + backend := o.getBackend() + if backend != nil { + return backend.GetPort() } return 0 } func (o *Options) SetPort(port int) { - if o != nil { - switch o.BackendType { - case BackendTypeLlamaCpp: - if o.LlamaServerOptions != nil { - o.LlamaServerOptions.Port = port - } - case BackendTypeMlxLm: - if o.MlxServerOptions != nil { - o.MlxServerOptions.Port = port - } - case BackendTypeVllm: - if o.VllmServerOptions != nil { - o.VllmServerOptions.Port = port - } - } + backend := o.getBackend() + if backend != nil { + backend.SetPort(port) } } func (o *Options) GetHost() string { - if o != nil { - switch o.BackendType { - case BackendTypeLlamaCpp: - if o.LlamaServerOptions != nil { - return o.LlamaServerOptions.Host - } - case BackendTypeMlxLm: - if o.MlxServerOptions != nil { - return o.MlxServerOptions.Host - } - case BackendTypeVllm: - if o.VllmServerOptions != nil { - return o.VllmServerOptions.Host - } - } + backend := o.getBackend() + if backend != nil { + return backend.GetHost() } return "localhost" } func (o *Options) GetResponseHeaders(backendConfig *config.BackendConfig) map[string]string { - backendSettings := getBackendSettings(o, backendConfig) + backendSettings := o.getBackendSettings(backendConfig) return backendSettings.ResponseHeaders } // ValidateInstanceOptions performs validation based on backend type func (o *Options) ValidateInstanceOptions() error { - // Validate based on backend type - switch o.BackendType { - case BackendTypeLlamaCpp: - return validateLlamaCppOptions(o.LlamaServerOptions) - case BackendTypeMlxLm: - return validateMlxOptions(o.MlxServerOptions) - case BackendTypeVllm: - return validateVllmOptions(o.VllmServerOptions) - default: - return validation.ValidationError(fmt.Errorf("unsupported backend type: %s", o.BackendType)) + backend := o.getBackend() + if backend == nil { + return validation.ValidationError(fmt.Errorf("backend options cannot be nil for backend type %s", o.BackendType)) } + + return backend.Validate() } diff --git a/pkg/backends/llama.go b/pkg/backends/llama.go index 7cdc182..dc07457 100644 --- a/pkg/backends/llama.go +++ b/pkg/backends/llama.go @@ -336,6 +336,36 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error { return nil } +func (o *LlamaServerOptions) GetPort() int { + return o.Port +} + +func (o *LlamaServerOptions) SetPort(port int) { + o.Port = port +} + +func (o *LlamaServerOptions) GetHost() string { + return o.Host +} + +func (o *LlamaServerOptions) Validate() error { + if o == nil { + return validation.ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend")) + } + + // Use reflection to check all string fields for injection patterns + if err := validation.ValidateStructStrings(o, ""); err != nil { + return err + } + + // Basic network validation for port + if o.Port < 0 || o.Port > 65535 { + return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port)) + } + + return nil +} + // BuildCommandArgs converts InstanceOptions to command line arguments func (o *LlamaServerOptions) BuildCommandArgs() []string { // Llama uses multiple flags for arrays by default (not comma-separated) @@ -366,22 +396,3 @@ func ParseLlamaCommand(command string) (*LlamaServerOptions, error) { return &llamaOptions, nil } - -// validateLlamaCppOptions validates llama.cpp specific options -func validateLlamaCppOptions(options *LlamaServerOptions) error { - if options == nil { - return validation.ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend")) - } - - // Use reflection to check all string fields for injection patterns - if err := validation.ValidateStructStrings(options, ""); err != nil { - return err - } - - // Basic network validation for port - if options.Port < 0 || options.Port > 65535 { - return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port)) - } - - return nil -} diff --git a/pkg/backends/mlx.go b/pkg/backends/mlx.go index 00f62da..d0ec602 100644 --- a/pkg/backends/mlx.go +++ b/pkg/backends/mlx.go @@ -31,12 +31,45 @@ type MlxServerOptions struct { MaxTokens int `json:"max_tokens,omitempty"` } +func (o *MlxServerOptions) GetPort() int { + return o.Port +} + +func (o *MlxServerOptions) SetPort(port int) { + o.Port = port +} + +func (o *MlxServerOptions) GetHost() string { + return o.Host +} + +func (o *MlxServerOptions) Validate() error { + if o == nil { + return validation.ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend")) + } + + if err := validation.ValidateStructStrings(o, ""); err != nil { + return err + } + + // Basic network validation for port + if o.Port < 0 || o.Port > 65535 { + return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port)) + } + + return nil +} + // BuildCommandArgs converts to command line arguments func (o *MlxServerOptions) BuildCommandArgs() []string { multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields return BuildCommandArgs(o, multipleFlags) } +func (o *MlxServerOptions) BuildDockerArgs() []string { + return []string{} +} + // ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions // Supports multiple formats: // 1. Full command: "mlx_lm.server --model model/path" @@ -55,21 +88,3 @@ func ParseMlxCommand(command string) (*MlxServerOptions, error) { return &mlxOptions, nil } - -// validateMlxOptions validates MLX backend specific options -func validateMlxOptions(options *MlxServerOptions) error { - if options == nil { - return validation.ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend")) - } - - if err := validation.ValidateStructStrings(options, ""); err != nil { - return err - } - - // Basic network validation for port - if options.Port < 0 || options.Port > 65535 { - return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port)) - } - - return nil -} diff --git a/pkg/backends/vllm.go b/pkg/backends/vllm.go index 0a8c6f3..857eab3 100644 --- a/pkg/backends/vllm.go +++ b/pkg/backends/vllm.go @@ -140,6 +140,36 @@ type VllmServerOptions struct { OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"` } +func (o *VllmServerOptions) GetPort() int { + return o.Port +} + +func (o *VllmServerOptions) SetPort(port int) { + o.Port = port +} + +func (o *VllmServerOptions) GetHost() string { + return o.Host +} + +func (o *VllmServerOptions) Validate() error { + if o == nil { + return validation.ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend")) + } + + // Use reflection to check all string fields for injection patterns + if err := validation.ValidateStructStrings(o, ""); err != nil { + return err + } + + // Basic network validation for port + if o.Port < 0 || o.Port > 65535 { + return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port)) + } + + return nil +} + // BuildCommandArgs converts VllmServerOptions to command line arguments // For vLLM native, model is a positional argument after "serve" func (o *VllmServerOptions) BuildCommandArgs() []string { @@ -199,22 +229,3 @@ func ParseVllmCommand(command string) (*VllmServerOptions, error) { return &vllmOptions, nil } - -// validateVllmOptions validates vLLM backend specific options -func validateVllmOptions(options *VllmServerOptions) error { - if options == nil { - return validation.ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend")) - } - - // Use reflection to check all string fields for injection patterns - if err := validation.ValidateStructStrings(options, ""); err != nil { - return err - } - - // Basic network validation for port - if options.Port < 0 || options.Port > 65535 { - return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port)) - } - - return nil -}