Add MLX backend config options

This commit is contained in:
2025-09-16 21:14:19 +02:00
parent 1f25e9d05b
commit 988c4aca40
4 changed files with 491 additions and 1 deletions

View File

@@ -4,4 +4,6 @@ type BackendType string
const ( const (
BackendTypeLlamaCpp BackendType = "llama_cpp" BackendTypeLlamaCpp BackendType = "llama_cpp"
BackendTypeMlxLm BackendType = "mlx_lm"
// BackendTypeMlxVlm BackendType = "mlx_vlm" // Future expansion
) )

208
pkg/backends/mlx/mlx.go Normal file
View File

@@ -0,0 +1,208 @@
package mlx
import (
"encoding/json"
"reflect"
"strconv"
)
type MlxServerOptions struct {
// Basic connection options
Model string `json:"model,omitempty"`
Host string `json:"host,omitempty"`
Port int `json:"port,omitempty"`
PythonPath string `json:"python_path,omitempty"` // Custom: Python venv path
// Model and adapter options
AdapterPath string `json:"adapter_path,omitempty"`
DraftModel string `json:"draft_model,omitempty"`
NumDraftTokens int `json:"num_draft_tokens,omitempty"`
TrustRemoteCode bool `json:"trust_remote_code,omitempty"`
// Logging and templates
LogLevel string `json:"log_level,omitempty"`
ChatTemplate string `json:"chat_template,omitempty"`
UseDefaultChatTemplate bool `json:"use_default_chat_template,omitempty"`
ChatTemplateArgs string `json:"chat_template_args,omitempty"` // JSON string
// Sampling defaults
Temp float64 `json:"temp,omitempty"` // Note: MLX uses "temp" not "temperature"
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
MinP float64 `json:"min_p,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
}
// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names
func (o *MlxServerOptions) UnmarshalJSON(data []byte) error {
// First unmarshal into a map to handle multiple field names
var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
// Create a temporary struct for standard unmarshaling
type tempOptions MlxServerOptions
temp := tempOptions{}
// Standard unmarshal first
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
// Copy to our struct
*o = MlxServerOptions(temp)
// Handle alternative field names
fieldMappings := map[string]string{
// Basic connection options
"m": "model",
"host": "host",
"port": "port",
"python_path": "python_path",
// Model and adapter options
"adapter-path": "adapter_path",
"draft-model": "draft_model",
"num-draft-tokens": "num_draft_tokens",
"trust-remote-code": "trust_remote_code",
// Logging and templates
"log-level": "log_level",
"chat-template": "chat_template",
"use-default-chat-template": "use_default_chat_template",
"chat-template-args": "chat_template_args",
// Sampling defaults
"temperature": "temp", // Support both temp and temperature
"top-p": "top_p",
"top-k": "top_k",
"min-p": "min_p",
"max-tokens": "max_tokens",
}
// Process alternative field names
for altName, canonicalName := range fieldMappings {
if value, exists := raw[altName]; exists {
// Use reflection to set the field value
v := reflect.ValueOf(o).Elem()
field := v.FieldByNameFunc(func(fieldName string) bool {
field, _ := v.Type().FieldByName(fieldName)
jsonTag := field.Tag.Get("json")
return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName
})
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Int:
if intVal, ok := value.(float64); ok {
field.SetInt(int64(intVal))
} else if strVal, ok := value.(string); ok {
if intVal, err := strconv.Atoi(strVal); err == nil {
field.SetInt(int64(intVal))
}
}
case reflect.Float64:
if floatVal, ok := value.(float64); ok {
field.SetFloat(floatVal)
} else if strVal, ok := value.(string); ok {
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
field.SetFloat(floatVal)
}
}
case reflect.String:
if strVal, ok := value.(string); ok {
field.SetString(strVal)
}
case reflect.Bool:
if boolVal, ok := value.(bool); ok {
field.SetBool(boolVal)
}
}
}
}
}
return nil
}
// NewMlxServerOptions creates MlxServerOptions with MLX defaults
func NewMlxServerOptions() *MlxServerOptions {
return &MlxServerOptions{
Host: "127.0.0.1", // MLX default (different from llama-server)
Port: 8080, // MLX default
NumDraftTokens: 3, // MLX default for speculative decoding
LogLevel: "INFO", // MLX default
Temp: 0.0, // MLX default
TopP: 1.0, // MLX default
TopK: 0, // MLX default (disabled)
MinP: 0.0, // MLX default (disabled)
MaxTokens: 512, // MLX default
ChatTemplateArgs: "{}", // MLX default (empty JSON object)
}
}
// BuildCommandArgs converts to command line arguments
func (o *MlxServerOptions) BuildCommandArgs() []string {
var args []string
// Note: PythonPath is handled in lifecycle.go execution logic
// Required and basic options
if o.Model != "" {
args = append(args, "--model", o.Model)
}
if o.Host != "" {
args = append(args, "--host", o.Host)
}
if o.Port != 0 {
args = append(args, "--port", strconv.Itoa(o.Port))
}
// Model and adapter options
if o.AdapterPath != "" {
args = append(args, "--adapter-path", o.AdapterPath)
}
if o.DraftModel != "" {
args = append(args, "--draft-model", o.DraftModel)
}
if o.NumDraftTokens != 0 {
args = append(args, "--num-draft-tokens", strconv.Itoa(o.NumDraftTokens))
}
if o.TrustRemoteCode {
args = append(args, "--trust-remote-code")
}
// Logging and templates
if o.LogLevel != "" {
args = append(args, "--log-level", o.LogLevel)
}
if o.ChatTemplate != "" {
args = append(args, "--chat-template", o.ChatTemplate)
}
if o.UseDefaultChatTemplate {
args = append(args, "--use-default-chat-template")
}
if o.ChatTemplateArgs != "" {
args = append(args, "--chat-template-args", o.ChatTemplateArgs)
}
// Sampling defaults
if o.Temp != 0 {
args = append(args, "--temp", strconv.FormatFloat(o.Temp, 'f', -1, 64))
}
if o.TopP != 0 {
args = append(args, "--top-p", strconv.FormatFloat(o.TopP, 'f', -1, 64))
}
if o.TopK != 0 {
args = append(args, "--top-k", strconv.Itoa(o.TopK))
}
if o.MinP != 0 {
args = append(args, "--min-p", strconv.FormatFloat(o.MinP, 'f', -1, 64))
}
if o.MaxTokens != 0 {
args = append(args, "--max-tokens", strconv.Itoa(o.MaxTokens))
}
return args
}

254
pkg/backends/mlx/parser.go Normal file
View File

@@ -0,0 +1,254 @@
package mlx
import (
"encoding/json"
"fmt"
"path/filepath"
"regexp"
"strconv"
"strings"
)
// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions
// Supports multiple formats:
// 1. Full command: "mlx_lm.server --model model/path"
// 2. Full path: "/usr/local/bin/mlx_lm.server --model model/path"
// 3. Args only: "--model model/path --host 0.0.0.0"
// 4. Multiline commands with backslashes
func ParseMlxCommand(command string) (*MlxServerOptions, error) {
// 1. Normalize the command - handle multiline with backslashes
trimmed := normalizeMultilineCommand(command)
if trimmed == "" {
return nil, fmt.Errorf("command cannot be empty")
}
// 2. Extract arguments from command
args, err := extractArgumentsFromCommand(trimmed)
if err != nil {
return nil, err
}
// 3. Parse arguments into map
options := make(map[string]any)
i := 0
for i < len(args) {
arg := args[i]
if !strings.HasPrefix(arg, "-") { // skip positional / stray values
i++
continue
}
// Reject malformed flags with more than two leading dashes (e.g. ---model) to surface user mistakes
if strings.HasPrefix(arg, "---") {
return nil, fmt.Errorf("malformed flag: %s", arg)
}
// Unified parsing for --flag=value vs --flag value
var rawFlag, rawValue string
hasEquals := false
if strings.Contains(arg, "=") {
parts := strings.SplitN(arg, "=", 2)
rawFlag = parts[0]
rawValue = parts[1] // may be empty string
hasEquals = true
} else {
rawFlag = arg
}
flagCore := strings.TrimPrefix(strings.TrimPrefix(rawFlag, "-"), "-")
flagName := strings.ReplaceAll(flagCore, "-", "_")
// Detect value if not in equals form
valueProvided := hasEquals
if !hasEquals {
if i+1 < len(args) && !isFlag(args[i+1]) { // next token is value
rawValue = args[i+1]
valueProvided = true
}
}
if valueProvided {
// MLX-specific validation for certain flags
if flagName == "log_level" && !isValidLogLevel(rawValue) {
return nil, fmt.Errorf("invalid log level: %s", rawValue)
}
options[flagName] = parseValue(rawValue)
// Advance index: if we consumed a following token as value (non equals form), skip it
if !hasEquals && i+1 < len(args) && rawValue == args[i+1] {
i += 2
} else {
i++
}
continue
}
// Boolean flag (no value) - MLX specific boolean flags
if flagName == "trust_remote_code" || flagName == "use_default_chat_template" {
options[flagName] = true
} else {
options[flagName] = true
}
i++
}
// 4. Convert to MlxServerOptions using existing UnmarshalJSON
jsonData, err := json.Marshal(options)
if err != nil {
return nil, fmt.Errorf("failed to marshal parsed options: %w", err)
}
var mlxOptions MlxServerOptions
if err := json.Unmarshal(jsonData, &mlxOptions); err != nil {
return nil, fmt.Errorf("failed to parse command options: %w", err)
}
// 5. Return MlxServerOptions
return &mlxOptions, nil
}
// isValidLogLevel validates MLX log levels
func isValidLogLevel(level string) bool {
validLevels := []string{"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
for _, valid := range validLevels {
if level == valid {
return true
}
}
return false
}
// parseValue attempts to parse a string value into the most appropriate type
func parseValue(value string) any {
// Surrounding matching quotes (single or double)
if l := len(value); l >= 2 {
if (value[0] == '"' && value[l-1] == '"') || (value[0] == '\'' && value[l-1] == '\'') {
value = value[1 : l-1]
}
}
lower := strings.ToLower(value)
if lower == "true" {
return true
}
if lower == "false" {
return false
}
if intVal, err := strconv.Atoi(value); err == nil {
return intVal
}
if floatVal, err := strconv.ParseFloat(value, 64); err == nil {
return floatVal
}
return value
}
// normalizeMultilineCommand handles multiline commands with backslashes
func normalizeMultilineCommand(command string) string {
// Handle escaped newlines (backslash followed by newline)
re := regexp.MustCompile(`\\\s*\n\s*`)
normalized := re.ReplaceAllString(command, " ")
// Clean up extra whitespace
re = regexp.MustCompile(`\s+`)
normalized = re.ReplaceAllString(normalized, " ")
return strings.TrimSpace(normalized)
}
// extractArgumentsFromCommand extracts arguments from various command formats
func extractArgumentsFromCommand(command string) ([]string, error) {
// Split command into tokens respecting quotes
tokens, err := splitCommandTokens(command)
if err != nil {
return nil, err
}
if len(tokens) == 0 {
return nil, fmt.Errorf("no command tokens found")
}
// Check if first token looks like an executable
firstToken := tokens[0]
// Case 1: Full path to executable (contains path separator or ends with mlx_lm.server)
if strings.Contains(firstToken, string(filepath.Separator)) ||
strings.HasSuffix(filepath.Base(firstToken), "mlx_lm.server") {
return tokens[1:], nil // Return everything except the executable
}
// Case 2: Just "mlx_lm.server" command
if strings.ToLower(firstToken) == "mlx_lm.server" {
return tokens[1:], nil // Return everything except the command
}
// Case 3: Arguments only (starts with a flag)
if strings.HasPrefix(firstToken, "-") {
return tokens, nil // Return all tokens as arguments
}
// Case 4: Unknown format - might be a different executable name
// Be permissive and assume it's the executable
return tokens[1:], nil
}
// splitCommandTokens splits a command string into tokens, respecting quotes
func splitCommandTokens(command string) ([]string, error) {
var tokens []string
var current strings.Builder
inQuotes := false
quoteChar := byte(0)
escaped := false
for i := 0; i < len(command); i++ {
c := command[i]
if escaped {
current.WriteByte(c)
escaped = false
continue
}
if c == '\\' {
escaped = true
current.WriteByte(c)
continue
}
if !inQuotes && (c == '"' || c == '\'') {
inQuotes = true
quoteChar = c
current.WriteByte(c)
} else if inQuotes && c == quoteChar {
inQuotes = false
quoteChar = 0
current.WriteByte(c)
} else if !inQuotes && (c == ' ' || c == '\t' || c == '\n') {
if current.Len() > 0 {
tokens = append(tokens, current.String())
current.Reset()
}
} else {
current.WriteByte(c)
}
}
if inQuotes {
return nil, fmt.Errorf("unclosed quote in command")
}
if current.Len() > 0 {
tokens = append(tokens, current.String())
}
return tokens, nil
}
// isFlag checks if a string looks like a command line flag
func isFlag(s string) bool {
return strings.HasPrefix(s, "-")
}

View File

@@ -10,9 +10,22 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
// BackendConfig contains backend executable configurations
type BackendConfig struct {
// Path to llama-server executable (llama.cpp backend)
LlamaExecutable string `yaml:"llama_executable"`
// Path to mlx_lm executable (MLX-LM backend)
MLXLMExecutable string `yaml:"mlx_lm_executable"`
// Optional: Default Python virtual environment path for MLX backends
MLXPythonPath string `yaml:"mlx_python_path,omitempty"`
}
// AppConfig represents the configuration for llamactl // AppConfig represents the configuration for llamactl
type AppConfig struct { type AppConfig struct {
Server ServerConfig `yaml:"server"` Server ServerConfig `yaml:"server"`
Backends BackendConfig `yaml:"backends"`
Instances InstancesConfig `yaml:"instances"` Instances InstancesConfig `yaml:"instances"`
Auth AuthConfig `yaml:"auth"` Auth AuthConfig `yaml:"auth"`
Version string `yaml:"-"` Version string `yaml:"-"`
@@ -112,6 +125,11 @@ func LoadConfig(configPath string) (AppConfig, error) {
AllowedOrigins: []string{"*"}, // Default to allow all origins AllowedOrigins: []string{"*"}, // Default to allow all origins
EnableSwagger: false, EnableSwagger: false,
}, },
Backends: BackendConfig{
LlamaExecutable: "llama-server",
MLXLMExecutable: "mlx_lm.server",
MLXPythonPath: "", // Empty means use system Python
},
Instances: InstancesConfig{ Instances: InstancesConfig{
PortRange: [2]int{8000, 9000}, PortRange: [2]int{8000, 9000},
DataDir: getDefaultDataDirectory(), DataDir: getDefaultDataDirectory(),
@@ -229,8 +247,16 @@ func loadEnvVars(cfg *AppConfig) {
cfg.Instances.EnableLRUEviction = b cfg.Instances.EnableLRUEviction = b
} }
} }
// Backend config
if llamaExec := os.Getenv("LLAMACTL_LLAMA_EXECUTABLE"); llamaExec != "" { if llamaExec := os.Getenv("LLAMACTL_LLAMA_EXECUTABLE"); llamaExec != "" {
cfg.Instances.LlamaExecutable = llamaExec cfg.Backends.LlamaExecutable = llamaExec
cfg.Instances.LlamaExecutable = llamaExec // Keep for backward compatibility
}
if mlxLMExec := os.Getenv("LLAMACTL_MLX_LM_EXECUTABLE"); mlxLMExec != "" {
cfg.Backends.MLXLMExecutable = mlxLMExec
}
if mlxPython := os.Getenv("LLAMACTL_MLX_PYTHON_PATH"); mlxPython != "" {
cfg.Backends.MLXPythonPath = mlxPython
} }
if autoRestart := os.Getenv("LLAMACTL_DEFAULT_AUTO_RESTART"); autoRestart != "" { if autoRestart := os.Getenv("LLAMACTL_DEFAULT_AUTO_RESTART"); autoRestart != "" {
if b, err := strconv.ParseBool(autoRestart); err == nil { if b, err := strconv.ParseBool(autoRestart); err == nil {