mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-12-22 17:14:22 +00:00
Add support for extra args for command parser
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"llamactl/pkg/validation"
|
"llamactl/pkg/validation"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
// llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
||||||
@@ -213,6 +212,15 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
|
|||||||
// Copy to our struct
|
// Copy to our struct
|
||||||
*o = LlamaServerOptions(temp)
|
*o = LlamaServerOptions(temp)
|
||||||
|
|
||||||
|
// Track which fields we've processed
|
||||||
|
processedFields := make(map[string]bool)
|
||||||
|
|
||||||
|
// Get all known canonical field names from struct tags
|
||||||
|
knownFields := getKnownFieldNames(o)
|
||||||
|
for field := range knownFields {
|
||||||
|
processedFields[field] = true
|
||||||
|
}
|
||||||
|
|
||||||
// Handle alternative field names
|
// Handle alternative field names
|
||||||
fieldMappings := map[string]string{
|
fieldMappings := map[string]string{
|
||||||
// Common params
|
// Common params
|
||||||
@@ -273,8 +281,8 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
|
|||||||
"rerank": "reranking", // --reranking
|
"rerank": "reranking", // --reranking
|
||||||
"to": "timeout", // -to, --timeout N
|
"to": "timeout", // -to, --timeout N
|
||||||
"sps": "slot_prompt_similarity", // -sps, --slot-prompt-similarity
|
"sps": "slot_prompt_similarity", // -sps, --slot-prompt-similarity
|
||||||
"draft": "draft-max", // -draft, --draft-max N
|
"draft": "draft_max", // -draft, --draft-max N
|
||||||
"draft_n": "draft-max", // --draft-n-max N
|
"draft_n": "draft_max", // --draft-n-max N
|
||||||
"draft_n_min": "draft_min", // --draft-n-min N
|
"draft_n_min": "draft_min", // --draft-n-min N
|
||||||
"cd": "ctx_size_draft", // -cd, --ctx-size-draft N
|
"cd": "ctx_size_draft", // -cd, --ctx-size-draft N
|
||||||
"devd": "device_draft", // -devd, --device-draft
|
"devd": "device_draft", // -devd, --device-draft
|
||||||
@@ -286,8 +294,10 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
|
|||||||
"mv": "model_vocoder", // -mv, --model-vocoder FNAME
|
"mv": "model_vocoder", // -mv, --model-vocoder FNAME
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process alternative field names
|
// Process alternative field names and mark them as processed
|
||||||
for altName, canonicalName := range fieldMappings {
|
for altName, canonicalName := range fieldMappings {
|
||||||
|
processedFields[altName] = true // Mark alternatives as known
|
||||||
|
|
||||||
if value, exists := raw[altName]; exists {
|
if value, exists := raw[altName]; exists {
|
||||||
// Use reflection to set the field value
|
// Use reflection to set the field value
|
||||||
v := reflect.ValueOf(o).Elem()
|
v := reflect.ValueOf(o).Elem()
|
||||||
@@ -298,36 +308,21 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if field.IsValid() && field.CanSet() {
|
if field.IsValid() && field.CanSet() {
|
||||||
switch field.Kind() {
|
setFieldValue(field, value)
|
||||||
case reflect.Int:
|
|
||||||
if intVal, ok := value.(float64); ok {
|
|
||||||
field.SetInt(int64(intVal))
|
|
||||||
} else if strVal, ok := value.(string); ok {
|
|
||||||
if intVal, err := strconv.Atoi(strVal); err == nil {
|
|
||||||
field.SetInt(int64(intVal))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case reflect.Float64:
|
|
||||||
if floatVal, ok := value.(float64); ok {
|
|
||||||
field.SetFloat(floatVal)
|
|
||||||
} else if strVal, ok := value.(string); ok {
|
|
||||||
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
|
|
||||||
field.SetFloat(floatVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case reflect.String:
|
|
||||||
if strVal, ok := value.(string); ok {
|
|
||||||
field.SetString(strVal)
|
|
||||||
}
|
|
||||||
case reflect.Bool:
|
|
||||||
if boolVal, ok := value.(bool); ok {
|
|
||||||
field.SetBool(boolVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Collect unknown fields into ExtraArgs
|
||||||
|
if o.ExtraArgs == nil {
|
||||||
|
o.ExtraArgs = make(map[string]string)
|
||||||
|
}
|
||||||
|
for key, value := range raw {
|
||||||
|
if !processedFields[key] {
|
||||||
|
o.ExtraArgs[key] = fmt.Sprintf("%v", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -453,3 +453,99 @@ func TestLlamaCppBuildCommandArgs_ExtraArgs(t *testing.T) {
|
|||||||
t.Error("Expected --log-file flag or value not found")
|
t.Error("Expected --log-file flag or value not found")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseLlamaCommand_ExtraArgs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
command string
|
||||||
|
expectErr bool
|
||||||
|
validate func(*testing.T, *backends.LlamaServerOptions)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extra args with known fields",
|
||||||
|
command: "llama-server --model /path/to/model.gguf --gpu-layers 32 --unknown-flag value --another-bool-flag",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||||
|
if opts.Model != "/path/to/model.gguf" {
|
||||||
|
t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.GPULayers != 32 {
|
||||||
|
t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers)
|
||||||
|
}
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["unknown_flag"]; !ok || val != "value" {
|
||||||
|
t.Errorf("expected extra_args[unknown_flag]='value', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["another_bool_flag"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[another_bool_flag]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra args with alternative field names",
|
||||||
|
command: "llama-server -m /model.gguf -ngl 16 --custom-arg test --new-feature",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||||
|
// Check that alternative names worked for known fields
|
||||||
|
if opts.Model != "/model.gguf" {
|
||||||
|
t.Errorf("expected model '/model.gguf', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.GPULayers != 16 {
|
||||||
|
t.Errorf("expected gpu_layers 16, got %d", opts.GPULayers)
|
||||||
|
}
|
||||||
|
// Check that unknown args went to ExtraArgs
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["custom_arg"]; !ok || val != "test" {
|
||||||
|
t.Errorf("expected extra_args[custom_arg]='test', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["new_feature"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[new_feature]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only extra args",
|
||||||
|
command: "llama-server --experimental-feature --beta-mode enabled",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["experimental_feature"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[experimental_feature]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["beta_mode"]; !ok || val != "enabled" {
|
||||||
|
t.Errorf("expected extra_args[beta_mode]='enabled', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var opts backends.LlamaServerOptions
|
||||||
|
result, err := opts.ParseCommand(tt.command)
|
||||||
|
|
||||||
|
if tt.expectErr && err == nil {
|
||||||
|
t.Error("expected error but got none")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !tt.expectErr && err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.expectErr && tt.validate != nil {
|
||||||
|
llamaOpts, ok := result.(*backends.LlamaServerOptions)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("result is not *LlamaServerOptions")
|
||||||
|
}
|
||||||
|
tt.validate(t, llamaOpts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package backends
|
package backends
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"llamactl/pkg/validation"
|
"llamactl/pkg/validation"
|
||||||
)
|
)
|
||||||
@@ -35,6 +36,42 @@ type MlxServerOptions struct {
|
|||||||
ExtraArgs map[string]string `json:"extra_args,omitempty"`
|
ExtraArgs map[string]string `json:"extra_args,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements custom JSON unmarshaling to collect unknown fields into ExtraArgs
|
||||||
|
func (o *MlxServerOptions) UnmarshalJSON(data []byte) error {
|
||||||
|
// First unmarshal into a map to capture all fields
|
||||||
|
var raw map[string]any
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a temporary struct for standard unmarshaling
|
||||||
|
type tempOptions MlxServerOptions
|
||||||
|
temp := tempOptions{}
|
||||||
|
|
||||||
|
// Standard unmarshal first
|
||||||
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy to our struct
|
||||||
|
*o = MlxServerOptions(temp)
|
||||||
|
|
||||||
|
// Get all known canonical field names from struct tags
|
||||||
|
knownFields := getKnownFieldNames(o)
|
||||||
|
|
||||||
|
// Collect unknown fields into ExtraArgs
|
||||||
|
if o.ExtraArgs == nil {
|
||||||
|
o.ExtraArgs = make(map[string]string)
|
||||||
|
}
|
||||||
|
for key, value := range raw {
|
||||||
|
if !knownFields[key] {
|
||||||
|
o.ExtraArgs[key] = fmt.Sprintf("%v", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (o *MlxServerOptions) GetPort() int {
|
func (o *MlxServerOptions) GetPort() int {
|
||||||
return o.Port
|
return o.Port
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -202,3 +202,75 @@ func TestMlxBuildCommandArgs_ZeroValues(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseMlxCommand_ExtraArgs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
command string
|
||||||
|
expectErr bool
|
||||||
|
validate func(*testing.T, *backends.MlxServerOptions)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extra args with known fields",
|
||||||
|
command: "mlx_lm.server --model /path/to/model --port 8080 --unknown-flag value --new-bool-flag",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||||
|
if opts.Model != "/path/to/model" {
|
||||||
|
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.Port != 8080 {
|
||||||
|
t.Errorf("expected port 8080, got %d", opts.Port)
|
||||||
|
}
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["unknown_flag"]; !ok || val != "value" {
|
||||||
|
t.Errorf("expected extra_args[unknown_flag]='value', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["new_bool_flag"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[new_bool_flag]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only extra args",
|
||||||
|
command: "mlx_lm.server --experimental-feature --custom-param test",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["experimental_feature"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[experimental_feature]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["custom_param"]; !ok || val != "test" {
|
||||||
|
t.Errorf("expected extra_args[custom_param]='test', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var opts backends.MlxServerOptions
|
||||||
|
result, err := opts.ParseCommand(tt.command)
|
||||||
|
|
||||||
|
if tt.expectErr && err == nil {
|
||||||
|
t.Error("expected error but got none")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !tt.expectErr && err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.expectErr && tt.validate != nil {
|
||||||
|
mlxOpts, ok := result.(*backends.MlxServerOptions)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("result is not *MlxServerOptions")
|
||||||
|
}
|
||||||
|
tt.validate(t, mlxOpts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -211,3 +212,65 @@ func parseValue(value string) any {
|
|||||||
// Return as string
|
// Return as string
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setFieldValue sets a field value using reflection, handling type conversions
|
||||||
|
// Used by UnmarshalJSON implementations to handle alternative field names
|
||||||
|
func setFieldValue(field reflect.Value, value any) {
|
||||||
|
switch field.Kind() {
|
||||||
|
case reflect.Int:
|
||||||
|
if intVal, ok := value.(float64); ok {
|
||||||
|
field.SetInt(int64(intVal))
|
||||||
|
} else if strVal, ok := value.(string); ok {
|
||||||
|
if intVal, err := strconv.Atoi(strVal); err == nil {
|
||||||
|
field.SetInt(int64(intVal))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Float64:
|
||||||
|
if floatVal, ok := value.(float64); ok {
|
||||||
|
field.SetFloat(floatVal)
|
||||||
|
} else if strVal, ok := value.(string); ok {
|
||||||
|
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
|
||||||
|
field.SetFloat(floatVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.String:
|
||||||
|
if strVal, ok := value.(string); ok {
|
||||||
|
field.SetString(strVal)
|
||||||
|
}
|
||||||
|
case reflect.Bool:
|
||||||
|
if boolVal, ok := value.(bool); ok {
|
||||||
|
field.SetBool(boolVal)
|
||||||
|
}
|
||||||
|
case reflect.Slice:
|
||||||
|
// Handle string slices
|
||||||
|
if field.Type().Elem().Kind() == reflect.String {
|
||||||
|
if slice, ok := value.([]any); ok {
|
||||||
|
strSlice := make([]string, 0, len(slice))
|
||||||
|
for _, v := range slice {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
strSlice = append(strSlice, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
field.Set(reflect.ValueOf(strSlice))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getKnownFieldNames extracts all known field names from struct json tags
|
||||||
|
// Used by UnmarshalJSON implementations to identify unknown fields for ExtraArgs
|
||||||
|
func getKnownFieldNames(v any) map[string]bool {
|
||||||
|
fields := make(map[string]bool)
|
||||||
|
t := reflect.TypeOf(v).Elem()
|
||||||
|
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
field := t.Field(i)
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
if jsonTag != "" && jsonTag != "-" {
|
||||||
|
// Handle "name,omitempty" format
|
||||||
|
name := strings.Split(jsonTag, ",")[0]
|
||||||
|
fields[name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package backends
|
package backends
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"llamactl/pkg/validation"
|
"llamactl/pkg/validation"
|
||||||
)
|
)
|
||||||
@@ -148,6 +149,42 @@ type VllmServerOptions struct {
|
|||||||
ExtraArgs map[string]string `json:"extra_args,omitempty"`
|
ExtraArgs map[string]string `json:"extra_args,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements custom JSON unmarshaling to collect unknown fields into ExtraArgs
|
||||||
|
func (o *VllmServerOptions) UnmarshalJSON(data []byte) error {
|
||||||
|
// First unmarshal into a map to capture all fields
|
||||||
|
var raw map[string]any
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a temporary struct for standard unmarshaling
|
||||||
|
type tempOptions VllmServerOptions
|
||||||
|
temp := tempOptions{}
|
||||||
|
|
||||||
|
// Standard unmarshal first
|
||||||
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy to our struct
|
||||||
|
*o = VllmServerOptions(temp)
|
||||||
|
|
||||||
|
// Get all known canonical field names from struct tags
|
||||||
|
knownFields := getKnownFieldNames(o)
|
||||||
|
|
||||||
|
// Collect unknown fields into ExtraArgs
|
||||||
|
if o.ExtraArgs == nil {
|
||||||
|
o.ExtraArgs = make(map[string]string)
|
||||||
|
}
|
||||||
|
for key, value := range raw {
|
||||||
|
if !knownFields[key] {
|
||||||
|
o.ExtraArgs[key] = fmt.Sprintf("%v", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (o *VllmServerOptions) GetPort() int {
|
func (o *VllmServerOptions) GetPort() int {
|
||||||
return o.Port
|
return o.Port
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -321,3 +321,94 @@ func TestVllmBuildCommandArgs_PositionalModel(t *testing.T) {
|
|||||||
t.Errorf("Expected --port 8080 not found in %v", args)
|
t.Errorf("Expected --port 8080 not found in %v", args)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseVllmCommand_ExtraArgs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
command string
|
||||||
|
expectErr bool
|
||||||
|
validate func(*testing.T, *backends.VllmServerOptions)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extra args with known fields",
|
||||||
|
command: "vllm serve llama-model --tensor-parallel-size 2 --unknown-flag value --new-bool-flag",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||||
|
if opts.Model != "llama-model" {
|
||||||
|
t.Errorf("expected model 'llama-model', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.TensorParallelSize != 2 {
|
||||||
|
t.Errorf("expected tensor_parallel_size 2, got %d", opts.TensorParallelSize)
|
||||||
|
}
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["unknown_flag"]; !ok || val != "value" {
|
||||||
|
t.Errorf("expected extra_args[unknown_flag]='value', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["new_bool_flag"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[new_bool_flag]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only extra args",
|
||||||
|
command: "vllm serve model --experimental-feature --custom-param test",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["experimental_feature"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[experimental_feature]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["custom_param"]; !ok || val != "test" {
|
||||||
|
t.Errorf("expected extra_args[custom_param]='test', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra args without model positional",
|
||||||
|
command: "vllm serve --model my-model --new-feature enabled --beta-flag",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||||
|
if opts.Model != "my-model" {
|
||||||
|
t.Errorf("expected model 'my-model', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.ExtraArgs == nil {
|
||||||
|
t.Fatal("expected extra_args to be non-nil")
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["new_feature"]; !ok || val != "enabled" {
|
||||||
|
t.Errorf("expected extra_args[new_feature]='enabled', got '%s'", val)
|
||||||
|
}
|
||||||
|
if val, ok := opts.ExtraArgs["beta_flag"]; !ok || val != "true" {
|
||||||
|
t.Errorf("expected extra_args[beta_flag]='true', got '%s'", val)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var opts backends.VllmServerOptions
|
||||||
|
result, err := opts.ParseCommand(tt.command)
|
||||||
|
|
||||||
|
if tt.expectErr && err == nil {
|
||||||
|
t.Error("expected error but got none")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !tt.expectErr && err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.expectErr && tt.validate != nil {
|
||||||
|
vllmOpts, ok := result.(*backends.VllmServerOptions)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("result is not *VllmServerOptions")
|
||||||
|
}
|
||||||
|
tt.validate(t, vllmOpts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user