mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-06 09:04:27 +00:00
Implement mlx and cllm tests and remove redundant code
This commit is contained in:
@@ -4,44 +4,11 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"llamactl/pkg/backends"
|
"llamactl/pkg/backends"
|
||||||
|
"llamactl/pkg/testutil"
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLlamaCppBuildCommandArgs_BasicFields(t *testing.T) {
|
|
||||||
options := backends.LlamaServerOptions{
|
|
||||||
Model: "/path/to/model.gguf",
|
|
||||||
Port: 8080,
|
|
||||||
Host: "localhost",
|
|
||||||
Verbose: true,
|
|
||||||
CtxSize: 4096,
|
|
||||||
GPULayers: 32,
|
|
||||||
}
|
|
||||||
|
|
||||||
args := options.BuildCommandArgs()
|
|
||||||
|
|
||||||
// Check individual arguments
|
|
||||||
expectedPairs := map[string]string{
|
|
||||||
"--model": "/path/to/model.gguf",
|
|
||||||
"--port": "8080",
|
|
||||||
"--host": "localhost",
|
|
||||||
"--ctx-size": "4096",
|
|
||||||
"--gpu-layers": "32",
|
|
||||||
}
|
|
||||||
|
|
||||||
for flag, expectedValue := range expectedPairs {
|
|
||||||
if !containsFlagWithValue(args, flag, expectedValue) {
|
|
||||||
t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check standalone boolean flag
|
|
||||||
if !contains(args, "--verbose") {
|
|
||||||
t.Errorf("Expected --verbose flag not found in %v", args)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) {
|
func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -81,13 +48,13 @@ func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) {
|
|||||||
args := tt.options.BuildCommandArgs()
|
args := tt.options.BuildCommandArgs()
|
||||||
|
|
||||||
for _, expectedArg := range tt.expected {
|
for _, expectedArg := range tt.expected {
|
||||||
if !contains(args, expectedArg) {
|
if !testutil.Contains(args, expectedArg) {
|
||||||
t.Errorf("Expected argument %q not found in %v", expectedArg, args)
|
t.Errorf("Expected argument %q not found in %v", expectedArg, args)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, excludedArg := range tt.excluded {
|
for _, excludedArg := range tt.excluded {
|
||||||
if contains(args, excludedArg) {
|
if testutil.Contains(args, excludedArg) {
|
||||||
t.Errorf("Excluded argument %q found in %v", excludedArg, args)
|
t.Errorf("Excluded argument %q found in %v", excludedArg, args)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -95,36 +62,6 @@ func TestLlamaCppBuildCommandArgs_BooleanFields(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLlamaCppBuildCommandArgs_NumericFields(t *testing.T) {
|
|
||||||
options := backends.LlamaServerOptions{
|
|
||||||
Port: 8080,
|
|
||||||
Threads: 4,
|
|
||||||
CtxSize: 2048,
|
|
||||||
GPULayers: 16,
|
|
||||||
Temperature: 0.7,
|
|
||||||
TopK: 40,
|
|
||||||
TopP: 0.9,
|
|
||||||
}
|
|
||||||
|
|
||||||
args := options.BuildCommandArgs()
|
|
||||||
|
|
||||||
expectedPairs := map[string]string{
|
|
||||||
"--port": "8080",
|
|
||||||
"--threads": "4",
|
|
||||||
"--ctx-size": "2048",
|
|
||||||
"--gpu-layers": "16",
|
|
||||||
"--temp": "0.7",
|
|
||||||
"--top-k": "40",
|
|
||||||
"--top-p": "0.9",
|
|
||||||
}
|
|
||||||
|
|
||||||
for flag, expectedValue := range expectedPairs {
|
|
||||||
if !containsFlagWithValue(args, flag, expectedValue) {
|
|
||||||
t.Errorf("Expected %s %s, not found in %v", flag, expectedValue, args)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLlamaCppBuildCommandArgs_ZeroValues(t *testing.T) {
|
func TestLlamaCppBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||||
options := backends.LlamaServerOptions{
|
options := backends.LlamaServerOptions{
|
||||||
Port: 0, // Should be excluded
|
Port: 0, // Should be excluded
|
||||||
@@ -146,7 +83,7 @@ func TestLlamaCppBuildCommandArgs_ZeroValues(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, excludedArg := range excludedArgs {
|
for _, excludedArg := range excludedArgs {
|
||||||
if contains(args, excludedArg) {
|
if testutil.Contains(args, excludedArg) {
|
||||||
t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args)
|
t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -170,7 +107,7 @@ func TestLlamaCppBuildCommandArgs_ArrayFields(t *testing.T) {
|
|||||||
|
|
||||||
for flag, values := range expectedOccurrences {
|
for flag, values := range expectedOccurrences {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
if !containsFlagWithValue(args, flag, value) {
|
if !testutil.ContainsFlagWithValue(args, flag, value) {
|
||||||
t.Errorf("Expected %s %s, not found in %v", flag, value, args)
|
t.Errorf("Expected %s %s, not found in %v", flag, value, args)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -187,42 +124,12 @@ func TestLlamaCppBuildCommandArgs_EmptyArrays(t *testing.T) {
|
|||||||
|
|
||||||
excludedArgs := []string{"--lora", "--override-tensor"}
|
excludedArgs := []string{"--lora", "--override-tensor"}
|
||||||
for _, excludedArg := range excludedArgs {
|
for _, excludedArg := range excludedArgs {
|
||||||
if contains(args, excludedArg) {
|
if testutil.Contains(args, excludedArg) {
|
||||||
t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args)
|
t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLlamaCppBuildCommandArgs_FieldNameConversion(t *testing.T) {
|
|
||||||
// Test snake_case to kebab-case conversion
|
|
||||||
options := backends.LlamaServerOptions{
|
|
||||||
CtxSize: 4096,
|
|
||||||
GPULayers: 32,
|
|
||||||
ThreadsBatch: 2,
|
|
||||||
FlashAttn: true,
|
|
||||||
TopK: 40,
|
|
||||||
TopP: 0.9,
|
|
||||||
}
|
|
||||||
|
|
||||||
args := options.BuildCommandArgs()
|
|
||||||
|
|
||||||
// Check that field names are properly converted
|
|
||||||
expectedFlags := []string{
|
|
||||||
"--ctx-size", // ctx_size -> ctx-size
|
|
||||||
"--gpu-layers", // gpu_layers -> gpu-layers
|
|
||||||
"--threads-batch", // threads_batch -> threads-batch
|
|
||||||
"--flash-attn", // flash_attn -> flash-attn
|
|
||||||
"--top-k", // top_k -> top-k
|
|
||||||
"--top-p", // top_p -> top-p
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, flag := range expectedFlags {
|
|
||||||
if !contains(args, flag) {
|
|
||||||
t.Errorf("Expected flag %q not found in %v", flag, args)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLlamaCppUnmarshalJSON_StandardFields(t *testing.T) {
|
func TestLlamaCppUnmarshalJSON_StandardFields(t *testing.T) {
|
||||||
jsonData := `{
|
jsonData := `{
|
||||||
"model": "/path/to/model.gguf",
|
"model": "/path/to/model.gguf",
|
||||||
@@ -383,26 +290,81 @@ func TestParseLlamaCommand(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
command string
|
command string
|
||||||
expectErr bool
|
expectErr bool
|
||||||
|
validate func(*testing.T, *backends.LlamaServerOptions)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "basic command",
|
name: "basic command",
|
||||||
command: "llama-server --model /path/to/model.gguf --gpu-layers 32",
|
command: "llama-server --model /path/to/model.gguf --gpu-layers 32",
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||||
|
if opts.Model != "/path/to/model.gguf" {
|
||||||
|
t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.GPULayers != 32 {
|
||||||
|
t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "args only",
|
name: "args only",
|
||||||
command: "--model /path/to/model.gguf --ctx-size 4096",
|
command: "--model /path/to/model.gguf --ctx-size 4096",
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||||
|
if opts.Model != "/path/to/model.gguf" {
|
||||||
|
t.Errorf("expected model '/path/to/model.gguf', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.CtxSize != 4096 {
|
||||||
|
t.Errorf("expected ctx_size 4096, got %d", opts.CtxSize)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mixed flag formats",
|
name: "mixed flag formats",
|
||||||
command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose",
|
command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose",
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||||
|
if opts.Model != "/path/model.gguf" {
|
||||||
|
t.Errorf("expected model '/path/model.gguf', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.GPULayers != 16 {
|
||||||
|
t.Errorf("expected gpu_layers 16, got %d", opts.GPULayers)
|
||||||
|
}
|
||||||
|
if !opts.Verbose {
|
||||||
|
t.Errorf("expected verbose to be true")
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "quoted strings",
|
name: "quoted strings",
|
||||||
command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`,
|
command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`,
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||||
|
if opts.APIKey != "sk-1234567890abcdef" {
|
||||||
|
t.Errorf("expected api_key 'sk-1234567890abcdef', got '%s'", opts.APIKey)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple value types",
|
||||||
|
command: "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.LlamaServerOptions) {
|
||||||
|
if opts.Model != "/test/model.gguf" {
|
||||||
|
t.Errorf("expected model '/test/model.gguf', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.GPULayers != 32 {
|
||||||
|
t.Errorf("expected gpu_layers 32, got %d", opts.GPULayers)
|
||||||
|
}
|
||||||
|
if opts.Temperature != 0.7 {
|
||||||
|
t.Errorf("expected temperature 0.7, got %f", opts.Temperature)
|
||||||
|
}
|
||||||
|
if !opts.Verbose {
|
||||||
|
t.Errorf("expected verbose to be true")
|
||||||
|
}
|
||||||
|
if !opts.NoMmap {
|
||||||
|
t.Errorf("expected no_mmap to be true")
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty command",
|
name: "empty command",
|
||||||
@@ -439,40 +401,16 @@ func TestParseLlamaCommand(t *testing.T) {
|
|||||||
|
|
||||||
if result == nil {
|
if result == nil {
|
||||||
t.Errorf("expected result but got nil")
|
t.Errorf("expected result but got nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(t, result)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseLlamaCommandValues(t *testing.T) {
|
|
||||||
command := "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap"
|
|
||||||
result, err := backends.ParseLlamaCommand(command)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Model != "/test/model.gguf" {
|
|
||||||
t.Errorf("expected model '/test/model.gguf', got '%s'", result.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.GPULayers != 32 {
|
|
||||||
t.Errorf("expected gpu_layers 32, got %d", result.GPULayers)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Temperature != 0.7 {
|
|
||||||
t.Errorf("expected temperature 0.7, got %f", result.Temperature)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !result.Verbose {
|
|
||||||
t.Errorf("expected verbose to be true")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !result.NoMmap {
|
|
||||||
t.Errorf("expected no_mmap to be true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseLlamaCommandArrays(t *testing.T) {
|
func TestParseLlamaCommandArrays(t *testing.T) {
|
||||||
command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin"
|
command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin"
|
||||||
result, err := backends.ParseLlamaCommand(command)
|
result, err := backends.ParseLlamaCommand(command)
|
||||||
@@ -492,20 +430,3 @@ func TestParseLlamaCommandArrays(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper functions
|
|
||||||
func contains(slice []string, item string) bool {
|
|
||||||
return slices.Contains(slice, item)
|
|
||||||
}
|
|
||||||
|
|
||||||
func containsFlagWithValue(args []string, flag, value string) bool {
|
|
||||||
for i, arg := range args {
|
|
||||||
if arg == flag {
|
|
||||||
// Check if there's a next argument and it matches the expected value
|
|
||||||
if i+1 < len(args) && args[i+1] == value {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package backends_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"llamactl/pkg/backends"
|
"llamactl/pkg/backends"
|
||||||
|
"llamactl/pkg/testutil"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -10,26 +11,71 @@ func TestParseMlxCommand(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
command string
|
command string
|
||||||
expectErr bool
|
expectErr bool
|
||||||
|
validate func(*testing.T, *backends.MlxServerOptions)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "basic command",
|
name: "basic command",
|
||||||
command: "mlx_lm.server --model /path/to/model --host 0.0.0.0",
|
command: "mlx_lm.server --model /path/to/model --host 0.0.0.0",
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||||
|
if opts.Model != "/path/to/model" {
|
||||||
|
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.Host != "0.0.0.0" {
|
||||||
|
t.Errorf("expected host '0.0.0.0', got '%s'", opts.Host)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "args only",
|
name: "args only",
|
||||||
command: "--model /path/to/model --port 8080",
|
command: "--model /path/to/model --port 8080",
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||||
|
if opts.Model != "/path/to/model" {
|
||||||
|
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.Port != 8080 {
|
||||||
|
t.Errorf("expected port 8080, got %d", opts.Port)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mixed flag formats",
|
name: "mixed flag formats",
|
||||||
command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code",
|
command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code",
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||||
|
if opts.Model != "/path/model" {
|
||||||
|
t.Errorf("expected model '/path/model', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.Temp != 0.7 {
|
||||||
|
t.Errorf("expected temp 0.7, got %f", opts.Temp)
|
||||||
|
}
|
||||||
|
if !opts.TrustRemoteCode {
|
||||||
|
t.Errorf("expected trust_remote_code to be true")
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "quoted strings",
|
name: "multiple value types",
|
||||||
command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`,
|
command: "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG",
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.MlxServerOptions) {
|
||||||
|
if opts.Model != "/test/model.mlx" {
|
||||||
|
t.Errorf("expected model '/test/model.mlx', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.Port != 8080 {
|
||||||
|
t.Errorf("expected port 8080, got %d", opts.Port)
|
||||||
|
}
|
||||||
|
if opts.Temp != 0.7 {
|
||||||
|
t.Errorf("expected temp 0.7, got %f", opts.Temp)
|
||||||
|
}
|
||||||
|
if !opts.TrustRemoteCode {
|
||||||
|
t.Errorf("expected trust_remote_code to be true")
|
||||||
|
}
|
||||||
|
if opts.LogLevel != "DEBUG" {
|
||||||
|
t.Errorf("expected log_level 'DEBUG', got '%s'", opts.LogLevel)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty command",
|
name: "empty command",
|
||||||
@@ -66,92 +112,91 @@ func TestParseMlxCommand(t *testing.T) {
|
|||||||
|
|
||||||
if result == nil {
|
if result == nil {
|
||||||
t.Errorf("expected result but got nil")
|
t.Errorf("expected result but got nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(t, result)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseMlxCommandValues(t *testing.T) {
|
func TestMlxBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||||
command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG"
|
tests := []struct {
|
||||||
result, err := backends.ParseMlxCommand(command)
|
name string
|
||||||
|
options backends.MlxServerOptions
|
||||||
if err != nil {
|
expected []string
|
||||||
t.Fatalf("unexpected error: %v", err)
|
excluded []string
|
||||||
}
|
}{
|
||||||
|
{
|
||||||
if result.Model != "/test/model.mlx" {
|
name: "trust_remote_code true",
|
||||||
t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model)
|
options: backends.MlxServerOptions{
|
||||||
}
|
|
||||||
|
|
||||||
if result.Port != 8080 {
|
|
||||||
t.Errorf("expected port 8080, got %d", result.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Temp != 0.7 {
|
|
||||||
t.Errorf("expected temp 0.7, got %f", result.Temp)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !result.TrustRemoteCode {
|
|
||||||
t.Errorf("expected trust_remote_code to be true")
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.LogLevel != "DEBUG" {
|
|
||||||
t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMlxBuildCommandArgs(t *testing.T) {
|
|
||||||
options := &backends.MlxServerOptions{
|
|
||||||
Model: "/test/model.mlx",
|
|
||||||
Host: "127.0.0.1",
|
|
||||||
Port: 8080,
|
|
||||||
Temp: 0.7,
|
|
||||||
TopP: 0.9,
|
|
||||||
TopK: 40,
|
|
||||||
MaxTokens: 2048,
|
|
||||||
TrustRemoteCode: true,
|
TrustRemoteCode: true,
|
||||||
LogLevel: "DEBUG",
|
},
|
||||||
ChatTemplate: "custom template",
|
expected: []string{"--trust-remote-code"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trust_remote_code false",
|
||||||
|
options: backends.MlxServerOptions{
|
||||||
|
TrustRemoteCode: false,
|
||||||
|
},
|
||||||
|
excluded: []string{"--trust-remote-code"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple booleans",
|
||||||
|
options: backends.MlxServerOptions{
|
||||||
|
TrustRemoteCode: true,
|
||||||
|
UseDefaultChatTemplate: true,
|
||||||
|
},
|
||||||
|
expected: []string{"--trust-remote-code", "--use-default-chat-template"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
args := tt.options.BuildCommandArgs()
|
||||||
|
|
||||||
|
for _, expectedArg := range tt.expected {
|
||||||
|
if !testutil.Contains(args, expectedArg) {
|
||||||
|
t.Errorf("Expected argument %q not found in %v", expectedArg, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, excludedArg := range tt.excluded {
|
||||||
|
if testutil.Contains(args, excludedArg) {
|
||||||
|
t.Errorf("Excluded argument %q found in %v", excludedArg, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMlxBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||||
|
options := backends.MlxServerOptions{
|
||||||
|
Port: 0, // Should be excluded
|
||||||
|
TopK: 0, // Should be excluded
|
||||||
|
Temp: 0, // Should be excluded
|
||||||
|
Model: "", // Should be excluded
|
||||||
|
LogLevel: "", // Should be excluded
|
||||||
|
TrustRemoteCode: false, // Should be excluded
|
||||||
}
|
}
|
||||||
|
|
||||||
args := options.BuildCommandArgs()
|
args := options.BuildCommandArgs()
|
||||||
|
|
||||||
// Check that all expected flags are present
|
// Zero values should not appear in arguments
|
||||||
expectedFlags := map[string]string{
|
excludedArgs := []string{
|
||||||
"--model": "/test/model.mlx",
|
"--port", "0",
|
||||||
"--host": "127.0.0.1",
|
"--top-k", "0",
|
||||||
"--port": "8080",
|
"--temp", "0",
|
||||||
"--log-level": "DEBUG",
|
"--model", "",
|
||||||
"--chat-template": "custom template",
|
"--log-level", "",
|
||||||
"--temp": "0.7",
|
"--trust-remote-code",
|
||||||
"--top-p": "0.9",
|
|
||||||
"--top-k": "40",
|
|
||||||
"--max-tokens": "2048",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < len(args); i++ {
|
for _, excludedArg := range excludedArgs {
|
||||||
if args[i] == "--trust-remote-code" {
|
if testutil.Contains(args, excludedArg) {
|
||||||
continue // Boolean flag with no value
|
t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args)
|
||||||
}
|
|
||||||
if args[i] == "--use-default-chat-template" {
|
|
||||||
continue // Boolean flag with no value
|
|
||||||
}
|
|
||||||
|
|
||||||
if expectedValue, exists := expectedFlags[args[i]]; exists && i+1 < len(args) {
|
|
||||||
if args[i+1] != expectedValue {
|
|
||||||
t.Errorf("expected %s to have value %s, got %s", args[i], expectedValue, args[i+1])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check boolean flags
|
|
||||||
foundTrustRemoteCode := false
|
|
||||||
for _, arg := range args {
|
|
||||||
if arg == "--trust-remote-code" {
|
|
||||||
foundTrustRemoteCode = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !foundTrustRemoteCode {
|
|
||||||
t.Errorf("expected --trust-remote-code flag to be present")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package backends_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"llamactl/pkg/backends"
|
"llamactl/pkg/backends"
|
||||||
|
"llamactl/pkg/testutil"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -10,26 +11,72 @@ func TestParseVllmCommand(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
command string
|
command string
|
||||||
expectErr bool
|
expectErr bool
|
||||||
|
validate func(*testing.T, *backends.VllmServerOptions)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "basic vllm serve command",
|
name: "basic vllm serve command",
|
||||||
command: "vllm serve microsoft/DialoGPT-medium",
|
command: "vllm serve microsoft/DialoGPT-medium",
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||||
|
if opts.Model != "microsoft/DialoGPT-medium" {
|
||||||
|
t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "serve only command",
|
name: "serve only command",
|
||||||
command: "serve microsoft/DialoGPT-medium",
|
command: "serve microsoft/DialoGPT-medium",
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||||
|
if opts.Model != "microsoft/DialoGPT-medium" {
|
||||||
|
t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "positional model with flags",
|
name: "positional model with flags",
|
||||||
command: "vllm serve microsoft/DialoGPT-medium --tensor-parallel-size 2",
|
command: "vllm serve microsoft/DialoGPT-medium --tensor-parallel-size 2",
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||||
|
if opts.Model != "microsoft/DialoGPT-medium" {
|
||||||
|
t.Errorf("expected model 'microsoft/DialoGPT-medium', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.TensorParallelSize != 2 {
|
||||||
|
t.Errorf("expected tensor_parallel_size 2, got %d", opts.TensorParallelSize)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "model with path",
|
name: "model with path",
|
||||||
command: "vllm serve /path/to/model --gpu-memory-utilization 0.8",
|
command: "vllm serve /path/to/model --gpu-memory-utilization 0.8",
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||||
|
if opts.Model != "/path/to/model" {
|
||||||
|
t.Errorf("expected model '/path/to/model', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.GPUMemoryUtilization != 0.8 {
|
||||||
|
t.Errorf("expected gpu_memory_utilization 0.8, got %f", opts.GPUMemoryUtilization)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple value types",
|
||||||
|
command: "vllm serve test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs",
|
||||||
|
expectErr: false,
|
||||||
|
validate: func(t *testing.T, opts *backends.VllmServerOptions) {
|
||||||
|
if opts.Model != "test-model" {
|
||||||
|
t.Errorf("expected model 'test-model', got '%s'", opts.Model)
|
||||||
|
}
|
||||||
|
if opts.TensorParallelSize != 4 {
|
||||||
|
t.Errorf("expected tensor_parallel_size 4, got %d", opts.TensorParallelSize)
|
||||||
|
}
|
||||||
|
if opts.GPUMemoryUtilization != 0.8 {
|
||||||
|
t.Errorf("expected gpu_memory_utilization 0.8, got %f", opts.GPUMemoryUtilization)
|
||||||
|
}
|
||||||
|
if !opts.EnableLogOutputs {
|
||||||
|
t.Errorf("expected enable_log_outputs true, got %v", opts.EnableLogOutputs)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty command",
|
name: "empty command",
|
||||||
@@ -61,34 +108,144 @@ func TestParseVllmCommand(t *testing.T) {
|
|||||||
|
|
||||||
if result == nil {
|
if result == nil {
|
||||||
t.Errorf("expected result but got nil")
|
t.Errorf("expected result but got nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(t, result)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseVllmCommandValues(t *testing.T) {
|
func TestVllmBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||||
command := "vllm serve test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs"
|
tests := []struct {
|
||||||
result, err := backends.ParseVllmCommand(command)
|
name string
|
||||||
|
options backends.VllmServerOptions
|
||||||
if err != nil {
|
expected []string
|
||||||
t.Fatalf("unexpected error: %v", err)
|
excluded []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enable_log_outputs true",
|
||||||
|
options: backends.VllmServerOptions{
|
||||||
|
EnableLogOutputs: true,
|
||||||
|
},
|
||||||
|
expected: []string{"--enable-log-outputs"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enable_log_outputs false",
|
||||||
|
options: backends.VllmServerOptions{
|
||||||
|
EnableLogOutputs: false,
|
||||||
|
},
|
||||||
|
excluded: []string{"--enable-log-outputs"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple booleans",
|
||||||
|
options: backends.VllmServerOptions{
|
||||||
|
EnableLogOutputs: true,
|
||||||
|
TrustRemoteCode: true,
|
||||||
|
EnablePrefixCaching: true,
|
||||||
|
DisableLogStats: false,
|
||||||
|
},
|
||||||
|
expected: []string{"--enable-log-outputs", "--trust-remote-code", "--enable-prefix-caching"},
|
||||||
|
excluded: []string{"--disable-log-stats"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.Model != "test-model" {
|
for _, tt := range tests {
|
||||||
t.Errorf("expected model 'test-model', got '%s'", result.Model)
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
}
|
args := tt.options.BuildCommandArgs()
|
||||||
if result.TensorParallelSize != 4 {
|
|
||||||
t.Errorf("expected tensor_parallel_size 4, got %d", result.TensorParallelSize)
|
for _, expectedArg := range tt.expected {
|
||||||
}
|
if !testutil.Contains(args, expectedArg) {
|
||||||
if result.GPUMemoryUtilization != 0.8 {
|
t.Errorf("Expected argument %q not found in %v", expectedArg, args)
|
||||||
t.Errorf("expected gpu_memory_utilization 0.8, got %f", result.GPUMemoryUtilization)
|
|
||||||
}
|
|
||||||
if !result.EnableLogOutputs {
|
|
||||||
t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVllmBuildCommandArgs(t *testing.T) {
|
for _, excludedArg := range tt.excluded {
|
||||||
|
if testutil.Contains(args, excludedArg) {
|
||||||
|
t.Errorf("Excluded argument %q found in %v", excludedArg, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVllmBuildCommandArgs_ZeroValues(t *testing.T) {
|
||||||
|
options := backends.VllmServerOptions{
|
||||||
|
Port: 0, // Should be excluded
|
||||||
|
TensorParallelSize: 0, // Should be excluded
|
||||||
|
GPUMemoryUtilization: 0, // Should be excluded
|
||||||
|
Model: "", // Should be excluded (positional arg)
|
||||||
|
Host: "", // Should be excluded
|
||||||
|
EnableLogOutputs: false, // Should be excluded
|
||||||
|
}
|
||||||
|
|
||||||
|
args := options.BuildCommandArgs()
|
||||||
|
|
||||||
|
// Zero values should not appear in arguments
|
||||||
|
excludedArgs := []string{
|
||||||
|
"--port", "0",
|
||||||
|
"--tensor-parallel-size", "0",
|
||||||
|
"--gpu-memory-utilization", "0",
|
||||||
|
"--host", "",
|
||||||
|
"--enable-log-outputs",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, excludedArg := range excludedArgs {
|
||||||
|
if testutil.Contains(args, excludedArg) {
|
||||||
|
t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model should not be present as positional arg when empty
|
||||||
|
if len(args) > 0 && args[0] == "" {
|
||||||
|
t.Errorf("Empty model should not be present as positional argument")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVllmBuildCommandArgs_ArrayFields(t *testing.T) {
|
||||||
|
options := backends.VllmServerOptions{
|
||||||
|
AllowedOrigins: []string{"http://localhost:3000", "https://example.com"},
|
||||||
|
AllowedMethods: []string{"GET", "POST"},
|
||||||
|
Middleware: []string{"middleware1", "middleware2", "middleware3"},
|
||||||
|
}
|
||||||
|
|
||||||
|
args := options.BuildCommandArgs()
|
||||||
|
|
||||||
|
// Check that each array value appears with its flag
|
||||||
|
expectedOccurrences := map[string][]string{
|
||||||
|
"--allowed-origins": {"http://localhost:3000", "https://example.com"},
|
||||||
|
"--allowed-methods": {"GET", "POST"},
|
||||||
|
"--middleware": {"middleware1", "middleware2", "middleware3"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for flag, values := range expectedOccurrences {
|
||||||
|
for _, value := range values {
|
||||||
|
if !testutil.ContainsFlagWithValue(args, flag, value) {
|
||||||
|
t.Errorf("Expected %s %s, not found in %v", flag, value, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVllmBuildCommandArgs_EmptyArrays(t *testing.T) {
|
||||||
|
options := backends.VllmServerOptions{
|
||||||
|
AllowedOrigins: []string{}, // Empty array should not generate args
|
||||||
|
Middleware: []string{}, // Empty array should not generate args
|
||||||
|
}
|
||||||
|
|
||||||
|
args := options.BuildCommandArgs()
|
||||||
|
|
||||||
|
excludedArgs := []string{"--allowed-origins", "--middleware"}
|
||||||
|
for _, excludedArg := range excludedArgs {
|
||||||
|
if testutil.Contains(args, excludedArg) {
|
||||||
|
t.Errorf("Empty array should not generate argument %q in %v", excludedArg, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVllmBuildCommandArgs_PositionalModel(t *testing.T) {
|
||||||
options := backends.VllmServerOptions{
|
options := backends.VllmServerOptions{
|
||||||
Model: "microsoft/DialoGPT-medium",
|
Model: "microsoft/DialoGPT-medium",
|
||||||
Port: 8080,
|
Port: 8080,
|
||||||
@@ -96,7 +253,6 @@ func TestVllmBuildCommandArgs(t *testing.T) {
|
|||||||
TensorParallelSize: 2,
|
TensorParallelSize: 2,
|
||||||
GPUMemoryUtilization: 0.8,
|
GPUMemoryUtilization: 0.8,
|
||||||
EnableLogOutputs: true,
|
EnableLogOutputs: true,
|
||||||
AllowedOrigins: []string{"http://localhost:3000", "https://example.com"},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
args := options.BuildCommandArgs()
|
args := options.BuildCommandArgs()
|
||||||
@@ -107,32 +263,24 @@ func TestVllmBuildCommandArgs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check that --model flag is NOT present (since model should be positional)
|
// Check that --model flag is NOT present (since model should be positional)
|
||||||
if contains(args, "--model") {
|
if testutil.Contains(args, "--model") {
|
||||||
t.Errorf("Found --model flag, but model should be positional argument in args: %v", args)
|
t.Errorf("Found --model flag, but model should be positional argument in args: %v", args)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check other flags
|
// Check other flags
|
||||||
if !containsFlagWithValue(args, "--tensor-parallel-size", "2") {
|
if !testutil.ContainsFlagWithValue(args, "--tensor-parallel-size", "2") {
|
||||||
t.Errorf("Expected --tensor-parallel-size 2 not found in %v", args)
|
t.Errorf("Expected --tensor-parallel-size 2 not found in %v", args)
|
||||||
}
|
}
|
||||||
if !contains(args, "--enable-log-outputs") {
|
if !testutil.ContainsFlagWithValue(args, "--gpu-memory-utilization", "0.8") {
|
||||||
|
t.Errorf("Expected --gpu-memory-utilization 0.8 not found in %v", args)
|
||||||
|
}
|
||||||
|
if !testutil.Contains(args, "--enable-log-outputs") {
|
||||||
t.Errorf("Expected --enable-log-outputs not found in %v", args)
|
t.Errorf("Expected --enable-log-outputs not found in %v", args)
|
||||||
}
|
}
|
||||||
if !contains(args, "--host") {
|
if !testutil.ContainsFlagWithValue(args, "--host", "localhost") {
|
||||||
t.Errorf("Expected --host not found in %v", args)
|
t.Errorf("Expected --host localhost not found in %v", args)
|
||||||
}
|
}
|
||||||
if !contains(args, "--port") {
|
if !testutil.ContainsFlagWithValue(args, "--port", "8080") {
|
||||||
t.Errorf("Expected --port not found in %v", args)
|
t.Errorf("Expected --port 8080 not found in %v", args)
|
||||||
}
|
|
||||||
|
|
||||||
// Check array handling (multiple flags)
|
|
||||||
allowedOriginsCount := 0
|
|
||||||
for i := range args {
|
|
||||||
if args[i] == "--allowed-origins" {
|
|
||||||
allowedOriginsCount++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if allowedOriginsCount != 2 {
|
|
||||||
t.Errorf("Expected 2 --allowed-origins flags, got %d", allowedOriginsCount)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package testutil
|
package testutil
|
||||||
|
|
||||||
|
import "slices"
|
||||||
|
|
||||||
// Helper functions for pointer fields
|
// Helper functions for pointer fields
|
||||||
func BoolPtr(b bool) *bool {
|
func BoolPtr(b bool) *bool {
|
||||||
return &b
|
return &b
|
||||||
@@ -8,3 +10,23 @@ func BoolPtr(b bool) *bool {
|
|||||||
func IntPtr(i int) *int {
|
func IntPtr(i int) *int {
|
||||||
return &i
|
return &i
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper functions for testing command arguments
|
||||||
|
|
||||||
|
// Contains checks if a slice contains a specific item
|
||||||
|
func Contains(slice []string, item string) bool {
|
||||||
|
return slices.Contains(slice, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContainsFlagWithValue checks if args contains a flag followed by a specific value
|
||||||
|
func ContainsFlagWithValue(args []string, flag, value string) bool {
|
||||||
|
for i, arg := range args {
|
||||||
|
if arg == flag {
|
||||||
|
// Check if there's a next argument and it matches the expected value
|
||||||
|
if i+1 < len(args) && args[i+1] == value {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user