mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-12-22 17:14:22 +00:00
Merge pull request #90 from lordmathis/feat/custom-args
feat: Add support for custom args
This commit is contained in:
@@ -93,3 +93,22 @@ func BuildDockerCommand(backendConfig *config.BackendSettings, instanceArgs []st
|
|||||||
|
|
||||||
return "docker", dockerArgs, nil
|
return "docker", dockerArgs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// convertExtraArgsToFlags converts map[string]string to command flags
|
||||||
|
// Empty values become boolean flags: {"flag": ""} → ["--flag"]
|
||||||
|
// Non-empty values: {"flag": "value"} → ["--flag", "value"]
|
||||||
|
func convertExtraArgsToFlags(extraArgs map[string]string) []string {
|
||||||
|
var args []string
|
||||||
|
|
||||||
|
for key, value := range extraArgs {
|
||||||
|
if value == "" {
|
||||||
|
// Boolean flag
|
||||||
|
args = append(args, "--"+key)
|
||||||
|
} else {
|
||||||
|
// Value flag
|
||||||
|
args = append(args, "--"+key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"llamactl/pkg/validation"
|
"llamactl/pkg/validation"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
// llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
||||||
@@ -41,7 +40,7 @@ type LlamaServerOptions struct {
|
|||||||
BatchSize int `json:"batch_size,omitempty"`
|
BatchSize int `json:"batch_size,omitempty"`
|
||||||
UBatchSize int `json:"ubatch_size,omitempty"`
|
UBatchSize int `json:"ubatch_size,omitempty"`
|
||||||
Keep int `json:"keep,omitempty"`
|
Keep int `json:"keep,omitempty"`
|
||||||
FlashAttn bool `json:"flash_attn,omitempty"`
|
FlashAttn string `json:"flash_attn,omitempty"`
|
||||||
NoPerf bool `json:"no_perf,omitempty"`
|
NoPerf bool `json:"no_perf,omitempty"`
|
||||||
Escape bool `json:"escape,omitempty"`
|
Escape bool `json:"escape,omitempty"`
|
||||||
NoEscape bool `json:"no_escape,omitempty"`
|
NoEscape bool `json:"no_escape,omitempty"`
|
||||||
@@ -187,6 +186,10 @@ type LlamaServerOptions struct {
|
|||||||
FIMQwen7BDefault bool `json:"fim_qwen_7b_default,omitempty"`
|
FIMQwen7BDefault bool `json:"fim_qwen_7b_default,omitempty"`
|
||||||
FIMQwen7BSpec bool `json:"fim_qwen_7b_spec,omitempty"`
|
FIMQwen7BSpec bool `json:"fim_qwen_7b_spec,omitempty"`
|
||||||
FIMQwen14BSpec bool `json:"fim_qwen_14b_spec,omitempty"`
|
FIMQwen14BSpec bool `json:"fim_qwen_14b_spec,omitempty"`
|
||||||
|
|
||||||
|
// ExtraArgs are additional command line arguments.
|
||||||
|
// Example: {"verbose": "", "log-file": "/logs/llama.log"}
|
||||||
|
ExtraArgs map[string]string `json:"extra_args,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names
|
// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names
|
||||||
@@ -209,6 +212,15 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
|
|||||||
// Copy to our struct
|
// Copy to our struct
|
||||||
*o = LlamaServerOptions(temp)
|
*o = LlamaServerOptions(temp)
|
||||||
|
|
||||||
|
// Track which fields we've processed
|
||||||
|
processedFields := make(map[string]bool)
|
||||||
|
|
||||||
|
// Get all known canonical field names from struct tags
|
||||||
|
knownFields := getKnownFieldNames(o)
|
||||||
|
for field := range knownFields {
|
||||||
|
processedFields[field] = true
|
||||||
|
}
|
||||||
|
|
||||||
// Handle alternative field names
|
// Handle alternative field names
|
||||||
fieldMappings := map[string]string{
|
fieldMappings := map[string]string{
|
||||||
// Common params
|
// Common params
|
||||||
@@ -220,7 +232,7 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
|
|||||||
"Crb": "cpu_range_batch", // -Crb, --cpu-range-batch lo-hi
|
"Crb": "cpu_range_batch", // -Crb, --cpu-range-batch lo-hi
|
||||||
"c": "ctx_size", // -c, --ctx-size N
|
"c": "ctx_size", // -c, --ctx-size N
|
||||||
"n": "predict", // -n, --predict N
|
"n": "predict", // -n, --predict N
|
||||||
"n-predict": "predict", // --n-predict N
|
"n_predict": "predict", // -n-predict N
|
||||||
"b": "batch_size", // -b, --batch-size N
|
"b": "batch_size", // -b, --batch-size N
|
||||||
"ub": "ubatch_size", // -ub, --ubatch-size N
|
"ub": "ubatch_size", // -ub, --ubatch-size N
|
||||||
"fa": "flash_attn", // -fa, --flash-attn
|
"fa": "flash_attn", // -fa, --flash-attn
|
||||||
@@ -234,7 +246,7 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
|
|||||||
"dev": "device", // -dev, --device <dev1,dev2,..>
|
"dev": "device", // -dev, --device <dev1,dev2,..>
|
||||||
"ot": "override_tensor", // --override-tensor, -ot
|
"ot": "override_tensor", // --override-tensor, -ot
|
||||||
"ngl": "gpu_layers", // -ngl, --gpu-layers, --n-gpu-layers N
|
"ngl": "gpu_layers", // -ngl, --gpu-layers, --n-gpu-layers N
|
||||||
"n-gpu-layers": "gpu_layers", // --n-gpu-layers N
|
"n_gpu_layers": "gpu_layers", // --n-gpu-layers N
|
||||||
"sm": "split_mode", // -sm, --split-mode
|
"sm": "split_mode", // -sm, --split-mode
|
||||||
"ts": "tensor_split", // -ts, --tensor-split N0,N1,N2,...
|
"ts": "tensor_split", // -ts, --tensor-split N0,N1,N2,...
|
||||||
"mg": "main_gpu", // -mg, --main-gpu INDEX
|
"mg": "main_gpu", // -mg, --main-gpu INDEX
|
||||||
@@ -250,9 +262,9 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
|
|||||||
"hffv": "hf_file_v", // -hffv, --hf-file-v FILE
|
"hffv": "hf_file_v", // -hffv, --hf-file-v FILE
|
||||||
"hft": "hf_token", // -hft, --hf-token TOKEN
|
"hft": "hf_token", // -hft, --hf-token TOKEN
|
||||||
"v": "verbose", // -v, --verbose, --log-verbose
|
"v": "verbose", // -v, --verbose, --log-verbose
|
||||||
"log-verbose": "verbose", // --log-verbose
|
"log_verbose": "verbose", // --log-verbose
|
||||||
"lv": "verbosity", // -lv, --verbosity, --log-verbosity N
|
"lv": "verbosity", // -lv, --verbosity, --log-verbosity N
|
||||||
"log-verbosity": "verbosity", // --log-verbosity N
|
"log_verbosity": "verbosity", // --log-verbosity N
|
||||||
|
|
||||||
// Sampling params
|
// Sampling params
|
||||||
"s": "seed", // -s, --seed SEED
|
"s": "seed", // -s, --seed SEED
|
||||||
@@ -269,21 +281,23 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
|
|||||||
"rerank": "reranking", // --reranking
|
"rerank": "reranking", // --reranking
|
||||||
"to": "timeout", // -to, --timeout N
|
"to": "timeout", // -to, --timeout N
|
||||||
"sps": "slot_prompt_similarity", // -sps, --slot-prompt-similarity
|
"sps": "slot_prompt_similarity", // -sps, --slot-prompt-similarity
|
||||||
"draft": "draft-max", // -draft, --draft-max N
|
"draft": "draft_max", // -draft, --draft-max N
|
||||||
"draft-n": "draft-max", // --draft-n-max N
|
"draft_n": "draft_max", // --draft-n-max N
|
||||||
"draft-n-min": "draft_min", // --draft-n-min N
|
"draft_n_min": "draft_min", // --draft-n-min N
|
||||||
"cd": "ctx_size_draft", // -cd, --ctx-size-draft N
|
"cd": "ctx_size_draft", // -cd, --ctx-size-draft N
|
||||||
"devd": "device_draft", // -devd, --device-draft
|
"devd": "device_draft", // -devd, --device-draft
|
||||||
"ngld": "gpu_layers_draft", // -ngld, --gpu-layers-draft
|
"ngld": "gpu_layers_draft", // -ngld, --gpu-layers-draft
|
||||||
"n-gpu-layers-draft": "gpu_layers_draft", // --n-gpu-layers-draft N
|
"n_gpu_layers_draft": "gpu_layers_draft", // --n-gpu-layers-draft N
|
||||||
"md": "model_draft", // -md, --model-draft FNAME
|
"md": "model_draft", // -md, --model-draft FNAME
|
||||||
"ctkd": "cache_type_k_draft", // -ctkd, --cache-type-k-draft TYPE
|
"ctkd": "cache_type_k_draft", // -ctkd, --cache-type-k-draft TYPE
|
||||||
"ctvd": "cache_type_v_draft", // -ctvd, --cache-type-v-draft TYPE
|
"ctvd": "cache_type_v_draft", // -ctvd, --cache-type-v-draft TYPE
|
||||||
"mv": "model_vocoder", // -mv, --model-vocoder FNAME
|
"mv": "model_vocoder", // -mv, --model-vocoder FNAME
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process alternative field names
|
// Process alternative field names and mark them as processed
|
||||||
for altName, canonicalName := range fieldMappings {
|
for altName, canonicalName := range fieldMappings {
|
||||||
|
processedFields[altName] = true // Mark alternatives as known
|
||||||
|
|
||||||
if value, exists := raw[altName]; exists {
|
if value, exists := raw[altName]; exists {
|
||||||
// Use reflection to set the field value
|
// Use reflection to set the field value
|
||||||
v := reflect.ValueOf(o).Elem()
|
v := reflect.ValueOf(o).Elem()
|
||||||
@@ -294,36 +308,21 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if field.IsValid() && field.CanSet() {
|
if field.IsValid() && field.CanSet() {
|
||||||
switch field.Kind() {
|
setFieldValue(field, value)
|
||||||
case reflect.Int:
|
|
||||||
if intVal, ok := value.(float64); ok {
|
|
||||||
field.SetInt(int64(intVal))
|
|
||||||
} else if strVal, ok := value.(string); ok {
|
|
||||||
if intVal, err := strconv.Atoi(strVal); err == nil {
|
|
||||||
field.SetInt(int64(intVal))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case reflect.Float64:
|
|
||||||
if floatVal, ok := value.(float64); ok {
|
|
||||||
field.SetFloat(floatVal)
|
|
||||||
} else if strVal, ok := value.(string); ok {
|
|
||||||
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
|
|
||||||
field.SetFloat(floatVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case reflect.String:
|
|
||||||
if strVal, ok := value.(string); ok {
|
|
||||||
field.SetString(strVal)
|
|
||||||
}
|
|
||||||
case reflect.Bool:
|
|
||||||
if boolVal, ok := value.(bool); ok {
|
|
||||||
field.SetBool(boolVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Collect unknown fields into ExtraArgs
|
||||||
|
if o.ExtraArgs == nil {
|
||||||
|
o.ExtraArgs = make(map[string]string)
|
||||||
|
}
|
||||||
|
for key, value := range raw {
|
||||||
|
if !processedFields[key] {
|
||||||
|
o.ExtraArgs[key] = fmt.Sprintf("%v", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -354,6 +353,18 @@ func (o *LlamaServerOptions) Validate() error {
|
|||||||
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
|
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate extra_args keys and values
|
||||||
|
for key, value := range o.ExtraArgs {
|
||||||
|
if err := validation.ValidateStringForInjection(key); err != nil {
|
||||||
|
return validation.ValidationError(fmt.Errorf("extra_args key %q: %w", key, err))
|
||||||
|
}
|
||||||
|
if value != "" {
|
||||||
|
if err := validation.ValidateStringForInjection(value); err != nil {
|
||||||
|
return validation.ValidationError(fmt.Errorf("extra_args value for %q: %w", key, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -361,7 +372,12 @@ func (o *LlamaServerOptions) Validate() error {
|
|||||||
func (o *LlamaServerOptions) BuildCommandArgs() []string {
|
func (o *LlamaServerOptions) BuildCommandArgs() []string {
|
||||||
// Llama uses multiple flags for arrays by default (not comma-separated)
|
// Llama uses multiple flags for arrays by default (not comma-separated)
|
||||||
// Use package-level llamaMultiValuedFlags variable
|
// Use package-level llamaMultiValuedFlags variable
|
||||||
return BuildCommandArgs(o, llamaMultiValuedFlags)
|
args := BuildCommandArgs(o, llamaMultiValuedFlags)
|
||||||
|
|
||||||
|
// Append extra args at the end
|
||||||
|
args = append(args, convertExtraArgsToFlags(o.ExtraArgs)...)
|
||||||
|
|
||||||
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *LlamaServerOptions) BuildDockerArgs() []string {
|
func (o *LlamaServerOptions) BuildDockerArgs() []string {
|
||||||
|
|||||||
@@ -33,12 +33,11 @@ func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "multiple booleans",
|
name: "multiple booleans",
|
||||||
options: backends.LlamaServerOptions{
|
options: backends.LlamaServerOptions{
|
||||||
Verbose: true,
|
Verbose: true,
|
||||||
FlashAttn: true,
|
Mlock: false,
|
||||||
Mlock: false,
|
NoMmap: true,
|
||||||
NoMmap: true,
|
|
||||||
},
|
},
|
||||||
expected: []string{"--verbose", "--flash-attn", "--no-mmap"},
|
expected: []string{"--verbose", "--no-mmap"},
|
||||||
excluded: []string{"--mlock"},
|
excluded: []string{"--mlock"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -346,7 +345,7 @@ func TestParseLlamaCommand(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple value types",
|
name: "multiple value types",
|
||||||
command: "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap",
|
command: "llama-server --model /test/model.gguf --n-gpu-layers 32 --temp 0.7 --verbose --no-mmap",
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||||
if opts.Model != "/test/model.gguf" {
|
if opts.Model != "/test/model.gguf" {
|
||||||
@@ -434,3 +433,119 @@ func TestParseLlamaCommandArrays(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLlamaCppBuildCommandArgs_ExtraArgs(t *testing.T) {
|
||||||
|
options := backends.LlamaServerOptions{
|
||||||
|
Model: "/models/test.gguf",
|
||||||
|
ExtraArgs: map[string]string{
|
||||||
|
"flash-attn": "", // boolean flag
|
||||||
|
"log-file": "/logs/test.log", // value flag
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
args := options.BuildCommandArgs()
|
||||||
|
|
||||||
|
// Check that extra args are present
|
||||||
|
if !testutil.Contains(args, "--flash-attn") {
|
||||||
|
t.Error("Expected --flash-attn flag not found")
|
||||||
|
}
|
||||||
|
if !testutil.Contains(args, "--log-file") || !testutil.Contains(args, "/logs/test.log") {
|
||||||
|
t.Error("Expected --log-file flag or value not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseLlamaCommand_ExtraArgs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
command string
|
||||||
|
expectErr bool
|
||||||
|
validate func(*testing.T, *backends.LlamaServerOptions)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extra args with known fields",
|
||||||
|
command: "llama-server --model /path/to/model.gguf --gpu-layers 32 --unknown-flag value --another-bool-flag",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||||
|
if opts.Model != "/path/to/model.gguf" {
|
||||||
|
t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.GPULayers != 32 {
|
||||||
|
t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers)
|
||||||
|
}
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["unknown_flag"]; !ok || val != "value" {
|
||||||
|
t.Errorf("expected extra_args[unknown_flag]='value', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["another_bool_flag"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[another_bool_flag]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra args with alternative field names",
|
||||||
|
command: "llama-server -m /model.gguf -ngl 16 --custom-arg test --new-feature",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||||
|
// Check that alternative names worked for known fields
|
||||||
|
if opts.Model != "/model.gguf" {
|
||||||
|
t.Errorf("expected model '/model.gguf', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.GPULayers != 16 {
|
||||||
|
t.Errorf("expected gpu_layers 16, got %d", opts.GPULayers)
|
||||||
|
}
|
||||||
|
// Check that unknown args went to ExtraArgs
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["custom_arg"]; !ok || val != "test" {
|
||||||
|
t.Errorf("expected extra_args[custom_arg]='test', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["new_feature"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[new_feature]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only extra args",
|
||||||
|
command: "llama-server --experimental-feature --beta-mode enabled",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["experimental_feature"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[experimental_feature]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["beta_mode"]; !ok || val != "enabled" {
|
||||||
|
t.Errorf("expected extra_args[beta_mode]='enabled', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var opts backends.LlamaServerOptions
|
||||||
|
result, err := opts.ParseCommand(tt.command)
|
||||||
|
|
||||||
|
if tt.expectErr && err == nil {
|
||||||
|
t.Error("expected error but got none")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !tt.expectErr && err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.expectErr && tt.validate != nil {
|
||||||
|
llamaOpts, ok := result.(*backends.LlamaServerOptions)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("result is not *LlamaServerOptions")
|
||||||
|
}
|
||||||
|
tt.validate(t, llamaOpts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package backends
|
package backends
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"llamactl/pkg/validation"
|
"llamactl/pkg/validation"
|
||||||
)
|
)
|
||||||
@@ -29,6 +30,46 @@ type MlxServerOptions struct {
|
|||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
MinP float64 `json:"min_p,omitempty"`
|
MinP float64 `json:"min_p,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
|
||||||
|
// ExtraArgs are additional command line arguments.
|
||||||
|
// Example: {"verbose": "", "log-file": "/logs/mlx.log"}
|
||||||
|
ExtraArgs map[string]string `json:"extra_args,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements custom JSON unmarshaling to collect unknown fields into ExtraArgs
|
||||||
|
func (o *MlxServerOptions) UnmarshalJSON(data []byte) error {
|
||||||
|
// First unmarshal into a map to capture all fields
|
||||||
|
var raw map[string]any
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a temporary struct for standard unmarshaling
|
||||||
|
type tempOptions MlxServerOptions
|
||||||
|
temp := tempOptions{}
|
||||||
|
|
||||||
|
// Standard unmarshal first
|
||||||
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy to our struct
|
||||||
|
*o = MlxServerOptions(temp)
|
||||||
|
|
||||||
|
// Get all known canonical field names from struct tags
|
||||||
|
knownFields := getKnownFieldNames(o)
|
||||||
|
|
||||||
|
// Collect unknown fields into ExtraArgs
|
||||||
|
if o.ExtraArgs == nil {
|
||||||
|
o.ExtraArgs = make(map[string]string)
|
||||||
|
}
|
||||||
|
for key, value := range raw {
|
||||||
|
if !knownFields[key] {
|
||||||
|
o.ExtraArgs[key] = fmt.Sprintf("%v", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *MlxServerOptions) GetPort() int {
|
func (o *MlxServerOptions) GetPort() int {
|
||||||
@@ -57,13 +98,30 @@ func (o *MlxServerOptions) Validate() error {
|
|||||||
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
|
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate extra_args keys and values
|
||||||
|
for key, value := range o.ExtraArgs {
|
||||||
|
if err := validation.ValidateStringForInjection(key); err != nil {
|
||||||
|
return validation.ValidationError(fmt.Errorf("extra_args key %q: %w", key, err))
|
||||||
|
}
|
||||||
|
if value != "" {
|
||||||
|
if err := validation.ValidateStringForInjection(value); err != nil {
|
||||||
|
return validation.ValidationError(fmt.Errorf("extra_args value for %q: %w", key, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildCommandArgs converts to command line arguments
|
// BuildCommandArgs converts to command line arguments
|
||||||
func (o *MlxServerOptions) BuildCommandArgs() []string {
|
func (o *MlxServerOptions) BuildCommandArgs() []string {
|
||||||
multipleFlags := map[string]struct{}{} // MLX doesn't currently have []string fields
|
multipleFlags := map[string]struct{}{} // MLX doesn't currently have []string fields
|
||||||
return BuildCommandArgs(o, multipleFlags)
|
args := BuildCommandArgs(o, multipleFlags)
|
||||||
|
|
||||||
|
// Append extra args at the end
|
||||||
|
args = append(args, convertExtraArgsToFlags(o.ExtraArgs)...)
|
||||||
|
|
||||||
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *MlxServerOptions) BuildDockerArgs() []string {
|
func (o *MlxServerOptions) BuildDockerArgs() []string {
|
||||||
|
|||||||
@@ -202,3 +202,75 @@ func TestMlxBuildCommandArgs_ZeroValues(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseMlxCommand_ExtraArgs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
command string
|
||||||
|
expectErr bool
|
||||||
|
validate func(*testing.T, *backends.MlxServerOptions)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extra args with known fields",
|
||||||
|
command: "mlx_lm.server --model /path/to/model --port 8080 --unknown-flag value --new-bool-flag",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||||
|
if opts.Model != "/path/to/model" {
|
||||||
|
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.Port != 8080 {
|
||||||
|
t.Errorf("expected port 8080, got %d", opts.Port)
|
||||||
|
}
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["unknown_flag"]; !ok || val != "value" {
|
||||||
|
t.Errorf("expected extra_args[unknown_flag]='value', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["new_bool_flag"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[new_bool_flag]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only extra args",
|
||||||
|
command: "mlx_lm.server --experimental-feature --custom-param test",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["experimental_feature"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[experimental_feature]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["custom_param"]; !ok || val != "test" {
|
||||||
|
t.Errorf("expected extra_args[custom_param]='test', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var opts backends.MlxServerOptions
|
||||||
|
result, err := opts.ParseCommand(tt.command)
|
||||||
|
|
||||||
|
if tt.expectErr && err == nil {
|
||||||
|
t.Error("expected error but got none")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !tt.expectErr && err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.expectErr && tt.validate != nil {
|
||||||
|
mlxOpts, ok := result.(*backends.MlxServerOptions)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("result is not *MlxServerOptions")
|
||||||
|
}
|
||||||
|
tt.validate(t, mlxOpts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -211,3 +212,65 @@ func parseValue(value string) any {
|
|||||||
// Return as string
|
// Return as string
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setFieldValue sets a field value using reflection, handling type conversions
|
||||||
|
// Used by UnmarshalJSON implementations to handle alternative field names
|
||||||
|
func setFieldValue(field reflect.Value, value any) {
|
||||||
|
switch field.Kind() {
|
||||||
|
case reflect.Int:
|
||||||
|
if intVal, ok := value.(float64); ok {
|
||||||
|
field.SetInt(int64(intVal))
|
||||||
|
} else if strVal, ok := value.(string); ok {
|
||||||
|
if intVal, err := strconv.Atoi(strVal); err == nil {
|
||||||
|
field.SetInt(int64(intVal))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Float64:
|
||||||
|
if floatVal, ok := value.(float64); ok {
|
||||||
|
field.SetFloat(floatVal)
|
||||||
|
} else if strVal, ok := value.(string); ok {
|
||||||
|
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
|
||||||
|
field.SetFloat(floatVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.String:
|
||||||
|
if strVal, ok := value.(string); ok {
|
||||||
|
field.SetString(strVal)
|
||||||
|
}
|
||||||
|
case reflect.Bool:
|
||||||
|
if boolVal, ok := value.(bool); ok {
|
||||||
|
field.SetBool(boolVal)
|
||||||
|
}
|
||||||
|
case reflect.Slice:
|
||||||
|
// Handle string slices
|
||||||
|
if field.Type().Elem().Kind() == reflect.String {
|
||||||
|
if slice, ok := value.([]any); ok {
|
||||||
|
strSlice := make([]string, 0, len(slice))
|
||||||
|
for _, v := range slice {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
strSlice = append(strSlice, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
field.Set(reflect.ValueOf(strSlice))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getKnownFieldNames extracts all known field names from struct json tags
|
||||||
|
// Used by UnmarshalJSON implementations to identify unknown fields for ExtraArgs
|
||||||
|
func getKnownFieldNames(v any) map[string]bool {
|
||||||
|
fields := make(map[string]bool)
|
||||||
|
t := reflect.TypeOf(v).Elem()
|
||||||
|
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
field := t.Field(i)
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
if jsonTag != "" && jsonTag != "-" {
|
||||||
|
// Handle "name,omitempty" format
|
||||||
|
name := strings.Split(jsonTag, ",")[0]
|
||||||
|
fields[name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package backends
|
package backends
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"llamactl/pkg/validation"
|
"llamactl/pkg/validation"
|
||||||
)
|
)
|
||||||
@@ -142,6 +143,46 @@ type VllmServerOptions struct {
|
|||||||
OverridePoolingConfig string `json:"override_pooling_config,omitempty"`
|
OverridePoolingConfig string `json:"override_pooling_config,omitempty"`
|
||||||
OverrideNeuronConfig string `json:"override_neuron_config,omitempty"`
|
OverrideNeuronConfig string `json:"override_neuron_config,omitempty"`
|
||||||
OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"`
|
OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"`
|
||||||
|
|
||||||
|
// ExtraArgs are additional command line arguments.
|
||||||
|
// Example: {"verbose": "", "log-file": "/logs/vllm.log"}
|
||||||
|
ExtraArgs map[string]string `json:"extra_args,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements custom JSON unmarshaling to collect unknown fields into ExtraArgs
|
||||||
|
func (o *VllmServerOptions) UnmarshalJSON(data []byte) error {
|
||||||
|
// First unmarshal into a map to capture all fields
|
||||||
|
var raw map[string]any
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a temporary struct for standard unmarshaling
|
||||||
|
type tempOptions VllmServerOptions
|
||||||
|
temp := tempOptions{}
|
||||||
|
|
||||||
|
// Standard unmarshal first
|
||||||
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy to our struct
|
||||||
|
*o = VllmServerOptions(temp)
|
||||||
|
|
||||||
|
// Get all known canonical field names from struct tags
|
||||||
|
knownFields := getKnownFieldNames(o)
|
||||||
|
|
||||||
|
// Collect unknown fields into ExtraArgs
|
||||||
|
if o.ExtraArgs == nil {
|
||||||
|
o.ExtraArgs = make(map[string]string)
|
||||||
|
}
|
||||||
|
for key, value := range raw {
|
||||||
|
if !knownFields[key] {
|
||||||
|
o.ExtraArgs[key] = fmt.Sprintf("%v", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *VllmServerOptions) GetPort() int {
|
func (o *VllmServerOptions) GetPort() int {
|
||||||
@@ -171,6 +212,18 @@ func (o *VllmServerOptions) Validate() error {
|
|||||||
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
|
return validation.ValidationError(fmt.Errorf("invalid port range: %d", o.Port))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate extra_args keys and values
|
||||||
|
for key, value := range o.ExtraArgs {
|
||||||
|
if err := validation.ValidateStringForInjection(key); err != nil {
|
||||||
|
return validation.ValidationError(fmt.Errorf("extra_args key %q: %w", key, err))
|
||||||
|
}
|
||||||
|
if value != "" {
|
||||||
|
if err := validation.ValidateStringForInjection(value); err != nil {
|
||||||
|
return validation.ValidationError(fmt.Errorf("extra_args value for %q: %w", key, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,6 +246,9 @@ func (o *VllmServerOptions) BuildCommandArgs() []string {
|
|||||||
flagArgs := BuildCommandArgs(&optionsCopy, vllmMultiValuedFlags)
|
flagArgs := BuildCommandArgs(&optionsCopy, vllmMultiValuedFlags)
|
||||||
args = append(args, flagArgs...)
|
args = append(args, flagArgs...)
|
||||||
|
|
||||||
|
// Append extra args at the end
|
||||||
|
args = append(args, convertExtraArgsToFlags(o.ExtraArgs)...)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,6 +259,9 @@ func (o *VllmServerOptions) BuildDockerArgs() []string {
|
|||||||
flagArgs := BuildCommandArgs(o, vllmMultiValuedFlags)
|
flagArgs := BuildCommandArgs(o, vllmMultiValuedFlags)
|
||||||
args = append(args, flagArgs...)
|
args = append(args, flagArgs...)
|
||||||
|
|
||||||
|
// Append extra args at the end
|
||||||
|
args = append(args, convertExtraArgsToFlags(o.ExtraArgs)...)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -321,3 +321,94 @@ func TestVllmBuildCommandArgs_PositionalModel(t *testing.T) {
|
|||||||
t.Errorf("Expected --port 8080 not found in %v", args)
|
t.Errorf("Expected --port 8080 not found in %v", args)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseVllmCommand_ExtraArgs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
command string
|
||||||
|
expectErr bool
|
||||||
|
validate func(*testing.T, *backends.VllmServerOptions)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extra args with known fields",
|
||||||
|
command: "vllm serve llama-model --tensor-parallel-size 2 --unknown-flag value --new-bool-flag",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||||
|
if opts.Model != "llama-model" {
|
||||||
|
t.Errorf("expected model 'llama-model', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.TensorParallelSize != 2 {
|
||||||
|
t.Errorf("expected tensor_parallel_size 2, got %d", opts.TensorParallelSize)
|
||||||
|
}
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["unknown_flag"]; !ok || val != "value" {
|
||||||
|
t.Errorf("expected extra_args[unknown_flag]='value', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["new_bool_flag"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[new_bool_flag]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only extra args",
|
||||||
|
command: "vllm serve model --experimental-feature --custom-param test",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["experimental_feature"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[experimental_feature]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["custom_param"]; !ok || val != "test" {
|
||||||
|
t.Errorf("expected extra_args[custom_param]='test', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra args without model positional",
|
||||||
|
command: "vllm serve --model my-model --new-feature enabled --beta-flag",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||||
|
if opts.Model != "my-model" {
|
||||||
|
t.Errorf("expected model 'my-model', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["new_feature"]; !ok || val != "enabled" {
|
||||||
|
t.Errorf("expected extra_args[new_feature]='enabled', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["beta_flag"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[beta_flag]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var opts backends.VllmServerOptions
|
||||||
|
result, err := opts.ParseCommand(tt.command)
|
||||||
|
|
||||||
|
if tt.expectErr && err == nil {
|
||||||
|
t.Error("expected error but got none")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !tt.expectErr && err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.expectErr && tt.validate != nil {
|
||||||
|
vllmOpts, ok := result.(*backends.VllmServerOptions)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("result is not *VllmServerOptions")
|
||||||
|
}
|
||||||
|
tt.validate(t, vllmOpts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -239,25 +239,3 @@ func TestValidateInstanceOptions_MultipleFieldInjection(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateInstanceOptions_NonStringFields(t *testing.T) {
|
|
||||||
// Test that non-string fields don't interfere with validation
|
|
||||||
options := backends.Options{
|
|
||||||
BackendType: backends.BackendTypeLlamaCpp,
|
|
||||||
LlamaServerOptions: &backends.LlamaServerOptions{
|
|
||||||
Port: 8080,
|
|
||||||
GPULayers: 32,
|
|
||||||
CtxSize: 4096,
|
|
||||||
Temperature: 0.7,
|
|
||||||
TopK: 40,
|
|
||||||
TopP: 0.9,
|
|
||||||
Verbose: true,
|
|
||||||
FlashAttn: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := options.ValidateInstanceOptions()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("ValidateInstanceOptions with non-string fields should not error, got: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,17 +3,31 @@ import { Input } from '@/components/ui/input'
|
|||||||
import { Label } from '@/components/ui/label'
|
import { Label } from '@/components/ui/label'
|
||||||
import { Checkbox } from '@/components/ui/checkbox'
|
import { Checkbox } from '@/components/ui/checkbox'
|
||||||
import { getBackendFieldType, basicBackendFieldsConfig } from '@/lib/zodFormUtils'
|
import { getBackendFieldType, basicBackendFieldsConfig } from '@/lib/zodFormUtils'
|
||||||
|
import ExtraArgsInput from '@/components/form/ExtraArgsInput'
|
||||||
|
|
||||||
interface BackendFormFieldProps {
|
interface BackendFormFieldProps {
|
||||||
fieldKey: string
|
fieldKey: string
|
||||||
value: string | number | boolean | string[] | undefined
|
value: string | number | boolean | string[] | Record<string, string> | undefined
|
||||||
onChange: (key: string, value: string | number | boolean | string[] | undefined) => void
|
onChange: (key: string, value: string | number | boolean | string[] | Record<string, string> | undefined) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
const BackendFormField: React.FC<BackendFormFieldProps> = ({ fieldKey, value, onChange }) => {
|
const BackendFormField: React.FC<BackendFormFieldProps> = ({ fieldKey, value, onChange }) => {
|
||||||
|
// Special handling for extra_args
|
||||||
|
if (fieldKey === 'extra_args') {
|
||||||
|
return (
|
||||||
|
<ExtraArgsInput
|
||||||
|
id={fieldKey}
|
||||||
|
label="Extra Arguments"
|
||||||
|
value={value as Record<string, string> | undefined}
|
||||||
|
onChange={(newValue) => onChange(fieldKey, newValue)}
|
||||||
|
description="Additional command line arguments to pass to the backend"
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// Get configuration for basic fields, or use field name for advanced fields
|
// Get configuration for basic fields, or use field name for advanced fields
|
||||||
const config = basicBackendFieldsConfig[fieldKey] || { label: fieldKey }
|
const config = basicBackendFieldsConfig[fieldKey] || { label: fieldKey }
|
||||||
|
|
||||||
// Get type from Zod schema
|
// Get type from Zod schema
|
||||||
const fieldType = getBackendFieldType(fieldKey)
|
const fieldType = getBackendFieldType(fieldKey)
|
||||||
|
|
||||||
|
|||||||
27
webui/src/components/form/EnvVarsInput.tsx
Normal file
27
webui/src/components/form/EnvVarsInput.tsx
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import KeyValueInput from './KeyValueInput'
|
||||||
|
|
||||||
|
interface EnvVarsInputProps {
|
||||||
|
id: string
|
||||||
|
label: string
|
||||||
|
value: Record<string, string> | undefined
|
||||||
|
onChange: (value: Record<string, string> | undefined) => void
|
||||||
|
description?: string
|
||||||
|
disabled?: boolean
|
||||||
|
className?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const EnvVarsInput: React.FC<EnvVarsInputProps> = (props) => {
|
||||||
|
return (
|
||||||
|
<KeyValueInput
|
||||||
|
{...props}
|
||||||
|
keyPlaceholder="Variable name"
|
||||||
|
valuePlaceholder="Variable value"
|
||||||
|
addButtonText="Add Variable"
|
||||||
|
helperText="Environment variables that will be passed to the backend process"
|
||||||
|
allowEmptyValues={false}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default EnvVarsInput
|
||||||
@@ -1,144 +0,0 @@
|
|||||||
import React, { useState } from 'react'
|
|
||||||
import { Input } from '@/components/ui/input'
|
|
||||||
import { Label } from '@/components/ui/label'
|
|
||||||
import { Button } from '@/components/ui/button'
|
|
||||||
import { X, Plus } from 'lucide-react'
|
|
||||||
|
|
||||||
interface EnvironmentVariablesInputProps {
|
|
||||||
id: string
|
|
||||||
label: string
|
|
||||||
value: Record<string, string> | undefined
|
|
||||||
onChange: (value: Record<string, string> | undefined) => void
|
|
||||||
description?: string
|
|
||||||
disabled?: boolean
|
|
||||||
className?: string
|
|
||||||
}
|
|
||||||
|
|
||||||
interface EnvVar {
|
|
||||||
key: string
|
|
||||||
value: string
|
|
||||||
}
|
|
||||||
|
|
||||||
const EnvironmentVariablesInput: React.FC<EnvironmentVariablesInputProps> = ({
|
|
||||||
id,
|
|
||||||
label,
|
|
||||||
value,
|
|
||||||
onChange,
|
|
||||||
description,
|
|
||||||
disabled = false,
|
|
||||||
className
|
|
||||||
}) => {
|
|
||||||
// Convert the value object to an array of key-value pairs for editing
|
|
||||||
const envVarsFromValue = value
|
|
||||||
? Object.entries(value).map(([key, val]) => ({ key, value: val }))
|
|
||||||
: []
|
|
||||||
|
|
||||||
const [envVars, setEnvVars] = useState<EnvVar[]>(
|
|
||||||
envVarsFromValue.length > 0 ? envVarsFromValue : [{ key: '', value: '' }]
|
|
||||||
)
|
|
||||||
|
|
||||||
// Update parent component when env vars change
|
|
||||||
const updateParent = (newEnvVars: EnvVar[]) => {
|
|
||||||
// Filter out empty entries
|
|
||||||
const validVars = newEnvVars.filter(env => env.key.trim() !== '' && env.value.trim() !== '')
|
|
||||||
|
|
||||||
if (validVars.length === 0) {
|
|
||||||
onChange(undefined)
|
|
||||||
} else {
|
|
||||||
const envObject = validVars.reduce((acc, env) => {
|
|
||||||
acc[env.key.trim()] = env.value.trim()
|
|
||||||
return acc
|
|
||||||
}, {} as Record<string, string>)
|
|
||||||
onChange(envObject)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const handleKeyChange = (index: number, newKey: string) => {
|
|
||||||
const newEnvVars = [...envVars]
|
|
||||||
newEnvVars[index].key = newKey
|
|
||||||
setEnvVars(newEnvVars)
|
|
||||||
updateParent(newEnvVars)
|
|
||||||
}
|
|
||||||
|
|
||||||
const handleValueChange = (index: number, newValue: string) => {
|
|
||||||
const newEnvVars = [...envVars]
|
|
||||||
newEnvVars[index].value = newValue
|
|
||||||
setEnvVars(newEnvVars)
|
|
||||||
updateParent(newEnvVars)
|
|
||||||
}
|
|
||||||
|
|
||||||
const addEnvVar = () => {
|
|
||||||
const newEnvVars = [...envVars, { key: '', value: '' }]
|
|
||||||
setEnvVars(newEnvVars)
|
|
||||||
}
|
|
||||||
|
|
||||||
const removeEnvVar = (index: number) => {
|
|
||||||
if (envVars.length === 1) {
|
|
||||||
// Reset to empty if it's the last one
|
|
||||||
const newEnvVars = [{ key: '', value: '' }]
|
|
||||||
setEnvVars(newEnvVars)
|
|
||||||
updateParent(newEnvVars)
|
|
||||||
} else {
|
|
||||||
const newEnvVars = envVars.filter((_, i) => i !== index)
|
|
||||||
setEnvVars(newEnvVars)
|
|
||||||
updateParent(newEnvVars)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className={`grid gap-2 ${className || ''}`}>
|
|
||||||
<Label htmlFor={id}>
|
|
||||||
{label}
|
|
||||||
</Label>
|
|
||||||
<div className="space-y-2">
|
|
||||||
{envVars.map((envVar, index) => (
|
|
||||||
<div key={index} className="flex gap-2 items-center">
|
|
||||||
<Input
|
|
||||||
placeholder="Variable name"
|
|
||||||
value={envVar.key}
|
|
||||||
onChange={(e) => handleKeyChange(index, e.target.value)}
|
|
||||||
disabled={disabled}
|
|
||||||
className="flex-1"
|
|
||||||
/>
|
|
||||||
<Input
|
|
||||||
placeholder="Variable value"
|
|
||||||
value={envVar.value}
|
|
||||||
onChange={(e) => handleValueChange(index, e.target.value)}
|
|
||||||
disabled={disabled}
|
|
||||||
className="flex-1"
|
|
||||||
/>
|
|
||||||
<Button
|
|
||||||
type="button"
|
|
||||||
variant="outline"
|
|
||||||
size="sm"
|
|
||||||
onClick={() => removeEnvVar(index)}
|
|
||||||
disabled={disabled}
|
|
||||||
className="shrink-0"
|
|
||||||
>
|
|
||||||
<X className="h-4 w-4" />
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
<Button
|
|
||||||
type="button"
|
|
||||||
variant="outline"
|
|
||||||
size="sm"
|
|
||||||
onClick={addEnvVar}
|
|
||||||
disabled={disabled}
|
|
||||||
className="w-fit"
|
|
||||||
>
|
|
||||||
<Plus className="h-4 w-4 mr-2" />
|
|
||||||
Add Variable
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
{description && (
|
|
||||||
<p className="text-sm text-muted-foreground">{description}</p>
|
|
||||||
)}
|
|
||||||
<p className="text-xs text-muted-foreground">
|
|
||||||
Environment variables that will be passed to the backend process
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
export default EnvironmentVariablesInput
|
|
||||||
27
webui/src/components/form/ExtraArgsInput.tsx
Normal file
27
webui/src/components/form/ExtraArgsInput.tsx
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import KeyValueInput from './KeyValueInput'
|
||||||
|
|
||||||
|
interface ExtraArgsInputProps {
|
||||||
|
id: string
|
||||||
|
label: string
|
||||||
|
value: Record<string, string> | undefined
|
||||||
|
onChange: (value: Record<string, string> | undefined) => void
|
||||||
|
description?: string
|
||||||
|
disabled?: boolean
|
||||||
|
className?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const ExtraArgsInput: React.FC<ExtraArgsInputProps> = (props) => {
|
||||||
|
return (
|
||||||
|
<KeyValueInput
|
||||||
|
{...props}
|
||||||
|
keyPlaceholder="Flag name (without --)"
|
||||||
|
valuePlaceholder="Value (empty for boolean flags)"
|
||||||
|
addButtonText="Add Argument"
|
||||||
|
helperText="Additional command line arguments to pass to the backend. Leave value empty for boolean flags."
|
||||||
|
allowEmptyValues={true}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default ExtraArgsInput
|
||||||
171
webui/src/components/form/KeyValueInput.tsx
Normal file
171
webui/src/components/form/KeyValueInput.tsx
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
import React, { useState, useEffect } from 'react'
|
||||||
|
import { Input } from '@/components/ui/input'
|
||||||
|
import { Label } from '@/components/ui/label'
|
||||||
|
import { Button } from '@/components/ui/button'
|
||||||
|
import { X, Plus } from 'lucide-react'
|
||||||
|
|
||||||
|
interface KeyValueInputProps {
|
||||||
|
id: string
|
||||||
|
label: string
|
||||||
|
value: Record<string, string> | undefined
|
||||||
|
onChange: (value: Record<string, string> | undefined) => void
|
||||||
|
description?: string
|
||||||
|
disabled?: boolean
|
||||||
|
className?: string
|
||||||
|
keyPlaceholder?: string
|
||||||
|
valuePlaceholder?: string
|
||||||
|
addButtonText?: string
|
||||||
|
helperText?: string
|
||||||
|
allowEmptyValues?: boolean // If true, entries with empty values are considered valid
|
||||||
|
}
|
||||||
|
|
||||||
|
interface KeyValuePair {
|
||||||
|
key: string
|
||||||
|
value: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const KeyValueInput: React.FC<KeyValueInputProps> = ({
|
||||||
|
id,
|
||||||
|
label,
|
||||||
|
value,
|
||||||
|
onChange,
|
||||||
|
description,
|
||||||
|
disabled = false,
|
||||||
|
className,
|
||||||
|
keyPlaceholder = 'Key',
|
||||||
|
valuePlaceholder = 'Value',
|
||||||
|
addButtonText = 'Add Entry',
|
||||||
|
helperText,
|
||||||
|
allowEmptyValues = false
|
||||||
|
}) => {
|
||||||
|
// Convert the value object to an array of key-value pairs for editing
|
||||||
|
const pairsFromValue = value
|
||||||
|
? Object.entries(value).map(([key, val]) => ({ key, value: val }))
|
||||||
|
: []
|
||||||
|
|
||||||
|
const [pairs, setPairs] = useState<KeyValuePair[]>(
|
||||||
|
pairsFromValue.length > 0 ? pairsFromValue : [{ key: '', value: '' }]
|
||||||
|
)
|
||||||
|
|
||||||
|
// Sync internal state when value prop changes
|
||||||
|
useEffect(() => {
|
||||||
|
const newPairsFromValue = value
|
||||||
|
? Object.entries(value).map(([key, val]) => ({ key, value: val }))
|
||||||
|
: []
|
||||||
|
|
||||||
|
if (newPairsFromValue.length > 0) {
|
||||||
|
setPairs(newPairsFromValue)
|
||||||
|
} else if (!value) {
|
||||||
|
// Reset to single empty row if value is explicitly undefined/null
|
||||||
|
setPairs([{ key: '', value: '' }])
|
||||||
|
}
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [value])
|
||||||
|
|
||||||
|
// Update parent component when pairs change
|
||||||
|
const updateParent = (newPairs: KeyValuePair[]) => {
|
||||||
|
// Filter based on validation rules
|
||||||
|
const validPairs = allowEmptyValues
|
||||||
|
? newPairs.filter(pair => pair.key.trim() !== '')
|
||||||
|
: newPairs.filter(pair => pair.key.trim() !== '' && pair.value.trim() !== '')
|
||||||
|
|
||||||
|
if (validPairs.length === 0) {
|
||||||
|
onChange(undefined)
|
||||||
|
} else {
|
||||||
|
const pairsObject = validPairs.reduce((acc, pair) => {
|
||||||
|
acc[pair.key.trim()] = pair.value.trim()
|
||||||
|
return acc
|
||||||
|
}, {} as Record<string, string>)
|
||||||
|
onChange(pairsObject)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleKeyChange = (index: number, newKey: string) => {
|
||||||
|
const newPairs = [...pairs]
|
||||||
|
newPairs[index].key = newKey
|
||||||
|
setPairs(newPairs)
|
||||||
|
updateParent(newPairs)
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleValueChange = (index: number, newValue: string) => {
|
||||||
|
const newPairs = [...pairs]
|
||||||
|
newPairs[index].value = newValue
|
||||||
|
setPairs(newPairs)
|
||||||
|
updateParent(newPairs)
|
||||||
|
}
|
||||||
|
|
||||||
|
const addPair = () => {
|
||||||
|
const newPairs = [...pairs, { key: '', value: '' }]
|
||||||
|
setPairs(newPairs)
|
||||||
|
}
|
||||||
|
|
||||||
|
const removePair = (index: number) => {
|
||||||
|
if (pairs.length === 1) {
|
||||||
|
// Reset to empty if it's the last one
|
||||||
|
const newPairs = [{ key: '', value: '' }]
|
||||||
|
setPairs(newPairs)
|
||||||
|
updateParent(newPairs)
|
||||||
|
} else {
|
||||||
|
const newPairs = pairs.filter((_, i) => i !== index)
|
||||||
|
setPairs(newPairs)
|
||||||
|
updateParent(newPairs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={`grid gap-2 ${className || ''}`}>
|
||||||
|
<Label htmlFor={id}>
|
||||||
|
{label}
|
||||||
|
</Label>
|
||||||
|
<div className="space-y-2">
|
||||||
|
{pairs.map((pair, index) => (
|
||||||
|
<div key={index} className="flex gap-2 items-center">
|
||||||
|
<Input
|
||||||
|
placeholder={keyPlaceholder}
|
||||||
|
value={pair.key}
|
||||||
|
onChange={(e) => handleKeyChange(index, e.target.value)}
|
||||||
|
disabled={disabled}
|
||||||
|
className="flex-1"
|
||||||
|
/>
|
||||||
|
<Input
|
||||||
|
placeholder={valuePlaceholder}
|
||||||
|
value={pair.value}
|
||||||
|
onChange={(e) => handleValueChange(index, e.target.value)}
|
||||||
|
disabled={disabled}
|
||||||
|
className="flex-1"
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
onClick={() => removePair(index)}
|
||||||
|
disabled={disabled}
|
||||||
|
className="shrink-0"
|
||||||
|
>
|
||||||
|
<X className="h-4 w-4" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
onClick={addPair}
|
||||||
|
disabled={disabled}
|
||||||
|
className="w-fit"
|
||||||
|
>
|
||||||
|
<Plus className="h-4 w-4 mr-2" />
|
||||||
|
{addButtonText}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
{description && (
|
||||||
|
<p className="text-sm text-muted-foreground">{description}</p>
|
||||||
|
)}
|
||||||
|
{helperText && (
|
||||||
|
<p className="text-xs text-muted-foreground">{helperText}</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default KeyValueInput
|
||||||
@@ -47,8 +47,18 @@ const BackendConfiguration: React.FC<BackendConfigurationProps> = ({
|
|||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* Extra Args - Always visible as a separate section */}
|
||||||
|
<div className="space-y-4">
|
||||||
|
<BackendFormField
|
||||||
|
key="extra_args"
|
||||||
|
fieldKey="extra_args"
|
||||||
|
value={(formData.backend_options as any)?.extra_args}
|
||||||
|
onChange={onBackendFieldChange}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
export default BackendConfiguration
|
export default BackendConfiguration
|
||||||
|
|||||||
@@ -109,6 +109,16 @@ const BackendConfigurationCard: React.FC<BackendConfigurationCardProps> = ({
|
|||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* Extra Arguments - Always visible */}
|
||||||
|
<div className="space-y-4">
|
||||||
|
<BackendFormField
|
||||||
|
key="extra_args"
|
||||||
|
fieldKey="extra_args"
|
||||||
|
value={(formData.backend_options as Record<string, unknown>)?.extra_args as Record<string, string> | undefined}
|
||||||
|
onChange={onBackendFieldChange}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import { Input } from '@/components/ui/input'
|
|||||||
import AutoRestartConfiguration from '@/components/instance/AutoRestartConfiguration'
|
import AutoRestartConfiguration from '@/components/instance/AutoRestartConfiguration'
|
||||||
import NumberInput from '@/components/form/NumberInput'
|
import NumberInput from '@/components/form/NumberInput'
|
||||||
import CheckboxInput from '@/components/form/CheckboxInput'
|
import CheckboxInput from '@/components/form/CheckboxInput'
|
||||||
import EnvironmentVariablesInput from '@/components/form/EnvironmentVariablesInput'
|
import EnvVarsInput from '@/components/form/EnvVarsInput'
|
||||||
import SelectInput from '@/components/form/SelectInput'
|
import SelectInput from '@/components/form/SelectInput'
|
||||||
import { nodesApi, type NodesMap } from '@/lib/api'
|
import { nodesApi, type NodesMap } from '@/lib/api'
|
||||||
|
|
||||||
@@ -132,7 +132,7 @@ const InstanceSettingsCard: React.FC<InstanceSettingsCardProps> = ({
|
|||||||
description="Start instance only when needed"
|
description="Start instance only when needed"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<EnvironmentVariablesInput
|
<EnvVarsInput
|
||||||
id="environment"
|
id="environment"
|
||||||
label="Environment Variables"
|
label="Environment Variables"
|
||||||
value={formData.environment}
|
value={formData.environment}
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ export function getAdvancedBackendFields(backendType?: string): string[] {
|
|||||||
const fieldGetter = backendFieldGetters[normalizedType] || getAllLlamaCppFieldKeys
|
const fieldGetter = backendFieldGetters[normalizedType] || getAllLlamaCppFieldKeys
|
||||||
const basicConfig = backendFieldConfigs[normalizedType] || basicLlamaCppFieldsConfig
|
const basicConfig = backendFieldConfigs[normalizedType] || basicLlamaCppFieldsConfig
|
||||||
|
|
||||||
return fieldGetter().filter(key => !(key in basicConfig))
|
return fieldGetter().filter(key => !(key in basicConfig) && key !== 'extra_args')
|
||||||
}
|
}
|
||||||
|
|
||||||
// Combined backend fields config for use in BackendFormField
|
// Combined backend fields config for use in BackendFormField
|
||||||
|
|||||||
@@ -167,6 +167,9 @@ export const LlamaCppBackendOptionsSchema = z.object({
|
|||||||
fim_qwen_7b_default: z.boolean().optional(),
|
fim_qwen_7b_default: z.boolean().optional(),
|
||||||
fim_qwen_7b_spec: z.boolean().optional(),
|
fim_qwen_7b_spec: z.boolean().optional(),
|
||||||
fim_qwen_14b_spec: z.boolean().optional(),
|
fim_qwen_14b_spec: z.boolean().optional(),
|
||||||
|
|
||||||
|
// Extra args
|
||||||
|
extra_args: z.record(z.string(), z.string()).optional(),
|
||||||
})
|
})
|
||||||
|
|
||||||
// Infer the TypeScript type from the schema
|
// Infer the TypeScript type from the schema
|
||||||
|
|||||||
@@ -25,6 +25,9 @@ export const MlxBackendOptionsSchema = z.object({
|
|||||||
top_k: z.number().optional(),
|
top_k: z.number().optional(),
|
||||||
min_p: z.number().optional(),
|
min_p: z.number().optional(),
|
||||||
max_tokens: z.number().optional(),
|
max_tokens: z.number().optional(),
|
||||||
|
|
||||||
|
// Extra args
|
||||||
|
extra_args: z.record(z.string(), z.string()).optional(),
|
||||||
})
|
})
|
||||||
|
|
||||||
// Infer the TypeScript type from the schema
|
// Infer the TypeScript type from the schema
|
||||||
|
|||||||
@@ -125,6 +125,9 @@ export const VllmBackendOptionsSchema = z.object({
|
|||||||
override_pooling_config: z.string().optional(),
|
override_pooling_config: z.string().optional(),
|
||||||
override_neuron_config: z.string().optional(),
|
override_neuron_config: z.string().optional(),
|
||||||
override_kv_cache_align_size: z.number().optional(),
|
override_kv_cache_align_size: z.number().optional(),
|
||||||
|
|
||||||
|
// Extra args
|
||||||
|
extra_args: z.record(z.string(), z.string()).optional(),
|
||||||
})
|
})
|
||||||
|
|
||||||
// Infer the TypeScript type from the schema
|
// Infer the TypeScript type from the schema
|
||||||
|
|||||||
Reference in New Issue
Block a user