mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-05 16:44:22 +00:00
Update multi valued flags in backends
This commit is contained in:
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
// BuildCommandArgs converts a struct to command line arguments
|
||||
func BuildCommandArgs(options any, multipleFlags map[string]bool) []string {
|
||||
func BuildCommandArgs(options any, multipleFlags map[string]struct{}) []string {
|
||||
var args []string
|
||||
|
||||
v := reflect.ValueOf(options).Elem()
|
||||
@@ -28,9 +28,10 @@ func BuildCommandArgs(options any, multipleFlags map[string]bool) []string {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get flag name from JSON tag
|
||||
flagName := strings.Split(jsonTag, ",")[0]
|
||||
flagName = strings.ReplaceAll(flagName, "_", "-")
|
||||
// Get flag name from JSON tag (snake_case)
|
||||
jsonFieldName := strings.Split(jsonTag, ",")[0]
|
||||
// Convert to kebab-case for CLI flags
|
||||
flagName := strings.ReplaceAll(jsonFieldName, "_", "-")
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.Bool:
|
||||
@@ -51,7 +52,8 @@ func BuildCommandArgs(options any, multipleFlags map[string]bool) []string {
|
||||
}
|
||||
case reflect.Slice:
|
||||
if field.Type().Elem().Kind() == reflect.String && field.Len() > 0 {
|
||||
if multipleFlags[flagName] {
|
||||
// Use jsonFieldName (snake_case) for multipleFlags lookup
|
||||
if _, isMultiValue := multipleFlags[jsonFieldName]; isMultiValue {
|
||||
// Multiple flags: --flag value1 --flag value2
|
||||
for j := 0; j < field.Len(); j++ {
|
||||
args = append(args, "--"+flagName, field.Index(j).String())
|
||||
|
||||
@@ -9,25 +9,16 @@ import (
|
||||
)
|
||||
|
||||
// llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
||||
// Used for both parsing (with underscores) and building (with dashes)
|
||||
var llamaMultiValuedFlags = map[string]bool{
|
||||
// Parsing keys (with underscores)
|
||||
"override_tensor": true,
|
||||
"override_kv": true,
|
||||
"lora": true,
|
||||
"lora_scaled": true,
|
||||
"control_vector": true,
|
||||
"control_vector_scaled": true,
|
||||
"dry_sequence_breaker": true,
|
||||
"logit_bias": true,
|
||||
// Building keys (with dashes)
|
||||
"override-tensor": true,
|
||||
"override-kv": true,
|
||||
"lora-scaled": true,
|
||||
"control-vector": true,
|
||||
"control-vector-scaled": true,
|
||||
"dry-sequence-breaker": true,
|
||||
"logit-bias": true,
|
||||
// Keys use snake_case as the parser converts kebab-case flags to snake_case before lookup
|
||||
var llamaMultiValuedFlags = map[string]struct{}{
|
||||
"override_tensor": {},
|
||||
"override_kv": {},
|
||||
"lora": {},
|
||||
"lora_scaled": {},
|
||||
"control_vector": {},
|
||||
"control_vector_scaled": {},
|
||||
"dry_sequence_breaker": {},
|
||||
"logit_bias": {},
|
||||
}
|
||||
|
||||
type LlamaServerOptions struct {
|
||||
|
||||
@@ -62,7 +62,7 @@ func (o *MlxServerOptions) Validate() error {
|
||||
|
||||
// BuildCommandArgs converts to command line arguments
|
||||
func (o *MlxServerOptions) BuildCommandArgs() []string {
|
||||
multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields
|
||||
multipleFlags := map[string]struct{}{} // MLX doesn't currently have []string fields
|
||||
return BuildCommandArgs(o, multipleFlags)
|
||||
}
|
||||
|
||||
@@ -79,7 +79,7 @@ func (o *MlxServerOptions) BuildDockerArgs() []string {
|
||||
func (o *MlxServerOptions) ParseCommand(command string) (any, error) {
|
||||
executableNames := []string{"mlx_lm.server"}
|
||||
var subcommandNames []string // MLX has no subcommands
|
||||
multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags
|
||||
multiValuedFlags := map[string]struct{}{} // MLX has no multi-valued flags
|
||||
|
||||
var mlxOptions MlxServerOptions
|
||||
if err := parseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
// parseCommand parses a command string into a target struct
|
||||
func parseCommand(command string, executableNames []string, subcommandNames []string, multiValuedFlags map[string]bool, target any) error {
|
||||
func parseCommand(command string, executableNames []string, subcommandNames []string, multiValuedFlags map[string]struct{}, target any) error {
|
||||
// Normalize multiline commands
|
||||
command = normalizeCommand(command)
|
||||
if command == "" {
|
||||
@@ -125,7 +125,7 @@ func extractArgs(command string, executableNames []string, subcommandNames []str
|
||||
}
|
||||
|
||||
// parseFlags parses command line flags into a map
|
||||
func parseFlags(args []string, multiValuedFlags map[string]bool) (map[string]any, error) {
|
||||
func parseFlags(args []string, multiValuedFlags map[string]struct{}) (map[string]any, error) {
|
||||
options := make(map[string]any)
|
||||
|
||||
for i := 0; i < len(args); i++ {
|
||||
@@ -163,7 +163,7 @@ func parseFlags(args []string, multiValuedFlags map[string]bool) (map[string]any
|
||||
|
||||
if hasValue {
|
||||
// Handle multi-valued flags
|
||||
if multiValuedFlags[flagName] {
|
||||
if _, isMultiValue := multiValuedFlags[flagName]; isMultiValue {
|
||||
if existing, ok := options[flagName].([]string); ok {
|
||||
options[flagName] = append(existing, value)
|
||||
} else {
|
||||
|
||||
@@ -6,12 +6,16 @@ import (
|
||||
)
|
||||
|
||||
// vllmMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
||||
var vllmMultiValuedFlags = map[string]bool{
|
||||
"api-key": true,
|
||||
"allowed-origins": true,
|
||||
"allowed-methods": true,
|
||||
"allowed-headers": true,
|
||||
"middleware": true,
|
||||
// Based on vLLM's CLI argument definitions with action='append' or List types
|
||||
// Keys use snake_case as the parser converts kebab-case flags to snake_case before lookup
|
||||
var vllmMultiValuedFlags = map[string]struct{}{
|
||||
"api_key": {}, // --api-key (action='append')
|
||||
"allowed_origins": {}, // --allowed-origins (List type)
|
||||
"allowed_methods": {}, // --allowed-methods (List type)
|
||||
"allowed_headers": {}, // --allowed-headers (List type)
|
||||
"middleware": {}, // --middleware (action='append')
|
||||
"lora_modules": {}, // --lora-modules (custom LoRAParserAction, accepts multiple)
|
||||
"prompt_adapters": {}, // --prompt-adapters (similar to lora-modules, accepts multiple)
|
||||
}
|
||||
|
||||
type VllmServerOptions struct {
|
||||
@@ -212,18 +216,9 @@ func (o *VllmServerOptions) BuildDockerArgs() []string {
|
||||
func (o *VllmServerOptions) ParseCommand(command string) (any, error) {
|
||||
executableNames := []string{"vllm"}
|
||||
subcommandNames := []string{"serve"}
|
||||
multiValuedFlags := map[string]bool{
|
||||
"middleware": true,
|
||||
"api_key": true,
|
||||
"allowed_origins": true,
|
||||
"allowed_methods": true,
|
||||
"allowed_headers": true,
|
||||
"lora_modules": true,
|
||||
"prompt_adapters": true,
|
||||
}
|
||||
|
||||
var vllmOptions VllmServerOptions
|
||||
if err := parseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil {
|
||||
if err := parseCommand(command, executableNames, subcommandNames, vllmMultiValuedFlags, &vllmOptions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -120,6 +120,41 @@ func TestParseVllmCommand(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseVllmCommandArrays(t *testing.T) {
|
||||
command := "vllm serve test-model --middleware auth.py --middleware=cors.py --api-key key1 --api-key key2"
|
||||
|
||||
var opts backends.VllmServerOptions
|
||||
resultAny, err := opts.ParseCommand(command)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
result, ok := resultAny.(*backends.VllmServerOptions)
|
||||
if !ok {
|
||||
t.Fatalf("expected *VllmServerOptions, got %T", resultAny)
|
||||
}
|
||||
|
||||
expectedMiddleware := []string{"auth.py", "cors.py"}
|
||||
if len(result.Middleware) != len(expectedMiddleware) {
|
||||
t.Errorf("expected %d middleware items, got %d", len(expectedMiddleware), len(result.Middleware))
|
||||
}
|
||||
for i, v := range expectedMiddleware {
|
||||
if i >= len(result.Middleware) || result.Middleware[i] != v {
|
||||
t.Errorf("expected middleware[%d]=%s got %s", i, v, result.Middleware[i])
|
||||
}
|
||||
}
|
||||
|
||||
expectedAPIKeys := []string{"key1", "key2"}
|
||||
if len(result.APIKey) != len(expectedAPIKeys) {
|
||||
t.Errorf("expected %d api keys, got %d", len(expectedAPIKeys), len(result.APIKey))
|
||||
}
|
||||
for i, v := range expectedAPIKeys {
|
||||
if i >= len(result.APIKey) || result.APIKey[i] != v {
|
||||
t.Errorf("expected api_key[%d]=%s got %s", i, v, result.APIKey[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestVllmBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
Reference in New Issue
Block a user