mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-05 16:44:22 +00:00
Refactor command parsing and building
This commit is contained in:
70
pkg/backends/builder.go
Normal file
70
pkg/backends/builder.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// BuildCommandArgs converts a struct to command line arguments
|
||||
func BuildCommandArgs(options any, multipleFlags map[string]bool) []string {
|
||||
var args []string
|
||||
|
||||
v := reflect.ValueOf(options).Elem()
|
||||
t := v.Type()
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
fieldType := t.Field(i)
|
||||
|
||||
if !field.CanInterface() {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonTag := fieldType.Tag.Get("json")
|
||||
if jsonTag == "" || jsonTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get flag name from JSON tag
|
||||
flagName := strings.Split(jsonTag, ",")[0]
|
||||
flagName = strings.ReplaceAll(flagName, "_", "-")
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.Bool:
|
||||
if field.Bool() {
|
||||
args = append(args, "--"+flagName)
|
||||
}
|
||||
case reflect.Int:
|
||||
if field.Int() != 0 {
|
||||
args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10))
|
||||
}
|
||||
case reflect.Float64:
|
||||
if field.Float() != 0 {
|
||||
args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64))
|
||||
}
|
||||
case reflect.String:
|
||||
if field.String() != "" {
|
||||
args = append(args, "--"+flagName, field.String())
|
||||
}
|
||||
case reflect.Slice:
|
||||
if field.Type().Elem().Kind() == reflect.String && field.Len() > 0 {
|
||||
if multipleFlags[flagName] {
|
||||
// Multiple flags: --flag value1 --flag value2
|
||||
for j := 0; j < field.Len(); j++ {
|
||||
args = append(args, "--"+flagName, field.Index(j).String())
|
||||
}
|
||||
} else {
|
||||
// Comma-separated: --flag value1,value2
|
||||
var values []string
|
||||
for j := 0; j < field.Len(); j++ {
|
||||
values = append(values, field.Index(j).String())
|
||||
}
|
||||
args = append(args, "--"+flagName, strings.Join(values, ","))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
@@ -328,3 +328,31 @@ func (o *LlamaServerOptions) BuildCommandArgs() []string {
|
||||
}
|
||||
return backends.BuildCommandArgs(o, multipleFlags)
|
||||
}
|
||||
|
||||
// ParseLlamaCommand parses a llama-server command string into LlamaServerOptions
|
||||
// Supports multiple formats:
|
||||
// 1. Full command: "llama-server --model file.gguf"
|
||||
// 2. Full path: "/usr/local/bin/llama-server --model file.gguf"
|
||||
// 3. Args only: "--model file.gguf --gpu-layers 32"
|
||||
// 4. Multiline commands with backslashes
|
||||
func ParseLlamaCommand(command string) (*LlamaServerOptions, error) {
|
||||
executableNames := []string{"llama-server"}
|
||||
var subcommandNames []string // Llama has no subcommands
|
||||
multiValuedFlags := map[string]bool{
|
||||
"override_tensor": true,
|
||||
"override_kv": true,
|
||||
"lora": true,
|
||||
"lora_scaled": true,
|
||||
"control_vector": true,
|
||||
"control_vector_scaled": true,
|
||||
"dry_sequence_breaker": true,
|
||||
"logit_bias": true,
|
||||
}
|
||||
|
||||
var llamaOptions LlamaServerOptions
|
||||
if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &llamaOptions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &llamaOptions, nil
|
||||
}
|
||||
|
||||
@@ -378,6 +378,121 @@ func TestUnmarshalJSON_ArrayFields(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLlamaCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic command",
|
||||
command: "llama-server --model /path/to/model.gguf --gpu-layers 32",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "args only",
|
||||
command: "--model /path/to/model.gguf --ctx-size 4096",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed flag formats",
|
||||
command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "quoted strings",
|
||||
command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
command: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "unterminated quote",
|
||||
command: `llama-server --model test.gguf --api-key "unterminated`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed flag",
|
||||
command: "llama-server ---model test.gguf",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := llamacpp.ParseLlamaCommand(tt.command)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLlamaCommandValues(t *testing.T) {
|
||||
command := "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap"
|
||||
result, err := llamacpp.ParseLlamaCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "/test/model.gguf" {
|
||||
t.Errorf("expected model '/test/model.gguf', got '%s'", result.Model)
|
||||
}
|
||||
|
||||
if result.GPULayers != 32 {
|
||||
t.Errorf("expected gpu_layers 32, got %d", result.GPULayers)
|
||||
}
|
||||
|
||||
if result.Temperature != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %f", result.Temperature)
|
||||
}
|
||||
|
||||
if !result.Verbose {
|
||||
t.Errorf("expected verbose to be true")
|
||||
}
|
||||
|
||||
if !result.NoMmap {
|
||||
t.Errorf("expected no_mmap to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLlamaCommandArrays(t *testing.T) {
|
||||
command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin"
|
||||
result, err := llamacpp.ParseLlamaCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Lora) != 2 {
|
||||
t.Errorf("expected 2 lora adapters, got %d", len(result.Lora))
|
||||
}
|
||||
|
||||
expected := []string{"adapter1.bin", "adapter2.bin"}
|
||||
for i, v := range expected {
|
||||
if result.Lora[i] != v {
|
||||
t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func contains(slice []string, item string) bool {
|
||||
return slices.Contains(slice, item)
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
package llamacpp
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends"
|
||||
)
|
||||
|
||||
// ParseLlamaCommand parses a llama-server command string into LlamaServerOptions
|
||||
// Supports multiple formats:
|
||||
// 1. Full command: "llama-server --model file.gguf"
|
||||
// 2. Full path: "/usr/local/bin/llama-server --model file.gguf"
|
||||
// 3. Args only: "--model file.gguf --gpu-layers 32"
|
||||
// 4. Multiline commands with backslashes
|
||||
func ParseLlamaCommand(command string) (*LlamaServerOptions, error) {
|
||||
executableNames := []string{"llama-server"}
|
||||
var subcommandNames []string // Llama has no subcommands
|
||||
multiValuedFlags := map[string]bool{
|
||||
"override_tensor": true,
|
||||
"override_kv": true,
|
||||
"lora": true,
|
||||
"lora_scaled": true,
|
||||
"control_vector": true,
|
||||
"control_vector_scaled": true,
|
||||
"dry_sequence_breaker": true,
|
||||
"logit_bias": true,
|
||||
}
|
||||
|
||||
var llamaOptions LlamaServerOptions
|
||||
if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &llamaOptions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &llamaOptions, nil
|
||||
}
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
package llamacpp_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends/llamacpp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseLlamaCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic command",
|
||||
command: "llama-server --model /path/to/model.gguf --gpu-layers 32",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "args only",
|
||||
command: "--model /path/to/model.gguf --ctx-size 4096",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed flag formats",
|
||||
command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "quoted strings",
|
||||
command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
command: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "unterminated quote",
|
||||
command: `llama-server --model test.gguf --api-key "unterminated`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed flag",
|
||||
command: "llama-server ---model test.gguf",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := llamacpp.ParseLlamaCommand(tt.command)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLlamaCommandValues(t *testing.T) {
|
||||
command := "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap"
|
||||
result, err := llamacpp.ParseLlamaCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "/test/model.gguf" {
|
||||
t.Errorf("expected model '/test/model.gguf', got '%s'", result.Model)
|
||||
}
|
||||
|
||||
if result.GPULayers != 32 {
|
||||
t.Errorf("expected gpu_layers 32, got %d", result.GPULayers)
|
||||
}
|
||||
|
||||
if result.Temperature != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %f", result.Temperature)
|
||||
}
|
||||
|
||||
if !result.Verbose {
|
||||
t.Errorf("expected verbose to be true")
|
||||
}
|
||||
|
||||
if !result.NoMmap {
|
||||
t.Errorf("expected no_mmap to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLlamaCommandArrays(t *testing.T) {
|
||||
command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin"
|
||||
result, err := llamacpp.ParseLlamaCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Lora) != 2 {
|
||||
t.Errorf("expected 2 lora adapters, got %d", len(result.Lora))
|
||||
}
|
||||
|
||||
expected := []string{"adapter1.bin", "adapter2.bin"}
|
||||
for i, v := range expected {
|
||||
if result.Lora[i] != v {
|
||||
t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,7 @@
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"llamactl/pkg/backends"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type MlxServerOptions struct {
|
||||
@@ -26,88 +23,34 @@ type MlxServerOptions struct {
|
||||
ChatTemplateArgs string `json:"chat_template_args,omitempty"` // JSON string
|
||||
|
||||
// Sampling defaults
|
||||
Temp float64 `json:"temp,omitempty"` // Note: MLX uses "temp" not "temperature"
|
||||
Temp float64 `json:"temp,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MinP float64 `json:"min_p,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names
|
||||
func (o *MlxServerOptions) UnmarshalJSON(data []byte) error {
|
||||
// First unmarshal into a map to handle multiple field names
|
||||
var raw map[string]any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a temporary struct for standard unmarshaling
|
||||
type tempOptions MlxServerOptions
|
||||
temp := tempOptions{}
|
||||
|
||||
// Standard unmarshal first
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy to our struct
|
||||
*o = MlxServerOptions(temp)
|
||||
|
||||
// Handle alternative field names
|
||||
fieldMappings := map[string]string{
|
||||
"m": "model", // -m, --model
|
||||
"temperature": "temp", // --temperature vs --temp
|
||||
"top_k": "top_k", // --top-k
|
||||
"adapter_path": "adapter_path", // --adapter-path
|
||||
}
|
||||
|
||||
// Process alternative field names
|
||||
for altName, canonicalName := range fieldMappings {
|
||||
if value, exists := raw[altName]; exists {
|
||||
// Use reflection to set the field value
|
||||
v := reflect.ValueOf(o).Elem()
|
||||
field := v.FieldByNameFunc(func(fieldName string) bool {
|
||||
field, _ := v.Type().FieldByName(fieldName)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName
|
||||
})
|
||||
|
||||
if field.IsValid() && field.CanSet() {
|
||||
switch field.Kind() {
|
||||
case reflect.Int:
|
||||
if intVal, ok := value.(float64); ok {
|
||||
field.SetInt(int64(intVal))
|
||||
} else if strVal, ok := value.(string); ok {
|
||||
if intVal, err := strconv.Atoi(strVal); err == nil {
|
||||
field.SetInt(int64(intVal))
|
||||
}
|
||||
}
|
||||
case reflect.Float64:
|
||||
if floatVal, ok := value.(float64); ok {
|
||||
field.SetFloat(floatVal)
|
||||
} else if strVal, ok := value.(string); ok {
|
||||
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
|
||||
field.SetFloat(floatVal)
|
||||
}
|
||||
}
|
||||
case reflect.String:
|
||||
if strVal, ok := value.(string); ok {
|
||||
field.SetString(strVal)
|
||||
}
|
||||
case reflect.Bool:
|
||||
if boolVal, ok := value.(bool); ok {
|
||||
field.SetBool(boolVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildCommandArgs converts to command line arguments
|
||||
func (o *MlxServerOptions) BuildCommandArgs() []string {
|
||||
multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields
|
||||
return backends.BuildCommandArgs(o, multipleFlags)
|
||||
}
|
||||
|
||||
// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions
|
||||
// Supports multiple formats:
|
||||
// 1. Full command: "mlx_lm.server --model model/path"
|
||||
// 2. Full path: "/usr/local/bin/mlx_lm.server --model model/path"
|
||||
// 3. Args only: "--model model/path --host 0.0.0.0"
|
||||
// 4. Multiline commands with backslashes
|
||||
func ParseMlxCommand(command string) (*MlxServerOptions, error) {
|
||||
executableNames := []string{"mlx_lm.server"}
|
||||
var subcommandNames []string // MLX has no subcommands
|
||||
multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags
|
||||
|
||||
var mlxOptions MlxServerOptions
|
||||
if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mlxOptions, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,101 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseMlxCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic command",
|
||||
command: "mlx_lm.server --model /path/to/model --host 0.0.0.0",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "args only",
|
||||
command: "--model /path/to/model --port 8080",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed flag formats",
|
||||
command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "quoted strings",
|
||||
command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
command: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "unterminated quote",
|
||||
command: `mlx_lm.server --model test.mlx --chat-template "unterminated`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed flag",
|
||||
command: "mlx_lm.server ---model test.mlx",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := mlx.ParseMlxCommand(tt.command)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMlxCommandValues(t *testing.T) {
|
||||
command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG"
|
||||
result, err := mlx.ParseMlxCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "/test/model.mlx" {
|
||||
t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model)
|
||||
}
|
||||
|
||||
if result.Port != 8080 {
|
||||
t.Errorf("expected port 8080, got %d", result.Port)
|
||||
}
|
||||
|
||||
if result.Temp != 0.7 {
|
||||
t.Errorf("expected temp 0.7, got %f", result.Temp)
|
||||
}
|
||||
|
||||
if !result.TrustRemoteCode {
|
||||
t.Errorf("expected trust_remote_code to be true")
|
||||
}
|
||||
|
||||
if result.LogLevel != "DEBUG" {
|
||||
t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs(t *testing.T) {
|
||||
options := &mlx.MlxServerOptions{
|
||||
Model: "/test/model.mlx",
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends"
|
||||
)
|
||||
|
||||
// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions
|
||||
// Supports multiple formats:
|
||||
// 1. Full command: "mlx_lm.server --model model/path"
|
||||
// 2. Full path: "/usr/local/bin/mlx_lm.server --model model/path"
|
||||
// 3. Args only: "--model model/path --host 0.0.0.0"
|
||||
// 4. Multiline commands with backslashes
|
||||
func ParseMlxCommand(command string) (*MlxServerOptions, error) {
|
||||
executableNames := []string{"mlx_lm.server"}
|
||||
var subcommandNames []string // MLX has no subcommands
|
||||
multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags
|
||||
|
||||
var mlxOptions MlxServerOptions
|
||||
if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mlxOptions, nil
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
package mlx_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends/mlx"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseMlxCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic command",
|
||||
command: "mlx_lm.server --model /path/to/model --host 0.0.0.0",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "args only",
|
||||
command: "--model /path/to/model --port 8080",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed flag formats",
|
||||
command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "quoted strings",
|
||||
command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
command: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "unterminated quote",
|
||||
command: `mlx_lm.server --model test.mlx --chat-template "unterminated`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed flag",
|
||||
command: "mlx_lm.server ---model test.mlx",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := mlx.ParseMlxCommand(tt.command)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMlxCommandValues(t *testing.T) {
|
||||
command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG"
|
||||
result, err := mlx.ParseMlxCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "/test/model.mlx" {
|
||||
t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model)
|
||||
}
|
||||
|
||||
if result.Port != 8080 {
|
||||
t.Errorf("expected port 8080, got %d", result.Port)
|
||||
}
|
||||
|
||||
if result.Temp != 0.7 {
|
||||
t.Errorf("expected temp 0.7, got %f", result.Temp)
|
||||
}
|
||||
|
||||
if !result.TrustRemoteCode {
|
||||
t.Errorf("expected trust_remote_code to be true")
|
||||
}
|
||||
|
||||
if result.LogLevel != "DEBUG" {
|
||||
t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel)
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -43,69 +42,6 @@ func ParseCommand(command string, executableNames []string, subcommandNames []st
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildCommandArgs converts a struct to command line arguments
|
||||
func BuildCommandArgs(options any, multipleFlags map[string]bool) []string {
|
||||
var args []string
|
||||
|
||||
v := reflect.ValueOf(options).Elem()
|
||||
t := v.Type()
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
fieldType := t.Field(i)
|
||||
|
||||
if !field.CanInterface() {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonTag := fieldType.Tag.Get("json")
|
||||
if jsonTag == "" || jsonTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get flag name from JSON tag
|
||||
flagName := strings.Split(jsonTag, ",")[0]
|
||||
flagName = strings.ReplaceAll(flagName, "_", "-")
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.Bool:
|
||||
if field.Bool() {
|
||||
args = append(args, "--"+flagName)
|
||||
}
|
||||
case reflect.Int:
|
||||
if field.Int() != 0 {
|
||||
args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10))
|
||||
}
|
||||
case reflect.Float64:
|
||||
if field.Float() != 0 {
|
||||
args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64))
|
||||
}
|
||||
case reflect.String:
|
||||
if field.String() != "" {
|
||||
args = append(args, "--"+flagName, field.String())
|
||||
}
|
||||
case reflect.Slice:
|
||||
if field.Type().Elem().Kind() == reflect.String && field.Len() > 0 {
|
||||
if multipleFlags[flagName] {
|
||||
// Multiple flags: --flag value1 --flag value2
|
||||
for j := 0; j < field.Len(); j++ {
|
||||
args = append(args, "--"+flagName, field.Index(j).String())
|
||||
}
|
||||
} else {
|
||||
// Comma-separated: --flag value1,value2
|
||||
var values []string
|
||||
for j := 0; j < field.Len(); j++ {
|
||||
values = append(values, field.Index(j).String())
|
||||
}
|
||||
args = append(args, "--"+flagName, strings.Join(values, ","))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// normalizeCommand handles multiline commands with backslashes
|
||||
func normalizeCommand(command string) string {
|
||||
re := regexp.MustCompile(`\\\s*\n\s*`)
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
package vllm
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends"
|
||||
)
|
||||
|
||||
// ParseVllmCommand parses a vLLM serve command string into VllmServerOptions
|
||||
// Supports multiple formats:
|
||||
// 1. Full command: "vllm serve --model MODEL_NAME --other-args"
|
||||
// 2. Full path: "/usr/local/bin/vllm serve --model MODEL_NAME"
|
||||
// 3. Serve only: "serve --model MODEL_NAME --other-args"
|
||||
// 4. Args only: "--model MODEL_NAME --other-args"
|
||||
// 5. Multiline commands with backslashes
|
||||
func ParseVllmCommand(command string) (*VllmServerOptions, error) {
|
||||
executableNames := []string{"vllm"}
|
||||
subcommandNames := []string{"serve"}
|
||||
multiValuedFlags := map[string]bool{
|
||||
"middleware": true,
|
||||
"api_key": true,
|
||||
"allowed_origins": true,
|
||||
"allowed_methods": true,
|
||||
"allowed_headers": true,
|
||||
"lora_modules": true,
|
||||
"prompt_adapters": true,
|
||||
}
|
||||
|
||||
var vllmOptions VllmServerOptions
|
||||
if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &vllmOptions, nil
|
||||
}
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
package vllm_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends/vllm"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseVllmCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic vllm serve command",
|
||||
command: "vllm serve --model microsoft/DialoGPT-medium",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "serve only command",
|
||||
command: "serve --model microsoft/DialoGPT-medium",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "args only",
|
||||
command: "--model microsoft/DialoGPT-medium --tensor-parallel-size 2",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
command: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "unterminated quote",
|
||||
command: `vllm serve --model "unterminated`,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := vllm.ParseVllmCommand(tt.command)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseVllmCommandValues(t *testing.T) {
|
||||
command := "vllm serve --model test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs"
|
||||
result, err := vllm.ParseVllmCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got '%s'", result.Model)
|
||||
}
|
||||
if result.TensorParallelSize != 4 {
|
||||
t.Errorf("expected tensor_parallel_size 4, got %d", result.TensorParallelSize)
|
||||
}
|
||||
if result.GPUMemoryUtilization != 0.8 {
|
||||
t.Errorf("expected gpu_memory_utilization 0.8, got %f", result.GPUMemoryUtilization)
|
||||
}
|
||||
if !result.EnableLogOutputs {
|
||||
t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs)
|
||||
}
|
||||
}
|
||||
@@ -142,3 +142,31 @@ func (o *VllmServerOptions) BuildCommandArgs() []string {
|
||||
}
|
||||
return backends.BuildCommandArgs(o, multipleFlags)
|
||||
}
|
||||
|
||||
// ParseVllmCommand parses a vLLM serve command string into VllmServerOptions
|
||||
// Supports multiple formats:
|
||||
// 1. Full command: "vllm serve --model MODEL_NAME --other-args"
|
||||
// 2. Full path: "/usr/local/bin/vllm serve --model MODEL_NAME"
|
||||
// 3. Serve only: "serve --model MODEL_NAME --other-args"
|
||||
// 4. Args only: "--model MODEL_NAME --other-args"
|
||||
// 5. Multiline commands with backslashes
|
||||
func ParseVllmCommand(command string) (*VllmServerOptions, error) {
|
||||
executableNames := []string{"vllm"}
|
||||
subcommandNames := []string{"serve"}
|
||||
multiValuedFlags := map[string]bool{
|
||||
"middleware": true,
|
||||
"api_key": true,
|
||||
"allowed_origins": true,
|
||||
"allowed_methods": true,
|
||||
"allowed_headers": true,
|
||||
"lora_modules": true,
|
||||
"prompt_adapters": true,
|
||||
}
|
||||
|
||||
var vllmOptions VllmServerOptions
|
||||
if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &vllmOptions, nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,84 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseVllmCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic vllm serve command",
|
||||
command: "vllm serve --model microsoft/DialoGPT-medium",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "serve only command",
|
||||
command: "serve --model microsoft/DialoGPT-medium",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "args only",
|
||||
command: "--model microsoft/DialoGPT-medium --tensor-parallel-size 2",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
command: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "unterminated quote",
|
||||
command: `vllm serve --model "unterminated`,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := vllm.ParseVllmCommand(tt.command)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseVllmCommandValues(t *testing.T) {
|
||||
command := "vllm serve --model test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs"
|
||||
result, err := vllm.ParseVllmCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got '%s'", result.Model)
|
||||
}
|
||||
if result.TensorParallelSize != 4 {
|
||||
t.Errorf("expected tensor_parallel_size 4, got %d", result.TensorParallelSize)
|
||||
}
|
||||
if result.GPUMemoryUtilization != 0.8 {
|
||||
t.Errorf("expected gpu_memory_utilization 0.8, got %f", result.GPUMemoryUtilization)
|
||||
}
|
||||
if !result.EnableLogOutputs {
|
||||
t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs(t *testing.T) {
|
||||
options := vllm.VllmServerOptions{
|
||||
Model: "microsoft/DialoGPT-medium",
|
||||
|
||||
Reference in New Issue
Block a user