Add support for extra args for command parser

This commit is contained in:
2025-11-13 20:41:08 +01:00
parent ae5358ff65
commit 11bfe75a3c
7 changed files with 421 additions and 30 deletions

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"llamactl/pkg/validation" "llamactl/pkg/validation"
"reflect" "reflect"
"strconv"
) )
// llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated // llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
@@ -213,6 +212,15 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
// Copy to our struct // Copy to our struct
*o = LlamaServerOptions(temp) *o = LlamaServerOptions(temp)
// Track which fields we've processed
processedFields := make(map[string]bool)
// Get all known canonical field names from struct tags
knownFields := getKnownFieldNames(o)
for field := range knownFields {
processedFields[field] = true
}
// Handle alternative field names // Handle alternative field names
fieldMappings := map[string]string{ fieldMappings := map[string]string{
// Common params // Common params
@@ -273,8 +281,8 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
"rerank": "reranking", // --reranking "rerank": "reranking", // --reranking
"to": "timeout", // -to, --timeout N "to": "timeout", // -to, --timeout N
"sps": "slot_prompt_similarity", // -sps, --slot-prompt-similarity "sps": "slot_prompt_similarity", // -sps, --slot-prompt-similarity
"draft": "draft-max", // -draft, --draft-max N "draft": "draft_max", // -draft, --draft-max N
"draft_n": "draft-max", // --draft-n-max N "draft_n": "draft_max", // --draft-n-max N
"draft_n_min": "draft_min", // --draft-n-min N "draft_n_min": "draft_min", // --draft-n-min N
"cd": "ctx_size_draft", // -cd, --ctx-size-draft N "cd": "ctx_size_draft", // -cd, --ctx-size-draft N
"devd": "device_draft", // -devd, --device-draft "devd": "device_draft", // -devd, --device-draft
@@ -286,8 +294,10 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
"mv": "model_vocoder", // -mv, --model-vocoder FNAME "mv": "model_vocoder", // -mv, --model-vocoder FNAME
} }
// Process alternative field names // Process alternative field names and mark them as processed
for altName, canonicalName := range fieldMappings { for altName, canonicalName := range fieldMappings {
processedFields[altName] = true // Mark alternatives as known
if value, exists := raw[altName]; exists { if value, exists := raw[altName]; exists {
// Use reflection to set the field value // Use reflection to set the field value
v := reflect.ValueOf(o).Elem() v := reflect.ValueOf(o).Elem()
@@ -298,36 +308,21 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
}) })
if field.IsValid() && field.CanSet() { if field.IsValid() && field.CanSet() {
switch field.Kind() { setFieldValue(field, value)
case reflect.Int:
if intVal, ok := value.(float64); ok {
field.SetInt(int64(intVal))
} else if strVal, ok := value.(string); ok {
if intVal, err := strconv.Atoi(strVal); err == nil {
field.SetInt(int64(intVal))
}
}
case reflect.Float64:
if floatVal, ok := value.(float64); ok {
field.SetFloat(floatVal)
} else if strVal, ok := value.(string); ok {
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
field.SetFloat(floatVal)
}
}
case reflect.String:
if strVal, ok := value.(string); ok {
field.SetString(strVal)
}
case reflect.Bool:
if boolVal, ok := value.(bool); ok {
field.SetBool(boolVal)
}
}
} }
} }
} }
// Collect unknown fields into ExtraArgs
if o.ExtraArgs == nil {
o.ExtraArgs = make(map[string]string)
}
for key, value := range raw {
if !processedFields[key] {
o.ExtraArgs[key] = fmt.Sprintf("%v", value)
}
}
return nil return nil
} }

View File

