mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-05 16:44:22 +00:00
Implement mlx and cllm tests and remove redundant code
This commit is contained in:
@@ -4,44 +4,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/testutil"
|
||||
"reflect"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLlamaCppBuildCommandArgs_BasicFields(t *testing.T) {
|
||||
options := backends.LlamaServerOptions{
|
||||
Model: "/path/to/model.gguf",
|
||||
Port: 8080,
|
||||
Host: "localhost",
|
||||
Verbose: true,
|
||||
CtxSize: 4096,
|
||||
GPULayers: 32,
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Check individual arguments
|
||||
expectedPairs := map[string]string{
|
||||
"--model": "/path/to/model.gguf",
|
||||
"--port": "8080",
|
||||
"--host": "localhost",
|
||||
"--ctx-size": "4096",
|
||||
"--gpu-layers": "32",
|
||||
}
|
||||
|
||||
for flag, expectedValue := range expectedPairs {
|
||||
if !containsFlagWithValue(args, flag, expectedValue) {
|
||||
t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args)
|
||||
}
|
||||
}
|
||||
|
||||
// Check standalone boolean flag
|
||||
if !contains(args, "--verbose") {
|
||||
t.Errorf("Expected --verbose flag not found in %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -81,13 +48,13 @@ func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
args := tt.options.BuildCommandArgs()
|
||||
|
||||
for _, expectedArg := range tt.expected {
|
||||
if !contains(args, expectedArg) {
|
||||
if !testutil.Contains(args, expectedArg) {
|
||||
t.Errorf("Expected argument %q not found in %v", expectedArg, args)
|
||||
}
|
||||
}
|
||||
|
||||
for _, excludedArg := range tt.excluded {
|
||||
if contains(args, excludedArg) {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Excluded argument %q found in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
@@ -95,36 +62,6 @@ func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLlamaCppBuildCommandArgs_NumericFields(t *testing.T) {
|
||||
options := backends.LlamaServerOptions{
|
||||
Port: 8080,
|
||||
Threads: 4,
|
||||
CtxSize: 2048,
|
||||
GPULayers: 16,
|
||||
Temperature: 0.7,
|
||||
TopK: 40,
|
||||
TopP: 0.9,
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
expectedPairs := map[string]string{
|
||||
"--port": "8080",
|
||||
"--threads": "4",
|
||||
"--ctx-size": "2048",
|
||||
"--gpu-layers": "16",
|
||||
"--temp": "0.7",
|
||||
"--top-k": "40",
|
||||
"--top-p": "0.9",
|
||||
}
|
||||
|
||||
for flag, expectedValue := range expectedPairs {
|
||||
if !containsFlagWithValue(args, flag, expectedValue) {
|
||||
t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLlamaCppBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||
options := backends.LlamaServerOptions{
|
||||
Port: 0, // Should be excluded
|
||||
@@ -146,7 +83,7 @@ func TestLlamaCppBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, excludedArg := range excludedArgs {
|
||||
if contains(args, excludedArg) {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
@@ -170,7 +107,7 @@ func TestLlamaCppBuildCommandArgs_ArrayFields(t *testing.T) {
|
||||
|
||||
for flag, values := range expectedOccurrences {
|
||||
for _, value := range values {
|
||||
if !containsFlagWithValue(args, flag, value) {
|
||||
if !testutil.ContainsFlagWithValue(args, flag, value) {
|
||||
t.Errorf("Expected %s %s, not found in %v", flag, value, args)
|
||||
}
|
||||
}
|
||||
@@ -187,42 +124,12 @@ func TestLlamaCppBuildCommandArgs_EmptyArrays(t *testing.T) {
|
||||
|
||||
excludedArgs := []string{"--lora", "--override-tensor"}
|
||||
for _, excludedArg := range excludedArgs {
|
||||
if contains(args, excludedArg) {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLlamaCppBuildCommandArgs_FieldNameConversion(t *testing.T) {
|
||||
// Test snake_case to kebab-case conversion
|
||||
options := backends.LlamaServerOptions{
|
||||
CtxSize: 4096,
|
||||
GPULayers: 32,
|
||||
ThreadsBatch: 2,
|
||||
FlashAttn: true,
|
||||
TopK: 40,
|
||||
TopP: 0.9,
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Check that field names are properly converted
|
||||
expectedFlags := []string{
|
||||
"--ctx-size", // ctx_size -> ctx-size
|
||||
"--gpu-layers", // gpu_layers -> gpu-layers
|
||||
"--threads-batch", // threads_batch -> threads-batch
|
||||
"--flash-attn", // flash_attn -> flash-attn
|
||||
"--top-k", // top_k -> top-k
|
||||
"--top-p", // top_p -> top-p
|
||||
}
|
||||
|
||||
for _, flag := range expectedFlags {
|
||||
if !contains(args, flag) {
|
||||
t.Errorf("Expected flag %q not found in %v", flag, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLlamaCppUnmarshalJSON_StandardFields(t *testing.T) {
|
||||
jsonData := `{
|
||||
"model": "/path/to/model.gguf",
|
||||
@@ -383,26 +290,81 @@ func TestParseLlamaCommand(t *testing.T) {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
validate func(*testing.T, *backends.LlamaServerOptions)
|
||||
}{
|
||||
{
|
||||
name: "basic command",
|
||||
command: "llama-server --model /path/to/model.gguf --gpu-layers 32",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||
if opts.Model != "/path/to/model.gguf" {
|
||||
t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.GPULayers != 32 {
|
||||
t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "args only",
|
||||
command: "--model /path/to/model.gguf --ctx-size 4096",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||
if opts.Model != "/path/to/model.gguf" {
|
||||
t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.CtxSize != 4096 {
|
||||
t.Errorf("expected ctx_size 4096, got %d", opts.CtxSize)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed flag formats",
|
||||
command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||
if opts.Model != "/path/model.gguf" {
|
||||
t.Errorf("expected model '/path/model.gguf', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.GPULayers != 16 {
|
||||
t.Errorf("expected gpu_layers 16, got %d", opts.GPULayers)
|
||||
}
|
||||
if !opts.Verbose {
|
||||
t.Errorf("expected verbose to be true")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "quoted strings",
|
||||
command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`,
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||
if opts.APIKey != "sk-1234567890abcdef" {
|
||||
t.Errorf("expected api_key 'sk-1234567890abcdef', got '%s'", opts.APIKey)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple value types",
|
||||
command: "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||
if opts.Model != "/test/model.gguf" {
|
||||
t.Errorf("expected model '/test/model.gguf', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.GPULayers != 32 {
|
||||
t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers)
|
||||
}
|
||||
if opts.Temperature != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %f", opts.Temperature)
|
||||
}
|
||||
if !opts.Verbose {
|
||||
t.Errorf("expected verbose to be true")
|
||||
}
|
||||
if !opts.NoMmap {
|
||||
t.Errorf("expected no_mmap to be true")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
@@ -439,40 +401,16 @@ func TestParseLlamaCommand(t *testing.T) {
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLlamaCommandValues(t *testing.T) {
|
||||
command := "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap"
|
||||
result, err := backends.ParseLlamaCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "/test/model.gguf" {
|
||||
t.Errorf("expected model '/test/model.gguf', got '%s'", result.Model)
|
||||
}
|
||||
|
||||
if result.GPULayers != 32 {
|
||||
t.Errorf("expected gpu_layers 32, got %d", result.GPULayers)
|
||||
}
|
||||
|
||||
if result.Temperature != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %f", result.Temperature)
|
||||
}
|
||||
|
||||
if !result.Verbose {
|
||||
t.Errorf("expected verbose to be true")
|
||||
}
|
||||
|
||||
if !result.NoMmap {
|
||||
t.Errorf("expected no_mmap to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLlamaCommandArrays(t *testing.T) {
|
||||
command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin"
|
||||
result, err := backends.ParseLlamaCommand(command)
|
||||
@@ -491,21 +429,4 @@ func TestParseLlamaCommandArrays(t *testing.T) {
|
||||
t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func contains(slice []string, item string) bool {
|
||||
return slices.Contains(slice, item)
|
||||
}
|
||||
|
||||
func containsFlagWithValue(args []string, flag, value string) bool {
|
||||
for i, arg := range args {
|
||||
if arg == flag {
|
||||
// Check if there's a next argument and it matches the expected value
|
||||
if i+1 < len(args) && args[i+1] == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package backends_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/testutil"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -10,26 +11,71 @@ func TestParseMlxCommand(t *testing.T) {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
validate func(*testing.T, *backends.MlxServerOptions)
|
||||
}{
|
||||
{
|
||||
name: "basic command",
|
||||
command: "mlx_lm.server --model /path/to/model --host 0.0.0.0",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||
if opts.Model != "/path/to/model" {
|
||||
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.Host != "0.0.0.0" {
|
||||
t.Errorf("expected host '0.0.0.0', got '%s'", opts.Host)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "args only",
|
||||
command: "--model /path/to/model --port 8080",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||
if opts.Model != "/path/to/model" {
|
||||
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.Port != 8080 {
|
||||
t.Errorf("expected port 8080, got %d", opts.Port)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed flag formats",
|
||||
command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||
if opts.Model != "/path/model" {
|
||||
t.Errorf("expected model '/path/model', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.Temp != 0.7 {
|
||||
t.Errorf("expected temp 0.7, got %f", opts.Temp)
|
||||
}
|
||||
if !opts.TrustRemoteCode {
|
||||
t.Errorf("expected trust_remote_code to be true")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "quoted strings",
|
||||
command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`,
|
||||
name: "multiple value types",
|
||||
command: "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||
if opts.Model != "/test/model.mlx" {
|
||||
t.Errorf("expected model '/test/model.mlx', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.Port != 8080 {
|
||||
t.Errorf("expected port 8080, got %d", opts.Port)
|
||||
}
|
||||
if opts.Temp != 0.7 {
|
||||
t.Errorf("expected temp 0.7, got %f", opts.Temp)
|
||||
}
|
||||
if !opts.TrustRemoteCode {
|
||||
t.Errorf("expected trust_remote_code to be true")
|
||||
}
|
||||
if opts.LogLevel != "DEBUG" {
|
||||
t.Errorf("expected log_level 'DEBUG', got '%s'", opts.LogLevel)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
@@ -66,92 +112,91 @@ func TestParseMlxCommand(t *testing.T) {
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMlxCommandValues(t *testing.T) {
|
||||
command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG"
|
||||
result, err := backends.ParseMlxCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
func TestMlxBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options backends.MlxServerOptions
|
||||
expected []string
|
||||
excluded []string
|
||||
}{
|
||||
{
|
||||
name: "trust_remote_code true",
|
||||
options: backends.MlxServerOptions{
|
||||
TrustRemoteCode: true,
|
||||
},
|
||||
expected: []string{"--trust-remote-code"},
|
||||
},
|
||||
{
|
||||
name: "trust_remote_code false",
|
||||
options: backends.MlxServerOptions{
|
||||
TrustRemoteCode: false,
|
||||
},
|
||||
excluded: []string{"--trust-remote-code"},
|
||||
},
|
||||
{
|
||||
name: "multiple booleans",
|
||||
options: backends.MlxServerOptions{
|
||||
TrustRemoteCode: true,
|
||||
UseDefaultChatTemplate: true,
|
||||
},
|
||||
expected: []string{"--trust-remote-code", "--use-default-chat-template"},
|
||||
},
|
||||
}
|
||||
|
||||
if result.Model != "/test/model.mlx" {
|
||||
t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model)
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
args := tt.options.BuildCommandArgs()
|
||||
|
||||
if result.Port != 8080 {
|
||||
t.Errorf("expected port 8080, got %d", result.Port)
|
||||
}
|
||||
for _, expectedArg := range tt.expected {
|
||||
if !testutil.Contains(args, expectedArg) {
|
||||
t.Errorf("Expected argument %q not found in %v", expectedArg, args)
|
||||
}
|
||||
}
|
||||
|
||||
if result.Temp != 0.7 {
|
||||
t.Errorf("expected temp 0.7, got %f", result.Temp)
|
||||
}
|
||||
|
||||
if !result.TrustRemoteCode {
|
||||
t.Errorf("expected trust_remote_code to be true")
|
||||
}
|
||||
|
||||
if result.LogLevel != "DEBUG" {
|
||||
t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel)
|
||||
for _, excludedArg := range tt.excluded {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Excluded argument %q found in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMlxBuildCommandArgs(t *testing.T) {
|
||||
options := &backends.MlxServerOptions{
|
||||
Model: "/test/model.mlx",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Temp: 0.7,
|
||||
TopP: 0.9,
|
||||
TopK: 40,
|
||||
MaxTokens: 2048,
|
||||
TrustRemoteCode: true,
|
||||
LogLevel: "DEBUG",
|
||||
ChatTemplate: "custom template",
|
||||
func TestMlxBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||
options := backends.MlxServerOptions{
|
||||
Port: 0, // Should be excluded
|
||||
TopK: 0, // Should be excluded
|
||||
Temp: 0, // Should be excluded
|
||||
Model: "", // Should be excluded
|
||||
LogLevel: "", // Should be excluded
|
||||
TrustRemoteCode: false, // Should be excluded
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Check that all expected flags are present
|
||||
expectedFlags := map[string]string{
|
||||
"--model": "/test/model.mlx",
|
||||
"--host": "127.0.0.1",
|
||||
"--port": "8080",
|
||||
"--log-level": "DEBUG",
|
||||
"--chat-template": "custom template",
|
||||
"--temp": "0.7",
|
||||
"--top-p": "0.9",
|
||||
"--top-k": "40",
|
||||
"--max-tokens": "2048",
|
||||
// Zero values should not appear in arguments
|
||||
excludedArgs := []string{
|
||||
"--port", "0",
|
||||
"--top-k", "0",
|
||||
"--temp", "0",
|
||||
"--model", "",
|
||||
"--log-level", "",
|
||||
"--trust-remote-code",
|
||||
}
|
||||
|
||||
for i := 0; i < len(args); i++ {
|
||||
if args[i] == "--trust-remote-code" {
|
||||
continue // Boolean flag with no value
|
||||
}
|
||||
if args[i] == "--use-default-chat-template" {
|
||||
continue // Boolean flag with no value
|
||||
}
|
||||
|
||||
if expectedValue, exists := expectedFlags[args[i]]; exists && i+1 < len(args) {
|
||||
if args[i+1] != expectedValue {
|
||||
t.Errorf("expected %s to have value %s, got %s", args[i], expectedValue, args[i+1])
|
||||
}
|
||||
for _, excludedArg := range excludedArgs {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
|
||||
// Check boolean flags
|
||||
foundTrustRemoteCode := false
|
||||
for _, arg := range args {
|
||||
if arg == "--trust-remote-code" {
|
||||
foundTrustRemoteCode = true
|
||||
}
|
||||
}
|
||||
if !foundTrustRemoteCode {
|
||||
t.Errorf("expected --trust-remote-code flag to be present")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package backends_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/testutil"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -10,26 +11,72 @@ func TestParseVllmCommand(t *testing.T) {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
validate func(*testing.T, *backends.VllmServerOptions)
|
||||
}{
|
||||
{
|
||||
name: "basic vllm serve command",
|
||||
command: "vllm serve microsoft/DialoGPT-medium",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||
if opts.Model != "microsoft/DialoGPT-medium" {
|
||||
t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "serve only command",
|
||||
command: "serve microsoft/DialoGPT-medium",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||
if opts.Model != "microsoft/DialoGPT-medium" {
|
||||
t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "positional model with flags",
|
||||
command: "vllm serve microsoft/DialoGPT-medium --tensor-parallel-size 2",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||
if opts.Model != "microsoft/DialoGPT-medium" {
|
||||
t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.TensorParallelSize != 2 {
|
||||
t.Errorf("expected tensor_parallel_size 2, got %d", opts.TensorParallelSize)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "model with path",
|
||||
command: "vllm serve /path/to/model --gpu-memory-utilization 0.8",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||
if opts.Model != "/path/to/model" {
|
||||
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.GPUMemoryUtilization != 0.8 {
|
||||
t.Errorf("expected gpu_memory_utilization 0.8, got %f", opts.GPUMemoryUtilization)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple value types",
|
||||
command: "vllm serve test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs",
|
||||
expectErr: false,
|
||||
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||
if opts.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got '%s'", opts.Model)
|
||||
}
|
||||
if opts.TensorParallelSize != 4 {
|
||||
t.Errorf("expected tensor_parallel_size 4, got %d", opts.TensorParallelSize)
|
||||
}
|
||||
if opts.GPUMemoryUtilization != 0.8 {
|
||||
t.Errorf("expected gpu_memory_utilization 0.8, got %f", opts.GPUMemoryUtilization)
|
||||
}
|
||||
if !opts.EnableLogOutputs {
|
||||
t.Errorf("expected enable_log_outputs true, got %v", opts.EnableLogOutputs)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
@@ -61,34 +108,144 @@ func TestParseVllmCommand(t *testing.T) {
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseVllmCommandValues(t *testing.T) {
|
||||
command := "vllm serve test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs"
|
||||
result, err := backends.ParseVllmCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
func TestVllmBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options backends.VllmServerOptions
|
||||
expected []string
|
||||
excluded []string
|
||||
}{
|
||||
{
|
||||
name: "enable_log_outputs true",
|
||||
options: backends.VllmServerOptions{
|
||||
EnableLogOutputs: true,
|
||||
},
|
||||
expected: []string{"--enable-log-outputs"},
|
||||
},
|
||||
{
|
||||
name: "enable_log_outputs false",
|
||||
options: backends.VllmServerOptions{
|
||||
EnableLogOutputs: false,
|
||||
},
|
||||
excluded: []string{"--enable-log-outputs"},
|
||||
},
|
||||
{
|
||||
name: "multiple booleans",
|
||||
options: backends.VllmServerOptions{
|
||||
EnableLogOutputs: true,
|
||||
TrustRemoteCode: true,
|
||||
EnablePrefixCaching: true,
|
||||
DisableLogStats: false,
|
||||
},
|
||||
expected: []string{"--enable-log-outputs", "--trust-remote-code", "--enable-prefix-caching"},
|
||||
excluded: []string{"--disable-log-stats"},
|
||||
},
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got '%s'", result.Model)
|
||||
}
|
||||
if result.TensorParallelSize != 4 {
|
||||
t.Errorf("expected tensor_parallel_size 4, got %d", result.TensorParallelSize)
|
||||
}
|
||||
if result.GPUMemoryUtilization != 0.8 {
|
||||
t.Errorf("expected gpu_memory_utilization 0.8, got %f", result.GPUMemoryUtilization)
|
||||
}
|
||||
if !result.EnableLogOutputs {
|
||||
t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
args := tt.options.BuildCommandArgs()
|
||||
|
||||
for _, expectedArg := range tt.expected {
|
||||
if !testutil.Contains(args, expectedArg) {
|
||||
t.Errorf("Expected argument %q not found in %v", expectedArg, args)
|
||||
}
|
||||
}
|
||||
|
||||
for _, excludedArg := range tt.excluded {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Excluded argument %q found in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVllmBuildCommandArgs(t *testing.T) {
|
||||
func TestVllmBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||
options := backends.VllmServerOptions{
|
||||
Port: 0, // Should be excluded
|
||||
TensorParallelSize: 0, // Should be excluded
|
||||
GPUMemoryUtilization: 0, // Should be excluded
|
||||
Model: "", // Should be excluded (positional arg)
|
||||
Host: "", // Should be excluded
|
||||
EnableLogOutputs: false, // Should be excluded
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Zero values should not appear in arguments
|
||||
excludedArgs := []string{
|
||||
"--port", "0",
|
||||
"--tensor-parallel-size", "0",
|
||||
"--gpu-memory-utilization", "0",
|
||||
"--host", "",
|
||||
"--enable-log-outputs",
|
||||
}
|
||||
|
||||
for _, excludedArg := range excludedArgs {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
|
||||
// Model should not be present as positional arg when empty
|
||||
if len(args) > 0 && args[0] == "" {
|
||||
t.Errorf("Empty model should not be present as positional argument")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVllmBuildCommandArgs_ArrayFields(t *testing.T) {
|
||||
options := backends.VllmServerOptions{
|
||||
AllowedOrigins: []string{"http://localhost:3000", "https://example.com"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
Middleware: []string{"middleware1", "middleware2", "middleware3"},
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
// Check that each array value appears with its flag
|
||||
expectedOccurrences := map[string][]string{
|
||||
"--allowed-origins": {"http://localhost:3000", "https://example.com"},
|
||||
"--allowed-methods": {"GET", "POST"},
|
||||
"--middleware": {"middleware1", "middleware2", "middleware3"},
|
||||
}
|
||||
|
||||
for flag, values := range expectedOccurrences {
|
||||
for _, value := range values {
|
||||
if !testutil.ContainsFlagWithValue(args, flag, value) {
|
||||
t.Errorf("Expected %s %s, not found in %v", flag, value, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestVllmBuildCommandArgs_EmptyArrays(t *testing.T) {
|
||||
options := backends.VllmServerOptions{
|
||||
AllowedOrigins: []string{}, // Empty array should not generate args
|
||||
Middleware: []string{}, // Empty array should not generate args
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
|
||||
excludedArgs := []string{"--allowed-origins", "--middleware"}
|
||||
for _, excludedArg := range excludedArgs {
|
||||
if testutil.Contains(args, excludedArg) {
|
||||
t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestVllmBuildCommandArgs_PositionalModel(t *testing.T) {
|
||||
options := backends.VllmServerOptions{
|
||||
Model: "microsoft/DialoGPT-medium",
|
||||
Port: 8080,
|
||||
@@ -96,7 +253,6 @@ func TestVllmBuildCommandArgs(t *testing.T) {
|
||||
TensorParallelSize: 2,
|
||||
GPUMemoryUtilization: 0.8,
|
||||
EnableLogOutputs: true,
|
||||
AllowedOrigins: []string{"http://localhost:3000", "https://example.com"},
|
||||
}
|
||||
|
||||
args := options.BuildCommandArgs()
|
||||
@@ -107,32 +263,24 @@ func TestVllmBuildCommandArgs(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check that --model flag is NOT present (since model should be positional)
|
||||
if contains(args, "--model") {
|
||||
if testutil.Contains(args, "--model") {
|
||||
t.Errorf("Found --model flag, but model should be positional argument in args: %v", args)
|
||||
}
|
||||
|
||||
// Check other flags
|
||||
if !containsFlagWithValue(args, "--tensor-parallel-size", "2") {
|
||||
if !testutil.ContainsFlagWithValue(args, "--tensor-parallel-size", "2") {
|
||||
t.Errorf("Expected --tensor-parallel-size 2 not found in %v", args)
|
||||
}
|
||||
if !contains(args, "--enable-log-outputs") {
|
||||
if !testutil.ContainsFlagWithValue(args, "--gpu-memory-utilization", "0.8") {
|
||||
t.Errorf("Expected --gpu-memory-utilization 0.8 not found in %v", args)
|
||||
}
|
||||
if !testutil.Contains(args, "--enable-log-outputs") {
|
||||
t.Errorf("Expected --enable-log-outputs not found in %v", args)
|
||||
}
|
||||
if !contains(args, "--host") {
|
||||
t.Errorf("Expected --host not found in %v", args)
|
||||
if !testutil.ContainsFlagWithValue(args, "--host", "localhost") {
|
||||
t.Errorf("Expected --host localhost not found in %v", args)
|
||||
}
|
||||
if !contains(args, "--port") {
|
||||
t.Errorf("Expected --port not found in %v", args)
|
||||
}
|
||||
|
||||
// Check array handling (multiple flags)
|
||||
allowedOriginsCount := 0
|
||||
for i := range args {
|
||||
if args[i] == "--allowed-origins" {
|
||||
allowedOriginsCount++
|
||||
}
|
||||
}
|
||||
if allowedOriginsCount != 2 {
|
||||
t.Errorf("Expected 2 --allowed-origins flags, got %d", allowedOriginsCount)
|
||||
if !testutil.ContainsFlagWithValue(args, "--port", "8080") {
|
||||
t.Errorf("Expected --port 8080 not found in %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package testutil
|
||||
|
||||
import "slices"
|
||||
|
||||
// Helper functions for pointer fields
|
||||
func BoolPtr(b bool) *bool {
|
||||
return &b
|
||||
@@ -8,3 +10,23 @@ func BoolPtr(b bool) *bool {
|
||||
func IntPtr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
|
||||
// Helper functions for testing command arguments
|
||||
|
||||
// Contains checks if a slice contains a specific item
|
||||
func Contains(slice []string, item string) bool {
|
||||
return slices.Contains(slice, item)
|
||||
}
|
||||
|
||||
// ContainsFlagWithValue checks if args contains a flag followed by a specific value
|
||||
func ContainsFlagWithValue(args []string, flag, value string) bool {
|
||||
for i, arg := range args {
|
||||
if arg == flag {
|
||||
// Check if there's a next argument and it matches the expected value
|
||||
if i+1 < len(args) && args[i+1] == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user