diff --git a/pkg/backends/builder.go b/pkg/backends/builder.go index d224742..0d7402e 100644 --- a/pkg/backends/builder.go +++ b/pkg/backends/builder.go @@ -93,3 +93,22 @@ func BuildDockerCommand(backendConfig *config.BackendSettings, instanceArgs []st return "docker", dockerArgs, nil } + +// convertExtraArgsToFlags converts map[string]string to command flags +// Empty values become boolean flags: {"flag": ""} → ["--flag"] +// Non-empty values: {"flag": "value"} → ["--flag", "value"] +func convertExtraArgsToFlags(extraArgs map[string]string) []string { + var args []string + + for key, value := range extraArgs { + if value == "" { + // Boolean flag + args = append(args, "--"+key) + } else { + // Value flag + args = append(args, "--"+key, value) + } + } + + return args +} diff --git a/pkg/backends/llama.go b/pkg/backends/llama.go index 2b3372a..246f0fe 100644 --- a/pkg/backends/llama.go +++ b/pkg/backends/llama.go @@ -5,7 +5,6 @@ import ( "fmt" "llamactl/pkg/validation" "reflect" - "strconv" ) // llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated @@ -41,7 +40,7 @@ type LlamaServerOptions struct { BatchSize int `json:"batch_size,omitempty"` UBatchSize int `json:"ubatch_size,omitempty"` Keep int `json:"keep,omitempty"` - FlashAttn bool `json:"flash_attn,omitempty"` + FlashAttn string `json:"flash_attn,omitempty"` NoPerf bool `json:"no_perf,omitempty"` Escape bool `json:"escape,omitempty"` NoEscape bool `json:"no_escape,omitempty"` @@ -187,6 +186,10 @@ type LlamaServerOptions struct { FIMQwen7BDefault bool `json:"fim_qwen_7b_default,omitempty"` FIMQwen7BSpec bool `json:"fim_qwen_7b_spec,omitempty"` FIMQwen14BSpec bool `json:"fim_qwen_14b_spec,omitempty"` + + // ExtraArgs are additional command line arguments. + // Example: {"verbose": "", "log-file": "/logs/llama.log"} + ExtraArgs map[string]string `json:"extra_args,omitempty"` } // UnmarshalJSON implements custom JSON unmarshaling to support multiple field names @@ -209,6 +212,15 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error { // Copy to our struct *o = LlamaServerOptions(temp) + // Track which fields we've processed + processedFields := make(map[string]bool) + + // Get all known canonical field names from struct tags + knownFields := getKnownFieldNames(o) + for field := range knownFields { + processedFields[field] = true + } + // Handle alternative field names fieldMappings := map[string]string{ // Common params @@ -220,7 +232,7 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error { "Crb": "cpu_range_batch", // -Crb, --cpu-range-batch lo-hi "c": "ctx_size", // -c, --ctx-size N "n": "predict", // -n, --predict N - "n-predict": "predict", // --n-predict N + "n_predict": "predict", // -n-predict N "b": "batch_size", // -b, --batch-size N "ub": "ubatch_size", // -ub, --ubatch-size N "fa": "flash_attn", // -fa, --flash-attn @@ -234,7 +246,7 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error { "dev": "device", // -dev, --device "ot": "override_tensor", // --override-tensor, -ot "ngl": "gpu_layers", // -ngl, --gpu-layers, --n-gpu-layers N - "n-gpu-layers": "gpu_layers", // --n-gpu-layers N + "n_gpu_layers": "gpu_layers", // --n-gpu-layers N "sm": "split_mode", // -sm, --split-mode "ts": "tensor_split", // -ts, --tensor-split N0,N1,N2,... "mg": "main_gpu", // -mg, --main-gpu INDEX @@ -250,9 +262,9 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error { "hffv": "hf_file_v", // -hffv, --hf-file-v FILE "hft": "hf_token", // -hft, --hf-token TOKEN "v": "verbose", // -v, --verbose, --log-verbose - "log-verbose": "verbose", // --log-verbose + "log_verbose": "verbose", // --log-verbose "lv": "verbosity", // -lv, --verbosity, --log-verbosity N - "log-verbosity": "verbosity", // --log-verbosity N + "log_verbosity": "verbosity", // --log-verbosity N // Sampling params "s": "seed", // -s, --seed SEED @@ -269,21 +281,23 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error { "rerank": "reranking", // --reranking "to": "timeout", // -to, --timeout N "sps": "slot_prompt_similarity", // -sps, --slot-prompt-similarity - "draft": "draft-max", // -draft, --draft-max N - "draft-n": "draft-max", // --draft-n-max N - "draft-n-min": "draft_min", // --draft-n-min N + "draft": "draft_max", // -draft, --draft-max N + "draft_n": "draft_max", // --draft-n-max N + "draft_n_min": "draft_min", // --draft-n-min N "cd": "ctx_size_draft", // -cd, --ctx-size-draft N "devd": "device_draft", // -devd, --device-draft "ngld": "gpu_layers_draft", // -ngld, --gpu-layers-draft - "n-gpu-layers-draft": "gpu_layers_draft", // --n-gpu-layers-draft N + "n_gpu_layers_draft": "gpu_layers_draft", // --n-gpu-layers-draft N "md": "model_draft", // -md, --model-draft FNAME "ctkd": "cache_type_k_draft", // -ctkd, --cache-type-k-draft TYPE "ctvd": "cache_type_v_draft", // -ctvd, --cache-type-v-draft TYPE "mv": "model_vocoder", // -mv, --model-vocoder FNAME } - // Process alternative field names + // Process alternative field names and mark them as processed for altName, canonicalName := range fieldMappings { + processedFields[altName] = true // Mark alternatives as known + if value, exists := raw[altName]; exists { // Use reflection to set the field value v := reflect.ValueOf(o).Elem() @@ -294,36 +308,21 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error { }) if field.IsValid() && field.CanSet() { - switch field.Kind() { - case reflect.Int: - if intVal, ok := value.(float64); ok { - field.SetInt(int64(intVal)) - } else if strVal, ok := value.(string); ok { - if intVal, err := strconv.Atoi(strVal); err == nil { - field.SetInt(int64(intVal)) - } - } - case reflect.Float64: - if floatVal, ok := value.(float64); ok { - field.SetFloat(floatVal) - } else if strVal, ok := value.(string); ok { - if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil { - field.SetFloat(floatVal) - } - } - case reflect.String: - if strVal, ok := value.(string); ok { - field.SetString(strVal) - } - case reflect.Bool: - if boolVal, ok := value.(bool); ok { - field.SetBool(boolVal) - } - } + setFieldValue(field, value) } } } + // Collect unknown fields into ExtraArgs + if o.ExtraArgs == nil { + o.ExtraArgs = make(map[string]string) + } + for key, value := range raw { + if !processedFields[key] { + o.ExtraArgs[key] = fmt.Sprintf("%v", value) + } + } + return nil } @@ -354,6 +353,18 @@ func (o *LlamaServerOptions) Validate() error { return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port)) } + // Validate extra_args keys and values + for key, value := range o.ExtraArgs { + if err := validation.ValidateStringForInjection(key); err != nil { + return validation.ValidationError(fmt.Errorf("extra_args key %q: %w", key, err)) + } + if value != "" { + if err := validation.ValidateStringForInjection(value); err != nil { + return validation.ValidationError(fmt.Errorf("extra_args value for %q: %w", key, err)) + } + } + } + return nil } @@ -361,7 +372,12 @@ func (o *LlamaServerOptions) Validate() error { func (o *LlamaServerOptions) BuildCommandArgs() []string { // Llama uses multiple flags for arrays by default (not comma-separated) // Use package-level llamaMultiValuedFlags variable - return BuildCommandArgs(o, llamaMultiValuedFlags) + args := BuildCommandArgs(o, llamaMultiValuedFlags) + + // Append extra args at the end + args = append(args, convertExtraArgsToFlags(o.ExtraArgs)...) + + return args } func (o *LlamaServerOptions) BuildDockerArgs() []string { diff --git a/pkg/backends/llama_test.go b/pkg/backends/llama_test.go index 1698d37..961967b 100644 --- a/pkg/backends/llama_test.go +++ b/pkg/backends/llama_test.go @@ -33,12 +33,11 @@ func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) { { name: "multiple booleans", options: backends.LlamaServerOptions{ - Verbose: true, - FlashAttn: true, - Mlock: false, - NoMmap: true, + Verbose: true, + Mlock: false, + NoMmap: true, }, - expected: []string{"--verbose", "--flash-attn", "--no-mmap"}, + expected: []string{"--verbose", "--no-mmap"}, excluded: []string{"--mlock"}, }, } @@ -346,7 +345,7 @@ func TestParseLlamaCommand(t *testing.T) { }, { name: "multiple value types", - command: "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap", + command: "llama-server --model /test/model.gguf --n-gpu-layers 32 --temp 0.7 --verbose --no-mmap", expectErr: false, validate: func(t *testing.T, opts *backends.LlamaServerOptions) { if opts.Model != "/test/model.gguf" { @@ -434,3 +433,119 @@ func TestParseLlamaCommandArrays(t *testing.T) { } } } + +func TestLlamaCppBuildCommandArgs_ExtraArgs(t *testing.T) { + options := backends.LlamaServerOptions{ + Model: "/models/test.gguf", + ExtraArgs: map[string]string{ + "flash-attn": "", // boolean flag + "log-file": "/logs/test.log", // value flag + }, + } + + args := options.BuildCommandArgs() + + // Check that extra args are present + if !testutil.Contains(args, "--flash-attn") { + t.Error("Expected --flash-attn flag not found") + } + if !testutil.Contains(args, "--log-file") || !testutil.Contains(args, "/logs/test.log") { + t.Error("Expected --log-file flag or value not found") + } +} + +func TestParseLlamaCommand_ExtraArgs(t *testing.T) { + tests := []struct { + name string + command string + expectErr bool + validate func(*testing.T, *backends.LlamaServerOptions) + }{ + { + name: "extra args with known fields", + command: "llama-server --model /path/to/model.gguf --gpu-layers 32 --unknown-flag value --another-bool-flag", + expectErr: false, + validate: func(t *testing.T, opts *backends.LlamaServerOptions) { + if opts.Model != "/path/to/model.gguf" { + t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model) + } + if opts.GPULayers != 32 { + t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers) + } + if opts.ExtraArgs == nil { + t.Fatal("expected extra_args to be non-nil") + } + if val, ok := opts.ExtraArgs["unknown_flag"]; !ok || val != "value" { + t.Errorf("expected extra_args[unknown_flag]='value', got '%s'", val) + } + if val, ok := opts.ExtraArgs["another_bool_flag"]; !ok || val != "true" { + t.Errorf("expected extra_args[another_bool_flag]='true', got '%s'", val) + } + }, + }, + { + name: "extra args with alternative field names", + command: "llama-server -m /model.gguf -ngl 16 --custom-arg test --new-feature", + expectErr: false, + validate: func(t *testing.T, opts *backends.LlamaServerOptions) { + // Check that alternative names worked for known fields + if opts.Model != "/model.gguf" { + t.Errorf("expected model '/model.gguf', got '%s'", opts.Model) + } + if opts.GPULayers != 16 { + t.Errorf("expected gpu_layers 16, got %d", opts.GPULayers) + } + // Check that unknown args went to ExtraArgs + if opts.ExtraArgs == nil { + t.Fatal("expected extra_args to be non-nil") + } + if val, ok := opts.ExtraArgs["custom_arg"]; !ok || val != "test" { + t.Errorf("expected extra_args[custom_arg]='test', got '%s'", val) + } + if val, ok := opts.ExtraArgs["new_feature"]; !ok || val != "true" { + t.Errorf("expected extra_args[new_feature]='true', got '%s'", val) + } + }, + }, + { + name: "only extra args", + command: "llama-server --experimental-feature --beta-mode enabled", + expectErr: false, + validate: func(t *testing.T, opts *backends.LlamaServerOptions) { + if opts.ExtraArgs == nil { + t.Fatal("expected extra_args to be non-nil") + } + if val, ok := opts.ExtraArgs["experimental_feature"]; !ok || val != "true" { + t.Errorf("expected extra_args[experimental_feature]='true', got '%s'", val) + } + if val, ok := opts.ExtraArgs["beta_mode"]; !ok || val != "enabled" { + t.Errorf("expected extra_args[beta_mode]='enabled', got '%s'", val) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var opts backends.LlamaServerOptions + result, err := opts.ParseCommand(tt.command) + + if tt.expectErr && err == nil { + t.Error("expected error but got none") + return + } + if !tt.expectErr && err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if !tt.expectErr && tt.validate != nil { + llamaOpts, ok := result.(*backends.LlamaServerOptions) + if !ok { + t.Fatal("result is not *LlamaServerOptions") + } + tt.validate(t, llamaOpts) + } + }) + } +} diff --git a/pkg/backends/mlx.go b/pkg/backends/mlx.go index 8911d0b..d066008 100644 --- a/pkg/backends/mlx.go +++ b/pkg/backends/mlx.go @@ -1,6 +1,7 @@ package backends import ( + "encoding/json" "fmt" "llamactl/pkg/validation" ) @@ -29,6 +30,46 @@ type MlxServerOptions struct { TopK int `json:"top_k,omitempty"` MinP float64 `json:"min_p,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` + + // ExtraArgs are additional command line arguments. + // Example: {"verbose": "", "log-file": "/logs/mlx.log"} + ExtraArgs map[string]string `json:"extra_args,omitempty"` +} + +// UnmarshalJSON implements custom JSON unmarshaling to collect unknown fields into ExtraArgs +func (o *MlxServerOptions) UnmarshalJSON(data []byte) error { + // First unmarshal into a map to capture all fields + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Create a temporary struct for standard unmarshaling + type tempOptions MlxServerOptions + temp := tempOptions{} + + // Standard unmarshal first + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + // Copy to our struct + *o = MlxServerOptions(temp) + + // Get all known canonical field names from struct tags + knownFields := getKnownFieldNames(o) + + // Collect unknown fields into ExtraArgs + if o.ExtraArgs == nil { + o.ExtraArgs = make(map[string]string) + } + for key, value := range raw { + if !knownFields[key] { + o.ExtraArgs[key] = fmt.Sprintf("%v", value) + } + } + + return nil } func (o *MlxServerOptions) GetPort() int { @@ -57,13 +98,30 @@ func (o *MlxServerOptions) Validate() error { return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port)) } + // Validate extra_args keys and values + for key, value := range o.ExtraArgs { + if err := validation.ValidateStringForInjection(key); err != nil { + return validation.ValidationError(fmt.Errorf("extra_args key %q: %w", key, err)) + } + if value != "" { + if err := validation.ValidateStringForInjection(value); err != nil { + return validation.ValidationError(fmt.Errorf("extra_args value for %q: %w", key, err)) + } + } + } + return nil } // BuildCommandArgs converts to command line arguments func (o *MlxServerOptions) BuildCommandArgs() []string { multipleFlags := map[string]struct{}{} // MLX doesn't currently have []string fields - return BuildCommandArgs(o, multipleFlags) + args := BuildCommandArgs(o, multipleFlags) + + // Append extra args at the end + args = append(args, convertExtraArgsToFlags(o.ExtraArgs)...) + + return args } func (o *MlxServerOptions) BuildDockerArgs() []string { diff --git a/pkg/backends/mlx_test.go b/pkg/backends/mlx_test.go index d15be3d..f8a2ee5 100644 --- a/pkg/backends/mlx_test.go +++ b/pkg/backends/mlx_test.go @@ -202,3 +202,75 @@ func TestMlxBuildCommandArgs_ZeroValues(t *testing.T) { } } } + +func TestParseMlxCommand_ExtraArgs(t *testing.T) { + tests := []struct { + name string + command string + expectErr bool + validate func(*testing.T, *backends.MlxServerOptions) + }{ + { + name: "extra args with known fields", + command: "mlx_lm.server --model /path/to/model --port 8080 --unknown-flag value --new-bool-flag", + expectErr: false, + validate: func(t *testing.T, opts *backends.MlxServerOptions) { + if opts.Model != "/path/to/model" { + t.Errorf("expected model '/path/to/model', got '%s'", opts.Model) + } + if opts.Port != 8080 { + t.Errorf("expected port 8080, got %d", opts.Port) + } + if opts.ExtraArgs == nil { + t.Fatal("expected extra_args to be non-nil") + } + if val, ok := opts.ExtraArgs["unknown_flag"]; !ok || val != "value" { + t.Errorf("expected extra_args[unknown_flag]='value', got '%s'", val) + } + if val, ok := opts.ExtraArgs["new_bool_flag"]; !ok || val != "true" { + t.Errorf("expected extra_args[new_bool_flag]='true', got '%s'", val) + } + }, + }, + { + name: "only extra args", + command: "mlx_lm.server --experimental-feature --custom-param test", + expectErr: false, + validate: func(t *testing.T, opts *backends.MlxServerOptions) { + if opts.ExtraArgs == nil { + t.Fatal("expected extra_args to be non-nil") + } + if val, ok := opts.ExtraArgs["experimental_feature"]; !ok || val != "true" { + t.Errorf("expected extra_args[experimental_feature]='true', got '%s'", val) + } + if val, ok := opts.ExtraArgs["custom_param"]; !ok || val != "test" { + t.Errorf("expected extra_args[custom_param]='test', got '%s'", val) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var opts backends.MlxServerOptions + result, err := opts.ParseCommand(tt.command) + + if tt.expectErr && err == nil { + t.Error("expected error but got none") + return + } + if !tt.expectErr && err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if !tt.expectErr && tt.validate != nil { + mlxOpts, ok := result.(*backends.MlxServerOptions) + if !ok { + t.Fatal("result is not *MlxServerOptions") + } + tt.validate(t, mlxOpts) + } + }) + } +} diff --git a/pkg/backends/parser.go b/pkg/backends/parser.go index 8208568..77f6099 100644 --- a/pkg/backends/parser.go +++ b/pkg/backends/parser.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "path/filepath" + "reflect" "regexp" "strconv" "strings" @@ -211,3 +212,65 @@ func parseValue(value string) any { // Return as string return value } + +// setFieldValue sets a field value using reflection, handling type conversions +// Used by UnmarshalJSON implementations to handle alternative field names +func setFieldValue(field reflect.Value, value any) { + switch field.Kind() { + case reflect.Int: + if intVal, ok := value.(float64); ok { + field.SetInt(int64(intVal)) + } else if strVal, ok := value.(string); ok { + if intVal, err := strconv.Atoi(strVal); err == nil { + field.SetInt(int64(intVal)) + } + } + case reflect.Float64: + if floatVal, ok := value.(float64); ok { + field.SetFloat(floatVal) + } else if strVal, ok := value.(string); ok { + if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil { + field.SetFloat(floatVal) + } + } + case reflect.String: + if strVal, ok := value.(string); ok { + field.SetString(strVal) + } + case reflect.Bool: + if boolVal, ok := value.(bool); ok { + field.SetBool(boolVal) + } + case reflect.Slice: + // Handle string slices + if field.Type().Elem().Kind() == reflect.String { + if slice, ok := value.([]any); ok { + strSlice := make([]string, 0, len(slice)) + for _, v := range slice { + if s, ok := v.(string); ok { + strSlice = append(strSlice, s) + } + } + field.Set(reflect.ValueOf(strSlice)) + } + } + } +} + +// getKnownFieldNames extracts all known field names from struct json tags +// Used by UnmarshalJSON implementations to identify unknown fields for ExtraArgs +func getKnownFieldNames(v any) map[string]bool { + fields := make(map[string]bool) + t := reflect.TypeOf(v).Elem() + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + jsonTag := field.Tag.Get("json") + if jsonTag != "" && jsonTag != "-" { + // Handle "name,omitempty" format + name := strings.Split(jsonTag, ",")[0] + fields[name] = true + } + } + return fields +} diff --git a/pkg/backends/vllm.go b/pkg/backends/vllm.go index 34dce4c..1c5d07c 100644 --- a/pkg/backends/vllm.go +++ b/pkg/backends/vllm.go @@ -1,6 +1,7 @@ package backends import ( + "encoding/json" "fmt" "llamactl/pkg/validation" ) @@ -142,6 +143,46 @@ type VllmServerOptions struct { OverridePoolingConfig string `json:"override_pooling_config,omitempty"` OverrideNeuronConfig string `json:"override_neuron_config,omitempty"` OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"` + + // ExtraArgs are additional command line arguments. + // Example: {"verbose": "", "log-file": "/logs/vllm.log"} + ExtraArgs map[string]string `json:"extra_args,omitempty"` +} + +// UnmarshalJSON implements custom JSON unmarshaling to collect unknown fields into ExtraArgs +func (o *VllmServerOptions) UnmarshalJSON(data []byte) error { + // First unmarshal into a map to capture all fields + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Create a temporary struct for standard unmarshaling + type tempOptions VllmServerOptions + temp := tempOptions{} + + // Standard unmarshal first + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + // Copy to our struct + *o = VllmServerOptions(temp) + + // Get all known canonical field names from struct tags + knownFields := getKnownFieldNames(o) + + // Collect unknown fields into ExtraArgs + if o.ExtraArgs == nil { + o.ExtraArgs = make(map[string]string) + } + for key, value := range raw { + if !knownFields[key] { + o.ExtraArgs[key] = fmt.Sprintf("%v", value) + } + } + + return nil } func (o *VllmServerOptions) GetPort() int { @@ -171,6 +212,18 @@ func (o *VllmServerOptions) Validate() error { return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port)) } + // Validate extra_args keys and values + for key, value := range o.ExtraArgs { + if err := validation.ValidateStringForInjection(key); err != nil { + return validation.ValidationError(fmt.Errorf("extra_args key %q: %w", key, err)) + } + if value != "" { + if err := validation.ValidateStringForInjection(value); err != nil { + return validation.ValidationError(fmt.Errorf("extra_args value for %q: %w", key, err)) + } + } + } + return nil } @@ -193,6 +246,9 @@ func (o *VllmServerOptions) BuildCommandArgs() []string { flagArgs := BuildCommandArgs(&optionsCopy, vllmMultiValuedFlags) args = append(args, flagArgs...) + // Append extra args at the end + args = append(args, convertExtraArgsToFlags(o.ExtraArgs)...) + return args } @@ -203,6 +259,9 @@ func (o *VllmServerOptions) BuildDockerArgs() []string { flagArgs := BuildCommandArgs(o, vllmMultiValuedFlags) args = append(args, flagArgs...) + // Append extra args at the end + args = append(args, convertExtraArgsToFlags(o.ExtraArgs)...) + return args } diff --git a/pkg/backends/vllm_test.go b/pkg/backends/vllm_test.go index acec8d6..c3b1308 100644 --- a/pkg/backends/vllm_test.go +++ b/pkg/backends/vllm_test.go @@ -321,3 +321,94 @@ func TestVllmBuildCommandArgs_PositionalModel(t *testing.T) { t.Errorf("Expected --port 8080 not found in %v", args) } } + +func TestParseVllmCommand_ExtraArgs(t *testing.T) { + tests := []struct { + name string + command string + expectErr bool + validate func(*testing.T, *backends.VllmServerOptions) + }{ + { + name: "extra args with known fields", + command: "vllm serve llama-model --tensor-parallel-size 2 --unknown-flag value --new-bool-flag", + expectErr: false, + validate: func(t *testing.T, opts *backends.VllmServerOptions) { + if opts.Model != "llama-model" { + t.Errorf("expected model 'llama-model', got '%s'", opts.Model) + } + if opts.TensorParallelSize != 2 { + t.Errorf("expected tensor_parallel_size 2, got %d", opts.TensorParallelSize) + } + if opts.ExtraArgs == nil { + t.Fatal("expected extra_args to be non-nil") + } + if val, ok := opts.ExtraArgs["unknown_flag"]; !ok || val != "value" { + t.Errorf("expected extra_args[unknown_flag]='value', got '%s'", val) + } + if val, ok := opts.ExtraArgs["new_bool_flag"]; !ok || val != "true" { + t.Errorf("expected extra_args[new_bool_flag]='true', got '%s'", val) + } + }, + }, + { + name: "only extra args", + command: "vllm serve model --experimental-feature --custom-param test", + expectErr: false, + validate: func(t *testing.T, opts *backends.VllmServerOptions) { + if opts.ExtraArgs == nil { + t.Fatal("expected extra_args to be non-nil") + } + if val, ok := opts.ExtraArgs["experimental_feature"]; !ok || val != "true" { + t.Errorf("expected extra_args[experimental_feature]='true', got '%s'", val) + } + if val, ok := opts.ExtraArgs["custom_param"]; !ok || val != "test" { + t.Errorf("expected extra_args[custom_param]='test', got '%s'", val) + } + }, + }, + { + name: "extra args without model positional", + command: "vllm serve --model my-model --new-feature enabled --beta-flag", + expectErr: false, + validate: func(t *testing.T, opts *backends.VllmServerOptions) { + if opts.Model != "my-model" { + t.Errorf("expected model 'my-model', got '%s'", opts.Model) + } + if opts.ExtraArgs == nil { + t.Fatal("expected extra_args to be non-nil") + } + if val, ok := opts.ExtraArgs["new_feature"]; !ok || val != "enabled" { + t.Errorf("expected extra_args[new_feature]='enabled', got '%s'", val) + } + if val, ok := opts.ExtraArgs["beta_flag"]; !ok || val != "true" { + t.Errorf("expected extra_args[beta_flag]='true', got '%s'", val) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var opts backends.VllmServerOptions + result, err := opts.ParseCommand(tt.command) + + if tt.expectErr && err == nil { + t.Error("expected error but got none") + return + } + if !tt.expectErr && err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if !tt.expectErr && tt.validate != nil { + vllmOpts, ok := result.(*backends.VllmServerOptions) + if !ok { + t.Fatal("result is not *VllmServerOptions") + } + tt.validate(t, vllmOpts) + } + }) + } +} diff --git a/pkg/validation/validation_test.go b/pkg/validation/validation_test.go index e447666..08ac8f0 100644 --- a/pkg/validation/validation_test.go +++ b/pkg/validation/validation_test.go @@ -239,25 +239,3 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) { }) } } - -func TestValidateInstanceOptions_NonStringFields(t *testing.T) { - // Test that non-string fields don't interfere with validation - options := backends.Options{ - BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: &backends.LlamaServerOptions{ - Port: 8080, - GPULayers: 32, - CtxSize: 4096, - Temperature: 0.7, - TopK: 40, - TopP: 0.9, - Verbose: true, - FlashAttn: false, - }, - } - - err := options.ValidateInstanceOptions() - if err != nil { - t.Errorf("ValidateInstanceOptions with non-string fields should not error, got: %v", err) - } -} diff --git a/webui/src/components/BackendFormField.tsx b/webui/src/components/BackendFormField.tsx index bb49fc1..7dfbf5a 100644 --- a/webui/src/components/BackendFormField.tsx +++ b/webui/src/components/BackendFormField.tsx @@ -3,17 +3,31 @@ import { Input } from '@/components/ui/input' import { Label } from '@/components/ui/label' import { Checkbox } from '@/components/ui/checkbox' import { getBackendFieldType, basicBackendFieldsConfig } from '@/lib/zodFormUtils' +import ExtraArgsInput from '@/components/form/ExtraArgsInput' interface BackendFormFieldProps { fieldKey: string - value: string | number | boolean | string[] | undefined - onChange: (key: string, value: string | number | boolean | string[] | undefined) => void + value: string | number | boolean | string[] | Record | undefined + onChange: (key: string, value: string | number | boolean | string[] | Record | undefined) => void } const BackendFormField: React.FC = ({ fieldKey, value, onChange }) => { + // Special handling for extra_args + if (fieldKey === 'extra_args') { + return ( + | undefined} + onChange={(newValue) => onChange(fieldKey, newValue)} + description="Additional command line arguments to pass to the backend" + /> + ) + } + // Get configuration for basic fields, or use field name for advanced fields const config = basicBackendFieldsConfig[fieldKey] || { label: fieldKey } - + // Get type from Zod schema const fieldType = getBackendFieldType(fieldKey) diff --git a/webui/src/components/form/EnvVarsInput.tsx b/webui/src/components/form/EnvVarsInput.tsx new file mode 100644 index 0000000..476a98a --- /dev/null +++ b/webui/src/components/form/EnvVarsInput.tsx @@ -0,0 +1,27 @@ +import React from 'react' +import KeyValueInput from './KeyValueInput' + +interface EnvVarsInputProps { + id: string + label: string + value: Record | undefined + onChange: (value: Record | undefined) => void + description?: string + disabled?: boolean + className?: string +} + +const EnvVarsInput: React.FC = (props) => { + return ( + + ) +} + +export default EnvVarsInput diff --git a/webui/src/components/form/EnvironmentVariablesInput.tsx b/webui/src/components/form/EnvironmentVariablesInput.tsx deleted file mode 100644 index 47739f0..0000000 --- a/webui/src/components/form/EnvironmentVariablesInput.tsx +++ /dev/null @@ -1,144 +0,0 @@ -import React, { useState } from 'react' -import { Input } from '@/components/ui/input' -import { Label } from '@/components/ui/label' -import { Button } from '@/components/ui/button' -import { X, Plus } from 'lucide-react' - -interface EnvironmentVariablesInputProps { - id: string - label: string - value: Record | undefined - onChange: (value: Record | undefined) => void - description?: string - disabled?: boolean - className?: string -} - -interface EnvVar { - key: string - value: string -} - -const EnvironmentVariablesInput: React.FC = ({ - id, - label, - value, - onChange, - description, - disabled = false, - className -}) => { - // Convert the value object to an array of key-value pairs for editing - const envVarsFromValue = value - ? Object.entries(value).map(([key, val]) => ({ key, value: val })) - : [] - - const [envVars, setEnvVars] = useState( - envVarsFromValue.length > 0 ? envVarsFromValue : [{ key: '', value: '' }] - ) - - // Update parent component when env vars change - const updateParent = (newEnvVars: EnvVar[]) => { - // Filter out empty entries - const validVars = newEnvVars.filter(env => env.key.trim() !== '' && env.value.trim() !== '') - - if (validVars.length === 0) { - onChange(undefined) - } else { - const envObject = validVars.reduce((acc, env) => { - acc[env.key.trim()] = env.value.trim() - return acc - }, {} as Record) - onChange(envObject) - } - } - - const handleKeyChange = (index: number, newKey: string) => { - const newEnvVars = [...envVars] - newEnvVars[index].key = newKey - setEnvVars(newEnvVars) - updateParent(newEnvVars) - } - - const handleValueChange = (index: number, newValue: string) => { - const newEnvVars = [...envVars] - newEnvVars[index].value = newValue - setEnvVars(newEnvVars) - updateParent(newEnvVars) - } - - const addEnvVar = () => { - const newEnvVars = [...envVars, { key: '', value: '' }] - setEnvVars(newEnvVars) - } - - const removeEnvVar = (index: number) => { - if (envVars.length === 1) { - // Reset to empty if it's the last one - const newEnvVars = [{ key: '', value: '' }] - setEnvVars(newEnvVars) - updateParent(newEnvVars) - } else { - const newEnvVars = envVars.filter((_, i) => i !== index) - setEnvVars(newEnvVars) - updateParent(newEnvVars) - } - } - - return ( -
- -
- {envVars.map((envVar, index) => ( -
- handleKeyChange(index, e.target.value)} - disabled={disabled} - className="flex-1" - /> - handleValueChange(index, e.target.value)} - disabled={disabled} - className="flex-1" - /> - -
- ))} - -
- {description && ( -

{description}

- )} -

