Refactor MLX and VLLM server options parsing and args building

This commit is contained in:
2025-09-19 19:39:36 +02:00
parent c7136d5206
commit 9eecb37aec
7 changed files with 382 additions and 869 deletions

View File

@@ -1,205 +1,88 @@
package mlx
import (
"encoding/json"
"reflect"
"strconv"
"strings"
)
type MlxServerOptions struct {
// Basic connection options
Model string `json:"model,omitempty"`
Host string `json:"host,omitempty"`
Port int `json:"port,omitempty"`
Model string `json:"model,omitempty"`
Host string `json:"host,omitempty"`
Port int `json:"port,omitempty"`
// Model and adapter options
AdapterPath string `json:"adapter_path,omitempty"`
DraftModel string `json:"draft_model,omitempty"`
NumDraftTokens int `json:"num_draft_tokens,omitempty"`
TrustRemoteCode bool `json:"trust_remote_code,omitempty"`
// Logging and templates
LogLevel string `json:"log_level,omitempty"`
ChatTemplate string `json:"chat_template,omitempty"`
UseDefaultChatTemplate bool `json:"use_default_chat_template,omitempty"`
ChatTemplateArgs string `json:"chat_template_args,omitempty"` // JSON string
LogLevel string `json:"log_level,omitempty"`
ChatTemplate string `json:"chat_template,omitempty"`
UseDefaultChatTemplate bool `json:"use_default_chat_template,omitempty"`
ChatTemplateArgs string `json:"chat_template_args,omitempty"` // JSON string
// Sampling defaults
Temp float64 `json:"temp,omitempty"` // Note: MLX uses "temp" not "temperature"
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
MinP float64 `json:"min_p,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temp float64 `json:"temp,omitempty"` // Note: MLX uses "temp" not "temperature"
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
MinP float64 `json:"min_p,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
}
// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names
func (o *MlxServerOptions) UnmarshalJSON(data []byte) error {
// First unmarshal into a map to handle multiple field names
var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
// BuildCommandArgs converts to command line arguments using reflection
func (o *MlxServerOptions) BuildCommandArgs() []string {
var args []string
// Create a temporary struct for standard unmarshaling
type tempOptions MlxServerOptions
temp := tempOptions{}
v := reflect.ValueOf(o).Elem()
t := v.Type()
// Standard unmarshal first
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
fieldType := t.Field(i)
// Copy to our struct
*o = MlxServerOptions(temp)
// Skip unexported fields
if !field.CanInterface() {
continue
}
// Handle alternative field names
fieldMappings := map[string]string{
// Basic connection options
"m": "model",
"host": "host",
"port": "port",
// "python_path": "python_path", // removed
// Model and adapter options
"adapter-path": "adapter_path",
"draft-model": "draft_model",
"num-draft-tokens": "num_draft_tokens",
"trust-remote-code": "trust_remote_code",
// Logging and templates
"log-level": "log_level",
"chat-template": "chat_template",
"use-default-chat-template": "use_default_chat_template",
"chat-template-args": "chat_template_args",
// Sampling defaults
"temperature": "temp", // Support both temp and temperature
"top-p": "top_p",
"top-k": "top_k",
"min-p": "min_p",
"max-tokens": "max_tokens",
}
// Get the JSON tag to determine the flag name
jsonTag := fieldType.Tag.Get("json")
if jsonTag == "" || jsonTag == "-" {
continue
}
// Process alternative field names
for altName, canonicalName := range fieldMappings {
if value, exists := raw[altName]; exists {
// Use reflection to set the field value
v := reflect.ValueOf(o).Elem()
field := v.FieldByNameFunc(func(fieldName string) bool {
field, _ := v.Type().FieldByName(fieldName)
jsonTag := field.Tag.Get("json")
return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName
})
// Remove ",omitempty" from the tag
flagName := jsonTag
if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 {
flagName = jsonTag[:commaIndex]
}
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Int:
if intVal, ok := value.(float64); ok {
field.SetInt(int64(intVal))
} else if strVal, ok := value.(string); ok {
if intVal, err := strconv.Atoi(strVal); err == nil {
field.SetInt(int64(intVal))
}
}
case reflect.Float64:
if floatVal, ok := value.(float64); ok {
field.SetFloat(floatVal)
} else if strVal, ok := value.(string); ok {
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
field.SetFloat(floatVal)
}
}
case reflect.String:
if strVal, ok := value.(string); ok {
field.SetString(strVal)
}
case reflect.Bool:
if boolVal, ok := value.(bool); ok {
field.SetBool(boolVal)
}
}
// Convert snake_case to kebab-case for CLI flags
flagName = strings.ReplaceAll(flagName, "_", "-")
// Add the appropriate arguments based on field type and value
switch field.Kind() {
case reflect.Bool:
if field.Bool() {
args = append(args, "--"+flagName)
}
case reflect.Int:
if field.Int() != 0 {
args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10))
}
case reflect.Float64:
if field.Float() != 0 {
args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64))
}
case reflect.String:
if field.String() != "" {
args = append(args, "--"+flagName, field.String())
}
}
}
return nil
}
// NewMlxServerOptions creates MlxServerOptions with MLX defaults
func NewMlxServerOptions() *MlxServerOptions {
return &MlxServerOptions{
Host: "127.0.0.1", // MLX default (different from llama-server)
Port: 8080, // MLX default
NumDraftTokens: 3, // MLX default for speculative decoding
LogLevel: "INFO", // MLX default
Temp: 0.0, // MLX default
TopP: 1.0, // MLX default
TopK: 0, // MLX default (disabled)
MinP: 0.0, // MLX default (disabled)
MaxTokens: 512, // MLX default
ChatTemplateArgs: "{}", // MLX default (empty JSON object)
}
}
// BuildCommandArgs converts to command line arguments
func (o *MlxServerOptions) BuildCommandArgs() []string {
var args []string
// Required and basic options
if o.Model != "" {
args = append(args, "--model", o.Model)
}
if o.Host != "" {
args = append(args, "--host", o.Host)
}
if o.Port != 0 {
args = append(args, "--port", strconv.Itoa(o.Port))
}
// Model and adapter options
if o.AdapterPath != "" {
args = append(args, "--adapter-path", o.AdapterPath)
}
if o.DraftModel != "" {
args = append(args, "--draft-model", o.DraftModel)
}
if o.NumDraftTokens != 0 {
args = append(args, "--num-draft-tokens", strconv.Itoa(o.NumDraftTokens))
}
if o.TrustRemoteCode {
args = append(args, "--trust-remote-code")
}
// Logging and templates
if o.LogLevel != "" {
args = append(args, "--log-level", o.LogLevel)
}
if o.ChatTemplate != "" {
args = append(args, "--chat-template", o.ChatTemplate)
}
if o.UseDefaultChatTemplate {
args = append(args, "--use-default-chat-template")
}
if o.ChatTemplateArgs != "" {
args = append(args, "--chat-template-args", o.ChatTemplateArgs)
}
// Sampling defaults
if o.Temp != 0 {
args = append(args, "--temp", strconv.FormatFloat(o.Temp, 'f', -1, 64))
}
if o.TopP != 0 {
args = append(args, "--top-p", strconv.FormatFloat(o.TopP, 'f', -1, 64))
}
if o.TopK != 0 {
args = append(args, "--top-k", strconv.Itoa(o.TopK))
}
if o.MinP != 0 {
args = append(args, "--min-p", strconv.FormatFloat(o.MinP, 'f', -1, 64))
}
if o.MaxTokens != 0 {
args = append(args, "--max-tokens", strconv.Itoa(o.MaxTokens))
}
return args
}
}