@@ -453,3 +453,99 @@ func TestLlamaCppBuildCommandArgs_ExtraArgs(t *testing.T) {
t.Error("Expected --log-file flag or value not found") t.Error("Expected --log-file flag or value not found")
} }
} }
func TestParseLlamaCommand_ExtraArgs(t *testing.T) {
tests := []struct {
name string
command string
expectErr bool
validate func(*testing.T, *backends.LlamaServerOptions)
}{
{
name: "extra args with known fields",
command: "llama-server --model /path/to/model.gguf --gpu-layers 32 --unknown-flag value --another-bool-flag",
expectErr: false,
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
if opts.Model != "/path/to/model.gguf" {
t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model)
}
if opts.GPULayers != 32 {
t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers)
}
if opts.ExtraArgs == nil {
t.Fatal("expected extra_args to be non-nil")
}
if val, ok := opts.ExtraArgs["unknown_flag"]; !ok || val != "value" {
t.Errorf("expected extra_args[unknown_flag]='value', got '%s'", val)
}
if val, ok := opts.ExtraArgs["another_bool_flag"]; !ok || val != "true" {
t.Errorf("expected extra_args[another_bool_flag]='true', got '%s'", val)
}
},
},
{
name: "extra args with alternative field names",
command: "llama-server -m /model.gguf -ngl 16 --custom-arg test --new-feature",
expectErr: false,
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
// Check that alternative names worked for known fields
if opts.Model != "/model.gguf" {
t.Errorf("expected model '/model.gguf', got '%s'", opts.Model)
}
if opts.GPULayers != 16 {
t.Errorf("expected gpu_layers 16, got %d", opts.GPULayers)
}
// Check that unknown args went to ExtraArgs
if opts.ExtraArgs == nil {
t.Fatal("expected extra_args to be non-nil")
}
if val, ok := opts.ExtraArgs["custom_arg"]; !ok || val != "test" {
t.Errorf("expected extra_args[custom_arg]='test', got '%s'", val)
}
if val, ok := opts.ExtraArgs["new_feature"]; !ok || val != "true" {
t.Errorf("expected extra_args[new_feature]='true', got '%s'", val)
}
},
},
{
name: "only extra args",
command: "llama-server --experimental-feature --beta-mode enabled",
expectErr: false,
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
if opts.ExtraArgs == nil {
t.Fatal("expected extra_args to be non-nil")
}
if val, ok := opts.ExtraArgs["experimental_feature"]; !ok || val != "true" {
t.Errorf("expected extra_args[experimental_feature]='true', got '%s'", val)
}
if val, ok := opts.ExtraArgs["beta_mode"]; !ok || val != "enabled" {
t.Errorf("expected extra_args[beta_mode]='enabled', got '%s'", val)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var opts backends.LlamaServerOptions
result, err := opts.ParseCommand(tt.command)
if tt.expectErr && err == nil {
t.Error("expected error but got none")
return
}
if !tt.expectErr && err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if !tt.expectErr && tt.validate != nil {
llamaOpts, ok := result.(*backends.LlamaServerOptions)
if !ok {
t.Fatal("result is not *LlamaServerOptions")
}
tt.validate(t, llamaOpts)
}
})
}
}

View File

@@ -1,6 +1,7 @@
package backends package backends
import ( import (
"encoding/json"
"fmt" "fmt"
"llamactl/pkg/validation" "llamactl/pkg/validation"
) )
@@ -35,6 +36,42 @@ type MlxServerOptions struct {
ExtraArgs map[string]string `json:"extra_args,omitempty"` ExtraArgs map[string]string `json:"extra_args,omitempty"`
} }
// UnmarshalJSON implements custom JSON unmarshaling to collect unknown fields into ExtraArgs
func (o *MlxServerOptions) UnmarshalJSON(data []byte) error {
// First unmarshal into a map to capture all fields
var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
// Create a temporary struct for standard unmarshaling
type tempOptions MlxServerOptions
temp := tempOptions{}
// Standard unmarshal first
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
// Copy to our struct
*o = MlxServerOptions(temp)
// Get all known canonical field names from struct tags
knownFields := getKnownFieldNames(o)
// Collect unknown fields into ExtraArgs
if o.ExtraArgs == nil {
o.ExtraArgs = make(map[string]string)
}
for key, value := range raw {
if !knownFields[key] {
o.ExtraArgs[key] = fmt.Sprintf("%v", value)
}
}
return nil
}
func (o *MlxServerOptions) GetPort() int { func (o *MlxServerOptions) GetPort() int {
return o.Port return o.Port
} }

View File

