Refactor backend options to implement common interface and streamline validation

This commit is contained in:
2025-10-19 20:36:57 +02:00
parent f42f000539
commit d8e0da9cf8
4 changed files with 142 additions and 144 deletions

View File

@@ -14,9 +14,17 @@ const (
BackendTypeLlamaCpp BackendType = "llama_cpp" BackendTypeLlamaCpp BackendType = "llama_cpp"
BackendTypeMlxLm BackendType = "mlx_lm" BackendTypeMlxLm BackendType = "mlx_lm"
BackendTypeVllm BackendType = "vllm" BackendTypeVllm BackendType = "vllm"
// BackendTypeMlxVlm BackendType = "mlx_vlm" // Future expansion
) )
type backend interface {
BuildCommandArgs() []string
BuildDockerArgs() []string
GetPort() int
SetPort(int)
GetHost() string
Validate() error
}
type Options struct { type Options struct {
BackendType BackendType `json:"backend_type"` BackendType BackendType `json:"backend_type"`
BackendOptions map[string]any `json:"backend_options,omitempty"` BackendOptions map[string]any `json:"backend_options,omitempty"`
@@ -135,7 +143,7 @@ func (o *Options) MarshalJSON() ([]byte, error) {
return json.Marshal(aux) return json.Marshal(aux)
} }
func getBackendSettings(o *Options, backendConfig *config.BackendConfig) *config.BackendSettings { func (o *Options) getBackendSettings(backendConfig *config.BackendConfig) *config.BackendSettings {
switch o.BackendType { switch o.BackendType {
case BackendTypeLlamaCpp: case BackendTypeLlamaCpp:
return &backendConfig.LlamaCpp return &backendConfig.LlamaCpp
@@ -148,6 +156,20 @@ func getBackendSettings(o *Options, backendConfig *config.BackendConfig) *config
} }
} }
// getBackend returns the actual backend implementation
func (o *Options) getBackend() backend {
switch o.BackendType {
case BackendTypeLlamaCpp:
return o.LlamaServerOptions
case BackendTypeMlxLm:
return o.MlxServerOptions
case BackendTypeVllm:
return o.VllmServerOptions
default:
return nil
}
}
func (o *Options) isDockerEnabled(backend *config.BackendSettings) bool { func (o *Options) isDockerEnabled(backend *config.BackendSettings) bool {
if backend.Docker != nil && backend.Docker.Enabled && o.BackendType != BackendTypeMlxLm { if backend.Docker != nil && backend.Docker.Enabled && o.BackendType != BackendTypeMlxLm {
return true return true
@@ -156,14 +178,14 @@ func (o *Options) isDockerEnabled(backend *config.BackendSettings) bool {
} }
func (o *Options) IsDockerEnabled(backendConfig *config.BackendConfig) bool { func (o *Options) IsDockerEnabled(backendConfig *config.BackendConfig) bool {
backendSettings := getBackendSettings(o, backendConfig) backendSettings := o.getBackendSettings(backendConfig)
return o.isDockerEnabled(backendSettings) return o.isDockerEnabled(backendSettings)
} }
// GetCommand builds the command to run the backend // GetCommand builds the command to run the backend
func (o *Options) GetCommand(backendConfig *config.BackendConfig) string { func (o *Options) GetCommand(backendConfig *config.BackendConfig) string {
backendSettings := getBackendSettings(o, backendConfig) backendSettings := o.getBackendSettings(backendConfig)
if o.isDockerEnabled(backendSettings) { if o.isDockerEnabled(backendSettings) {
return "docker" return "docker"
@@ -177,42 +199,22 @@ func (o *Options) BuildCommandArgs(backendConfig *config.BackendConfig) []string
var args []string var args []string
backendSettings := getBackendSettings(o, backendConfig) backendSettings := o.getBackendSettings(backendConfig)
backend := o.getBackend()
if backend == nil {
return args
}
if o.isDockerEnabled(backendSettings) { if o.isDockerEnabled(backendSettings) {
// For Docker, start with Docker args // For Docker, start with Docker args
args = append(args, backendSettings.Docker.Args...) args = append(args, backendSettings.Docker.Args...)
args = append(args, backendSettings.Docker.Image) args = append(args, backendSettings.Docker.Image)
args = append(args, backend.BuildDockerArgs()...)
switch o.BackendType {
case BackendTypeLlamaCpp:
if o.LlamaServerOptions != nil {
args = append(args, o.LlamaServerOptions.BuildDockerArgs()...)
}
case BackendTypeVllm:
if o.VllmServerOptions != nil {
args = append(args, o.VllmServerOptions.BuildDockerArgs()...)
}
}
} else { } else {
// For native execution, start with backend args // For native execution, start with backend args
args = append(args, backendSettings.Args...) args = append(args, backendSettings.Args...)
args = append(args, backend.BuildCommandArgs()...)
switch o.BackendType {
case BackendTypeLlamaCpp:
if o.LlamaServerOptions != nil {
args = append(args, o.LlamaServerOptions.BuildCommandArgs()...)
}
case BackendTypeMlxLm:
if o.MlxServerOptions != nil {
args = append(args, o.MlxServerOptions.BuildCommandArgs()...)
}
case BackendTypeVllm:
if o.VllmServerOptions != nil {
args = append(args, o.VllmServerOptions.BuildCommandArgs()...)
}
}
} }
return args return args
@@ -221,7 +223,7 @@ func (o *Options) BuildCommandArgs(backendConfig *config.BackendConfig) []string
// BuildEnvironment builds the environment variables for the backend process // BuildEnvironment builds the environment variables for the backend process
func (o *Options) BuildEnvironment(backendConfig *config.BackendConfig, environment map[string]string) map[string]string { func (o *Options) BuildEnvironment(backendConfig *config.BackendConfig, environment map[string]string) map[string]string {
backendSettings := getBackendSettings(o, backendConfig) backendSettings := o.getBackendSettings(backendConfig)
env := map[string]string{} env := map[string]string{}
if backendSettings.Environment != nil { if backendSettings.Environment != nil {
@@ -242,80 +244,39 @@ func (o *Options) BuildEnvironment(backendConfig *config.BackendConfig, environm
} }
func (o *Options) GetPort() int { func (o *Options) GetPort() int {
if o != nil { backend := o.getBackend()
switch o.BackendType { if backend != nil {
case BackendTypeLlamaCpp: return backend.GetPort()
if o.LlamaServerOptions != nil {
return o.LlamaServerOptions.Port
}
case BackendTypeMlxLm:
if o.MlxServerOptions != nil {
return o.MlxServerOptions.Port
}
case BackendTypeVllm:
if o.VllmServerOptions != nil {
return o.VllmServerOptions.Port
}
}
} }
return 0 return 0
} }
func (o *Options) SetPort(port int) { func (o *Options) SetPort(port int) {
if o != nil { backend := o.getBackend()
switch o.BackendType { if backend != nil {
case BackendTypeLlamaCpp: backend.SetPort(port)
if o.LlamaServerOptions != nil {
o.LlamaServerOptions.Port = port
}
case BackendTypeMlxLm:
if o.MlxServerOptions != nil {
o.MlxServerOptions.Port = port
}
case BackendTypeVllm:
if o.VllmServerOptions != nil {
o.VllmServerOptions.Port = port
}
}
} }
} }
func (o *Options) GetHost() string { func (o *Options) GetHost() string {
if o != nil { backend := o.getBackend()
switch o.BackendType { if backend != nil {
case BackendTypeLlamaCpp: return backend.GetHost()
if o.LlamaServerOptions != nil {
return o.LlamaServerOptions.Host
}
case BackendTypeMlxLm:
if o.MlxServerOptions != nil {
return o.MlxServerOptions.Host
}
case BackendTypeVllm:
if o.VllmServerOptions != nil {
return o.VllmServerOptions.Host
}
}
} }
return "localhost" return "localhost"
} }
func (o *Options) GetResponseHeaders(backendConfig *config.BackendConfig) map[string]string { func (o *Options) GetResponseHeaders(backendConfig *config.BackendConfig) map[string]string {
backendSettings := getBackendSettings(o, backendConfig) backendSettings := o.getBackendSettings(backendConfig)
return backendSettings.ResponseHeaders return backendSettings.ResponseHeaders
} }
// ValidateInstanceOptions performs validation based on backend type // ValidateInstanceOptions performs validation based on backend type
func (o *Options) ValidateInstanceOptions() error { func (o *Options) ValidateInstanceOptions() error {
// Validate based on backend type backend := o.getBackend()
switch o.BackendType { if backend == nil {
case BackendTypeLlamaCpp: return validation.ValidationError(fmt.Errorf("backend options cannot be nil for backend type %s", o.BackendType))
return validateLlamaCppOptions(o.LlamaServerOptions)
case BackendTypeMlxLm:
return validateMlxOptions(o.MlxServerOptions)
case BackendTypeVllm:
return validateVllmOptions(o.VllmServerOptions)
default:
return validation.ValidationError(fmt.Errorf("unsupported backend type: %s", o.BackendType))
} }
return backend.Validate()
} }

View File

@@ -336,6 +336,36 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
return nil return nil
} }
func (o *LlamaServerOptions) GetPort() int {
return o.Port
}
func (o *LlamaServerOptions) SetPort(port int) {
o.Port = port
}
func (o *LlamaServerOptions) GetHost() string {
return o.Host
}
func (o *LlamaServerOptions) Validate() error {
if o == nil {
return validation.ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend"))
}
// Use reflection to check all string fields for injection patterns
if err := validation.ValidateStructStrings(o, ""); err != nil {
return err
}
// Basic network validation for port
if o.Port < 0 || o.Port > 65535 {
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
}
return nil
}
// BuildCommandArgs converts InstanceOptions to command line arguments // BuildCommandArgs converts InstanceOptions to command line arguments
func (o *LlamaServerOptions) BuildCommandArgs() []string { func (o *LlamaServerOptions) BuildCommandArgs() []string {
// Llama uses multiple flags for arrays by default (not comma-separated) // Llama uses multiple flags for arrays by default (not comma-separated)
@@ -366,22 +396,3 @@ func ParseLlamaCommand(command string) (*LlamaServerOptions, error) {
return &llamaOptions, nil return &llamaOptions, nil
} }
// validateLlamaCppOptions validates llama.cpp specific options
func validateLlamaCppOptions(options *LlamaServerOptions) error {
if options == nil {
return validation.ValidationError(fmt.Errorf("llama server options cannot be nil for llama.cpp backend"))
}
// Use reflection to check all string fields for injection patterns
if err := validation.ValidateStructStrings(options, ""); err != nil {
return err
}
// Basic network validation for port
if options.Port < 0 || options.Port > 65535 {
return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port))
}
return nil
}

View File

@@ -31,12 +31,45 @@ type MlxServerOptions struct {
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
} }
func (o *MlxServerOptions) GetPort() int {
return o.Port
}
func (o *MlxServerOptions) SetPort(port int) {
o.Port = port
}
func (o *MlxServerOptions) GetHost() string {
return o.Host
}
func (o *MlxServerOptions) Validate() error {
if o == nil {
return validation.ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend"))
}
if err := validation.ValidateStructStrings(o, ""); err != nil {
return err
}
// Basic network validation for port
if o.Port < 0 || o.Port > 65535 {
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
}
return nil
}
// BuildCommandArgs converts to command line arguments // BuildCommandArgs converts to command line arguments
func (o *MlxServerOptions) BuildCommandArgs() []string { func (o *MlxServerOptions) BuildCommandArgs() []string {
multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields
return BuildCommandArgs(o, multipleFlags) return BuildCommandArgs(o, multipleFlags)
} }
func (o *MlxServerOptions) BuildDockerArgs() []string {
return []string{}
}
// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions // ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions
// Supports multiple formats: // Supports multiple formats:
// 1. Full command: "mlx_lm.server --model model/path" // 1. Full command: "mlx_lm.server --model model/path"
@@ -55,21 +88,3 @@ func ParseMlxCommand(command string) (*MlxServerOptions, error) {
return &mlxOptions, nil return &mlxOptions, nil
} }
// validateMlxOptions validates MLX backend specific options
func validateMlxOptions(options *MlxServerOptions) error {
if options == nil {
return validation.ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend"))
}
if err := validation.ValidateStructStrings(options, ""); err != nil {
return err
}
// Basic network validation for port
if options.Port < 0 || options.Port > 65535 {
return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port))
}
return nil
}

View File

@@ -140,6 +140,36 @@ type VllmServerOptions struct {
OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"` OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"`
} }
func (o *VllmServerOptions) GetPort() int {
return o.Port
}
func (o *VllmServerOptions) SetPort(port int) {
o.Port = port
}
func (o *VllmServerOptions) GetHost() string {
return o.Host
}
func (o *VllmServerOptions) Validate() error {
if o == nil {
return validation.ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend"))
}
// Use reflection to check all string fields for injection patterns
if err := validation.ValidateStructStrings(o, ""); err != nil {
return err
}
// Basic network validation for port
if o.Port < 0 || o.Port > 65535 {
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
}
return nil
}
// BuildCommandArgs converts VllmServerOptions to command line arguments // BuildCommandArgs converts VllmServerOptions to command line arguments
// For vLLM native, model is a positional argument after "serve" // For vLLM native, model is a positional argument after "serve"
func (o *VllmServerOptions) BuildCommandArgs() []string { func (o *VllmServerOptions) BuildCommandArgs() []string {
@@ -199,22 +229,3 @@ func ParseVllmCommand(command string) (*VllmServerOptions, error) {
return &vllmOptions, nil return &vllmOptions, nil
} }
// validateVllmOptions validates vLLM backend specific options
func validateVllmOptions(options *VllmServerOptions) error {
if options == nil {
return validation.ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend"))
}
// Use reflection to check all string fields for injection patterns
if err := validation.ValidateStructStrings(options, ""); err != nil {
return err
}
// Basic network validation for port
if options.Port < 0 || options.Port > 65535 {
return validation.ValidationError(fmt.Errorf("invalid port range: %d", options.Port))
}
return nil
}