View File

@@ -0,0 +1,62 @@
package mlx_test
import (
"llamactl/pkg/backends/mlx"
"testing"
)
func TestBuildCommandArgs(t *testing.T) {
options := &mlx.MlxServerOptions{
Model: "/test/model.mlx",
Host: "127.0.0.1",
Port: 8080,
Temp: 0.7,
TopP: 0.9,
TopK: 40,
MaxTokens: 2048,
TrustRemoteCode: true,
LogLevel: "DEBUG",
ChatTemplate: "custom template",
}
args := options.BuildCommandArgs()
// Check that all expected flags are present
expectedFlags := map[string]string{
"--model": "/test/model.mlx",
"--host": "127.0.0.1",
"--port": "8080",
"--log-level": "DEBUG",
"--chat-template": "custom template",
"--temp": "0.7",
"--top-p": "0.9",
"--top-k": "40",
"--max-tokens": "2048",
}
for i := 0; i < len(args); i++ {
if args[i] == "--trust-remote-code" {
continue // Boolean flag with no value
}
if args[i] == "--use-default-chat-template" {
continue // Boolean flag with no value
}
if expectedValue, exists := expectedFlags[args[i]]; exists && i+1 < len(args) {
if args[i+1] != expectedValue {
t.Errorf("expected %s to have value %s, got %s", args[i], expectedValue, args[i+1])
}
}
}
// Check boolean flags
foundTrustRemoteCode := false
for _, arg := range args {
if arg == "--trust-remote-code" {
foundTrustRemoteCode = true
}
}
if !foundTrustRemoteCode {
t.Errorf("expected --trust-remote-code flag to be present")
}
}

View File

@@ -0,0 +1,101 @@
package mlx_test
import (
"llamactl/pkg/backends/mlx"
"testing"
)
func TestParseMlxCommand(t *testing.T) {
tests := []struct {
name string
command string
expectErr bool
}{
{
name: "basic command",
command: "mlx_lm.server --model /path/to/model --host 0.0.0.0",
expectErr: false,
},
{
name: "args only",
command: "--model /path/to/model --port 8080",
expectErr: false,
},
{
name: "mixed flag formats",
command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code",
expectErr: false,
},
{
name: "quoted strings",
command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`,
expectErr: false,
},
{
name: "empty command",
command: "",
expectErr: true,
},
{
name: "unterminated quote",
command: `mlx_lm.server --model test.mlx --chat-template "unterminated`,
expectErr: true,
},
{
name: "malformed flag",
command: "mlx_lm.server ---model test.mlx",
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := mlx.ParseMlxCommand(tt.command)
if tt.expectErr {
if err == nil {
t.Errorf("expected error but got none")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if result == nil {
t.Errorf("expected result but got nil")
}
})
}
}
func TestParseMlxCommandValues(t *testing.T) {
command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG"
result, err := mlx.ParseMlxCommand(command)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "/test/model.mlx" {
t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model)
}
if result.Port != 8080 {
t.Errorf("expected port 8080, got %d", result.Port)
}
if result.Temp != 0.7 {
t.Errorf("expected temp 0.7, got %f", result.Temp)
}
if !result.TrustRemoteCode {
t.Errorf("expected trust_remote_code to be true")
}
if result.LogLevel != "DEBUG" {
t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel)
}
}