From ea6c76cc96ef5c58cd7f59ec8e8d05f14614ffed Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 25 Oct 2025 19:02:46 +0200 Subject: [PATCH] Update multi valued flags in backends --- pkg/backends/builder.go | 12 +++++++----- pkg/backends/llama.go | 29 ++++++++++------------------- pkg/backends/mlx.go | 6 +++--- pkg/backends/parser.go | 6 +++--- pkg/backends/vllm.go | 27 +++++++++++---------------- pkg/backends/vllm_test.go | 35 +++++++++++++++++++++++++++++++++++ 6 files changed, 69 insertions(+), 46 deletions(-) diff --git a/pkg/backends/builder.go b/pkg/backends/builder.go index d5b5c0c..d224742 100644 --- a/pkg/backends/builder.go +++ b/pkg/backends/builder.go @@ -9,7 +9,7 @@ import ( ) // BuildCommandArgs converts a struct to command line arguments -func BuildCommandArgs(options any, multipleFlags map[string]bool) []string { +func BuildCommandArgs(options any, multipleFlags map[string]struct{}) []string { var args []string v := reflect.ValueOf(options).Elem() @@ -28,9 +28,10 @@ func BuildCommandArgs(options any, multipleFlags map[string]bool) []string { continue } - // Get flag name from JSON tag - flagName := strings.Split(jsonTag, ",")[0] - flagName = strings.ReplaceAll(flagName, "_", "-") + // Get flag name from JSON tag (snake_case) + jsonFieldName := strings.Split(jsonTag, ",")[0] + // Convert to kebab-case for CLI flags + flagName := strings.ReplaceAll(jsonFieldName, "_", "-") switch field.Kind() { case reflect.Bool: @@ -51,7 +52,8 @@ func BuildCommandArgs(options any, multipleFlags map[string]bool) []string { } case reflect.Slice: if field.Type().Elem().Kind() == reflect.String && field.Len() > 0 { - if multipleFlags[flagName] { + // Use jsonFieldName (snake_case) for multipleFlags lookup + if _, isMultiValue := multipleFlags[jsonFieldName]; isMultiValue { // Multiple flags: --flag value1 --flag value2 for j := 0; j < field.Len(); j++ { args = append(args, "--"+flagName, field.Index(j).String()) diff --git a/pkg/backends/llama.go b/pkg/backends/llama.go index bb0f205..2b3372a 100644 --- a/pkg/backends/llama.go +++ b/pkg/backends/llama.go @@ -9,25 +9,16 @@ import ( ) // llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated -// Used for both parsing (with underscores) and building (with dashes) -var llamaMultiValuedFlags = map[string]bool{ - // Parsing keys (with underscores) - "override_tensor": true, - "override_kv": true, - "lora": true, - "lora_scaled": true, - "control_vector": true, - "control_vector_scaled": true, - "dry_sequence_breaker": true, - "logit_bias": true, - // Building keys (with dashes) - "override-tensor": true, - "override-kv": true, - "lora-scaled": true, - "control-vector": true, - "control-vector-scaled": true, - "dry-sequence-breaker": true, - "logit-bias": true, +// Keys use snake_case as the parser converts kebab-case flags to snake_case before lookup +var llamaMultiValuedFlags = map[string]struct{}{ + "override_tensor": {}, + "override_kv": {}, + "lora": {}, + "lora_scaled": {}, + "control_vector": {}, + "control_vector_scaled": {}, + "dry_sequence_breaker": {}, + "logit_bias": {}, } type LlamaServerOptions struct { diff --git a/pkg/backends/mlx.go b/pkg/backends/mlx.go index 4b70e3c..8911d0b 100644 --- a/pkg/backends/mlx.go +++ b/pkg/backends/mlx.go @@ -62,7 +62,7 @@ func (o *MlxServerOptions) Validate() error { // BuildCommandArgs converts to command line arguments func (o *MlxServerOptions) BuildCommandArgs() []string { - multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields + multipleFlags := map[string]struct{}{} // MLX doesn't currently have []string fields return BuildCommandArgs(o, multipleFlags) } @@ -78,8 +78,8 @@ func (o *MlxServerOptions) BuildDockerArgs() []string { // 4. Multiline commands with backslashes func (o *MlxServerOptions) ParseCommand(command string) (any, error) { executableNames := []string{"mlx_lm.server"} - var subcommandNames []string // MLX has no subcommands - multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags + var subcommandNames []string // MLX has no subcommands + multiValuedFlags := map[string]struct{}{} // MLX has no multi-valued flags var mlxOptions MlxServerOptions if err := parseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil { diff --git a/pkg/backends/parser.go b/pkg/backends/parser.go index 7a73249..8208568 100644 --- a/pkg/backends/parser.go +++ b/pkg/backends/parser.go @@ -10,7 +10,7 @@ import ( ) // parseCommand parses a command string into a target struct -func parseCommand(command string, executableNames []string, subcommandNames []string, multiValuedFlags map[string]bool, target any) error { +func parseCommand(command string, executableNames []string, subcommandNames []string, multiValuedFlags map[string]struct{}, target any) error { // Normalize multiline commands command = normalizeCommand(command) if command == "" { @@ -125,7 +125,7 @@ func extractArgs(command string, executableNames []string, subcommandNames []str } // parseFlags parses command line flags into a map -func parseFlags(args []string, multiValuedFlags map[string]bool) (map[string]any, error) { +func parseFlags(args []string, multiValuedFlags map[string]struct{}) (map[string]any, error) { options := make(map[string]any) for i := 0; i < len(args); i++ { @@ -163,7 +163,7 @@ func parseFlags(args []string, multiValuedFlags map[string]bool) (map[string]any if hasValue { // Handle multi-valued flags - if multiValuedFlags[flagName] { + if _, isMultiValue := multiValuedFlags[flagName]; isMultiValue { if existing, ok := options[flagName].([]string); ok { options[flagName] = append(existing, value) } else { diff --git a/pkg/backends/vllm.go b/pkg/backends/vllm.go index a7f669d..34dce4c 100644 --- a/pkg/backends/vllm.go +++ b/pkg/backends/vllm.go @@ -6,12 +6,16 @@ import ( ) // vllmMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated -var vllmMultiValuedFlags = map[string]bool{ - "api-key": true, - "allowed-origins": true, - "allowed-methods": true, - "allowed-headers": true, - "middleware": true, +// Based on vLLM's CLI argument definitions with action='append' or List types +// Keys use snake_case as the parser converts kebab-case flags to snake_case before lookup +var vllmMultiValuedFlags = map[string]struct{}{ + "api_key": {}, // --api-key (action='append') + "allowed_origins": {}, // --allowed-origins (List type) + "allowed_methods": {}, // --allowed-methods (List type) + "allowed_headers": {}, // --allowed-headers (List type) + "middleware": {}, // --middleware (action='append') + "lora_modules": {}, // --lora-modules (custom LoRAParserAction, accepts multiple) + "prompt_adapters": {}, // --prompt-adapters (similar to lora-modules, accepts multiple) } type VllmServerOptions struct { @@ -212,18 +216,9 @@ func (o *VllmServerOptions) BuildDockerArgs() []string { func (o *VllmServerOptions) ParseCommand(command string) (any, error) { executableNames := []string{"vllm"} subcommandNames := []string{"serve"} - multiValuedFlags := map[string]bool{ - "middleware": true, - "api_key": true, - "allowed_origins": true, - "allowed_methods": true, - "allowed_headers": true, - "lora_modules": true, - "prompt_adapters": true, - } var vllmOptions VllmServerOptions - if err := parseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil { + if err := parseCommand(command, executableNames, subcommandNames, vllmMultiValuedFlags, &vllmOptions); err != nil { return nil, err } diff --git a/pkg/backends/vllm_test.go b/pkg/backends/vllm_test.go index 9f93fd1..acec8d6 100644 --- a/pkg/backends/vllm_test.go +++ b/pkg/backends/vllm_test.go @@ -120,6 +120,41 @@ func TestParseVllmCommand(t *testing.T) { } } +func TestParseVllmCommandArrays(t *testing.T) { + command := "vllm serve test-model --middleware auth.py --middleware=cors.py --api-key key1 --api-key key2" + + var opts backends.VllmServerOptions + resultAny, err := opts.ParseCommand(command) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + result, ok := resultAny.(*backends.VllmServerOptions) + if !ok { + t.Fatalf("expected *VllmServerOptions, got %T", resultAny) + } + + expectedMiddleware := []string{"auth.py", "cors.py"} + if len(result.Middleware) != len(expectedMiddleware) { + t.Errorf("expected %d middleware items, got %d", len(expectedMiddleware), len(result.Middleware)) + } + for i, v := range expectedMiddleware { + if i >= len(result.Middleware) || result.Middleware[i] != v { + t.Errorf("expected middleware[%d]=%s got %s", i, v, result.Middleware[i]) + } + } + + expectedAPIKeys := []string{"key1", "key2"} + if len(result.APIKey) != len(expectedAPIKeys) { + t.Errorf("expected %d api keys, got %d", len(expectedAPIKeys), len(result.APIKey)) + } + for i, v := range expectedAPIKeys { + if i >= len(result.APIKey) || result.APIKey[i] != v { + t.Errorf("expected api_key[%d]=%s got %s", i, v, result.APIKey[i]) + } + } +} + func TestVllmBuildCommandArgs_BooleanFields(t *testing.T) { tests := []struct { name string