diff --git a/pkg/backends/llamacpp/parser.go b/pkg/backends/llamacpp/parser.go index 2dd7f65..d94b0ed 100644 --- a/pkg/backends/llamacpp/parser.go +++ b/pkg/backends/llamacpp/parser.go @@ -2,6 +2,7 @@ package llamacpp import ( "encoding/json" + "errors" "fmt" "path/filepath" "regexp" @@ -30,78 +31,101 @@ func ParseLlamaCommand(command string) (*LlamaServerOptions, error) { // 3. Parse arguments into map options := make(map[string]any) + + // Known multi-valued flags (snake_case form) + multiValued := map[string]struct{}{ + "override_tensor": {}, + "override_kv": {}, + "lora": {}, + "lora_scaled": {}, + "control_vector": {}, + "control_vector_scaled": {}, + "dry_sequence_breaker": {}, + "logit_bias": {}, + } + i := 0 for i < len(args) { arg := args[i] - // Skip non-flag arguments - if !strings.HasPrefix(arg, "-") { + if !strings.HasPrefix(arg, "-") { // skip positional / stray values i++ continue } - // Handle --flag=value format + // Reject malformed flags with more than two leading dashes (e.g. ---model) to surface user mistakes + if strings.HasPrefix(arg, "---") { + return nil, fmt.Errorf("malformed flag: %s", arg) + } + + // Unified parsing for --flag=value vs --flag value + var rawFlag, rawValue string + hasEquals := false if strings.Contains(arg, "=") { parts := strings.SplitN(arg, "=", 2) - flag := strings.TrimPrefix(parts[0], "-") - flag = strings.TrimPrefix(flag, "-") + rawFlag = parts[0] + rawValue = parts[1] // may be empty string + hasEquals = true + } else { + rawFlag = arg + } - // Convert flag from kebab-case to snake_case for consistency with JSON field names - flagName := strings.ReplaceAll(flag, "-", "_") + flagCore := strings.TrimPrefix(strings.TrimPrefix(rawFlag, "-"), "-") + flagName := strings.ReplaceAll(flagCore, "-", "_") - // Convert value to appropriate type - value := parseValue(parts[1]) - - // Handle array flags by checking if flag already exists - if existingValue, exists := options[flagName]; exists { - // Convert to array if not already - switch existing := existingValue.(type) { - case []string: - options[flagName] = append(existing, parts[1]) - case string: - options[flagName] = []string{existing, parts[1]} - default: - options[flagName] = []string{fmt.Sprintf("%v", existing), parts[1]} - } - } else { - options[flagName] = value + // Detect value if not in equals form + valueProvided := hasEquals + if !hasEquals { + if i+1 < len(args) && !isFlag(args[i+1]) { // next token is value + rawValue = args[i+1] + valueProvided = true + } + } + + // Determine if multi-valued flag + _, isMulti := multiValued[flagName] + + // Normalization helper: ensure slice for multi-valued flags + appendValue := func(valStr string) { + if existing, ok := options[flagName]; ok { + // Existing value; ensure slice semantics for multi-valued flags or repeated occurrences + if slice, ok := existing.([]string); ok { + options[flagName] = append(slice, valStr) + return + } + // Convert scalar to slice + options[flagName] = []string{fmt.Sprintf("%v", existing), valStr} + return + } + // First value + if isMulti { + options[flagName] = []string{valStr} + } else { + // We'll parse type below for single-valued flags + options[flagName] = valStr + } + } + + if valueProvided { + // Use raw token for multi-valued flags; else allow typed parsing + appendValue(rawValue) + if !isMulti { // convert to typed value if scalar + if strVal, ok := options[flagName].(string); ok { // still scalar + options[flagName] = parseValue(strVal) + } + } + // Advance index: if we consumed a following token as value (non equals form), skip it + if !hasEquals && i+1 < len(args) && rawValue == args[i+1] { + i += 2 + } else { + i++ } - i++ continue } - // Handle --flag value format - flag := strings.TrimPrefix(arg, "-") - flag = strings.TrimPrefix(flag, "-") - - // Convert flag from kebab-case to snake_case for consistency with JSON field names - flagName := strings.ReplaceAll(flag, "-", "_") - - // Check if next arg is a value (not a flag) - // Special case: allow negative numbers as values - if i+1 < len(args) && !isFlag(args[i+1]) { - value := parseValue(args[i+1]) - - // Handle array flags by checking if flag already exists - if existingValue, exists := options[flagName]; exists { - // Convert to array if not already - switch existing := existingValue.(type) { - case []string: - options[flagName] = append(existing, args[i+1]) - case string: - options[flagName] = []string{existing, args[i+1]} - default: - options[flagName] = []string{fmt.Sprintf("%v", existing), args[i+1]} - } - } else { - options[flagName] = value - } - i += 2 // Skip flag and value - } else { - // Boolean flag - options[flagName] = true - i++ - } + // Boolean flag (no value) + options[flagName] = true + i++ } // 4. Convert to LlamaServerOptions using existing UnmarshalJSON @@ -121,26 +145,28 @@ func ParseLlamaCommand(command string) (*LlamaServerOptions, error) { // parseValue attempts to parse a string value into the most appropriate type func parseValue(value string) any { - // Try to parse as boolean - if strings.ToLower(value) == "true" { + // Surrounding matching quotes (single or double) + if l := len(value); l >= 2 { + if (value[0] == '"' && value[l-1] == '"') || (value[0] == '\'' && value[l-1] == '\'') { + value = value[1 : l-1] + } + } + + lower := strings.ToLower(value) + if lower == "true" { return true } - if strings.ToLower(value) == "false" { + if lower == "false" { return false } - // Try to parse as integer (handle negative numbers) if intVal, err := strconv.Atoi(value); err == nil { return intVal } - - // Try to parse as float if floatVal, err := strconv.ParseFloat(value, 64); err == nil { return floatVal } - - // Default to string (remove quotes if present) - return strings.Trim(value, `""`) + return value } // normalizeMultilineCommand handles multiline commands with backslashes @@ -148,11 +174,11 @@ func normalizeMultilineCommand(command string) string { // Handle escaped newlines (backslash followed by newline) re := regexp.MustCompile(`\\\s*\n\s*`) normalized := re.ReplaceAllString(command, " ") - + // Clean up extra whitespace re = regexp.MustCompile(`\s+`) normalized = re.ReplaceAllString(normalized, " ") - + return strings.TrimSpace(normalized) } @@ -163,30 +189,30 @@ func extractArgumentsFromCommand(command string) ([]string, error) { if err != nil { return nil, err } - + if len(tokens) == 0 { return nil, fmt.Errorf("no command tokens found") } - + // Check if first token looks like an executable firstToken := tokens[0] - + // Case 1: Full path to executable (contains path separator or ends with llama-server) - if strings.Contains(firstToken, string(filepath.Separator)) || - strings.HasSuffix(filepath.Base(firstToken), "llama-server") { + if strings.Contains(firstToken, string(filepath.Separator)) || + strings.HasSuffix(filepath.Base(firstToken), "llama-server") { return tokens[1:], nil // Return everything except the executable } - + // Case 2: Just "llama-server" command if strings.ToLower(firstToken) == "llama-server" { return tokens[1:], nil // Return everything except the command } - + // Case 3: Arguments only (starts with a flag) if strings.HasPrefix(firstToken, "-") { return tokens, nil // Return all tokens as arguments } - + // Case 4: Unknown format - might be a different executable name // Be permissive and assume it's the executable return tokens[1:], nil @@ -199,22 +225,22 @@ func splitCommandTokens(command string) ([]string, error) { inQuotes := false quoteChar := byte(0) escaped := false - + for i := 0; i < len(command); i++ { c := command[i] - + if escaped { current.WriteByte(c) escaped = false continue } - + if c == '\\' { escaped = true current.WriteByte(c) continue } - + if !inQuotes && (c == '"' || c == '\'') { inQuotes = true quoteChar = c @@ -232,15 +258,15 @@ func splitCommandTokens(command string) ([]string, error) { current.WriteByte(c) } } - + if inQuotes { - return nil, fmt.Errorf("unterminated quoted string") + return nil, errors.New("unterminated quoted string") } - + if current.Len() > 0 { tokens = append(tokens, current.String()) } - + return tokens, nil } @@ -250,11 +276,11 @@ func isFlag(arg string) bool { if !strings.HasPrefix(arg, "-") { return false } - + // Special case: if it's a negative number, treat it as a value if _, err := strconv.ParseFloat(arg, 64); err == nil { return false } - + return true -} \ No newline at end of file +} diff --git a/pkg/backends/llamacpp/parser_test.go b/pkg/backends/llamacpp/parser_test.go index a217840..60e6a19 100644 --- a/pkg/backends/llamacpp/parser_test.go +++ b/pkg/backends/llamacpp/parser_test.go @@ -258,7 +258,7 @@ func TestParseLlamaCommandMultiline(t *testing.T) { --ctx-size 4096 \ --batch-size 512 \ --gpu-layers 32` - + result, err := ParseLlamaCommand(command) if err != nil { @@ -318,7 +318,7 @@ func TestParseLlamaCommandUnslothExample(t *testing.T) { --host 0.0.0.0 \ --port 8000 \ --api-key "sk-1234567890abcdef"` - + result, err := ParseLlamaCommand(command) if err != nil { @@ -369,4 +369,45 @@ func TestParseLlamaCommandUnslothExample(t *testing.T) { if result.APIKey != "sk-1234567890abcdef" { t.Errorf("expected api_key 'sk-1234567890abcdef', got '%s'", result.APIKey) } -} \ No newline at end of file +} + +// Focused additional edge case tests (kept minimal per guidance) +func TestParseLlamaCommandSingleQuotedValue(t *testing.T) { + cmd := "llama-server --model 'my model.gguf' --alias 'Test Alias'" + result, err := ParseLlamaCommand(cmd) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Model != "my model.gguf" { + t.Errorf("expected model 'my model.gguf', got '%s'", result.Model) + } + if result.Alias != "Test Alias" { + t.Errorf("expected alias 'Test Alias', got '%s'", result.Alias) + } +} + +func TestParseLlamaCommandMixedArrayForms(t *testing.T) { + // Same multi-value flag using --flag value and --flag=value forms + cmd := "llama-server --lora adapter1.bin --lora=adapter2.bin --lora adapter3.bin" + result, err := ParseLlamaCommand(cmd) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result.Lora) != 3 { + t.Fatalf("expected 3 lora values, got %d (%v)", len(result.Lora), result.Lora) + } + expected := []string{"adapter1.bin", "adapter2.bin", "adapter3.bin"} + for i, v := range expected { + if result.Lora[i] != v { + t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i]) + } + } +} + +func TestParseLlamaCommandMalformedFlag(t *testing.T) { + cmd := "llama-server ---model test.gguf" + _, err := ParseLlamaCommand(cmd) + if err == nil { + t.Fatalf("expected error for malformed flag but got none") + } +} diff --git a/pkg/server/handlers.go b/pkg/server/handlers.go index 81ae23a..e5e2eb5 100644 --- a/pkg/server/handlers.go +++ b/pkg/server/handlers.go @@ -646,38 +646,41 @@ type ParseCommandRequest struct { // @Produce json // @Param request body ParseCommandRequest true "Command to parse" // @Success 200 {object} instance.CreateInstanceOptions "Parsed options" -// @Failure 400 {string} string "Invalid request or command" -// @Failure 500 {string} string "Internal Server Error" +// @Failure 400 {object} map[string]string "Invalid request or command" +// @Failure 500 {object} map[string]string "Internal Server Error" // @Router /backends/llama-cpp/parse-command [post] func (h *Handler) ParseLlamaCommand() http.HandlerFunc { + type errorResponse struct { + Error string `json:"error"` + Details string `json:"details,omitempty"` + } + writeError := func(w http.ResponseWriter, status int, code, details string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(errorResponse{Error: code, Details: details}) + } return func(w http.ResponseWriter, r *http.Request) { var req ParseCommandRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Invalid request body", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_request", "Invalid JSON body") return } - - if req.Command == "" { - http.Error(w, "Command cannot be empty", http.StatusBadRequest) + if strings.TrimSpace(req.Command) == "" { + writeError(w, http.StatusBadRequest, "invalid_command", "Command cannot be empty") return } - - // Parse the command using llamacpp parser llamaOptions, err := llamacpp.ParseLlamaCommand(req.Command) if err != nil { - http.Error(w, "Failed to parse command: "+err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "parse_error", err.Error()) return } - - // Create the full CreateInstanceOptions options := &instance.CreateInstanceOptions{ BackendType: backends.BackendTypeLlamaCpp, LlamaServerOptions: llamaOptions, } - w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(options); err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, "encode_error", err.Error()) } } }