From ec5485bd0e6ba40d72e41c4572e26eb20e636d65 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 19 Sep 2025 19:46:54 +0200 Subject: [PATCH] Refactor command argument building across backends --- pkg/backends/llamacpp/llama.go | 64 ++--------------- pkg/backends/mlx/mlx.go | 115 +++++++++++++++++++------------ pkg/backends/parser.go | 121 +++++++++++++++++++++++++++++++++ pkg/backends/vllm/vllm.go | 85 ++++------------------- 4 files changed, 209 insertions(+), 176 deletions(-) diff --git a/pkg/backends/llamacpp/llama.go b/pkg/backends/llamacpp/llama.go index c838141..cfad2fd 100644 --- a/pkg/backends/llamacpp/llama.go +++ b/pkg/backends/llamacpp/llama.go @@ -2,9 +2,9 @@ package llamacpp import ( "encoding/json" + "llamactl/pkg/backends" "reflect" "strconv" - "strings" ) type LlamaServerOptions struct { @@ -313,64 +313,10 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error { return nil } -// BuildCommandArgs converts InstanceOptions to command line arguments +// BuildCommandArgs converts InstanceOptions to command line arguments using the common builder func (o *LlamaServerOptions) BuildCommandArgs() []string { - var args []string - - v := reflect.ValueOf(o).Elem() - t := v.Type() - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - fieldType := t.Field(i) - - // Skip unexported fields - if !field.CanInterface() { - continue - } - - // Get the JSON tag to determine the flag name - jsonTag := fieldType.Tag.Get("json") - if jsonTag == "" || jsonTag == "-" { - continue - } - - // Remove ",omitempty" from the tag - flagName := jsonTag - if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 { - flagName = jsonTag[:commaIndex] - } - - // Convert snake_case to kebab-case for CLI flags - flagName = strings.ReplaceAll(flagName, "_", "-") - - // Add the appropriate arguments based on field type and value - switch field.Kind() { - case reflect.Bool: - if field.Bool() { - args = append(args, "--"+flagName) - } - case reflect.Int: - if field.Int() != 0 { - args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) - } - case reflect.Float64: - if field.Float() != 0 { - args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) - } - case reflect.String: - if field.String() != "" { - args = append(args, "--"+flagName, field.String()) - } - case reflect.Slice: - if field.Type().Elem().Kind() == reflect.String { - // Handle []string fields - for j := 0; j < field.Len(); j++ { - args = append(args, "--"+flagName, field.Index(j).String()) - } - } - } + config := backends.ArgsBuilderConfig{ + SliceHandling: backends.SliceAsMultipleFlags, // Llama uses multiple flags for arrays } - - return args + return backends.BuildCommandArgs(o, config) } diff --git a/pkg/backends/mlx/mlx.go b/pkg/backends/mlx/mlx.go index 8527c7b..9b29010 100644 --- a/pkg/backends/mlx/mlx.go +++ b/pkg/backends/mlx/mlx.go @@ -1,9 +1,10 @@ package mlx import ( + "encoding/json" + "llamactl/pkg/backends" "reflect" "strconv" - "strings" ) type MlxServerOptions struct { @@ -32,57 +33,83 @@ type MlxServerOptions struct { MaxTokens int `json:"max_tokens,omitempty"` } -// BuildCommandArgs converts to command line arguments using reflection -func (o *MlxServerOptions) BuildCommandArgs() []string { - var args []string +// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names +func (o *MlxServerOptions) UnmarshalJSON(data []byte) error { + // First unmarshal into a map to handle multiple field names + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return err + } - v := reflect.ValueOf(o).Elem() - t := v.Type() + // Create a temporary struct for standard unmarshaling + type tempOptions MlxServerOptions + temp := tempOptions{} - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - fieldType := t.Field(i) + // Standard unmarshal first + if err := json.Unmarshal(data, &temp); err != nil { + return err + } - // Skip unexported fields - if !field.CanInterface() { - continue - } + // Copy to our struct + *o = MlxServerOptions(temp) - // Get the JSON tag to determine the flag name - jsonTag := fieldType.Tag.Get("json") - if jsonTag == "" || jsonTag == "-" { - continue - } + // Handle alternative field names + fieldMappings := map[string]string{ + "m": "model", // -m, --model + "temperature": "temp", // --temperature vs --temp + "top_k": "top_k", // --top-k + "adapter_path": "adapter_path", // --adapter-path + } - // Remove ",omitempty" from the tag - flagName := jsonTag - if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 { - flagName = jsonTag[:commaIndex] - } + // Process alternative field names + for altName, canonicalName := range fieldMappings { + if value, exists := raw[altName]; exists { + // Use reflection to set the field value + v := reflect.ValueOf(o).Elem() + field := v.FieldByNameFunc(func(fieldName string) bool { + field, _ := v.Type().FieldByName(fieldName) + jsonTag := field.Tag.Get("json") + return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName + }) - // Convert snake_case to kebab-case for CLI flags - flagName = strings.ReplaceAll(flagName, "_", "-") - - // Add the appropriate arguments based on field type and value - switch field.Kind() { - case reflect.Bool: - if field.Bool() { - args = append(args, "--"+flagName) - } - case reflect.Int: - if field.Int() != 0 { - args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) - } - case reflect.Float64: - if field.Float() != 0 { - args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) - } - case reflect.String: - if field.String() != "" { - args = append(args, "--"+flagName, field.String()) + if field.IsValid() && field.CanSet() { + switch field.Kind() { + case reflect.Int: + if intVal, ok := value.(float64); ok { + field.SetInt(int64(intVal)) + } else if strVal, ok := value.(string); ok { + if intVal, err := strconv.Atoi(strVal); err == nil { + field.SetInt(int64(intVal)) + } + } + case reflect.Float64: + if floatVal, ok := value.(float64); ok { + field.SetFloat(floatVal) + } else if strVal, ok := value.(string); ok { + if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil { + field.SetFloat(floatVal) + } + } + case reflect.String: + if strVal, ok := value.(string); ok { + field.SetString(strVal) + } + case reflect.Bool: + if boolVal, ok := value.(bool); ok { + field.SetBool(boolVal) + } + } } } } - return args + return nil +} + +// BuildCommandArgs converts to command line arguments using the common builder +func (o *MlxServerOptions) BuildCommandArgs() []string { + config := backends.ArgsBuilderConfig{ + SliceHandling: backends.SliceAsMultipleFlags, // MLX doesn't currently have []string fields, but default to multiple flags + } + return backends.BuildCommandArgs(o, config) } diff --git a/pkg/backends/parser.go b/pkg/backends/parser.go index 5023a7e..0e34398 100644 --- a/pkg/backends/parser.go +++ b/pkg/backends/parser.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "path/filepath" + "reflect" "regexp" "strconv" "strings" @@ -308,3 +309,123 @@ func isFlag(arg string) bool { return true } + +// SliceHandling defines how []string fields should be handled when building command args +type SliceHandling int + +const ( + // SliceAsMultipleFlags creates multiple flags: --flag value1 --flag value2 + SliceAsMultipleFlags SliceHandling = iota + // SliceAsCommaSeparated creates single flag with comma-separated values: --flag value1,value2 + SliceAsCommaSeparated + // SliceAsMixed uses different strategies for different flags (requires configuration) + SliceAsMixed +) + +// ArgsBuilderConfig holds configuration for building command line arguments +type ArgsBuilderConfig struct { + // SliceHandling defines the default strategy for []string fields + SliceHandling SliceHandling + // MultipleFlags specifies which flags should use multiple instances when SliceHandling is SliceAsMixed + MultipleFlags map[string]struct{} +} + +// BuildCommandArgs converts a struct to command line arguments using reflection +func BuildCommandArgs(options any, config ArgsBuilderConfig) []string { + var args []string + + v := reflect.ValueOf(options).Elem() + t := v.Type() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldType := t.Field(i) + + // Skip unexported fields + if !field.CanInterface() { + continue + } + + // Get the JSON tag to determine the flag name + jsonTag := fieldType.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } + + // Remove ",omitempty" from the tag + flagName := jsonTag + if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 { + flagName = jsonTag[:commaIndex] + } + + // Convert snake_case to kebab-case for CLI flags + flagName = strings.ReplaceAll(flagName, "_", "-") + + // Add the appropriate arguments based on field type and value + switch field.Kind() { + case reflect.Bool: + if field.Bool() { + args = append(args, "--"+flagName) + } + case reflect.Int: + if field.Int() != 0 { + args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) + } + case reflect.Float64: + if field.Float() != 0 { + args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) + } + case reflect.String: + if field.String() != "" { + args = append(args, "--"+flagName, field.String()) + } + case reflect.Slice: + if field.Type().Elem().Kind() == reflect.String { + args = append(args, handleStringSlice(field, flagName, config)...) + } + } + } + + return args +} + +// handleStringSlice handles []string fields based on the configuration +func handleStringSlice(field reflect.Value, flagName string, config ArgsBuilderConfig) []string { + if field.Len() == 0 { + return nil + } + + var args []string + + switch config.SliceHandling { + case SliceAsMultipleFlags: + // Multiple flags: --flag value1 --flag value2 + for j := 0; j < field.Len(); j++ { + args = append(args, "--"+flagName, field.Index(j).String()) + } + case SliceAsCommaSeparated: + // Comma-separated: --flag value1,value2 + var values []string + for j := 0; j < field.Len(); j++ { + values = append(values, field.Index(j).String()) + } + args = append(args, "--"+flagName, strings.Join(values, ",")) + case SliceAsMixed: + // Check if this specific flag should use multiple instances + if _, useMultiple := config.MultipleFlags[flagName]; useMultiple { + // Multiple flags + for j := 0; j < field.Len(); j++ { + args = append(args, "--"+flagName, field.Index(j).String()) + } + } else { + // Comma-separated + var values []string + for j := 0; j < field.Len(); j++ { + values = append(values, field.Index(j).String()) + } + args = append(args, "--"+flagName, strings.Join(values, ",")) + } + } + + return args +} diff --git a/pkg/backends/vllm/vllm.go b/pkg/backends/vllm/vllm.go index 9aa865c..81e567d 100644 --- a/pkg/backends/vllm/vllm.go +++ b/pkg/backends/vllm/vllm.go @@ -1,9 +1,7 @@ package vllm import ( - "reflect" - "strconv" - "strings" + "llamactl/pkg/backends" ) type VllmServerOptions struct { @@ -132,77 +130,18 @@ type VllmServerOptions struct { OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"` } -// BuildCommandArgs converts VllmServerOptions to command line arguments +// BuildCommandArgs converts VllmServerOptions to command line arguments using the common builder // Note: This does NOT include the "serve" subcommand, that's handled at the instance level func (o *VllmServerOptions) BuildCommandArgs() []string { - var args []string - - v := reflect.ValueOf(o).Elem() - t := v.Type() - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - fieldType := t.Field(i) - - // Skip unexported fields - if !field.CanInterface() { - continue - } - - // Get the JSON tag to determine the flag name - jsonTag := fieldType.Tag.Get("json") - if jsonTag == "" || jsonTag == "-" { - continue - } - - // Remove ",omitempty" from the tag - flagName := jsonTag - if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 { - flagName = jsonTag[:commaIndex] - } - - // Convert snake_case to kebab-case for CLI flags - flagName = strings.ReplaceAll(flagName, "_", "-") - - // Add the appropriate arguments based on field type and value - switch field.Kind() { - case reflect.Bool: - if field.Bool() { - args = append(args, "--"+flagName) - } - case reflect.Int: - if field.Int() != 0 { - args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) - } - case reflect.Float64: - if field.Float() != 0 { - args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) - } - case reflect.String: - if field.String() != "" { - args = append(args, "--"+flagName, field.String()) - } - case reflect.Slice: - if field.Type().Elem().Kind() == reflect.String { - // Handle []string fields - some are comma-separated, some use multiple flags - if flagName == "api-key" || flagName == "allowed-origins" || flagName == "allowed-methods" || flagName == "allowed-headers" || flagName == "middleware" { - // Multiple flags for these - for j := 0; j < field.Len(); j++ { - args = append(args, "--"+flagName, field.Index(j).String()) - } - } else { - // Comma-separated for others - if field.Len() > 0 { - var values []string - for j := 0; j < field.Len(); j++ { - values = append(values, field.Index(j).String()) - } - args = append(args, "--"+flagName, strings.Join(values, ",")) - } - } - } - } + config := backends.ArgsBuilderConfig{ + SliceHandling: backends.SliceAsMixed, + MultipleFlags: map[string]struct{}{ + "api-key": {}, + "allowed-origins": {}, + "allowed-methods": {}, + "allowed-headers": {}, + "middleware": {}, + }, } - - return args + return backends.BuildCommandArgs(o, config) }