diff --git a/pkg/instance/options.go b/pkg/instance/options.go index b9a2cca..2b1437f 100644 --- a/pkg/instance/options.go +++ b/pkg/instance/options.go @@ -5,6 +5,7 @@ import ( "fmt" "llamactl/pkg/backends" "llamactl/pkg/backends/llamacpp" + "llamactl/pkg/backends/mlx" "llamactl/pkg/config" "log" ) @@ -22,8 +23,9 @@ type CreateInstanceOptions struct { BackendType backends.BackendType `json:"backend_type"` BackendOptions map[string]any `json:"backend_options,omitempty"` - // LlamaServerOptions contains the options for the llama server + // Backend-specific options LlamaServerOptions *llamacpp.LlamaServerOptions `json:"-"` + MlxServerOptions *mlx.MlxServerOptions `json:"-"` } // UnmarshalJSON implements custom JSON unmarshaling for CreateInstanceOptions @@ -55,6 +57,18 @@ func (c *CreateInstanceOptions) UnmarshalJSON(data []byte) error { return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err) } } + case backends.BackendTypeMlxLm: + if c.BackendOptions != nil { + optionsData, err := json.Marshal(c.BackendOptions) + if err != nil { + return fmt.Errorf("failed to marshal backend options: %w", err) + } + + c.MlxServerOptions = &mlx.MlxServerOptions{} + if err := json.Unmarshal(optionsData, c.MlxServerOptions); err != nil { + return fmt.Errorf("failed to unmarshal MLX options: %w", err) + } + } default: return fmt.Errorf("unknown backend type: %s", c.BackendType) } @@ -72,19 +86,36 @@ func (c *CreateInstanceOptions) MarshalJSON() ([]byte, error) { Alias: (*Alias)(c), } - // Convert LlamaServerOptions back to BackendOptions map for JSON - if c.BackendType == backends.BackendTypeLlamaCpp && c.LlamaServerOptions != nil { - data, err := json.Marshal(c.LlamaServerOptions) - if err != nil { - return nil, fmt.Errorf("failed to marshal llama server options: %w", err) - } + // Convert backend-specific options back to BackendOptions map for JSON + switch c.BackendType { + case backends.BackendTypeLlamaCpp: + if c.LlamaServerOptions != nil { + data, err := json.Marshal(c.LlamaServerOptions) + if err != nil { + return nil, fmt.Errorf("failed to marshal llama server options: %w", err) + } - var backendOpts map[string]any - if err := json.Unmarshal(data, &backendOpts); err != nil { - return nil, fmt.Errorf("failed to unmarshal to map: %w", err) - } + var backendOpts map[string]any + if err := json.Unmarshal(data, &backendOpts); err != nil { + return nil, fmt.Errorf("failed to unmarshal to map: %w", err) + } - aux.BackendOptions = backendOpts + aux.BackendOptions = backendOpts + } + case backends.BackendTypeMlxLm: + if c.MlxServerOptions != nil { + data, err := json.Marshal(c.MlxServerOptions) + if err != nil { + return nil, fmt.Errorf("failed to marshal MLX server options: %w", err) + } + + var backendOpts map[string]any + if err := json.Unmarshal(data, &backendOpts); err != nil { + return nil, fmt.Errorf("failed to unmarshal to map: %w", err) + } + + aux.BackendOptions = backendOpts + } } return json.Marshal(aux) @@ -136,6 +167,10 @@ func (c *CreateInstanceOptions) BuildCommandArgs() []string { if c.LlamaServerOptions != nil { return c.LlamaServerOptions.BuildCommandArgs() } + case backends.BackendTypeMlxLm: + if c.MlxServerOptions != nil { + return c.MlxServerOptions.BuildCommandArgs() + } } return []string{} } diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go index 77873ca..eff1dd3 100644 --- a/pkg/validation/validation.go +++ b/pkg/validation/validation.go @@ -44,6 +44,8 @@ func ValidateInstanceOptions(options *instance.CreateInstanceOptions) error { switch options.BackendType { case backends.BackendTypeLlamaCpp: return validateLlamaCppOptions(options) + case backends.BackendTypeMlxLm: + return validateMlxOptions(options) default: return ValidationError(fmt.Errorf("unsupported backend type: %s", options.BackendType)) } @@ -68,6 +70,24 @@ func validateLlamaCppOptions(options *instance.CreateInstanceOptions) error { return nil } +// validateMlxOptions validates MLX backend specific options +func validateMlxOptions(options *instance.CreateInstanceOptions) error { + if options.MlxServerOptions == nil { + return ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend")) + } + + if err := validateStructStrings(options.MlxServerOptions, ""); err != nil { + return err + } + + // Basic network validation for port + if options.MlxServerOptions.Port < 0 || options.MlxServerOptions.Port > 65535 { + return ValidationError(fmt.Errorf("invalid port range: %d", options.MlxServerOptions.Port)) + } + + return nil +} + // validateStructStrings recursively validates all string fields in a struct func validateStructStrings(v any, fieldPath string) error { val := reflect.ValueOf(v)