mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-06 17:14:28 +00:00
102 lines
2.1 KiB
Go
102 lines
2.1 KiB
Go
package mlx_test
|
|
|
|
import (
|
|
"llamactl/pkg/backends/mlx"
|
|
"testing"
|
|
)
|
|
|
|
func TestParseMlxCommand(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
command string
|
|
expectErr bool
|
|
}{
|
|
{
|
|
name: "basic command",
|
|
command: "mlx_lm.server --model /path/to/model --host 0.0.0.0",
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "args only",
|
|
command: "--model /path/to/model --port 8080",
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "mixed flag formats",
|
|
command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code",
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "quoted strings",
|
|
command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`,
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "empty command",
|
|
command: "",
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "unterminated quote",
|
|
command: `mlx_lm.server --model test.mlx --chat-template "unterminated`,
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "malformed flag",
|
|
command: "mlx_lm.server ---model test.mlx",
|
|
expectErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result, err := mlx.ParseMlxCommand(tt.command)
|
|
|
|
if tt.expectErr {
|
|
if err == nil {
|
|
t.Errorf("expected error but got none")
|
|
}
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
return
|
|
}
|
|
|
|
if result == nil {
|
|
t.Errorf("expected result but got nil")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseMlxCommandValues(t *testing.T) {
|
|
command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG"
|
|
result, err := mlx.ParseMlxCommand(command)
|
|
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if result.Model != "/test/model.mlx" {
|
|
t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model)
|
|
}
|
|
|
|
if result.Port != 8080 {
|
|
t.Errorf("expected port 8080, got %d", result.Port)
|
|
}
|
|
|
|
if result.Temp != 0.7 {
|
|
t.Errorf("expected temp 0.7, got %f", result.Temp)
|
|
}
|
|
|
|
if !result.TrustRemoteCode {
|
|
t.Errorf("expected trust_remote_code to be true")
|
|
}
|
|
|
|
if result.LogLevel != "DEBUG" {
|
|
t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel)
|
|
}
|
|
}
|