From 988c4aca409f8b4ec697394c24cf5c502de6b384 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 16 Sep 2025 21:14:19 +0200 Subject: [PATCH] Add MLX backend config options --- pkg/backends/backend.go | 2 + pkg/backends/mlx/mlx.go | 208 ++++++++++++++++++++++++++++++ pkg/backends/mlx/parser.go | 254 +++++++++++++++++++++++++++++++++++++ pkg/config/config.go | 28 +++- 4 files changed, 491 insertions(+), 1 deletion(-) create mode 100644 pkg/backends/mlx/mlx.go create mode 100644 pkg/backends/mlx/parser.go diff --git a/pkg/backends/backend.go b/pkg/backends/backend.go index c28a2cc..0270945 100644 --- a/pkg/backends/backend.go +++ b/pkg/backends/backend.go @@ -4,4 +4,6 @@ type BackendType string const ( BackendTypeLlamaCpp BackendType = "llama_cpp" + BackendTypeMlxLm BackendType = "mlx_lm" + // BackendTypeMlxVlm BackendType = "mlx_vlm" // Future expansion ) diff --git a/pkg/backends/mlx/mlx.go b/pkg/backends/mlx/mlx.go new file mode 100644 index 0000000..06f3128 --- /dev/null +++ b/pkg/backends/mlx/mlx.go @@ -0,0 +1,208 @@ +package mlx + +import ( + "encoding/json" + "reflect" + "strconv" +) + +type MlxServerOptions struct { + // Basic connection options + Model string `json:"model,omitempty"` + Host string `json:"host,omitempty"` + Port int `json:"port,omitempty"` + PythonPath string `json:"python_path,omitempty"` // Custom: Python venv path + + // Model and adapter options + AdapterPath string `json:"adapter_path,omitempty"` + DraftModel string `json:"draft_model,omitempty"` + NumDraftTokens int `json:"num_draft_tokens,omitempty"` + TrustRemoteCode bool `json:"trust_remote_code,omitempty"` + + // Logging and templates + LogLevel string `json:"log_level,omitempty"` + ChatTemplate string `json:"chat_template,omitempty"` + UseDefaultChatTemplate bool `json:"use_default_chat_template,omitempty"` + ChatTemplateArgs string `json:"chat_template_args,omitempty"` // JSON string + + // Sampling defaults + Temp float64 `json:"temp,omitempty"` // Note: MLX uses "temp" not "temperature" + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + MinP float64 `json:"min_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` +} + +// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names +func (o *MlxServerOptions) UnmarshalJSON(data []byte) error { + // First unmarshal into a map to handle multiple field names + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Create a temporary struct for standard unmarshaling + type tempOptions MlxServerOptions + temp := tempOptions{} + + // Standard unmarshal first + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + // Copy to our struct + *o = MlxServerOptions(temp) + + // Handle alternative field names + fieldMappings := map[string]string{ + // Basic connection options + "m": "model", + "host": "host", + "port": "port", + "python_path": "python_path", + + // Model and adapter options + "adapter-path": "adapter_path", + "draft-model": "draft_model", + "num-draft-tokens": "num_draft_tokens", + "trust-remote-code": "trust_remote_code", + + // Logging and templates + "log-level": "log_level", + "chat-template": "chat_template", + "use-default-chat-template": "use_default_chat_template", + "chat-template-args": "chat_template_args", + + // Sampling defaults + "temperature": "temp", // Support both temp and temperature + "top-p": "top_p", + "top-k": "top_k", + "min-p": "min_p", + "max-tokens": "max_tokens", + } + + // Process alternative field names + for altName, canonicalName := range fieldMappings { + if value, exists := raw[altName]; exists { + // Use reflection to set the field value + v := reflect.ValueOf(o).Elem() + field := v.FieldByNameFunc(func(fieldName string) bool { + field, _ := v.Type().FieldByName(fieldName) + jsonTag := field.Tag.Get("json") + return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName + }) + + if field.IsValid() && field.CanSet() { + switch field.Kind() { + case reflect.Int: + if intVal, ok := value.(float64); ok { + field.SetInt(int64(intVal)) + } else if strVal, ok := value.(string); ok { + if intVal, err := strconv.Atoi(strVal); err == nil { + field.SetInt(int64(intVal)) + } + } + case reflect.Float64: + if floatVal, ok := value.(float64); ok { + field.SetFloat(floatVal) + } else if strVal, ok := value.(string); ok { + if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil { + field.SetFloat(floatVal) + } + } + case reflect.String: + if strVal, ok := value.(string); ok { + field.SetString(strVal) + } + case reflect.Bool: + if boolVal, ok := value.(bool); ok { + field.SetBool(boolVal) + } + } + } + } + } + + return nil +} + +// NewMlxServerOptions creates MlxServerOptions with MLX defaults +func NewMlxServerOptions() *MlxServerOptions { + return &MlxServerOptions{ + Host: "127.0.0.1", // MLX default (different from llama-server) + Port: 8080, // MLX default + NumDraftTokens: 3, // MLX default for speculative decoding + LogLevel: "INFO", // MLX default + Temp: 0.0, // MLX default + TopP: 1.0, // MLX default + TopK: 0, // MLX default (disabled) + MinP: 0.0, // MLX default (disabled) + MaxTokens: 512, // MLX default + ChatTemplateArgs: "{}", // MLX default (empty JSON object) + } +} + +// BuildCommandArgs converts to command line arguments +func (o *MlxServerOptions) BuildCommandArgs() []string { + var args []string + + // Note: PythonPath is handled in lifecycle.go execution logic + + // Required and basic options + if o.Model != "" { + args = append(args, "--model", o.Model) + } + if o.Host != "" { + args = append(args, "--host", o.Host) + } + if o.Port != 0 { + args = append(args, "--port", strconv.Itoa(o.Port)) + } + + // Model and adapter options + if o.AdapterPath != "" { + args = append(args, "--adapter-path", o.AdapterPath) + } + if o.DraftModel != "" { + args = append(args, "--draft-model", o.DraftModel) + } + if o.NumDraftTokens != 0 { + args = append(args, "--num-draft-tokens", strconv.Itoa(o.NumDraftTokens)) + } + if o.TrustRemoteCode { + args = append(args, "--trust-remote-code") + } + + // Logging and templates + if o.LogLevel != "" { + args = append(args, "--log-level", o.LogLevel) + } + if o.ChatTemplate != "" { + args = append(args, "--chat-template", o.ChatTemplate) + } + if o.UseDefaultChatTemplate { + args = append(args, "--use-default-chat-template") + } + if o.ChatTemplateArgs != "" { + args = append(args, "--chat-template-args", o.ChatTemplateArgs) + } + + // Sampling defaults + if o.Temp != 0 { + args = append(args, "--temp", strconv.FormatFloat(o.Temp, 'f', -1, 64)) + } + if o.TopP != 0 { + args = append(args, "--top-p", strconv.FormatFloat(o.TopP, 'f', -1, 64)) + } + if o.TopK != 0 { + args = append(args, "--top-k", strconv.Itoa(o.TopK)) + } + if o.MinP != 0 { + args = append(args, "--min-p", strconv.FormatFloat(o.MinP, 'f', -1, 64)) + } + if o.MaxTokens != 0 { + args = append(args, "--max-tokens", strconv.Itoa(o.MaxTokens)) + } + + return args +} \ No newline at end of file diff --git a/pkg/backends/mlx/parser.go b/pkg/backends/mlx/parser.go new file mode 100644 index 0000000..96b04a9 --- /dev/null +++ b/pkg/backends/mlx/parser.go @@ -0,0 +1,254 @@ +package mlx + +import ( + "encoding/json" + "fmt" + "path/filepath" + "regexp" + "strconv" + "strings" +) + +// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions +// Supports multiple formats: +// 1. Full command: "mlx_lm.server --model model/path" +// 2. Full path: "/usr/local/bin/mlx_lm.server --model model/path" +// 3. Args only: "--model model/path --host 0.0.0.0" +// 4. Multiline commands with backslashes +func ParseMlxCommand(command string) (*MlxServerOptions, error) { + // 1. Normalize the command - handle multiline with backslashes + trimmed := normalizeMultilineCommand(command) + if trimmed == "" { + return nil, fmt.Errorf("command cannot be empty") + } + + // 2. Extract arguments from command + args, err := extractArgumentsFromCommand(trimmed) + if err != nil { + return nil, err + } + + // 3. Parse arguments into map + options := make(map[string]any) + + i := 0 + for i < len(args) { + arg := args[i] + + if !strings.HasPrefix(arg, "-") { // skip positional / stray values + i++ + continue + } + + // Reject malformed flags with more than two leading dashes (e.g. ---model) to surface user mistakes + if strings.HasPrefix(arg, "---") { + return nil, fmt.Errorf("malformed flag: %s", arg) + } + + // Unified parsing for --flag=value vs --flag value + var rawFlag, rawValue string + hasEquals := false + if strings.Contains(arg, "=") { + parts := strings.SplitN(arg, "=", 2) + rawFlag = parts[0] + rawValue = parts[1] // may be empty string + hasEquals = true + } else { + rawFlag = arg + } + + flagCore := strings.TrimPrefix(strings.TrimPrefix(rawFlag, "-"), "-") + flagName := strings.ReplaceAll(flagCore, "-", "_") + + // Detect value if not in equals form + valueProvided := hasEquals + if !hasEquals { + if i+1 < len(args) && !isFlag(args[i+1]) { // next token is value + rawValue = args[i+1] + valueProvided = true + } + } + + if valueProvided { + // MLX-specific validation for certain flags + if flagName == "log_level" && !isValidLogLevel(rawValue) { + return nil, fmt.Errorf("invalid log level: %s", rawValue) + } + + options[flagName] = parseValue(rawValue) + + // Advance index: if we consumed a following token as value (non equals form), skip it + if !hasEquals && i+1 < len(args) && rawValue == args[i+1] { + i += 2 + } else { + i++ + } + continue + } + + // Boolean flag (no value) - MLX specific boolean flags + if flagName == "trust_remote_code" || flagName == "use_default_chat_template" { + options[flagName] = true + } else { + options[flagName] = true + } + i++ + } + + // 4. Convert to MlxServerOptions using existing UnmarshalJSON + jsonData, err := json.Marshal(options) + if err != nil { + return nil, fmt.Errorf("failed to marshal parsed options: %w", err) + } + + var mlxOptions MlxServerOptions + if err := json.Unmarshal(jsonData, &mlxOptions); err != nil { + return nil, fmt.Errorf("failed to parse command options: %w", err) + } + + // 5. Return MlxServerOptions + return &mlxOptions, nil +} + +// isValidLogLevel validates MLX log levels +func isValidLogLevel(level string) bool { + validLevels := []string{"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} + for _, valid := range validLevels { + if level == valid { + return true + } + } + return false +} + +// parseValue attempts to parse a string value into the most appropriate type +func parseValue(value string) any { + // Surrounding matching quotes (single or double) + if l := len(value); l >= 2 { + if (value[0] == '"' && value[l-1] == '"') || (value[0] == '\'' && value[l-1] == '\'') { + value = value[1 : l-1] + } + } + + lower := strings.ToLower(value) + if lower == "true" { + return true + } + if lower == "false" { + return false + } + + if intVal, err := strconv.Atoi(value); err == nil { + return intVal + } + if floatVal, err := strconv.ParseFloat(value, 64); err == nil { + return floatVal + } + return value +} + +// normalizeMultilineCommand handles multiline commands with backslashes +func normalizeMultilineCommand(command string) string { + // Handle escaped newlines (backslash followed by newline) + re := regexp.MustCompile(`\\\s*\n\s*`) + normalized := re.ReplaceAllString(command, " ") + + // Clean up extra whitespace + re = regexp.MustCompile(`\s+`) + normalized = re.ReplaceAllString(normalized, " ") + + return strings.TrimSpace(normalized) +} + +// extractArgumentsFromCommand extracts arguments from various command formats +func extractArgumentsFromCommand(command string) ([]string, error) { + // Split command into tokens respecting quotes + tokens, err := splitCommandTokens(command) + if err != nil { + return nil, err + } + + if len(tokens) == 0 { + return nil, fmt.Errorf("no command tokens found") + } + + // Check if first token looks like an executable + firstToken := tokens[0] + + // Case 1: Full path to executable (contains path separator or ends with mlx_lm.server) + if strings.Contains(firstToken, string(filepath.Separator)) || + strings.HasSuffix(filepath.Base(firstToken), "mlx_lm.server") { + return tokens[1:], nil // Return everything except the executable + } + + // Case 2: Just "mlx_lm.server" command + if strings.ToLower(firstToken) == "mlx_lm.server" { + return tokens[1:], nil // Return everything except the command + } + + // Case 3: Arguments only (starts with a flag) + if strings.HasPrefix(firstToken, "-") { + return tokens, nil // Return all tokens as arguments + } + + // Case 4: Unknown format - might be a different executable name + // Be permissive and assume it's the executable + return tokens[1:], nil +} + +// splitCommandTokens splits a command string into tokens, respecting quotes +func splitCommandTokens(command string) ([]string, error) { + var tokens []string + var current strings.Builder + inQuotes := false + quoteChar := byte(0) + escaped := false + + for i := 0; i < len(command); i++ { + c := command[i] + + if escaped { + current.WriteByte(c) + escaped = false + continue + } + + if c == '\\' { + escaped = true + current.WriteByte(c) + continue + } + + if !inQuotes && (c == '"' || c == '\'') { + inQuotes = true + quoteChar = c + current.WriteByte(c) + } else if inQuotes && c == quoteChar { + inQuotes = false + quoteChar = 0 + current.WriteByte(c) + } else if !inQuotes && (c == ' ' || c == '\t' || c == '\n') { + if current.Len() > 0 { + tokens = append(tokens, current.String()) + current.Reset() + } + } else { + current.WriteByte(c) + } + } + + if inQuotes { + return nil, fmt.Errorf("unclosed quote in command") + } + + if current.Len() > 0 { + tokens = append(tokens, current.String()) + } + + return tokens, nil +} + +// isFlag checks if a string looks like a command line flag +func isFlag(s string) bool { + return strings.HasPrefix(s, "-") +} \ No newline at end of file diff --git a/pkg/config/config.go b/pkg/config/config.go index 5017662..1b873a5 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -10,9 +10,22 @@ import ( "gopkg.in/yaml.v3" ) +// BackendConfig contains backend executable configurations +type BackendConfig struct { + // Path to llama-server executable (llama.cpp backend) + LlamaExecutable string `yaml:"llama_executable"` + + // Path to mlx_lm executable (MLX-LM backend) + MLXLMExecutable string `yaml:"mlx_lm_executable"` + + // Optional: Default Python virtual environment path for MLX backends + MLXPythonPath string `yaml:"mlx_python_path,omitempty"` +} + // AppConfig represents the configuration for llamactl type AppConfig struct { Server ServerConfig `yaml:"server"` + Backends BackendConfig `yaml:"backends"` Instances InstancesConfig `yaml:"instances"` Auth AuthConfig `yaml:"auth"` Version string `yaml:"-"` @@ -112,6 +125,11 @@ func LoadConfig(configPath string) (AppConfig, error) { AllowedOrigins: []string{"*"}, // Default to allow all origins EnableSwagger: false, }, + Backends: BackendConfig{ + LlamaExecutable: "llama-server", + MLXLMExecutable: "mlx_lm.server", + MLXPythonPath: "", // Empty means use system Python + }, Instances: InstancesConfig{ PortRange: [2]int{8000, 9000}, DataDir: getDefaultDataDirectory(), @@ -229,8 +247,16 @@ func loadEnvVars(cfg *AppConfig) { cfg.Instances.EnableLRUEviction = b } } + // Backend config if llamaExec := os.Getenv("LLAMACTL_LLAMA_EXECUTABLE"); llamaExec != "" { - cfg.Instances.LlamaExecutable = llamaExec + cfg.Backends.LlamaExecutable = llamaExec + cfg.Instances.LlamaExecutable = llamaExec // Keep for backward compatibility + } + if mlxLMExec := os.Getenv("LLAMACTL_MLX_LM_EXECUTABLE"); mlxLMExec != "" { + cfg.Backends.MLXLMExecutable = mlxLMExec + } + if mlxPython := os.Getenv("LLAMACTL_MLX_PYTHON_PATH"); mlxPython != "" { + cfg.Backends.MLXPythonPath = mlxPython } if autoRestart := os.Getenv("LLAMACTL_DEFAULT_AUTO_RESTART"); autoRestart != "" { if b, err := strconv.ParseBool(autoRestart); err == nil {