Refactor command argument building across backends

This commit is contained in:
2025-09-19 19:46:54 +02:00
parent 9eecb37aec
commit ec5485bd0e
4 changed files with 209 additions and 176 deletions

View File

@@ -2,9 +2,9 @@ package llamacpp
import ( import (
"encoding/json" "encoding/json"
"llamactl/pkg/backends"
"reflect" "reflect"
"strconv" "strconv"
"strings"
) )
type LlamaServerOptions struct { type LlamaServerOptions struct {
@@ -313,64 +313,10 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
return nil return nil
} }
// BuildCommandArgs converts InstanceOptions to command line arguments // BuildCommandArgs converts InstanceOptions to command line arguments using the common builder
func (o *LlamaServerOptions) BuildCommandArgs() []string { func (o *LlamaServerOptions) BuildCommandArgs() []string {
var args []string config := backends.ArgsBuilderConfig{
SliceHandling: backends.SliceAsMultipleFlags, // Llama uses multiple flags for arrays
v := reflect.ValueOf(o).Elem()
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
fieldType := t.Field(i)
// Skip unexported fields
if !field.CanInterface() {
continue
} }
return backends.BuildCommandArgs(o, config)
// Get the JSON tag to determine the flag name
jsonTag := fieldType.Tag.Get("json")
if jsonTag == "" || jsonTag == "-" {
continue
}
// Remove ",omitempty" from the tag
flagName := jsonTag
if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 {
flagName = jsonTag[:commaIndex]
}
// Convert snake_case to kebab-case for CLI flags
flagName = strings.ReplaceAll(flagName, "_", "-")
// Add the appropriate arguments based on field type and value
switch field.Kind() {
case reflect.Bool:
if field.Bool() {
args = append(args, "--"+flagName)
}
case reflect.Int:
if field.Int() != 0 {
args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10))
}
case reflect.Float64:
if field.Float() != 0 {
args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64))
}
case reflect.String:
if field.String() != "" {
args = append(args, "--"+flagName, field.String())
}
case reflect.Slice:
if field.Type().Elem().Kind() == reflect.String {
// Handle []string fields
for j := 0; j < field.Len(); j++ {
args = append(args, "--"+flagName, field.Index(j).String())
}
}
}
}
return args
} }

View File

