mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-06 00:54:23 +00:00
Update multi valued flags in backends
This commit is contained in:
@@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// BuildCommandArgs converts a struct to command line arguments
|
// BuildCommandArgs converts a struct to command line arguments
|
||||||
func BuildCommandArgs(options any, multipleFlags map[string]bool) []string {
|
func BuildCommandArgs(options any, multipleFlags map[string]struct{}) []string {
|
||||||
var args []string
|
var args []string
|
||||||
|
|
||||||
v := reflect.ValueOf(options).Elem()
|
v := reflect.ValueOf(options).Elem()
|
||||||
@@ -28,9 +28,10 @@ func BuildCommandArgs(options any, multipleFlags map[string]bool) []string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get flag name from JSON tag
|
// Get flag name from JSON tag (snake_case)
|
||||||
flagName := strings.Split(jsonTag, ",")[0]
|
jsonFieldName := strings.Split(jsonTag, ",")[0]
|
||||||
flagName = strings.ReplaceAll(flagName, "_", "-")
|
// Convert to kebab-case for CLI flags
|
||||||
|
flagName := strings.ReplaceAll(jsonFieldName, "_", "-")
|
||||||
|
|
||||||
switch field.Kind() {
|
switch field.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
@@ -51,7 +52,8 @@ func BuildCommandArgs(options any, multipleFlags map[string]bool) []string {
|
|||||||
}
|
}
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
if field.Type().Elem().Kind() == reflect.String && field.Len() > 0 {
|
if field.Type().Elem().Kind() == reflect.String && field.Len() > 0 {
|
||||||
if multipleFlags[flagName] {
|
// Use jsonFieldName (snake_case) for multipleFlags lookup
|
||||||
|
if _, isMultiValue := multipleFlags[jsonFieldName]; isMultiValue {
|
||||||
// Multiple flags: --flag value1 --flag value2
|
// Multiple flags: --flag value1 --flag value2
|
||||||
for j := 0; j < field.Len(); j++ {
|
for j := 0; j < field.Len(); j++ {
|
||||||
args = append(args, "--"+flagName, field.Index(j).String())
|
args = append(args, "--"+flagName, field.Index(j).String())
|
||||||
|
|||||||
@@ -9,25 +9,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
// llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
||||||
// Used for both parsing (with underscores) and building (with dashes)
|
// Keys use snake_case as the parser converts kebab-case flags to snake_case before lookup
|
||||||
var llamaMultiValuedFlags = map[string]bool{
|
var llamaMultiValuedFlags = map[string]struct{}{
|
||||||
// Parsing keys (with underscores)
|
"override_tensor": {},
|
||||||
"override_tensor": true,
|
"override_kv": {},
|
||||||
"override_kv": true,
|
"lora": {},
|
||||||
"lora": true,
|
"lora_scaled": {},
|
||||||
"lora_scaled": true,
|
"control_vector": {},
|
||||||
"control_vector": true,
|
"control_vector_scaled": {},
|
||||||
"control_vector_scaled": true,
|
"dry_sequence_breaker": {},
|
||||||
"dry_sequence_breaker": true,
|
"logit_bias": {},
|
||||||
"logit_bias": true,
|
|
||||||
// Building keys (with dashes)
|
|
||||||
"override-tensor": true,
|
|
||||||
"override-kv": true,
|
|
||||||
"lora-scaled": true,
|
|
||||||
"control-vector": true,
|
|
||||||
"control-vector-scaled": true,
|
|
||||||
"dry-sequence-breaker": true,
|
|
||||||
"logit-bias": true,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type LlamaServerOptions struct {
|
type LlamaServerOptions struct {
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func (o *MlxServerOptions) Validate() error {
|
|||||||
|
|
||||||
// BuildCommandArgs converts to command line arguments
|
// BuildCommandArgs converts to command line arguments
|
||||||
func (o *MlxServerOptions) BuildCommandArgs() []string {
|
func (o *MlxServerOptions) BuildCommandArgs() []string {
|
||||||
multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields
|
multipleFlags := map[string]struct{}{} // MLX doesn't currently have []string fields
|
||||||
return BuildCommandArgs(o, multipleFlags)
|
return BuildCommandArgs(o, multipleFlags)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,8 +78,8 @@ func (o *MlxServerOptions) BuildDockerArgs() []string {
|
|||||||
// 4. Multiline commands with backslashes
|
// 4. Multiline commands with backslashes
|
||||||
func (o *MlxServerOptions) ParseCommand(command string) (any, error) {
|
func (o *MlxServerOptions) ParseCommand(command string) (any, error) {
|
||||||
executableNames := []string{"mlx_lm.server"}
|
executableNames := []string{"mlx_lm.server"}
|
||||||
var subcommandNames []string // MLX has no subcommands
|
var subcommandNames []string // MLX has no subcommands
|
||||||
multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags
|
multiValuedFlags := map[string]struct{}{} // MLX has no multi-valued flags
|
||||||
|
|
||||||
var mlxOptions MlxServerOptions
|
var mlxOptions MlxServerOptions
|
||||||
if err := parseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil {
|
if err := parseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// parseCommand parses a command string into a target struct
|
// parseCommand parses a command string into a target struct
|
||||||
func parseCommand(command string, executableNames []string, subcommandNames []string, multiValuedFlags map[string]bool, target any) error {
|
func parseCommand(command string, executableNames []string, subcommandNames []string, multiValuedFlags map[string]struct{}, target any) error {
|
||||||
// Normalize multiline commands
|
// Normalize multiline commands
|
||||||
command = normalizeCommand(command)
|
command = normalizeCommand(command)
|
||||||
if command == "" {
|
if command == "" {
|
||||||
@@ -125,7 +125,7 @@ func extractArgs(command string, executableNames []string, subcommandNames []str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parseFlags parses command line flags into a map
|
// parseFlags parses command line flags into a map
|
||||||
func parseFlags(args []string, multiValuedFlags map[string]bool) (map[string]any, error) {
|
func parseFlags(args []string, multiValuedFlags map[string]struct{}) (map[string]any, error) {
|
||||||
options := make(map[string]any)
|
options := make(map[string]any)
|
||||||
|
|
||||||
for i := 0; i < len(args); i++ {
|
for i := 0; i < len(args); i++ {
|
||||||
@@ -163,7 +163,7 @@ func parseFlags(args []string, multiValuedFlags map[string]bool) (map[string]any
|
|||||||
|
|
||||||
if hasValue {
|
if hasValue {
|
||||||
// Handle multi-valued flags
|
// Handle multi-valued flags
|
||||||
if multiValuedFlags[flagName] {
|
if _, isMultiValue := multiValuedFlags[flagName]; isMultiValue {
|
||||||
if existing, ok := options[flagName].([]string); ok {
|
if existing, ok := options[flagName].([]string); ok {
|
||||||
options[flagName] = append(existing, value)
|
options[flagName] = append(existing, value)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -6,12 +6,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// vllmMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
// vllmMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated
|
||||||
var vllmMultiValuedFlags = map[string]bool{
|
// Based on vLLM's CLI argument definitions with action='append' or List types
|
||||||
"api-key": true,
|
// Keys use snake_case as the parser converts kebab-case flags to snake_case before lookup
|
||||||
"allowed-origins": true,
|
var vllmMultiValuedFlags = map[string]struct{}{
|
||||||
"allowed-methods": true,
|
"api_key": {}, // --api-key (action='append')
|
||||||
"allowed-headers": true,
|
"allowed_origins": {}, // --allowed-origins (List type)
|
||||||
"middleware": true,
|
"allowed_methods": {}, // --allowed-methods (List type)
|
||||||
|
"allowed_headers": {}, // --allowed-headers (List type)
|
||||||
|
"middleware": {}, // --middleware (action='append')
|
||||||
|
"lora_modules": {}, // --lora-modules (custom LoRAParserAction, accepts multiple)
|
||||||
|
"prompt_adapters": {}, // --prompt-adapters (similar to lora-modules, accepts multiple)
|
||||||
}
|
}
|
||||||
|
|
||||||
type VllmServerOptions struct {
|
type VllmServerOptions struct {
|
||||||
@@ -212,18 +216,9 @@ func (o *VllmServerOptions) BuildDockerArgs() []string {
|
|||||||
func (o *VllmServerOptions) ParseCommand(command string) (any, error) {
|
func (o *VllmServerOptions) ParseCommand(command string) (any, error) {
|
||||||
executableNames := []string{"vllm"}
|
executableNames := []string{"vllm"}
|
||||||
subcommandNames := []string{"serve"}
|
subcommandNames := []string{"serve"}
|
||||||
multiValuedFlags := map[string]bool{
|
|
||||||
"middleware": true,
|
|
||||||
"api_key": true,
|
|
||||||
"allowed_origins": true,
|
|
||||||
"allowed_methods": true,
|
|
||||||
"allowed_headers": true,
|
|
||||||
"lora_modules": true,
|
|
||||||
"prompt_adapters": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
var vllmOptions VllmServerOptions
|
var vllmOptions VllmServerOptions
|
||||||
if err := parseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil {
|
if err := parseCommand(command, executableNames, subcommandNames, vllmMultiValuedFlags, &vllmOptions); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -120,6 +120,41 @@ func TestParseVllmCommand(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseVllmCommandArrays(t *testing.T) {
|
||||||
|
command := "vllm serve test-model --middleware auth.py --middleware=cors.py --api-key key1 --api-key key2"
|
||||||
|
|
||||||
|
var opts backends.VllmServerOptions
|
||||||
|
resultAny, err := opts.ParseCommand(command)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, ok := resultAny.(*backends.VllmServerOptions)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected *VllmServerOptions, got %T", resultAny)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMiddleware := []string{"auth.py", "cors.py"}
|
||||||
|
if len(result.Middleware) != len(expectedMiddleware) {
|
||||||
|
t.Errorf("expected %d middleware items, got %d", len(expectedMiddleware), len(result.Middleware))
|
||||||
|
}
|
||||||
|
for i, v := range expectedMiddleware {
|
||||||
|
if i >= len(result.Middleware) || result.Middleware[i] != v {
|
||||||
|
t.Errorf("expected middleware[%d]=%s got %s", i, v, result.Middleware[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedAPIKeys := []string{"key1", "key2"}
|
||||||
|
if len(result.APIKey) != len(expectedAPIKeys) {
|
||||||
|
t.Errorf("expected %d api keys, got %d", len(expectedAPIKeys), len(result.APIKey))
|
||||||
|
}
|
||||||
|
for i, v := range expectedAPIKeys {
|
||||||
|
if i >= len(result.APIKey) || result.APIKey[i] != v {
|
||||||
|
t.Errorf("expected api_key[%d]=%s got %s", i, v, result.APIKey[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestVllmBuildCommandArgs_BooleanFields(t *testing.T) {
|
func TestVllmBuildCommandArgs_BooleanFields(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
Reference in New Issue
Block a user