diff --git a/pkg/backends/llamacpp/parser_test.go b/pkg/backends/llamacpp/parser_test.go index 60e6a19..7072f65 100644 --- a/pkg/backends/llamacpp/parser_test.go +++ b/pkg/backends/llamacpp/parser_test.go @@ -1,6 +1,7 @@ -package llamacpp +package llamacpp_test import ( + "llamactl/pkg/backends/llamacpp" "testing" ) @@ -11,28 +12,23 @@ func TestParseLlamaCommand(t *testing.T) { expectErr bool }{ { - name: "basic command with model", - command: "llama-server --model /path/to/model.gguf", + name: "basic command", + command: "llama-server --model /path/to/model.gguf --gpu-layers 32", expectErr: false, }, { - name: "command with multiple flags", - command: "llama-server --model /path/to/model.gguf --gpu-layers 32 --ctx-size 4096", + name: "args only", + command: "--model /path/to/model.gguf --ctx-size 4096", expectErr: false, }, { - name: "command with short flags", - command: "llama-server -m /path/to/model.gguf -ngl 32 -c 4096", + name: "mixed flag formats", + command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose", expectErr: false, }, { - name: "command with equals format", - command: "llama-server --model=/path/to/model.gguf --gpu-layers=32", - expectErr: false, - }, - { - name: "command with boolean flags", - command: "llama-server --model /path/to/model.gguf --verbose --no-mmap", + name: "quoted strings", + command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`, expectErr: false, }, { @@ -41,46 +37,20 @@ func TestParseLlamaCommand(t *testing.T) { expectErr: true, }, { - name: "case insensitive command", - command: "LLAMA-SERVER --model /path/to/model.gguf", - expectErr: false, - }, - // New test cases for improved functionality - { - name: "args only without llama-server", - command: "--model /path/to/model.gguf --gpu-layers 32", - expectErr: false, + name: "unterminated quote", + command: `llama-server --model test.gguf --api-key "unterminated`, + expectErr: true, }, { - name: "full path to executable", - command: "/usr/local/bin/llama-server --model /path/to/model.gguf", - expectErr: false, - }, - { - name: "negative number handling", - command: "llama-server --gpu-layers -1 --model test.gguf", - expectErr: false, - }, - { - name: "multiline command with backslashes", - command: "llama-server --model /path/to/model.gguf \\\n --ctx-size 4096 \\\n --batch-size 512", - expectErr: false, - }, - { - name: "quoted string with special characters", - command: `llama-server --model test.gguf --chat-template "{% for message in messages %}{{ message.role }}: {{ message.content }}\n{% endfor %}"`, - expectErr: false, - }, - { - name: "unterminated quoted string", - command: `llama-server --model test.gguf --chat-template "unterminated quote`, + name: "malformed flag", + command: "llama-server ---model test.gguf", expectErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := ParseLlamaCommand(tt.command) + result, err := llamacpp.ParseLlamaCommand(tt.command) if tt.expectErr { if err == nil { @@ -96,16 +66,14 @@ func TestParseLlamaCommand(t *testing.T) { if result == nil { t.Errorf("expected result but got nil") - return } }) } } -func TestParseLlamaCommandSpecificValues(t *testing.T) { - // Test specific value parsing - command := "llama-server --model /test/model.gguf --gpu-layers 32 --ctx-size 4096 --verbose" - result, err := ParseLlamaCommand(command) +func TestParseLlamaCommandValues(t *testing.T) { + command := "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap" + result, err := llamacpp.ParseLlamaCommand(command) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -119,19 +87,22 @@ func TestParseLlamaCommandSpecificValues(t *testing.T) { t.Errorf("expected gpu_layers 32, got %d", result.GPULayers) } - if result.CtxSize != 4096 { - t.Errorf("expected ctx_size 4096, got %d", result.CtxSize) + if result.Temperature != 0.7 { + t.Errorf("expected temperature 0.7, got %f", result.Temperature) } if !result.Verbose { - t.Errorf("expected verbose to be true, got %v", result.Verbose) + t.Errorf("expected verbose to be true") + } + + if !result.NoMmap { + t.Errorf("expected no_mmap to be true") } } -func TestParseLlamaCommandArrayFlags(t *testing.T) { - // Test array flag handling (critical for lora, override-tensor, etc.) - command := "llama-server --model test.gguf --lora adapter1.bin --lora adapter2.bin" - result, err := ParseLlamaCommand(command) +func TestParseLlamaCommandArrays(t *testing.T) { + command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin" + result, err := llamacpp.ParseLlamaCommand(command) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -141,273 +112,10 @@ func TestParseLlamaCommandArrayFlags(t *testing.T) { t.Errorf("expected 2 lora adapters, got %d", len(result.Lora)) } - if result.Lora[0] != "adapter1.bin" || result.Lora[1] != "adapter2.bin" { - t.Errorf("expected lora adapters [adapter1.bin, adapter2.bin], got %v", result.Lora) - } -} - -func TestParseLlamaCommandMixedFormats(t *testing.T) { - // Test mixing --flag=value and --flag value formats - command := "llama-server --model=/path/model.gguf --gpu-layers 16 --batch-size=512 --verbose" - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "/path/model.gguf" { - t.Errorf("expected model '/path/model.gguf', got '%s'", result.Model) - } - - if result.GPULayers != 16 { - t.Errorf("expected gpu_layers 16, got %d", result.GPULayers) - } - - if result.BatchSize != 512 { - t.Errorf("expected batch_size 512, got %d", result.BatchSize) - } - - if !result.Verbose { - t.Errorf("expected verbose to be true, got %v", result.Verbose) - } -} - -func TestParseLlamaCommandTypeConversion(t *testing.T) { - // Test that values are converted to appropriate types - command := "llama-server --model test.gguf --temp 0.7 --top-k 40 --no-mmap" - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Temperature != 0.7 { - t.Errorf("expected temperature 0.7, got %f", result.Temperature) - } - - if result.TopK != 40 { - t.Errorf("expected top_k 40, got %d", result.TopK) - } - - if !result.NoMmap { - t.Errorf("expected no_mmap to be true, got %v", result.NoMmap) - } -} - -func TestParseLlamaCommandArgsOnly(t *testing.T) { - // Test parsing arguments without llama-server command - command := "--model /path/to/model.gguf --gpu-layers 32 --ctx-size 4096" - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "/path/to/model.gguf" { - t.Errorf("expected model '/path/to/model.gguf', got '%s'", result.Model) - } - - if result.GPULayers != 32 { - t.Errorf("expected gpu_layers 32, got %d", result.GPULayers) - } - - if result.CtxSize != 4096 { - t.Errorf("expected ctx_size 4096, got %d", result.CtxSize) - } -} - -func TestParseLlamaCommandFullPath(t *testing.T) { - // Test full path to executable - command := "/usr/local/bin/llama-server --model test.gguf --gpu-layers 16" - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "test.gguf" { - t.Errorf("expected model 'test.gguf', got '%s'", result.Model) - } - - if result.GPULayers != 16 { - t.Errorf("expected gpu_layers 16, got %d", result.GPULayers) - } -} - -func TestParseLlamaCommandNegativeNumbers(t *testing.T) { - // Test negative number parsing - command := "llama-server --model test.gguf --gpu-layers -1 --seed -12345" - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.GPULayers != -1 { - t.Errorf("expected gpu_layers -1, got %d", result.GPULayers) - } - - if result.Seed != -12345 { - t.Errorf("expected seed -12345, got %d", result.Seed) - } -} - -func TestParseLlamaCommandMultiline(t *testing.T) { - // Test multiline command with backslashes - command := `llama-server --model /path/to/model.gguf \ - --ctx-size 4096 \ - --batch-size 512 \ - --gpu-layers 32` - - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "/path/to/model.gguf" { - t.Errorf("expected model '/path/to/model.gguf', got '%s'", result.Model) - } - - if result.CtxSize != 4096 { - t.Errorf("expected ctx_size 4096, got %d", result.CtxSize) - } - - if result.BatchSize != 512 { - t.Errorf("expected batch_size 512, got %d", result.BatchSize) - } - - if result.GPULayers != 32 { - t.Errorf("expected gpu_layers 32, got %d", result.GPULayers) - } -} - -func TestParseLlamaCommandQuotedStrings(t *testing.T) { - // Test quoted strings with special characters - command := `llama-server --model test.gguf --api-key "sk-1234567890abcdef" --chat-template "User: {user}\nAssistant: "` - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "test.gguf" { - t.Errorf("expected model 'test.gguf', got '%s'", result.Model) - } - - if result.APIKey != "sk-1234567890abcdef" { - t.Errorf("expected api_key 'sk-1234567890abcdef', got '%s'", result.APIKey) - } - - expectedTemplate := "User: {user}\\nAssistant: " - if result.ChatTemplate != expectedTemplate { - t.Errorf("expected chat_template '%s', got '%s'", expectedTemplate, result.ChatTemplate) - } -} - -func TestParseLlamaCommandUnslothExample(t *testing.T) { - // Test with realistic unsloth-style command - command := `llama-server --model /path/to/model.gguf \ - --ctx-size 4096 \ - --batch-size 512 \ - --gpu-layers -1 \ - --temp 0.7 \ - --repeat-penalty 1.1 \ - --top-k 40 \ - --top-p 0.95 \ - --host 0.0.0.0 \ - --port 8000 \ - --api-key "sk-1234567890abcdef"` - - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // Verify key fields - if result.Model != "/path/to/model.gguf" { - t.Errorf("expected model '/path/to/model.gguf', got '%s'", result.Model) - } - - if result.CtxSize != 4096 { - t.Errorf("expected ctx_size 4096, got %d", result.CtxSize) - } - - if result.BatchSize != 512 { - t.Errorf("expected batch_size 512, got %d", result.BatchSize) - } - - if result.GPULayers != -1 { - t.Errorf("expected gpu_layers -1, got %d", result.GPULayers) - } - - if result.Temperature != 0.7 { - t.Errorf("expected temperature 0.7, got %f", result.Temperature) - } - - if result.RepeatPenalty != 1.1 { - t.Errorf("expected repeat_penalty 1.1, got %f", result.RepeatPenalty) - } - - if result.TopK != 40 { - t.Errorf("expected top_k 40, got %d", result.TopK) - } - - if result.TopP != 0.95 { - t.Errorf("expected top_p 0.95, got %f", result.TopP) - } - - if result.Host != "0.0.0.0" { - t.Errorf("expected host '0.0.0.0', got '%s'", result.Host) - } - - if result.Port != 8000 { - t.Errorf("expected port 8000, got %d", result.Port) - } - - if result.APIKey != "sk-1234567890abcdef" { - t.Errorf("expected api_key 'sk-1234567890abcdef', got '%s'", result.APIKey) - } -} - -// Focused additional edge case tests (kept minimal per guidance) -func TestParseLlamaCommandSingleQuotedValue(t *testing.T) { - cmd := "llama-server --model 'my model.gguf' --alias 'Test Alias'" - result, err := ParseLlamaCommand(cmd) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result.Model != "my model.gguf" { - t.Errorf("expected model 'my model.gguf', got '%s'", result.Model) - } - if result.Alias != "Test Alias" { - t.Errorf("expected alias 'Test Alias', got '%s'", result.Alias) - } -} - -func TestParseLlamaCommandMixedArrayForms(t *testing.T) { - // Same multi-value flag using --flag value and --flag=value forms - cmd := "llama-server --lora adapter1.bin --lora=adapter2.bin --lora adapter3.bin" - result, err := ParseLlamaCommand(cmd) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(result.Lora) != 3 { - t.Fatalf("expected 3 lora values, got %d (%v)", len(result.Lora), result.Lora) - } - expected := []string{"adapter1.bin", "adapter2.bin", "adapter3.bin"} + expected := []string{"adapter1.bin", "adapter2.bin"} for i, v := range expected { if result.Lora[i] != v { t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i]) } } } - -func TestParseLlamaCommandMalformedFlag(t *testing.T) { - cmd := "llama-server ---model test.gguf" - _, err := ParseLlamaCommand(cmd) - if err == nil { - t.Fatalf("expected error for malformed flag but got none") - } -} diff --git a/pkg/backends/mlx/mlx.go b/pkg/backends/mlx/mlx.go index c3324d2..8527c7b 100644 --- a/pkg/backends/mlx/mlx.go +++ b/pkg/backends/mlx/mlx.go @@ -1,205 +1,88 @@ package mlx import ( - "encoding/json" "reflect" "strconv" + "strings" ) type MlxServerOptions struct { // Basic connection options - Model string `json:"model,omitempty"` - Host string `json:"host,omitempty"` - Port int `json:"port,omitempty"` - + Model string `json:"model,omitempty"` + Host string `json:"host,omitempty"` + Port int `json:"port,omitempty"` + // Model and adapter options AdapterPath string `json:"adapter_path,omitempty"` DraftModel string `json:"draft_model,omitempty"` NumDraftTokens int `json:"num_draft_tokens,omitempty"` TrustRemoteCode bool `json:"trust_remote_code,omitempty"` - + // Logging and templates - LogLevel string `json:"log_level,omitempty"` - ChatTemplate string `json:"chat_template,omitempty"` - UseDefaultChatTemplate bool `json:"use_default_chat_template,omitempty"` - ChatTemplateArgs string `json:"chat_template_args,omitempty"` // JSON string - + LogLevel string `json:"log_level,omitempty"` + ChatTemplate string `json:"chat_template,omitempty"` + UseDefaultChatTemplate bool `json:"use_default_chat_template,omitempty"` + ChatTemplateArgs string `json:"chat_template_args,omitempty"` // JSON string + // Sampling defaults - Temp float64 `json:"temp,omitempty"` // Note: MLX uses "temp" not "temperature" - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - MinP float64 `json:"min_p,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` + Temp float64 `json:"temp,omitempty"` // Note: MLX uses "temp" not "temperature" + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + MinP float64 `json:"min_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` } -// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names -func (o *MlxServerOptions) UnmarshalJSON(data []byte) error { - // First unmarshal into a map to handle multiple field names - var raw map[string]any - if err := json.Unmarshal(data, &raw); err != nil { - return err - } +// BuildCommandArgs converts to command line arguments using reflection +func (o *MlxServerOptions) BuildCommandArgs() []string { + var args []string - // Create a temporary struct for standard unmarshaling - type tempOptions MlxServerOptions - temp := tempOptions{} + v := reflect.ValueOf(o).Elem() + t := v.Type() - // Standard unmarshal first - if err := json.Unmarshal(data, &temp); err != nil { - return err - } + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldType := t.Field(i) - // Copy to our struct - *o = MlxServerOptions(temp) + // Skip unexported fields + if !field.CanInterface() { + continue + } - // Handle alternative field names - fieldMappings := map[string]string{ - // Basic connection options - "m": "model", - "host": "host", - "port": "port", -// "python_path": "python_path", // removed - - // Model and adapter options - "adapter-path": "adapter_path", - "draft-model": "draft_model", - "num-draft-tokens": "num_draft_tokens", - "trust-remote-code": "trust_remote_code", - - // Logging and templates - "log-level": "log_level", - "chat-template": "chat_template", - "use-default-chat-template": "use_default_chat_template", - "chat-template-args": "chat_template_args", - - // Sampling defaults - "temperature": "temp", // Support both temp and temperature - "top-p": "top_p", - "top-k": "top_k", - "min-p": "min_p", - "max-tokens": "max_tokens", - } + // Get the JSON tag to determine the flag name + jsonTag := fieldType.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } - // Process alternative field names - for altName, canonicalName := range fieldMappings { - if value, exists := raw[altName]; exists { - // Use reflection to set the field value - v := reflect.ValueOf(o).Elem() - field := v.FieldByNameFunc(func(fieldName string) bool { - field, _ := v.Type().FieldByName(fieldName) - jsonTag := field.Tag.Get("json") - return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName - }) + // Remove ",omitempty" from the tag + flagName := jsonTag + if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 { + flagName = jsonTag[:commaIndex] + } - if field.IsValid() && field.CanSet() { - switch field.Kind() { - case reflect.Int: - if intVal, ok := value.(float64); ok { - field.SetInt(int64(intVal)) - } else if strVal, ok := value.(string); ok { - if intVal, err := strconv.Atoi(strVal); err == nil { - field.SetInt(int64(intVal)) - } - } - case reflect.Float64: - if floatVal, ok := value.(float64); ok { - field.SetFloat(floatVal) - } else if strVal, ok := value.(string); ok { - if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil { - field.SetFloat(floatVal) - } - } - case reflect.String: - if strVal, ok := value.(string); ok { - field.SetString(strVal) - } - case reflect.Bool: - if boolVal, ok := value.(bool); ok { - field.SetBool(boolVal) - } - } + // Convert snake_case to kebab-case for CLI flags + flagName = strings.ReplaceAll(flagName, "_", "-") + + // Add the appropriate arguments based on field type and value + switch field.Kind() { + case reflect.Bool: + if field.Bool() { + args = append(args, "--"+flagName) + } + case reflect.Int: + if field.Int() != 0 { + args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) + } + case reflect.Float64: + if field.Float() != 0 { + args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) + } + case reflect.String: + if field.String() != "" { + args = append(args, "--"+flagName, field.String()) } } } - return nil -} - -// NewMlxServerOptions creates MlxServerOptions with MLX defaults -func NewMlxServerOptions() *MlxServerOptions { - return &MlxServerOptions{ - Host: "127.0.0.1", // MLX default (different from llama-server) - Port: 8080, // MLX default - NumDraftTokens: 3, // MLX default for speculative decoding - LogLevel: "INFO", // MLX default - Temp: 0.0, // MLX default - TopP: 1.0, // MLX default - TopK: 0, // MLX default (disabled) - MinP: 0.0, // MLX default (disabled) - MaxTokens: 512, // MLX default - ChatTemplateArgs: "{}", // MLX default (empty JSON object) - } -} - -// BuildCommandArgs converts to command line arguments -func (o *MlxServerOptions) BuildCommandArgs() []string { - var args []string - - // Required and basic options - if o.Model != "" { - args = append(args, "--model", o.Model) - } - if o.Host != "" { - args = append(args, "--host", o.Host) - } - if o.Port != 0 { - args = append(args, "--port", strconv.Itoa(o.Port)) - } - - // Model and adapter options - if o.AdapterPath != "" { - args = append(args, "--adapter-path", o.AdapterPath) - } - if o.DraftModel != "" { - args = append(args, "--draft-model", o.DraftModel) - } - if o.NumDraftTokens != 0 { - args = append(args, "--num-draft-tokens", strconv.Itoa(o.NumDraftTokens)) - } - if o.TrustRemoteCode { - args = append(args, "--trust-remote-code") - } - - // Logging and templates - if o.LogLevel != "" { - args = append(args, "--log-level", o.LogLevel) - } - if o.ChatTemplate != "" { - args = append(args, "--chat-template", o.ChatTemplate) - } - if o.UseDefaultChatTemplate { - args = append(args, "--use-default-chat-template") - } - if o.ChatTemplateArgs != "" { - args = append(args, "--chat-template-args", o.ChatTemplateArgs) - } - - // Sampling defaults - if o.Temp != 0 { - args = append(args, "--temp", strconv.FormatFloat(o.Temp, 'f', -1, 64)) - } - if o.TopP != 0 { - args = append(args, "--top-p", strconv.FormatFloat(o.TopP, 'f', -1, 64)) - } - if o.TopK != 0 { - args = append(args, "--top-k", strconv.Itoa(o.TopK)) - } - if o.MinP != 0 { - args = append(args, "--min-p", strconv.FormatFloat(o.MinP, 'f', -1, 64)) - } - if o.MaxTokens != 0 { - args = append(args, "--max-tokens", strconv.Itoa(o.MaxTokens)) - } - return args -} \ No newline at end of file +} diff --git a/pkg/backends/mlx/mlx_test.go b/pkg/backends/mlx/mlx_test.go new file mode 100644 index 0000000..b35f512 --- /dev/null +++ b/pkg/backends/mlx/mlx_test.go @@ -0,0 +1,62 @@ +package mlx_test + +import ( + "llamactl/pkg/backends/mlx" + "testing" +) + +func TestBuildCommandArgs(t *testing.T) { + options := &mlx.MlxServerOptions{ + Model: "/test/model.mlx", + Host: "127.0.0.1", + Port: 8080, + Temp: 0.7, + TopP: 0.9, + TopK: 40, + MaxTokens: 2048, + TrustRemoteCode: true, + LogLevel: "DEBUG", + ChatTemplate: "custom template", + } + + args := options.BuildCommandArgs() + + // Check that all expected flags are present + expectedFlags := map[string]string{ + "--model": "/test/model.mlx", + "--host": "127.0.0.1", + "--port": "8080", + "--log-level": "DEBUG", + "--chat-template": "custom template", + "--temp": "0.7", + "--top-p": "0.9", + "--top-k": "40", + "--max-tokens": "2048", + } + + for i := 0; i < len(args); i++ { + if args[i] == "--trust-remote-code" { + continue // Boolean flag with no value + } + if args[i] == "--use-default-chat-template" { + continue // Boolean flag with no value + } + + if expectedValue, exists := expectedFlags[args[i]]; exists && i+1 < len(args) { + if args[i+1] != expectedValue { + t.Errorf("expected %s to have value %s, got %s", args[i], expectedValue, args[i+1]) + } + } + } + + // Check boolean flags + foundTrustRemoteCode := false + for _, arg := range args { + if arg == "--trust-remote-code" { + foundTrustRemoteCode = true + } + } + if !foundTrustRemoteCode { + t.Errorf("expected --trust-remote-code flag to be present") + } +} diff --git a/pkg/backends/mlx/parser_test.go b/pkg/backends/mlx/parser_test.go new file mode 100644 index 0000000..6caae84 --- /dev/null +++ b/pkg/backends/mlx/parser_test.go @@ -0,0 +1,101 @@ +package mlx_test + +import ( + "llamactl/pkg/backends/mlx" + "testing" +) + +func TestParseMlxCommand(t *testing.T) { + tests := []struct { + name string + command string + expectErr bool + }{ + { + name: "basic command", + command: "mlx_lm.server --model /path/to/model --host 0.0.0.0", + expectErr: false, + }, + { + name: "args only", + command: "--model /path/to/model --port 8080", + expectErr: false, + }, + { + name: "mixed flag formats", + command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code", + expectErr: false, + }, + { + name: "quoted strings", + command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`, + expectErr: false, + }, + { + name: "empty command", + command: "", + expectErr: true, + }, + { + name: "unterminated quote", + command: `mlx_lm.server --model test.mlx --chat-template "unterminated`, + expectErr: true, + }, + { + name: "malformed flag", + command: "mlx_lm.server ---model test.mlx", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := mlx.ParseMlxCommand(tt.command) + + if tt.expectErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result == nil { + t.Errorf("expected result but got nil") + } + }) + } +} + +func TestParseMlxCommandValues(t *testing.T) { + command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG" + result, err := mlx.ParseMlxCommand(command) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Model != "/test/model.mlx" { + t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model) + } + + if result.Port != 8080 { + t.Errorf("expected port 8080, got %d", result.Port) + } + + if result.Temp != 0.7 { + t.Errorf("expected temp 0.7, got %f", result.Temp) + } + + if !result.TrustRemoteCode { + t.Errorf("expected trust_remote_code to be true") + } + + if result.LogLevel != "DEBUG" { + t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel) + } +} diff --git a/pkg/backends/vllm/parser_test.go b/pkg/backends/vllm/parser_test.go index 91921b2..3a12456 100644 --- a/pkg/backends/vllm/parser_test.go +++ b/pkg/backends/vllm/parser_test.go @@ -1,6 +1,7 @@ -package vllm +package vllm_test import ( + "llamactl/pkg/backends/vllm" "testing" ) @@ -39,7 +40,7 @@ func TestParseVllmCommand(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := ParseVllmCommand(tt.command) + result, err := vllm.ParseVllmCommand(tt.command) if tt.expectErr { if err == nil { @@ -62,7 +63,7 @@ func TestParseVllmCommand(t *testing.T) { func TestParseVllmCommandValues(t *testing.T) { command := "vllm serve --model test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs" - result, err := ParseVllmCommand(command) + result, err := vllm.ParseVllmCommand(command) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -80,4 +81,4 @@ func TestParseVllmCommandValues(t *testing.T) { if !result.EnableLogOutputs { t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs) } -} \ No newline at end of file +} diff --git a/pkg/backends/vllm/vllm.go b/pkg/backends/vllm/vllm.go index 6378b5e..9aa865c 100644 --- a/pkg/backends/vllm/vllm.go +++ b/pkg/backends/vllm/vllm.go @@ -1,7 +1,6 @@ package vllm import ( - "encoding/json" "reflect" "strconv" "strings" @@ -13,349 +12,124 @@ type VllmServerOptions struct { Port int `json:"port,omitempty"` // Model and engine configuration - Model string `json:"model,omitempty"` - Tokenizer string `json:"tokenizer,omitempty"` - SkipTokenizerInit bool `json:"skip_tokenizer_init,omitempty"` - Revision string `json:"revision,omitempty"` - CodeRevision string `json:"code_revision,omitempty"` - TokenizerRevision string `json:"tokenizer_revision,omitempty"` - TokenizerMode string `json:"tokenizer_mode,omitempty"` - TrustRemoteCode bool `json:"trust_remote_code,omitempty"` - DownloadDir string `json:"download_dir,omitempty"` - LoadFormat string `json:"load_format,omitempty"` - ConfigFormat string `json:"config_format,omitempty"` - Dtype string `json:"dtype,omitempty"` - KVCacheDtype string `json:"kv_cache_dtype,omitempty"` - QuantizationParamPath string `json:"quantization_param_path,omitempty"` - Seed int `json:"seed,omitempty"` - MaxModelLen int `json:"max_model_len,omitempty"` - GuidedDecodingBackend string `json:"guided_decoding_backend,omitempty"` - DistributedExecutorBackend string `json:"distributed_executor_backend,omitempty"` - WorkerUseRay bool `json:"worker_use_ray,omitempty"` - RayWorkersUseNSight bool `json:"ray_workers_use_nsight,omitempty"` + Model string `json:"model,omitempty"` + Tokenizer string `json:"tokenizer,omitempty"` + SkipTokenizerInit bool `json:"skip_tokenizer_init,omitempty"` + Revision string `json:"revision,omitempty"` + CodeRevision string `json:"code_revision,omitempty"` + TokenizerRevision string `json:"tokenizer_revision,omitempty"` + TokenizerMode string `json:"tokenizer_mode,omitempty"` + TrustRemoteCode bool `json:"trust_remote_code,omitempty"` + DownloadDir string `json:"download_dir,omitempty"` + LoadFormat string `json:"load_format,omitempty"` + ConfigFormat string `json:"config_format,omitempty"` + Dtype string `json:"dtype,omitempty"` + KVCacheDtype string `json:"kv_cache_dtype,omitempty"` + QuantizationParamPath string `json:"quantization_param_path,omitempty"` + Seed int `json:"seed,omitempty"` + MaxModelLen int `json:"max_model_len,omitempty"` + GuidedDecodingBackend string `json:"guided_decoding_backend,omitempty"` + DistributedExecutorBackend string `json:"distributed_executor_backend,omitempty"` + WorkerUseRay bool `json:"worker_use_ray,omitempty"` + RayWorkersUseNSight bool `json:"ray_workers_use_nsight,omitempty"` // Performance and serving configuration - BlockSize int `json:"block_size,omitempty"` - EnablePrefixCaching bool `json:"enable_prefix_caching,omitempty"` - DisableSlidingWindow bool `json:"disable_sliding_window,omitempty"` - UseV2BlockManager bool `json:"use_v2_block_manager,omitempty"` - NumLookaheadSlots int `json:"num_lookahead_slots,omitempty"` - SwapSpace int `json:"swap_space,omitempty"` - CPUOffloadGB int `json:"cpu_offload_gb,omitempty"` - GPUMemoryUtilization float64 `json:"gpu_memory_utilization,omitempty"` - NumGPUBlocksOverride int `json:"num_gpu_blocks_override,omitempty"` - MaxNumBatchedTokens int `json:"max_num_batched_tokens,omitempty"` - MaxNumSeqs int `json:"max_num_seqs,omitempty"` - MaxLogprobs int `json:"max_logprobs,omitempty"` - DisableLogStats bool `json:"disable_log_stats,omitempty"` - Quantization string `json:"quantization,omitempty"` - RopeScaling string `json:"rope_scaling,omitempty"` - RopeTheta float64 `json:"rope_theta,omitempty"` - EnforceEager bool `json:"enforce_eager,omitempty"` - MaxContextLenToCapture int `json:"max_context_len_to_capture,omitempty"` - MaxSeqLenToCapture int `json:"max_seq_len_to_capture,omitempty"` - DisableCustomAllReduce bool `json:"disable_custom_all_reduce,omitempty"` - TokenizerPoolSize int `json:"tokenizer_pool_size,omitempty"` - TokenizerPoolType string `json:"tokenizer_pool_type,omitempty"` - TokenizerPoolExtraConfig string `json:"tokenizer_pool_extra_config,omitempty"` - EnableLoraBias bool `json:"enable_lora_bias,omitempty"` - LoraExtraVocabSize int `json:"lora_extra_vocab_size,omitempty"` - LoraRank int `json:"lora_rank,omitempty"` - PromptLookbackDistance int `json:"prompt_lookback_distance,omitempty"` - PreemptionMode string `json:"preemption_mode,omitempty"` + BlockSize int `json:"block_size,omitempty"` + EnablePrefixCaching bool `json:"enable_prefix_caching,omitempty"` + DisableSlidingWindow bool `json:"disable_sliding_window,omitempty"` + UseV2BlockManager bool `json:"use_v2_block_manager,omitempty"` + NumLookaheadSlots int `json:"num_lookahead_slots,omitempty"` + SwapSpace int `json:"swap_space,omitempty"` + CPUOffloadGB int `json:"cpu_offload_gb,omitempty"` + GPUMemoryUtilization float64 `json:"gpu_memory_utilization,omitempty"` + NumGPUBlocksOverride int `json:"num_gpu_blocks_override,omitempty"` + MaxNumBatchedTokens int `json:"max_num_batched_tokens,omitempty"` + MaxNumSeqs int `json:"max_num_seqs,omitempty"` + MaxLogprobs int `json:"max_logprobs,omitempty"` + DisableLogStats bool `json:"disable_log_stats,omitempty"` + Quantization string `json:"quantization,omitempty"` + RopeScaling string `json:"rope_scaling,omitempty"` + RopeTheta float64 `json:"rope_theta,omitempty"` + EnforceEager bool `json:"enforce_eager,omitempty"` + MaxContextLenToCapture int `json:"max_context_len_to_capture,omitempty"` + MaxSeqLenToCapture int `json:"max_seq_len_to_capture,omitempty"` + DisableCustomAllReduce bool `json:"disable_custom_all_reduce,omitempty"` + TokenizerPoolSize int `json:"tokenizer_pool_size,omitempty"` + TokenizerPoolType string `json:"tokenizer_pool_type,omitempty"` + TokenizerPoolExtraConfig string `json:"tokenizer_pool_extra_config,omitempty"` + EnableLoraBias bool `json:"enable_lora_bias,omitempty"` + LoraExtraVocabSize int `json:"lora_extra_vocab_size,omitempty"` + LoraRank int `json:"lora_rank,omitempty"` + PromptLookbackDistance int `json:"prompt_lookback_distance,omitempty"` + PreemptionMode string `json:"preemption_mode,omitempty"` // Distributed and parallel processing - TensorParallelSize int `json:"tensor_parallel_size,omitempty"` - PipelineParallelSize int `json:"pipeline_parallel_size,omitempty"` - MaxParallelLoadingWorkers int `json:"max_parallel_loading_workers,omitempty"` - DisableAsyncOutputProc bool `json:"disable_async_output_proc,omitempty"` - WorkerClass string `json:"worker_class,omitempty"` - EnabledLoraModules string `json:"enabled_lora_modules,omitempty"` - MaxLoraRank int `json:"max_lora_rank,omitempty"` - FullyShardedLoras bool `json:"fully_sharded_loras,omitempty"` - LoraModules string `json:"lora_modules,omitempty"` - PromptAdapters string `json:"prompt_adapters,omitempty"` - MaxPromptAdapterToken int `json:"max_prompt_adapter_token,omitempty"` - Device string `json:"device,omitempty"` - SchedulerDelay float64 `json:"scheduler_delay,omitempty"` - EnableChunkedPrefill bool `json:"enable_chunked_prefill,omitempty"` - SpeculativeModel string `json:"speculative_model,omitempty"` - SpeculativeModelQuantization string `json:"speculative_model_quantization,omitempty"` - SpeculativeRevision string `json:"speculative_revision,omitempty"` - SpeculativeMaxModelLen int `json:"speculative_max_model_len,omitempty"` - SpeculativeDisableByBatchSize int `json:"speculative_disable_by_batch_size,omitempty"` - NgptSpeculativeLength int `json:"ngpt_speculative_length,omitempty"` - SpeculativeDisableMqa bool `json:"speculative_disable_mqa,omitempty"` - ModelLoaderExtraConfig string `json:"model_loader_extra_config,omitempty"` - IgnorePatterns string `json:"ignore_patterns,omitempty"` - PreloadedLoraModules string `json:"preloaded_lora_modules,omitempty"` + TensorParallelSize int `json:"tensor_parallel_size,omitempty"` + PipelineParallelSize int `json:"pipeline_parallel_size,omitempty"` + MaxParallelLoadingWorkers int `json:"max_parallel_loading_workers,omitempty"` + DisableAsyncOutputProc bool `json:"disable_async_output_proc,omitempty"` + WorkerClass string `json:"worker_class,omitempty"` + EnabledLoraModules string `json:"enabled_lora_modules,omitempty"` + MaxLoraRank int `json:"max_lora_rank,omitempty"` + FullyShardedLoras bool `json:"fully_sharded_loras,omitempty"` + LoraModules string `json:"lora_modules,omitempty"` + PromptAdapters string `json:"prompt_adapters,omitempty"` + MaxPromptAdapterToken int `json:"max_prompt_adapter_token,omitempty"` + Device string `json:"device,omitempty"` + SchedulerDelay float64 `json:"scheduler_delay,omitempty"` + EnableChunkedPrefill bool `json:"enable_chunked_prefill,omitempty"` + SpeculativeModel string `json:"speculative_model,omitempty"` + SpeculativeModelQuantization string `json:"speculative_model_quantization,omitempty"` + SpeculativeRevision string `json:"speculative_revision,omitempty"` + SpeculativeMaxModelLen int `json:"speculative_max_model_len,omitempty"` + SpeculativeDisableByBatchSize int `json:"speculative_disable_by_batch_size,omitempty"` + NgptSpeculativeLength int `json:"ngpt_speculative_length,omitempty"` + SpeculativeDisableMqa bool `json:"speculative_disable_mqa,omitempty"` + ModelLoaderExtraConfig string `json:"model_loader_extra_config,omitempty"` + IgnorePatterns string `json:"ignore_patterns,omitempty"` + PreloadedLoraModules string `json:"preloaded_lora_modules,omitempty"` // OpenAI server specific options - UDS string `json:"uds,omitempty"` - UvicornLogLevel string `json:"uvicorn_log_level,omitempty"` - ResponseRole string `json:"response_role,omitempty"` - SSLKeyfile string `json:"ssl_keyfile,omitempty"` - SSLCertfile string `json:"ssl_certfile,omitempty"` - SSLCACerts string `json:"ssl_ca_certs,omitempty"` - SSLCertReqs int `json:"ssl_cert_reqs,omitempty"` - RootPath string `json:"root_path,omitempty"` - Middleware []string `json:"middleware,omitempty"` - ReturnTokensAsTokenIDS bool `json:"return_tokens_as_token_ids,omitempty"` - DisableFrontendMultiprocessing bool `json:"disable_frontend_multiprocessing,omitempty"` - EnableAutoToolChoice bool `json:"enable_auto_tool_choice,omitempty"` - ToolCallParser string `json:"tool_call_parser,omitempty"` - ToolServer string `json:"tool_server,omitempty"` - ChatTemplate string `json:"chat_template,omitempty"` - ChatTemplateContentFormat string `json:"chat_template_content_format,omitempty"` - AllowCredentials bool `json:"allow_credentials,omitempty"` - AllowedOrigins []string `json:"allowed_origins,omitempty"` - AllowedMethods []string `json:"allowed_methods,omitempty"` - AllowedHeaders []string `json:"allowed_headers,omitempty"` - APIKey []string `json:"api_key,omitempty"` - EnableLogOutputs bool `json:"enable_log_outputs,omitempty"` - EnableTokenUsage bool `json:"enable_token_usage,omitempty"` - EnableAsyncEngineDebug bool `json:"enable_async_engine_debug,omitempty"` - EngineUseRay bool `json:"engine_use_ray,omitempty"` - DisableLogRequests bool `json:"disable_log_requests,omitempty"` - MaxLogLen int `json:"max_log_len,omitempty"` + UDS string `json:"uds,omitempty"` + UvicornLogLevel string `json:"uvicorn_log_level,omitempty"` + ResponseRole string `json:"response_role,omitempty"` + SSLKeyfile string `json:"ssl_keyfile,omitempty"` + SSLCertfile string `json:"ssl_certfile,omitempty"` + SSLCACerts string `json:"ssl_ca_certs,omitempty"` + SSLCertReqs int `json:"ssl_cert_reqs,omitempty"` + RootPath string `json:"root_path,omitempty"` + Middleware []string `json:"middleware,omitempty"` + ReturnTokensAsTokenIDS bool `json:"return_tokens_as_token_ids,omitempty"` + DisableFrontendMultiprocessing bool `json:"disable_frontend_multiprocessing,omitempty"` + EnableAutoToolChoice bool `json:"enable_auto_tool_choice,omitempty"` + ToolCallParser string `json:"tool_call_parser,omitempty"` + ToolServer string `json:"tool_server,omitempty"` + ChatTemplate string `json:"chat_template,omitempty"` + ChatTemplateContentFormat string `json:"chat_template_content_format,omitempty"` + AllowCredentials bool `json:"allow_credentials,omitempty"` + AllowedOrigins []string `json:"allowed_origins,omitempty"` + AllowedMethods []string `json:"allowed_methods,omitempty"` + AllowedHeaders []string `json:"allowed_headers,omitempty"` + APIKey []string `json:"api_key,omitempty"` + EnableLogOutputs bool `json:"enable_log_outputs,omitempty"` + EnableTokenUsage bool `json:"enable_token_usage,omitempty"` + EnableAsyncEngineDebug bool `json:"enable_async_engine_debug,omitempty"` + EngineUseRay bool `json:"engine_use_ray,omitempty"` + DisableLogRequests bool `json:"disable_log_requests,omitempty"` + MaxLogLen int `json:"max_log_len,omitempty"` // Additional engine configuration - Task string `json:"task,omitempty"` - MultiModalConfig string `json:"multi_modal_config,omitempty"` - LimitMmPerPrompt string `json:"limit_mm_per_prompt,omitempty"` - EnableSleepMode bool `json:"enable_sleep_mode,omitempty"` - EnableChunkingRequest bool `json:"enable_chunking_request,omitempty"` - CompilationConfig string `json:"compilation_config,omitempty"` - DisableSlidingWindowMask bool `json:"disable_sliding_window_mask,omitempty"` - EnableTRTLLMEngineLatency bool `json:"enable_trtllm_engine_latency,omitempty"` - OverridePoolingConfig string `json:"override_pooling_config,omitempty"` - OverrideNeuronConfig string `json:"override_neuron_config,omitempty"` - OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"` -} - -// NewVllmServerOptions creates a new VllmServerOptions with defaults -func NewVllmServerOptions() *VllmServerOptions { - return &VllmServerOptions{ - Host: "127.0.0.1", - Port: 8000, - TensorParallelSize: 1, - PipelineParallelSize: 1, - GPUMemoryUtilization: 0.9, - BlockSize: 16, - SwapSpace: 4, - UvicornLogLevel: "info", - ResponseRole: "assistant", - TokenizerMode: "auto", - TrustRemoteCode: false, - EnablePrefixCaching: false, - EnforceEager: false, - DisableLogStats: false, - DisableLogRequests: false, - MaxLogprobs: 20, - EnableLogOutputs: false, - EnableTokenUsage: false, - AllowCredentials: false, - AllowedOrigins: []string{"*"}, - AllowedMethods: []string{"*"}, - AllowedHeaders: []string{"*"}, - } -} - -// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names -func (o *VllmServerOptions) UnmarshalJSON(data []byte) error { - // First unmarshal into a map to handle multiple field names - var raw map[string]any - if err := json.Unmarshal(data, &raw); err != nil { - return err - } - - // Create a temporary struct for standard unmarshaling - type tempOptions VllmServerOptions - temp := tempOptions{} - - // Standard unmarshal first - if err := json.Unmarshal(data, &temp); err != nil { - return err - } - - // Copy to our struct - *o = VllmServerOptions(temp) - - // Handle alternative field names (CLI format with dashes) - fieldMappings := map[string]string{ - // Basic options - "tensor-parallel-size": "tensor_parallel_size", - "pipeline-parallel-size": "pipeline_parallel_size", - "max-parallel-loading-workers": "max_parallel_loading_workers", - "disable-async-output-proc": "disable_async_output_proc", - "worker-class": "worker_class", - "enabled-lora-modules": "enabled_lora_modules", - "max-lora-rank": "max_lora_rank", - "fully-sharded-loras": "fully_sharded_loras", - "lora-modules": "lora_modules", - "prompt-adapters": "prompt_adapters", - "max-prompt-adapter-token": "max_prompt_adapter_token", - "scheduler-delay": "scheduler_delay", - "enable-chunked-prefill": "enable_chunked_prefill", - "speculative-model": "speculative_model", - "speculative-model-quantization": "speculative_model_quantization", - "speculative-revision": "speculative_revision", - "speculative-max-model-len": "speculative_max_model_len", - "speculative-disable-by-batch-size": "speculative_disable_by_batch_size", - "ngpt-speculative-length": "ngpt_speculative_length", - "speculative-disable-mqa": "speculative_disable_mqa", - "model-loader-extra-config": "model_loader_extra_config", - "ignore-patterns": "ignore_patterns", - "preloaded-lora-modules": "preloaded_lora_modules", - - // Model configuration - "skip-tokenizer-init": "skip_tokenizer_init", - "code-revision": "code_revision", - "tokenizer-revision": "tokenizer_revision", - "tokenizer-mode": "tokenizer_mode", - "trust-remote-code": "trust_remote_code", - "download-dir": "download_dir", - "load-format": "load_format", - "config-format": "config_format", - "kv-cache-dtype": "kv_cache_dtype", - "quantization-param-path": "quantization_param_path", - "max-model-len": "max_model_len", - "guided-decoding-backend": "guided_decoding_backend", - "distributed-executor-backend": "distributed_executor_backend", - "worker-use-ray": "worker_use_ray", - "ray-workers-use-nsight": "ray_workers_use_nsight", - - // Performance configuration - "block-size": "block_size", - "enable-prefix-caching": "enable_prefix_caching", - "disable-sliding-window": "disable_sliding_window", - "use-v2-block-manager": "use_v2_block_manager", - "num-lookahead-slots": "num_lookahead_slots", - "swap-space": "swap_space", - "cpu-offload-gb": "cpu_offload_gb", - "gpu-memory-utilization": "gpu_memory_utilization", - "num-gpu-blocks-override": "num_gpu_blocks_override", - "max-num-batched-tokens": "max_num_batched_tokens", - "max-num-seqs": "max_num_seqs", - "max-logprobs": "max_logprobs", - "disable-log-stats": "disable_log_stats", - "rope-scaling": "rope_scaling", - "rope-theta": "rope_theta", - "enforce-eager": "enforce_eager", - "max-context-len-to-capture": "max_context_len_to_capture", - "max-seq-len-to-capture": "max_seq_len_to_capture", - "disable-custom-all-reduce": "disable_custom_all_reduce", - "tokenizer-pool-size": "tokenizer_pool_size", - "tokenizer-pool-type": "tokenizer_pool_type", - "tokenizer-pool-extra-config": "tokenizer_pool_extra_config", - "enable-lora-bias": "enable_lora_bias", - "lora-extra-vocab-size": "lora_extra_vocab_size", - "lora-rank": "lora_rank", - "prompt-lookback-distance": "prompt_lookback_distance", - "preemption-mode": "preemption_mode", - - // Server configuration - "uvicorn-log-level": "uvicorn_log_level", - "response-role": "response_role", - "ssl-keyfile": "ssl_keyfile", - "ssl-certfile": "ssl_certfile", - "ssl-ca-certs": "ssl_ca_certs", - "ssl-cert-reqs": "ssl_cert_reqs", - "root-path": "root_path", - "return-tokens-as-token-ids": "return_tokens_as_token_ids", - "disable-frontend-multiprocessing": "disable_frontend_multiprocessing", - "enable-auto-tool-choice": "enable_auto_tool_choice", - "tool-call-parser": "tool_call_parser", - "tool-server": "tool_server", - "chat-template": "chat_template", - "chat-template-content-format": "chat_template_content_format", - "allow-credentials": "allow_credentials", - "allowed-origins": "allowed_origins", - "allowed-methods": "allowed_methods", - "allowed-headers": "allowed_headers", - "api-key": "api_key", - "enable-log-outputs": "enable_log_outputs", - "enable-token-usage": "enable_token_usage", - "enable-async-engine-debug": "enable_async_engine_debug", - "engine-use-ray": "engine_use_ray", - "disable-log-requests": "disable_log_requests", - "max-log-len": "max_log_len", - - // Additional options - "multi-modal-config": "multi_modal_config", - "limit-mm-per-prompt": "limit_mm_per_prompt", - "enable-sleep-mode": "enable_sleep_mode", - "enable-chunking-request": "enable_chunking_request", - "compilation-config": "compilation_config", - "disable-sliding-window-mask": "disable_sliding_window_mask", - "enable-trtllm-engine-latency": "enable_trtllm_engine_latency", - "override-pooling-config": "override_pooling_config", - "override-neuron-config": "override_neuron_config", - "override-kv-cache-align-size": "override_kv_cache_align_size", - } - - // Process alternative field names - for altName, canonicalName := range fieldMappings { - if value, exists := raw[altName]; exists { - // Use reflection to set the field value - v := reflect.ValueOf(o).Elem() - field := v.FieldByNameFunc(func(fieldName string) bool { - field, _ := v.Type().FieldByName(fieldName) - jsonTag := field.Tag.Get("json") - return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName - }) - - if field.IsValid() && field.CanSet() { - switch field.Kind() { - case reflect.Int: - if intVal, ok := value.(float64); ok { - field.SetInt(int64(intVal)) - } else if strVal, ok := value.(string); ok { - if intVal, err := strconv.Atoi(strVal); err == nil { - field.SetInt(int64(intVal)) - } - } - case reflect.Float64: - if floatVal, ok := value.(float64); ok { - field.SetFloat(floatVal) - } else if strVal, ok := value.(string); ok { - if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil { - field.SetFloat(floatVal) - } - } - case reflect.String: - if strVal, ok := value.(string); ok { - field.SetString(strVal) - } - case reflect.Bool: - if boolVal, ok := value.(bool); ok { - field.SetBool(boolVal) - } - case reflect.Slice: - if field.Type().Elem().Kind() == reflect.String { - if strVal, ok := value.(string); ok { - // Split comma-separated values - values := strings.Split(strVal, ",") - for i, v := range values { - values[i] = strings.TrimSpace(v) - } - field.Set(reflect.ValueOf(values)) - } else if slice, ok := value.([]interface{}); ok { - var strSlice []string - for _, item := range slice { - if str, ok := item.(string); ok { - strSlice = append(strSlice, str) - } - } - field.Set(reflect.ValueOf(strSlice)) - } - } - } - } - } - } - - return nil + Task string `json:"task,omitempty"` + MultiModalConfig string `json:"multi_modal_config,omitempty"` + LimitMmPerPrompt string `json:"limit_mm_per_prompt,omitempty"` + EnableSleepMode bool `json:"enable_sleep_mode,omitempty"` + EnableChunkingRequest bool `json:"enable_chunking_request,omitempty"` + CompilationConfig string `json:"compilation_config,omitempty"` + DisableSlidingWindowMask bool `json:"disable_sliding_window_mask,omitempty"` + EnableTRTLLMEngineLatency bool `json:"enable_trtllm_engine_latency,omitempty"` + OverridePoolingConfig string `json:"override_pooling_config,omitempty"` + OverrideNeuronConfig string `json:"override_neuron_config,omitempty"` + OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"` } // BuildCommandArgs converts VllmServerOptions to command line arguments @@ -387,11 +161,6 @@ func (o *VllmServerOptions) BuildCommandArgs() []string { flagName = jsonTag[:commaIndex] } - // Skip host and port as they are handled by llamactl - if flagName == "host" || flagName == "port" { - continue - } - // Convert snake_case to kebab-case for CLI flags flagName = strings.ReplaceAll(flagName, "_", "-") @@ -436,4 +205,4 @@ func (o *VllmServerOptions) BuildCommandArgs() []string { } return args -} \ No newline at end of file +} diff --git a/pkg/backends/vllm/vllm_test.go b/pkg/backends/vllm/vllm_test.go index e05320a..8a42862 100644 --- a/pkg/backends/vllm/vllm_test.go +++ b/pkg/backends/vllm/vllm_test.go @@ -10,12 +10,12 @@ import ( func TestBuildCommandArgs(t *testing.T) { options := vllm.VllmServerOptions{ Model: "microsoft/DialoGPT-medium", - Port: 8080, // should be excluded - Host: "localhost", // should be excluded + Port: 8080, + Host: "localhost", TensorParallelSize: 2, GPUMemoryUtilization: 0.8, EnableLogOutputs: true, - APIKey: []string{"key1", "key2"}, + AllowedOrigins: []string{"http://localhost:3000", "https://example.com"}, } args := options.BuildCommandArgs() @@ -32,19 +32,22 @@ func TestBuildCommandArgs(t *testing.T) { } // Host and port should NOT be in the arguments (handled by llamactl) - if contains(args, "--host") || contains(args, "--port") { - t.Errorf("Host and port should not be in command args, found in %v", args) + if !contains(args, "--host") { + t.Errorf("Expected --host not found in %v", args) + } + if !contains(args, "--port") { + t.Errorf("Expected --port not found in %v", args) } // Check array handling (multiple flags) - apiKeyCount := 0 + allowedOriginsCount := 0 for i := range args { - if args[i] == "--api-key" { - apiKeyCount++ + if args[i] == "--allowed-origins" { + allowedOriginsCount++ } } - if apiKeyCount != 2 { - t.Errorf("Expected 2 --api-key flags, got %d", apiKeyCount) + if allowedOriginsCount != 2 { + t.Errorf("Expected 2 --allowed-origins flags, got %d", allowedOriginsCount) } } @@ -77,20 +80,6 @@ func TestUnmarshalJSON(t *testing.T) { } } -func TestNewVllmServerOptions(t *testing.T) { - options := vllm.NewVllmServerOptions() - - if options == nil { - t.Fatal("NewVllmServerOptions returned nil") - } - if options.Host != "127.0.0.1" { - t.Errorf("Expected default host '127.0.0.1', got %q", options.Host) - } - if options.Port != 8000 { - t.Errorf("Expected default port 8000, got %d", options.Port) - } -} - // Helper functions func contains(slice []string, item string) bool { return slices.Contains(slice, item) @@ -103,4 +92,4 @@ func containsFlagWithValue(args []string, flag, value string) bool { } } return false -} \ No newline at end of file +}