@@ -1,9 +1,10 @@
package mlx package mlx
import ( import (
"encoding/json"
"llamactl/pkg/backends"
"reflect" "reflect"
"strconv" "strconv"
"strings"
) )
type MlxServerOptions struct { type MlxServerOptions struct {
@@ -32,57 +33,83 @@ type MlxServerOptions struct {
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
} }
// BuildCommandArgs converts to command line arguments using reflection // UnmarshalJSON implements custom JSON unmarshaling to support multiple field names
func (o *MlxServerOptions) BuildCommandArgs() []string { func (o *MlxServerOptions) UnmarshalJSON(data []byte) error {
var args []string // First unmarshal into a map to handle multiple field names
var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
// Create a temporary struct for standard unmarshaling
type tempOptions MlxServerOptions
temp := tempOptions{}
// Standard unmarshal first
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
// Copy to our struct
*o = MlxServerOptions(temp)
// Handle alternative field names
fieldMappings := map[string]string{
"m": "model", // -m, --model
"temperature": "temp", // --temperature vs --temp
"top_k": "top_k", // --top-k
"adapter_path": "adapter_path", // --adapter-path
}
// Process alternative field names
for altName, canonicalName := range fieldMappings {
if value, exists := raw[altName]; exists {
// Use reflection to set the field value
v := reflect.ValueOf(o).Elem() v := reflect.ValueOf(o).Elem()
t := v.Type() field := v.FieldByNameFunc(func(fieldName string) bool {
field, _ := v.Type().FieldByName(fieldName)
jsonTag := field.Tag.Get("json")
return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName
})
for i := 0; i < v.NumField(); i++ { if field.IsValid() && field.CanSet() {
field := v.Field(i)
fieldType := t.Field(i)
// Skip unexported fields
if !field.CanInterface() {
continue
}
// Get the JSON tag to determine the flag name
jsonTag := fieldType.Tag.Get("json")
if jsonTag == "" || jsonTag == "-" {
continue
}
// Remove ",omitempty" from the tag
flagName := jsonTag
if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 {
flagName = jsonTag[:commaIndex]
}
// Convert snake_case to kebab-case for CLI flags
flagName = strings.ReplaceAll(flagName, "_", "-")
// Add the appropriate arguments based on field type and value
switch field.Kind() { switch field.Kind() {
case reflect.Bool:
if field.Bool() {
args = append(args, "--"+flagName)
}
case reflect.Int: case reflect.Int:
if field.Int() != 0 { if intVal, ok := value.(float64); ok {
args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) field.SetInt(int64(intVal))
} else if strVal, ok := value.(string); ok {
if intVal, err := strconv.Atoi(strVal); err == nil {
field.SetInt(int64(intVal))
}
} }
case reflect.Float64: case reflect.Float64:
if field.Float() != 0 { if floatVal, ok := value.(float64); ok {
args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) field.SetFloat(floatVal)
} else if strVal, ok := value.(string); ok {
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
field.SetFloat(floatVal)
}
} }
case reflect.String: case reflect.String:
if field.String() != "" { if strVal, ok := value.(string); ok {
args = append(args, "--"+flagName, field.String()) field.SetString(strVal)
}
case reflect.Bool:
if boolVal, ok := value.(bool); ok {
field.SetBool(boolVal)
}
}
} }
} }
} }
return args return nil
}
// BuildCommandArgs converts to command line arguments using the common builder
func (o *MlxServerOptions) BuildCommandArgs() []string {
config := backends.ArgsBuilderConfig{
SliceHandling: backends.SliceAsMultipleFlags, // MLX doesn't currently have []string fields, but default to multiple flags
}
return backends.BuildCommandArgs(o, config)
} }

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"path/filepath" "path/filepath"
"reflect"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
@@ -308,3 +309,123 @@ func isFlag(arg string) bool {
return true return true
} }
// SliceHandling defines how []string fields should be handled when building command args
type SliceHandling int
const (
// SliceAsMultipleFlags creates multiple flags: --flag value1 --flag value2
SliceAsMultipleFlags SliceHandling = iota
// SliceAsCommaSeparated creates single flag with comma-separated values: --flag value1,value2
SliceAsCommaSeparated
// SliceAsMixed uses different strategies for different flags (requires configuration)
SliceAsMixed
)
// ArgsBuilderConfig holds configuration for building command line arguments
type ArgsBuilderConfig struct {
// SliceHandling defines the default strategy for []string fields
SliceHandling SliceHandling
// MultipleFlags specifies which flags should use multiple instances when SliceHandling is SliceAsMixed
MultipleFlags map[string]struct{}
}
// BuildCommandArgs converts a struct to command line arguments using reflection
func BuildCommandArgs(options any, config ArgsBuilderConfig) []string {
var args []string
v := reflect.ValueOf(options).Elem()
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
fieldType := t.Field(i)
// Skip unexported fields
if !field.CanInterface() {
continue
}
// Get the JSON tag to determine the flag name
jsonTag := fieldType.Tag.Get("json")
if jsonTag == "" || jsonTag == "-" {
continue
}
// Remove ",omitempty" from the tag
flagName := jsonTag
if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 {
flagName = jsonTag[:commaIndex]
}
// Convert snake_case to kebab-case for CLI flags
flagName = strings.ReplaceAll(flagName, "_", "-")
// Add the appropriate arguments based on field type and value
switch field.Kind() {
case reflect.Bool:
if field.Bool() {
args = append(args, "--"+flagName)
}
case reflect.Int:
if field.Int() != 0 {
args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10))
}
case reflect.Float64:
if field.Float() != 0 {
args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64))
}
case reflect.String:
if field.String() != "" {
args = append(args, "--"+flagName, field.String())
}
case reflect.Slice:
if field.Type().Elem().Kind() == reflect.String {
args = append(args, handleStringSlice(field, flagName, config)...)
}
}
}
return args
}
// handleStringSlice handles []string fields based on the configuration
func handleStringSlice(field reflect.Value, flagName string, config ArgsBuilderConfig) []string {
if field.Len() == 0 {
return nil
}
var args []string
switch config.SliceHandling {
case SliceAsMultipleFlags:
// Multiple flags: --flag value1 --flag value2
for j := 0; j < field.Len(); j++ {
args = append(args, "--"+flagName, field.Index(j).String())
}
case SliceAsCommaSeparated:
// Comma-separated: --flag value1,value2
var values []string
for j := 0; j < field.Len(); j++ {
values = append(values, field.Index(j).String())
}
args = append(args, "--"+flagName, strings.Join(values, ","))
case SliceAsMixed:
// Check if this specific flag should use multiple instances
if _, useMultiple := config.MultipleFlags[flagName]; useMultiple {
// Multiple flags
for j := 0; j < field.Len(); j++ {
args = append(args, "--"+flagName, field.Index(j).String())
}
} else {
// Comma-separated
var values []string
for j := 0; j < field.Len(); j++ {
values = append(values, field.Index(j).String())
}
args = append(args, "--"+flagName, strings.Join(values, ","))
}
}
return args
}

