Update multi valued flags in backends

This commit is contained in:
2025-10-25 19:02:46 +02:00
parent bd6436840e
commit ea6c76cc96
6 changed files with 69 additions and 46 deletions

View File

@@ -9,7 +9,7 @@ import (
)
// BuildCommandArgs converts a struct to command line arguments
func BuildCommandArgs(options any, multipleFlags map[string]bool) []string {
func BuildCommandArgs(options any, multipleFlags map[string]struct{}) []string {
var args []string
v := reflect.ValueOf(options).Elem()
@@ -28,9 +28,10 @@ func BuildCommandArgs(options any, multipleFlags map[string]bool) []string {
continue
}
// Get flag name from JSON tag
flagName := strings.Split(jsonTag, ",")[0]
flagName = strings.ReplaceAll(flagName, "_", "-")
// Get flag name from JSON tag (snake_case)
jsonFieldName := strings.Split(jsonTag, ",")[0]
// Convert to kebab-case for CLI flags
flagName := strings.ReplaceAll(jsonFieldName, "_", "-")
switch field.Kind() {
case reflect.Bool:
@@ -51,7 +52,8 @@ func BuildCommandArgs(options any, multipleFlags map[string]bool) []string {
}
case reflect.Slice:
if field.Type().Elem().Kind() == reflect.String && field.Len() > 0 {
if multipleFlags[flagName] {
// Use jsonFieldName (snake_case) for multipleFlags lookup
if _, isMultiValue := multipleFlags[jsonFieldName]; isMultiValue {
// Multiple flags: --flag value1 --flag value2
for j := 0; j < field.Len(); j++ {
args = append(args, "--"+flagName, field.Index(j).String())

View File

@@ -9,25 +9,16 @@ import (
)
// llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
// Used for both parsing (with underscores) and building (with dashes)
var llamaMultiValuedFlags = map[string]bool{
// Parsing keys (with underscores)
"override_tensor": true,
"override_kv": true,
"lora": true,
"lora_scaled": true,
"control_vector": true,
"control_vector_scaled": true,
"dry_sequence_breaker": true,
"logit_bias": true,
// Building keys (with dashes)
"override-tensor": true,
"override-kv": true,
"lora-scaled": true,
"control-vector": true,
"control-vector-scaled": true,
"dry-sequence-breaker": true,
"logit-bias": true,
// Keys use snake_case as the parser converts kebab-case flags to snake_case before lookup
var llamaMultiValuedFlags = map[string]struct{}{
"override_tensor": {},
"override_kv": {},
"lora": {},
"lora_scaled": {},
"control_vector": {},
"control_vector_scaled": {},
"dry_sequence_breaker": {},
"logit_bias": {},
}
type LlamaServerOptions struct {

View File

@@ -62,7 +62,7 @@ func (o *MlxServerOptions) Validate() error {
// BuildCommandArgs converts to command line arguments
func (o *MlxServerOptions) BuildCommandArgs() []string {
multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields
multipleFlags := map[string]struct{}{} // MLX doesn't currently have []string fields
return BuildCommandArgs(o, multipleFlags)
}
@@ -79,7 +79,7 @@ func (o *MlxServerOptions) BuildDockerArgs() []string {
func (o *MlxServerOptions) ParseCommand(command string) (any, error) {
executableNames := []string{"mlx_lm.server"}
var subcommandNames []string // MLX has no subcommands
multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags
multiValuedFlags := map[string]struct{}{} // MLX has no multi-valued flags
var mlxOptions MlxServerOptions
if err := parseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil {

View File

@@ -10,7 +10,7 @@ import (
)
// parseCommand parses a command string into a target struct
func parseCommand(command string, executableNames []string, subcommandNames []string, multiValuedFlags map[string]bool, target any) error {
func parseCommand(command string, executableNames []string, subcommandNames []string, multiValuedFlags map[string]struct{}, target any) error {
// Normalize multiline commands
command = normalizeCommand(command)
if command == "" {
@@ -125,7 +125,7 @@ func extractArgs(command string, executableNames []string, subcommandNames []str
}
// parseFlags parses command line flags into a map
func parseFlags(args []string, multiValuedFlags map[string]bool) (map[string]any, error) {
func parseFlags(args []string, multiValuedFlags map[string]struct{}) (map[string]any, error) {
options := make(map[string]any)
for i := 0; i < len(args); i++ {
@@ -163,7 +163,7 @@ func parseFlags(args []string, multiValuedFlags map[string]bool) (map[string]any
if hasValue {
// Handle multi-valued flags
if multiValuedFlags[flagName] {
if _, isMultiValue := multiValuedFlags[flagName]; isMultiValue {
if existing, ok := options[flagName].([]string); ok {
options[flagName] = append(existing, value)
} else {

View File

@@ -6,12 +6,16 @@ import (
)
// vllmMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
var vllmMultiValuedFlags = map[string]bool{
"api-key": true,
"allowed-origins": true,
"allowed-methods": true,
"allowed-headers": true,
"middleware": true,
// Based on vLLM's CLI argument definitions with action='append' or List types
// Keys use snake_case as the parser converts kebab-case flags to snake_case before lookup
var vllmMultiValuedFlags = map[string]struct{}{
"api_key": {}, // --api-key (action='append')
"allowed_origins": {}, // --allowed-origins (List type)
"allowed_methods": {}, // --allowed-methods (List type)
"allowed_headers": {}, // --allowed-headers (List type)
"middleware": {}, // --middleware (action='append')
"lora_modules": {}, // --lora-modules (custom LoRAParserAction, accepts multiple)
"prompt_adapters": {}, // --prompt-adapters (similar to lora-modules, accepts multiple)
}
type VllmServerOptions struct {
@@ -212,18 +216,9 @@ func (o *VllmServerOptions) BuildDockerArgs() []string {
func (o *VllmServerOptions) ParseCommand(command string) (any, error) {
executableNames := []string{"vllm"}
subcommandNames := []string{"serve"}
multiValuedFlags := map[string]bool{
"middleware": true,
"api_key": true,
"allowed_origins": true,
"allowed_methods": true,
"allowed_headers": true,
"lora_modules": true,
"prompt_adapters": true,
}
var vllmOptions VllmServerOptions
if err := parseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil {
if err := parseCommand(command, executableNames, subcommandNames, vllmMultiValuedFlags, &vllmOptions); err != nil {
return nil, err
}

View File

@@ -120,6 +120,41 @@ func TestParseVllmCommand(t *testing.T) {
}
}
func TestParseVllmCommandArrays(t *testing.T) {
command := "vllm serve test-model --middleware auth.py --middleware=cors.py --api-key key1 --api-key key2"
var opts backends.VllmServerOptions
resultAny, err := opts.ParseCommand(command)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
result, ok := resultAny.(*backends.VllmServerOptions)
if !ok {
t.Fatalf("expected *VllmServerOptions, got %T", resultAny)
}
expectedMiddleware := []string{"auth.py", "cors.py"}
if len(result.Middleware) != len(expectedMiddleware) {
t.Errorf("expected %d middleware items, got %d", len(expectedMiddleware), len(result.Middleware))
}
for i, v := range expectedMiddleware {
if i >= len(result.Middleware) || result.Middleware[i] != v {
t.Errorf("expected middleware[%d]=%s got %s", i, v, result.Middleware[i])
}
}
expectedAPIKeys := []string{"key1", "key2"}
if len(result.APIKey) != len(expectedAPIKeys) {
t.Errorf("expected %d api keys, got %d", len(expectedAPIKeys), len(result.APIKey))
}
for i, v := range expectedAPIKeys {
if i >= len(result.APIKey) || result.APIKey[i] != v {
t.Errorf("expected api_key[%d]=%s got %s", i, v, result.APIKey[i])
}
}
}
func TestVllmBuildCommandArgs_BooleanFields(t *testing.T) {
tests := []struct {
name string