diff --git a/pkg/backends/llama_test.go b/pkg/backends/llama_test.go index 63fe533..c05a3a5 100644 --- a/pkg/backends/llama_test.go +++ b/pkg/backends/llama_test.go @@ -4,44 +4,11 @@ import ( "encoding/json" "fmt" "llamactl/pkg/backends" + "llamactl/pkg/testutil" "reflect" - "slices" "testing" ) -func TestLlamaCppBuildCommandArgs_BasicFields(t *testing.T) { - options := backends.LlamaServerOptions{ - Model: "/path/to/model.gguf", - Port: 8080, - Host: "localhost", - Verbose: true, - CtxSize: 4096, - GPULayers: 32, - } - - args := options.BuildCommandArgs() - - // Check individual arguments - expectedPairs := map[string]string{ - "--model": "/path/to/model.gguf", - "--port": "8080", - "--host": "localhost", - "--ctx-size": "4096", - "--gpu-layers": "32", - } - - for flag, expectedValue := range expectedPairs { - if !containsFlagWithValue(args, flag, expectedValue) { - t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args) - } - } - - // Check standalone boolean flag - if !contains(args, "--verbose") { - t.Errorf("Expected --verbose flag not found in %v", args) - } -} - func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) { tests := []struct { name string @@ -81,13 +48,13 @@ func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) { args := tt.options.BuildCommandArgs() for _, expectedArg := range tt.expected { - if !contains(args, expectedArg) { + if !testutil.Contains(args, expectedArg) { t.Errorf("Expected argument %q not found in %v", expectedArg, args) } } for _, excludedArg := range tt.excluded { - if contains(args, excludedArg) { + if testutil.Contains(args, excludedArg) { t.Errorf("Excluded argument %q found in %v", excludedArg, args) } } @@ -95,36 +62,6 @@ func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) { } } -func TestLlamaCppBuildCommandArgs_NumericFields(t *testing.T) { - options := backends.LlamaServerOptions{ - Port: 8080, - Threads: 4, - CtxSize: 2048, - GPULayers: 16, - Temperature: 0.7, - TopK: 40, - TopP: 0.9, - } - - args := options.BuildCommandArgs() - - expectedPairs := map[string]string{ - "--port": "8080", - "--threads": "4", - "--ctx-size": "2048", - "--gpu-layers": "16", - "--temp": "0.7", - "--top-k": "40", - "--top-p": "0.9", - } - - for flag, expectedValue := range expectedPairs { - if !containsFlagWithValue(args, flag, expectedValue) { - t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args) - } - } -} - func TestLlamaCppBuildCommandArgs_ZeroValues(t *testing.T) { options := backends.LlamaServerOptions{ Port: 0, // Should be excluded @@ -146,7 +83,7 @@ func TestLlamaCppBuildCommandArgs_ZeroValues(t *testing.T) { } for _, excludedArg := range excludedArgs { - if contains(args, excludedArg) { + if testutil.Contains(args, excludedArg) { t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args) } } @@ -170,7 +107,7 @@ func TestLlamaCppBuildCommandArgs_ArrayFields(t *testing.T) { for flag, values := range expectedOccurrences { for _, value := range values { - if !containsFlagWithValue(args, flag, value) { + if !testutil.ContainsFlagWithValue(args, flag, value) { t.Errorf("Expected %s %s, not found in %v", flag, value, args) } } @@ -187,42 +124,12 @@ func TestLlamaCppBuildCommandArgs_EmptyArrays(t *testing.T) { excludedArgs := []string{"--lora", "--override-tensor"} for _, excludedArg := range excludedArgs { - if contains(args, excludedArg) { + if testutil.Contains(args, excludedArg) { t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args) } } } -func TestLlamaCppBuildCommandArgs_FieldNameConversion(t *testing.T) { - // Test snake_case to kebab-case conversion - options := backends.LlamaServerOptions{ - CtxSize: 4096, - GPULayers: 32, - ThreadsBatch: 2, - FlashAttn: true, - TopK: 40, - TopP: 0.9, - } - - args := options.BuildCommandArgs() - - // Check that field names are properly converted - expectedFlags := []string{ - "--ctx-size", // ctx_size -> ctx-size - "--gpu-layers", // gpu_layers -> gpu-layers - "--threads-batch", // threads_batch -> threads-batch - "--flash-attn", // flash_attn -> flash-attn - "--top-k", // top_k -> top-k - "--top-p", // top_p -> top-p - } - - for _, flag := range expectedFlags { - if !contains(args, flag) { - t.Errorf("Expected flag %q not found in %v", flag, args) - } - } -} - func TestLlamaCppUnmarshalJSON_StandardFields(t *testing.T) { jsonData := `{ "model": "/path/to/model.gguf", @@ -383,26 +290,81 @@ func TestParseLlamaCommand(t *testing.T) { name string command string expectErr bool + validate func(*testing.T, *backends.LlamaServerOptions) }{ { name: "basic command", command: "llama-server --model /path/to/model.gguf --gpu-layers 32", expectErr: false, + validate: func(t *testing.T, opts *backends.LlamaServerOptions) { + if opts.Model != "/path/to/model.gguf" { + t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model) + } + if opts.GPULayers != 32 { + t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers) + } + }, }, { name: "args only", command: "--model /path/to/model.gguf --ctx-size 4096", expectErr: false, + validate: func(t *testing.T, opts *backends.LlamaServerOptions) { + if opts.Model != "/path/to/model.gguf" { + t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model) + } + if opts.CtxSize != 4096 { + t.Errorf("expected ctx_size 4096, got %d", opts.CtxSize) + } + }, }, { name: "mixed flag formats", command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose", expectErr: false, + validate: func(t *testing.T, opts *backends.LlamaServerOptions) { + if opts.Model != "/path/model.gguf" { + t.Errorf("expected model '/path/model.gguf', got '%s'", opts.Model) + } + if opts.GPULayers != 16 { + t.Errorf("expected gpu_layers 16, got %d", opts.GPULayers) + } + if !opts.Verbose { + t.Errorf("expected verbose to be true") + } + }, }, { name: "quoted strings", command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`, expectErr: false, + validate: func(t *testing.T, opts *backends.LlamaServerOptions) { + if opts.APIKey != "sk-1234567890abcdef" { + t.Errorf("expected api_key 'sk-1234567890abcdef', got '%s'", opts.APIKey) + } + }, + }, + { + name: "multiple value types", + command: "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap", + expectErr: false, + validate: func(t *testing.T, opts *backends.LlamaServerOptions) { + if opts.Model != "/test/model.gguf" { + t.Errorf("expected model '/test/model.gguf', got '%s'", opts.Model) + } + if opts.GPULayers != 32 { + t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers) + } + if opts.Temperature != 0.7 { + t.Errorf("expected temperature 0.7, got %f", opts.Temperature) + } + if !opts.Verbose { + t.Errorf("expected verbose to be true") + } + if !opts.NoMmap { + t.Errorf("expected no_mmap to be true") + } + }, }, { name: "empty command", @@ -439,40 +401,16 @@ func TestParseLlamaCommand(t *testing.T) { if result == nil { t.Errorf("expected result but got nil") + return + } + + if tt.validate != nil { + tt.validate(t, result) } }) } } -func TestParseLlamaCommandValues(t *testing.T) { - command := "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap" - result, err := backends.ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "/test/model.gguf" { - t.Errorf("expected model '/test/model.gguf', got '%s'", result.Model) - } - - if result.GPULayers != 32 { - t.Errorf("expected gpu_layers 32, got %d", result.GPULayers) - } - - if result.Temperature != 0.7 { - t.Errorf("expected temperature 0.7, got %f", result.Temperature) - } - - if !result.Verbose { - t.Errorf("expected verbose to be true") - } - - if !result.NoMmap { - t.Errorf("expected no_mmap to be true") - } -} - func TestParseLlamaCommandArrays(t *testing.T) { command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin" result, err := backends.ParseLlamaCommand(command) @@ -491,21 +429,4 @@ func TestParseLlamaCommandArrays(t *testing.T) { t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i]) } } -} - -// Helper functions -func contains(slice []string, item string) bool { - return slices.Contains(slice, item) -} - -func containsFlagWithValue(args []string, flag, value string) bool { - for i, arg := range args { - if arg == flag { - // Check if there's a next argument and it matches the expected value - if i+1 < len(args) && args[i+1] == value { - return true - } - } - } - return false -} +} \ No newline at end of file diff --git a/pkg/backends/mlx_test.go b/pkg/backends/mlx_test.go index 4a4c43a..0194551 100644 --- a/pkg/backends/mlx_test.go +++ b/pkg/backends/mlx_test.go @@ -2,6 +2,7 @@ package backends_test import ( "llamactl/pkg/backends" + "llamactl/pkg/testutil" "testing" ) @@ -10,26 +11,71 @@ func TestParseMlxCommand(t *testing.T) { name string command string expectErr bool + validate func(*testing.T, *backends.MlxServerOptions) }{ { name: "basic command", command: "mlx_lm.server --model /path/to/model --host 0.0.0.0", expectErr: false, + validate: func(t *testing.T, opts *backends.MlxServerOptions) { + if opts.Model != "/path/to/model" { + t.Errorf("expected model '/path/to/model', got '%s'", opts.Model) + } + if opts.Host != "0.0.0.0" { + t.Errorf("expected host '0.0.0.0', got '%s'", opts.Host) + } + }, }, { name: "args only", command: "--model /path/to/model --port 8080", expectErr: false, + validate: func(t *testing.T, opts *backends.MlxServerOptions) { + if opts.Model != "/path/to/model" { + t.Errorf("expected model '/path/to/model', got '%s'", opts.Model) + } + if opts.Port != 8080 { + t.Errorf("expected port 8080, got %d", opts.Port) + } + }, }, { name: "mixed flag formats", command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code", expectErr: false, + validate: func(t *testing.T, opts *backends.MlxServerOptions) { + if opts.Model != "/path/model" { + t.Errorf("expected model '/path/model', got '%s'", opts.Model) + } + if opts.Temp != 0.7 { + t.Errorf("expected temp 0.7, got %f", opts.Temp) + } + if !opts.TrustRemoteCode { + t.Errorf("expected trust_remote_code to be true") + } + }, }, { - name: "quoted strings", - command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`, + name: "multiple value types", + command: "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG", expectErr: false, + validate: func(t *testing.T, opts *backends.MlxServerOptions) { + if opts.Model != "/test/model.mlx" { + t.Errorf("expected model '/test/model.mlx', got '%s'", opts.Model) + } + if opts.Port != 8080 { + t.Errorf("expected port 8080, got %d", opts.Port) + } + if opts.Temp != 0.7 { + t.Errorf("expected temp 0.7, got %f", opts.Temp) + } + if !opts.TrustRemoteCode { + t.Errorf("expected trust_remote_code to be true") + } + if opts.LogLevel != "DEBUG" { + t.Errorf("expected log_level 'DEBUG', got '%s'", opts.LogLevel) + } + }, }, { name: "empty command", @@ -66,92 +112,91 @@ func TestParseMlxCommand(t *testing.T) { if result == nil { t.Errorf("expected result but got nil") + return + } + + if tt.validate != nil { + tt.validate(t, result) } }) } } -func TestParseMlxCommandValues(t *testing.T) { - command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG" - result, err := backends.ParseMlxCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) +func TestMlxBuildCommandArgs_BooleanFields(t *testing.T) { + tests := []struct { + name string + options backends.MlxServerOptions + expected []string + excluded []string + }{ + { + name: "trust_remote_code true", + options: backends.MlxServerOptions{ + TrustRemoteCode: true, + }, + expected: []string{"--trust-remote-code"}, + }, + { + name: "trust_remote_code false", + options: backends.MlxServerOptions{ + TrustRemoteCode: false, + }, + excluded: []string{"--trust-remote-code"}, + }, + { + name: "multiple booleans", + options: backends.MlxServerOptions{ + TrustRemoteCode: true, + UseDefaultChatTemplate: true, + }, + expected: []string{"--trust-remote-code", "--use-default-chat-template"}, + }, } - if result.Model != "/test/model.mlx" { - t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args := tt.options.BuildCommandArgs() - if result.Port != 8080 { - t.Errorf("expected port 8080, got %d", result.Port) - } + for _, expectedArg := range tt.expected { + if !testutil.Contains(args, expectedArg) { + t.Errorf("Expected argument %q not found in %v", expectedArg, args) + } + } - if result.Temp != 0.7 { - t.Errorf("expected temp 0.7, got %f", result.Temp) - } - - if !result.TrustRemoteCode { - t.Errorf("expected trust_remote_code to be true") - } - - if result.LogLevel != "DEBUG" { - t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel) + for _, excludedArg := range tt.excluded { + if testutil.Contains(args, excludedArg) { + t.Errorf("Excluded argument %q found in %v", excludedArg, args) + } + } + }) } } -func TestMlxBuildCommandArgs(t *testing.T) { - options := &backends.MlxServerOptions{ - Model: "/test/model.mlx", - Host: "127.0.0.1", - Port: 8080, - Temp: 0.7, - TopP: 0.9, - TopK: 40, - MaxTokens: 2048, - TrustRemoteCode: true, - LogLevel: "DEBUG", - ChatTemplate: "custom template", +func TestMlxBuildCommandArgs_ZeroValues(t *testing.T) { + options := backends.MlxServerOptions{ + Port: 0, // Should be excluded + TopK: 0, // Should be excluded + Temp: 0, // Should be excluded + Model: "", // Should be excluded + LogLevel: "", // Should be excluded + TrustRemoteCode: false, // Should be excluded } args := options.BuildCommandArgs() - // Check that all expected flags are present - expectedFlags := map[string]string{ - "--model": "/test/model.mlx", - "--host": "127.0.0.1", - "--port": "8080", - "--log-level": "DEBUG", - "--chat-template": "custom template", - "--temp": "0.7", - "--top-p": "0.9", - "--top-k": "40", - "--max-tokens": "2048", + // Zero values should not appear in arguments + excludedArgs := []string{ + "--port", "0", + "--top-k", "0", + "--temp", "0", + "--model", "", + "--log-level", "", + "--trust-remote-code", } - for i := 0; i < len(args); i++ { - if args[i] == "--trust-remote-code" { - continue // Boolean flag with no value - } - if args[i] == "--use-default-chat-template" { - continue // Boolean flag with no value - } - - if expectedValue, exists := expectedFlags[args[i]]; exists && i+1 < len(args) { - if args[i+1] != expectedValue { - t.Errorf("expected %s to have value %s, got %s", args[i], expectedValue, args[i+1]) - } + for _, excludedArg := range excludedArgs { + if testutil.Contains(args, excludedArg) { + t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args) } } - - // Check boolean flags - foundTrustRemoteCode := false - for _, arg := range args { - if arg == "--trust-remote-code" { - foundTrustRemoteCode = true - } - } - if !foundTrustRemoteCode { - t.Errorf("expected --trust-remote-code flag to be present") - } -} +} \ No newline at end of file diff --git a/pkg/backends/vllm_test.go b/pkg/backends/vllm_test.go index 0133e37..b9e6a13 100644 --- a/pkg/backends/vllm_test.go +++ b/pkg/backends/vllm_test.go @@ -2,6 +2,7 @@ package backends_test import ( "llamactl/pkg/backends" + "llamactl/pkg/testutil" "testing" ) @@ -10,26 +11,72 @@ func TestParseVllmCommand(t *testing.T) { name string command string expectErr bool + validate func(*testing.T, *backends.VllmServerOptions) }{ { name: "basic vllm serve command", command: "vllm serve microsoft/DialoGPT-medium", expectErr: false, + validate: func(t *testing.T, opts *backends.VllmServerOptions) { + if opts.Model != "microsoft/DialoGPT-medium" { + t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model) + } + }, }, { name: "serve only command", command: "serve microsoft/DialoGPT-medium", expectErr: false, + validate: func(t *testing.T, opts *backends.VllmServerOptions) { + if opts.Model != "microsoft/DialoGPT-medium" { + t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model) + } + }, }, { name: "positional model with flags", command: "vllm serve microsoft/DialoGPT-medium --tensor-parallel-size 2", expectErr: false, + validate: func(t *testing.T, opts *backends.VllmServerOptions) { + if opts.Model != "microsoft/DialoGPT-medium" { + t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model) + } + if opts.TensorParallelSize != 2 { + t.Errorf("expected tensor_parallel_size 2, got %d", opts.TensorParallelSize) + } + }, }, { name: "model with path", command: "vllm serve /path/to/model --gpu-memory-utilization 0.8", expectErr: false, + validate: func(t *testing.T, opts *backends.VllmServerOptions) { + if opts.Model != "/path/to/model" { + t.Errorf("expected model '/path/to/model', got '%s'", opts.Model) + } + if opts.GPUMemoryUtilization != 0.8 { + t.Errorf("expected gpu_memory_utilization 0.8, got %f", opts.GPUMemoryUtilization) + } + }, + }, + { + name: "multiple value types", + command: "vllm serve test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs", + expectErr: false, + validate: func(t *testing.T, opts *backends.VllmServerOptions) { + if opts.Model != "test-model" { + t.Errorf("expected model 'test-model', got '%s'", opts.Model) + } + if opts.TensorParallelSize != 4 { + t.Errorf("expected tensor_parallel_size 4, got %d", opts.TensorParallelSize) + } + if opts.GPUMemoryUtilization != 0.8 { + t.Errorf("expected gpu_memory_utilization 0.8, got %f", opts.GPUMemoryUtilization) + } + if !opts.EnableLogOutputs { + t.Errorf("expected enable_log_outputs true, got %v", opts.EnableLogOutputs) + } + }, }, { name: "empty command", @@ -61,34 +108,144 @@ func TestParseVllmCommand(t *testing.T) { if result == nil { t.Errorf("expected result but got nil") + return + } + + if tt.validate != nil { + tt.validate(t, result) } }) } } -func TestParseVllmCommandValues(t *testing.T) { - command := "vllm serve test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs" - result, err := backends.ParseVllmCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) +func TestVllmBuildCommandArgs_BooleanFields(t *testing.T) { + tests := []struct { + name string + options backends.VllmServerOptions + expected []string + excluded []string + }{ + { + name: "enable_log_outputs true", + options: backends.VllmServerOptions{ + EnableLogOutputs: true, + }, + expected: []string{"--enable-log-outputs"}, + }, + { + name: "enable_log_outputs false", + options: backends.VllmServerOptions{ + EnableLogOutputs: false, + }, + excluded: []string{"--enable-log-outputs"}, + }, + { + name: "multiple booleans", + options: backends.VllmServerOptions{ + EnableLogOutputs: true, + TrustRemoteCode: true, + EnablePrefixCaching: true, + DisableLogStats: false, + }, + expected: []string{"--enable-log-outputs", "--trust-remote-code", "--enable-prefix-caching"}, + excluded: []string{"--disable-log-stats"}, + }, } - if result.Model != "test-model" { - t.Errorf("expected model 'test-model', got '%s'", result.Model) - } - if result.TensorParallelSize != 4 { - t.Errorf("expected tensor_parallel_size 4, got %d", result.TensorParallelSize) - } - if result.GPUMemoryUtilization != 0.8 { - t.Errorf("expected gpu_memory_utilization 0.8, got %f", result.GPUMemoryUtilization) - } - if !result.EnableLogOutputs { - t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args := tt.options.BuildCommandArgs() + + for _, expectedArg := range tt.expected { + if !testutil.Contains(args, expectedArg) { + t.Errorf("Expected argument %q not found in %v", expectedArg, args) + } + } + + for _, excludedArg := range tt.excluded { + if testutil.Contains(args, excludedArg) { + t.Errorf("Excluded argument %q found in %v", excludedArg, args) + } + } + }) } } -func TestVllmBuildCommandArgs(t *testing.T) { +func TestVllmBuildCommandArgs_ZeroValues(t *testing.T) { + options := backends.VllmServerOptions{ + Port: 0, // Should be excluded + TensorParallelSize: 0, // Should be excluded + GPUMemoryUtilization: 0, // Should be excluded + Model: "", // Should be excluded (positional arg) + Host: "", // Should be excluded + EnableLogOutputs: false, // Should be excluded + } + + args := options.BuildCommandArgs() + + // Zero values should not appear in arguments + excludedArgs := []string{ + "--port", "0", + "--tensor-parallel-size", "0", + "--gpu-memory-utilization", "0", + "--host", "", + "--enable-log-outputs", + } + + for _, excludedArg := range excludedArgs { + if testutil.Contains(args, excludedArg) { + t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args) + } + } + + // Model should not be present as positional arg when empty + if len(args) > 0 && args[0] == "" { + t.Errorf("Empty model should not be present as positional argument") + } +} + +func TestVllmBuildCommandArgs_ArrayFields(t *testing.T) { + options := backends.VllmServerOptions{ + AllowedOrigins: []string{"http://localhost:3000", "https://example.com"}, + AllowedMethods: []string{"GET", "POST"}, + Middleware: []string{"middleware1", "middleware2", "middleware3"}, + } + + args := options.BuildCommandArgs() + + // Check that each array value appears with its flag + expectedOccurrences := map[string][]string{ + "--allowed-origins": {"http://localhost:3000", "https://example.com"}, + "--allowed-methods": {"GET", "POST"}, + "--middleware": {"middleware1", "middleware2", "middleware3"}, + } + + for flag, values := range expectedOccurrences { + for _, value := range values { + if !testutil.ContainsFlagWithValue(args, flag, value) { + t.Errorf("Expected %s %s, not found in %v", flag, value, args) + } + } + } +} + +func TestVllmBuildCommandArgs_EmptyArrays(t *testing.T) { + options := backends.VllmServerOptions{ + AllowedOrigins: []string{}, // Empty array should not generate args + Middleware: []string{}, // Empty array should not generate args + } + + args := options.BuildCommandArgs() + + excludedArgs := []string{"--allowed-origins", "--middleware"} + for _, excludedArg := range excludedArgs { + if testutil.Contains(args, excludedArg) { + t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args) + } + } +} + +func TestVllmBuildCommandArgs_PositionalModel(t *testing.T) { options := backends.VllmServerOptions{ Model: "microsoft/DialoGPT-medium", Port: 8080, @@ -96,7 +253,6 @@ func TestVllmBuildCommandArgs(t *testing.T) { TensorParallelSize: 2, GPUMemoryUtilization: 0.8, EnableLogOutputs: true, - AllowedOrigins: []string{"http://localhost:3000", "https://example.com"}, } args := options.BuildCommandArgs() @@ -107,32 +263,24 @@ func TestVllmBuildCommandArgs(t *testing.T) { } // Check that --model flag is NOT present (since model should be positional) - if contains(args, "--model") { + if testutil.Contains(args, "--model") { t.Errorf("Found --model flag, but model should be positional argument in args: %v", args) } // Check other flags - if !containsFlagWithValue(args, "--tensor-parallel-size", "2") { + if !testutil.ContainsFlagWithValue(args, "--tensor-parallel-size", "2") { t.Errorf("Expected --tensor-parallel-size 2 not found in %v", args) } - if !contains(args, "--enable-log-outputs") { + if !testutil.ContainsFlagWithValue(args, "--gpu-memory-utilization", "0.8") { + t.Errorf("Expected --gpu-memory-utilization 0.8 not found in %v", args) + } + if !testutil.Contains(args, "--enable-log-outputs") { t.Errorf("Expected --enable-log-outputs not found in %v", args) } - if !contains(args, "--host") { - t.Errorf("Expected --host not found in %v", args) + if !testutil.ContainsFlagWithValue(args, "--host", "localhost") { + t.Errorf("Expected --host localhost not found in %v", args) } - if !contains(args, "--port") { - t.Errorf("Expected --port not found in %v", args) - } - - // Check array handling (multiple flags) - allowedOriginsCount := 0 - for i := range args { - if args[i] == "--allowed-origins" { - allowedOriginsCount++ - } - } - if allowedOriginsCount != 2 { - t.Errorf("Expected 2 --allowed-origins flags, got %d", allowedOriginsCount) + if !testutil.ContainsFlagWithValue(args, "--port", "8080") { + t.Errorf("Expected --port 8080 not found in %v", args) } } diff --git a/pkg/testutil/helpers.go b/pkg/testutil/helpers.go index 7b7fe0c..73c83c5 100644 --- a/pkg/testutil/helpers.go +++ b/pkg/testutil/helpers.go @@ -1,5 +1,7 @@ package testutil +import "slices" + // Helper functions for pointer fields func BoolPtr(b bool) *bool { return &b @@ -8,3 +10,23 @@ func BoolPtr(b bool) *bool { func IntPtr(i int) *int { return &i } + +// Helper functions for testing command arguments + +// Contains checks if a slice contains a specific item +func Contains(slice []string, item string) bool { + return slices.Contains(slice, item) +} + +// ContainsFlagWithValue checks if args contains a flag followed by a specific value +func ContainsFlagWithValue(args []string, flag, value string) bool { + for i, arg := range args { + if arg == flag { + // Check if there's a next argument and it matches the expected value + if i+1 < len(args) && args[i+1] == value { + return true + } + } + } + return false +}