Implement mlx and cllm tests and remove redundant code

This commit is contained in:
2025-10-19 19:45:31 +02:00
parent 72fe780e31
commit f42f000539
4 changed files with 390 additions and 254 deletions

View File

@@ -4,44 +4,11 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"llamactl/pkg/backends" "llamactl/pkg/backends"
"llamactl/pkg/testutil"
"reflect" "reflect"
"slices"
"testing" "testing"
) )
func TestLlamaCppBuildCommandArgs_BasicFields(t *testing.T) {
options := backends.LlamaServerOptions{
Model: "/path/to/model.gguf",
Port: 8080,
Host: "localhost",
Verbose: true,
CtxSize: 4096,
GPULayers: 32,
}
args := options.BuildCommandArgs()
// Check individual arguments
expectedPairs := map[string]string{
"--model": "/path/to/model.gguf",
"--port": "8080",
"--host": "localhost",
"--ctx-size": "4096",
"--gpu-layers": "32",
}
for flag, expectedValue := range expectedPairs {
if !containsFlagWithValue(args, flag, expectedValue) {
t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args)
}
}
// Check standalone boolean flag
if !contains(args, "--verbose") {
t.Errorf("Expected --verbose flag not found in %v", args)
}
}
func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) { func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -81,13 +48,13 @@ func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) {
args := tt.options.BuildCommandArgs() args := tt.options.BuildCommandArgs()
for _, expectedArg := range tt.expected { for _, expectedArg := range tt.expected {
if !contains(args, expectedArg) { if !testutil.Contains(args, expectedArg) {
t.Errorf("Expected argument %q not found in %v", expectedArg, args) t.Errorf("Expected argument %q not found in %v", expectedArg, args)
} }
} }
for _, excludedArg := range tt.excluded { for _, excludedArg := range tt.excluded {
if contains(args, excludedArg) { if testutil.Contains(args, excludedArg) {
t.Errorf("Excluded argument %q found in %v", excludedArg, args) t.Errorf("Excluded argument %q found in %v", excludedArg, args)
} }
} }
@@ -95,36 +62,6 @@ func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) {
} }
} }
func TestLlamaCppBuildCommandArgs_NumericFields(t *testing.T) {
options := backends.LlamaServerOptions{
Port: 8080,
Threads: 4,
CtxSize: 2048,
GPULayers: 16,
Temperature: 0.7,
TopK: 40,
TopP: 0.9,
}
args := options.BuildCommandArgs()
expectedPairs := map[string]string{
"--port": "8080",
"--threads": "4",
"--ctx-size": "2048",
"--gpu-layers": "16",
"--temp": "0.7",
"--top-k": "40",
"--top-p": "0.9",
}
for flag, expectedValue := range expectedPairs {
if !containsFlagWithValue(args, flag, expectedValue) {
t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args)
}
}
}
func TestLlamaCppBuildCommandArgs_ZeroValues(t *testing.T) { func TestLlamaCppBuildCommandArgs_ZeroValues(t *testing.T) {
options := backends.LlamaServerOptions{ options := backends.LlamaServerOptions{
Port: 0, // Should be excluded Port: 0, // Should be excluded
@@ -146,7 +83,7 @@ func TestLlamaCppBuildCommandArgs_ZeroValues(t *testing.T) {
} }
for _, excludedArg := range excludedArgs { for _, excludedArg := range excludedArgs {
if contains(args, excludedArg) { if testutil.Contains(args, excludedArg) {
t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args) t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args)
} }
} }
@@ -170,7 +107,7 @@ func TestLlamaCppBuildCommandArgs_ArrayFields(t *testing.T) {
for flag, values := range expectedOccurrences { for flag, values := range expectedOccurrences {
for _, value := range values { for _, value := range values {
if !containsFlagWithValue(args, flag, value) { if !testutil.ContainsFlagWithValue(args, flag, value) {
t.Errorf("Expected %s %s, not found in %v", flag, value, args) t.Errorf("Expected %s %s, not found in %v", flag, value, args)
} }
} }
@@ -187,42 +124,12 @@ func TestLlamaCppBuildCommandArgs_EmptyArrays(t *testing.T) {
excludedArgs := []string{"--lora", "--override-tensor"} excludedArgs := []string{"--lora", "--override-tensor"}
for _, excludedArg := range excludedArgs { for _, excludedArg := range excludedArgs {
if contains(args, excludedArg) { if testutil.Contains(args, excludedArg) {
t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args) t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args)
} }
} }
} }
func TestLlamaCppBuildCommandArgs_FieldNameConversion(t *testing.T) {
// Test snake_case to kebab-case conversion
options := backends.LlamaServerOptions{
CtxSize: 4096,
GPULayers: 32,
ThreadsBatch: 2,
FlashAttn: true,
TopK: 40,
TopP: 0.9,
}
args := options.BuildCommandArgs()
// Check that field names are properly converted
expectedFlags := []string{
"--ctx-size", // ctx_size -> ctx-size
"--gpu-layers", // gpu_layers -> gpu-layers
"--threads-batch", // threads_batch -> threads-batch
"--flash-attn", // flash_attn -> flash-attn
"--top-k", // top_k -> top-k
"--top-p", // top_p -> top-p
}
for _, flag := range expectedFlags {
if !contains(args, flag) {
t.Errorf("Expected flag %q not found in %v", flag, args)
}
}
}
func TestLlamaCppUnmarshalJSON_StandardFields(t *testing.T) { func TestLlamaCppUnmarshalJSON_StandardFields(t *testing.T) {
jsonData := `{ jsonData := `{
"model": "/path/to/model.gguf", "model": "/path/to/model.gguf",
@@ -383,26 +290,81 @@ func TestParseLlamaCommand(t *testing.T) {
name string name string
command string command string
expectErr bool expectErr bool
validate func(*testing.T, *backends.LlamaServerOptions)
}{ }{
{ {
name: "basic command", name: "basic command",
command: "llama-server --model /path/to/model.gguf --gpu-layers 32", command: "llama-server --model /path/to/model.gguf --gpu-layers 32",
expectErr: false, expectErr: false,
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
if opts.Model != "/path/to/model.gguf" {
t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model)
}
if opts.GPULayers != 32 {
t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers)
}
},
}, },
{ {
name: "args only", name: "args only",
command: "--model /path/to/model.gguf --ctx-size 4096", command: "--model /path/to/model.gguf --ctx-size 4096",
expectErr: false, expectErr: false,
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
if opts.Model != "/path/to/model.gguf" {
t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model)
}
if opts.CtxSize != 4096 {
t.Errorf("expected ctx_size 4096, got %d", opts.CtxSize)
}
},
}, },
{ {
name: "mixed flag formats", name: "mixed flag formats",
command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose", command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose",
expectErr: false, expectErr: false,
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
if opts.Model != "/path/model.gguf" {
t.Errorf("expected model '/path/model.gguf', got '%s'", opts.Model)
}
if opts.GPULayers != 16 {
t.Errorf("expected gpu_layers 16, got %d", opts.GPULayers)
}
if !opts.Verbose {
t.Errorf("expected verbose to be true")
}
},
}, },
{ {
name: "quoted strings", name: "quoted strings",
command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`, command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`,
expectErr: false, expectErr: false,
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
if opts.APIKey != "sk-1234567890abcdef" {
t.Errorf("expected api_key 'sk-1234567890abcdef', got '%s'", opts.APIKey)
}
},
},
{
name: "multiple value types",
command: "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap",
expectErr: false,
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
if opts.Model != "/test/model.gguf" {
t.Errorf("expected model '/test/model.gguf', got '%s'", opts.Model)
}
if opts.GPULayers != 32 {
t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers)
}
if opts.Temperature != 0.7 {
t.Errorf("expected temperature 0.7, got %f", opts.Temperature)
}
if !opts.Verbose {
t.Errorf("expected verbose to be true")
}
if !opts.NoMmap {
t.Errorf("expected no_mmap to be true")
}
},
}, },
{ {
name: "empty command", name: "empty command",
@@ -439,40 +401,16 @@ func TestParseLlamaCommand(t *testing.T) {
if result == nil { if result == nil {
t.Errorf("expected result but got nil") t.Errorf("expected result but got nil")
return
}
if tt.validate != nil {
tt.validate(t, result)
} }
}) })
} }
} }
func TestParseLlamaCommandValues(t *testing.T) {
command := "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap"
result, err := backends.ParseLlamaCommand(command)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "/test/model.gguf" {
t.Errorf("expected model '/test/model.gguf', got '%s'", result.Model)
}
if result.GPULayers != 32 {
t.Errorf("expected gpu_layers 32, got %d", result.GPULayers)
}
if result.Temperature != 0.7 {
t.Errorf("expected temperature 0.7, got %f", result.Temperature)
}
if !result.Verbose {
t.Errorf("expected verbose to be true")
}
if !result.NoMmap {
t.Errorf("expected no_mmap to be true")
}
}
func TestParseLlamaCommandArrays(t *testing.T) { func TestParseLlamaCommandArrays(t *testing.T) {
command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin" command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin"
result, err := backends.ParseLlamaCommand(command) result, err := backends.ParseLlamaCommand(command)
@@ -491,21 +429,4 @@ func TestParseLlamaCommandArrays(t *testing.T) {
t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i]) t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i])
} }
} }
} }
// Helper functions
func contains(slice []string, item string) bool {
return slices.Contains(slice, item)
}
func containsFlagWithValue(args []string, flag, value string) bool {
for i, arg := range args {
if arg == flag {
// Check if there's a next argument and it matches the expected value
if i+1 < len(args) && args[i+1] == value {
return true
}
}
}
return false
}

