Refactor command argument building across backends

This commit is contained in:
2025-09-19 19:46:54 +02:00
parent 9eecb37aec
commit ec5485bd0e
4 changed files with 209 additions and 176 deletions

View File

@@ -1,9 +1,10 @@
package mlx
import (
"encoding/json"
"llamactl/pkg/backends"
"reflect"
"strconv"
"strings"
)
type MlxServerOptions struct {
@@ -32,57 +33,83 @@ type MlxServerOptions struct {
MaxTokens int `json:"max_tokens,omitempty"`
}
// BuildCommandArgs converts to command line arguments using reflection
func (o *MlxServerOptions) BuildCommandArgs() []string {
var args []string
// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names
func (o *MlxServerOptions) UnmarshalJSON(data []byte) error {
// First unmarshal into a map to handle multiple field names
var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
v := reflect.ValueOf(o).Elem()
t := v.Type()
// Create a temporary struct for standard unmarshaling
type tempOptions MlxServerOptions
temp := tempOptions{}
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
fieldType := t.Field(i)
// Standard unmarshal first
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
// Skip unexported fields
if !field.CanInterface() {
continue
}
// Copy to our struct
*o = MlxServerOptions(temp)
// Get the JSON tag to determine the flag name
jsonTag := fieldType.Tag.Get("json")
if jsonTag == "" || jsonTag == "-" {
continue
}
// Handle alternative field names
fieldMappings := map[string]string{
"m": "model", // -m, --model
"temperature": "temp", // --temperature vs --temp
"top_k": "top_k", // --top-k
"adapter_path": "adapter_path", // --adapter-path
}
// Remove ",omitempty" from the tag
flagName := jsonTag
if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 {
flagName = jsonTag[:commaIndex]
}
// Process alternative field names
for altName, canonicalName := range fieldMappings {
if value, exists := raw[altName]; exists {
// Use reflection to set the field value
v := reflect.ValueOf(o).Elem()
field := v.FieldByNameFunc(func(fieldName string) bool {
field, _ := v.Type().FieldByName(fieldName)
jsonTag := field.Tag.Get("json")
return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName
})
// Convert snake_case to kebab-case for CLI flags
flagName = strings.ReplaceAll(flagName, "_", "-")
// Add the appropriate arguments based on field type and value
switch field.Kind() {
case reflect.Bool:
if field.Bool() {
args = append(args, "--"+flagName)
}
case reflect.Int:
if field.Int() != 0 {
args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10))
}
case reflect.Float64:
if field.Float() != 0 {
args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64))
}
case reflect.String:
if field.String() != "" {
args = append(args, "--"+flagName, field.String())
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Int:
if intVal, ok := value.(float64); ok {
field.SetInt(int64(intVal))
} else if strVal, ok := value.(string); ok {
if intVal, err := strconv.Atoi(strVal); err == nil {
field.SetInt(int64(intVal))
}
}
case reflect.Float64:
if floatVal, ok := value.(float64); ok {
field.SetFloat(floatVal)
} else if strVal, ok := value.(string); ok {
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
field.SetFloat(floatVal)
}
}
case reflect.String:
if strVal, ok := value.(string); ok {
field.SetString(strVal)
}
case reflect.Bool:
if boolVal, ok := value.(bool); ok {
field.SetBool(boolVal)
}
}
}
}
}
return args
return nil
}
// BuildCommandArgs converts to command line arguments using the common builder
func (o *MlxServerOptions) BuildCommandArgs() []string {
config := backends.ArgsBuilderConfig{
SliceHandling: backends.SliceAsMultipleFlags, // MLX doesn't currently have []string fields, but default to multiple flags
}
return backends.BuildCommandArgs(o, config)
}