Add MLX backend support in CreateInstanceOptions and validation

This commit is contained in:
2025-09-16 21:38:33 +02:00
parent 468688cdbc
commit 63fea02d66
2 changed files with 67 additions and 12 deletions

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"llamactl/pkg/backends" "llamactl/pkg/backends"
"llamactl/pkg/backends/llamacpp" "llamactl/pkg/backends/llamacpp"
"llamactl/pkg/backends/mlx"
"llamactl/pkg/config" "llamactl/pkg/config"
"log" "log"
) )
@@ -22,8 +23,9 @@ type CreateInstanceOptions struct {
BackendType backends.BackendType `json:"backend_type"` BackendType backends.BackendType `json:"backend_type"`
BackendOptions map[string]any `json:"backend_options,omitempty"` BackendOptions map[string]any `json:"backend_options,omitempty"`
// LlamaServerOptions contains the options for the llama server // Backend-specific options
LlamaServerOptions *llamacpp.LlamaServerOptions `json:"-"` LlamaServerOptions *llamacpp.LlamaServerOptions `json:"-"`
MlxServerOptions *mlx.MlxServerOptions `json:"-"`
} }
// UnmarshalJSON implements custom JSON unmarshaling for CreateInstanceOptions // UnmarshalJSON implements custom JSON unmarshaling for CreateInstanceOptions
@@ -55,6 +57,18 @@ func (c *CreateInstanceOptions) UnmarshalJSON(data []byte) error {
return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err) return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err)
} }
} }
case backends.BackendTypeMlxLm:
if c.BackendOptions != nil {
optionsData, err := json.Marshal(c.BackendOptions)
if err != nil {
return fmt.Errorf("failed to marshal backend options: %w", err)
}
c.MlxServerOptions = &mlx.MlxServerOptions{}
if err := json.Unmarshal(optionsData, c.MlxServerOptions); err != nil {
return fmt.Errorf("failed to unmarshal MLX options: %w", err)
}
}
default: default:
return fmt.Errorf("unknown backend type: %s", c.BackendType) return fmt.Errorf("unknown backend type: %s", c.BackendType)
} }
@@ -72,8 +86,10 @@ func (c *CreateInstanceOptions) MarshalJSON() ([]byte, error) {
Alias: (*Alias)(c), Alias: (*Alias)(c),
} }
// Convert LlamaServerOptions back to BackendOptions map for JSON // Convert backend-specific options back to BackendOptions map for JSON
if c.BackendType == backends.BackendTypeLlamaCpp && c.LlamaServerOptions != nil { switch c.BackendType {
case backends.BackendTypeLlamaCpp:
if c.LlamaServerOptions != nil {
data, err := json.Marshal(c.LlamaServerOptions) data, err := json.Marshal(c.LlamaServerOptions)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal llama server options: %w", err) return nil, fmt.Errorf("failed to marshal llama server options: %w", err)
@@ -86,6 +102,21 @@ func (c *CreateInstanceOptions) MarshalJSON() ([]byte, error) {
aux.BackendOptions = backendOpts aux.BackendOptions = backendOpts
} }
case backends.BackendTypeMlxLm:
if c.MlxServerOptions != nil {
data, err := json.Marshal(c.MlxServerOptions)
if err != nil {
return nil, fmt.Errorf("failed to marshal MLX server options: %w", err)
}
var backendOpts map[string]any
if err := json.Unmarshal(data, &backendOpts); err != nil {
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
}
aux.BackendOptions = backendOpts
}
}
return json.Marshal(aux) return json.Marshal(aux)
} }
@@ -136,6 +167,10 @@ func (c *CreateInstanceOptions) BuildCommandArgs() []string {
if c.LlamaServerOptions != nil { if c.LlamaServerOptions != nil {
return c.LlamaServerOptions.BuildCommandArgs() return c.LlamaServerOptions.BuildCommandArgs()
} }
case backends.BackendTypeMlxLm:
if c.MlxServerOptions != nil {
return c.MlxServerOptions.BuildCommandArgs()
}
} }
return []string{} return []string{}
} }

View File

@@ -44,6 +44,8 @@ func ValidateInstanceOptions(options *instance.CreateInstanceOptions) error {
switch options.BackendType { switch options.BackendType {
case backends.BackendTypeLlamaCpp: case backends.BackendTypeLlamaCpp:
return validateLlamaCppOptions(options) return validateLlamaCppOptions(options)
case backends.BackendTypeMlxLm:
return validateMlxOptions(options)
default: default:
return ValidationError(fmt.Errorf("unsupported backend type: %s", options.BackendType)) return ValidationError(fmt.Errorf("unsupported backend type: %s", options.BackendType))
} }
@@ -68,6 +70,24 @@ func validateLlamaCppOptions(options *instance.CreateInstanceOptions) error {
return nil return nil
} }
// validateMlxOptions validates MLX backend specific options
func validateMlxOptions(options *instance.CreateInstanceOptions) error {
if options.MlxServerOptions == nil {
return ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend"))
}
if err := validateStructStrings(options.MlxServerOptions, ""); err != nil {
return err
}
// Basic network validation for port
if options.MlxServerOptions.Port < 0 || options.MlxServerOptions.Port > 65535 {
return ValidationError(fmt.Errorf("invalid port range: %d", options.MlxServerOptions.Port))
}
return nil
}
// validateStructStrings recursively validates all string fields in a struct // validateStructStrings recursively validates all string fields in a struct
func validateStructStrings(v any, fieldPath string) error { func validateStructStrings(v any, fieldPath string) error {
val := reflect.ValueOf(v) val := reflect.ValueOf(v)