View File

@@ -2,6 +2,7 @@ package backends_test
import ( import (
"llamactl/pkg/backends" "llamactl/pkg/backends"
"llamactl/pkg/testutil"
"testing" "testing"
) )
@@ -10,26 +11,71 @@ func TestParseMlxCommand(t *testing.T) {
name string name string
command string command string
expectErr bool expectErr bool
validate func(*testing.T, *backends.MlxServerOptions)
}{ }{
{ {
name: "basic command", name: "basic command",
command: "mlx_lm.server --model /path/to/model --host 0.0.0.0", command: "mlx_lm.server --model /path/to/model --host 0.0.0.0",
expectErr: false, expectErr: false,
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
if opts.Model != "/path/to/model" {
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
}
if opts.Host != "0.0.0.0" {
t.Errorf("expected host '0.0.0.0', got '%s'", opts.Host)
}
},
}, },
{ {
name: "args only", name: "args only",
command: "--model /path/to/model --port 8080", command: "--model /path/to/model --port 8080",
expectErr: false, expectErr: false,
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
if opts.Model != "/path/to/model" {
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
}
if opts.Port != 8080 {
t.Errorf("expected port 8080, got %d", opts.Port)
}
},
}, },
{ {
name: "mixed flag formats", name: "mixed flag formats",
command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code", command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code",
expectErr: false, expectErr: false,
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
if opts.Model != "/path/model" {
t.Errorf("expected model '/path/model', got '%s'", opts.Model)
}
if opts.Temp != 0.7 {
t.Errorf("expected temp 0.7, got %f", opts.Temp)
}
if !opts.TrustRemoteCode {
t.Errorf("expected trust_remote_code to be true")
}
},
}, },
{ {
name: "quoted strings", name: "multiple value types",
command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`, command: "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG",
expectErr: false, expectErr: false,
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
if opts.Model != "/test/model.mlx" {
t.Errorf("expected model '/test/model.mlx', got '%s'", opts.Model)
}
if opts.Port != 8080 {
t.Errorf("expected port 8080, got %d", opts.Port)
}
if opts.Temp != 0.7 {
t.Errorf("expected temp 0.7, got %f", opts.Temp)
}
if !opts.TrustRemoteCode {
t.Errorf("expected trust_remote_code to be true")
}
if opts.LogLevel != "DEBUG" {
t.Errorf("expected log_level 'DEBUG', got '%s'", opts.LogLevel)
}
},
}, },
{ {
name: "empty command", name: "empty command",
@@ -66,92 +112,91 @@ func TestParseMlxCommand(t *testing.T) {
if result == nil { if result == nil {
t.Errorf("expected result but got nil") t.Errorf("expected result but got nil")
return
}
if tt.validate != nil {
tt.validate(t, result)
} }
}) })
} }
} }
func TestParseMlxCommandValues(t *testing.T) { func TestMlxBuildCommandArgs_BooleanFields(t *testing.T) {
command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG" tests := []struct {
result, err := backends.ParseMlxCommand(command) name string
options backends.MlxServerOptions
if err != nil { expected []string
t.Fatalf("unexpected error: %v", err) excluded []string
}{
{
name: "trust_remote_code true",
options: backends.MlxServerOptions{
TrustRemoteCode: true,
},
expected: []string{"--trust-remote-code"},
},
{
name: "trust_remote_code false",
options: backends.MlxServerOptions{
TrustRemoteCode: false,
},
excluded: []string{"--trust-remote-code"},
},
{
name: "multiple booleans",
options: backends.MlxServerOptions{
TrustRemoteCode: true,
UseDefaultChatTemplate: true,
},
expected: []string{"--trust-remote-code", "--use-default-chat-template"},
},
} }
if result.Model != "/test/model.mlx" { for _, tt := range tests {
t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model) t.Run(tt.name, func(t *testing.T) {
} args := tt.options.BuildCommandArgs()
if result.Port != 8080 { for _, expectedArg := range tt.expected {
t.Errorf("expected port 8080, got %d", result.Port) if !testutil.Contains(args, expectedArg) {
} t.Errorf("Expected argument %q not found in %v", expectedArg, args)
}
}
if result.Temp != 0.7 { for _, excludedArg := range tt.excluded {
t.Errorf("expected temp 0.7, got %f", result.Temp) if testutil.Contains(args, excludedArg) {
} t.Errorf("Excluded argument %q found in %v", excludedArg, args)
}
if !result.TrustRemoteCode { }
t.Errorf("expected trust_remote_code to be true") })
}
if result.LogLevel != "DEBUG" {
t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel)
} }
} }
func TestMlxBuildCommandArgs(t *testing.T) { func TestMlxBuildCommandArgs_ZeroValues(t *testing.T) {
options := &backends.MlxServerOptions{ options := backends.MlxServerOptions{
Model: "/test/model.mlx", Port: 0, // Should be excluded
Host: "127.0.0.1", TopK: 0, // Should be excluded
Port: 8080, Temp: 0, // Should be excluded
Temp: 0.7, Model: "", // Should be excluded
TopP: 0.9, LogLevel: "", // Should be excluded
TopK: 40, TrustRemoteCode: false, // Should be excluded
MaxTokens: 2048,
TrustRemoteCode: true,
LogLevel: "DEBUG",
ChatTemplate: "custom template",
} }
args := options.BuildCommandArgs() args := options.BuildCommandArgs()
// Check that all expected flags are present // Zero values should not appear in arguments
expectedFlags := map[string]string{ excludedArgs := []string{
"--model": "/test/model.mlx", "--port", "0",
"--host": "127.0.0.1", "--top-k", "0",
"--port": "8080", "--temp", "0",
"--log-level": "DEBUG", "--model", "",
"--chat-template": "custom template", "--log-level", "",
"--temp": "0.7", "--trust-remote-code",
"--top-p": "0.9",
"--top-k": "40",
"--max-tokens": "2048",
} }
for i := 0; i < len(args); i++ { for _, excludedArg := range excludedArgs {
if args[i] == "--trust-remote-code" { if testutil.Contains(args, excludedArg) {
continue // Boolean flag with no value t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args)
}
if args[i] == "--use-default-chat-template" {
continue // Boolean flag with no value
}
if expectedValue, exists := expectedFlags[args[i]]; exists && i+1 < len(args) {
if args[i+1] != expectedValue {
t.Errorf("expected %s to have value %s, got %s", args[i], expectedValue, args[i+1])
}
} }
} }
}
// Check boolean flags
foundTrustRemoteCode := false
for _, arg := range args {
if arg == "--trust-remote-code" {
foundTrustRemoteCode = true
}
}
if !foundTrustRemoteCode {
t.Errorf("expected --trust-remote-code flag to be present")
}
}

View File

@@ -2,6 +2,7 @@ package backends_test
import ( import (
"llamactl/pkg/backends" "llamactl/pkg/backends"
"llamactl/pkg/testutil"
"testing" "testing"
) )
@@ -10,26 +11,72 @@ func TestParseVllmCommand(t *testing.T) {
name string name string
command string command string
expectErr bool expectErr bool
validate func(*testing.T, *backends.VllmServerOptions)
}{ }{
{ {
name: "basic vllm serve command", name: "basic vllm serve command",
command: "vllm serve microsoft/DialoGPT-medium", command: "vllm serve microsoft/DialoGPT-medium",
expectErr: false, expectErr: false,
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
if opts.Model != "microsoft/DialoGPT-medium" {
t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model)
}
},
}, },
{ {
name: "serve only command", name: "serve only command",
command: "serve microsoft/DialoGPT-medium", command: "serve microsoft/DialoGPT-medium",
expectErr: false, expectErr: false,
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
if opts.Model != "microsoft/DialoGPT-medium" {
t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model)
}
},
}, },
{ {
name: "positional model with flags", name: "positional model with flags",
command: "vllm serve microsoft/DialoGPT-medium --tensor-parallel-size 2", command: "vllm serve microsoft/DialoGPT-medium --tensor-parallel-size 2",
expectErr: false, expectErr: false,
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
if opts.Model != "microsoft/DialoGPT-medium" {
t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model)
}
if opts.TensorParallelSize != 2 {
t.Errorf("expected tensor_parallel_size 2, got %d", opts.TensorParallelSize)
}
},
}, },
{ {
name: "model with path", name: "model with path",
command: "vllm serve /path/to/model --gpu-memory-utilization 0.8", command: "vllm serve /path/to/model --gpu-memory-utilization 0.8",
expectErr: false, expectErr: false,
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
if opts.Model != "/path/to/model" {
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
}
if opts.GPUMemoryUtilization != 0.8 {
t.Errorf("expected gpu_memory_utilization 0.8, got %f", opts.GPUMemoryUtilization)
}
},
},
{
name: "multiple value types",
command: "vllm serve test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs",
expectErr: false,
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
if opts.Model != "test-model" {
t.Errorf("expected model 'test-model', got '%s'", opts.Model)
}
if opts.TensorParallelSize != 4 {
t.Errorf("expected tensor_parallel_size 4, got %d", opts.TensorParallelSize)
}
if opts.GPUMemoryUtilization != 0.8 {
t.Errorf("expected gpu_memory_utilization 0.8, got %f", opts.GPUMemoryUtilization)
}
if !opts.EnableLogOutputs {
t.Errorf("expected enable_log_outputs true, got %v", opts.EnableLogOutputs)
}
},
}, },
{ {
name: "empty command", name: "empty command",
@@ -61,34 +108,144 @@ func TestParseVllmCommand(t *testing.T) {
if result == nil { if result == nil {
t.Errorf("expected result but got nil") t.Errorf("expected result but got nil")
return
}
if tt.validate != nil {
tt.validate(t, result)
} }
}) })
} }
} }
func TestParseVllmCommandValues(t *testing.T) { func TestVllmBuildCommandArgs_BooleanFields(t *testing.T) {
command := "vllm serve test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs" tests := []struct {
result, err := backends.ParseVllmCommand(command) name string
options backends.VllmServerOptions
if err != nil { expected []string
t.Fatalf("unexpected error: %v", err) excluded []string
}{
{
name: "enable_log_outputs true",
options: backends.VllmServerOptions{
EnableLogOutputs: true,
},
expected: []string{"--enable-log-outputs"},
},
{
name: "enable_log_outputs false",
options: backends.VllmServerOptions{
EnableLogOutputs: false,
},
excluded: []string{"--enable-log-outputs"},
},
{
name: "multiple booleans",
options: backends.VllmServerOptions{
EnableLogOutputs: true,
TrustRemoteCode: true,
EnablePrefixCaching: true,
DisableLogStats: false,
},
expected: []string{"--enable-log-outputs", "--trust-remote-code", "--enable-prefix-caching"},
excluded: []string{"--disable-log-stats"},
},
} }
if result.Model != "test-model" { for _, tt := range tests {
t.Errorf("expected model 'test-model', got '%s'", result.Model) t.Run(tt.name, func(t *testing.T) {
} args := tt.options.BuildCommandArgs()
if result.TensorParallelSize != 4 {
t.Errorf("expected tensor_parallel_size 4, got %d", result.TensorParallelSize) for _, expectedArg := range tt.expected {
} if !testutil.Contains(args, expectedArg) {
if result.GPUMemoryUtilization != 0.8 { t.Errorf("Expected argument %q not found in %v", expectedArg, args)
t.Errorf("expected gpu_memory_utilization 0.8, got %f", result.GPUMemoryUtilization) }
} }
if !result.EnableLogOutputs {
t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs) for _, excludedArg := range tt.excluded {
if testutil.Contains(args, excludedArg) {
t.Errorf("Excluded argument %q found in %v", excludedArg, args)
}
}
})
} }
} }
func TestVllmBuildCommandArgs(t *testing.T) { func TestVllmBuildCommandArgs_ZeroValues(t *testing.T) {
options := backends.VllmServerOptions{
Port: 0, // Should be excluded
TensorParallelSize: 0, // Should be excluded
GPUMemoryUtilization: 0, // Should be excluded
Model: "", // Should be excluded (positional arg)
Host: "", // Should be excluded
EnableLogOutputs: false, // Should be excluded
}
args := options.BuildCommandArgs()
// Zero values should not appear in arguments
excludedArgs := []string{
"--port", "0",
"--tensor-parallel-size", "0",
"--gpu-memory-utilization", "0",
"--host", "",
"--enable-log-outputs",
}
for _, excludedArg := range excludedArgs {
if testutil.Contains(args, excludedArg) {
t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args)
}
}
// Model should not be present as positional arg when empty
if len(args) > 0 && args[0] == "" {
t.Errorf("Empty model should not be present as positional argument")
}
}
func TestVllmBuildCommandArgs_ArrayFields(t *testing.T) {
options := backends.VllmServerOptions{
AllowedOrigins: []string{"http://localhost:3000", "https://example.com"},
AllowedMethods: []string{"GET", "POST"},
Middleware: []string{"middleware1", "middleware2", "middleware3"},
}
args := options.BuildCommandArgs()
// Check that each array value appears with its flag
expectedOccurrences := map[string][]string{
"--allowed-origins": {"http://localhost:3000", "https://example.com"},
"--allowed-methods": {"GET", "POST"},
"--middleware": {"middleware1", "middleware2", "middleware3"},
}
for flag, values := range expectedOccurrences {
for _, value := range values {
if !testutil.ContainsFlagWithValue(args, flag, value) {
t.Errorf("Expected %s %s, not found in %v", flag, value, args)
}
}
}
}
func TestVllmBuildCommandArgs_EmptyArrays(t *testing.T) {
options := backends.VllmServerOptions{
AllowedOrigins: []string{}, // Empty array should not generate args
Middleware: []string{}, // Empty array should not generate args
}
args := options.BuildCommandArgs()
excludedArgs := []string{"--allowed-origins", "--middleware"}
for _, excludedArg := range excludedArgs {
if testutil.Contains(args, excludedArg) {
t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args)
}
}
}
func TestVllmBuildCommandArgs_PositionalModel(t *testing.T) {
options := backends.VllmServerOptions{ options := backends.VllmServerOptions{
Model: "microsoft/DialoGPT-medium", Model: "microsoft/DialoGPT-medium",
Port: 8080, Port: 8080,
@@ -96,7 +253,6 @@ func TestVllmBuildCommandArgs(t *testing.T) {
TensorParallelSize: 2, TensorParallelSize: 2,
GPUMemoryUtilization: 0.8, GPUMemoryUtilization: 0.8,
EnableLogOutputs: true, EnableLogOutputs: true,
AllowedOrigins: []string{"http://localhost:3000", "https://example.com"},
} }
args := options.BuildCommandArgs() args := options.BuildCommandArgs()
@@ -107,32 +263,24 @@ func TestVllmBuildCommandArgs(t *testing.T) {
} }
// Check that --model flag is NOT present (since model should be positional) // Check that --model flag is NOT present (since model should be positional)
if contains(args, "--model") { if testutil.Contains(args, "--model") {
t.Errorf("Found --model flag, but model should be positional argument in args: %v", args) t.Errorf("Found --model flag, but model should be positional argument in args: %v", args)
} }
// Check other flags // Check other flags
if !containsFlagWithValue(args, "--tensor-parallel-size", "2") { if !testutil.ContainsFlagWithValue(args, "--tensor-parallel-size", "2") {
t.Errorf("Expected --tensor-parallel-size 2 not found in %v", args) t.Errorf("Expected --tensor-parallel-size 2 not found in %v", args)
} }
if !contains(args, "--enable-log-outputs") { if !testutil.ContainsFlagWithValue(args, "--gpu-memory-utilization", "0.8") {
t.Errorf("Expected --gpu-memory-utilization 0.8 not found in %v", args)
}
if !testutil.Contains(args, "--enable-log-outputs") {
t.Errorf("Expected --enable-log-outputs not found in %v", args) t.Errorf("Expected --enable-log-outputs not found in %v", args)
} }
if !contains(args, "--host") { if !testutil.ContainsFlagWithValue(args, "--host", "localhost") {
t.Errorf("Expected --host not found in %v", args) t.Errorf("Expected --host localhost not found in %v", args)
} }
if !contains(args, "--port") { if !testutil.ContainsFlagWithValue(args, "--port", "8080") {
t.Errorf("Expected --port not found in %v", args) t.Errorf("Expected --port 8080 not found in %v", args)
}
// Check array handling (multiple flags)
allowedOriginsCount := 0
for i := range args {
if args[i] == "--allowed-origins" {
allowedOriginsCount++
}
}
if allowedOriginsCount != 2 {
t.Errorf("Expected 2 --allowed-origins flags, got %d", allowedOriginsCount)
} }
} }

View File

@@ -1,5 +1,7 @@
package testutil package testutil
import "slices"
// Helper functions for pointer fields // Helper functions for pointer fields
func BoolPtr(b bool) *bool { func BoolPtr(b bool) *bool {
return &b return &b
@@ -8,3 +10,23 @@ func BoolPtr(b bool) *bool {
func IntPtr(i int) *int { func IntPtr(i int) *int {
return &i return &i
} }
// Helper functions for testing command arguments
// Contains checks if a slice contains a specific item
func Contains(slice []string, item string) bool {
return slices.Contains(slice, item)
}
// ContainsFlagWithValue checks if args contains a flag followed by a specific value
func ContainsFlagWithValue(args []string, flag, value string) bool {
for i, arg := range args {
if arg == flag {
// Check if there's a next argument and it matches the expected value
if i+1 < len(args) && args[i+1] == value {
return true
}
}
}
return false
}