From 4df02a6519dfffbd36364c044032d4db4ccb9fea Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 19 Sep 2025 18:05:12 +0200 Subject: [PATCH 01/20] Initial vLLM backend support --- pkg/backends/backend.go | 1 + pkg/backends/vllm/parser.go | 302 +++++++++++++++++++++ pkg/backends/vllm/parser_test.go | 83 ++++++ pkg/backends/vllm/vllm.go | 439 ++++++++++++++++++++++++++++++ pkg/backends/vllm/vllm_test.go | 106 ++++++++ pkg/config/config.go | 7 + pkg/instance/lifecycle.go | 2 + pkg/instance/options.go | 37 ++- pkg/server/handlers.go | 56 +++- pkg/server/routes.go | 3 + pkg/validation/validation.go | 21 ++ vllm_backend_spec.md | 440 +++++++++++++++++++++++++++++++ 12 files changed, 1495 insertions(+), 2 deletions(-) create mode 100644 pkg/backends/vllm/parser.go create mode 100644 pkg/backends/vllm/parser_test.go create mode 100644 pkg/backends/vllm/vllm.go create mode 100644 pkg/backends/vllm/vllm_test.go create mode 100644 vllm_backend_spec.md diff --git a/pkg/backends/backend.go b/pkg/backends/backend.go index 0270945..802fec2 100644 --- a/pkg/backends/backend.go +++ b/pkg/backends/backend.go @@ -5,5 +5,6 @@ type BackendType string const ( BackendTypeLlamaCpp BackendType = "llama_cpp" BackendTypeMlxLm BackendType = "mlx_lm" + BackendTypeVllm BackendType = "vllm" // BackendTypeMlxVlm BackendType = "mlx_vlm" // Future expansion ) diff --git a/pkg/backends/vllm/parser.go b/pkg/backends/vllm/parser.go new file mode 100644 index 0000000..cb9125c --- /dev/null +++ b/pkg/backends/vllm/parser.go @@ -0,0 +1,302 @@ +package vllm + +import ( + "encoding/json" + "errors" + "fmt" + "path/filepath" + "regexp" + "strconv" + "strings" +) + +// ParseVllmCommand parses a vLLM serve command string into VllmServerOptions +// Supports multiple formats: +// 1. Full command: "vllm serve --model MODEL_NAME --other-args" +// 2. Full path: "/usr/local/bin/vllm serve --model MODEL_NAME" +// 3. Serve only: "serve --model MODEL_NAME --other-args" +// 4. Args only: "--model MODEL_NAME --other-args" +// 5. Multiline commands with backslashes +func ParseVllmCommand(command string) (*VllmServerOptions, error) { + // 1. Normalize the command - handle multiline with backslashes + trimmed := normalizeMultilineCommand(command) + if trimmed == "" { + return nil, fmt.Errorf("command cannot be empty") + } + + // 2. Extract arguments from command + args, err := extractArgumentsFromCommand(trimmed) + if err != nil { + return nil, err + } + + // 3. Parse arguments into map + options := make(map[string]any) + + // Known multi-valued flags (snake_case form) + multiValued := map[string]struct{}{ + "middleware": {}, + "api_key": {}, + "allowed_origins": {}, + "allowed_methods": {}, + "allowed_headers": {}, + "lora_modules": {}, + "prompt_adapters": {}, + } + + i := 0 + for i < len(args) { + arg := args[i] + + if !strings.HasPrefix(arg, "-") { // skip positional / stray values + i++ + continue + } + + // Reject malformed flags with more than two leading dashes (e.g. ---model) to surface user mistakes + if strings.HasPrefix(arg, "---") { + return nil, fmt.Errorf("malformed flag: %s", arg) + } + + // Unified parsing for --flag=value vs --flag value + var rawFlag, rawValue string + hasEquals := false + if strings.Contains(arg, "=") { + parts := strings.SplitN(arg, "=", 2) + rawFlag = parts[0] + rawValue = parts[1] // may be empty string + hasEquals = true + } else { + rawFlag = arg + } + + flagCore := strings.TrimPrefix(strings.TrimPrefix(rawFlag, "-"), "-") + flagName := strings.ReplaceAll(flagCore, "-", "_") + + // Detect value if not in equals form + valueProvided := hasEquals + if !hasEquals { + if i+1 < len(args) && !isFlag(args[i+1]) { // next token is value + rawValue = args[i+1] + valueProvided = true + } + } + + // Determine if multi-valued flag + _, isMulti := multiValued[flagName] + + // Normalization helper: ensure slice for multi-valued flags + appendValue := func(valStr string) { + if existing, ok := options[flagName]; ok { + // Existing value; ensure slice semantics for multi-valued flags or repeated occurrences + if slice, ok := existing.([]string); ok { + options[flagName] = append(slice, valStr) + return + } + // Convert scalar to slice + options[flagName] = []string{fmt.Sprintf("%v", existing), valStr} + return + } + // First value + if isMulti { + options[flagName] = []string{valStr} + } else { + // We'll parse type below for single-valued flags + options[flagName] = valStr + } + } + + if valueProvided { + // Use raw token for multi-valued flags; else allow typed parsing + appendValue(rawValue) + if !isMulti { // convert to typed value if scalar + if strVal, ok := options[flagName].(string); ok { // still scalar + options[flagName] = parseValue(strVal) + } + } + // Advance index: if we consumed a following token as value (non equals form), skip it + if !hasEquals && i+1 < len(args) && rawValue == args[i+1] { + i += 2 + } else { + i++ + } + continue + } + + // Boolean flag (no value) + options[flagName] = true + i++ + } + + // 4. Convert to VllmServerOptions using existing UnmarshalJSON + jsonData, err := json.Marshal(options) + if err != nil { + return nil, fmt.Errorf("failed to marshal parsed options: %w", err) + } + + var vllmOptions VllmServerOptions + if err := json.Unmarshal(jsonData, &vllmOptions); err != nil { + return nil, fmt.Errorf("failed to parse command options: %w", err) + } + + // 5. Return VllmServerOptions + return &vllmOptions, nil +} + +// parseValue attempts to parse a string value into the most appropriate type +func parseValue(value string) any { + // Surrounding matching quotes (single or double) + if l := len(value); l >= 2 { + if (value[0] == '"' && value[l-1] == '"') || (value[0] == '\'' && value[l-1] == '\'') { + value = value[1 : l-1] + } + } + + lower := strings.ToLower(value) + if lower == "true" { + return true + } + if lower == "false" { + return false + } + + if intVal, err := strconv.Atoi(value); err == nil { + return intVal + } + if floatVal, err := strconv.ParseFloat(value, 64); err == nil { + return floatVal + } + return value +} + +// normalizeMultilineCommand handles multiline commands with backslashes +func normalizeMultilineCommand(command string) string { + // Handle escaped newlines (backslash followed by newline) + re := regexp.MustCompile(`\\\s*\n\s*`) + normalized := re.ReplaceAllString(command, " ") + + // Clean up extra whitespace + re = regexp.MustCompile(`\s+`) + normalized = re.ReplaceAllString(normalized, " ") + + return strings.TrimSpace(normalized) +} + +// extractArgumentsFromCommand extracts arguments from various command formats +func extractArgumentsFromCommand(command string) ([]string, error) { + // Split command into tokens respecting quotes + tokens, err := splitCommandTokens(command) + if err != nil { + return nil, err + } + + if len(tokens) == 0 { + return nil, fmt.Errorf("no command tokens found") + } + + // Check if first token looks like an executable + firstToken := tokens[0] + + // Case 1: Full path to executable (contains path separator or ends with vllm) + if strings.Contains(firstToken, string(filepath.Separator)) || + strings.HasSuffix(filepath.Base(firstToken), "vllm") { + // Check if second token is "serve" + if len(tokens) > 1 && strings.ToLower(tokens[1]) == "serve" { + return tokens[2:], nil // Return everything except executable and serve + } + return tokens[1:], nil // Return everything except the executable + } + + // Case 2: Just "vllm" command + if strings.ToLower(firstToken) == "vllm" { + // Check if second token is "serve" + if len(tokens) > 1 && strings.ToLower(tokens[1]) == "serve" { + return tokens[2:], nil // Return everything except vllm and serve + } + return tokens[1:], nil // Return everything except vllm + } + + // Case 3: Just "serve" command + if strings.ToLower(firstToken) == "serve" { + return tokens[1:], nil // Return everything except serve + } + + // Case 4: Arguments only (starts with a flag) + if strings.HasPrefix(firstToken, "-") { + return tokens, nil // Return all tokens as arguments + } + + // Case 5: Unknown format - might be a different executable name + // Be permissive and assume it's the executable + if len(tokens) > 1 && strings.ToLower(tokens[1]) == "serve" { + return tokens[2:], nil // Return everything except executable and serve + } + return tokens[1:], nil +} + +// splitCommandTokens splits a command string into tokens, respecting quotes +func splitCommandTokens(command string) ([]string, error) { + var tokens []string + var current strings.Builder + inQuotes := false + quoteChar := byte(0) + escaped := false + + for i := 0; i < len(command); i++ { + c := command[i] + + if escaped { + current.WriteByte(c) + escaped = false + continue + } + + if c == '\\' { + escaped = true + current.WriteByte(c) + continue + } + + if !inQuotes && (c == '"' || c == '\'') { + inQuotes = true + quoteChar = c + current.WriteByte(c) + } else if inQuotes && c == quoteChar { + inQuotes = false + quoteChar = 0 + current.WriteByte(c) + } else if !inQuotes && (c == ' ' || c == '\t') { + if current.Len() > 0 { + tokens = append(tokens, current.String()) + current.Reset() + } + } else { + current.WriteByte(c) + } + } + + if inQuotes { + return nil, errors.New("unterminated quoted string") + } + + if current.Len() > 0 { + tokens = append(tokens, current.String()) + } + + return tokens, nil +} + +// isFlag determines if a string is a command line flag or a value +// Handles the special case where negative numbers should be treated as values, not flags +func isFlag(arg string) bool { + if !strings.HasPrefix(arg, "-") { + return false + } + + // Special case: if it's a negative number, treat it as a value + if _, err := strconv.ParseFloat(arg, 64); err == nil { + return false + } + + return true +} \ No newline at end of file diff --git a/pkg/backends/vllm/parser_test.go b/pkg/backends/vllm/parser_test.go new file mode 100644 index 0000000..91921b2 --- /dev/null +++ b/pkg/backends/vllm/parser_test.go @@ -0,0 +1,83 @@ +package vllm + +import ( + "testing" +) + +func TestParseVllmCommand(t *testing.T) { + tests := []struct { + name string + command string + expectErr bool + }{ + { + name: "basic vllm serve command", + command: "vllm serve --model microsoft/DialoGPT-medium", + expectErr: false, + }, + { + name: "serve only command", + command: "serve --model microsoft/DialoGPT-medium", + expectErr: false, + }, + { + name: "args only", + command: "--model microsoft/DialoGPT-medium --tensor-parallel-size 2", + expectErr: false, + }, + { + name: "empty command", + command: "", + expectErr: true, + }, + { + name: "unterminated quote", + command: `vllm serve --model "unterminated`, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ParseVllmCommand(tt.command) + + if tt.expectErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result == nil { + t.Errorf("expected result but got nil") + } + }) + } +} + +func TestParseVllmCommandValues(t *testing.T) { + command := "vllm serve --model test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs" + result, err := ParseVllmCommand(command) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Model != "test-model" { + t.Errorf("expected model 'test-model', got '%s'", result.Model) + } + if result.TensorParallelSize != 4 { + t.Errorf("expected tensor_parallel_size 4, got %d", result.TensorParallelSize) + } + if result.GPUMemoryUtilization != 0.8 { + t.Errorf("expected gpu_memory_utilization 0.8, got %f", result.GPUMemoryUtilization) + } + if !result.EnableLogOutputs { + t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs) + } +} \ No newline at end of file diff --git a/pkg/backends/vllm/vllm.go b/pkg/backends/vllm/vllm.go new file mode 100644 index 0000000..6378b5e --- /dev/null +++ b/pkg/backends/vllm/vllm.go @@ -0,0 +1,439 @@ +package vllm + +import ( + "encoding/json" + "reflect" + "strconv" + "strings" +) + +type VllmServerOptions struct { + // Basic connection options (auto-assigned by llamactl) + Host string `json:"host,omitempty"` + Port int `json:"port,omitempty"` + + // Model and engine configuration + Model string `json:"model,omitempty"` + Tokenizer string `json:"tokenizer,omitempty"` + SkipTokenizerInit bool `json:"skip_tokenizer_init,omitempty"` + Revision string `json:"revision,omitempty"` + CodeRevision string `json:"code_revision,omitempty"` + TokenizerRevision string `json:"tokenizer_revision,omitempty"` + TokenizerMode string `json:"tokenizer_mode,omitempty"` + TrustRemoteCode bool `json:"trust_remote_code,omitempty"` + DownloadDir string `json:"download_dir,omitempty"` + LoadFormat string `json:"load_format,omitempty"` + ConfigFormat string `json:"config_format,omitempty"` + Dtype string `json:"dtype,omitempty"` + KVCacheDtype string `json:"kv_cache_dtype,omitempty"` + QuantizationParamPath string `json:"quantization_param_path,omitempty"` + Seed int `json:"seed,omitempty"` + MaxModelLen int `json:"max_model_len,omitempty"` + GuidedDecodingBackend string `json:"guided_decoding_backend,omitempty"` + DistributedExecutorBackend string `json:"distributed_executor_backend,omitempty"` + WorkerUseRay bool `json:"worker_use_ray,omitempty"` + RayWorkersUseNSight bool `json:"ray_workers_use_nsight,omitempty"` + + // Performance and serving configuration + BlockSize int `json:"block_size,omitempty"` + EnablePrefixCaching bool `json:"enable_prefix_caching,omitempty"` + DisableSlidingWindow bool `json:"disable_sliding_window,omitempty"` + UseV2BlockManager bool `json:"use_v2_block_manager,omitempty"` + NumLookaheadSlots int `json:"num_lookahead_slots,omitempty"` + SwapSpace int `json:"swap_space,omitempty"` + CPUOffloadGB int `json:"cpu_offload_gb,omitempty"` + GPUMemoryUtilization float64 `json:"gpu_memory_utilization,omitempty"` + NumGPUBlocksOverride int `json:"num_gpu_blocks_override,omitempty"` + MaxNumBatchedTokens int `json:"max_num_batched_tokens,omitempty"` + MaxNumSeqs int `json:"max_num_seqs,omitempty"` + MaxLogprobs int `json:"max_logprobs,omitempty"` + DisableLogStats bool `json:"disable_log_stats,omitempty"` + Quantization string `json:"quantization,omitempty"` + RopeScaling string `json:"rope_scaling,omitempty"` + RopeTheta float64 `json:"rope_theta,omitempty"` + EnforceEager bool `json:"enforce_eager,omitempty"` + MaxContextLenToCapture int `json:"max_context_len_to_capture,omitempty"` + MaxSeqLenToCapture int `json:"max_seq_len_to_capture,omitempty"` + DisableCustomAllReduce bool `json:"disable_custom_all_reduce,omitempty"` + TokenizerPoolSize int `json:"tokenizer_pool_size,omitempty"` + TokenizerPoolType string `json:"tokenizer_pool_type,omitempty"` + TokenizerPoolExtraConfig string `json:"tokenizer_pool_extra_config,omitempty"` + EnableLoraBias bool `json:"enable_lora_bias,omitempty"` + LoraExtraVocabSize int `json:"lora_extra_vocab_size,omitempty"` + LoraRank int `json:"lora_rank,omitempty"` + PromptLookbackDistance int `json:"prompt_lookback_distance,omitempty"` + PreemptionMode string `json:"preemption_mode,omitempty"` + + // Distributed and parallel processing + TensorParallelSize int `json:"tensor_parallel_size,omitempty"` + PipelineParallelSize int `json:"pipeline_parallel_size,omitempty"` + MaxParallelLoadingWorkers int `json:"max_parallel_loading_workers,omitempty"` + DisableAsyncOutputProc bool `json:"disable_async_output_proc,omitempty"` + WorkerClass string `json:"worker_class,omitempty"` + EnabledLoraModules string `json:"enabled_lora_modules,omitempty"` + MaxLoraRank int `json:"max_lora_rank,omitempty"` + FullyShardedLoras bool `json:"fully_sharded_loras,omitempty"` + LoraModules string `json:"lora_modules,omitempty"` + PromptAdapters string `json:"prompt_adapters,omitempty"` + MaxPromptAdapterToken int `json:"max_prompt_adapter_token,omitempty"` + Device string `json:"device,omitempty"` + SchedulerDelay float64 `json:"scheduler_delay,omitempty"` + EnableChunkedPrefill bool `json:"enable_chunked_prefill,omitempty"` + SpeculativeModel string `json:"speculative_model,omitempty"` + SpeculativeModelQuantization string `json:"speculative_model_quantization,omitempty"` + SpeculativeRevision string `json:"speculative_revision,omitempty"` + SpeculativeMaxModelLen int `json:"speculative_max_model_len,omitempty"` + SpeculativeDisableByBatchSize int `json:"speculative_disable_by_batch_size,omitempty"` + NgptSpeculativeLength int `json:"ngpt_speculative_length,omitempty"` + SpeculativeDisableMqa bool `json:"speculative_disable_mqa,omitempty"` + ModelLoaderExtraConfig string `json:"model_loader_extra_config,omitempty"` + IgnorePatterns string `json:"ignore_patterns,omitempty"` + PreloadedLoraModules string `json:"preloaded_lora_modules,omitempty"` + + // OpenAI server specific options + UDS string `json:"uds,omitempty"` + UvicornLogLevel string `json:"uvicorn_log_level,omitempty"` + ResponseRole string `json:"response_role,omitempty"` + SSLKeyfile string `json:"ssl_keyfile,omitempty"` + SSLCertfile string `json:"ssl_certfile,omitempty"` + SSLCACerts string `json:"ssl_ca_certs,omitempty"` + SSLCertReqs int `json:"ssl_cert_reqs,omitempty"` + RootPath string `json:"root_path,omitempty"` + Middleware []string `json:"middleware,omitempty"` + ReturnTokensAsTokenIDS bool `json:"return_tokens_as_token_ids,omitempty"` + DisableFrontendMultiprocessing bool `json:"disable_frontend_multiprocessing,omitempty"` + EnableAutoToolChoice bool `json:"enable_auto_tool_choice,omitempty"` + ToolCallParser string `json:"tool_call_parser,omitempty"` + ToolServer string `json:"tool_server,omitempty"` + ChatTemplate string `json:"chat_template,omitempty"` + ChatTemplateContentFormat string `json:"chat_template_content_format,omitempty"` + AllowCredentials bool `json:"allow_credentials,omitempty"` + AllowedOrigins []string `json:"allowed_origins,omitempty"` + AllowedMethods []string `json:"allowed_methods,omitempty"` + AllowedHeaders []string `json:"allowed_headers,omitempty"` + APIKey []string `json:"api_key,omitempty"` + EnableLogOutputs bool `json:"enable_log_outputs,omitempty"` + EnableTokenUsage bool `json:"enable_token_usage,omitempty"` + EnableAsyncEngineDebug bool `json:"enable_async_engine_debug,omitempty"` + EngineUseRay bool `json:"engine_use_ray,omitempty"` + DisableLogRequests bool `json:"disable_log_requests,omitempty"` + MaxLogLen int `json:"max_log_len,omitempty"` + + // Additional engine configuration + Task string `json:"task,omitempty"` + MultiModalConfig string `json:"multi_modal_config,omitempty"` + LimitMmPerPrompt string `json:"limit_mm_per_prompt,omitempty"` + EnableSleepMode bool `json:"enable_sleep_mode,omitempty"` + EnableChunkingRequest bool `json:"enable_chunking_request,omitempty"` + CompilationConfig string `json:"compilation_config,omitempty"` + DisableSlidingWindowMask bool `json:"disable_sliding_window_mask,omitempty"` + EnableTRTLLMEngineLatency bool `json:"enable_trtllm_engine_latency,omitempty"` + OverridePoolingConfig string `json:"override_pooling_config,omitempty"` + OverrideNeuronConfig string `json:"override_neuron_config,omitempty"` + OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"` +} + +// NewVllmServerOptions creates a new VllmServerOptions with defaults +func NewVllmServerOptions() *VllmServerOptions { + return &VllmServerOptions{ + Host: "127.0.0.1", + Port: 8000, + TensorParallelSize: 1, + PipelineParallelSize: 1, + GPUMemoryUtilization: 0.9, + BlockSize: 16, + SwapSpace: 4, + UvicornLogLevel: "info", + ResponseRole: "assistant", + TokenizerMode: "auto", + TrustRemoteCode: false, + EnablePrefixCaching: false, + EnforceEager: false, + DisableLogStats: false, + DisableLogRequests: false, + MaxLogprobs: 20, + EnableLogOutputs: false, + EnableTokenUsage: false, + AllowCredentials: false, + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"*"}, + AllowedHeaders: []string{"*"}, + } +} + +// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names +func (o *VllmServerOptions) UnmarshalJSON(data []byte) error { + // First unmarshal into a map to handle multiple field names + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Create a temporary struct for standard unmarshaling + type tempOptions VllmServerOptions + temp := tempOptions{} + + // Standard unmarshal first + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + // Copy to our struct + *o = VllmServerOptions(temp) + + // Handle alternative field names (CLI format with dashes) + fieldMappings := map[string]string{ + // Basic options + "tensor-parallel-size": "tensor_parallel_size", + "pipeline-parallel-size": "pipeline_parallel_size", + "max-parallel-loading-workers": "max_parallel_loading_workers", + "disable-async-output-proc": "disable_async_output_proc", + "worker-class": "worker_class", + "enabled-lora-modules": "enabled_lora_modules", + "max-lora-rank": "max_lora_rank", + "fully-sharded-loras": "fully_sharded_loras", + "lora-modules": "lora_modules", + "prompt-adapters": "prompt_adapters", + "max-prompt-adapter-token": "max_prompt_adapter_token", + "scheduler-delay": "scheduler_delay", + "enable-chunked-prefill": "enable_chunked_prefill", + "speculative-model": "speculative_model", + "speculative-model-quantization": "speculative_model_quantization", + "speculative-revision": "speculative_revision", + "speculative-max-model-len": "speculative_max_model_len", + "speculative-disable-by-batch-size": "speculative_disable_by_batch_size", + "ngpt-speculative-length": "ngpt_speculative_length", + "speculative-disable-mqa": "speculative_disable_mqa", + "model-loader-extra-config": "model_loader_extra_config", + "ignore-patterns": "ignore_patterns", + "preloaded-lora-modules": "preloaded_lora_modules", + + // Model configuration + "skip-tokenizer-init": "skip_tokenizer_init", + "code-revision": "code_revision", + "tokenizer-revision": "tokenizer_revision", + "tokenizer-mode": "tokenizer_mode", + "trust-remote-code": "trust_remote_code", + "download-dir": "download_dir", + "load-format": "load_format", + "config-format": "config_format", + "kv-cache-dtype": "kv_cache_dtype", + "quantization-param-path": "quantization_param_path", + "max-model-len": "max_model_len", + "guided-decoding-backend": "guided_decoding_backend", + "distributed-executor-backend": "distributed_executor_backend", + "worker-use-ray": "worker_use_ray", + "ray-workers-use-nsight": "ray_workers_use_nsight", + + // Performance configuration + "block-size": "block_size", + "enable-prefix-caching": "enable_prefix_caching", + "disable-sliding-window": "disable_sliding_window", + "use-v2-block-manager": "use_v2_block_manager", + "num-lookahead-slots": "num_lookahead_slots", + "swap-space": "swap_space", + "cpu-offload-gb": "cpu_offload_gb", + "gpu-memory-utilization": "gpu_memory_utilization", + "num-gpu-blocks-override": "num_gpu_blocks_override", + "max-num-batched-tokens": "max_num_batched_tokens", + "max-num-seqs": "max_num_seqs", + "max-logprobs": "max_logprobs", + "disable-log-stats": "disable_log_stats", + "rope-scaling": "rope_scaling", + "rope-theta": "rope_theta", + "enforce-eager": "enforce_eager", + "max-context-len-to-capture": "max_context_len_to_capture", + "max-seq-len-to-capture": "max_seq_len_to_capture", + "disable-custom-all-reduce": "disable_custom_all_reduce", + "tokenizer-pool-size": "tokenizer_pool_size", + "tokenizer-pool-type": "tokenizer_pool_type", + "tokenizer-pool-extra-config": "tokenizer_pool_extra_config", + "enable-lora-bias": "enable_lora_bias", + "lora-extra-vocab-size": "lora_extra_vocab_size", + "lora-rank": "lora_rank", + "prompt-lookback-distance": "prompt_lookback_distance", + "preemption-mode": "preemption_mode", + + // Server configuration + "uvicorn-log-level": "uvicorn_log_level", + "response-role": "response_role", + "ssl-keyfile": "ssl_keyfile", + "ssl-certfile": "ssl_certfile", + "ssl-ca-certs": "ssl_ca_certs", + "ssl-cert-reqs": "ssl_cert_reqs", + "root-path": "root_path", + "return-tokens-as-token-ids": "return_tokens_as_token_ids", + "disable-frontend-multiprocessing": "disable_frontend_multiprocessing", + "enable-auto-tool-choice": "enable_auto_tool_choice", + "tool-call-parser": "tool_call_parser", + "tool-server": "tool_server", + "chat-template": "chat_template", + "chat-template-content-format": "chat_template_content_format", + "allow-credentials": "allow_credentials", + "allowed-origins": "allowed_origins", + "allowed-methods": "allowed_methods", + "allowed-headers": "allowed_headers", + "api-key": "api_key", + "enable-log-outputs": "enable_log_outputs", + "enable-token-usage": "enable_token_usage", + "enable-async-engine-debug": "enable_async_engine_debug", + "engine-use-ray": "engine_use_ray", + "disable-log-requests": "disable_log_requests", + "max-log-len": "max_log_len", + + // Additional options + "multi-modal-config": "multi_modal_config", + "limit-mm-per-prompt": "limit_mm_per_prompt", + "enable-sleep-mode": "enable_sleep_mode", + "enable-chunking-request": "enable_chunking_request", + "compilation-config": "compilation_config", + "disable-sliding-window-mask": "disable_sliding_window_mask", + "enable-trtllm-engine-latency": "enable_trtllm_engine_latency", + "override-pooling-config": "override_pooling_config", + "override-neuron-config": "override_neuron_config", + "override-kv-cache-align-size": "override_kv_cache_align_size", + } + + // Process alternative field names + for altName, canonicalName := range fieldMappings { + if value, exists := raw[altName]; exists { + // Use reflection to set the field value + v := reflect.ValueOf(o).Elem() + field := v.FieldByNameFunc(func(fieldName string) bool { + field, _ := v.Type().FieldByName(fieldName) + jsonTag := field.Tag.Get("json") + return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName + }) + + if field.IsValid() && field.CanSet() { + switch field.Kind() { + case reflect.Int: + if intVal, ok := value.(float64); ok { + field.SetInt(int64(intVal)) + } else if strVal, ok := value.(string); ok { + if intVal, err := strconv.Atoi(strVal); err == nil { + field.SetInt(int64(intVal)) + } + } + case reflect.Float64: + if floatVal, ok := value.(float64); ok { + field.SetFloat(floatVal) + } else if strVal, ok := value.(string); ok { + if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil { + field.SetFloat(floatVal) + } + } + case reflect.String: + if strVal, ok := value.(string); ok { + field.SetString(strVal) + } + case reflect.Bool: + if boolVal, ok := value.(bool); ok { + field.SetBool(boolVal) + } + case reflect.Slice: + if field.Type().Elem().Kind() == reflect.String { + if strVal, ok := value.(string); ok { + // Split comma-separated values + values := strings.Split(strVal, ",") + for i, v := range values { + values[i] = strings.TrimSpace(v) + } + field.Set(reflect.ValueOf(values)) + } else if slice, ok := value.([]interface{}); ok { + var strSlice []string + for _, item := range slice { + if str, ok := item.(string); ok { + strSlice = append(strSlice, str) + } + } + field.Set(reflect.ValueOf(strSlice)) + } + } + } + } + } + } + + return nil +} + +// BuildCommandArgs converts VllmServerOptions to command line arguments +// Note: This does NOT include the "serve" subcommand, that's handled at the instance level +func (o *VllmServerOptions) BuildCommandArgs() []string { + var args []string + + v := reflect.ValueOf(o).Elem() + t := v.Type() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldType := t.Field(i) + + // Skip unexported fields + if !field.CanInterface() { + continue + } + + // Get the JSON tag to determine the flag name + jsonTag := fieldType.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } + + // Remove ",omitempty" from the tag + flagName := jsonTag + if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 { + flagName = jsonTag[:commaIndex] + } + + // Skip host and port as they are handled by llamactl + if flagName == "host" || flagName == "port" { + continue + } + + // Convert snake_case to kebab-case for CLI flags + flagName = strings.ReplaceAll(flagName, "_", "-") + + // Add the appropriate arguments based on field type and value + switch field.Kind() { + case reflect.Bool: + if field.Bool() { + args = append(args, "--"+flagName) + } + case reflect.Int: + if field.Int() != 0 { + args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) + } + case reflect.Float64: + if field.Float() != 0 { + args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) + } + case reflect.String: + if field.String() != "" { + args = append(args, "--"+flagName, field.String()) + } + case reflect.Slice: + if field.Type().Elem().Kind() == reflect.String { + // Handle []string fields - some are comma-separated, some use multiple flags + if flagName == "api-key" || flagName == "allowed-origins" || flagName == "allowed-methods" || flagName == "allowed-headers" || flagName == "middleware" { + // Multiple flags for these + for j := 0; j < field.Len(); j++ { + args = append(args, "--"+flagName, field.Index(j).String()) + } + } else { + // Comma-separated for others + if field.Len() > 0 { + var values []string + for j := 0; j < field.Len(); j++ { + values = append(values, field.Index(j).String()) + } + args = append(args, "--"+flagName, strings.Join(values, ",")) + } + } + } + } + } + + return args +} \ No newline at end of file diff --git a/pkg/backends/vllm/vllm_test.go b/pkg/backends/vllm/vllm_test.go new file mode 100644 index 0000000..e05320a --- /dev/null +++ b/pkg/backends/vllm/vllm_test.go @@ -0,0 +1,106 @@ +package vllm_test + +import ( + "encoding/json" + "llamactl/pkg/backends/vllm" + "slices" + "testing" +) + +func TestBuildCommandArgs(t *testing.T) { + options := vllm.VllmServerOptions{ + Model: "microsoft/DialoGPT-medium", + Port: 8080, // should be excluded + Host: "localhost", // should be excluded + TensorParallelSize: 2, + GPUMemoryUtilization: 0.8, + EnableLogOutputs: true, + APIKey: []string{"key1", "key2"}, + } + + args := options.BuildCommandArgs() + + // Check core functionality + if !containsFlagWithValue(args, "--model", "microsoft/DialoGPT-medium") { + t.Errorf("Expected --model microsoft/DialoGPT-medium not found in %v", args) + } + if !containsFlagWithValue(args, "--tensor-parallel-size", "2") { + t.Errorf("Expected --tensor-parallel-size 2 not found in %v", args) + } + if !contains(args, "--enable-log-outputs") { + t.Errorf("Expected --enable-log-outputs not found in %v", args) + } + + // Host and port should NOT be in the arguments (handled by llamactl) + if contains(args, "--host") || contains(args, "--port") { + t.Errorf("Host and port should not be in command args, found in %v", args) + } + + // Check array handling (multiple flags) + apiKeyCount := 0 + for i := range args { + if args[i] == "--api-key" { + apiKeyCount++ + } + } + if apiKeyCount != 2 { + t.Errorf("Expected 2 --api-key flags, got %d", apiKeyCount) + } +} + +func TestUnmarshalJSON(t *testing.T) { + // Test both underscore and dash formats + jsonData := `{ + "model": "test-model", + "tensor_parallel_size": 4, + "gpu-memory-utilization": 0.9, + "enable-log-outputs": true + }` + + var options vllm.VllmServerOptions + err := json.Unmarshal([]byte(jsonData), &options) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + if options.Model != "test-model" { + t.Errorf("Expected model 'test-model', got %q", options.Model) + } + if options.TensorParallelSize != 4 { + t.Errorf("Expected tensor_parallel_size 4, got %d", options.TensorParallelSize) + } + if options.GPUMemoryUtilization != 0.9 { + t.Errorf("Expected gpu_memory_utilization 0.9, got %f", options.GPUMemoryUtilization) + } + if !options.EnableLogOutputs { + t.Errorf("Expected enable_log_outputs true, got %v", options.EnableLogOutputs) + } +} + +func TestNewVllmServerOptions(t *testing.T) { + options := vllm.NewVllmServerOptions() + + if options == nil { + t.Fatal("NewVllmServerOptions returned nil") + } + if options.Host != "127.0.0.1" { + t.Errorf("Expected default host '127.0.0.1', got %q", options.Host) + } + if options.Port != 8000 { + t.Errorf("Expected default port 8000, got %d", options.Port) + } +} + +// Helper functions +func contains(slice []string, item string) bool { + return slices.Contains(slice, item) +} + +func containsFlagWithValue(args []string, flag, value string) bool { + for i, arg := range args { + if arg == flag && i+1 < len(args) && args[i+1] == value { + return true + } + } + return false +} \ No newline at end of file diff --git a/pkg/config/config.go b/pkg/config/config.go index 28087db..504ecc3 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -17,6 +17,9 @@ type BackendConfig struct { // Path to mlx_lm executable (MLX-LM backend) MLXLMExecutable string `yaml:"mlx_lm_executable"` + + // Path to vllm executable (vLLM backend) + VllmExecutable string `yaml:"vllm_executable"` } // AppConfig represents the configuration for llamactl @@ -122,6 +125,7 @@ func LoadConfig(configPath string) (AppConfig, error) { Backends: BackendConfig{ LlamaExecutable: "llama-server", MLXLMExecutable: "mlx_lm.server", + VllmExecutable: "vllm", }, Instances: InstancesConfig{ PortRange: [2]int{8000, 9000}, @@ -246,6 +250,9 @@ func loadEnvVars(cfg *AppConfig) { if mlxLMExec := os.Getenv("LLAMACTL_MLX_LM_EXECUTABLE"); mlxLMExec != "" { cfg.Backends.MLXLMExecutable = mlxLMExec } + if vllmExec := os.Getenv("LLAMACTL_VLLM_EXECUTABLE"); vllmExec != "" { + cfg.Backends.VllmExecutable = vllmExec + } if autoRestart := os.Getenv("LLAMACTL_DEFAULT_AUTO_RESTART"); autoRestart != "" { if b, err := strconv.ParseBool(autoRestart); err == nil { cfg.Instances.DefaultAutoRestart = b diff --git a/pkg/instance/lifecycle.go b/pkg/instance/lifecycle.go index 04c5fba..c4e23a7 100644 --- a/pkg/instance/lifecycle.go +++ b/pkg/instance/lifecycle.go @@ -52,6 +52,8 @@ func (i *Process) Start() error { executable = i.globalBackendSettings.LlamaExecutable case backends.BackendTypeMlxLm: executable = i.globalBackendSettings.MLXLMExecutable + case backends.BackendTypeVllm: + executable = i.globalBackendSettings.VllmExecutable default: return fmt.Errorf("unsupported backend type: %s", i.options.BackendType) } diff --git a/pkg/instance/options.go b/pkg/instance/options.go index 2b1437f..2e0b2fd 100644 --- a/pkg/instance/options.go +++ b/pkg/instance/options.go @@ -6,6 +6,7 @@ import ( "llamactl/pkg/backends" "llamactl/pkg/backends/llamacpp" "llamactl/pkg/backends/mlx" + "llamactl/pkg/backends/vllm" "llamactl/pkg/config" "log" ) @@ -26,6 +27,7 @@ type CreateInstanceOptions struct { // Backend-specific options LlamaServerOptions *llamacpp.LlamaServerOptions `json:"-"` MlxServerOptions *mlx.MlxServerOptions `json:"-"` + VllmServerOptions *vllm.VllmServerOptions `json:"-"` } // UnmarshalJSON implements custom JSON unmarshaling for CreateInstanceOptions @@ -63,12 +65,24 @@ func (c *CreateInstanceOptions) UnmarshalJSON(data []byte) error { if err != nil { return fmt.Errorf("failed to marshal backend options: %w", err) } - + c.MlxServerOptions = &mlx.MlxServerOptions{} if err := json.Unmarshal(optionsData, c.MlxServerOptions); err != nil { return fmt.Errorf("failed to unmarshal MLX options: %w", err) } } + case backends.BackendTypeVllm: + if c.BackendOptions != nil { + optionsData, err := json.Marshal(c.BackendOptions) + if err != nil { + return fmt.Errorf("failed to marshal backend options: %w", err) + } + + c.VllmServerOptions = &vllm.VllmServerOptions{} + if err := json.Unmarshal(optionsData, c.VllmServerOptions); err != nil { + return fmt.Errorf("failed to unmarshal vLLM options: %w", err) + } + } default: return fmt.Errorf("unknown backend type: %s", c.BackendType) } @@ -114,6 +128,20 @@ func (c *CreateInstanceOptions) MarshalJSON() ([]byte, error) { return nil, fmt.Errorf("failed to unmarshal to map: %w", err) } + aux.BackendOptions = backendOpts + } + case backends.BackendTypeVllm: + if c.VllmServerOptions != nil { + data, err := json.Marshal(c.VllmServerOptions) + if err != nil { + return nil, fmt.Errorf("failed to marshal vLLM server options: %w", err) + } + + var backendOpts map[string]any + if err := json.Unmarshal(data, &backendOpts); err != nil { + return nil, fmt.Errorf("failed to unmarshal to map: %w", err) + } + aux.BackendOptions = backendOpts } } @@ -171,6 +199,13 @@ func (c *CreateInstanceOptions) BuildCommandArgs() []string { if c.MlxServerOptions != nil { return c.MlxServerOptions.BuildCommandArgs() } + case backends.BackendTypeVllm: + if c.VllmServerOptions != nil { + // Prepend "serve" as first argument + args := []string{"serve"} + args = append(args, c.VllmServerOptions.BuildCommandArgs()...) + return args + } } return []string{} } diff --git a/pkg/server/handlers.go b/pkg/server/handlers.go index c4932b2..0d74851 100644 --- a/pkg/server/handlers.go +++ b/pkg/server/handlers.go @@ -8,6 +8,7 @@ import ( "llamactl/pkg/backends" "llamactl/pkg/backends/llamacpp" "llamactl/pkg/backends/mlx" + "llamactl/pkg/backends/vllm" "llamactl/pkg/config" "llamactl/pkg/instance" "llamactl/pkg/manager" @@ -732,7 +733,60 @@ func (h *Handler) ParseMlxCommand() http.HandlerFunc { BackendType: backendType, MlxServerOptions: mlxOptions, } - + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(options); err != nil { + writeError(w, http.StatusInternalServerError, "encode_error", err.Error()) + } + } +} + +// ParseVllmCommand godoc +// @Summary Parse vllm serve command +// @Description Parses a vLLM serve command string into instance options +// @Tags backends +// @Security ApiKeyAuth +// @Accept json +// @Produce json +// @Param request body ParseCommandRequest true "Command to parse" +// @Success 200 {object} instance.CreateInstanceOptions "Parsed options" +// @Failure 400 {object} map[string]string "Invalid request or command" +// @Router /backends/vllm/parse-command [post] +func (h *Handler) ParseVllmCommand() http.HandlerFunc { + type errorResponse struct { + Error string `json:"error"` + Details string `json:"details,omitempty"` + } + writeError := func(w http.ResponseWriter, status int, code, details string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(errorResponse{Error: code, Details: details}) + } + return func(w http.ResponseWriter, r *http.Request) { + var req ParseCommandRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "Invalid JSON body") + return + } + + if strings.TrimSpace(req.Command) == "" { + writeError(w, http.StatusBadRequest, "invalid_command", "Command cannot be empty") + return + } + + vllmOptions, err := vllm.ParseVllmCommand(req.Command) + if err != nil { + writeError(w, http.StatusBadRequest, "parse_error", err.Error()) + return + } + + backendType := backends.BackendTypeVllm + + options := &instance.CreateInstanceOptions{ + BackendType: backendType, + VllmServerOptions: vllmOptions, + } + w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(options); err != nil { writeError(w, http.StatusInternalServerError, "encode_error", err.Error()) diff --git a/pkg/server/routes.go b/pkg/server/routes.go index aa31e1f..898b574 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -58,6 +58,9 @@ func SetupRouter(handler *Handler) *chi.Mux { r.Route("/mlx", func(r chi.Router) { r.Post("/parse-command", handler.ParseMlxCommand()) }) + r.Route("/vllm", func(r chi.Router) { + r.Post("/parse-command", handler.ParseVllmCommand()) + }) }) // Instance management endpoints diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go index eff1dd3..638e5d2 100644 --- a/pkg/validation/validation.go +++ b/pkg/validation/validation.go @@ -46,6 +46,8 @@ func ValidateInstanceOptions(options *instance.CreateInstanceOptions) error { return validateLlamaCppOptions(options) case backends.BackendTypeMlxLm: return validateMlxOptions(options) + case backends.BackendTypeVllm: + return validateVllmOptions(options) default: return ValidationError(fmt.Errorf("unsupported backend type: %s", options.BackendType)) } @@ -88,6 +90,25 @@ func validateMlxOptions(options *instance.CreateInstanceOptions) error { return nil } +// validateVllmOptions validates vLLM backend specific options +func validateVllmOptions(options *instance.CreateInstanceOptions) error { + if options.VllmServerOptions == nil { + return ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend")) + } + + // Use reflection to check all string fields for injection patterns + if err := validateStructStrings(options.VllmServerOptions, ""); err != nil { + return err + } + + // Basic network validation for port + if options.VllmServerOptions.Port < 0 || options.VllmServerOptions.Port > 65535 { + return ValidationError(fmt.Errorf("invalid port range: %d", options.VllmServerOptions.Port)) + } + + return nil +} + // validateStructStrings recursively validates all string fields in a struct func validateStructStrings(v any, fieldPath string) error { val := reflect.ValueOf(v) diff --git a/vllm_backend_spec.md b/vllm_backend_spec.md new file mode 100644 index 0000000..9fede10 --- /dev/null +++ b/vllm_backend_spec.md @@ -0,0 +1,440 @@ +# vLLM Backend Implementation Specification + +## Overview +This specification outlines the implementation of vLLM backend support for llamactl, following the existing patterns established by the llama.cpp and MLX backends. + +## 1. Backend Configuration + +### Basic Details +- **Backend Type**: `vllm` +- **Executable**: `vllm` (configured via `VllmExecutable`) +- **Subcommand**: `serve` (automatically prepended to arguments) +- **Default Host/Port**: Auto-assigned by llamactl +- **Health Check**: Uses `/health` endpoint (returns HTTP 200 with no content) +- **API Compatibility**: OpenAI-compatible endpoints + +### Example Command +```bash +vllm serve --enable-log-outputs --tensor-parallel-size 2 --gpu-memory-utilization 0.5 --model ISTA-DASLab/gemma-3-27b-it-GPTQ-4b-128g +``` + +## 2. File Structure +Following the existing backend pattern: +``` +pkg/backends/vllm/ +├── vllm.go # VllmServerOptions struct and methods +├── vllm_test.go # Unit tests for VllmServerOptions +├── parser.go # Command parsing logic +└── parser_test.go # Parser tests +``` + +## 3. Core Implementation Files + +### 3.1 `pkg/backends/vllm/vllm.go` + +#### VllmServerOptions Struct +```go +type VllmServerOptions struct { + // Basic connection options (auto-assigned by llamactl) + Host string `json:"host,omitempty"` + Port int `json:"port,omitempty"` + + // Core model options + Model string `json:"model,omitempty"` + + // Common serving options + EnableLogOutputs bool `json:"enable_log_outputs,omitempty"` + TensorParallelSize int `json:"tensor_parallel_size,omitempty"` + GPUMemoryUtilization float64 `json:"gpu_memory_utilization,omitempty"` + + // Additional parameters to be added based on vLLM CLI documentation + // Following the same comprehensive approach as llamacpp.LlamaServerOptions +} +``` + +#### Required Methods +- `UnmarshalJSON()` - Custom unmarshaling with alternative field name support (dash-to-underscore conversion) +- `BuildCommandArgs()` - Convert struct to command line arguments (excluding "serve" subcommand) +- `NewVllmServerOptions()` - Constructor with vLLM defaults + +#### Field Name Mapping +Support both CLI argument names (with dashes) and programmatic names (with underscores), similar to the llama.cpp implementation: +```go +fieldMappings := map[string]string{ + "enable-log-outputs": "enable_log_outputs", + "tensor-parallel-size": "tensor_parallel_size", + "gpu-memory-utilization": "gpu_memory_utilization", + // ... other mappings +} +``` + +### 3.2 `pkg/backends/vllm/parser.go` + +#### ParseVllmCommand Function +Following the same pattern as `llamacpp/parser.go` and `mlx/parser.go`: + +```go +func ParseVllmCommand(command string) (*VllmServerOptions, error) +``` + +**Supported Input Formats:** +1. `vllm serve --model MODEL_NAME --other-args` +2. `/path/to/vllm serve --model MODEL_NAME` +3. `serve --model MODEL_NAME --other-args` +4. `--model MODEL_NAME --other-args` (args only) +5. Multiline commands with backslashes + +**Implementation Details:** +- Handle "serve" subcommand detection and removal +- Support quoted strings and escaped characters +- Validate command structure +- Convert parsed arguments to `VllmServerOptions` + +## 4. Backend Integration + +### 4.1 Backend Type Definition +**File**: `pkg/backends/backend.go` +```go +const ( + BackendTypeLlamaCpp BackendType = "llama_cpp" + BackendTypeMlxLm BackendType = "mlx_lm" + BackendTypeVllm BackendType = "vllm" // ADD THIS +) +``` + +### 4.2 Configuration Integration +**File**: `pkg/config/config.go` + +#### BackendConfig Update +```go +type BackendConfig struct { + LlamaExecutable string `yaml:"llama_executable"` + MLXLMExecutable string `yaml:"mlx_lm_executable"` + VllmExecutable string `yaml:"vllm_executable"` // ADD THIS +} +``` + +#### Default Configuration +- **Default Value**: `"vllm"` +- **Environment Variable**: `LLAMACTL_VLLM_EXECUTABLE` + +#### Environment Variable Loading +Add to `loadEnvVars()` function: +```go +if vllmExec := os.Getenv("LLAMACTL_VLLM_EXECUTABLE"); vllmExec != "" { + cfg.Backends.VllmExecutable = vllmExec +} +``` + +### 4.3 Instance Options Integration +**File**: `pkg/instance/options.go` + +#### CreateInstanceOptions Update +```go +type CreateInstanceOptions struct { + // existing fields... + VllmServerOptions *vllm.VllmServerOptions `json:"-"` +} +``` + +#### JSON Marshaling/Unmarshaling +Update `UnmarshalJSON()` and `MarshalJSON()` methods to handle vLLM backend similar to existing backends. + +#### BuildCommandArgs Implementation +```go +case backends.BackendTypeVllm: + if c.VllmServerOptions != nil { + // Prepend "serve" as first argument + args := []string{"serve"} + args = append(args, c.VllmServerOptions.BuildCommandArgs()...) + return args + } +``` + +**Key Point**: The "serve" subcommand is handled at the instance options level, keeping the `VllmServerOptions.BuildCommandArgs()` method focused only on vLLM-specific parameters. + +## 5. Health Check Integration + +### 5.1 Standard Health Check for vLLM +**File**: `pkg/instance/lifecycle.go` + +vLLM provides a standard `/health` endpoint that returns HTTP 200 with no content, so no modifications are needed to the existing health check logic. The current `WaitForHealthy()` method will work as-is: + +```go +healthURL := fmt.Sprintf("http://%s:%d/health", host, port) +``` + +### 5.2 Startup Time Considerations +- vLLM typically has longer startup times compared to llama.cpp +- The existing configurable timeout system should handle this adequately +- Users may need to adjust `on_demand_start_timeout` for larger models + +## 6. Lifecycle Integration + +### 6.1 Executable Selection +**File**: `pkg/instance/lifecycle.go` + +Update the `Start()` method to handle vLLM executable: + +```go +switch i.options.BackendType { +case backends.BackendTypeLlamaCpp: + executable = i.globalBackendSettings.LlamaExecutable +case backends.BackendTypeMlxLm: + executable = i.globalBackendSettings.MLXLMExecutable +case backends.BackendTypeVllm: // ADD THIS + executable = i.globalBackendSettings.VllmExecutable +default: + return fmt.Errorf("unsupported backend type: %s", i.options.BackendType) +} + +args := i.options.BuildCommandArgs() +i.cmd = exec.CommandContext(i.ctx, executable, args...) +``` + +### 6.2 Command Execution +The final executed command will be: +```bash +vllm serve --model MODEL_NAME --other-vllm-args +``` + +Where: +- `vllm` comes from `VllmExecutable` configuration +- `serve` is prepended by `BuildCommandArgs()` +- Remaining args come from `VllmServerOptions.BuildCommandArgs()` + +## 7. Server Handler Integration + +### 7.1 New Handler Method +**File**: `pkg/server/handlers.go` + +```go +// ParseVllmCommand godoc +// @Summary Parse vllm serve command +// @Description Parses a vLLM serve command string into instance options +// @Tags backends +// @Security ApiKeyAuth +// @Accept json +// @Produce json +// @Param request body ParseCommandRequest true "Command to parse" +// @Success 200 {object} instance.CreateInstanceOptions "Parsed options" +// @Failure 400 {object} map[string]string "Invalid request or command" +// @Router /backends/vllm/parse-command [post] +func (h *Handler) ParseVllmCommand() http.HandlerFunc { + // Implementation similar to ParseMlxCommand() + // Uses vllm.ParseVllmCommand() internally +} +``` + +### 7.2 Router Integration +**File**: `pkg/server/routes.go` + +Add vLLM route: +```go +r.Route("/backends", func(r chi.Router) { + r.Route("/llama-cpp", func(r chi.Router) { + r.Post("/parse-command", handler.ParseLlamaCommand()) + }) + r.Route("/mlx", func(r chi.Router) { + r.Post("/parse-command", handler.ParseMlxCommand()) + }) + r.Route("/vllm", func(r chi.Router) { // ADD THIS + r.Post("/parse-command", handler.ParseVllmCommand()) + }) +}) +``` + +## 8. Validation Integration + +### 8.1 Instance Options Validation +**File**: `pkg/validation/validation.go` + +Add vLLM validation case: +```go +func ValidateInstanceOptions(options *instance.CreateInstanceOptions) error { + // existing validation... + + switch options.BackendType { + case backends.BackendTypeLlamaCpp: + return validateLlamaCppOptions(options) + case backends.BackendTypeMlxLm: + return validateMlxOptions(options) + case backends.BackendTypeVllm: // ADD THIS + return validateVllmOptions(options) + default: + return ValidationError(fmt.Errorf("unsupported backend type: %s", options.BackendType)) + } +} + +func validateVllmOptions(options *instance.CreateInstanceOptions) error { + if options.VllmServerOptions == nil { + return ValidationError(fmt.Errorf("vLLM server options cannot be nil for vLLM backend")) + } + + // Basic validation following the same pattern as other backends + if err := validateStructStrings(options.VllmServerOptions, ""); err != nil { + return err + } + + // Port validation + if options.VllmServerOptions.Port < 0 || options.VllmServerOptions.Port > 65535 { + return ValidationError(fmt.Errorf("invalid port range: %d", options.VllmServerOptions.Port)) + } + + return nil +} +``` + +## 9. Testing Strategy + +### 9.1 Unit Tests +- **`vllm_test.go`**: Test `VllmServerOptions` marshaling/unmarshaling, BuildCommandArgs() +- **`parser_test.go`**: Test command parsing for various formats +- **Integration tests**: Mock vLLM commands and validate parsing + +### 9.2 Test Cases +```go +func TestBuildCommandArgs_VllmBasic(t *testing.T) { + options := VllmServerOptions{ + Model: "microsoft/DialoGPT-medium", + Port: 8080, + Host: "localhost", + EnableLogOutputs: true, + TensorParallelSize: 2, + } + + args := options.BuildCommandArgs() + // Validate expected arguments (excluding "serve") +} + +func TestParseVllmCommand_FullCommand(t *testing.T) { + command := "vllm serve --model ISTA-DASLab/gemma-3-27b-it-GPTQ-4b-128g --tensor-parallel-size 2" + result, err := ParseVllmCommand(command) + // Validate parsing results +} +``` + +## 10. Example Usage + +### 10.1 Parse Existing vLLM Command +```bash +curl -X POST http://localhost:8080/api/v1/backends/vllm/parse-command \ + -H "Authorization: Bearer your-management-key" \ + -H "Content-Type: application/json" \ + -d '{ + "command": "vllm serve --model ISTA-DASLab/gemma-3-27b-it-GPTQ-4b-128g --tensor-parallel-size 2 --gpu-memory-utilization 0.5" + }' +``` + +### 10.2 Create vLLM Instance +```bash +curl -X POST http://localhost:8080/api/v1/instances/my-vllm-model \ + -H "Authorization: Bearer your-management-key" \ + -H "Content-Type: application/json" \ + -d '{ + "backend_type": "vllm", + "backend_options": { + "model": "ISTA-DASLab/gemma-3-27b-it-GPTQ-4b-128g", + "tensor_parallel_size": 2, + "gpu_memory_utilization": 0.5, + "enable_log_outputs": true + } + }' +``` + +### 10.3 Use via OpenAI-Compatible API +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Authorization: Bearer your-inference-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "my-vllm-model", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +## 11. Implementation Checklist + +### Phase 1: Core Backend +- [ ] Create `pkg/backends/vllm/vllm.go` +- [ ] Implement `VllmServerOptions` struct with basic fields +- [ ] Implement `BuildCommandArgs()`, `UnmarshalJSON()`, `MarshalJSON()` +- [ ] Add comprehensive field mappings for CLI args +- [ ] Create unit tests for `VllmServerOptions` + +### Phase 2: Command Parsing +- [ ] Create `pkg/backends/vllm/parser.go` +- [ ] Implement `ParseVllmCommand()` function +- [ ] Handle various command input formats +- [ ] Create comprehensive parser tests +- [ ] Test edge cases and error conditions + +### Phase 3: Integration +- [ ] Add `BackendTypeVllm` to `pkg/backends/backend.go` +- [ ] Update `BackendConfig` in `pkg/config/config.go` +- [ ] Add environment variable support +- [ ] Update `CreateInstanceOptions` in `pkg/instance/options.go` +- [ ] Implement `BuildCommandArgs()` with "serve" prepending + +### Phase 4: Lifecycle & Health Checks +- [ ] Update executable selection in `pkg/instance/lifecycle.go` +- [ ] Test instance startup and health checking (uses existing `/health` endpoint) +- [ ] Validate command execution flow + +### Phase 5: API Integration +- [ ] Add `ParseVllmCommand()` handler in `pkg/server/handlers.go` +- [ ] Add vLLM route in `pkg/server/routes.go` +- [ ] Update validation in `pkg/validation/validation.go` +- [ ] Test API endpoints + +### Phase 6: Testing & Documentation +- [ ] Create comprehensive integration tests +- [ ] Test with actual vLLM installation (if available) +- [ ] Update documentation +- [ ] Test OpenAI-compatible proxy functionality + +## 12. Configuration Examples + +### 12.1 YAML Configuration +```yaml +backends: + llama_executable: "llama-server" + mlx_lm_executable: "mlx_lm.server" + vllm_executable: "vllm" + +instances: + # ... other instance settings +``` + +### 12.2 Environment Variables +```bash +export LLAMACTL_VLLM_EXECUTABLE="vllm" +# OR for custom installation +export LLAMACTL_VLLM_EXECUTABLE="python -m vllm" +# OR for containerized deployment +export LLAMACTL_VLLM_EXECUTABLE="docker run --rm --gpus all vllm/vllm-openai" +``` + +## 13. Notes and Considerations + +### 13.1 Startup Time +- vLLM instances may take significantly longer to start than llama.cpp +- Consider documenting recommended timeout values +- The configurable `on_demand_start_timeout` should accommodate this + +### 13.2 Resource Usage +- vLLM typically requires substantial GPU memory +- No special handling needed in llamactl (follows existing pattern) +- Resource management is left to the user/administrator + +### 13.3 Model Compatibility +- Primarily designed for HuggingFace models +- Supports various quantization formats (GPTQ, AWQ, etc.) +- Model path validation can be basic (similar to other backends) + +### 13.4 Future Enhancements +- Consider adding vLLM-specific parameter validation +- Could add model download/caching features +- May want to add vLLM version detection capabilities + +This specification provides a comprehensive roadmap for implementing vLLM backend support while maintaining consistency with the existing llamactl architecture. \ No newline at end of file From c7136d520617d62c0f534220dc2762e78d80d244 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 19 Sep 2025 18:36:23 +0200 Subject: [PATCH 02/20] Refactor command parsing logic across backends to utilize a unified CommandParserConfig structure --- pkg/backends/llamacpp/parser.go | 281 ++--------------------------- pkg/backends/mlx/parser.go | 229 +---------------------- pkg/backends/parser.go | 310 ++++++++++++++++++++++++++++++++ pkg/backends/vllm/parser.go | 296 ++---------------------------- 4 files changed, 346 insertions(+), 770 deletions(-) create mode 100644 pkg/backends/parser.go diff --git a/pkg/backends/llamacpp/parser.go b/pkg/backends/llamacpp/parser.go index d94b0ed..b5ce1f9 100644 --- a/pkg/backends/llamacpp/parser.go +++ b/pkg/backends/llamacpp/parser.go @@ -1,13 +1,7 @@ package llamacpp import ( - "encoding/json" - "errors" - "fmt" - "path/filepath" - "regexp" - "strconv" - "strings" + "llamactl/pkg/backends" ) // ParseLlamaCommand parses a llama-server command string into LlamaServerOptions @@ -17,270 +11,25 @@ import ( // 3. Args only: "--model file.gguf --gpu-layers 32" // 4. Multiline commands with backslashes func ParseLlamaCommand(command string) (*LlamaServerOptions, error) { - // 1. Normalize the command - handle multiline with backslashes - trimmed := normalizeMultilineCommand(command) - if trimmed == "" { - return nil, fmt.Errorf("command cannot be empty") - } - - // 2. Extract arguments from command - args, err := extractArgumentsFromCommand(trimmed) - if err != nil { - return nil, err - } - - // 3. Parse arguments into map - options := make(map[string]any) - - // Known multi-valued flags (snake_case form) - multiValued := map[string]struct{}{ - "override_tensor": {}, - "override_kv": {}, - "lora": {}, - "lora_scaled": {}, - "control_vector": {}, - "control_vector_scaled": {}, - "dry_sequence_breaker": {}, - "logit_bias": {}, - } - - i := 0 - for i < len(args) { - arg := args[i] - - if !strings.HasPrefix(arg, "-") { // skip positional / stray values - i++ - continue - } - - // Reject malformed flags with more than two leading dashes (e.g. ---model) to surface user mistakes - if strings.HasPrefix(arg, "---") { - return nil, fmt.Errorf("malformed flag: %s", arg) - } - - // Unified parsing for --flag=value vs --flag value - var rawFlag, rawValue string - hasEquals := false - if strings.Contains(arg, "=") { - parts := strings.SplitN(arg, "=", 2) - rawFlag = parts[0] - rawValue = parts[1] // may be empty string - hasEquals = true - } else { - rawFlag = arg - } - - flagCore := strings.TrimPrefix(strings.TrimPrefix(rawFlag, "-"), "-") - flagName := strings.ReplaceAll(flagCore, "-", "_") - - // Detect value if not in equals form - valueProvided := hasEquals - if !hasEquals { - if i+1 < len(args) && !isFlag(args[i+1]) { // next token is value - rawValue = args[i+1] - valueProvided = true - } - } - - // Determine if multi-valued flag - _, isMulti := multiValued[flagName] - - // Normalization helper: ensure slice for multi-valued flags - appendValue := func(valStr string) { - if existing, ok := options[flagName]; ok { - // Existing value; ensure slice semantics for multi-valued flags or repeated occurrences - if slice, ok := existing.([]string); ok { - options[flagName] = append(slice, valStr) - return - } - // Convert scalar to slice - options[flagName] = []string{fmt.Sprintf("%v", existing), valStr} - return - } - // First value - if isMulti { - options[flagName] = []string{valStr} - } else { - // We'll parse type below for single-valued flags - options[flagName] = valStr - } - } - - if valueProvided { - // Use raw token for multi-valued flags; else allow typed parsing - appendValue(rawValue) - if !isMulti { // convert to typed value if scalar - if strVal, ok := options[flagName].(string); ok { // still scalar - options[flagName] = parseValue(strVal) - } - } - // Advance index: if we consumed a following token as value (non equals form), skip it - if !hasEquals && i+1 < len(args) && rawValue == args[i+1] { - i += 2 - } else { - i++ - } - continue - } - - // Boolean flag (no value) - options[flagName] = true - i++ - } - - // 4. Convert to LlamaServerOptions using existing UnmarshalJSON - jsonData, err := json.Marshal(options) - if err != nil { - return nil, fmt.Errorf("failed to marshal parsed options: %w", err) + config := backends.CommandParserConfig{ + ExecutableNames: []string{"llama-server"}, + MultiValuedFlags: map[string]struct{}{ + "override_tensor": {}, + "override_kv": {}, + "lora": {}, + "lora_scaled": {}, + "control_vector": {}, + "control_vector_scaled": {}, + "dry_sequence_breaker": {}, + "logit_bias": {}, + }, } var llamaOptions LlamaServerOptions - if err := json.Unmarshal(jsonData, &llamaOptions); err != nil { - return nil, fmt.Errorf("failed to parse command options: %w", err) - } - - // 5. Return LlamaServerOptions - return &llamaOptions, nil -} - -// parseValue attempts to parse a string value into the most appropriate type -func parseValue(value string) any { - // Surrounding matching quotes (single or double) - if l := len(value); l >= 2 { - if (value[0] == '"' && value[l-1] == '"') || (value[0] == '\'' && value[l-1] == '\'') { - value = value[1 : l-1] - } - } - - lower := strings.ToLower(value) - if lower == "true" { - return true - } - if lower == "false" { - return false - } - - if intVal, err := strconv.Atoi(value); err == nil { - return intVal - } - if floatVal, err := strconv.ParseFloat(value, 64); err == nil { - return floatVal - } - return value -} - -// normalizeMultilineCommand handles multiline commands with backslashes -func normalizeMultilineCommand(command string) string { - // Handle escaped newlines (backslash followed by newline) - re := regexp.MustCompile(`\\\s*\n\s*`) - normalized := re.ReplaceAllString(command, " ") - - // Clean up extra whitespace - re = regexp.MustCompile(`\s+`) - normalized = re.ReplaceAllString(normalized, " ") - - return strings.TrimSpace(normalized) -} - -// extractArgumentsFromCommand extracts arguments from various command formats -func extractArgumentsFromCommand(command string) ([]string, error) { - // Split command into tokens respecting quotes - tokens, err := splitCommandTokens(command) - if err != nil { + if err := backends.ParseCommand(command, config, &llamaOptions); err != nil { return nil, err } - if len(tokens) == 0 { - return nil, fmt.Errorf("no command tokens found") - } - - // Check if first token looks like an executable - firstToken := tokens[0] - - // Case 1: Full path to executable (contains path separator or ends with llama-server) - if strings.Contains(firstToken, string(filepath.Separator)) || - strings.HasSuffix(filepath.Base(firstToken), "llama-server") { - return tokens[1:], nil // Return everything except the executable - } - - // Case 2: Just "llama-server" command - if strings.ToLower(firstToken) == "llama-server" { - return tokens[1:], nil // Return everything except the command - } - - // Case 3: Arguments only (starts with a flag) - if strings.HasPrefix(firstToken, "-") { - return tokens, nil // Return all tokens as arguments - } - - // Case 4: Unknown format - might be a different executable name - // Be permissive and assume it's the executable - return tokens[1:], nil + return &llamaOptions, nil } -// splitCommandTokens splits a command string into tokens, respecting quotes -func splitCommandTokens(command string) ([]string, error) { - var tokens []string - var current strings.Builder - inQuotes := false - quoteChar := byte(0) - escaped := false - - for i := 0; i < len(command); i++ { - c := command[i] - - if escaped { - current.WriteByte(c) - escaped = false - continue - } - - if c == '\\' { - escaped = true - current.WriteByte(c) - continue - } - - if !inQuotes && (c == '"' || c == '\'') { - inQuotes = true - quoteChar = c - current.WriteByte(c) - } else if inQuotes && c == quoteChar { - inQuotes = false - quoteChar = 0 - current.WriteByte(c) - } else if !inQuotes && (c == ' ' || c == '\t') { - if current.Len() > 0 { - tokens = append(tokens, current.String()) - current.Reset() - } - } else { - current.WriteByte(c) - } - } - - if inQuotes { - return nil, errors.New("unterminated quoted string") - } - - if current.Len() > 0 { - tokens = append(tokens, current.String()) - } - - return tokens, nil -} - -// isFlag determines if a string is a command line flag or a value -// Handles the special case where negative numbers should be treated as values, not flags -func isFlag(arg string) bool { - if !strings.HasPrefix(arg, "-") { - return false - } - - // Special case: if it's a negative number, treat it as a value - if _, err := strconv.ParseFloat(arg, 64); err == nil { - return false - } - - return true -} diff --git a/pkg/backends/mlx/parser.go b/pkg/backends/mlx/parser.go index 96b04a9..01fad0e 100644 --- a/pkg/backends/mlx/parser.go +++ b/pkg/backends/mlx/parser.go @@ -1,12 +1,7 @@ package mlx import ( - "encoding/json" - "fmt" - "path/filepath" - "regexp" - "strconv" - "strings" + "llamactl/pkg/backends" ) // ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions @@ -16,97 +11,16 @@ import ( // 3. Args only: "--model model/path --host 0.0.0.0" // 4. Multiline commands with backslashes func ParseMlxCommand(command string) (*MlxServerOptions, error) { - // 1. Normalize the command - handle multiline with backslashes - trimmed := normalizeMultilineCommand(command) - if trimmed == "" { - return nil, fmt.Errorf("command cannot be empty") - } - - // 2. Extract arguments from command - args, err := extractArgumentsFromCommand(trimmed) - if err != nil { - return nil, err - } - - // 3. Parse arguments into map - options := make(map[string]any) - - i := 0 - for i < len(args) { - arg := args[i] - - if !strings.HasPrefix(arg, "-") { // skip positional / stray values - i++ - continue - } - - // Reject malformed flags with more than two leading dashes (e.g. ---model) to surface user mistakes - if strings.HasPrefix(arg, "---") { - return nil, fmt.Errorf("malformed flag: %s", arg) - } - - // Unified parsing for --flag=value vs --flag value - var rawFlag, rawValue string - hasEquals := false - if strings.Contains(arg, "=") { - parts := strings.SplitN(arg, "=", 2) - rawFlag = parts[0] - rawValue = parts[1] // may be empty string - hasEquals = true - } else { - rawFlag = arg - } - - flagCore := strings.TrimPrefix(strings.TrimPrefix(rawFlag, "-"), "-") - flagName := strings.ReplaceAll(flagCore, "-", "_") - - // Detect value if not in equals form - valueProvided := hasEquals - if !hasEquals { - if i+1 < len(args) && !isFlag(args[i+1]) { // next token is value - rawValue = args[i+1] - valueProvided = true - } - } - - if valueProvided { - // MLX-specific validation for certain flags - if flagName == "log_level" && !isValidLogLevel(rawValue) { - return nil, fmt.Errorf("invalid log level: %s", rawValue) - } - - options[flagName] = parseValue(rawValue) - - // Advance index: if we consumed a following token as value (non equals form), skip it - if !hasEquals && i+1 < len(args) && rawValue == args[i+1] { - i += 2 - } else { - i++ - } - continue - } - - // Boolean flag (no value) - MLX specific boolean flags - if flagName == "trust_remote_code" || flagName == "use_default_chat_template" { - options[flagName] = true - } else { - options[flagName] = true - } - i++ - } - - // 4. Convert to MlxServerOptions using existing UnmarshalJSON - jsonData, err := json.Marshal(options) - if err != nil { - return nil, fmt.Errorf("failed to marshal parsed options: %w", err) + config := backends.CommandParserConfig{ + ExecutableNames: []string{"mlx_lm.server"}, + MultiValuedFlags: map[string]struct{}{}, // MLX has no multi-valued flags } var mlxOptions MlxServerOptions - if err := json.Unmarshal(jsonData, &mlxOptions); err != nil { - return nil, fmt.Errorf("failed to parse command options: %w", err) + if err := backends.ParseCommand(command, config, &mlxOptions); err != nil { + return nil, err } - // 5. Return MlxServerOptions return &mlxOptions, nil } @@ -121,134 +35,3 @@ func isValidLogLevel(level string) bool { return false } -// parseValue attempts to parse a string value into the most appropriate type -func parseValue(value string) any { - // Surrounding matching quotes (single or double) - if l := len(value); l >= 2 { - if (value[0] == '"' && value[l-1] == '"') || (value[0] == '\'' && value[l-1] == '\'') { - value = value[1 : l-1] - } - } - - lower := strings.ToLower(value) - if lower == "true" { - return true - } - if lower == "false" { - return false - } - - if intVal, err := strconv.Atoi(value); err == nil { - return intVal - } - if floatVal, err := strconv.ParseFloat(value, 64); err == nil { - return floatVal - } - return value -} - -// normalizeMultilineCommand handles multiline commands with backslashes -func normalizeMultilineCommand(command string) string { - // Handle escaped newlines (backslash followed by newline) - re := regexp.MustCompile(`\\\s*\n\s*`) - normalized := re.ReplaceAllString(command, " ") - - // Clean up extra whitespace - re = regexp.MustCompile(`\s+`) - normalized = re.ReplaceAllString(normalized, " ") - - return strings.TrimSpace(normalized) -} - -// extractArgumentsFromCommand extracts arguments from various command formats -func extractArgumentsFromCommand(command string) ([]string, error) { - // Split command into tokens respecting quotes - tokens, err := splitCommandTokens(command) - if err != nil { - return nil, err - } - - if len(tokens) == 0 { - return nil, fmt.Errorf("no command tokens found") - } - - // Check if first token looks like an executable - firstToken := tokens[0] - - // Case 1: Full path to executable (contains path separator or ends with mlx_lm.server) - if strings.Contains(firstToken, string(filepath.Separator)) || - strings.HasSuffix(filepath.Base(firstToken), "mlx_lm.server") { - return tokens[1:], nil // Return everything except the executable - } - - // Case 2: Just "mlx_lm.server" command - if strings.ToLower(firstToken) == "mlx_lm.server" { - return tokens[1:], nil // Return everything except the command - } - - // Case 3: Arguments only (starts with a flag) - if strings.HasPrefix(firstToken, "-") { - return tokens, nil // Return all tokens as arguments - } - - // Case 4: Unknown format - might be a different executable name - // Be permissive and assume it's the executable - return tokens[1:], nil -} - -// splitCommandTokens splits a command string into tokens, respecting quotes -func splitCommandTokens(command string) ([]string, error) { - var tokens []string - var current strings.Builder - inQuotes := false - quoteChar := byte(0) - escaped := false - - for i := 0; i < len(command); i++ { - c := command[i] - - if escaped { - current.WriteByte(c) - escaped = false - continue - } - - if c == '\\' { - escaped = true - current.WriteByte(c) - continue - } - - if !inQuotes && (c == '"' || c == '\'') { - inQuotes = true - quoteChar = c - current.WriteByte(c) - } else if inQuotes && c == quoteChar { - inQuotes = false - quoteChar = 0 - current.WriteByte(c) - } else if !inQuotes && (c == ' ' || c == '\t' || c == '\n') { - if current.Len() > 0 { - tokens = append(tokens, current.String()) - current.Reset() - } - } else { - current.WriteByte(c) - } - } - - if inQuotes { - return nil, fmt.Errorf("unclosed quote in command") - } - - if current.Len() > 0 { - tokens = append(tokens, current.String()) - } - - return tokens, nil -} - -// isFlag checks if a string looks like a command line flag -func isFlag(s string) bool { - return strings.HasPrefix(s, "-") -} \ No newline at end of file diff --git a/pkg/backends/parser.go b/pkg/backends/parser.go new file mode 100644 index 0000000..5023a7e --- /dev/null +++ b/pkg/backends/parser.go @@ -0,0 +1,310 @@ +package backends + +import ( + "encoding/json" + "errors" + "fmt" + "path/filepath" + "regexp" + "strconv" + "strings" +) + +// CommandParserConfig holds configuration for parsing command line arguments +type CommandParserConfig struct { + // ExecutableNames are the names of executables to detect (e.g., "llama-server", "mlx_lm.server") + ExecutableNames []string + // SubcommandNames are optional subcommands (e.g., "serve" for vllm) + SubcommandNames []string + // MultiValuedFlags are flags that can accept multiple values + MultiValuedFlags map[string]struct{} +} + +// ParseCommand parses a command string using the provided configuration +func ParseCommand(command string, config CommandParserConfig, target any) error { + // 1. Normalize the command - handle multiline with backslashes + trimmed := normalizeMultilineCommand(command) + if trimmed == "" { + return fmt.Errorf("command cannot be empty") + } + + // 2. Extract arguments from command + args, err := extractArgumentsFromCommand(trimmed, config) + if err != nil { + return err + } + + // 3. Parse arguments into map + options := make(map[string]any) + + i := 0 + for i < len(args) { + arg := args[i] + + if !strings.HasPrefix(arg, "-") { // skip positional / stray values + i++ + continue + } + + // Reject malformed flags with more than two leading dashes (e.g. ---model) to surface user mistakes + if strings.HasPrefix(arg, "---") { + return fmt.Errorf("malformed flag: %s", arg) + } + + // Unified parsing for --flag=value vs --flag value + var rawFlag, rawValue string + hasEquals := false + if strings.Contains(arg, "=") { + parts := strings.SplitN(arg, "=", 2) + rawFlag = parts[0] + rawValue = parts[1] // may be empty string + hasEquals = true + } else { + rawFlag = arg + } + + flagCore := strings.TrimPrefix(strings.TrimPrefix(rawFlag, "-"), "-") + flagName := strings.ReplaceAll(flagCore, "-", "_") + + // Detect value if not in equals form + valueProvided := hasEquals + if !hasEquals { + if i+1 < len(args) && !isFlag(args[i+1]) { // next token is value + rawValue = args[i+1] + valueProvided = true + } + } + + // Determine if multi-valued flag + _, isMulti := config.MultiValuedFlags[flagName] + + // Normalization helper: ensure slice for multi-valued flags + appendValue := func(valStr string) { + if existing, ok := options[flagName]; ok { + // Existing value; ensure slice semantics for multi-valued flags or repeated occurrences + if slice, ok := existing.([]string); ok { + options[flagName] = append(slice, valStr) + return + } + // Convert scalar to slice + options[flagName] = []string{fmt.Sprintf("%v", existing), valStr} + return + } + // First value + if isMulti { + options[flagName] = []string{valStr} + } else { + // We'll parse type below for single-valued flags + options[flagName] = valStr + } + } + + if valueProvided { + // Use raw token for multi-valued flags; else allow typed parsing + appendValue(rawValue) + if !isMulti { // convert to typed value if scalar + if strVal, ok := options[flagName].(string); ok { // still scalar + options[flagName] = parseValue(strVal) + } + } + // Advance index: if we consumed a following token as value (non equals form), skip it + if !hasEquals && i+1 < len(args) && rawValue == args[i+1] { + i += 2 + } else { + i++ + } + continue + } + + // Boolean flag (no value) + options[flagName] = true + i++ + } + + // 4. Convert to target struct using JSON marshaling + jsonData, err := json.Marshal(options) + if err != nil { + return fmt.Errorf("failed to marshal parsed options: %w", err) + } + + if err := json.Unmarshal(jsonData, target); err != nil { + return fmt.Errorf("failed to parse command options: %w", err) + } + + return nil +} + +// parseValue attempts to parse a string value into the most appropriate type +func parseValue(value string) any { + // Surrounding matching quotes (single or double) + if l := len(value); l >= 2 { + if (value[0] == '"' && value[l-1] == '"') || (value[0] == '\'' && value[l-1] == '\'') { + value = value[1 : l-1] + } + } + + lower := strings.ToLower(value) + if lower == "true" { + return true + } + if lower == "false" { + return false + } + + if intVal, err := strconv.Atoi(value); err == nil { + return intVal + } + if floatVal, err := strconv.ParseFloat(value, 64); err == nil { + return floatVal + } + return value +} + +// normalizeMultilineCommand handles multiline commands with backslashes +func normalizeMultilineCommand(command string) string { + // Handle escaped newlines (backslash followed by newline) + re := regexp.MustCompile(`\\\s*\n\s*`) + normalized := re.ReplaceAllString(command, " ") + + // Clean up extra whitespace + re = regexp.MustCompile(`\s+`) + normalized = re.ReplaceAllString(normalized, " ") + + return strings.TrimSpace(normalized) +} + +// extractArgumentsFromCommand extracts arguments from various command formats +func extractArgumentsFromCommand(command string, config CommandParserConfig) ([]string, error) { + // Split command into tokens respecting quotes + tokens, err := splitCommandTokens(command) + if err != nil { + return nil, err + } + + if len(tokens) == 0 { + return nil, fmt.Errorf("no command tokens found") + } + + firstToken := tokens[0] + + // Check for full path executable + if strings.Contains(firstToken, string(filepath.Separator)) { + baseName := filepath.Base(firstToken) + for _, execName := range config.ExecutableNames { + if strings.HasSuffix(baseName, execName) { + return skipExecutableAndSubcommands(tokens[1:], config.SubcommandNames) + } + } + // Unknown executable, assume it's still an executable + return skipExecutableAndSubcommands(tokens[1:], config.SubcommandNames) + } + + // Check for simple executable names + lowerFirstToken := strings.ToLower(firstToken) + for _, execName := range config.ExecutableNames { + if lowerFirstToken == strings.ToLower(execName) { + return skipExecutableAndSubcommands(tokens[1:], config.SubcommandNames) + } + } + + // Check for subcommands (like "serve" for vllm) + for _, subCmd := range config.SubcommandNames { + if lowerFirstToken == strings.ToLower(subCmd) { + return tokens[1:], nil // Return everything except the subcommand + } + } + + // Arguments only (starts with a flag) + if strings.HasPrefix(firstToken, "-") { + return tokens, nil // Return all tokens as arguments + } + + // Unknown format - might be a different executable name + return skipExecutableAndSubcommands(tokens[1:], config.SubcommandNames) +} + +// skipExecutableAndSubcommands removes subcommands from the beginning of tokens +func skipExecutableAndSubcommands(tokens []string, subcommands []string) ([]string, error) { + if len(tokens) == 0 { + return tokens, nil + } + + // Check if first token is a subcommand + if len(subcommands) > 0 && len(tokens) > 0 { + lowerFirstToken := strings.ToLower(tokens[0]) + for _, subCmd := range subcommands { + if lowerFirstToken == strings.ToLower(subCmd) { + return tokens[1:], nil // Skip the subcommand + } + } + } + + return tokens, nil +} + +// splitCommandTokens splits a command string into tokens, respecting quotes +func splitCommandTokens(command string) ([]string, error) { + var tokens []string + var current strings.Builder + inQuotes := false + quoteChar := byte(0) + escaped := false + + for i := 0; i < len(command); i++ { + c := command[i] + + if escaped { + current.WriteByte(c) + escaped = false + continue + } + + if c == '\\' { + escaped = true + current.WriteByte(c) + continue + } + + if !inQuotes && (c == '"' || c == '\'') { + inQuotes = true + quoteChar = c + current.WriteByte(c) + } else if inQuotes && c == quoteChar { + inQuotes = false + quoteChar = 0 + current.WriteByte(c) + } else if !inQuotes && (c == ' ' || c == '\t' || c == '\n') { + if current.Len() > 0 { + tokens = append(tokens, current.String()) + current.Reset() + } + } else { + current.WriteByte(c) + } + } + + if inQuotes { + return nil, errors.New("unterminated quoted string") + } + + if current.Len() > 0 { + tokens = append(tokens, current.String()) + } + + return tokens, nil +} + +// isFlag determines if a string is a command line flag or a value +// Handles the special case where negative numbers should be treated as values, not flags +func isFlag(arg string) bool { + if !strings.HasPrefix(arg, "-") { + return false + } + + // Special case: if it's a negative number, treat it as a value + if _, err := strconv.ParseFloat(arg, 64); err == nil { + return false + } + + return true +} diff --git a/pkg/backends/vllm/parser.go b/pkg/backends/vllm/parser.go index cb9125c..cce935d 100644 --- a/pkg/backends/vllm/parser.go +++ b/pkg/backends/vllm/parser.go @@ -1,13 +1,7 @@ package vllm import ( - "encoding/json" - "errors" - "fmt" - "path/filepath" - "regexp" - "strconv" - "strings" + "llamactl/pkg/backends" ) // ParseVllmCommand parses a vLLM serve command string into VllmServerOptions @@ -18,285 +12,25 @@ import ( // 4. Args only: "--model MODEL_NAME --other-args" // 5. Multiline commands with backslashes func ParseVllmCommand(command string) (*VllmServerOptions, error) { - // 1. Normalize the command - handle multiline with backslashes - trimmed := normalizeMultilineCommand(command) - if trimmed == "" { - return nil, fmt.Errorf("command cannot be empty") - } - - // 2. Extract arguments from command - args, err := extractArgumentsFromCommand(trimmed) - if err != nil { - return nil, err - } - - // 3. Parse arguments into map - options := make(map[string]any) - - // Known multi-valued flags (snake_case form) - multiValued := map[string]struct{}{ - "middleware": {}, - "api_key": {}, - "allowed_origins": {}, - "allowed_methods": {}, - "allowed_headers": {}, - "lora_modules": {}, - "prompt_adapters": {}, - } - - i := 0 - for i < len(args) { - arg := args[i] - - if !strings.HasPrefix(arg, "-") { // skip positional / stray values - i++ - continue - } - - // Reject malformed flags with more than two leading dashes (e.g. ---model) to surface user mistakes - if strings.HasPrefix(arg, "---") { - return nil, fmt.Errorf("malformed flag: %s", arg) - } - - // Unified parsing for --flag=value vs --flag value - var rawFlag, rawValue string - hasEquals := false - if strings.Contains(arg, "=") { - parts := strings.SplitN(arg, "=", 2) - rawFlag = parts[0] - rawValue = parts[1] // may be empty string - hasEquals = true - } else { - rawFlag = arg - } - - flagCore := strings.TrimPrefix(strings.TrimPrefix(rawFlag, "-"), "-") - flagName := strings.ReplaceAll(flagCore, "-", "_") - - // Detect value if not in equals form - valueProvided := hasEquals - if !hasEquals { - if i+1 < len(args) && !isFlag(args[i+1]) { // next token is value - rawValue = args[i+1] - valueProvided = true - } - } - - // Determine if multi-valued flag - _, isMulti := multiValued[flagName] - - // Normalization helper: ensure slice for multi-valued flags - appendValue := func(valStr string) { - if existing, ok := options[flagName]; ok { - // Existing value; ensure slice semantics for multi-valued flags or repeated occurrences - if slice, ok := existing.([]string); ok { - options[flagName] = append(slice, valStr) - return - } - // Convert scalar to slice - options[flagName] = []string{fmt.Sprintf("%v", existing), valStr} - return - } - // First value - if isMulti { - options[flagName] = []string{valStr} - } else { - // We'll parse type below for single-valued flags - options[flagName] = valStr - } - } - - if valueProvided { - // Use raw token for multi-valued flags; else allow typed parsing - appendValue(rawValue) - if !isMulti { // convert to typed value if scalar - if strVal, ok := options[flagName].(string); ok { // still scalar - options[flagName] = parseValue(strVal) - } - } - // Advance index: if we consumed a following token as value (non equals form), skip it - if !hasEquals && i+1 < len(args) && rawValue == args[i+1] { - i += 2 - } else { - i++ - } - continue - } - - // Boolean flag (no value) - options[flagName] = true - i++ - } - - // 4. Convert to VllmServerOptions using existing UnmarshalJSON - jsonData, err := json.Marshal(options) - if err != nil { - return nil, fmt.Errorf("failed to marshal parsed options: %w", err) + config := backends.CommandParserConfig{ + ExecutableNames: []string{"vllm"}, + SubcommandNames: []string{"serve"}, + MultiValuedFlags: map[string]struct{}{ + "middleware": {}, + "api_key": {}, + "allowed_origins": {}, + "allowed_methods": {}, + "allowed_headers": {}, + "lora_modules": {}, + "prompt_adapters": {}, + }, } var vllmOptions VllmServerOptions - if err := json.Unmarshal(jsonData, &vllmOptions); err != nil { - return nil, fmt.Errorf("failed to parse command options: %w", err) - } - - // 5. Return VllmServerOptions - return &vllmOptions, nil -} - -// parseValue attempts to parse a string value into the most appropriate type -func parseValue(value string) any { - // Surrounding matching quotes (single or double) - if l := len(value); l >= 2 { - if (value[0] == '"' && value[l-1] == '"') || (value[0] == '\'' && value[l-1] == '\'') { - value = value[1 : l-1] - } - } - - lower := strings.ToLower(value) - if lower == "true" { - return true - } - if lower == "false" { - return false - } - - if intVal, err := strconv.Atoi(value); err == nil { - return intVal - } - if floatVal, err := strconv.ParseFloat(value, 64); err == nil { - return floatVal - } - return value -} - -// normalizeMultilineCommand handles multiline commands with backslashes -func normalizeMultilineCommand(command string) string { - // Handle escaped newlines (backslash followed by newline) - re := regexp.MustCompile(`\\\s*\n\s*`) - normalized := re.ReplaceAllString(command, " ") - - // Clean up extra whitespace - re = regexp.MustCompile(`\s+`) - normalized = re.ReplaceAllString(normalized, " ") - - return strings.TrimSpace(normalized) -} - -// extractArgumentsFromCommand extracts arguments from various command formats -func extractArgumentsFromCommand(command string) ([]string, error) { - // Split command into tokens respecting quotes - tokens, err := splitCommandTokens(command) - if err != nil { + if err := backends.ParseCommand(command, config, &vllmOptions); err != nil { return nil, err } - if len(tokens) == 0 { - return nil, fmt.Errorf("no command tokens found") - } - - // Check if first token looks like an executable - firstToken := tokens[0] - - // Case 1: Full path to executable (contains path separator or ends with vllm) - if strings.Contains(firstToken, string(filepath.Separator)) || - strings.HasSuffix(filepath.Base(firstToken), "vllm") { - // Check if second token is "serve" - if len(tokens) > 1 && strings.ToLower(tokens[1]) == "serve" { - return tokens[2:], nil // Return everything except executable and serve - } - return tokens[1:], nil // Return everything except the executable - } - - // Case 2: Just "vllm" command - if strings.ToLower(firstToken) == "vllm" { - // Check if second token is "serve" - if len(tokens) > 1 && strings.ToLower(tokens[1]) == "serve" { - return tokens[2:], nil // Return everything except vllm and serve - } - return tokens[1:], nil // Return everything except vllm - } - - // Case 3: Just "serve" command - if strings.ToLower(firstToken) == "serve" { - return tokens[1:], nil // Return everything except serve - } - - // Case 4: Arguments only (starts with a flag) - if strings.HasPrefix(firstToken, "-") { - return tokens, nil // Return all tokens as arguments - } - - // Case 5: Unknown format - might be a different executable name - // Be permissive and assume it's the executable - if len(tokens) > 1 && strings.ToLower(tokens[1]) == "serve" { - return tokens[2:], nil // Return everything except executable and serve - } - return tokens[1:], nil + return &vllmOptions, nil } -// splitCommandTokens splits a command string into tokens, respecting quotes -func splitCommandTokens(command string) ([]string, error) { - var tokens []string - var current strings.Builder - inQuotes := false - quoteChar := byte(0) - escaped := false - - for i := 0; i < len(command); i++ { - c := command[i] - - if escaped { - current.WriteByte(c) - escaped = false - continue - } - - if c == '\\' { - escaped = true - current.WriteByte(c) - continue - } - - if !inQuotes && (c == '"' || c == '\'') { - inQuotes = true - quoteChar = c - current.WriteByte(c) - } else if inQuotes && c == quoteChar { - inQuotes = false - quoteChar = 0 - current.WriteByte(c) - } else if !inQuotes && (c == ' ' || c == '\t') { - if current.Len() > 0 { - tokens = append(tokens, current.String()) - current.Reset() - } - } else { - current.WriteByte(c) - } - } - - if inQuotes { - return nil, errors.New("unterminated quoted string") - } - - if current.Len() > 0 { - tokens = append(tokens, current.String()) - } - - return tokens, nil -} - -// isFlag determines if a string is a command line flag or a value -// Handles the special case where negative numbers should be treated as values, not flags -func isFlag(arg string) bool { - if !strings.HasPrefix(arg, "-") { - return false - } - - // Special case: if it's a negative number, treat it as a value - if _, err := strconv.ParseFloat(arg, 64); err == nil { - return false - } - - return true -} \ No newline at end of file From 9eecb37aec83384710a88934c3b2b341f8ffd909 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 19 Sep 2025 19:39:36 +0200 Subject: [PATCH 03/20] Refactor MLX and VLLM server options parsing and args building --- pkg/backends/llamacpp/parser_test.go | 352 ++------------------- pkg/backends/mlx/mlx.go | 235 ++++---------- pkg/backends/mlx/mlx_test.go | 62 ++++ pkg/backends/mlx/parser_test.go | 101 ++++++ pkg/backends/vllm/parser_test.go | 9 +- pkg/backends/vllm/vllm.go | 453 +++++++-------------------- pkg/backends/vllm/vllm_test.go | 39 +-- 7 files changed, 382 insertions(+), 869 deletions(-) create mode 100644 pkg/backends/mlx/mlx_test.go create mode 100644 pkg/backends/mlx/parser_test.go diff --git a/pkg/backends/llamacpp/parser_test.go b/pkg/backends/llamacpp/parser_test.go index 60e6a19..7072f65 100644 --- a/pkg/backends/llamacpp/parser_test.go +++ b/pkg/backends/llamacpp/parser_test.go @@ -1,6 +1,7 @@ -package llamacpp +package llamacpp_test import ( + "llamactl/pkg/backends/llamacpp" "testing" ) @@ -11,28 +12,23 @@ func TestParseLlamaCommand(t *testing.T) { expectErr bool }{ { - name: "basic command with model", - command: "llama-server --model /path/to/model.gguf", + name: "basic command", + command: "llama-server --model /path/to/model.gguf --gpu-layers 32", expectErr: false, }, { - name: "command with multiple flags", - command: "llama-server --model /path/to/model.gguf --gpu-layers 32 --ctx-size 4096", + name: "args only", + command: "--model /path/to/model.gguf --ctx-size 4096", expectErr: false, }, { - name: "command with short flags", - command: "llama-server -m /path/to/model.gguf -ngl 32 -c 4096", + name: "mixed flag formats", + command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose", expectErr: false, }, { - name: "command with equals format", - command: "llama-server --model=/path/to/model.gguf --gpu-layers=32", - expectErr: false, - }, - { - name: "command with boolean flags", - command: "llama-server --model /path/to/model.gguf --verbose --no-mmap", + name: "quoted strings", + command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`, expectErr: false, }, { @@ -41,46 +37,20 @@ func TestParseLlamaCommand(t *testing.T) { expectErr: true, }, { - name: "case insensitive command", - command: "LLAMA-SERVER --model /path/to/model.gguf", - expectErr: false, - }, - // New test cases for improved functionality - { - name: "args only without llama-server", - command: "--model /path/to/model.gguf --gpu-layers 32", - expectErr: false, + name: "unterminated quote", + command: `llama-server --model test.gguf --api-key "unterminated`, + expectErr: true, }, { - name: "full path to executable", - command: "/usr/local/bin/llama-server --model /path/to/model.gguf", - expectErr: false, - }, - { - name: "negative number handling", - command: "llama-server --gpu-layers -1 --model test.gguf", - expectErr: false, - }, - { - name: "multiline command with backslashes", - command: "llama-server --model /path/to/model.gguf \\\n --ctx-size 4096 \\\n --batch-size 512", - expectErr: false, - }, - { - name: "quoted string with special characters", - command: `llama-server --model test.gguf --chat-template "{% for message in messages %}{{ message.role }}: {{ message.content }}\n{% endfor %}"`, - expectErr: false, - }, - { - name: "unterminated quoted string", - command: `llama-server --model test.gguf --chat-template "unterminated quote`, + name: "malformed flag", + command: "llama-server ---model test.gguf", expectErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := ParseLlamaCommand(tt.command) + result, err := llamacpp.ParseLlamaCommand(tt.command) if tt.expectErr { if err == nil { @@ -96,16 +66,14 @@ func TestParseLlamaCommand(t *testing.T) { if result == nil { t.Errorf("expected result but got nil") - return } }) } } -func TestParseLlamaCommandSpecificValues(t *testing.T) { - // Test specific value parsing - command := "llama-server --model /test/model.gguf --gpu-layers 32 --ctx-size 4096 --verbose" - result, err := ParseLlamaCommand(command) +func TestParseLlamaCommandValues(t *testing.T) { + command := "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap" + result, err := llamacpp.ParseLlamaCommand(command) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -119,19 +87,22 @@ func TestParseLlamaCommandSpecificValues(t *testing.T) { t.Errorf("expected gpu_layers 32, got %d", result.GPULayers) } - if result.CtxSize != 4096 { - t.Errorf("expected ctx_size 4096, got %d", result.CtxSize) + if result.Temperature != 0.7 { + t.Errorf("expected temperature 0.7, got %f", result.Temperature) } if !result.Verbose { - t.Errorf("expected verbose to be true, got %v", result.Verbose) + t.Errorf("expected verbose to be true") + } + + if !result.NoMmap { + t.Errorf("expected no_mmap to be true") } } -func TestParseLlamaCommandArrayFlags(t *testing.T) { - // Test array flag handling (critical for lora, override-tensor, etc.) - command := "llama-server --model test.gguf --lora adapter1.bin --lora adapter2.bin" - result, err := ParseLlamaCommand(command) +func TestParseLlamaCommandArrays(t *testing.T) { + command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin" + result, err := llamacpp.ParseLlamaCommand(command) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -141,273 +112,10 @@ func TestParseLlamaCommandArrayFlags(t *testing.T) { t.Errorf("expected 2 lora adapters, got %d", len(result.Lora)) } - if result.Lora[0] != "adapter1.bin" || result.Lora[1] != "adapter2.bin" { - t.Errorf("expected lora adapters [adapter1.bin, adapter2.bin], got %v", result.Lora) - } -} - -func TestParseLlamaCommandMixedFormats(t *testing.T) { - // Test mixing --flag=value and --flag value formats - command := "llama-server --model=/path/model.gguf --gpu-layers 16 --batch-size=512 --verbose" - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "/path/model.gguf" { - t.Errorf("expected model '/path/model.gguf', got '%s'", result.Model) - } - - if result.GPULayers != 16 { - t.Errorf("expected gpu_layers 16, got %d", result.GPULayers) - } - - if result.BatchSize != 512 { - t.Errorf("expected batch_size 512, got %d", result.BatchSize) - } - - if !result.Verbose { - t.Errorf("expected verbose to be true, got %v", result.Verbose) - } -} - -func TestParseLlamaCommandTypeConversion(t *testing.T) { - // Test that values are converted to appropriate types - command := "llama-server --model test.gguf --temp 0.7 --top-k 40 --no-mmap" - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Temperature != 0.7 { - t.Errorf("expected temperature 0.7, got %f", result.Temperature) - } - - if result.TopK != 40 { - t.Errorf("expected top_k 40, got %d", result.TopK) - } - - if !result.NoMmap { - t.Errorf("expected no_mmap to be true, got %v", result.NoMmap) - } -} - -func TestParseLlamaCommandArgsOnly(t *testing.T) { - // Test parsing arguments without llama-server command - command := "--model /path/to/model.gguf --gpu-layers 32 --ctx-size 4096" - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "/path/to/model.gguf" { - t.Errorf("expected model '/path/to/model.gguf', got '%s'", result.Model) - } - - if result.GPULayers != 32 { - t.Errorf("expected gpu_layers 32, got %d", result.GPULayers) - } - - if result.CtxSize != 4096 { - t.Errorf("expected ctx_size 4096, got %d", result.CtxSize) - } -} - -func TestParseLlamaCommandFullPath(t *testing.T) { - // Test full path to executable - command := "/usr/local/bin/llama-server --model test.gguf --gpu-layers 16" - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "test.gguf" { - t.Errorf("expected model 'test.gguf', got '%s'", result.Model) - } - - if result.GPULayers != 16 { - t.Errorf("expected gpu_layers 16, got %d", result.GPULayers) - } -} - -func TestParseLlamaCommandNegativeNumbers(t *testing.T) { - // Test negative number parsing - command := "llama-server --model test.gguf --gpu-layers -1 --seed -12345" - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.GPULayers != -1 { - t.Errorf("expected gpu_layers -1, got %d", result.GPULayers) - } - - if result.Seed != -12345 { - t.Errorf("expected seed -12345, got %d", result.Seed) - } -} - -func TestParseLlamaCommandMultiline(t *testing.T) { - // Test multiline command with backslashes - command := `llama-server --model /path/to/model.gguf \ - --ctx-size 4096 \ - --batch-size 512 \ - --gpu-layers 32` - - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "/path/to/model.gguf" { - t.Errorf("expected model '/path/to/model.gguf', got '%s'", result.Model) - } - - if result.CtxSize != 4096 { - t.Errorf("expected ctx_size 4096, got %d", result.CtxSize) - } - - if result.BatchSize != 512 { - t.Errorf("expected batch_size 512, got %d", result.BatchSize) - } - - if result.GPULayers != 32 { - t.Errorf("expected gpu_layers 32, got %d", result.GPULayers) - } -} - -func TestParseLlamaCommandQuotedStrings(t *testing.T) { - // Test quoted strings with special characters - command := `llama-server --model test.gguf --api-key "sk-1234567890abcdef" --chat-template "User: {user}\nAssistant: "` - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "test.gguf" { - t.Errorf("expected model 'test.gguf', got '%s'", result.Model) - } - - if result.APIKey != "sk-1234567890abcdef" { - t.Errorf("expected api_key 'sk-1234567890abcdef', got '%s'", result.APIKey) - } - - expectedTemplate := "User: {user}\\nAssistant: " - if result.ChatTemplate != expectedTemplate { - t.Errorf("expected chat_template '%s', got '%s'", expectedTemplate, result.ChatTemplate) - } -} - -func TestParseLlamaCommandUnslothExample(t *testing.T) { - // Test with realistic unsloth-style command - command := `llama-server --model /path/to/model.gguf \ - --ctx-size 4096 \ - --batch-size 512 \ - --gpu-layers -1 \ - --temp 0.7 \ - --repeat-penalty 1.1 \ - --top-k 40 \ - --top-p 0.95 \ - --host 0.0.0.0 \ - --port 8000 \ - --api-key "sk-1234567890abcdef"` - - result, err := ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // Verify key fields - if result.Model != "/path/to/model.gguf" { - t.Errorf("expected model '/path/to/model.gguf', got '%s'", result.Model) - } - - if result.CtxSize != 4096 { - t.Errorf("expected ctx_size 4096, got %d", result.CtxSize) - } - - if result.BatchSize != 512 { - t.Errorf("expected batch_size 512, got %d", result.BatchSize) - } - - if result.GPULayers != -1 { - t.Errorf("expected gpu_layers -1, got %d", result.GPULayers) - } - - if result.Temperature != 0.7 { - t.Errorf("expected temperature 0.7, got %f", result.Temperature) - } - - if result.RepeatPenalty != 1.1 { - t.Errorf("expected repeat_penalty 1.1, got %f", result.RepeatPenalty) - } - - if result.TopK != 40 { - t.Errorf("expected top_k 40, got %d", result.TopK) - } - - if result.TopP != 0.95 { - t.Errorf("expected top_p 0.95, got %f", result.TopP) - } - - if result.Host != "0.0.0.0" { - t.Errorf("expected host '0.0.0.0', got '%s'", result.Host) - } - - if result.Port != 8000 { - t.Errorf("expected port 8000, got %d", result.Port) - } - - if result.APIKey != "sk-1234567890abcdef" { - t.Errorf("expected api_key 'sk-1234567890abcdef', got '%s'", result.APIKey) - } -} - -// Focused additional edge case tests (kept minimal per guidance) -func TestParseLlamaCommandSingleQuotedValue(t *testing.T) { - cmd := "llama-server --model 'my model.gguf' --alias 'Test Alias'" - result, err := ParseLlamaCommand(cmd) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result.Model != "my model.gguf" { - t.Errorf("expected model 'my model.gguf', got '%s'", result.Model) - } - if result.Alias != "Test Alias" { - t.Errorf("expected alias 'Test Alias', got '%s'", result.Alias) - } -} - -func TestParseLlamaCommandMixedArrayForms(t *testing.T) { - // Same multi-value flag using --flag value and --flag=value forms - cmd := "llama-server --lora adapter1.bin --lora=adapter2.bin --lora adapter3.bin" - result, err := ParseLlamaCommand(cmd) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(result.Lora) != 3 { - t.Fatalf("expected 3 lora values, got %d (%v)", len(result.Lora), result.Lora) - } - expected := []string{"adapter1.bin", "adapter2.bin", "adapter3.bin"} + expected := []string{"adapter1.bin", "adapter2.bin"} for i, v := range expected { if result.Lora[i] != v { t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i]) } } } - -func TestParseLlamaCommandMalformedFlag(t *testing.T) { - cmd := "llama-server ---model test.gguf" - _, err := ParseLlamaCommand(cmd) - if err == nil { - t.Fatalf("expected error for malformed flag but got none") - } -} diff --git a/pkg/backends/mlx/mlx.go b/pkg/backends/mlx/mlx.go index c3324d2..8527c7b 100644 --- a/pkg/backends/mlx/mlx.go +++ b/pkg/backends/mlx/mlx.go @@ -1,205 +1,88 @@ package mlx import ( - "encoding/json" "reflect" "strconv" + "strings" ) type MlxServerOptions struct { // Basic connection options - Model string `json:"model,omitempty"` - Host string `json:"host,omitempty"` - Port int `json:"port,omitempty"` - + Model string `json:"model,omitempty"` + Host string `json:"host,omitempty"` + Port int `json:"port,omitempty"` + // Model and adapter options AdapterPath string `json:"adapter_path,omitempty"` DraftModel string `json:"draft_model,omitempty"` NumDraftTokens int `json:"num_draft_tokens,omitempty"` TrustRemoteCode bool `json:"trust_remote_code,omitempty"` - + // Logging and templates - LogLevel string `json:"log_level,omitempty"` - ChatTemplate string `json:"chat_template,omitempty"` - UseDefaultChatTemplate bool `json:"use_default_chat_template,omitempty"` - ChatTemplateArgs string `json:"chat_template_args,omitempty"` // JSON string - + LogLevel string `json:"log_level,omitempty"` + ChatTemplate string `json:"chat_template,omitempty"` + UseDefaultChatTemplate bool `json:"use_default_chat_template,omitempty"` + ChatTemplateArgs string `json:"chat_template_args,omitempty"` // JSON string + // Sampling defaults - Temp float64 `json:"temp,omitempty"` // Note: MLX uses "temp" not "temperature" - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - MinP float64 `json:"min_p,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` + Temp float64 `json:"temp,omitempty"` // Note: MLX uses "temp" not "temperature" + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + MinP float64 `json:"min_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` } -// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names -func (o *MlxServerOptions) UnmarshalJSON(data []byte) error { - // First unmarshal into a map to handle multiple field names - var raw map[string]any - if err := json.Unmarshal(data, &raw); err != nil { - return err - } +// BuildCommandArgs converts to command line arguments using reflection +func (o *MlxServerOptions) BuildCommandArgs() []string { + var args []string - // Create a temporary struct for standard unmarshaling - type tempOptions MlxServerOptions - temp := tempOptions{} + v := reflect.ValueOf(o).Elem() + t := v.Type() - // Standard unmarshal first - if err := json.Unmarshal(data, &temp); err != nil { - return err - } + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldType := t.Field(i) - // Copy to our struct - *o = MlxServerOptions(temp) + // Skip unexported fields + if !field.CanInterface() { + continue + } - // Handle alternative field names - fieldMappings := map[string]string{ - // Basic connection options - "m": "model", - "host": "host", - "port": "port", -// "python_path": "python_path", // removed - - // Model and adapter options - "adapter-path": "adapter_path", - "draft-model": "draft_model", - "num-draft-tokens": "num_draft_tokens", - "trust-remote-code": "trust_remote_code", - - // Logging and templates - "log-level": "log_level", - "chat-template": "chat_template", - "use-default-chat-template": "use_default_chat_template", - "chat-template-args": "chat_template_args", - - // Sampling defaults - "temperature": "temp", // Support both temp and temperature - "top-p": "top_p", - "top-k": "top_k", - "min-p": "min_p", - "max-tokens": "max_tokens", - } + // Get the JSON tag to determine the flag name + jsonTag := fieldType.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } - // Process alternative field names - for altName, canonicalName := range fieldMappings { - if value, exists := raw[altName]; exists { - // Use reflection to set the field value - v := reflect.ValueOf(o).Elem() - field := v.FieldByNameFunc(func(fieldName string) bool { - field, _ := v.Type().FieldByName(fieldName) - jsonTag := field.Tag.Get("json") - return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName - }) + // Remove ",omitempty" from the tag + flagName := jsonTag + if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 { + flagName = jsonTag[:commaIndex] + } - if field.IsValid() && field.CanSet() { - switch field.Kind() { - case reflect.Int: - if intVal, ok := value.(float64); ok { - field.SetInt(int64(intVal)) - } else if strVal, ok := value.(string); ok { - if intVal, err := strconv.Atoi(strVal); err == nil { - field.SetInt(int64(intVal)) - } - } - case reflect.Float64: - if floatVal, ok := value.(float64); ok { - field.SetFloat(floatVal) - } else if strVal, ok := value.(string); ok { - if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil { - field.SetFloat(floatVal) - } - } - case reflect.String: - if strVal, ok := value.(string); ok { - field.SetString(strVal) - } - case reflect.Bool: - if boolVal, ok := value.(bool); ok { - field.SetBool(boolVal) - } - } + // Convert snake_case to kebab-case for CLI flags + flagName = strings.ReplaceAll(flagName, "_", "-") + + // Add the appropriate arguments based on field type and value + switch field.Kind() { + case reflect.Bool: + if field.Bool() { + args = append(args, "--"+flagName) + } + case reflect.Int: + if field.Int() != 0 { + args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) + } + case reflect.Float64: + if field.Float() != 0 { + args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) + } + case reflect.String: + if field.String() != "" { + args = append(args, "--"+flagName, field.String()) } } } - return nil -} - -// NewMlxServerOptions creates MlxServerOptions with MLX defaults -func NewMlxServerOptions() *MlxServerOptions { - return &MlxServerOptions{ - Host: "127.0.0.1", // MLX default (different from llama-server) - Port: 8080, // MLX default - NumDraftTokens: 3, // MLX default for speculative decoding - LogLevel: "INFO", // MLX default - Temp: 0.0, // MLX default - TopP: 1.0, // MLX default - TopK: 0, // MLX default (disabled) - MinP: 0.0, // MLX default (disabled) - MaxTokens: 512, // MLX default - ChatTemplateArgs: "{}", // MLX default (empty JSON object) - } -} - -// BuildCommandArgs converts to command line arguments -func (o *MlxServerOptions) BuildCommandArgs() []string { - var args []string - - // Required and basic options - if o.Model != "" { - args = append(args, "--model", o.Model) - } - if o.Host != "" { - args = append(args, "--host", o.Host) - } - if o.Port != 0 { - args = append(args, "--port", strconv.Itoa(o.Port)) - } - - // Model and adapter options - if o.AdapterPath != "" { - args = append(args, "--adapter-path", o.AdapterPath) - } - if o.DraftModel != "" { - args = append(args, "--draft-model", o.DraftModel) - } - if o.NumDraftTokens != 0 { - args = append(args, "--num-draft-tokens", strconv.Itoa(o.NumDraftTokens)) - } - if o.TrustRemoteCode { - args = append(args, "--trust-remote-code") - } - - // Logging and templates - if o.LogLevel != "" { - args = append(args, "--log-level", o.LogLevel) - } - if o.ChatTemplate != "" { - args = append(args, "--chat-template", o.ChatTemplate) - } - if o.UseDefaultChatTemplate { - args = append(args, "--use-default-chat-template") - } - if o.ChatTemplateArgs != "" { - args = append(args, "--chat-template-args", o.ChatTemplateArgs) - } - - // Sampling defaults - if o.Temp != 0 { - args = append(args, "--temp", strconv.FormatFloat(o.Temp, 'f', -1, 64)) - } - if o.TopP != 0 { - args = append(args, "--top-p", strconv.FormatFloat(o.TopP, 'f', -1, 64)) - } - if o.TopK != 0 { - args = append(args, "--top-k", strconv.Itoa(o.TopK)) - } - if o.MinP != 0 { - args = append(args, "--min-p", strconv.FormatFloat(o.MinP, 'f', -1, 64)) - } - if o.MaxTokens != 0 { - args = append(args, "--max-tokens", strconv.Itoa(o.MaxTokens)) - } - return args -} \ No newline at end of file +} diff --git a/pkg/backends/mlx/mlx_test.go b/pkg/backends/mlx/mlx_test.go new file mode 100644 index 0000000..b35f512 --- /dev/null +++ b/pkg/backends/mlx/mlx_test.go @@ -0,0 +1,62 @@ +package mlx_test + +import ( + "llamactl/pkg/backends/mlx" + "testing" +) + +func TestBuildCommandArgs(t *testing.T) { + options := &mlx.MlxServerOptions{ + Model: "/test/model.mlx", + Host: "127.0.0.1", + Port: 8080, + Temp: 0.7, + TopP: 0.9, + TopK: 40, + MaxTokens: 2048, + TrustRemoteCode: true, + LogLevel: "DEBUG", + ChatTemplate: "custom template", + } + + args := options.BuildCommandArgs() + + // Check that all expected flags are present + expectedFlags := map[string]string{ + "--model": "/test/model.mlx", + "--host": "127.0.0.1", + "--port": "8080", + "--log-level": "DEBUG", + "--chat-template": "custom template", + "--temp": "0.7", + "--top-p": "0.9", + "--top-k": "40", + "--max-tokens": "2048", + } + + for i := 0; i < len(args); i++ { + if args[i] == "--trust-remote-code" { + continue // Boolean flag with no value + } + if args[i] == "--use-default-chat-template" { + continue // Boolean flag with no value + } + + if expectedValue, exists := expectedFlags[args[i]]; exists && i+1 < len(args) { + if args[i+1] != expectedValue { + t.Errorf("expected %s to have value %s, got %s", args[i], expectedValue, args[i+1]) + } + } + } + + // Check boolean flags + foundTrustRemoteCode := false + for _, arg := range args { + if arg == "--trust-remote-code" { + foundTrustRemoteCode = true + } + } + if !foundTrustRemoteCode { + t.Errorf("expected --trust-remote-code flag to be present") + } +} diff --git a/pkg/backends/mlx/parser_test.go b/pkg/backends/mlx/parser_test.go new file mode 100644 index 0000000..6caae84 --- /dev/null +++ b/pkg/backends/mlx/parser_test.go @@ -0,0 +1,101 @@ +package mlx_test + +import ( + "llamactl/pkg/backends/mlx" + "testing" +) + +func TestParseMlxCommand(t *testing.T) { + tests := []struct { + name string + command string + expectErr bool + }{ + { + name: "basic command", + command: "mlx_lm.server --model /path/to/model --host 0.0.0.0", + expectErr: false, + }, + { + name: "args only", + command: "--model /path/to/model --port 8080", + expectErr: false, + }, + { + name: "mixed flag formats", + command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code", + expectErr: false, + }, + { + name: "quoted strings", + command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`, + expectErr: false, + }, + { + name: "empty command", + command: "", + expectErr: true, + }, + { + name: "unterminated quote", + command: `mlx_lm.server --model test.mlx --chat-template "unterminated`, + expectErr: true, + }, + { + name: "malformed flag", + command: "mlx_lm.server ---model test.mlx", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := mlx.ParseMlxCommand(tt.command) + + if tt.expectErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result == nil { + t.Errorf("expected result but got nil") + } + }) + } +} + +func TestParseMlxCommandValues(t *testing.T) { + command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG" + result, err := mlx.ParseMlxCommand(command) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Model != "/test/model.mlx" { + t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model) + } + + if result.Port != 8080 { + t.Errorf("expected port 8080, got %d", result.Port) + } + + if result.Temp != 0.7 { + t.Errorf("expected temp 0.7, got %f", result.Temp) + } + + if !result.TrustRemoteCode { + t.Errorf("expected trust_remote_code to be true") + } + + if result.LogLevel != "DEBUG" { + t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel) + } +} diff --git a/pkg/backends/vllm/parser_test.go b/pkg/backends/vllm/parser_test.go index 91921b2..3a12456 100644 --- a/pkg/backends/vllm/parser_test.go +++ b/pkg/backends/vllm/parser_test.go @@ -1,6 +1,7 @@ -package vllm +package vllm_test import ( + "llamactl/pkg/backends/vllm" "testing" ) @@ -39,7 +40,7 @@ func TestParseVllmCommand(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := ParseVllmCommand(tt.command) + result, err := vllm.ParseVllmCommand(tt.command) if tt.expectErr { if err == nil { @@ -62,7 +63,7 @@ func TestParseVllmCommand(t *testing.T) { func TestParseVllmCommandValues(t *testing.T) { command := "vllm serve --model test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs" - result, err := ParseVllmCommand(command) + result, err := vllm.ParseVllmCommand(command) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -80,4 +81,4 @@ func TestParseVllmCommandValues(t *testing.T) { if !result.EnableLogOutputs { t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs) } -} \ No newline at end of file +} diff --git a/pkg/backends/vllm/vllm.go b/pkg/backends/vllm/vllm.go index 6378b5e..9aa865c 100644 --- a/pkg/backends/vllm/vllm.go +++ b/pkg/backends/vllm/vllm.go @@ -1,7 +1,6 @@ package vllm import ( - "encoding/json" "reflect" "strconv" "strings" @@ -13,349 +12,124 @@ type VllmServerOptions struct { Port int `json:"port,omitempty"` // Model and engine configuration - Model string `json:"model,omitempty"` - Tokenizer string `json:"tokenizer,omitempty"` - SkipTokenizerInit bool `json:"skip_tokenizer_init,omitempty"` - Revision string `json:"revision,omitempty"` - CodeRevision string `json:"code_revision,omitempty"` - TokenizerRevision string `json:"tokenizer_revision,omitempty"` - TokenizerMode string `json:"tokenizer_mode,omitempty"` - TrustRemoteCode bool `json:"trust_remote_code,omitempty"` - DownloadDir string `json:"download_dir,omitempty"` - LoadFormat string `json:"load_format,omitempty"` - ConfigFormat string `json:"config_format,omitempty"` - Dtype string `json:"dtype,omitempty"` - KVCacheDtype string `json:"kv_cache_dtype,omitempty"` - QuantizationParamPath string `json:"quantization_param_path,omitempty"` - Seed int `json:"seed,omitempty"` - MaxModelLen int `json:"max_model_len,omitempty"` - GuidedDecodingBackend string `json:"guided_decoding_backend,omitempty"` - DistributedExecutorBackend string `json:"distributed_executor_backend,omitempty"` - WorkerUseRay bool `json:"worker_use_ray,omitempty"` - RayWorkersUseNSight bool `json:"ray_workers_use_nsight,omitempty"` + Model string `json:"model,omitempty"` + Tokenizer string `json:"tokenizer,omitempty"` + SkipTokenizerInit bool `json:"skip_tokenizer_init,omitempty"` + Revision string `json:"revision,omitempty"` + CodeRevision string `json:"code_revision,omitempty"` + TokenizerRevision string `json:"tokenizer_revision,omitempty"` + TokenizerMode string `json:"tokenizer_mode,omitempty"` + TrustRemoteCode bool `json:"trust_remote_code,omitempty"` + DownloadDir string `json:"download_dir,omitempty"` + LoadFormat string `json:"load_format,omitempty"` + ConfigFormat string `json:"config_format,omitempty"` + Dtype string `json:"dtype,omitempty"` + KVCacheDtype string `json:"kv_cache_dtype,omitempty"` + QuantizationParamPath string `json:"quantization_param_path,omitempty"` + Seed int `json:"seed,omitempty"` + MaxModelLen int `json:"max_model_len,omitempty"` + GuidedDecodingBackend string `json:"guided_decoding_backend,omitempty"` + DistributedExecutorBackend string `json:"distributed_executor_backend,omitempty"` + WorkerUseRay bool `json:"worker_use_ray,omitempty"` + RayWorkersUseNSight bool `json:"ray_workers_use_nsight,omitempty"` // Performance and serving configuration - BlockSize int `json:"block_size,omitempty"` - EnablePrefixCaching bool `json:"enable_prefix_caching,omitempty"` - DisableSlidingWindow bool `json:"disable_sliding_window,omitempty"` - UseV2BlockManager bool `json:"use_v2_block_manager,omitempty"` - NumLookaheadSlots int `json:"num_lookahead_slots,omitempty"` - SwapSpace int `json:"swap_space,omitempty"` - CPUOffloadGB int `json:"cpu_offload_gb,omitempty"` - GPUMemoryUtilization float64 `json:"gpu_memory_utilization,omitempty"` - NumGPUBlocksOverride int `json:"num_gpu_blocks_override,omitempty"` - MaxNumBatchedTokens int `json:"max_num_batched_tokens,omitempty"` - MaxNumSeqs int `json:"max_num_seqs,omitempty"` - MaxLogprobs int `json:"max_logprobs,omitempty"` - DisableLogStats bool `json:"disable_log_stats,omitempty"` - Quantization string `json:"quantization,omitempty"` - RopeScaling string `json:"rope_scaling,omitempty"` - RopeTheta float64 `json:"rope_theta,omitempty"` - EnforceEager bool `json:"enforce_eager,omitempty"` - MaxContextLenToCapture int `json:"max_context_len_to_capture,omitempty"` - MaxSeqLenToCapture int `json:"max_seq_len_to_capture,omitempty"` - DisableCustomAllReduce bool `json:"disable_custom_all_reduce,omitempty"` - TokenizerPoolSize int `json:"tokenizer_pool_size,omitempty"` - TokenizerPoolType string `json:"tokenizer_pool_type,omitempty"` - TokenizerPoolExtraConfig string `json:"tokenizer_pool_extra_config,omitempty"` - EnableLoraBias bool `json:"enable_lora_bias,omitempty"` - LoraExtraVocabSize int `json:"lora_extra_vocab_size,omitempty"` - LoraRank int `json:"lora_rank,omitempty"` - PromptLookbackDistance int `json:"prompt_lookback_distance,omitempty"` - PreemptionMode string `json:"preemption_mode,omitempty"` + BlockSize int `json:"block_size,omitempty"` + EnablePrefixCaching bool `json:"enable_prefix_caching,omitempty"` + DisableSlidingWindow bool `json:"disable_sliding_window,omitempty"` + UseV2BlockManager bool `json:"use_v2_block_manager,omitempty"` + NumLookaheadSlots int `json:"num_lookahead_slots,omitempty"` + SwapSpace int `json:"swap_space,omitempty"` + CPUOffloadGB int `json:"cpu_offload_gb,omitempty"` + GPUMemoryUtilization float64 `json:"gpu_memory_utilization,omitempty"` + NumGPUBlocksOverride int `json:"num_gpu_blocks_override,omitempty"` + MaxNumBatchedTokens int `json:"max_num_batched_tokens,omitempty"` + MaxNumSeqs int `json:"max_num_seqs,omitempty"` + MaxLogprobs int `json:"max_logprobs,omitempty"` + DisableLogStats bool `json:"disable_log_stats,omitempty"` + Quantization string `json:"quantization,omitempty"` + RopeScaling string `json:"rope_scaling,omitempty"` + RopeTheta float64 `json:"rope_theta,omitempty"` + EnforceEager bool `json:"enforce_eager,omitempty"` + MaxContextLenToCapture int `json:"max_context_len_to_capture,omitempty"` + MaxSeqLenToCapture int `json:"max_seq_len_to_capture,omitempty"` + DisableCustomAllReduce bool `json:"disable_custom_all_reduce,omitempty"` + TokenizerPoolSize int `json:"tokenizer_pool_size,omitempty"` + TokenizerPoolType string `json:"tokenizer_pool_type,omitempty"` + TokenizerPoolExtraConfig string `json:"tokenizer_pool_extra_config,omitempty"` + EnableLoraBias bool `json:"enable_lora_bias,omitempty"` + LoraExtraVocabSize int `json:"lora_extra_vocab_size,omitempty"` + LoraRank int `json:"lora_rank,omitempty"` + PromptLookbackDistance int `json:"prompt_lookback_distance,omitempty"` + PreemptionMode string `json:"preemption_mode,omitempty"` // Distributed and parallel processing - TensorParallelSize int `json:"tensor_parallel_size,omitempty"` - PipelineParallelSize int `json:"pipeline_parallel_size,omitempty"` - MaxParallelLoadingWorkers int `json:"max_parallel_loading_workers,omitempty"` - DisableAsyncOutputProc bool `json:"disable_async_output_proc,omitempty"` - WorkerClass string `json:"worker_class,omitempty"` - EnabledLoraModules string `json:"enabled_lora_modules,omitempty"` - MaxLoraRank int `json:"max_lora_rank,omitempty"` - FullyShardedLoras bool `json:"fully_sharded_loras,omitempty"` - LoraModules string `json:"lora_modules,omitempty"` - PromptAdapters string `json:"prompt_adapters,omitempty"` - MaxPromptAdapterToken int `json:"max_prompt_adapter_token,omitempty"` - Device string `json:"device,omitempty"` - SchedulerDelay float64 `json:"scheduler_delay,omitempty"` - EnableChunkedPrefill bool `json:"enable_chunked_prefill,omitempty"` - SpeculativeModel string `json:"speculative_model,omitempty"` - SpeculativeModelQuantization string `json:"speculative_model_quantization,omitempty"` - SpeculativeRevision string `json:"speculative_revision,omitempty"` - SpeculativeMaxModelLen int `json:"speculative_max_model_len,omitempty"` - SpeculativeDisableByBatchSize int `json:"speculative_disable_by_batch_size,omitempty"` - NgptSpeculativeLength int `json:"ngpt_speculative_length,omitempty"` - SpeculativeDisableMqa bool `json:"speculative_disable_mqa,omitempty"` - ModelLoaderExtraConfig string `json:"model_loader_extra_config,omitempty"` - IgnorePatterns string `json:"ignore_patterns,omitempty"` - PreloadedLoraModules string `json:"preloaded_lora_modules,omitempty"` + TensorParallelSize int `json:"tensor_parallel_size,omitempty"` + PipelineParallelSize int `json:"pipeline_parallel_size,omitempty"` + MaxParallelLoadingWorkers int `json:"max_parallel_loading_workers,omitempty"` + DisableAsyncOutputProc bool `json:"disable_async_output_proc,omitempty"` + WorkerClass string `json:"worker_class,omitempty"` + EnabledLoraModules string `json:"enabled_lora_modules,omitempty"` + MaxLoraRank int `json:"max_lora_rank,omitempty"` + FullyShardedLoras bool `json:"fully_sharded_loras,omitempty"` + LoraModules string `json:"lora_modules,omitempty"` + PromptAdapters string `json:"prompt_adapters,omitempty"` + MaxPromptAdapterToken int `json:"max_prompt_adapter_token,omitempty"` + Device string `json:"device,omitempty"` + SchedulerDelay float64 `json:"scheduler_delay,omitempty"` + EnableChunkedPrefill bool `json:"enable_chunked_prefill,omitempty"` + SpeculativeModel string `json:"speculative_model,omitempty"` + SpeculativeModelQuantization string `json:"speculative_model_quantization,omitempty"` + SpeculativeRevision string `json:"speculative_revision,omitempty"` + SpeculativeMaxModelLen int `json:"speculative_max_model_len,omitempty"` + SpeculativeDisableByBatchSize int `json:"speculative_disable_by_batch_size,omitempty"` + NgptSpeculativeLength int `json:"ngpt_speculative_length,omitempty"` + SpeculativeDisableMqa bool `json:"speculative_disable_mqa,omitempty"` + ModelLoaderExtraConfig string `json:"model_loader_extra_config,omitempty"` + IgnorePatterns string `json:"ignore_patterns,omitempty"` + PreloadedLoraModules string `json:"preloaded_lora_modules,omitempty"` // OpenAI server specific options - UDS string `json:"uds,omitempty"` - UvicornLogLevel string `json:"uvicorn_log_level,omitempty"` - ResponseRole string `json:"response_role,omitempty"` - SSLKeyfile string `json:"ssl_keyfile,omitempty"` - SSLCertfile string `json:"ssl_certfile,omitempty"` - SSLCACerts string `json:"ssl_ca_certs,omitempty"` - SSLCertReqs int `json:"ssl_cert_reqs,omitempty"` - RootPath string `json:"root_path,omitempty"` - Middleware []string `json:"middleware,omitempty"` - ReturnTokensAsTokenIDS bool `json:"return_tokens_as_token_ids,omitempty"` - DisableFrontendMultiprocessing bool `json:"disable_frontend_multiprocessing,omitempty"` - EnableAutoToolChoice bool `json:"enable_auto_tool_choice,omitempty"` - ToolCallParser string `json:"tool_call_parser,omitempty"` - ToolServer string `json:"tool_server,omitempty"` - ChatTemplate string `json:"chat_template,omitempty"` - ChatTemplateContentFormat string `json:"chat_template_content_format,omitempty"` - AllowCredentials bool `json:"allow_credentials,omitempty"` - AllowedOrigins []string `json:"allowed_origins,omitempty"` - AllowedMethods []string `json:"allowed_methods,omitempty"` - AllowedHeaders []string `json:"allowed_headers,omitempty"` - APIKey []string `json:"api_key,omitempty"` - EnableLogOutputs bool `json:"enable_log_outputs,omitempty"` - EnableTokenUsage bool `json:"enable_token_usage,omitempty"` - EnableAsyncEngineDebug bool `json:"enable_async_engine_debug,omitempty"` - EngineUseRay bool `json:"engine_use_ray,omitempty"` - DisableLogRequests bool `json:"disable_log_requests,omitempty"` - MaxLogLen int `json:"max_log_len,omitempty"` + UDS string `json:"uds,omitempty"` + UvicornLogLevel string `json:"uvicorn_log_level,omitempty"` + ResponseRole string `json:"response_role,omitempty"` + SSLKeyfile string `json:"ssl_keyfile,omitempty"` + SSLCertfile string `json:"ssl_certfile,omitempty"` + SSLCACerts string `json:"ssl_ca_certs,omitempty"` + SSLCertReqs int `json:"ssl_cert_reqs,omitempty"` + RootPath string `json:"root_path,omitempty"` + Middleware []string `json:"middleware,omitempty"` + ReturnTokensAsTokenIDS bool `json:"return_tokens_as_token_ids,omitempty"` + DisableFrontendMultiprocessing bool `json:"disable_frontend_multiprocessing,omitempty"` + EnableAutoToolChoice bool `json:"enable_auto_tool_choice,omitempty"` + ToolCallParser string `json:"tool_call_parser,omitempty"` + ToolServer string `json:"tool_server,omitempty"` + ChatTemplate string `json:"chat_template,omitempty"` + ChatTemplateContentFormat string `json:"chat_template_content_format,omitempty"` + AllowCredentials bool `json:"allow_credentials,omitempty"` + AllowedOrigins []string `json:"allowed_origins,omitempty"` + AllowedMethods []string `json:"allowed_methods,omitempty"` + AllowedHeaders []string `json:"allowed_headers,omitempty"` + APIKey []string `json:"api_key,omitempty"` + EnableLogOutputs bool `json:"enable_log_outputs,omitempty"` + EnableTokenUsage bool `json:"enable_token_usage,omitempty"` + EnableAsyncEngineDebug bool `json:"enable_async_engine_debug,omitempty"` + EngineUseRay bool `json:"engine_use_ray,omitempty"` + DisableLogRequests bool `json:"disable_log_requests,omitempty"` + MaxLogLen int `json:"max_log_len,omitempty"` // Additional engine configuration - Task string `json:"task,omitempty"` - MultiModalConfig string `json:"multi_modal_config,omitempty"` - LimitMmPerPrompt string `json:"limit_mm_per_prompt,omitempty"` - EnableSleepMode bool `json:"enable_sleep_mode,omitempty"` - EnableChunkingRequest bool `json:"enable_chunking_request,omitempty"` - CompilationConfig string `json:"compilation_config,omitempty"` - DisableSlidingWindowMask bool `json:"disable_sliding_window_mask,omitempty"` - EnableTRTLLMEngineLatency bool `json:"enable_trtllm_engine_latency,omitempty"` - OverridePoolingConfig string `json:"override_pooling_config,omitempty"` - OverrideNeuronConfig string `json:"override_neuron_config,omitempty"` - OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"` -} - -// NewVllmServerOptions creates a new VllmServerOptions with defaults -func NewVllmServerOptions() *VllmServerOptions { - return &VllmServerOptions{ - Host: "127.0.0.1", - Port: 8000, - TensorParallelSize: 1, - PipelineParallelSize: 1, - GPUMemoryUtilization: 0.9, - BlockSize: 16, - SwapSpace: 4, - UvicornLogLevel: "info", - ResponseRole: "assistant", - TokenizerMode: "auto", - TrustRemoteCode: false, - EnablePrefixCaching: false, - EnforceEager: false, - DisableLogStats: false, - DisableLogRequests: false, - MaxLogprobs: 20, - EnableLogOutputs: false, - EnableTokenUsage: false, - AllowCredentials: false, - AllowedOrigins: []string{"*"}, - AllowedMethods: []string{"*"}, - AllowedHeaders: []string{"*"}, - } -} - -// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names -func (o *VllmServerOptions) UnmarshalJSON(data []byte) error { - // First unmarshal into a map to handle multiple field names - var raw map[string]any - if err := json.Unmarshal(data, &raw); err != nil { - return err - } - - // Create a temporary struct for standard unmarshaling - type tempOptions VllmServerOptions - temp := tempOptions{} - - // Standard unmarshal first - if err := json.Unmarshal(data, &temp); err != nil { - return err - } - - // Copy to our struct - *o = VllmServerOptions(temp) - - // Handle alternative field names (CLI format with dashes) - fieldMappings := map[string]string{ - // Basic options - "tensor-parallel-size": "tensor_parallel_size", - "pipeline-parallel-size": "pipeline_parallel_size", - "max-parallel-loading-workers": "max_parallel_loading_workers", - "disable-async-output-proc": "disable_async_output_proc", - "worker-class": "worker_class", - "enabled-lora-modules": "enabled_lora_modules", - "max-lora-rank": "max_lora_rank", - "fully-sharded-loras": "fully_sharded_loras", - "lora-modules": "lora_modules", - "prompt-adapters": "prompt_adapters", - "max-prompt-adapter-token": "max_prompt_adapter_token", - "scheduler-delay": "scheduler_delay", - "enable-chunked-prefill": "enable_chunked_prefill", - "speculative-model": "speculative_model", - "speculative-model-quantization": "speculative_model_quantization", - "speculative-revision": "speculative_revision", - "speculative-max-model-len": "speculative_max_model_len", - "speculative-disable-by-batch-size": "speculative_disable_by_batch_size", - "ngpt-speculative-length": "ngpt_speculative_length", - "speculative-disable-mqa": "speculative_disable_mqa", - "model-loader-extra-config": "model_loader_extra_config", - "ignore-patterns": "ignore_patterns", - "preloaded-lora-modules": "preloaded_lora_modules", - - // Model configuration - "skip-tokenizer-init": "skip_tokenizer_init", - "code-revision": "code_revision", - "tokenizer-revision": "tokenizer_revision", - "tokenizer-mode": "tokenizer_mode", - "trust-remote-code": "trust_remote_code", - "download-dir": "download_dir", - "load-format": "load_format", - "config-format": "config_format", - "kv-cache-dtype": "kv_cache_dtype", - "quantization-param-path": "quantization_param_path", - "max-model-len": "max_model_len", - "guided-decoding-backend": "guided_decoding_backend", - "distributed-executor-backend": "distributed_executor_backend", - "worker-use-ray": "worker_use_ray", - "ray-workers-use-nsight": "ray_workers_use_nsight", - - // Performance configuration - "block-size": "block_size", - "enable-prefix-caching": "enable_prefix_caching", - "disable-sliding-window": "disable_sliding_window", - "use-v2-block-manager": "use_v2_block_manager", - "num-lookahead-slots": "num_lookahead_slots", - "swap-space": "swap_space", - "cpu-offload-gb": "cpu_offload_gb", - "gpu-memory-utilization": "gpu_memory_utilization", - "num-gpu-blocks-override": "num_gpu_blocks_override", - "max-num-batched-tokens": "max_num_batched_tokens", - "max-num-seqs": "max_num_seqs", - "max-logprobs": "max_logprobs", - "disable-log-stats": "disable_log_stats", - "rope-scaling": "rope_scaling", - "rope-theta": "rope_theta", - "enforce-eager": "enforce_eager", - "max-context-len-to-capture": "max_context_len_to_capture", - "max-seq-len-to-capture": "max_seq_len_to_capture", - "disable-custom-all-reduce": "disable_custom_all_reduce", - "tokenizer-pool-size": "tokenizer_pool_size", - "tokenizer-pool-type": "tokenizer_pool_type", - "tokenizer-pool-extra-config": "tokenizer_pool_extra_config", - "enable-lora-bias": "enable_lora_bias", - "lora-extra-vocab-size": "lora_extra_vocab_size", - "lora-rank": "lora_rank", - "prompt-lookback-distance": "prompt_lookback_distance", - "preemption-mode": "preemption_mode", - - // Server configuration - "uvicorn-log-level": "uvicorn_log_level", - "response-role": "response_role", - "ssl-keyfile": "ssl_keyfile", - "ssl-certfile": "ssl_certfile", - "ssl-ca-certs": "ssl_ca_certs", - "ssl-cert-reqs": "ssl_cert_reqs", - "root-path": "root_path", - "return-tokens-as-token-ids": "return_tokens_as_token_ids", - "disable-frontend-multiprocessing": "disable_frontend_multiprocessing", - "enable-auto-tool-choice": "enable_auto_tool_choice", - "tool-call-parser": "tool_call_parser", - "tool-server": "tool_server", - "chat-template": "chat_template", - "chat-template-content-format": "chat_template_content_format", - "allow-credentials": "allow_credentials", - "allowed-origins": "allowed_origins", - "allowed-methods": "allowed_methods", - "allowed-headers": "allowed_headers", - "api-key": "api_key", - "enable-log-outputs": "enable_log_outputs", - "enable-token-usage": "enable_token_usage", - "enable-async-engine-debug": "enable_async_engine_debug", - "engine-use-ray": "engine_use_ray", - "disable-log-requests": "disable_log_requests", - "max-log-len": "max_log_len", - - // Additional options - "multi-modal-config": "multi_modal_config", - "limit-mm-per-prompt": "limit_mm_per_prompt", - "enable-sleep-mode": "enable_sleep_mode", - "enable-chunking-request": "enable_chunking_request", - "compilation-config": "compilation_config", - "disable-sliding-window-mask": "disable_sliding_window_mask", - "enable-trtllm-engine-latency": "enable_trtllm_engine_latency", - "override-pooling-config": "override_pooling_config", - "override-neuron-config": "override_neuron_config", - "override-kv-cache-align-size": "override_kv_cache_align_size", - } - - // Process alternative field names - for altName, canonicalName := range fieldMappings { - if value, exists := raw[altName]; exists { - // Use reflection to set the field value - v := reflect.ValueOf(o).Elem() - field := v.FieldByNameFunc(func(fieldName string) bool { - field, _ := v.Type().FieldByName(fieldName) - jsonTag := field.Tag.Get("json") - return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName - }) - - if field.IsValid() && field.CanSet() { - switch field.Kind() { - case reflect.Int: - if intVal, ok := value.(float64); ok { - field.SetInt(int64(intVal)) - } else if strVal, ok := value.(string); ok { - if intVal, err := strconv.Atoi(strVal); err == nil { - field.SetInt(int64(intVal)) - } - } - case reflect.Float64: - if floatVal, ok := value.(float64); ok { - field.SetFloat(floatVal) - } else if strVal, ok := value.(string); ok { - if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil { - field.SetFloat(floatVal) - } - } - case reflect.String: - if strVal, ok := value.(string); ok { - field.SetString(strVal) - } - case reflect.Bool: - if boolVal, ok := value.(bool); ok { - field.SetBool(boolVal) - } - case reflect.Slice: - if field.Type().Elem().Kind() == reflect.String { - if strVal, ok := value.(string); ok { - // Split comma-separated values - values := strings.Split(strVal, ",") - for i, v := range values { - values[i] = strings.TrimSpace(v) - } - field.Set(reflect.ValueOf(values)) - } else if slice, ok := value.([]interface{}); ok { - var strSlice []string - for _, item := range slice { - if str, ok := item.(string); ok { - strSlice = append(strSlice, str) - } - } - field.Set(reflect.ValueOf(strSlice)) - } - } - } - } - } - } - - return nil + Task string `json:"task,omitempty"` + MultiModalConfig string `json:"multi_modal_config,omitempty"` + LimitMmPerPrompt string `json:"limit_mm_per_prompt,omitempty"` + EnableSleepMode bool `json:"enable_sleep_mode,omitempty"` + EnableChunkingRequest bool `json:"enable_chunking_request,omitempty"` + CompilationConfig string `json:"compilation_config,omitempty"` + DisableSlidingWindowMask bool `json:"disable_sliding_window_mask,omitempty"` + EnableTRTLLMEngineLatency bool `json:"enable_trtllm_engine_latency,omitempty"` + OverridePoolingConfig string `json:"override_pooling_config,omitempty"` + OverrideNeuronConfig string `json:"override_neuron_config,omitempty"` + OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"` } // BuildCommandArgs converts VllmServerOptions to command line arguments @@ -387,11 +161,6 @@ func (o *VllmServerOptions) BuildCommandArgs() []string { flagName = jsonTag[:commaIndex] } - // Skip host and port as they are handled by llamactl - if flagName == "host" || flagName == "port" { - continue - } - // Convert snake_case to kebab-case for CLI flags flagName = strings.ReplaceAll(flagName, "_", "-") @@ -436,4 +205,4 @@ func (o *VllmServerOptions) BuildCommandArgs() []string { } return args -} \ No newline at end of file +} diff --git a/pkg/backends/vllm/vllm_test.go b/pkg/backends/vllm/vllm_test.go index e05320a..8a42862 100644 --- a/pkg/backends/vllm/vllm_test.go +++ b/pkg/backends/vllm/vllm_test.go @@ -10,12 +10,12 @@ import ( func TestBuildCommandArgs(t *testing.T) { options := vllm.VllmServerOptions{ Model: "microsoft/DialoGPT-medium", - Port: 8080, // should be excluded - Host: "localhost", // should be excluded + Port: 8080, + Host: "localhost", TensorParallelSize: 2, GPUMemoryUtilization: 0.8, EnableLogOutputs: true, - APIKey: []string{"key1", "key2"}, + AllowedOrigins: []string{"http://localhost:3000", "https://example.com"}, } args := options.BuildCommandArgs() @@ -32,19 +32,22 @@ func TestBuildCommandArgs(t *testing.T) { } // Host and port should NOT be in the arguments (handled by llamactl) - if contains(args, "--host") || contains(args, "--port") { - t.Errorf("Host and port should not be in command args, found in %v", args) + if !contains(args, "--host") { + t.Errorf("Expected --host not found in %v", args) + } + if !contains(args, "--port") { + t.Errorf("Expected --port not found in %v", args) } // Check array handling (multiple flags) - apiKeyCount := 0 + allowedOriginsCount := 0 for i := range args { - if args[i] == "--api-key" { - apiKeyCount++ + if args[i] == "--allowed-origins" { + allowedOriginsCount++ } } - if apiKeyCount != 2 { - t.Errorf("Expected 2 --api-key flags, got %d", apiKeyCount) + if allowedOriginsCount != 2 { + t.Errorf("Expected 2 --allowed-origins flags, got %d", allowedOriginsCount) } } @@ -77,20 +80,6 @@ func TestUnmarshalJSON(t *testing.T) { } } -func TestNewVllmServerOptions(t *testing.T) { - options := vllm.NewVllmServerOptions() - - if options == nil { - t.Fatal("NewVllmServerOptions returned nil") - } - if options.Host != "127.0.0.1" { - t.Errorf("Expected default host '127.0.0.1', got %q", options.Host) - } - if options.Port != 8000 { - t.Errorf("Expected default port 8000, got %d", options.Port) - } -} - // Helper functions func contains(slice []string, item string) bool { return slices.Contains(slice, item) @@ -103,4 +92,4 @@ func containsFlagWithValue(args []string, flag, value string) bool { } } return false -} \ No newline at end of file +} From ec5485bd0e6ba40d72e41c4572e26eb20e636d65 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 19 Sep 2025 19:46:54 +0200 Subject: [PATCH 04/20] Refactor command argument building across backends --- pkg/backends/llamacpp/llama.go | 64 ++--------------- pkg/backends/mlx/mlx.go | 115 +++++++++++++++++++------------ pkg/backends/parser.go | 121 +++++++++++++++++++++++++++++++++ pkg/backends/vllm/vllm.go | 85 ++++------------------- 4 files changed, 209 insertions(+), 176 deletions(-) diff --git a/pkg/backends/llamacpp/llama.go b/pkg/backends/llamacpp/llama.go index c838141..cfad2fd 100644 --- a/pkg/backends/llamacpp/llama.go +++ b/pkg/backends/llamacpp/llama.go @@ -2,9 +2,9 @@ package llamacpp import ( "encoding/json" + "llamactl/pkg/backends" "reflect" "strconv" - "strings" ) type LlamaServerOptions struct { @@ -313,64 +313,10 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error { return nil } -// BuildCommandArgs converts InstanceOptions to command line arguments +// BuildCommandArgs converts InstanceOptions to command line arguments using the common builder func (o *LlamaServerOptions) BuildCommandArgs() []string { - var args []string - - v := reflect.ValueOf(o).Elem() - t := v.Type() - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - fieldType := t.Field(i) - - // Skip unexported fields - if !field.CanInterface() { - continue - } - - // Get the JSON tag to determine the flag name - jsonTag := fieldType.Tag.Get("json") - if jsonTag == "" || jsonTag == "-" { - continue - } - - // Remove ",omitempty" from the tag - flagName := jsonTag - if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 { - flagName = jsonTag[:commaIndex] - } - - // Convert snake_case to kebab-case for CLI flags - flagName = strings.ReplaceAll(flagName, "_", "-") - - // Add the appropriate arguments based on field type and value - switch field.Kind() { - case reflect.Bool: - if field.Bool() { - args = append(args, "--"+flagName) - } - case reflect.Int: - if field.Int() != 0 { - args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) - } - case reflect.Float64: - if field.Float() != 0 { - args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) - } - case reflect.String: - if field.String() != "" { - args = append(args, "--"+flagName, field.String()) - } - case reflect.Slice: - if field.Type().Elem().Kind() == reflect.String { - // Handle []string fields - for j := 0; j < field.Len(); j++ { - args = append(args, "--"+flagName, field.Index(j).String()) - } - } - } + config := backends.ArgsBuilderConfig{ + SliceHandling: backends.SliceAsMultipleFlags, // Llama uses multiple flags for arrays } - - return args + return backends.BuildCommandArgs(o, config) } diff --git a/pkg/backends/mlx/mlx.go b/pkg/backends/mlx/mlx.go index 8527c7b..9b29010 100644 --- a/pkg/backends/mlx/mlx.go +++ b/pkg/backends/mlx/mlx.go @@ -1,9 +1,10 @@ package mlx import ( + "encoding/json" + "llamactl/pkg/backends" "reflect" "strconv" - "strings" ) type MlxServerOptions struct { @@ -32,57 +33,83 @@ type MlxServerOptions struct { MaxTokens int `json:"max_tokens,omitempty"` } -// BuildCommandArgs converts to command line arguments using reflection -func (o *MlxServerOptions) BuildCommandArgs() []string { - var args []string +// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names +func (o *MlxServerOptions) UnmarshalJSON(data []byte) error { + // First unmarshal into a map to handle multiple field names + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return err + } - v := reflect.ValueOf(o).Elem() - t := v.Type() + // Create a temporary struct for standard unmarshaling + type tempOptions MlxServerOptions + temp := tempOptions{} - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - fieldType := t.Field(i) + // Standard unmarshal first + if err := json.Unmarshal(data, &temp); err != nil { + return err + } - // Skip unexported fields - if !field.CanInterface() { - continue - } + // Copy to our struct + *o = MlxServerOptions(temp) - // Get the JSON tag to determine the flag name - jsonTag := fieldType.Tag.Get("json") - if jsonTag == "" || jsonTag == "-" { - continue - } + // Handle alternative field names + fieldMappings := map[string]string{ + "m": "model", // -m, --model + "temperature": "temp", // --temperature vs --temp + "top_k": "top_k", // --top-k + "adapter_path": "adapter_path", // --adapter-path + } - // Remove ",omitempty" from the tag - flagName := jsonTag - if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 { - flagName = jsonTag[:commaIndex] - } + // Process alternative field names + for altName, canonicalName := range fieldMappings { + if value, exists := raw[altName]; exists { + // Use reflection to set the field value + v := reflect.ValueOf(o).Elem() + field := v.FieldByNameFunc(func(fieldName string) bool { + field, _ := v.Type().FieldByName(fieldName) + jsonTag := field.Tag.Get("json") + return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName + }) - // Convert snake_case to kebab-case for CLI flags - flagName = strings.ReplaceAll(flagName, "_", "-") - - // Add the appropriate arguments based on field type and value - switch field.Kind() { - case reflect.Bool: - if field.Bool() { - args = append(args, "--"+flagName) - } - case reflect.Int: - if field.Int() != 0 { - args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) - } - case reflect.Float64: - if field.Float() != 0 { - args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) - } - case reflect.String: - if field.String() != "" { - args = append(args, "--"+flagName, field.String()) + if field.IsValid() && field.CanSet() { + switch field.Kind() { + case reflect.Int: + if intVal, ok := value.(float64); ok { + field.SetInt(int64(intVal)) + } else if strVal, ok := value.(string); ok { + if intVal, err := strconv.Atoi(strVal); err == nil { + field.SetInt(int64(intVal)) + } + } + case reflect.Float64: + if floatVal, ok := value.(float64); ok { + field.SetFloat(floatVal) + } else if strVal, ok := value.(string); ok { + if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil { + field.SetFloat(floatVal) + } + } + case reflect.String: + if strVal, ok := value.(string); ok { + field.SetString(strVal) + } + case reflect.Bool: + if boolVal, ok := value.(bool); ok { + field.SetBool(boolVal) + } + } } } } - return args + return nil +} + +// BuildCommandArgs converts to command line arguments using the common builder +func (o *MlxServerOptions) BuildCommandArgs() []string { + config := backends.ArgsBuilderConfig{ + SliceHandling: backends.SliceAsMultipleFlags, // MLX doesn't currently have []string fields, but default to multiple flags + } + return backends.BuildCommandArgs(o, config) } diff --git a/pkg/backends/parser.go b/pkg/backends/parser.go index 5023a7e..0e34398 100644 --- a/pkg/backends/parser.go +++ b/pkg/backends/parser.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "path/filepath" + "reflect" "regexp" "strconv" "strings" @@ -308,3 +309,123 @@ func isFlag(arg string) bool { return true } + +// SliceHandling defines how []string fields should be handled when building command args +type SliceHandling int + +const ( + // SliceAsMultipleFlags creates multiple flags: --flag value1 --flag value2 + SliceAsMultipleFlags SliceHandling = iota + // SliceAsCommaSeparated creates single flag with comma-separated values: --flag value1,value2 + SliceAsCommaSeparated + // SliceAsMixed uses different strategies for different flags (requires configuration) + SliceAsMixed +) + +// ArgsBuilderConfig holds configuration for building command line arguments +type ArgsBuilderConfig struct { + // SliceHandling defines the default strategy for []string fields + SliceHandling SliceHandling + // MultipleFlags specifies which flags should use multiple instances when SliceHandling is SliceAsMixed + MultipleFlags map[string]struct{} +} + +// BuildCommandArgs converts a struct to command line arguments using reflection +func BuildCommandArgs(options any, config ArgsBuilderConfig) []string { + var args []string + + v := reflect.ValueOf(options).Elem() + t := v.Type() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldType := t.Field(i) + + // Skip unexported fields + if !field.CanInterface() { + continue + } + + // Get the JSON tag to determine the flag name + jsonTag := fieldType.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } + + // Remove ",omitempty" from the tag + flagName := jsonTag + if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 { + flagName = jsonTag[:commaIndex] + } + + // Convert snake_case to kebab-case for CLI flags + flagName = strings.ReplaceAll(flagName, "_", "-") + + // Add the appropriate arguments based on field type and value + switch field.Kind() { + case reflect.Bool: + if field.Bool() { + args = append(args, "--"+flagName) + } + case reflect.Int: + if field.Int() != 0 { + args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) + } + case reflect.Float64: + if field.Float() != 0 { + args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) + } + case reflect.String: + if field.String() != "" { + args = append(args, "--"+flagName, field.String()) + } + case reflect.Slice: + if field.Type().Elem().Kind() == reflect.String { + args = append(args, handleStringSlice(field, flagName, config)...) + } + } + } + + return args +} + +// handleStringSlice handles []string fields based on the configuration +func handleStringSlice(field reflect.Value, flagName string, config ArgsBuilderConfig) []string { + if field.Len() == 0 { + return nil + } + + var args []string + + switch config.SliceHandling { + case SliceAsMultipleFlags: + // Multiple flags: --flag value1 --flag value2 + for j := 0; j < field.Len(); j++ { + args = append(args, "--"+flagName, field.Index(j).String()) + } + case SliceAsCommaSeparated: + // Comma-separated: --flag value1,value2 + var values []string + for j := 0; j < field.Len(); j++ { + values = append(values, field.Index(j).String()) + } + args = append(args, "--"+flagName, strings.Join(values, ",")) + case SliceAsMixed: + // Check if this specific flag should use multiple instances + if _, useMultiple := config.MultipleFlags[flagName]; useMultiple { + // Multiple flags + for j := 0; j < field.Len(); j++ { + args = append(args, "--"+flagName, field.Index(j).String()) + } + } else { + // Comma-separated + var values []string + for j := 0; j < field.Len(); j++ { + values = append(values, field.Index(j).String()) + } + args = append(args, "--"+flagName, strings.Join(values, ",")) + } + } + + return args +} diff --git a/pkg/backends/vllm/vllm.go b/pkg/backends/vllm/vllm.go index 9aa865c..81e567d 100644 --- a/pkg/backends/vllm/vllm.go +++ b/pkg/backends/vllm/vllm.go @@ -1,9 +1,7 @@ package vllm import ( - "reflect" - "strconv" - "strings" + "llamactl/pkg/backends" ) type VllmServerOptions struct { @@ -132,77 +130,18 @@ type VllmServerOptions struct { OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"` } -// BuildCommandArgs converts VllmServerOptions to command line arguments +// BuildCommandArgs converts VllmServerOptions to command line arguments using the common builder // Note: This does NOT include the "serve" subcommand, that's handled at the instance level func (o *VllmServerOptions) BuildCommandArgs() []string { - var args []string - - v := reflect.ValueOf(o).Elem() - t := v.Type() - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - fieldType := t.Field(i) - - // Skip unexported fields - if !field.CanInterface() { - continue - } - - // Get the JSON tag to determine the flag name - jsonTag := fieldType.Tag.Get("json") - if jsonTag == "" || jsonTag == "-" { - continue - } - - // Remove ",omitempty" from the tag - flagName := jsonTag - if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 { - flagName = jsonTag[:commaIndex] - } - - // Convert snake_case to kebab-case for CLI flags - flagName = strings.ReplaceAll(flagName, "_", "-") - - // Add the appropriate arguments based on field type and value - switch field.Kind() { - case reflect.Bool: - if field.Bool() { - args = append(args, "--"+flagName) - } - case reflect.Int: - if field.Int() != 0 { - args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) - } - case reflect.Float64: - if field.Float() != 0 { - args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) - } - case reflect.String: - if field.String() != "" { - args = append(args, "--"+flagName, field.String()) - } - case reflect.Slice: - if field.Type().Elem().Kind() == reflect.String { - // Handle []string fields - some are comma-separated, some use multiple flags - if flagName == "api-key" || flagName == "allowed-origins" || flagName == "allowed-methods" || flagName == "allowed-headers" || flagName == "middleware" { - // Multiple flags for these - for j := 0; j < field.Len(); j++ { - args = append(args, "--"+flagName, field.Index(j).String()) - } - } else { - // Comma-separated for others - if field.Len() > 0 { - var values []string - for j := 0; j < field.Len(); j++ { - values = append(values, field.Index(j).String()) - } - args = append(args, "--"+flagName, strings.Join(values, ",")) - } - } - } - } + config := backends.ArgsBuilderConfig{ + SliceHandling: backends.SliceAsMixed, + MultipleFlags: map[string]struct{}{ + "api-key": {}, + "allowed-origins": {}, + "allowed-methods": {}, + "allowed-headers": {}, + "middleware": {}, + }, } - - return args + return backends.BuildCommandArgs(o, config) } From 34a949d22e5438f77dc53d3210e16467886b9efa Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 19 Sep 2025 19:59:46 +0200 Subject: [PATCH 05/20] Refactor command argument building and parsing --- pkg/backends/llamacpp/llama.go | 16 +- pkg/backends/llamacpp/parser.go | 25 +- pkg/backends/mlx/mlx.go | 8 +- pkg/backends/mlx/parser.go | 21 +- pkg/backends/parser.go | 517 +++++++++++--------------------- pkg/backends/vllm/parser.go | 24 +- pkg/backends/vllm/vllm.go | 19 +- 7 files changed, 223 insertions(+), 407 deletions(-) diff --git a/pkg/backends/llamacpp/llama.go b/pkg/backends/llamacpp/llama.go index cfad2fd..7c8a21f 100644 --- a/pkg/backends/llamacpp/llama.go +++ b/pkg/backends/llamacpp/llama.go @@ -313,10 +313,18 @@ func (o *LlamaServerOptions) UnmarshalJSON(data []byte) error { return nil } -// BuildCommandArgs converts InstanceOptions to command line arguments using the common builder +// BuildCommandArgs converts InstanceOptions to command line arguments func (o *LlamaServerOptions) BuildCommandArgs() []string { - config := backends.ArgsBuilderConfig{ - SliceHandling: backends.SliceAsMultipleFlags, // Llama uses multiple flags for arrays + // Llama uses multiple flags for arrays by default (not comma-separated) + multipleFlags := map[string]bool{ + "override-tensor": true, + "override-kv": true, + "lora": true, + "lora-scaled": true, + "control-vector": true, + "control-vector-scaled": true, + "dry-sequence-breaker": true, + "logit-bias": true, } - return backends.BuildCommandArgs(o, config) + return backends.BuildCommandArgs(o, multipleFlags) } diff --git a/pkg/backends/llamacpp/parser.go b/pkg/backends/llamacpp/parser.go index b5ce1f9..b5b850a 100644 --- a/pkg/backends/llamacpp/parser.go +++ b/pkg/backends/llamacpp/parser.go @@ -11,22 +11,21 @@ import ( // 3. Args only: "--model file.gguf --gpu-layers 32" // 4. Multiline commands with backslashes func ParseLlamaCommand(command string) (*LlamaServerOptions, error) { - config := backends.CommandParserConfig{ - ExecutableNames: []string{"llama-server"}, - MultiValuedFlags: map[string]struct{}{ - "override_tensor": {}, - "override_kv": {}, - "lora": {}, - "lora_scaled": {}, - "control_vector": {}, - "control_vector_scaled": {}, - "dry_sequence_breaker": {}, - "logit_bias": {}, - }, + executableNames := []string{"llama-server"} + var subcommandNames []string // Llama has no subcommands + multiValuedFlags := map[string]bool{ + "override_tensor": true, + "override_kv": true, + "lora": true, + "lora_scaled": true, + "control_vector": true, + "control_vector_scaled": true, + "dry_sequence_breaker": true, + "logit_bias": true, } var llamaOptions LlamaServerOptions - if err := backends.ParseCommand(command, config, &llamaOptions); err != nil { + if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &llamaOptions); err != nil { return nil, err } diff --git a/pkg/backends/mlx/mlx.go b/pkg/backends/mlx/mlx.go index 9b29010..c72597c 100644 --- a/pkg/backends/mlx/mlx.go +++ b/pkg/backends/mlx/mlx.go @@ -106,10 +106,8 @@ func (o *MlxServerOptions) UnmarshalJSON(data []byte) error { return nil } -// BuildCommandArgs converts to command line arguments using the common builder +// BuildCommandArgs converts to command line arguments func (o *MlxServerOptions) BuildCommandArgs() []string { - config := backends.ArgsBuilderConfig{ - SliceHandling: backends.SliceAsMultipleFlags, // MLX doesn't currently have []string fields, but default to multiple flags - } - return backends.BuildCommandArgs(o, config) + multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields + return backends.BuildCommandArgs(o, multipleFlags) } diff --git a/pkg/backends/mlx/parser.go b/pkg/backends/mlx/parser.go index 01fad0e..ec4cfb2 100644 --- a/pkg/backends/mlx/parser.go +++ b/pkg/backends/mlx/parser.go @@ -11,27 +11,14 @@ import ( // 3. Args only: "--model model/path --host 0.0.0.0" // 4. Multiline commands with backslashes func ParseMlxCommand(command string) (*MlxServerOptions, error) { - config := backends.CommandParserConfig{ - ExecutableNames: []string{"mlx_lm.server"}, - MultiValuedFlags: map[string]struct{}{}, // MLX has no multi-valued flags - } + executableNames := []string{"mlx_lm.server"} + var subcommandNames []string // MLX has no subcommands + multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags var mlxOptions MlxServerOptions - if err := backends.ParseCommand(command, config, &mlxOptions); err != nil { + if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil { return nil, err } return &mlxOptions, nil } - -// isValidLogLevel validates MLX log levels -func isValidLogLevel(level string) bool { - validLevels := []string{"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} - for _, valid := range validLevels { - if level == valid { - return true - } - } - return false -} - diff --git a/pkg/backends/parser.go b/pkg/backends/parser.go index 0e34398..721173f 100644 --- a/pkg/backends/parser.go +++ b/pkg/backends/parser.go @@ -2,7 +2,6 @@ package backends import ( "encoding/json" - "errors" "fmt" "path/filepath" "reflect" @@ -11,327 +10,41 @@ import ( "strings" ) -// CommandParserConfig holds configuration for parsing command line arguments -type CommandParserConfig struct { - // ExecutableNames are the names of executables to detect (e.g., "llama-server", "mlx_lm.server") - ExecutableNames []string - // SubcommandNames are optional subcommands (e.g., "serve" for vllm) - SubcommandNames []string - // MultiValuedFlags are flags that can accept multiple values - MultiValuedFlags map[string]struct{} -} - -// ParseCommand parses a command string using the provided configuration -func ParseCommand(command string, config CommandParserConfig, target any) error { - // 1. Normalize the command - handle multiline with backslashes - trimmed := normalizeMultilineCommand(command) - if trimmed == "" { +// ParseCommand parses a command string into a target struct +func ParseCommand(command string, executableNames []string, subcommandNames []string, multiValuedFlags map[string]bool, target any) error { + // Normalize multiline commands + command = normalizeCommand(command) + if command == "" { return fmt.Errorf("command cannot be empty") } - // 2. Extract arguments from command - args, err := extractArgumentsFromCommand(trimmed, config) + // Extract arguments + args, err := extractArgs(command, executableNames, subcommandNames) if err != nil { return err } - // 3. Parse arguments into map - options := make(map[string]any) - - i := 0 - for i < len(args) { - arg := args[i] - - if !strings.HasPrefix(arg, "-") { // skip positional / stray values - i++ - continue - } - - // Reject malformed flags with more than two leading dashes (e.g. ---model) to surface user mistakes - if strings.HasPrefix(arg, "---") { - return fmt.Errorf("malformed flag: %s", arg) - } - - // Unified parsing for --flag=value vs --flag value - var rawFlag, rawValue string - hasEquals := false - if strings.Contains(arg, "=") { - parts := strings.SplitN(arg, "=", 2) - rawFlag = parts[0] - rawValue = parts[1] // may be empty string - hasEquals = true - } else { - rawFlag = arg - } - - flagCore := strings.TrimPrefix(strings.TrimPrefix(rawFlag, "-"), "-") - flagName := strings.ReplaceAll(flagCore, "-", "_") - - // Detect value if not in equals form - valueProvided := hasEquals - if !hasEquals { - if i+1 < len(args) && !isFlag(args[i+1]) { // next token is value - rawValue = args[i+1] - valueProvided = true - } - } - - // Determine if multi-valued flag - _, isMulti := config.MultiValuedFlags[flagName] - - // Normalization helper: ensure slice for multi-valued flags - appendValue := func(valStr string) { - if existing, ok := options[flagName]; ok { - // Existing value; ensure slice semantics for multi-valued flags or repeated occurrences - if slice, ok := existing.([]string); ok { - options[flagName] = append(slice, valStr) - return - } - // Convert scalar to slice - options[flagName] = []string{fmt.Sprintf("%v", existing), valStr} - return - } - // First value - if isMulti { - options[flagName] = []string{valStr} - } else { - // We'll parse type below for single-valued flags - options[flagName] = valStr - } - } - - if valueProvided { - // Use raw token for multi-valued flags; else allow typed parsing - appendValue(rawValue) - if !isMulti { // convert to typed value if scalar - if strVal, ok := options[flagName].(string); ok { // still scalar - options[flagName] = parseValue(strVal) - } - } - // Advance index: if we consumed a following token as value (non equals form), skip it - if !hasEquals && i+1 < len(args) && rawValue == args[i+1] { - i += 2 - } else { - i++ - } - continue - } - - // Boolean flag (no value) - options[flagName] = true - i++ + // Parse flags into map + options, err := parseFlags(args, multiValuedFlags) + if err != nil { + return err } - // 4. Convert to target struct using JSON marshaling + // Convert to target struct via JSON jsonData, err := json.Marshal(options) if err != nil { - return fmt.Errorf("failed to marshal parsed options: %w", err) + return fmt.Errorf("failed to marshal options: %w", err) } if err := json.Unmarshal(jsonData, target); err != nil { - return fmt.Errorf("failed to parse command options: %w", err) + return fmt.Errorf("failed to unmarshal to target: %w", err) } return nil } -// parseValue attempts to parse a string value into the most appropriate type -func parseValue(value string) any { - // Surrounding matching quotes (single or double) - if l := len(value); l >= 2 { - if (value[0] == '"' && value[l-1] == '"') || (value[0] == '\'' && value[l-1] == '\'') { - value = value[1 : l-1] - } - } - - lower := strings.ToLower(value) - if lower == "true" { - return true - } - if lower == "false" { - return false - } - - if intVal, err := strconv.Atoi(value); err == nil { - return intVal - } - if floatVal, err := strconv.ParseFloat(value, 64); err == nil { - return floatVal - } - return value -} - -// normalizeMultilineCommand handles multiline commands with backslashes -func normalizeMultilineCommand(command string) string { - // Handle escaped newlines (backslash followed by newline) - re := regexp.MustCompile(`\\\s*\n\s*`) - normalized := re.ReplaceAllString(command, " ") - - // Clean up extra whitespace - re = regexp.MustCompile(`\s+`) - normalized = re.ReplaceAllString(normalized, " ") - - return strings.TrimSpace(normalized) -} - -// extractArgumentsFromCommand extracts arguments from various command formats -func extractArgumentsFromCommand(command string, config CommandParserConfig) ([]string, error) { - // Split command into tokens respecting quotes - tokens, err := splitCommandTokens(command) - if err != nil { - return nil, err - } - - if len(tokens) == 0 { - return nil, fmt.Errorf("no command tokens found") - } - - firstToken := tokens[0] - - // Check for full path executable - if strings.Contains(firstToken, string(filepath.Separator)) { - baseName := filepath.Base(firstToken) - for _, execName := range config.ExecutableNames { - if strings.HasSuffix(baseName, execName) { - return skipExecutableAndSubcommands(tokens[1:], config.SubcommandNames) - } - } - // Unknown executable, assume it's still an executable - return skipExecutableAndSubcommands(tokens[1:], config.SubcommandNames) - } - - // Check for simple executable names - lowerFirstToken := strings.ToLower(firstToken) - for _, execName := range config.ExecutableNames { - if lowerFirstToken == strings.ToLower(execName) { - return skipExecutableAndSubcommands(tokens[1:], config.SubcommandNames) - } - } - - // Check for subcommands (like "serve" for vllm) - for _, subCmd := range config.SubcommandNames { - if lowerFirstToken == strings.ToLower(subCmd) { - return tokens[1:], nil // Return everything except the subcommand - } - } - - // Arguments only (starts with a flag) - if strings.HasPrefix(firstToken, "-") { - return tokens, nil // Return all tokens as arguments - } - - // Unknown format - might be a different executable name - return skipExecutableAndSubcommands(tokens[1:], config.SubcommandNames) -} - -// skipExecutableAndSubcommands removes subcommands from the beginning of tokens -func skipExecutableAndSubcommands(tokens []string, subcommands []string) ([]string, error) { - if len(tokens) == 0 { - return tokens, nil - } - - // Check if first token is a subcommand - if len(subcommands) > 0 && len(tokens) > 0 { - lowerFirstToken := strings.ToLower(tokens[0]) - for _, subCmd := range subcommands { - if lowerFirstToken == strings.ToLower(subCmd) { - return tokens[1:], nil // Skip the subcommand - } - } - } - - return tokens, nil -} - -// splitCommandTokens splits a command string into tokens, respecting quotes -func splitCommandTokens(command string) ([]string, error) { - var tokens []string - var current strings.Builder - inQuotes := false - quoteChar := byte(0) - escaped := false - - for i := 0; i < len(command); i++ { - c := command[i] - - if escaped { - current.WriteByte(c) - escaped = false - continue - } - - if c == '\\' { - escaped = true - current.WriteByte(c) - continue - } - - if !inQuotes && (c == '"' || c == '\'') { - inQuotes = true - quoteChar = c - current.WriteByte(c) - } else if inQuotes && c == quoteChar { - inQuotes = false - quoteChar = 0 - current.WriteByte(c) - } else if !inQuotes && (c == ' ' || c == '\t' || c == '\n') { - if current.Len() > 0 { - tokens = append(tokens, current.String()) - current.Reset() - } - } else { - current.WriteByte(c) - } - } - - if inQuotes { - return nil, errors.New("unterminated quoted string") - } - - if current.Len() > 0 { - tokens = append(tokens, current.String()) - } - - return tokens, nil -} - -// isFlag determines if a string is a command line flag or a value -// Handles the special case where negative numbers should be treated as values, not flags -func isFlag(arg string) bool { - if !strings.HasPrefix(arg, "-") { - return false - } - - // Special case: if it's a negative number, treat it as a value - if _, err := strconv.ParseFloat(arg, 64); err == nil { - return false - } - - return true -} - -// SliceHandling defines how []string fields should be handled when building command args -type SliceHandling int - -const ( - // SliceAsMultipleFlags creates multiple flags: --flag value1 --flag value2 - SliceAsMultipleFlags SliceHandling = iota - // SliceAsCommaSeparated creates single flag with comma-separated values: --flag value1,value2 - SliceAsCommaSeparated - // SliceAsMixed uses different strategies for different flags (requires configuration) - SliceAsMixed -) - -// ArgsBuilderConfig holds configuration for building command line arguments -type ArgsBuilderConfig struct { - // SliceHandling defines the default strategy for []string fields - SliceHandling SliceHandling - // MultipleFlags specifies which flags should use multiple instances when SliceHandling is SliceAsMixed - MultipleFlags map[string]struct{} -} - -// BuildCommandArgs converts a struct to command line arguments using reflection -func BuildCommandArgs(options any, config ArgsBuilderConfig) []string { +// BuildCommandArgs converts a struct to command line arguments +func BuildCommandArgs(options any, multipleFlags map[string]bool) []string { var args []string v := reflect.ValueOf(options).Elem() @@ -341,27 +54,19 @@ func BuildCommandArgs(options any, config ArgsBuilderConfig) []string { field := v.Field(i) fieldType := t.Field(i) - // Skip unexported fields if !field.CanInterface() { continue } - // Get the JSON tag to determine the flag name jsonTag := fieldType.Tag.Get("json") if jsonTag == "" || jsonTag == "-" { continue } - // Remove ",omitempty" from the tag - flagName := jsonTag - if commaIndex := strings.Index(jsonTag, ","); commaIndex != -1 { - flagName = jsonTag[:commaIndex] - } - - // Convert snake_case to kebab-case for CLI flags + // Get flag name from JSON tag + flagName := strings.Split(jsonTag, ",")[0] flagName = strings.ReplaceAll(flagName, "_", "-") - // Add the appropriate arguments based on field type and value switch field.Kind() { case reflect.Bool: if field.Bool() { @@ -380,8 +85,20 @@ func BuildCommandArgs(options any, config ArgsBuilderConfig) []string { args = append(args, "--"+flagName, field.String()) } case reflect.Slice: - if field.Type().Elem().Kind() == reflect.String { - args = append(args, handleStringSlice(field, flagName, config)...) + if field.Type().Elem().Kind() == reflect.String && field.Len() > 0 { + if multipleFlags[flagName] { + // Multiple flags: --flag value1 --flag value2 + for j := 0; j < field.Len(); j++ { + args = append(args, "--"+flagName, field.Index(j).String()) + } + } else { + // Comma-separated: --flag value1,value2 + var values []string + for j := 0; j < field.Len(); j++ { + values = append(values, field.Index(j).String()) + } + args = append(args, "--"+flagName, strings.Join(values, ",")) + } } } } @@ -389,43 +106,155 @@ func BuildCommandArgs(options any, config ArgsBuilderConfig) []string { return args } -// handleStringSlice handles []string fields based on the configuration -func handleStringSlice(field reflect.Value, flagName string, config ArgsBuilderConfig) []string { - if field.Len() == 0 { - return nil +// normalizeCommand handles multiline commands with backslashes +func normalizeCommand(command string) string { + re := regexp.MustCompile(`\\\s*\n\s*`) + normalized := re.ReplaceAllString(command, " ") + re = regexp.MustCompile(`\s+`) + return strings.TrimSpace(re.ReplaceAllString(normalized, " ")) +} + +// extractArgs extracts arguments from command, removing executable and subcommands +func extractArgs(command string, executableNames []string, subcommandNames []string) ([]string, error) { + // Check for unterminated quotes + if strings.Count(command, `"`)%2 != 0 || strings.Count(command, `'`)%2 != 0 { + return nil, fmt.Errorf("unterminated quoted string") } - var args []string + tokens := strings.Fields(command) + if len(tokens) == 0 { + return nil, fmt.Errorf("no tokens found") + } - switch config.SliceHandling { - case SliceAsMultipleFlags: - // Multiple flags: --flag value1 --flag value2 - for j := 0; j < field.Len(); j++ { - args = append(args, "--"+flagName, field.Index(j).String()) + // Skip executable + start := 0 + firstToken := tokens[0] + + // Check for executable name (with or without path) + if strings.Contains(firstToken, string(filepath.Separator)) { + baseName := filepath.Base(firstToken) + for _, execName := range executableNames { + if strings.HasSuffix(strings.ToLower(baseName), strings.ToLower(execName)) { + start = 1 + break + } } - case SliceAsCommaSeparated: - // Comma-separated: --flag value1,value2 - var values []string - for j := 0; j < field.Len(); j++ { - values = append(values, field.Index(j).String()) + } else { + for _, execName := range executableNames { + if strings.EqualFold(firstToken, execName) { + start = 1 + break + } } - args = append(args, "--"+flagName, strings.Join(values, ",")) - case SliceAsMixed: - // Check if this specific flag should use multiple instances - if _, useMultiple := config.MultipleFlags[flagName]; useMultiple { - // Multiple flags - for j := 0; j < field.Len(); j++ { - args = append(args, "--"+flagName, field.Index(j).String()) + } + + // Skip subcommand if present + if start < len(tokens) { + for _, subCmd := range subcommandNames { + if strings.EqualFold(tokens[start], subCmd) { + start++ + break + } + } + } + + // Handle case where command starts with subcommand (no executable) + if start == 0 { + for _, subCmd := range subcommandNames { + if strings.EqualFold(firstToken, subCmd) { + start = 1 + break + } + } + } + + return tokens[start:], nil +} + +// parseFlags parses command line flags into a map +func parseFlags(args []string, multiValuedFlags map[string]bool) (map[string]any, error) { + options := make(map[string]any) + + for i := 0; i < len(args); i++ { + arg := args[i] + + if !strings.HasPrefix(arg, "-") { + continue + } + + // Check for malformed flags (more than two leading dashes) + if strings.HasPrefix(arg, "---") { + return nil, fmt.Errorf("malformed flag: %s", arg) + } + + // Get flag name and value + var flagName, value string + var hasValue bool + + if strings.Contains(arg, "=") { + parts := strings.SplitN(arg, "=", 2) + flagName = strings.TrimLeft(parts[0], "-") + value = parts[1] + hasValue = true + } else { + flagName = strings.TrimLeft(arg, "-") + if i+1 < len(args) && !strings.HasPrefix(args[i+1], "-") { + value = args[i+1] + hasValue = true + i++ // Skip next arg since we consumed it + } + } + + // Convert kebab-case to snake_case for JSON + flagName = strings.ReplaceAll(flagName, "-", "_") + + if hasValue { + // Handle multi-valued flags + if multiValuedFlags[flagName] { + if existing, ok := options[flagName].([]string); ok { + options[flagName] = append(existing, value) + } else { + options[flagName] = []string{value} + } + } else { + options[flagName] = parseValue(value) } } else { - // Comma-separated - var values []string - for j := 0; j < field.Len(); j++ { - values = append(values, field.Index(j).String()) - } - args = append(args, "--"+flagName, strings.Join(values, ",")) + // Boolean flag + options[flagName] = true } } - return args + return options, nil +} + +// parseValue converts string to appropriate type +func parseValue(value string) any { + // Remove quotes + if len(value) >= 2 { + if (value[0] == '"' && value[len(value)-1] == '"') || (value[0] == '\'' && value[len(value)-1] == '\'') { + value = value[1 : len(value)-1] + } + } + + // Try boolean + switch strings.ToLower(value) { + case "true": + return true + case "false": + return false + } + + // Try integer + if intVal, err := strconv.Atoi(value); err == nil { + return intVal + } + + // Try float + if floatVal, err := strconv.ParseFloat(value, 64); err == nil { + return floatVal + } + + // Return as string + return value } diff --git a/pkg/backends/vllm/parser.go b/pkg/backends/vllm/parser.go index cce935d..5eb3fbf 100644 --- a/pkg/backends/vllm/parser.go +++ b/pkg/backends/vllm/parser.go @@ -12,22 +12,20 @@ import ( // 4. Args only: "--model MODEL_NAME --other-args" // 5. Multiline commands with backslashes func ParseVllmCommand(command string) (*VllmServerOptions, error) { - config := backends.CommandParserConfig{ - ExecutableNames: []string{"vllm"}, - SubcommandNames: []string{"serve"}, - MultiValuedFlags: map[string]struct{}{ - "middleware": {}, - "api_key": {}, - "allowed_origins": {}, - "allowed_methods": {}, - "allowed_headers": {}, - "lora_modules": {}, - "prompt_adapters": {}, - }, + executableNames := []string{"vllm"} + subcommandNames := []string{"serve"} + multiValuedFlags := map[string]bool{ + "middleware": true, + "api_key": true, + "allowed_origins": true, + "allowed_methods": true, + "allowed_headers": true, + "lora_modules": true, + "prompt_adapters": true, } var vllmOptions VllmServerOptions - if err := backends.ParseCommand(command, config, &vllmOptions); err != nil { + if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil { return nil, err } diff --git a/pkg/backends/vllm/vllm.go b/pkg/backends/vllm/vllm.go index 81e567d..2ab6ed8 100644 --- a/pkg/backends/vllm/vllm.go +++ b/pkg/backends/vllm/vllm.go @@ -130,18 +130,15 @@ type VllmServerOptions struct { OverrideKVCacheALIGNSize int `json:"override_kv_cache_align_size,omitempty"` } -// BuildCommandArgs converts VllmServerOptions to command line arguments using the common builder +// BuildCommandArgs converts VllmServerOptions to command line arguments // Note: This does NOT include the "serve" subcommand, that's handled at the instance level func (o *VllmServerOptions) BuildCommandArgs() []string { - config := backends.ArgsBuilderConfig{ - SliceHandling: backends.SliceAsMixed, - MultipleFlags: map[string]struct{}{ - "api-key": {}, - "allowed-origins": {}, - "allowed-methods": {}, - "allowed-headers": {}, - "middleware": {}, - }, + multipleFlags := map[string]bool{ + "api-key": true, + "allowed-origins": true, + "allowed-methods": true, + "allowed-headers": true, + "middleware": true, } - return backends.BuildCommandArgs(o, config) + return backends.BuildCommandArgs(o, multipleFlags) } From 64842e74b07ef9ec8c750d559e5d224690d348bd Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 19 Sep 2025 20:23:25 +0200 Subject: [PATCH 06/20] Refactor command parsing and building --- pkg/backends/builder.go | 70 ++++++++++++++++ pkg/backends/llamacpp/llama.go | 28 +++++++ pkg/backends/llamacpp/llama_test.go | 115 +++++++++++++++++++++++++ pkg/backends/llamacpp/parser.go | 34 -------- pkg/backends/llamacpp/parser_test.go | 121 --------------------------- pkg/backends/mlx/mlx.go | 97 +++++---------------- pkg/backends/mlx/mlx_test.go | 95 +++++++++++++++++++++ pkg/backends/mlx/parser.go | 24 ------ pkg/backends/mlx/parser_test.go | 101 ---------------------- pkg/backends/parser.go | 64 -------------- pkg/backends/vllm/parser.go | 34 -------- pkg/backends/vllm/parser_test.go | 84 ------------------- pkg/backends/vllm/vllm.go | 28 +++++++ pkg/backends/vllm/vllm_test.go | 78 +++++++++++++++++ 14 files changed, 434 insertions(+), 539 deletions(-) create mode 100644 pkg/backends/builder.go delete mode 100644 pkg/backends/llamacpp/parser.go delete mode 100644 pkg/backends/llamacpp/parser_test.go delete mode 100644 pkg/backends/mlx/parser.go delete mode 100644 pkg/backends/mlx/parser_test.go delete mode 100644 pkg/backends/vllm/parser.go delete mode 100644 pkg/backends/vllm/parser_test.go diff --git a/pkg/backends/builder.go b/pkg/backends/builder.go new file mode 100644 index 0000000..23c3bb1 --- /dev/null +++ b/pkg/backends/builder.go @@ -0,0 +1,70 @@ +package backends + +import ( + "reflect" + "strconv" + "strings" +) + +// BuildCommandArgs converts a struct to command line arguments +func BuildCommandArgs(options any, multipleFlags map[string]bool) []string { + var args []string + + v := reflect.ValueOf(options).Elem() + t := v.Type() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldType := t.Field(i) + + if !field.CanInterface() { + continue + } + + jsonTag := fieldType.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } + + // Get flag name from JSON tag + flagName := strings.Split(jsonTag, ",")[0] + flagName = strings.ReplaceAll(flagName, "_", "-") + + switch field.Kind() { + case reflect.Bool: + if field.Bool() { + args = append(args, "--"+flagName) + } + case reflect.Int: + if field.Int() != 0 { + args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) + } + case reflect.Float64: + if field.Float() != 0 { + args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) + } + case reflect.String: + if field.String() != "" { + args = append(args, "--"+flagName, field.String()) + } + case reflect.Slice: + if field.Type().Elem().Kind() == reflect.String && field.Len() > 0 { + if multipleFlags[flagName] { + // Multiple flags: --flag value1 --flag value2 + for j := 0; j < field.Len(); j++ { + args = append(args, "--"+flagName, field.Index(j).String()) + } + } else { + // Comma-separated: --flag value1,value2 + var values []string + for j := 0; j < field.Len(); j++ { + values = append(values, field.Index(j).String()) + } + args = append(args, "--"+flagName, strings.Join(values, ",")) + } + } + } + } + + return args +} diff --git a/pkg/backends/llamacpp/llama.go b/pkg/backends/llamacpp/llama.go index 7c8a21f..f2a7d31 100644 --- a/pkg/backends/llamacpp/llama.go +++ b/pkg/backends/llamacpp/llama.go @@ -328,3 +328,31 @@ func (o *LlamaServerOptions) BuildCommandArgs() []string { } return backends.BuildCommandArgs(o, multipleFlags) } + +// ParseLlamaCommand parses a llama-server command string into LlamaServerOptions +// Supports multiple formats: +// 1. Full command: "llama-server --model file.gguf" +// 2. Full path: "/usr/local/bin/llama-server --model file.gguf" +// 3. Args only: "--model file.gguf --gpu-layers 32" +// 4. Multiline commands with backslashes +func ParseLlamaCommand(command string) (*LlamaServerOptions, error) { + executableNames := []string{"llama-server"} + var subcommandNames []string // Llama has no subcommands + multiValuedFlags := map[string]bool{ + "override_tensor": true, + "override_kv": true, + "lora": true, + "lora_scaled": true, + "control_vector": true, + "control_vector_scaled": true, + "dry_sequence_breaker": true, + "logit_bias": true, + } + + var llamaOptions LlamaServerOptions + if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &llamaOptions); err != nil { + return nil, err + } + + return &llamaOptions, nil +} diff --git a/pkg/backends/llamacpp/llama_test.go b/pkg/backends/llamacpp/llama_test.go index 9c1162e..c779320 100644 --- a/pkg/backends/llamacpp/llama_test.go +++ b/pkg/backends/llamacpp/llama_test.go @@ -378,6 +378,121 @@ func TestUnmarshalJSON_ArrayFields(t *testing.T) { } } +func TestParseLlamaCommand(t *testing.T) { + tests := []struct { + name string + command string + expectErr bool + }{ + { + name: "basic command", + command: "llama-server --model /path/to/model.gguf --gpu-layers 32", + expectErr: false, + }, + { + name: "args only", + command: "--model /path/to/model.gguf --ctx-size 4096", + expectErr: false, + }, + { + name: "mixed flag formats", + command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose", + expectErr: false, + }, + { + name: "quoted strings", + command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`, + expectErr: false, + }, + { + name: "empty command", + command: "", + expectErr: true, + }, + { + name: "unterminated quote", + command: `llama-server --model test.gguf --api-key "unterminated`, + expectErr: true, + }, + { + name: "malformed flag", + command: "llama-server ---model test.gguf", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := llamacpp.ParseLlamaCommand(tt.command) + + if tt.expectErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result == nil { + t.Errorf("expected result but got nil") + } + }) + } +} + +func TestParseLlamaCommandValues(t *testing.T) { + command := "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap" + result, err := llamacpp.ParseLlamaCommand(command) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Model != "/test/model.gguf" { + t.Errorf("expected model '/test/model.gguf', got '%s'", result.Model) + } + + if result.GPULayers != 32 { + t.Errorf("expected gpu_layers 32, got %d", result.GPULayers) + } + + if result.Temperature != 0.7 { + t.Errorf("expected temperature 0.7, got %f", result.Temperature) + } + + if !result.Verbose { + t.Errorf("expected verbose to be true") + } + + if !result.NoMmap { + t.Errorf("expected no_mmap to be true") + } +} + +func TestParseLlamaCommandArrays(t *testing.T) { + command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin" + result, err := llamacpp.ParseLlamaCommand(command) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Lora) != 2 { + t.Errorf("expected 2 lora adapters, got %d", len(result.Lora)) + } + + expected := []string{"adapter1.bin", "adapter2.bin"} + for i, v := range expected { + if result.Lora[i] != v { + t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i]) + } + } +} + // Helper functions func contains(slice []string, item string) bool { return slices.Contains(slice, item) diff --git a/pkg/backends/llamacpp/parser.go b/pkg/backends/llamacpp/parser.go deleted file mode 100644 index b5b850a..0000000 --- a/pkg/backends/llamacpp/parser.go +++ /dev/null @@ -1,34 +0,0 @@ -package llamacpp - -import ( - "llamactl/pkg/backends" -) - -// ParseLlamaCommand parses a llama-server command string into LlamaServerOptions -// Supports multiple formats: -// 1. Full command: "llama-server --model file.gguf" -// 2. Full path: "/usr/local/bin/llama-server --model file.gguf" -// 3. Args only: "--model file.gguf --gpu-layers 32" -// 4. Multiline commands with backslashes -func ParseLlamaCommand(command string) (*LlamaServerOptions, error) { - executableNames := []string{"llama-server"} - var subcommandNames []string // Llama has no subcommands - multiValuedFlags := map[string]bool{ - "override_tensor": true, - "override_kv": true, - "lora": true, - "lora_scaled": true, - "control_vector": true, - "control_vector_scaled": true, - "dry_sequence_breaker": true, - "logit_bias": true, - } - - var llamaOptions LlamaServerOptions - if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &llamaOptions); err != nil { - return nil, err - } - - return &llamaOptions, nil -} - diff --git a/pkg/backends/llamacpp/parser_test.go b/pkg/backends/llamacpp/parser_test.go deleted file mode 100644 index 7072f65..0000000 --- a/pkg/backends/llamacpp/parser_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package llamacpp_test - -import ( - "llamactl/pkg/backends/llamacpp" - "testing" -) - -func TestParseLlamaCommand(t *testing.T) { - tests := []struct { - name string - command string - expectErr bool - }{ - { - name: "basic command", - command: "llama-server --model /path/to/model.gguf --gpu-layers 32", - expectErr: false, - }, - { - name: "args only", - command: "--model /path/to/model.gguf --ctx-size 4096", - expectErr: false, - }, - { - name: "mixed flag formats", - command: "llama-server --model=/path/model.gguf --gpu-layers 16 --verbose", - expectErr: false, - }, - { - name: "quoted strings", - command: `llama-server --model test.gguf --api-key "sk-1234567890abcdef"`, - expectErr: false, - }, - { - name: "empty command", - command: "", - expectErr: true, - }, - { - name: "unterminated quote", - command: `llama-server --model test.gguf --api-key "unterminated`, - expectErr: true, - }, - { - name: "malformed flag", - command: "llama-server ---model test.gguf", - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := llamacpp.ParseLlamaCommand(tt.command) - - if tt.expectErr { - if err == nil { - t.Errorf("expected error but got none") - } - return - } - - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - - if result == nil { - t.Errorf("expected result but got nil") - } - }) - } -} - -func TestParseLlamaCommandValues(t *testing.T) { - command := "llama-server --model /test/model.gguf --gpu-layers 32 --temp 0.7 --verbose --no-mmap" - result, err := llamacpp.ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "/test/model.gguf" { - t.Errorf("expected model '/test/model.gguf', got '%s'", result.Model) - } - - if result.GPULayers != 32 { - t.Errorf("expected gpu_layers 32, got %d", result.GPULayers) - } - - if result.Temperature != 0.7 { - t.Errorf("expected temperature 0.7, got %f", result.Temperature) - } - - if !result.Verbose { - t.Errorf("expected verbose to be true") - } - - if !result.NoMmap { - t.Errorf("expected no_mmap to be true") - } -} - -func TestParseLlamaCommandArrays(t *testing.T) { - command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin" - result, err := llamacpp.ParseLlamaCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(result.Lora) != 2 { - t.Errorf("expected 2 lora adapters, got %d", len(result.Lora)) - } - - expected := []string{"adapter1.bin", "adapter2.bin"} - for i, v := range expected { - if result.Lora[i] != v { - t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i]) - } - } -} diff --git a/pkg/backends/mlx/mlx.go b/pkg/backends/mlx/mlx.go index c72597c..3b83681 100644 --- a/pkg/backends/mlx/mlx.go +++ b/pkg/backends/mlx/mlx.go @@ -1,10 +1,7 @@ package mlx import ( - "encoding/json" "llamactl/pkg/backends" - "reflect" - "strconv" ) type MlxServerOptions struct { @@ -26,88 +23,34 @@ type MlxServerOptions struct { ChatTemplateArgs string `json:"chat_template_args,omitempty"` // JSON string // Sampling defaults - Temp float64 `json:"temp,omitempty"` // Note: MLX uses "temp" not "temperature" + Temp float64 `json:"temp,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` MinP float64 `json:"min_p,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` } -// UnmarshalJSON implements custom JSON unmarshaling to support multiple field names -func (o *MlxServerOptions) UnmarshalJSON(data []byte) error { - // First unmarshal into a map to handle multiple field names - var raw map[string]any - if err := json.Unmarshal(data, &raw); err != nil { - return err - } - - // Create a temporary struct for standard unmarshaling - type tempOptions MlxServerOptions - temp := tempOptions{} - - // Standard unmarshal first - if err := json.Unmarshal(data, &temp); err != nil { - return err - } - - // Copy to our struct - *o = MlxServerOptions(temp) - - // Handle alternative field names - fieldMappings := map[string]string{ - "m": "model", // -m, --model - "temperature": "temp", // --temperature vs --temp - "top_k": "top_k", // --top-k - "adapter_path": "adapter_path", // --adapter-path - } - - // Process alternative field names - for altName, canonicalName := range fieldMappings { - if value, exists := raw[altName]; exists { - // Use reflection to set the field value - v := reflect.ValueOf(o).Elem() - field := v.FieldByNameFunc(func(fieldName string) bool { - field, _ := v.Type().FieldByName(fieldName) - jsonTag := field.Tag.Get("json") - return jsonTag == canonicalName+",omitempty" || jsonTag == canonicalName - }) - - if field.IsValid() && field.CanSet() { - switch field.Kind() { - case reflect.Int: - if intVal, ok := value.(float64); ok { - field.SetInt(int64(intVal)) - } else if strVal, ok := value.(string); ok { - if intVal, err := strconv.Atoi(strVal); err == nil { - field.SetInt(int64(intVal)) - } - } - case reflect.Float64: - if floatVal, ok := value.(float64); ok { - field.SetFloat(floatVal) - } else if strVal, ok := value.(string); ok { - if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil { - field.SetFloat(floatVal) - } - } - case reflect.String: - if strVal, ok := value.(string); ok { - field.SetString(strVal) - } - case reflect.Bool: - if boolVal, ok := value.(bool); ok { - field.SetBool(boolVal) - } - } - } - } - } - - return nil -} - // BuildCommandArgs converts to command line arguments func (o *MlxServerOptions) BuildCommandArgs() []string { multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields return backends.BuildCommandArgs(o, multipleFlags) } + +// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions +// Supports multiple formats: +// 1. Full command: "mlx_lm.server --model model/path" +// 2. Full path: "/usr/local/bin/mlx_lm.server --model model/path" +// 3. Args only: "--model model/path --host 0.0.0.0" +// 4. Multiline commands with backslashes +func ParseMlxCommand(command string) (*MlxServerOptions, error) { + executableNames := []string{"mlx_lm.server"} + var subcommandNames []string // MLX has no subcommands + multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags + + var mlxOptions MlxServerOptions + if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil { + return nil, err + } + + return &mlxOptions, nil +} diff --git a/pkg/backends/mlx/mlx_test.go b/pkg/backends/mlx/mlx_test.go index b35f512..8baeb5c 100644 --- a/pkg/backends/mlx/mlx_test.go +++ b/pkg/backends/mlx/mlx_test.go @@ -5,6 +5,101 @@ import ( "testing" ) +func TestParseMlxCommand(t *testing.T) { + tests := []struct { + name string + command string + expectErr bool + }{ + { + name: "basic command", + command: "mlx_lm.server --model /path/to/model --host 0.0.0.0", + expectErr: false, + }, + { + name: "args only", + command: "--model /path/to/model --port 8080", + expectErr: false, + }, + { + name: "mixed flag formats", + command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code", + expectErr: false, + }, + { + name: "quoted strings", + command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`, + expectErr: false, + }, + { + name: "empty command", + command: "", + expectErr: true, + }, + { + name: "unterminated quote", + command: `mlx_lm.server --model test.mlx --chat-template "unterminated`, + expectErr: true, + }, + { + name: "malformed flag", + command: "mlx_lm.server ---model test.mlx", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := mlx.ParseMlxCommand(tt.command) + + if tt.expectErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result == nil { + t.Errorf("expected result but got nil") + } + }) + } +} + +func TestParseMlxCommandValues(t *testing.T) { + command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG" + result, err := mlx.ParseMlxCommand(command) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Model != "/test/model.mlx" { + t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model) + } + + if result.Port != 8080 { + t.Errorf("expected port 8080, got %d", result.Port) + } + + if result.Temp != 0.7 { + t.Errorf("expected temp 0.7, got %f", result.Temp) + } + + if !result.TrustRemoteCode { + t.Errorf("expected trust_remote_code to be true") + } + + if result.LogLevel != "DEBUG" { + t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel) + } +} + func TestBuildCommandArgs(t *testing.T) { options := &mlx.MlxServerOptions{ Model: "/test/model.mlx", diff --git a/pkg/backends/mlx/parser.go b/pkg/backends/mlx/parser.go deleted file mode 100644 index ec4cfb2..0000000 --- a/pkg/backends/mlx/parser.go +++ /dev/null @@ -1,24 +0,0 @@ -package mlx - -import ( - "llamactl/pkg/backends" -) - -// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions -// Supports multiple formats: -// 1. Full command: "mlx_lm.server --model model/path" -// 2. Full path: "/usr/local/bin/mlx_lm.server --model model/path" -// 3. Args only: "--model model/path --host 0.0.0.0" -// 4. Multiline commands with backslashes -func ParseMlxCommand(command string) (*MlxServerOptions, error) { - executableNames := []string{"mlx_lm.server"} - var subcommandNames []string // MLX has no subcommands - multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags - - var mlxOptions MlxServerOptions - if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil { - return nil, err - } - - return &mlxOptions, nil -} diff --git a/pkg/backends/mlx/parser_test.go b/pkg/backends/mlx/parser_test.go deleted file mode 100644 index 6caae84..0000000 --- a/pkg/backends/mlx/parser_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package mlx_test - -import ( - "llamactl/pkg/backends/mlx" - "testing" -) - -func TestParseMlxCommand(t *testing.T) { - tests := []struct { - name string - command string - expectErr bool - }{ - { - name: "basic command", - command: "mlx_lm.server --model /path/to/model --host 0.0.0.0", - expectErr: false, - }, - { - name: "args only", - command: "--model /path/to/model --port 8080", - expectErr: false, - }, - { - name: "mixed flag formats", - command: "mlx_lm.server --model=/path/model --temp=0.7 --trust-remote-code", - expectErr: false, - }, - { - name: "quoted strings", - command: `mlx_lm.server --model test.mlx --chat-template "User: {user}\nAssistant: "`, - expectErr: false, - }, - { - name: "empty command", - command: "", - expectErr: true, - }, - { - name: "unterminated quote", - command: `mlx_lm.server --model test.mlx --chat-template "unterminated`, - expectErr: true, - }, - { - name: "malformed flag", - command: "mlx_lm.server ---model test.mlx", - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := mlx.ParseMlxCommand(tt.command) - - if tt.expectErr { - if err == nil { - t.Errorf("expected error but got none") - } - return - } - - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - - if result == nil { - t.Errorf("expected result but got nil") - } - }) - } -} - -func TestParseMlxCommandValues(t *testing.T) { - command := "mlx_lm.server --model /test/model.mlx --port 8080 --temp 0.7 --trust-remote-code --log-level DEBUG" - result, err := mlx.ParseMlxCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "/test/model.mlx" { - t.Errorf("expected model '/test/model.mlx', got '%s'", result.Model) - } - - if result.Port != 8080 { - t.Errorf("expected port 8080, got %d", result.Port) - } - - if result.Temp != 0.7 { - t.Errorf("expected temp 0.7, got %f", result.Temp) - } - - if !result.TrustRemoteCode { - t.Errorf("expected trust_remote_code to be true") - } - - if result.LogLevel != "DEBUG" { - t.Errorf("expected log_level 'DEBUG', got '%s'", result.LogLevel) - } -} diff --git a/pkg/backends/parser.go b/pkg/backends/parser.go index 721173f..89ad46e 100644 --- a/pkg/backends/parser.go +++ b/pkg/backends/parser.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "path/filepath" - "reflect" "regexp" "strconv" "strings" @@ -43,69 +42,6 @@ func ParseCommand(command string, executableNames []string, subcommandNames []st return nil } -// BuildCommandArgs converts a struct to command line arguments -func BuildCommandArgs(options any, multipleFlags map[string]bool) []string { - var args []string - - v := reflect.ValueOf(options).Elem() - t := v.Type() - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - fieldType := t.Field(i) - - if !field.CanInterface() { - continue - } - - jsonTag := fieldType.Tag.Get("json") - if jsonTag == "" || jsonTag == "-" { - continue - } - - // Get flag name from JSON tag - flagName := strings.Split(jsonTag, ",")[0] - flagName = strings.ReplaceAll(flagName, "_", "-") - - switch field.Kind() { - case reflect.Bool: - if field.Bool() { - args = append(args, "--"+flagName) - } - case reflect.Int: - if field.Int() != 0 { - args = append(args, "--"+flagName, strconv.FormatInt(field.Int(), 10)) - } - case reflect.Float64: - if field.Float() != 0 { - args = append(args, "--"+flagName, strconv.FormatFloat(field.Float(), 'f', -1, 64)) - } - case reflect.String: - if field.String() != "" { - args = append(args, "--"+flagName, field.String()) - } - case reflect.Slice: - if field.Type().Elem().Kind() == reflect.String && field.Len() > 0 { - if multipleFlags[flagName] { - // Multiple flags: --flag value1 --flag value2 - for j := 0; j < field.Len(); j++ { - args = append(args, "--"+flagName, field.Index(j).String()) - } - } else { - // Comma-separated: --flag value1,value2 - var values []string - for j := 0; j < field.Len(); j++ { - values = append(values, field.Index(j).String()) - } - args = append(args, "--"+flagName, strings.Join(values, ",")) - } - } - } - } - - return args -} - // normalizeCommand handles multiline commands with backslashes func normalizeCommand(command string) string { re := regexp.MustCompile(`\\\s*\n\s*`) diff --git a/pkg/backends/vllm/parser.go b/pkg/backends/vllm/parser.go deleted file mode 100644 index 5eb3fbf..0000000 --- a/pkg/backends/vllm/parser.go +++ /dev/null @@ -1,34 +0,0 @@ -package vllm - -import ( - "llamactl/pkg/backends" -) - -// ParseVllmCommand parses a vLLM serve command string into VllmServerOptions -// Supports multiple formats: -// 1. Full command: "vllm serve --model MODEL_NAME --other-args" -// 2. Full path: "/usr/local/bin/vllm serve --model MODEL_NAME" -// 3. Serve only: "serve --model MODEL_NAME --other-args" -// 4. Args only: "--model MODEL_NAME --other-args" -// 5. Multiline commands with backslashes -func ParseVllmCommand(command string) (*VllmServerOptions, error) { - executableNames := []string{"vllm"} - subcommandNames := []string{"serve"} - multiValuedFlags := map[string]bool{ - "middleware": true, - "api_key": true, - "allowed_origins": true, - "allowed_methods": true, - "allowed_headers": true, - "lora_modules": true, - "prompt_adapters": true, - } - - var vllmOptions VllmServerOptions - if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil { - return nil, err - } - - return &vllmOptions, nil -} - diff --git a/pkg/backends/vllm/parser_test.go b/pkg/backends/vllm/parser_test.go deleted file mode 100644 index 3a12456..0000000 --- a/pkg/backends/vllm/parser_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package vllm_test - -import ( - "llamactl/pkg/backends/vllm" - "testing" -) - -func TestParseVllmCommand(t *testing.T) { - tests := []struct { - name string - command string - expectErr bool - }{ - { - name: "basic vllm serve command", - command: "vllm serve --model microsoft/DialoGPT-medium", - expectErr: false, - }, - { - name: "serve only command", - command: "serve --model microsoft/DialoGPT-medium", - expectErr: false, - }, - { - name: "args only", - command: "--model microsoft/DialoGPT-medium --tensor-parallel-size 2", - expectErr: false, - }, - { - name: "empty command", - command: "", - expectErr: true, - }, - { - name: "unterminated quote", - command: `vllm serve --model "unterminated`, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := vllm.ParseVllmCommand(tt.command) - - if tt.expectErr { - if err == nil { - t.Errorf("expected error but got none") - } - return - } - - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - - if result == nil { - t.Errorf("expected result but got nil") - } - }) - } -} - -func TestParseVllmCommandValues(t *testing.T) { - command := "vllm serve --model test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs" - result, err := vllm.ParseVllmCommand(command) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if result.Model != "test-model" { - t.Errorf("expected model 'test-model', got '%s'", result.Model) - } - if result.TensorParallelSize != 4 { - t.Errorf("expected tensor_parallel_size 4, got %d", result.TensorParallelSize) - } - if result.GPUMemoryUtilization != 0.8 { - t.Errorf("expected gpu_memory_utilization 0.8, got %f", result.GPUMemoryUtilization) - } - if !result.EnableLogOutputs { - t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs) - } -} diff --git a/pkg/backends/vllm/vllm.go b/pkg/backends/vllm/vllm.go index 2ab6ed8..df080ea 100644 --- a/pkg/backends/vllm/vllm.go +++ b/pkg/backends/vllm/vllm.go @@ -142,3 +142,31 @@ func (o *VllmServerOptions) BuildCommandArgs() []string { } return backends.BuildCommandArgs(o, multipleFlags) } + +// ParseVllmCommand parses a vLLM serve command string into VllmServerOptions +// Supports multiple formats: +// 1. Full command: "vllm serve --model MODEL_NAME --other-args" +// 2. Full path: "/usr/local/bin/vllm serve --model MODEL_NAME" +// 3. Serve only: "serve --model MODEL_NAME --other-args" +// 4. Args only: "--model MODEL_NAME --other-args" +// 5. Multiline commands with backslashes +func ParseVllmCommand(command string) (*VllmServerOptions, error) { + executableNames := []string{"vllm"} + subcommandNames := []string{"serve"} + multiValuedFlags := map[string]bool{ + "middleware": true, + "api_key": true, + "allowed_origins": true, + "allowed_methods": true, + "allowed_headers": true, + "lora_modules": true, + "prompt_adapters": true, + } + + var vllmOptions VllmServerOptions + if err := backends.ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil { + return nil, err + } + + return &vllmOptions, nil +} diff --git a/pkg/backends/vllm/vllm_test.go b/pkg/backends/vllm/vllm_test.go index 8a42862..40423b9 100644 --- a/pkg/backends/vllm/vllm_test.go +++ b/pkg/backends/vllm/vllm_test.go @@ -7,6 +7,84 @@ import ( "testing" ) +func TestParseVllmCommand(t *testing.T) { + tests := []struct { + name string + command string + expectErr bool + }{ + { + name: "basic vllm serve command", + command: "vllm serve --model microsoft/DialoGPT-medium", + expectErr: false, + }, + { + name: "serve only command", + command: "serve --model microsoft/DialoGPT-medium", + expectErr: false, + }, + { + name: "args only", + command: "--model microsoft/DialoGPT-medium --tensor-parallel-size 2", + expectErr: false, + }, + { + name: "empty command", + command: "", + expectErr: true, + }, + { + name: "unterminated quote", + command: `vllm serve --model "unterminated`, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := vllm.ParseVllmCommand(tt.command) + + if tt.expectErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result == nil { + t.Errorf("expected result but got nil") + } + }) + } +} + +func TestParseVllmCommandValues(t *testing.T) { + command := "vllm serve --model test-model --tensor-parallel-size 4 --gpu-memory-utilization 0.8 --enable-log-outputs" + result, err := vllm.ParseVllmCommand(command) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Model != "test-model" { + t.Errorf("expected model 'test-model', got '%s'", result.Model) + } + if result.TensorParallelSize != 4 { + t.Errorf("expected tensor_parallel_size 4, got %d", result.TensorParallelSize) + } + if result.GPUMemoryUtilization != 0.8 { + t.Errorf("expected gpu_memory_utilization 0.8, got %f", result.GPUMemoryUtilization) + } + if !result.EnableLogOutputs { + t.Errorf("expected enable_log_outputs true, got %v", result.EnableLogOutputs) + } +} + func TestBuildCommandArgs(t *testing.T) { options := vllm.VllmServerOptions{ Model: "microsoft/DialoGPT-medium", From 7eb59aa7e059a3af3a56d62cc0499bd983e2db23 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 19 Sep 2025 20:46:25 +0200 Subject: [PATCH 07/20] Remove unused JSON unmarshal test and clean up command argument checks --- pkg/backends/vllm/vllm_test.go | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/pkg/backends/vllm/vllm_test.go b/pkg/backends/vllm/vllm_test.go index 40423b9..3f01ff9 100644 --- a/pkg/backends/vllm/vllm_test.go +++ b/pkg/backends/vllm/vllm_test.go @@ -1,7 +1,6 @@ package vllm_test import ( - "encoding/json" "llamactl/pkg/backends/vllm" "slices" "testing" @@ -108,8 +107,6 @@ func TestBuildCommandArgs(t *testing.T) { if !contains(args, "--enable-log-outputs") { t.Errorf("Expected --enable-log-outputs not found in %v", args) } - - // Host and port should NOT be in the arguments (handled by llamactl) if !contains(args, "--host") { t.Errorf("Expected --host not found in %v", args) } @@ -129,35 +126,6 @@ func TestBuildCommandArgs(t *testing.T) { } } -func TestUnmarshalJSON(t *testing.T) { - // Test both underscore and dash formats - jsonData := `{ - "model": "test-model", - "tensor_parallel_size": 4, - "gpu-memory-utilization": 0.9, - "enable-log-outputs": true - }` - - var options vllm.VllmServerOptions - err := json.Unmarshal([]byte(jsonData), &options) - if err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - - if options.Model != "test-model" { - t.Errorf("Expected model 'test-model', got %q", options.Model) - } - if options.TensorParallelSize != 4 { - t.Errorf("Expected tensor_parallel_size 4, got %d", options.TensorParallelSize) - } - if options.GPUMemoryUtilization != 0.9 { - t.Errorf("Expected gpu_memory_utilization 0.9, got %f", options.GPUMemoryUtilization) - } - if !options.EnableLogOutputs { - t.Errorf("Expected enable_log_outputs true, got %v", options.EnableLogOutputs) - } -} - // Helper functions func contains(slice []string, item string) bool { return slices.Contains(slice, item) From b66519430754fc6c83e522d8e702ed19c93de587 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 21 Sep 2025 20:58:43 +0200 Subject: [PATCH 08/20] Add vLLM backend support to webui --- webui/src/components/ParseCommandDialog.tsx | 63 ++++- webui/src/components/ZodFormField.tsx | 1 + webui/src/lib/api.ts | 8 + webui/src/lib/zodFormUtils.ts | 59 ++++- webui/src/schemas/backends/index.ts | 4 + webui/src/schemas/backends/llamacpp.ts | 192 ++++++++++++++ webui/src/schemas/backends/mlx.ts | 51 ++++ webui/src/schemas/backends/vllm.ts | 150 +++++++++++ webui/src/schemas/instanceOptions.ts | 274 +++----------------- webui/src/types/instance.ts | 1 + 10 files changed, 545 insertions(+), 258 deletions(-) create mode 100644 webui/src/schemas/backends/index.ts create mode 100644 webui/src/schemas/backends/llamacpp.ts create mode 100644 webui/src/schemas/backends/mlx.ts create mode 100644 webui/src/schemas/backends/vllm.ts diff --git a/webui/src/components/ParseCommandDialog.tsx b/webui/src/components/ParseCommandDialog.tsx index 6b14eaa..fcf79e6 100644 --- a/webui/src/components/ParseCommandDialog.tsx +++ b/webui/src/components/ParseCommandDialog.tsx @@ -9,7 +9,7 @@ import { DialogHeader, DialogTitle, } from "@/components/ui/dialog"; -import { type CreateInstanceOptions } from "@/types/instance"; +import { BackendType, type BackendTypeValue, type CreateInstanceOptions } from "@/types/instance"; import { backendsApi } from "@/lib/api"; import { toast } from "sonner"; @@ -25,6 +25,7 @@ const ParseCommandDialog: React.FC = ({ onParsed, }) => { const [command, setCommand] = useState(''); + const [backendType, setBackendType] = useState(BackendType.LLAMA_CPP); const [loading, setLoading] = useState(false); const [error, setError] = useState(null); @@ -38,18 +39,31 @@ const ParseCommandDialog: React.FC = ({ setError(null); try { - const options = await backendsApi.llamaCpp.parseCommand(command); + let options: CreateInstanceOptions; + + // Parse based on selected backend type + switch (backendType) { + case BackendType.LLAMA_CPP: + options = await backendsApi.llamaCpp.parseCommand(command); + break; + case BackendType.MLX_LM: + options = await backendsApi.mlx.parseCommand(command); + break; + case BackendType.VLLM: + options = await backendsApi.vllm.parseCommand(command); + break; + default: + throw new Error(`Unsupported backend type: ${backendType}`); + } + onParsed(options); onOpenChange(false); - // Reset form setCommand(''); setError(null); - // Show success toast toast.success('Command parsed successfully'); } catch (err) { const errorMessage = err instanceof Error ? err.message : 'Failed to parse command'; setError(errorMessage); - // Show error toast toast.error('Failed to parse command', { description: errorMessage }); @@ -60,35 +74,62 @@ const ParseCommandDialog: React.FC = ({ const handleOpenChange = (open: boolean) => { if (!open) { - // Reset form when closing setCommand(''); + setBackendType(BackendType.LLAMA_CPP); setError(null); } onOpenChange(open); }; + const getPlaceholderForBackend = (backendType: BackendTypeValue): string => { + switch (backendType) { + case BackendType.LLAMA_CPP: + return "llama-server --model /path/to/model.gguf --gpu-layers 32 --ctx-size 4096"; + case BackendType.MLX_LM: + return "mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit --host 0.0.0.0 --port 8080"; + case BackendType.VLLM: + return "vllm serve --model microsoft/DialoGPT-medium --tensor-parallel-size 2 --gpu-memory-utilization 0.9"; + default: + return "Enter your command here..."; + } + }; + return ( - Parse Llama Server Command + Parse Backend Command - Paste your llama-server command to automatically populate the form fields + Select your backend type and paste the command to automatically populate the form fields - +
+
+ + +
+