mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-06 00:54:23 +00:00
Add MLX backend support in CreateInstanceOptions and validation
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"llamactl/pkg/backends"
|
"llamactl/pkg/backends"
|
||||||
"llamactl/pkg/backends/llamacpp"
|
"llamactl/pkg/backends/llamacpp"
|
||||||
|
"llamactl/pkg/backends/mlx"
|
||||||
"llamactl/pkg/config"
|
"llamactl/pkg/config"
|
||||||
"log"
|
"log"
|
||||||
)
|
)
|
||||||
@@ -22,8 +23,9 @@ type CreateInstanceOptions struct {
|
|||||||
BackendType backends.BackendType `json:"backend_type"`
|
BackendType backends.BackendType `json:"backend_type"`
|
||||||
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
BackendOptions map[string]any `json:"backend_options,omitempty"`
|
||||||
|
|
||||||
// LlamaServerOptions contains the options for the llama server
|
// Backend-specific options
|
||||||
LlamaServerOptions *llamacpp.LlamaServerOptions `json:"-"`
|
LlamaServerOptions *llamacpp.LlamaServerOptions `json:"-"`
|
||||||
|
MlxServerOptions *mlx.MlxServerOptions `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalJSON implements custom JSON unmarshaling for CreateInstanceOptions
|
// UnmarshalJSON implements custom JSON unmarshaling for CreateInstanceOptions
|
||||||
@@ -55,6 +57,18 @@ func (c *CreateInstanceOptions) UnmarshalJSON(data []byte) error {
|
|||||||
return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err)
|
return fmt.Errorf("failed to unmarshal llama.cpp options: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case backends.BackendTypeMlxLm:
|
||||||
|
if c.BackendOptions != nil {
|
||||||
|
optionsData, err := json.Marshal(c.BackendOptions)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal backend options: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.MlxServerOptions = &mlx.MlxServerOptions{}
|
||||||
|
if err := json.Unmarshal(optionsData, c.MlxServerOptions); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal MLX options: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown backend type: %s", c.BackendType)
|
return fmt.Errorf("unknown backend type: %s", c.BackendType)
|
||||||
}
|
}
|
||||||
@@ -72,19 +86,36 @@ func (c *CreateInstanceOptions) MarshalJSON() ([]byte, error) {
|
|||||||
Alias: (*Alias)(c),
|
Alias: (*Alias)(c),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert LlamaServerOptions back to BackendOptions map for JSON
|
// Convert backend-specific options back to BackendOptions map for JSON
|
||||||
if c.BackendType == backends.BackendTypeLlamaCpp && c.LlamaServerOptions != nil {
|
switch c.BackendType {
|
||||||
data, err := json.Marshal(c.LlamaServerOptions)
|
case backends.BackendTypeLlamaCpp:
|
||||||
if err != nil {
|
if c.LlamaServerOptions != nil {
|
||||||
return nil, fmt.Errorf("failed to marshal llama server options: %w", err)
|
data, err := json.Marshal(c.LlamaServerOptions)
|
||||||
}
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal llama server options: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
var backendOpts map[string]any
|
var backendOpts map[string]any
|
||||||
if err := json.Unmarshal(data, &backendOpts); err != nil {
|
if err := json.Unmarshal(data, &backendOpts); err != nil {
|
||||||
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
aux.BackendOptions = backendOpts
|
aux.BackendOptions = backendOpts
|
||||||
|
}
|
||||||
|
case backends.BackendTypeMlxLm:
|
||||||
|
if c.MlxServerOptions != nil {
|
||||||
|
data, err := json.Marshal(c.MlxServerOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal MLX server options: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var backendOpts map[string]any
|
||||||
|
if err := json.Unmarshal(data, &backendOpts); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
aux.BackendOptions = backendOpts
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return json.Marshal(aux)
|
return json.Marshal(aux)
|
||||||
@@ -136,6 +167,10 @@ func (c *CreateInstanceOptions) BuildCommandArgs() []string {
|
|||||||
if c.LlamaServerOptions != nil {
|
if c.LlamaServerOptions != nil {
|
||||||
return c.LlamaServerOptions.BuildCommandArgs()
|
return c.LlamaServerOptions.BuildCommandArgs()
|
||||||
}
|
}
|
||||||
|
case backends.BackendTypeMlxLm:
|
||||||
|
if c.MlxServerOptions != nil {
|
||||||
|
return c.MlxServerOptions.BuildCommandArgs()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return []string{}
|
return []string{}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,6 +44,8 @@ func ValidateInstanceOptions(options *instance.CreateInstanceOptions) error {
|
|||||||
switch options.BackendType {
|
switch options.BackendType {
|
||||||
case backends.BackendTypeLlamaCpp:
|
case backends.BackendTypeLlamaCpp:
|
||||||
return validateLlamaCppOptions(options)
|
return validateLlamaCppOptions(options)
|
||||||
|
case backends.BackendTypeMlxLm:
|
||||||
|
return validateMlxOptions(options)
|
||||||
default:
|
default:
|
||||||
return ValidationError(fmt.Errorf("unsupported backend type: %s", options.BackendType))
|
return ValidationError(fmt.Errorf("unsupported backend type: %s", options.BackendType))
|
||||||
}
|
}
|
||||||
@@ -68,6 +70,24 @@ func validateLlamaCppOptions(options *instance.CreateInstanceOptions) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateMlxOptions validates MLX backend specific options
|
||||||
|
func validateMlxOptions(options *instance.CreateInstanceOptions) error {
|
||||||
|
if options.MlxServerOptions == nil {
|
||||||
|
return ValidationError(fmt.Errorf("MLX server options cannot be nil for MLX backend"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateStructStrings(options.MlxServerOptions, ""); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic network validation for port
|
||||||
|
if options.MlxServerOptions.Port < 0 || options.MlxServerOptions.Port > 65535 {
|
||||||
|
return ValidationError(fmt.Errorf("invalid port range: %d", options.MlxServerOptions.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// validateStructStrings recursively validates all string fields in a struct
|
// validateStructStrings recursively validates all string fields in a struct
|
||||||
func validateStructStrings(v any, fieldPath string) error {
|
func validateStructStrings(v any, fieldPath string) error {
|
||||||
val := reflect.ValueOf(v)
|
val := reflect.ValueOf(v)
|
||||||
|
|||||||
Reference in New Issue
Block a user