From 64842e74b07ef9ec8c750d559e5d224690d348bd Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 19 Sep 2025 20:23:25 +0200 Subject: [PATCH] Refactor command parsing and building --- pkg/backends/builder.go | 70 ++++++++++++++++ pkg/backends/llamacpp/llama.go | 28 +++++++ pkg/backends/llamacpp/llama_test.go | 115 +++++++++++++++++++++++++ pkg/backends/llamacpp/parser.go | 34 -------- pkg/backends/llamacpp/parser_test.go | 121 --------------------------- pkg/backends/mlx/mlx.go | 97 +++++---------------- pkg/backends/mlx/mlx_test.go | 95 +++++++++++++++++++++ pkg/backends/mlx/parser.go | 24 ------ pkg/backends/mlx/parser_test.go | 101 ---------------------- pkg/backends/parser.go | 64 -------------- pkg/backends/vllm/parser.go | 34 -------- pkg/backends/vllm/parser_test.go | 84 ------------------- pkg/backends/vllm/vllm.go | 28 +++++++ pkg/backends/vllm/vllm_test.go | 78 +++++++++++++++++ 14 files changed, 434 insertions(+), 539 deletions(-) create mode 100644 pkg/backends/builder.go delete mode 100644 pkg/backends/llamacpp/parser.go delete mode 100644 pkg/backends/llamacpp/parser_test.go delete mode 100644 pkg/backends/mlx/parser.go delete mode 100644 pkg/backends/mlx/parser_test.go delete mode 100644 pkg/backends/vllm/parser.go delete mode 100644 pkg/backends/vllm/parser_test.go diff --git a/pkg/backends/builder.go b/pkg/backends/builder.go new file mode 100644 index 0000000..23c3bb1 --- /dev/null +++ b/pkg/backends/builder.go @@ -0,0 +1,70 @@ +package backends + +import ( + "reflect" + "strconv" + "strings" +) + +// BuildCommandArgs converts a struct to command line arguments +func BuildCommandArgs(options any, multipleFlags map[string]bool) []string { + var args []string + + v := reflect.ValueOf(options).Elem() + t := v.Type() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldType := t.Field(i) + + if !field.CanInterface() { + continue + } + + jsonTag := fieldType.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } + + // Get flag name from JSON tag + flagName := strings.Split(jsonTag, ",")[0] + flagName = strings.ReplaceAll(flagName, "_", "-") + + switch field.Kind() { + case reflect.Bool: + if field.Bool() { + args = append(args, "--"+flagName) + } + case reflect.Int: + if field.Int() != 0 { + args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) + } + case reflect.Float64: + if field.Float() != 0 { + args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) + } + case reflect.String: + if field.String() != "" { + args = append(args, "--"+flagName, field.String()) + } + case reflect.Slice: + if field.Type().Elem().Kind() == reflect.String && field.Len() > 0 { + if multipleFlags[flagName] { + // Multiple flags: --flag value1 --flag value2 + for j := 0; j < field.Len(); j++ { + args = append(args, "--"+flagName, field.Index(j).String()) + } + } else { + // Comma-separated: --flag value1,value2 + var values []string + for j := 0; j < field.Len(); j++ { + values = append(values, field.Index(j).String()) + } + args = append(args, "--"+flagName, strings.Join(values, ",")) + } + } + } + } + + return args +} diff --git a/pkg/backends/llamacpp/llama.go b/pkg/backends/llamacpp/llama.go index 7c8a21f..f2a7d31 100644 --- a/pkg/backends/llamacpp/llama.go +++ b/pkg/backends/llamacpp/llama.go @@ -328,3 +328,31 @@ func (o *LlamaServerOptions) BuildCommandArgs() []string { } return backends.BuildCommandArgs(o, multipleFlags) } + +// ParseLlamaCommand parses a llama-server command string into LlamaServerOptions +// Supports multiple formats: +// 1. Full command: "llama-server --model file.gguf" +// 2. Full path: "/usr/local/bin/llama-server --model file.gguf" +// 3. Args only: "--model file.gguf --gpu-layers 32" +// 4. Multiline commands with backslashes +func ParseLlamaCommand(command string) (*LlamaServerOptions, error) { + executableNames := []string{"llama-server"} + var subcommandNames []string // Llama has no subcommands + multiValuedFlags := map[string]bool{ + "override_tensor": true, + "override_kv": true, + "lora": true, + "lora_scaled": true, + "control_vector": true, + "control_vector_scaled": true, + "dry_sequence_breaker": true, + "logit_bias": true, + } + + var llamaOptions LlamaServerOptions + if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &llamaOptions); err != nil { + return nil, err + } + + return &llamaOptions, nil +} diff --git a/pkg/backends/llamacpp/llama_test.go b/pkg/backends/llamacpp/llama_test.go index 9c1162e..c779320 100644 --- a/pkg/backends/llamacpp/llama_test.go +++ b/pkg/backends/llamacpp/llama_test.go @@ -378,6 +378,121 @@ func TestUnmarshalJSON_ArrayFields(t *testing.T) { } } +func TestParseLlamaCommand(t *testing.T) { + tests := []struct { + name string + command string + expectErr bool + }{ + { + name: "basic command", + command: "llama-server --model /path/to/model.gguf --gpu-layers 32", + expectErr: false, + }, + { + name: "args only", + command: "--model /path/to/model.gguf --ctx-size 4096", + expectErr: false, + }, + { + name: "mixed flag formats", + command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose", + expectErr: false, + }, + { + name: "quoted strings", + command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`, + expectErr: false, + }, + { + name: "empty command", + command: "", + expectErr: true, + }, + { + name: "unterminated quote", + command: `llama-server --model test.gguf --api-key "unterminated`, + expectErr: true, + }, + { + name: "malformed flag", + command: "llama-server ---model test.gguf", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := llamacpp.ParseLlamaCommand(tt.command) + + if tt.expectErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result == nil { + t.Errorf("expected result but got nil") + } + }) + } +} + +func TestParseLlamaCommandValues(t *testing.T) { + command := "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap" + result, err := llamacpp.ParseLlamaCommand(command) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Model != "/test/model.gguf" { + t.Errorf("expected model '/test/model.gguf', got '%s'", result.Model) + } + + if result.GPULayers != 32 { + t.Errorf("expected gpu_layers 32, got %d", result.GPULayers) + } + + if result.Temperature != 0.7 { + t.Errorf("expected temperature 0.7, got %f", result.Temperature) + } + + if !result.Verbose { + t.Errorf("expected verbose to be true") + } + + if !result.NoMmap { + t.Errorf("expected no_mmap to be true") + } +} + +func TestParseLlamaCommandArrays(t *testing.T) { + command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin" + result, err := llamacpp.ParseLlamaCommand(command) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Lora) != 2 { + t.Errorf("expected 2 lora adapters, got %d", len(result.Lora)) + } + + expected := []string{"adapter1.bin", "adapter2.bin"} + for i, v := range expected { + if result.Lora[i] != v { + t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i]) + } + } +} + // Helper functions func contains(slice []string, item string) bool { return slices.Contains(slice, item) diff --git a/pkg/backends/llamacpp/parser.go b/pkg/backends/llamacpp/parser.go deleted file mode 100644 index b5b850a..0000000 --- a/pkg/backends/llamacpp/parser.go +++ /dev/null @@ -1,34 +0,0 @@ -package llamacpp - -import ( - "llamactl/pkg/backends" -) - -// ParseLlamaCommand parses a llama-server command string into LlamaServerOptions -// Supports multiple formats: -// 1. Full command: "llama-server --model file.gguf" -// 2. Full path: "/usr/local/bin/llama-server --model file.gguf" -// 3. Args only: "--model file.gguf --gpu-layers 32" -// 4. Multiline commands with backslashes -func ParseLlamaCommand(command string) (*LlamaServerOptions, error) { - executableNames := []string{"llama-server"} - var subcommandNames []string // Llama has no subcommands - multiValuedFlags := map[string]bool{ - "override_tensor": true, - "override_kv": true, - "lora": true, - "lora_scaled": true, - "control_vector": true, - "control_vector_scaled": true, - "dry_sequence_breaker": true, - "logit_bias": true, - } - - var llamaOptions LlamaServerOptions - if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &llamaOptions); err != nil { - return nil, err - } - - return &llamaOptions, nil -} - diff --git a/pkg/backends/llamacpp/parser_test.go b/pkg/backends/llamacpp/parser_test.go deleted file mode 100644 index 7072f65..0000000 --- a/pkg/backends/llamacpp/parser_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package llamacpp_test - -import ( - "llamactl/pkg/backends/llamacpp" - "testing" -) - -func TestParseLlamaCommand(t *testing.T) { - tests := []struct { - name string - command string - expectErr bool - }{ - { - name: "basic command", - command: "llama-server --model /path/to/model.gguf --gpu-layers 32", - expectErr: false, - }, - { - name: "args only", - command: "--model /path/to/model.gguf --ctx-size 4096", - expectErr: false, - }, - { - name: "mixed flag formats", - command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose", - expectErr: false, - }, - { - name: "quoted strings", - command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`, - expectErr: false, - }, - { - name: "empty command", - command: "", - expectErr: true, - }, - { - name: "unterminated quote", - command: `llama-server --model test.gguf --api-key "unterminated`, - expectErr: true, - }, - { - name: "malformed flag", - command: "llama-server ---model test.gguf", - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := llamacpp.ParseLlamaCommand(tt.command) - - if tt.expectErr { - if err == nil { - t.Errorf("expected error but got none") - } - return - } - - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - - if result == nil { - t.Errorf("expected result but got nil") - } - }) - } -} - -func TestParseLlamaCommandValues(t *testing.T) { - command := "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap" - result, err := llamacpp.ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "/test/model.gguf" { - t.Errorf("expected model '/test/model.gguf', got '%s'", result.Model) - } - - if result.GPULayers != 32 { - t.Errorf("expected gpu_layers 32, got %d", result.GPULayers) - } - - if result.Temperature != 0.7 { - t.Errorf("expected temperature 0.7, got %f", result.Temperature) - } - - if !result.Verbose { - t.Errorf("expected verbose to be true") - } - - if !result.NoMmap { - t.Errorf("expected no_mmap to be true") - } -} - -func TestParseLlamaCommandArrays(t *testing.T) { - command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin" - result, err := llamacpp.ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(result.Lora) != 2 { - t.Errorf("expected 2 lora adapters, got %d", len(result.Lora)) - } - - expected := []string{"adapter1.bin", "adapter2.bin"} - for i, v := range expected { - if result.Lora[i] != v { - t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i]) - } - } -} diff --git a/pkg/backends/mlx/mlx.go b/pkg/backends/mlx/mlx.go index c72597c..3b83681 100644 --- a/pkg/backends/mlx/mlx.go +++ b/pkg/backends/mlx/mlx.go @@ -1,10 +1,7 @@ package mlx import ( - "encoding/json" "llamactl/pkg/backends" - "reflect" - "strconv" ) type MlxServerOptions struct { @@ -26,88 +23,34 @@ type MlxServerOptions struct { ChatTemplateArgs string `json:"chat_template_args,omitempty"` // JSON string // Sampling defaults - Temp float64 `json:"temp,omitempty"` // Note: MLX uses "temp" not "temperature" + Temp float64 `json:"temp,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` MinP float64 `json:"min_p,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` } -// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names -func (o *MlxServerOptions) UnmarshalJSON(data []byte) error { - // First unmarshal into a map to handle multiple field names - var raw map[string]any - if err := json.Unmarshal(data, &raw); err != nil { - return err - } - - // Create a temporary struct for standard unmarshaling - type tempOptions MlxServerOptions - temp := tempOptions{} - - // Standard unmarshal first - if err := json.Unmarshal(data, &temp); err != nil { - return err - } - - // Copy to our struct - *o = MlxServerOptions(temp) - - // Handle alternative field names - fieldMappings := map[string]string{ - "m": "model", // -m, --model - "temperature": "temp", // --temperature vs --temp - "top_k": "top_k", // --top-k - "adapter_path": "adapter_path", // --adapter-path - } - - // Process alternative field names - for altName, canonicalName := range fieldMappings { - if value, exists := raw[altName]; exists { - // Use reflection to set the field value - v := reflect.ValueOf(o).Elem() - field := v.FieldByNameFunc(func(fieldName string) bool { - field, _ := v.Type().FieldByName(fieldName) - jsonTag := field.Tag.Get("json") - return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName - }) - - if field.IsValid() && field.CanSet() { - switch field.Kind() { - case reflect.Int: - if intVal, ok := value.(float64); ok { - field.SetInt(int64(intVal)) - } else if strVal, ok := value.(string); ok { - if intVal, err := strconv.Atoi(strVal); err == nil { - field.SetInt(int64(intVal)) - } - } - case reflect.Float64: - if floatVal, ok := value.(float64); ok { - field.SetFloat(floatVal) - } else if strVal, ok := value.(string); ok { - if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil { - field.SetFloat(floatVal) - } - } - case reflect.String: - if strVal, ok := value.(string); ok { - field.SetString(strVal) - } - case reflect.Bool: - if boolVal, ok := value.(bool); ok { - field.SetBool(boolVal) - } - } - } - } - } - - return nil -} - // BuildCommandArgs converts to command line arguments func (o *MlxServerOptions) BuildCommandArgs() []string { multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields return backends.BuildCommandArgs(o, multipleFlags) } + +// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions +// Supports multiple formats: +// 1. Full command: "mlx_lm.server --model model/path" +// 2. Full path: "/usr/local/bin/mlx_lm.server --model model/path" +// 3. Args only: "--model model/path --host 0.0.0.0" +// 4. Multiline commands with backslashes +func ParseMlxCommand(command string) (*MlxServerOptions, error) { + executableNames := []string{"mlx_lm.server"} + var subcommandNames []string // MLX has no subcommands + multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags + + var mlxOptions MlxServerOptions + if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil { + return nil, err + } + + return &mlxOptions, nil +} diff --git a/pkg/backends/mlx/mlx_test.go b/pkg/backends/mlx/mlx_test.go index b35f512..8baeb5c 100644 --- a/pkg/backends/mlx/mlx_test.go +++ b/pkg/backends/mlx/mlx_test.go @@ -5,6 +5,101 @@ import ( "testing" ) +func TestParseMlxCommand(t *testing.T) { + tests := []struct { + name string + command string + expectErr bool + }{ + { + name: "basic command", + command: "mlx_lm.server --model /path/to/model --host 0.0.0.0", + expectErr: false, + }, + { + name: "args only", + command: "--model /path/to/model --port 8080", + expectErr: false, + }, + { + name: "mixed flag formats", + command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code", + expectErr: false, + }, + { + name: "quoted strings", + command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`, + expectErr: false, + }, + { + name: "empty command", + command: "", + expectErr: true, + }, + { + name: "unterminated quote", + command: `mlx_lm.server --model test.mlx --chat-template "unterminated`, + expectErr: true, + }, + { + name: "malformed flag", + command: "mlx_lm.server ---model test.mlx", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := mlx.ParseMlxCommand(tt.command) + + if tt.expectErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result == nil { + t.Errorf("expected result but got nil") + } + }) + } +} + +func TestParseMlxCommandValues(t *testing.T) { + command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG" + result, err := mlx.ParseMlxCommand(command) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Model != "/test/model.mlx" { + t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model) + } + + if result.Port != 8080 { + t.Errorf("expected port 8080, got %d", result.Port) + } + + if result.Temp != 0.7 { + t.Errorf("expected temp 0.7, got %f", result.Temp) + } + + if !result.TrustRemoteCode { + t.Errorf("expected trust_remote_code to be true") + } + + if result.LogLevel != "DEBUG" { + t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel) + } +} + func TestBuildCommandArgs(t *testing.T) { options := &mlx.MlxServerOptions{ Model: "/test/model.mlx", diff --git a/pkg/backends/mlx/parser.go b/pkg/backends/mlx/parser.go deleted file mode 100644 index ec4cfb2..0000000 --- a/pkg/backends/mlx/parser.go +++ /dev/null @@ -1,24 +0,0 @@ -package mlx - -import ( - "llamactl/pkg/backends" -) - -// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions -// Supports multiple formats: -// 1. Full command: "mlx_lm.server --model model/path" -// 2. Full path: "/usr/local/bin/mlx_lm.server --model model/path" -// 3. Args only: "--model model/path --host 0.0.0.0" -// 4. Multiline commands with backslashes -func ParseMlxCommand(command string) (*MlxServerOptions, error) { - executableNames := []string{"mlx_lm.server"} - var subcommandNames []string // MLX has no subcommands - multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags - - var mlxOptions MlxServerOptions - if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil { - return nil, err - } - - return &mlxOptions, nil -} diff --git a/pkg/backends/mlx/parser_test.go b/pkg/backends/mlx/parser_test.go deleted file mode 100644 index 6caae84..0000000 --- a/pkg/backends/mlx/parser_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package mlx_test - -import ( - "llamactl/pkg/backends/mlx" - "testing" -) - -func TestParseMlxCommand(t *testing.T) { - tests := []struct { - name string - command string - expectErr bool - }{ - { - name: "basic command", - command: "mlx_lm.server --model /path/to/model --host 0.0.0.0", - expectErr: false, - }, - { - name: "args only", - command: "--model /path/to/model --port 8080", - expectErr: false, - }, - { - name: "mixed flag formats", - command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code", - expectErr: false, - }, - { - name: "quoted strings", - command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`, - expectErr: false, - }, - { - name: "empty command", - command: "", - expectErr: true, - }, - { - name: "unterminated quote", - command: `mlx_lm.server --model test.mlx --chat-template "unterminated`, - expectErr: true, - }, - { - name: "malformed flag", - command: "mlx_lm.server ---model test.mlx", - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := mlx.ParseMlxCommand(tt.command) - - if tt.expectErr { - if err == nil { - t.Errorf("expected error but got none") - } - return - } - - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - - if result == nil { - t.Errorf("expected result but got nil") - } - }) - } -} - -func TestParseMlxCommandValues(t *testing.T) { - command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG" - result, err := mlx.ParseMlxCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "/test/model.mlx" { - t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model) - } - - if result.Port != 8080 { - t.Errorf("expected port 8080, got %d", result.Port) - } - - if result.Temp != 0.7 { - t.Errorf("expected temp 0.7, got %f", result.Temp) - } - - if !result.TrustRemoteCode { - t.Errorf("expected trust_remote_code to be true") - } - - if result.LogLevel != "DEBUG" { - t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel) - } -} diff --git a/pkg/backends/parser.go b/pkg/backends/parser.go index 721173f..89ad46e 100644 --- a/pkg/backends/parser.go +++ b/pkg/backends/parser.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "path/filepath" - "reflect" "regexp" "strconv" "strings" @@ -43,69 +42,6 @@ func ParseCommand(command string, executableNames []string, subcommandNames []st return nil } -// BuildCommandArgs converts a struct to command line arguments -func BuildCommandArgs(options any, multipleFlags map[string]bool) []string { - var args []string - - v := reflect.ValueOf(options).Elem() - t := v.Type() - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - fieldType := t.Field(i) - - if !field.CanInterface() { - continue - } - - jsonTag := fieldType.Tag.Get("json") - if jsonTag == "" || jsonTag == "-" { - continue - } - - // Get flag name from JSON tag - flagName := strings.Split(jsonTag, ",")[0] - flagName = strings.ReplaceAll(flagName, "_", "-") - - switch field.Kind() { - case reflect.Bool: - if field.Bool() { - args = append(args, "--"+flagName) - } - case reflect.Int: - if field.Int() != 0 { - args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) - } - case reflect.Float64: - if field.Float() != 0 { - args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) - } - case reflect.String: - if field.String() != "" { - args = append(args, "--"+flagName, field.String()) - } - case reflect.Slice: - if field.Type().Elem().Kind() == reflect.String && field.Len() > 0 { - if multipleFlags[flagName] { - // Multiple flags: --flag value1 --flag value2 - for j := 0; j < field.Len(); j++ { - args = append(args, "--"+flagName, field.Index(j).String()) - } - } else { - // Comma-separated: --flag value1,value2 - var values []string - for j := 0; j < field.Len(); j++ { - values = append(values, field.Index(j).String()) - } - args = append(args, "--"+flagName, strings.Join(values, ",")) - } - } - } - } - - return args -} - // normalizeCommand handles multiline commands with backslashes func normalizeCommand(command string) string { re := regexp.MustCompile(`\\\s*\n\s*`) diff --git a/pkg/backends/vllm/parser.go b/pkg/backends/vllm/parser.go deleted file mode 100644 index 5eb3fbf..0000000 --- a/pkg/backends/vllm/parser.go +++ /dev/null @@ -1,34 +0,0 @@ -package vllm - -import ( - "llamactl/pkg/backends" -) - -// ParseVllmCommand parses a vLLM serve command string into VllmServerOptions -// Supports multiple formats: -// 1. Full command: "vllm serve --model MODEL_NAME --other-args" -// 2. Full path: "/usr/local/bin/vllm serve --model MODEL_NAME" -// 3. Serve only: "serve --model MODEL_NAME --other-args" -// 4. Args only: "--model MODEL_NAME --other-args" -// 5. Multiline commands with backslashes -func ParseVllmCommand(command string) (*VllmServerOptions, error) { - executableNames := []string{"vllm"} - subcommandNames := []string{"serve"} - multiValuedFlags := map[string]bool{ - "middleware": true, - "api_key": true, - "allowed_origins": true, - "allowed_methods": true, - "allowed_headers": true, - "lora_modules": true, - "prompt_adapters": true, - } - - var vllmOptions VllmServerOptions - if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil { - return nil, err - } - - return &vllmOptions, nil -} - diff --git a/pkg/backends/vllm/parser_test.go b/pkg/backends/vllm/parser_test.go deleted file mode 100644 index 3a12456..0000000 --- a/pkg/backends/vllm/parser_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package vllm_test - -import ( - "llamactl/pkg/backends/vllm" - "testing" -) - -func TestParseVllmCommand(t *testing.T) { - tests := []struct { - name string - command string - expectErr bool - }{ - { - name: "basic vllm serve command", - command: "vllm serve --model microsoft/DialoGPT-medium", - expectErr: false, - }, - { - name: "serve only command", - command: "serve --model microsoft/DialoGPT-medium", - expectErr: false, - }, - { - name: "args only", - command: "--model microsoft/DialoGPT-medium --tensor-parallel-size 2", - expectErr: false, - }, - { - name: "empty command", - command: "", - expectErr: true, - }, - { - name: "unterminated quote", - command: `vllm serve --model "unterminated`, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := vllm.ParseVllmCommand(tt.command) - - if tt.expectErr { - if err == nil { - t.Errorf("expected error but got none") - } - return - } - - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - - if result == nil { - t.Errorf("expected result but got nil") - } - }) - } -} - -func TestParseVllmCommandValues(t *testing.T) { - command := "vllm serve --model test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs" - result, err := vllm.ParseVllmCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "test-model" { - t.Errorf("expected model 'test-model', got '%s'", result.Model) - } - if result.TensorParallelSize != 4 { - t.Errorf("expected tensor_parallel_size 4, got %d", result.TensorParallelSize) - } - if result.GPUMemoryUtilization != 0.8 { - t.Errorf("expected gpu_memory_utilization 0.8, got %f", result.GPUMemoryUtilization) - } - if !result.EnableLogOutputs { - t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs) - } -} diff --git a/pkg/backends/vllm/vllm.go b/pkg/backends/vllm/vllm.go index 2ab6ed8..df080ea 100644 --- a/pkg/backends/vllm/vllm.go +++ b/pkg/backends/vllm/vllm.go @@ -142,3 +142,31 @@ func (o *VllmServerOptions) BuildCommandArgs() []string { } return backends.BuildCommandArgs(o, multipleFlags) } + +// ParseVllmCommand parses a vLLM serve command string into VllmServerOptions +// Supports multiple formats: +// 1. Full command: "vllm serve --model MODEL_NAME --other-args" +// 2. Full path: "/usr/local/bin/vllm serve --model MODEL_NAME" +// 3. Serve only: "serve --model MODEL_NAME --other-args" +// 4. Args only: "--model MODEL_NAME --other-args" +// 5. Multiline commands with backslashes +func ParseVllmCommand(command string) (*VllmServerOptions, error) { + executableNames := []string{"vllm"} + subcommandNames := []string{"serve"} + multiValuedFlags := map[string]bool{ + "middleware": true, + "api_key": true, + "allowed_origins": true, + "allowed_methods": true, + "allowed_headers": true, + "lora_modules": true, + "prompt_adapters": true, + } + + var vllmOptions VllmServerOptions + if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil { + return nil, err + } + + return &vllmOptions, nil +} diff --git a/pkg/backends/vllm/vllm_test.go b/pkg/backends/vllm/vllm_test.go index 8a42862..40423b9 100644 --- a/pkg/backends/vllm/vllm_test.go +++ b/pkg/backends/vllm/vllm_test.go @@ -7,6 +7,84 @@ import ( "testing" ) +func TestParseVllmCommand(t *testing.T) { + tests := []struct { + name string + command string + expectErr bool + }{ + { + name: "basic vllm serve command", + command: "vllm serve --model microsoft/DialoGPT-medium", + expectErr: false, + }, + { + name: "serve only command", + command: "serve --model microsoft/DialoGPT-medium", + expectErr: false, + }, + { + name: "args only", + command: "--model microsoft/DialoGPT-medium --tensor-parallel-size 2", + expectErr: false, + }, + { + name: "empty command", + command: "", + expectErr: true, + }, + { + name: "unterminated quote", + command: `vllm serve --model "unterminated`, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := vllm.ParseVllmCommand(tt.command) + + if tt.expectErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result == nil { + t.Errorf("expected result but got nil") + } + }) + } +} + +func TestParseVllmCommandValues(t *testing.T) { + command := "vllm serve --model test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs" + result, err := vllm.ParseVllmCommand(command) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Model != "test-model" { + t.Errorf("expected model 'test-model', got '%s'", result.Model) + } + if result.TensorParallelSize != 4 { + t.Errorf("expected tensor_parallel_size 4, got %d", result.TensorParallelSize) + } + if result.GPUMemoryUtilization != 0.8 { + t.Errorf("expected gpu_memory_utilization 0.8, got %f", result.GPUMemoryUtilization) + } + if !result.EnableLogOutputs { + t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs) + } +} + func TestBuildCommandArgs(t *testing.T) { options := vllm.VllmServerOptions{ Model: "microsoft/DialoGPT-medium",