mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-06 09:04:27 +00:00
Refactor command parsing and building
This commit is contained in:
@@ -1,10 +1,7 @@
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"llamactl/pkg/backends"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type MlxServerOptions struct {
|
||||
@@ -26,88 +23,34 @@ type MlxServerOptions struct {
|
||||
ChatTemplateArgs string `json:"chat_template_args,omitempty"` // JSON string
|
||||
|
||||
// Sampling defaults
|
||||
Temp float64 `json:"temp,omitempty"` // Note: MLX uses "temp" not "temperature"
|
||||
Temp float64 `json:"temp,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MinP float64 `json:"min_p,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names
|
||||
func (o *MlxServerOptions) UnmarshalJSON(data []byte) error {
|
||||
// First unmarshal into a map to handle multiple field names
|
||||
var raw map[string]any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a temporary struct for standard unmarshaling
|
||||
type tempOptions MlxServerOptions
|
||||
temp := tempOptions{}
|
||||
|
||||
// Standard unmarshal first
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy to our struct
|
||||
*o = MlxServerOptions(temp)
|
||||
|
||||
// Handle alternative field names
|
||||
fieldMappings := map[string]string{
|
||||
"m": "model", // -m, --model
|
||||
"temperature": "temp", // --temperature vs --temp
|
||||
"top_k": "top_k", // --top-k
|
||||
"adapter_path": "adapter_path", // --adapter-path
|
||||
}
|
||||
|
||||
// Process alternative field names
|
||||
for altName, canonicalName := range fieldMappings {
|
||||
if value, exists := raw[altName]; exists {
|
||||
// Use reflection to set the field value
|
||||
v := reflect.ValueOf(o).Elem()
|
||||
field := v.FieldByNameFunc(func(fieldName string) bool {
|
||||
field, _ := v.Type().FieldByName(fieldName)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName
|
||||
})
|
||||
|
||||
if field.IsValid() && field.CanSet() {
|
||||
switch field.Kind() {
|
||||
case reflect.Int:
|
||||
if intVal, ok := value.(float64); ok {
|
||||
field.SetInt(int64(intVal))
|
||||
} else if strVal, ok := value.(string); ok {
|
||||
if intVal, err := strconv.Atoi(strVal); err == nil {
|
||||
field.SetInt(int64(intVal))
|
||||
}
|
||||
}
|
||||
case reflect.Float64:
|
||||
if floatVal, ok := value.(float64); ok {
|
||||
field.SetFloat(floatVal)
|
||||
} else if strVal, ok := value.(string); ok {
|
||||
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
|
||||
field.SetFloat(floatVal)
|
||||
}
|
||||
}
|
||||
case reflect.String:
|
||||
if strVal, ok := value.(string); ok {
|
||||
field.SetString(strVal)
|
||||
}
|
||||
case reflect.Bool:
|
||||
if boolVal, ok := value.(bool); ok {
|
||||
field.SetBool(boolVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildCommandArgs converts to command line arguments
|
||||
func (o *MlxServerOptions) BuildCommandArgs() []string {
|
||||
multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields
|
||||
return backends.BuildCommandArgs(o, multipleFlags)
|
||||
}
|
||||
|
||||
// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions
|
||||
// Supports multiple formats:
|
||||
// 1. Full command: "mlx_lm.server --model model/path"
|
||||
// 2. Full path: "/usr/local/bin/mlx_lm.server --model model/path"
|
||||
// 3. Args only: "--model model/path --host 0.0.0.0"
|
||||
// 4. Multiline commands with backslashes
|
||||
func ParseMlxCommand(command string) (*MlxServerOptions, error) {
|
||||
executableNames := []string{"mlx_lm.server"}
|
||||
var subcommandNames []string // MLX has no subcommands
|
||||
multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags
|
||||
|
||||
var mlxOptions MlxServerOptions
|
||||
if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mlxOptions, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,101 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseMlxCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic command",
|
||||
command: "mlx_lm.server --model /path/to/model --host 0.0.0.0",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "args only",
|
||||
command: "--model /path/to/model --port 8080",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed flag formats",
|
||||
command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "quoted strings",
|
||||
command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
command: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "unterminated quote",
|
||||
command: `mlx_lm.server --model test.mlx --chat-template "unterminated`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed flag",
|
||||
command: "mlx_lm.server ---model test.mlx",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := mlx.ParseMlxCommand(tt.command)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMlxCommandValues(t *testing.T) {
|
||||
command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG"
|
||||
result, err := mlx.ParseMlxCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "/test/model.mlx" {
|
||||
t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model)
|
||||
}
|
||||
|
||||
if result.Port != 8080 {
|
||||
t.Errorf("expected port 8080, got %d", result.Port)
|
||||
}
|
||||
|
||||
if result.Temp != 0.7 {
|
||||
t.Errorf("expected temp 0.7, got %f", result.Temp)
|
||||
}
|
||||
|
||||
if !result.TrustRemoteCode {
|
||||
t.Errorf("expected trust_remote_code to be true")
|
||||
}
|
||||
|
||||
if result.LogLevel != "DEBUG" {
|
||||
t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs(t *testing.T) {
|
||||
options := &mlx.MlxServerOptions{
|
||||
Model: "/test/model.mlx",
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends"
|
||||
)
|
||||
|
||||
// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions
|
||||
// Supports multiple formats:
|
||||
// 1. Full command: "mlx_lm.server --model model/path"
|
||||
// 2. Full path: "/usr/local/bin/mlx_lm.server --model model/path"
|
||||
// 3. Args only: "--model model/path --host 0.0.0.0"
|
||||
// 4. Multiline commands with backslashes
|
||||
func ParseMlxCommand(command string) (*MlxServerOptions, error) {
|
||||
executableNames := []string{"mlx_lm.server"}
|
||||
var subcommandNames []string // MLX has no subcommands
|
||||
multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags
|
||||
|
||||
var mlxOptions MlxServerOptions
|
||||
if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mlxOptions, nil
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
package mlx_test
|
||||
|
||||
import (
|
||||
"llamactl/pkg/backends/mlx"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseMlxCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic command",
|
||||
command: "mlx_lm.server --model /path/to/model --host 0.0.0.0",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "args only",
|
||||
command: "--model /path/to/model --port 8080",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed flag formats",
|
||||
command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "quoted strings",
|
||||
command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty command",
|
||||
command: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "unterminated quote",
|
||||
command: `mlx_lm.server --model test.mlx --chat-template "unterminated`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed flag",
|
||||
command: "mlx_lm.server ---model test.mlx",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := mlx.ParseMlxCommand(tt.command)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("expected result but got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMlxCommandValues(t *testing.T) {
|
||||
command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG"
|
||||
result, err := mlx.ParseMlxCommand(command)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "/test/model.mlx" {
|
||||
t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model)
|
||||
}
|
||||
|
||||
if result.Port != 8080 {
|
||||
t.Errorf("expected port 8080, got %d", result.Port)
|
||||
}
|
||||
|
||||
if result.Temp != 0.7 {
|
||||
t.Errorf("expected temp 0.7, got %f", result.Temp)
|
||||
}
|
||||
|
||||
if !result.TrustRemoteCode {
|
||||
t.Errorf("expected trust_remote_code to be true")
|
||||
}
|
||||
|
||||
if result.LogLevel != "DEBUG" {
|
||||
t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user