@@ -202,3 +202,75 @@ func TestMlxBuildCommandArgs_ZeroValues(t *testing.T) {
} }
} }
} }
func TestParseMlxCommand_ExtraArgs(t *testing.T) {
tests := []struct {
name string
command string
expectErr bool
validate func(*testing.T, *backends.MlxServerOptions)
}{
{
name: "extra args with known fields",
command: "mlx_lm.server --model /path/to/model --port 8080 --unknown-flag value --new-bool-flag",
expectErr: false,
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
if opts.Model != "/path/to/model" {
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
}
if opts.Port != 8080 {
t.Errorf("expected port 8080, got %d", opts.Port)
}
if opts.ExtraArgs == nil {
t.Fatal("expected extra_args to be non-nil")
}
if val, ok := opts.ExtraArgs["unknown_flag"]; !ok || val != "value" {
t.Errorf("expected extra_args[unknown_flag]='value', got '%s'", val)
}
if val, ok := opts.ExtraArgs["new_bool_flag"]; !ok || val != "true" {
t.Errorf("expected extra_args[new_bool_flag]='true', got '%s'", val)
}
},
},
{
name: "only extra args",
command: "mlx_lm.server --experimental-feature --custom-param test",
expectErr: false,
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
if opts.ExtraArgs == nil {
t.Fatal("expected extra_args to be non-nil")
}
if val, ok := opts.ExtraArgs["experimental_feature"]; !ok || val != "true" {
t.Errorf("expected extra_args[experimental_feature]='true', got '%s'", val)
}
if val, ok := opts.ExtraArgs["custom_param"]; !ok || val != "test" {
t.Errorf("expected extra_args[custom_param]='test', got '%s'", val)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var opts backends.MlxServerOptions
result, err := opts.ParseCommand(tt.command)
if tt.expectErr && err == nil {
t.Error("expected error but got none")
return
}
if !tt.expectErr && err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if !tt.expectErr && tt.validate != nil {
mlxOpts, ok := result.(*backends.MlxServerOptions)
if !ok {
t.Fatal("result is not *MlxServerOptions")
}
tt.validate(t, mlxOpts)
}
})
}
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"path/filepath" "path/filepath"
"reflect"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
@@ -211,3 +212,65 @@ func parseValue(value string) any {
// Return as string // Return as string
return value return value
} }
// setFieldValue sets a field value using reflection, handling type conversions
// Used by UnmarshalJSON implementations to handle alternative field names
func setFieldValue(field reflect.Value, value any) {
switch field.Kind() {
case reflect.Int:
if intVal, ok := value.(float64); ok {
field.SetInt(int64(intVal))
} else if strVal, ok := value.(string); ok {
if intVal, err := strconv.Atoi(strVal); err == nil {
field.SetInt(int64(intVal))
}
}
case reflect.Float64:
if floatVal, ok := value.(float64); ok {
field.SetFloat(floatVal)
} else if strVal, ok := value.(string); ok {
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
field.SetFloat(floatVal)
}
}
case reflect.String:
if strVal, ok := value.(string); ok {
field.SetString(strVal)
}
case reflect.Bool:
if boolVal, ok := value.(bool); ok {
field.SetBool(boolVal)
}
case reflect.Slice:
// Handle string slices
if field.Type().Elem().Kind() == reflect.String {
if slice, ok := value.([]any); ok {
strSlice := make([]string, 0, len(slice))
for _, v := range slice {
if s, ok := v.(string); ok {
strSlice = append(strSlice, s)
}
}
field.Set(reflect.ValueOf(strSlice))
}
}
}
}
// getKnownFieldNames extracts all known field names from struct json tags
// Used by UnmarshalJSON implementations to identify unknown fields for ExtraArgs
func getKnownFieldNames(v any) map[string]bool {
fields := make(map[string]bool)
t := reflect.TypeOf(v).Elem()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
jsonTag := field.Tag.Get("json")
if jsonTag != "" && jsonTag != "-" {
// Handle "name,omitempty" format
name := strings.Split(jsonTag, ",")[0]
fields[name] = true
}
}
return fields
}

View File