View File

@@ -1,9 +1,7 @@
package vllm package vllm
import ( import (
"reflect" "llamactl/pkg/backends"
"strconv"
"strings"
) )
type VllmServerOptions struct { type VllmServerOptions struct {
@@ -132,77 +130,18 @@ type VllmServerOptions struct {
OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"` OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"`
} }
// BuildCommandArgs converts VllmServerOptions to command line arguments // BuildCommandArgs converts VllmServerOptions to command line arguments using the common builder
// Note: This does NOT include the "serve" subcommand, that's handled at the instance level // Note: This does NOT include the "serve" subcommand, that's handled at the instance level
func (o *VllmServerOptions) BuildCommandArgs() []string { func (o *VllmServerOptions) BuildCommandArgs() []string {
var args []string config := backends.ArgsBuilderConfig{
SliceHandling: backends.SliceAsMixed,
v := reflect.ValueOf(o).Elem() MultipleFlags: map[string]struct{}{
t := v.Type() "api-key": {},
"allowed-origins": {},
for i := 0; i < v.NumField(); i++ { "allowed-methods": {},
field := v.Field(i) "allowed-headers": {},
fieldType := t.Field(i) "middleware": {},
},
// Skip unexported fields
if !field.CanInterface() {
continue
} }
return backends.BuildCommandArgs(o, config)
// Get the JSON tag to determine the flag name
jsonTag := fieldType.Tag.Get("json")
if jsonTag == "" || jsonTag == "-" {
continue
}
// Remove ",omitempty" from the tag
flagName := jsonTag
if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 {
flagName = jsonTag[:commaIndex]
}
// Convert snake_case to kebab-case for CLI flags
flagName = strings.ReplaceAll(flagName, "_", "-")
// Add the appropriate arguments based on field type and value
switch field.Kind() {
case reflect.Bool:
if field.Bool() {
args = append(args, "--"+flagName)
}
case reflect.Int:
if field.Int() != 0 {
args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10))
}
case reflect.Float64:
if field.Float() != 0 {
args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64))
}
case reflect.String:
if field.String() != "" {
args = append(args, "--"+flagName, field.String())
}
case reflect.Slice:
if field.Type().Elem().Kind() == reflect.String {
// Handle []string fields - some are comma-separated, some use multiple flags
if flagName == "api-key" || flagName == "allowed-origins" || flagName == "allowed-methods" || flagName == "allowed-headers" || flagName == "middleware" {
// Multiple flags for these
for j := 0; j < field.Len(); j++ {
args = append(args, "--"+flagName, field.Index(j).String())
}
} else {
// Comma-separated for others
if field.Len() > 0 {
var values []string
for j := 0; j < field.Len(); j++ {
values = append(values, field.Index(j).String())
}
args = append(args, "--"+flagName, strings.Join(values, ","))
}
}
}
}
}
return args
} }