- Environment variables that will be passed to the backend process -

-
- ) -} - -export default EnvironmentVariablesInput \ No newline at end of file diff --git a/webui/src/components/form/ExtraArgsInput.tsx b/webui/src/components/form/ExtraArgsInput.tsx new file mode 100644 index 0000000..70f11b8 --- /dev/null +++ b/webui/src/components/form/ExtraArgsInput.tsx @@ -0,0 +1,27 @@ +import React from 'react' +import KeyValueInput from './KeyValueInput' + +interface ExtraArgsInputProps { + id: string + label: string + value: Record | undefined + onChange: (value: Record | undefined) => void + description?: string + disabled?: boolean + className?: string +} + +const ExtraArgsInput: React.FC = (props) => { + return ( + + ) +} + +export default ExtraArgsInput diff --git a/webui/src/components/form/KeyValueInput.tsx b/webui/src/components/form/KeyValueInput.tsx new file mode 100644 index 0000000..62585c4 --- /dev/null +++ b/webui/src/components/form/KeyValueInput.tsx @@ -0,0 +1,171 @@ +import React, { useState, useEffect } from 'react' +import { Input } from '@/components/ui/input' +import { Label } from '@/components/ui/label' +import { Button } from '@/components/ui/button' +import { X, Plus } from 'lucide-react' + +interface KeyValueInputProps { + id: string + label: string + value: Record | undefined + onChange: (value: Record | undefined) => void + description?: string + disabled?: boolean + className?: string + keyPlaceholder?: string + valuePlaceholder?: string + addButtonText?: string + helperText?: string + allowEmptyValues?: boolean // If true, entries with empty values are considered valid +} + +interface KeyValuePair { + key: string + value: string +} + +const KeyValueInput: React.FC = ({ + id, + label, + value, + onChange, + description, + disabled = false, + className, + keyPlaceholder = 'Key', + valuePlaceholder = 'Value', + addButtonText = 'Add Entry', + helperText, + allowEmptyValues = false +}) => { + // Convert the value object to an array of key-value pairs for editing + const pairsFromValue = value + ? Object.entries(value).map(([key, val]) => ({ key, value: val })) + : [] + + const [pairs, setPairs] = useState( + pairsFromValue.length > 0 ? pairsFromValue : [{ key: '', value: '' }] + ) + + // Sync internal state when value prop changes + useEffect(() => { + const newPairsFromValue = value + ? Object.entries(value).map(([key, val]) => ({ key, value: val })) + : [] + + if (newPairsFromValue.length > 0) { + setPairs(newPairsFromValue) + } else if (!value) { + // Reset to single empty row if value is explicitly undefined/null + setPairs([{ key: '', value: '' }]) + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [value]) + + // Update parent component when pairs change + const updateParent = (newPairs: KeyValuePair[]) => { + // Filter based on validation rules + const validPairs = allowEmptyValues + ? newPairs.filter(pair => pair.key.trim() !== '') + : newPairs.filter(pair => pair.key.trim() !== '' && pair.value.trim() !== '') + + if (validPairs.length === 0) { + onChange(undefined) + } else { + const pairsObject = validPairs.reduce((acc, pair) => { + acc[pair.key.trim()] = pair.value.trim() + return acc + }, {} as Record) + onChange(pairsObject) + } + } + + const handleKeyChange = (index: number, newKey: string) => { + const newPairs = [...pairs] + newPairs[index].key = newKey + setPairs(newPairs) + updateParent(newPairs) + } + + const handleValueChange = (index: number, newValue: string) => { + const newPairs = [...pairs] + newPairs[index].value = newValue + setPairs(newPairs) + updateParent(newPairs) + } + + const addPair = () => { + const newPairs = [...pairs, { key: '', value: '' }] + setPairs(newPairs) + } + + const removePair = (index: number) => { + if (pairs.length === 1) { + // Reset to empty if it's the last one + const newPairs = [{ key: '', value: '' }] + setPairs(newPairs) + updateParent(newPairs) + } else { + const newPairs = pairs.filter((_, i) => i !== index) + setPairs(newPairs) + updateParent(newPairs) + } + } + + return ( +
+ +
+ {pairs.map((pair, index) => ( +
+ handleKeyChange(index, e.target.value)} + disabled={disabled} + className="flex-1" + /> + handleValueChange(index, e.target.value)} + disabled={disabled} + className="flex-1" + /> + +
+ ))} + +
+ {description && ( +

{description}

+ )} + {helperText && ( +

{helperText}

+ )} +
+ ) +} + +export default KeyValueInput diff --git a/webui/src/components/instance/BackendConfiguration.tsx b/webui/src/components/instance/BackendConfiguration.tsx index cfcee86..8f10e41 100644 --- a/webui/src/components/instance/BackendConfiguration.tsx +++ b/webui/src/components/instance/BackendConfiguration.tsx @@ -47,8 +47,18 @@ const BackendConfiguration: React.FC = ({ ))} )} + + {/* Extra Args - Always visible as a separate section */} +
+ +
) } -export default BackendConfiguration \ No newline at end of file +export default BackendConfiguration diff --git a/webui/src/components/instance/BackendConfigurationCard.tsx b/webui/src/components/instance/BackendConfigurationCard.tsx index 5bf7c36..799ea2b 100644 --- a/webui/src/components/instance/BackendConfigurationCard.tsx +++ b/webui/src/components/instance/BackendConfigurationCard.tsx @@ -109,6 +109,16 @@ const BackendConfigurationCard: React.FC = ({ )} )} + + {/* Extra Arguments - Always visible */} +
+ )?.extra_args as Record | undefined} + onChange={onBackendFieldChange} + /> +
) diff --git a/webui/src/components/instance/InstanceSettingsCard.tsx b/webui/src/components/instance/InstanceSettingsCard.tsx index 1834eab..7b853cb 100644 --- a/webui/src/components/instance/InstanceSettingsCard.tsx +++ b/webui/src/components/instance/InstanceSettingsCard.tsx @@ -6,7 +6,7 @@ import { Input } from '@/components/ui/input' import AutoRestartConfiguration from '@/components/instance/AutoRestartConfiguration' import NumberInput from '@/components/form/NumberInput' import CheckboxInput from '@/components/form/CheckboxInput' -import EnvironmentVariablesInput from '@/components/form/EnvironmentVariablesInput' +import EnvVarsInput from '@/components/form/EnvVarsInput' import SelectInput from '@/components/form/SelectInput' import { nodesApi, type NodesMap } from '@/lib/api' @@ -132,7 +132,7 @@ const InstanceSettingsCard: React.FC = ({ description="Start instance only when needed" /> - !(key in basicConfig)) + return fieldGetter().filter(key => !(key in basicConfig) && key !== 'extra_args') } // Combined backend fields config for use in BackendFormField diff --git a/webui/src/schemas/backends/llamacpp.ts b/webui/src/schemas/backends/llamacpp.ts index 7dead95..9383ec8 100644 --- a/webui/src/schemas/backends/llamacpp.ts +++ b/webui/src/schemas/backends/llamacpp.ts @@ -167,6 +167,9 @@ export const LlamaCppBackendOptionsSchema = z.object({ fim_qwen_7b_default: z.boolean().optional(), fim_qwen_7b_spec: z.boolean().optional(), fim_qwen_14b_spec: z.boolean().optional(), + + // Extra args + extra_args: z.record(z.string(), z.string()).optional(), }) // Infer the TypeScript type from the schema diff --git a/webui/src/schemas/backends/mlx.ts b/webui/src/schemas/backends/mlx.ts index 917ca81..4267ed9 100644 --- a/webui/src/schemas/backends/mlx.ts +++ b/webui/src/schemas/backends/mlx.ts @@ -25,6 +25,9 @@ export const MlxBackendOptionsSchema = z.object({ top_k: z.number().optional(), min_p: z.number().optional(), max_tokens: z.number().optional(), + + // Extra args + extra_args: z.record(z.string(), z.string()).optional(), }) // Infer the TypeScript type from the schema diff --git a/webui/src/schemas/backends/vllm.ts b/webui/src/schemas/backends/vllm.ts index 7dd700f..0972a8f 100644 --- a/webui/src/schemas/backends/vllm.ts +++ b/webui/src/schemas/backends/vllm.ts @@ -125,6 +125,9 @@ export const VllmBackendOptionsSchema = z.object({ override_pooling_config: z.string().optional(), override_neuron_config: z.string().optional(), override_kv_cache_align_size: z.number().optional(), + + // Extra args + extra_args: z.record(z.string(), z.string()).optional(), }) // Infer the TypeScript type from the schema