@@ -1,6 +1,7 @@
package backends package backends
import ( import (
"encoding/json"
"fmt" "fmt"
"llamactl/pkg/validation" "llamactl/pkg/validation"
) )
@@ -148,6 +149,42 @@ type VllmServerOptions struct {
ExtraArgs map[string]string `json:"extra_args,omitempty"` ExtraArgs map[string]string `json:"extra_args,omitempty"`
} }
// UnmarshalJSON implements custom JSON unmarshaling to collect unknown fields into ExtraArgs
func (o *VllmServerOptions) UnmarshalJSON(data []byte) error {
// First unmarshal into a map to capture all fields
var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
// Create a temporary struct for standard unmarshaling
type tempOptions VllmServerOptions
temp := tempOptions{}
// Standard unmarshal first
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
// Copy to our struct
*o = VllmServerOptions(temp)
// Get all known canonical field names from struct tags
knownFields := getKnownFieldNames(o)
// Collect unknown fields into ExtraArgs
if o.ExtraArgs == nil {
o.ExtraArgs = make(map[string]string)
}
for key, value := range raw {
if !knownFields[key] {
o.ExtraArgs[key] = fmt.Sprintf("%v", value)
}
}
return nil
}
func (o *VllmServerOptions) GetPort() int { func (o *VllmServerOptions) GetPort() int {
return o.Port return o.Port
} }

View File

@@ -321,3 +321,94 @@ func TestVllmBuildCommandArgs_PositionalModel(t *testing.T) {
t.Errorf("Expected --port 8080 not found in %v", args) t.Errorf("Expected --port 8080 not found in %v", args)
} }
} }
func TestParseVllmCommand_ExtraArgs(t *testing.T) {
tests := []struct {
name string
command string
expectErr bool
validate func(*testing.T, *backends.VllmServerOptions)
}{
{
name: "extra args with known fields",
command: "vllm serve llama-model --tensor-parallel-size 2 --unknown-flag value --new-bool-flag",
expectErr: false,
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
if opts.Model != "llama-model" {
t.Errorf("expected model 'llama-model', got '%s'", opts.Model)
}
if opts.TensorParallelSize != 2 {
t.Errorf("expected tensor_parallel_size 2, got %d", opts.TensorParallelSize)
}
if opts.ExtraArgs == nil {
t.Fatal("expected extra_args to be non-nil")
}
if val, ok := opts.ExtraArgs["unknown_flag"]; !ok || val != "value" {
t.Errorf("expected extra_args[unknown_flag]='value', got '%s'", val)
}
if val, ok := opts.ExtraArgs["new_bool_flag"]; !ok || val != "true" {
t.Errorf("expected extra_args[new_bool_flag]='true', got '%s'", val)
}
},
},
{
name: "only extra args",
command: "vllm serve model --experimental-feature --custom-param test",
expectErr: false,
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
if opts.ExtraArgs == nil {
t.Fatal("expected extra_args to be non-nil")
}
if val, ok := opts.ExtraArgs["experimental_feature"]; !ok || val != "true" {
t.Errorf("expected extra_args[experimental_feature]='true', got '%s'", val)
}
if val, ok := opts.ExtraArgs["custom_param"]; !ok || val != "test" {
t.Errorf("expected extra_args[custom_param]='test', got '%s'", val)
}
},
},
{
name: "extra args without model positional",
command: "vllm serve --model my-model --new-feature enabled --beta-flag",
expectErr: false,
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
if opts.Model != "my-model" {
t.Errorf("expected model 'my-model', got '%s'", opts.Model)
}
if opts.ExtraArgs == nil {
t.Fatal("expected extra_args to be non-nil")
}
if val, ok := opts.ExtraArgs["new_feature"]; !ok || val != "enabled" {
t.Errorf("expected extra_args[new_feature]='enabled', got '%s'", val)
}
if val, ok := opts.ExtraArgs["beta_flag"]; !ok || val != "true" {
t.Errorf("expected extra_args[beta_flag]='true', got '%s'", val)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var opts backends.VllmServerOptions
result, err := opts.ParseCommand(tt.command)
if tt.expectErr && err == nil {
t.Error("expected error but got none")
return
}
if !tt.expectErr && err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if !tt.expectErr && tt.validate != nil {
vllmOpts, ok := result.(*backends.VllmServerOptions)
if !ok {
t.Fatal("result is not *VllmServerOptions")
}
tt.validate(t, vllmOpts)
}
})
}
}