diff --git a/pkg/backends/backend.go b/pkg/backends/backend.go index ae22b1c..db18920 100644 --- a/pkg/backends/backend.go +++ b/pkg/backends/backend.go @@ -23,6 +23,7 @@ type backend interface { SetPort(int) GetHost() string Validate() error + ParseCommand(string) (any, error) } var backendConstructors = map[BackendType]func() backend{ diff --git a/pkg/backends/builder.go b/pkg/backends/builder.go index d5b5c0c..d224742 100644 --- a/pkg/backends/builder.go +++ b/pkg/backends/builder.go @@ -9,7 +9,7 @@ import ( ) // BuildCommandArgs converts a struct to command line arguments -func BuildCommandArgs(options any, multipleFlags map[string]bool) []string { +func BuildCommandArgs(options any, multipleFlags map[string]struct{}) []string { var args []string v := reflect.ValueOf(options).Elem() @@ -28,9 +28,10 @@ func BuildCommandArgs(options any, multipleFlags map[string]bool) []string { continue } - // Get flag name from JSON tag - flagName := strings.Split(jsonTag, ",")[0] - flagName = strings.ReplaceAll(flagName, "_", "-") + // Get flag name from JSON tag (snake_case) + jsonFieldName := strings.Split(jsonTag, ",")[0] + // Convert to kebab-case for CLI flags + flagName := strings.ReplaceAll(jsonFieldName, "_", "-") switch field.Kind() { case reflect.Bool: @@ -51,7 +52,8 @@ func BuildCommandArgs(options any, multipleFlags map[string]bool) []string { } case reflect.Slice: if field.Type().Elem().Kind() == reflect.String && field.Len() > 0 { - if multipleFlags[flagName] { + // Use jsonFieldName (snake_case) for multipleFlags lookup + if _, isMultiValue := multipleFlags[jsonFieldName]; isMultiValue { // Multiple flags: --flag value1 --flag value2 for j := 0; j < field.Len(); j++ { args = append(args, "--"+flagName, field.Index(j).String()) diff --git a/pkg/backends/llama.go b/pkg/backends/llama.go index dc07457..2b3372a 100644 --- a/pkg/backends/llama.go +++ b/pkg/backends/llama.go @@ -9,25 +9,16 @@ import ( ) // llamaMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated -// Used for both parsing (with underscores) and building (with dashes) -var llamaMultiValuedFlags = map[string]bool{ - // Parsing keys (with underscores) - "override_tensor": true, - "override_kv": true, - "lora": true, - "lora_scaled": true, - "control_vector": true, - "control_vector_scaled": true, - "dry_sequence_breaker": true, - "logit_bias": true, - // Building keys (with dashes) - "override-tensor": true, - "override-kv": true, - "lora-scaled": true, - "control-vector": true, - "control-vector-scaled": true, - "dry-sequence-breaker": true, - "logit-bias": true, +// Keys use snake_case as the parser converts kebab-case flags to snake_case before lookup +var llamaMultiValuedFlags = map[string]struct{}{ + "override_tensor": {}, + "override_kv": {}, + "lora": {}, + "lora_scaled": {}, + "control_vector": {}, + "control_vector_scaled": {}, + "dry_sequence_breaker": {}, + "logit_bias": {}, } type LlamaServerOptions struct { @@ -378,19 +369,19 @@ func (o *LlamaServerOptions) BuildDockerArgs() []string { return o.BuildCommandArgs() } -// ParseLlamaCommand parses a llama-server command string into LlamaServerOptions +// ParseCommand parses a llama-server command string into LlamaServerOptions // Supports multiple formats: // 1. Full command: "llama-server --model file.gguf" // 2. Full path: "/usr/local/bin/llama-server --model file.gguf" // 3. Args only: "--model file.gguf --gpu-layers 32" // 4. Multiline commands with backslashes -func ParseLlamaCommand(command string) (*LlamaServerOptions, error) { +func (o *LlamaServerOptions) ParseCommand(command string) (any, error) { executableNames := []string{"llama-server"} var subcommandNames []string // Llama has no subcommands // Use package-level llamaMultiValuedFlags variable var llamaOptions LlamaServerOptions - if err := ParseCommand(command, executableNames, subcommandNames, llamaMultiValuedFlags, &llamaOptions); err != nil { + if err := parseCommand(command, executableNames, subcommandNames, llamaMultiValuedFlags, &llamaOptions); err != nil { return nil, err } diff --git a/pkg/backends/llama_test.go b/pkg/backends/llama_test.go index c05a3a5..1698d37 100644 --- a/pkg/backends/llama_test.go +++ b/pkg/backends/llama_test.go @@ -385,7 +385,9 @@ func TestParseLlamaCommand(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := backends.ParseLlamaCommand(tt.command) + var opts backends.LlamaServerOptions + resultAny, err := opts.ParseCommand(tt.command) + result, _ := resultAny.(*backends.LlamaServerOptions) if tt.expectErr { if err == nil { @@ -413,7 +415,9 @@ func TestParseLlamaCommand(t *testing.T) { func TestParseLlamaCommandArrays(t *testing.T) { command := "llama-server --model test.gguf --lora adapter1.bin --lora=adapter2.bin" - result, err := backends.ParseLlamaCommand(command) + var opts backends.LlamaServerOptions + resultAny, err := opts.ParseCommand(command) + result, _ := resultAny.(*backends.LlamaServerOptions) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -429,4 +433,4 @@ func TestParseLlamaCommandArrays(t *testing.T) { t.Errorf("expected lora[%d]=%s got %s", i, v, result.Lora[i]) } } -} \ No newline at end of file +} diff --git a/pkg/backends/mlx.go b/pkg/backends/mlx.go index d0ec602..8911d0b 100644 --- a/pkg/backends/mlx.go +++ b/pkg/backends/mlx.go @@ -62,7 +62,7 @@ func (o *MlxServerOptions) Validate() error { // BuildCommandArgs converts to command line arguments func (o *MlxServerOptions) BuildCommandArgs() []string { - multipleFlags := map[string]bool{} // MLX doesn't currently have []string fields + multipleFlags := map[string]struct{}{} // MLX doesn't currently have []string fields return BuildCommandArgs(o, multipleFlags) } @@ -70,19 +70,19 @@ func (o *MlxServerOptions) BuildDockerArgs() []string { return []string{} } -// ParseMlxCommand parses a mlx_lm.server command string into MlxServerOptions +// ParseCommand parses a mlx_lm.server command string into MlxServerOptions // Supports multiple formats: // 1. Full command: "mlx_lm.server --model model/path" // 2. Full path: "/usr/local/bin/mlx_lm.server --model model/path" // 3. Args only: "--model model/path --host 0.0.0.0" // 4. Multiline commands with backslashes -func ParseMlxCommand(command string) (*MlxServerOptions, error) { +func (o *MlxServerOptions) ParseCommand(command string) (any, error) { executableNames := []string{"mlx_lm.server"} - var subcommandNames []string // MLX has no subcommands - multiValuedFlags := map[string]bool{} // MLX has no multi-valued flags + var subcommandNames []string // MLX has no subcommands + multiValuedFlags := map[string]struct{}{} // MLX has no multi-valued flags var mlxOptions MlxServerOptions - if err := ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil { + if err := parseCommand(command, executableNames, subcommandNames, multiValuedFlags, &mlxOptions); err != nil { return nil, err } diff --git a/pkg/backends/mlx_test.go b/pkg/backends/mlx_test.go index 0194551..d15be3d 100644 --- a/pkg/backends/mlx_test.go +++ b/pkg/backends/mlx_test.go @@ -96,7 +96,9 @@ func TestParseMlxCommand(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := backends.ParseMlxCommand(tt.command) + var opts backends.MlxServerOptions + resultAny, err := opts.ParseCommand(tt.command) + result, _ := resultAny.(*backends.MlxServerOptions) if tt.expectErr { if err == nil { @@ -174,11 +176,11 @@ func TestMlxBuildCommandArgs_BooleanFields(t *testing.T) { func TestMlxBuildCommandArgs_ZeroValues(t *testing.T) { options := backends.MlxServerOptions{ - Port: 0, // Should be excluded - TopK: 0, // Should be excluded - Temp: 0, // Should be excluded - Model: "", // Should be excluded - LogLevel: "", // Should be excluded + Port: 0, // Should be excluded + TopK: 0, // Should be excluded + Temp: 0, // Should be excluded + Model: "", // Should be excluded + LogLevel: "", // Should be excluded TrustRemoteCode: false, // Should be excluded } @@ -199,4 +201,4 @@ func TestMlxBuildCommandArgs_ZeroValues(t *testing.T) { t.Errorf("Zero value argument %q should not be present in %v", excludedArg, args) } } -} \ No newline at end of file +} diff --git a/pkg/backends/parser.go b/pkg/backends/parser.go index df585c9..8208568 100644 --- a/pkg/backends/parser.go +++ b/pkg/backends/parser.go @@ -9,8 +9,8 @@ import ( "strings" ) -// ParseCommand parses a command string into a target struct -func ParseCommand(command string, executableNames []string, subcommandNames []string, multiValuedFlags map[string]bool, target any) error { +// parseCommand parses a command string into a target struct +func parseCommand(command string, executableNames []string, subcommandNames []string, multiValuedFlags map[string]struct{}, target any) error { // Normalize multiline commands command = normalizeCommand(command) if command == "" { @@ -125,7 +125,7 @@ func extractArgs(command string, executableNames []string, subcommandNames []str } // parseFlags parses command line flags into a map -func parseFlags(args []string, multiValuedFlags map[string]bool) (map[string]any, error) { +func parseFlags(args []string, multiValuedFlags map[string]struct{}) (map[string]any, error) { options := make(map[string]any) for i := 0; i < len(args); i++ { @@ -163,7 +163,7 @@ func parseFlags(args []string, multiValuedFlags map[string]bool) (map[string]any if hasValue { // Handle multi-valued flags - if multiValuedFlags[flagName] { + if _, isMultiValue := multiValuedFlags[flagName]; isMultiValue { if existing, ok := options[flagName].([]string); ok { options[flagName] = append(existing, value) } else { diff --git a/pkg/backends/vllm.go b/pkg/backends/vllm.go index 857eab3..34dce4c 100644 --- a/pkg/backends/vllm.go +++ b/pkg/backends/vllm.go @@ -6,12 +6,16 @@ import ( ) // vllmMultiValuedFlags defines flags that should be repeated for each value rather than comma-separated -var vllmMultiValuedFlags = map[string]bool{ - "api-key": true, - "allowed-origins": true, - "allowed-methods": true, - "allowed-headers": true, - "middleware": true, +// Based on vLLM's CLI argument definitions with action='append' or List types +// Keys use snake_case as the parser converts kebab-case flags to snake_case before lookup +var vllmMultiValuedFlags = map[string]struct{}{ + "api_key": {}, // --api-key (action='append') + "allowed_origins": {}, // --allowed-origins (List type) + "allowed_methods": {}, // --allowed-methods (List type) + "allowed_headers": {}, // --allowed-headers (List type) + "middleware": {}, // --middleware (action='append') + "lora_modules": {}, // --lora-modules (custom LoRAParserAction, accepts multiple) + "prompt_adapters": {}, // --prompt-adapters (similar to lora-modules, accepts multiple) } type VllmServerOptions struct { @@ -202,28 +206,19 @@ func (o *VllmServerOptions) BuildDockerArgs() []string { return args } -// ParseVllmCommand parses a vLLM serve command string into VllmServerOptions +// ParseCommand parses a vLLM serve command string into VllmServerOptions // Supports multiple formats: // 1. Full command: "vllm serve --model MODEL_NAME --other-args" // 2. Full path: "/usr/local/bin/vllm serve --model MODEL_NAME" // 3. Serve only: "serve --model MODEL_NAME --other-args" // 4. Args only: "--model MODEL_NAME --other-args" // 5. Multiline commands with backslashes -func ParseVllmCommand(command string) (*VllmServerOptions, error) { +func (o *VllmServerOptions) ParseCommand(command string) (any, error) { executableNames := []string{"vllm"} subcommandNames := []string{"serve"} - multiValuedFlags := map[string]bool{ - "middleware": true, - "api_key": true, - "allowed_origins": true, - "allowed_methods": true, - "allowed_headers": true, - "lora_modules": true, - "prompt_adapters": true, - } var vllmOptions VllmServerOptions - if err := ParseCommand(command, executableNames, subcommandNames, multiValuedFlags, &vllmOptions); err != nil { + if err := parseCommand(command, executableNames, subcommandNames, vllmMultiValuedFlags, &vllmOptions); err != nil { return nil, err } diff --git a/pkg/backends/vllm_test.go b/pkg/backends/vllm_test.go index b9e6a13..acec8d6 100644 --- a/pkg/backends/vllm_test.go +++ b/pkg/backends/vllm_test.go @@ -92,7 +92,9 @@ func TestParseVllmCommand(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := backends.ParseVllmCommand(tt.command) + var opts backends.VllmServerOptions + resultAny, err := opts.ParseCommand(tt.command) + result, _ := resultAny.(*backends.VllmServerOptions) if tt.expectErr { if err == nil { @@ -118,6 +120,41 @@ func TestParseVllmCommand(t *testing.T) { } } +func TestParseVllmCommandArrays(t *testing.T) { + command := "vllm serve test-model --middleware auth.py --middleware=cors.py --api-key key1 --api-key key2" + + var opts backends.VllmServerOptions + resultAny, err := opts.ParseCommand(command) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + result, ok := resultAny.(*backends.VllmServerOptions) + if !ok { + t.Fatalf("expected *VllmServerOptions, got %T", resultAny) + } + + expectedMiddleware := []string{"auth.py", "cors.py"} + if len(result.Middleware) != len(expectedMiddleware) { + t.Errorf("expected %d middleware items, got %d", len(expectedMiddleware), len(result.Middleware)) + } + for i, v := range expectedMiddleware { + if i >= len(result.Middleware) || result.Middleware[i] != v { + t.Errorf("expected middleware[%d]=%s got %s", i, v, result.Middleware[i]) + } + } + + expectedAPIKeys := []string{"key1", "key2"} + if len(result.APIKey) != len(expectedAPIKeys) { + t.Errorf("expected %d api keys, got %d", len(expectedAPIKeys), len(result.APIKey)) + } + for i, v := range expectedAPIKeys { + if i >= len(result.APIKey) || result.APIKey[i] != v { + t.Errorf("expected api_key[%d]=%s got %s", i, v, result.APIKey[i]) + } + } +} + func TestVllmBuildCommandArgs_BooleanFields(t *testing.T) { tests := []struct { name string @@ -173,11 +210,11 @@ func TestVllmBuildCommandArgs_BooleanFields(t *testing.T) { func TestVllmBuildCommandArgs_ZeroValues(t *testing.T) { options := backends.VllmServerOptions{ - Port: 0, // Should be excluded - TensorParallelSize: 0, // Should be excluded - GPUMemoryUtilization: 0, // Should be excluded - Model: "", // Should be excluded (positional arg) - Host: "", // Should be excluded + Port: 0, // Should be excluded + TensorParallelSize: 0, // Should be excluded + GPUMemoryUtilization: 0, // Should be excluded + Model: "", // Should be excluded (positional arg) + Host: "", // Should be excluded EnableLogOutputs: false, // Should be excluded } diff --git a/pkg/server/handlers.go b/pkg/server/handlers.go index 4ddbfea..78b83c5 100644 --- a/pkg/server/handlers.go +++ b/pkg/server/handlers.go @@ -1,18 +1,60 @@ package server import ( + "encoding/json" + "fmt" "llamactl/pkg/config" + "llamactl/pkg/instance" "llamactl/pkg/manager" + "llamactl/pkg/validation" + "log" "net/http" "time" + + "github.com/go-chi/chi/v5" ) +// errorResponse represents an error response returned by the API +type errorResponse struct { + Error string `json:"error"` + Details string `json:"details,omitempty"` +} + +// writeError writes a JSON error response with the specified HTTP status code +func writeError(w http.ResponseWriter, status int, code, details string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(errorResponse{Error: code, Details: details}); err != nil { + log.Printf("Failed to encode error response: %v", err) + } +} + +// writeJSON writes a JSON response with the specified HTTP status code +func writeJSON(w http.ResponseWriter, status int, data any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(data); err != nil { + log.Printf("Failed to encode JSON response: %v", err) + } +} + +// writeText writes a plain text response with the specified HTTP status code +func writeText(w http.ResponseWriter, status int, data string) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(status) + if _, err := w.Write([]byte(data)); err != nil { + log.Printf("Failed to write text response: %v", err) + } +} + +// Handler provides HTTP handlers for the llamactl server API type Handler struct { InstanceManager manager.InstanceManager cfg config.AppConfig httpClient *http.Client } +// NewHandler creates a new Handler instance with the provided instance manager and configuration func NewHandler(im manager.InstanceManager, cfg config.AppConfig) *Handler { return &Handler{ InstanceManager: im, @@ -22,3 +64,52 @@ func NewHandler(im manager.InstanceManager, cfg config.AppConfig) *Handler { }, } } + +// getInstance retrieves an instance by name from the request query parameters +func (h *Handler) getInstance(r *http.Request) (*instance.Instance, error) { + name := chi.URLParam(r, "name") + validatedName, err := validation.ValidateInstanceName(name) + if err != nil { + return nil, fmt.Errorf("invalid instance name: %w", err) + } + + inst, err := h.InstanceManager.GetInstance(validatedName) + if err != nil { + return nil, fmt.Errorf("failed to get instance by name: %w", err) + } + + return inst, nil +} + +// ensureInstanceRunning ensures the instance is running by starting it if on-demand start is enabled +// It handles LRU eviction when the maximum number of running instances is reached +func (h *Handler) ensureInstanceRunning(inst *instance.Instance) error { + options := inst.GetOptions() + allowOnDemand := options != nil && options.OnDemandStart != nil && *options.OnDemandStart + if !allowOnDemand { + return fmt.Errorf("instance is not running and on-demand start is not enabled") + } + + if h.InstanceManager.IsMaxRunningInstancesReached() { + if h.cfg.Instances.EnableLRUEviction { + err := h.InstanceManager.EvictLRUInstance() + if err != nil { + return fmt.Errorf("cannot start instance, failed to evict instance: %w", err) + } + } else { + return fmt.Errorf("cannot start instance, maximum number of instances reached") + } + } + + // If on-demand start is enabled, start the instance + if _, err := h.InstanceManager.StartInstance(inst.Name); err != nil { + return fmt.Errorf("failed to start instance: %w", err) + } + + // Wait for the instance to become healthy before proceeding + if err := inst.WaitForHealthy(h.cfg.Instances.OnDemandStartTimeout); err != nil { + return fmt.Errorf("instance failed to become healthy: %w", err) + } + + return nil +} diff --git a/pkg/server/handlers_backends.go b/pkg/server/handlers_backends.go index d3132af..47ef02d 100644 --- a/pkg/server/handlers_backends.go +++ b/pkg/server/handlers_backends.go @@ -5,99 +5,148 @@ import ( "fmt" "llamactl/pkg/backends" "llamactl/pkg/instance" - "llamactl/pkg/validation" "net/http" "os/exec" "strings" - - "github.com/go-chi/chi/v5" ) -// ParseCommandRequest represents the request body for command parsing +// ParseCommandRequest represents the request body for backend command parsing type ParseCommandRequest struct { Command string `json:"command"` } -func (h *Handler) LlamaCppProxy(onDemandStart bool) http.HandlerFunc { +// validateLlamaCppInstance validates that the instance specified in the request is a llama.cpp instance +func (h *Handler) validateLlamaCppInstance(r *http.Request) (*instance.Instance, error) { + inst, err := h.getInstance(r) + if err != nil { + return nil, fmt.Errorf("invalid instance: %w", err) + } + + options := inst.GetOptions() + if options == nil { + return nil, fmt.Errorf("cannot obtain instance's options") + } + + if options.BackendOptions.BackendType != backends.BackendTypeLlamaCpp { + return nil, fmt.Errorf("instance is not a llama.cpp server") + } + + return inst, nil +} + +// stripLlamaCppPrefix removes the llama.cpp proxy prefix from the request URL path +func (h *Handler) stripLlamaCppPrefix(r *http.Request, instName string) { + // Strip the "/llama-cpp/" prefix from the request URL + prefix := fmt.Sprintf("/llama-cpp/%s", instName) + r.URL.Path = strings.TrimPrefix(r.URL.Path, prefix) +} + +// LlamaCppUIProxy godoc +// @Summary Proxy requests to llama.cpp UI for the instance +// @Description Proxies requests to the llama.cpp UI for the specified instance +// @Tags backends +// @Security ApiKeyAuth +// @Produce html +// @Param name query string true "Instance Name" +// @Success 200 {string} string "Proxied HTML response" +// @Failure 400 {string} string "Invalid instance" +// @Failure 500 {string} string "Internal Server Error" +// @Router /llama-cpp/{name}/ [get] +func (h *Handler) LlamaCppUIProxy() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - // Get the instance name from the URL parameter - name := chi.URLParam(r, "name") - - // Validate instance name at the entry point - validatedName, err := validation.ValidateInstanceName(name) + inst, err := h.validateLlamaCppInstance(r) if err != nil { - http.Error(w, "Invalid instance name: "+err.Error(), http.StatusBadRequest) - return - } - - // Route to the appropriate inst based on instance name - inst, err := h.InstanceManager.GetInstance(validatedName) - if err != nil { - http.Error(w, "Invalid instance: "+err.Error(), http.StatusBadRequest) - return - } - - options := inst.GetOptions() - if options == nil { - http.Error(w, "Cannot obtain Instance's options", http.StatusInternalServerError) - return - } - - if options.BackendOptions.BackendType != backends.BackendTypeLlamaCpp { - http.Error(w, "Instance is not a llama.cpp server.", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid instance", err.Error()) return } if !inst.IsRemote() && !inst.IsRunning() { + writeError(w, http.StatusBadRequest, "instance is not running", "Instance is not running") + return + } - if !(onDemandStart && options.OnDemandStart != nil && *options.OnDemandStart) { - http.Error(w, "Instance is not running", http.StatusServiceUnavailable) - return - } + proxy, err := inst.GetProxy() + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to get proxy", err.Error()) + return + } - if h.InstanceManager.IsMaxRunningInstancesReached() { - if h.cfg.Instances.EnableLRUEviction { - err := h.InstanceManager.EvictLRUInstance() - if err != nil { - http.Error(w, "Cannot start Instance, failed to evict instance "+err.Error(), http.StatusInternalServerError) - return - } - } else { - http.Error(w, "Cannot start Instance, maximum number of instances reached", http.StatusConflict) - return - } - } + if !inst.IsRemote() { + h.stripLlamaCppPrefix(r, inst.Name) + } - // If on-demand start is enabled, start the instance - if _, err := h.InstanceManager.StartInstance(validatedName); err != nil { - http.Error(w, "Failed to start instance: "+err.Error(), http.StatusInternalServerError) - return - } + proxy.ServeHTTP(w, r) + } +} - // Wait for the instance to become healthy before proceeding - if err := inst.WaitForHealthy(h.cfg.Instances.OnDemandStartTimeout); err != nil { // 2 minutes timeout - http.Error(w, "Instance failed to become healthy: "+err.Error(), http.StatusServiceUnavailable) +// LlamaCppProxy godoc +// @Summary Proxy requests to llama.cpp server instance +// @Description Proxies requests to the specified llama.cpp server instance, starting it on-demand if configured +// @Tags backends +// @Security ApiKeyAuth +// @Produce json +// @Param name query string true "Instance Name" +// @Success 200 {object} map[string]any "Proxied response" +// @Failure 400 {string} string "Invalid instance" +// @Failure 500 {string} string "Internal Server Error" +// @Router /llama-cpp/{name}/* [post] +func (h *Handler) LlamaCppProxy() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + + inst, err := h.validateLlamaCppInstance(r) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid instance", err.Error()) + return + } + + if !inst.IsRemote() && !inst.IsRunning() { + err := h.ensureInstanceRunning(inst) + if err != nil { + writeError(w, http.StatusInternalServerError, "instance start failed", err.Error()) return } } proxy, err := inst.GetProxy() if err != nil { - http.Error(w, "Failed to get proxy: "+err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, "failed to get proxy", err.Error()) return } if !inst.IsRemote() { - // Strip the "/llama-cpp/" prefix from the request URL - prefix := fmt.Sprintf("/llama-cpp/%s", validatedName) - r.URL.Path = strings.TrimPrefix(r.URL.Path, prefix) + h.stripLlamaCppPrefix(r, inst.Name) } proxy.ServeHTTP(w, r) } } +// parseHelper parses a backend command and returns the parsed options +func parseHelper(w http.ResponseWriter, r *http.Request, backend interface { + ParseCommand(string) (any, error) +}) (any, bool) { + var req ParseCommandRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "Invalid JSON body") + return nil, false + } + + if strings.TrimSpace(req.Command) == "" { + writeError(w, http.StatusBadRequest, "invalid_command", "Command cannot be empty") + return nil, false + } + + // Parse command using the backend's ParseCommand method + parsedOptions, err := backend.ParseCommand(req.Command) + if err != nil { + writeError(w, http.StatusBadRequest, "parse_error", err.Error()) + return nil, false + } + + return parsedOptions, true +} + // ParseLlamaCommand godoc // @Summary Parse llama-server command // @Description Parses a llama-server command string into instance options @@ -111,40 +160,20 @@ func (h *Handler) LlamaCppProxy(onDemandStart bool) http.HandlerFunc { // @Failure 500 {object} map[string]string "Internal Server Error" // @Router /backends/llama-cpp/parse-command [post] func (h *Handler) ParseLlamaCommand() http.HandlerFunc { - type errorResponse struct { - Error string `json:"error"` - Details string `json:"details,omitempty"` - } - writeError := func(w http.ResponseWriter, status int, code, details string) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(errorResponse{Error: code, Details: details}) - } return func(w http.ResponseWriter, r *http.Request) { - var req ParseCommandRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeError(w, http.StatusBadRequest, "invalid_request", "Invalid JSON body") - return - } - if strings.TrimSpace(req.Command) == "" { - writeError(w, http.StatusBadRequest, "invalid_command", "Command cannot be empty") - return - } - llamaOptions, err := backends.ParseLlamaCommand(req.Command) - if err != nil { - writeError(w, http.StatusBadRequest, "parse_error", err.Error()) + parsedOptions, ok := parseHelper(w, r, &backends.LlamaServerOptions{}) + if !ok { return } + options := &instance.Options{ BackendOptions: backends.Options{ BackendType: backends.BackendTypeLlamaCpp, - LlamaServerOptions: llamaOptions, + LlamaServerOptions: parsedOptions.(*backends.LlamaServerOptions), }, } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(options); err != nil { - writeError(w, http.StatusInternalServerError, "encode_error", err.Error()) - } + + writeJSON(w, http.StatusOK, options) } } @@ -160,47 +189,20 @@ func (h *Handler) ParseLlamaCommand() http.HandlerFunc { // @Failure 400 {object} map[string]string "Invalid request or command" // @Router /backends/mlx/parse-command [post] func (h *Handler) ParseMlxCommand() http.HandlerFunc { - type errorResponse struct { - Error string `json:"error"` - Details string `json:"details,omitempty"` - } - writeError := func(w http.ResponseWriter, status int, code, details string) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(errorResponse{Error: code, Details: details}) - } return func(w http.ResponseWriter, r *http.Request) { - var req ParseCommandRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeError(w, http.StatusBadRequest, "invalid_request", "Invalid JSON body") + parsedOptions, ok := parseHelper(w, r, &backends.MlxServerOptions{}) + if !ok { return } - if strings.TrimSpace(req.Command) == "" { - writeError(w, http.StatusBadRequest, "invalid_command", "Command cannot be empty") - return - } - - mlxOptions, err := backends.ParseMlxCommand(req.Command) - if err != nil { - writeError(w, http.StatusBadRequest, "parse_error", err.Error()) - return - } - - // Currently only support mlx_lm backend type - backendType := backends.BackendTypeMlxLm - options := &instance.Options{ BackendOptions: backends.Options{ - BackendType: backendType, - MlxServerOptions: mlxOptions, + BackendType: backends.BackendTypeMlxLm, + MlxServerOptions: parsedOptions.(*backends.MlxServerOptions), }, } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(options); err != nil { - writeError(w, http.StatusInternalServerError, "encode_error", err.Error()) - } + writeJSON(w, http.StatusOK, options) } } @@ -216,46 +218,33 @@ func (h *Handler) ParseMlxCommand() http.HandlerFunc { // @Failure 400 {object} map[string]string "Invalid request or command" // @Router /backends/vllm/parse-command [post] func (h *Handler) ParseVllmCommand() http.HandlerFunc { - type errorResponse struct { - Error string `json:"error"` - Details string `json:"details,omitempty"` - } - writeError := func(w http.ResponseWriter, status int, code, details string) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(errorResponse{Error: code, Details: details}) - } return func(w http.ResponseWriter, r *http.Request) { - var req ParseCommandRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeError(w, http.StatusBadRequest, "invalid_request", "Invalid JSON body") + parsedOptions, ok := parseHelper(w, r, &backends.VllmServerOptions{}) + if !ok { return } - if strings.TrimSpace(req.Command) == "" { - writeError(w, http.StatusBadRequest, "invalid_command", "Command cannot be empty") - return - } - - vllmOptions, err := backends.ParseVllmCommand(req.Command) - if err != nil { - writeError(w, http.StatusBadRequest, "parse_error", err.Error()) - return - } - - backendType := backends.BackendTypeVllm - options := &instance.Options{ BackendOptions: backends.Options{ - BackendType: backendType, - VllmServerOptions: vllmOptions, + BackendType: backends.BackendTypeVllm, + VllmServerOptions: parsedOptions.(*backends.VllmServerOptions), }, } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(options); err != nil { - writeError(w, http.StatusInternalServerError, "encode_error", err.Error()) + writeJSON(w, http.StatusOK, options) + } +} + +// executeLlamaServerCommand executes a llama-server command with the specified flag and returns the output +func (h *Handler) executeLlamaServerCommand(flag, errorMsg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + cmd := exec.Command("llama-server", flag) + output, err := cmd.CombinedOutput() + if err != nil { + writeError(w, http.StatusInternalServerError, "command failed", errorMsg+": "+err.Error()) + return } + writeText(w, http.StatusOK, string(output)) } } @@ -269,16 +258,7 @@ func (h *Handler) ParseVllmCommand() http.HandlerFunc { // @Failure 500 {string} string "Internal Server Error" // @Router /backends/llama-cpp/help [get] func (h *Handler) LlamaServerHelpHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - helpCmd := exec.Command("llama-server", "--help") - output, err := helpCmd.CombinedOutput() - if err != nil { - http.Error(w, "Failed to get help: "+err.Error(), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "text/plain") - w.Write(output) - } + return h.executeLlamaServerCommand("--help", "Failed to get help") } // LlamaServerVersionHandler godoc @@ -291,16 +271,7 @@ func (h *Handler) LlamaServerHelpHandler() http.HandlerFunc { // @Failure 500 {string} string "Internal Server Error" // @Router /backends/llama-cpp/version [get] func (h *Handler) LlamaServerVersionHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - versionCmd := exec.Command("llama-server", "--version") - output, err := versionCmd.CombinedOutput() - if err != nil { - http.Error(w, "Failed to get version: "+err.Error(), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "text/plain") - w.Write(output) - } + return h.executeLlamaServerCommand("--version", "Failed to get version") } // LlamaServerListDevicesHandler godoc @@ -313,14 +284,5 @@ func (h *Handler) LlamaServerVersionHandler() http.HandlerFunc { // @Failure 500 {string} string "Internal Server Error" // @Router /backends/llama-cpp/devices [get] func (h *Handler) LlamaServerListDevicesHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - listCmd := exec.Command("llama-server", "--list-devices") - output, err := listCmd.CombinedOutput() - if err != nil { - http.Error(w, "Failed to list devices: "+err.Error(), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "text/plain") - w.Write(output) - } + return h.executeLlamaServerCommand("--list-devices", "Failed to list devices") } diff --git a/pkg/server/handlers_instances.go b/pkg/server/handlers_instances.go index 7a444d0..24fe3e7 100644 --- a/pkg/server/handlers_instances.go +++ b/pkg/server/handlers_instances.go @@ -26,15 +26,11 @@ func (h *Handler) ListInstances() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { instances, err := h.InstanceManager.ListInstances() if err != nil { - http.Error(w, "Failed to list instances: "+err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, "list_failed", "Failed to list instances: "+err.Error()) return } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(instances); err != nil { - http.Error(w, "Failed to encode instances: "+err.Error(), http.StatusInternalServerError) - return - } + writeJSON(w, http.StatusOK, instances) } } @@ -54,31 +50,25 @@ func (h *Handler) ListInstances() http.HandlerFunc { func (h *Handler) CreateInstance() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") - validatedName, err := validation.ValidateInstanceName(name) if err != nil { - http.Error(w, "Invalid instance name: "+err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_instance_name", err.Error()) return } var options instance.Options if err := json.NewDecoder(r.Body).Decode(&options); err != nil { - http.Error(w, "Invalid request body", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_request", "Invalid request body") return } inst, err := h.InstanceManager.CreateInstance(validatedName, &options) if err != nil { - http.Error(w, "Failed to create instance: "+err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, "create_failed", "Failed to create instance: "+err.Error()) return } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusCreated) - if err := json.NewEncoder(w).Encode(inst); err != nil { - http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError) - return - } + writeJSON(w, http.StatusCreated, inst) } } @@ -96,24 +86,19 @@ func (h *Handler) CreateInstance() http.HandlerFunc { func (h *Handler) GetInstance() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") - validatedName, err := validation.ValidateInstanceName(name) if err != nil { - http.Error(w, "Invalid instance name: "+err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_instance_name", err.Error()) return } inst, err := h.InstanceManager.GetInstance(validatedName) if err != nil { - http.Error(w, "Invalid instance: "+err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_instance", err.Error()) return } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(inst); err != nil { - http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError) - return - } + writeJSON(w, http.StatusOK, inst) } } @@ -133,30 +118,25 @@ func (h *Handler) GetInstance() http.HandlerFunc { func (h *Handler) UpdateInstance() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") - validatedName, err := validation.ValidateInstanceName(name) if err != nil { - http.Error(w, "Invalid instance name: "+err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_instance_name", err.Error()) return } var options instance.Options if err := json.NewDecoder(r.Body).Decode(&options); err != nil { - http.Error(w, "Invalid request body", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_request", "Invalid request body") return } inst, err := h.InstanceManager.UpdateInstance(validatedName, &options) if err != nil { - http.Error(w, "Failed to update instance: "+err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, "update_failed", "Failed to update instance: "+err.Error()) return } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(inst); err != nil { - http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError) - return - } + writeJSON(w, http.StatusOK, inst) } } @@ -174,10 +154,9 @@ func (h *Handler) UpdateInstance() http.HandlerFunc { func (h *Handler) StartInstance() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") - validatedName, err := validation.ValidateInstanceName(name) if err != nil { - http.Error(w, "Invalid instance name: "+err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_instance_name", err.Error()) return } @@ -185,19 +164,15 @@ func (h *Handler) StartInstance() http.HandlerFunc { if err != nil { // Check if error is due to maximum running instances limit if _, ok := err.(manager.MaxRunningInstancesError); ok { - http.Error(w, err.Error(), http.StatusConflict) + writeError(w, http.StatusConflict, "max_instances_reached", err.Error()) return } - http.Error(w, "Failed to start instance: "+err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, "start_failed", "Failed to start instance: "+err.Error()) return } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(inst); err != nil { - http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError) - return - } + writeJSON(w, http.StatusOK, inst) } } @@ -215,24 +190,19 @@ func (h *Handler) StartInstance() http.HandlerFunc { func (h *Handler) StopInstance() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") - validatedName, err := validation.ValidateInstanceName(name) if err != nil { - http.Error(w, "Invalid instance name: "+err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_instance_name", err.Error()) return } inst, err := h.InstanceManager.StopInstance(validatedName) if err != nil { - http.Error(w, "Failed to stop instance: "+err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, "stop_failed", "Failed to stop instance: "+err.Error()) return } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(inst); err != nil { - http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError) - return - } + writeJSON(w, http.StatusOK, inst) } } @@ -250,24 +220,19 @@ func (h *Handler) StopInstance() http.HandlerFunc { func (h *Handler) RestartInstance() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") - validatedName, err := validation.ValidateInstanceName(name) if err != nil { - http.Error(w, "Invalid instance name: "+err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_instance_name", err.Error()) return } inst, err := h.InstanceManager.RestartInstance(validatedName) if err != nil { - http.Error(w, "Failed to restart instance: "+err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, "restart_failed", "Failed to restart instance: "+err.Error()) return } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(inst); err != nil { - http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError) - return - } + writeJSON(w, http.StatusOK, inst) } } @@ -284,15 +249,14 @@ func (h *Handler) RestartInstance() http.HandlerFunc { func (h *Handler) DeleteInstance() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") - validatedName, err := validation.ValidateInstanceName(name) if err != nil { - http.Error(w, "Invalid instance name: "+err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_instance_name", err.Error()) return } if err := h.InstanceManager.DeleteInstance(validatedName); err != nil { - http.Error(w, "Failed to delete instance: "+err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, "delete_failed", "Failed to delete instance: "+err.Error()) return } @@ -315,10 +279,9 @@ func (h *Handler) DeleteInstance() http.HandlerFunc { func (h *Handler) GetInstanceLogs() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") - validatedName, err := validation.ValidateInstanceName(name) if err != nil { - http.Error(w, "Invalid instance name: "+err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_instance_name", err.Error()) return } @@ -327,7 +290,7 @@ func (h *Handler) GetInstanceLogs() http.HandlerFunc { if lines != "" { parsedLines, err := strconv.Atoi(lines) if err != nil { - http.Error(w, "Invalid lines parameter: "+err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_parameter", "Invalid lines parameter: "+err.Error()) return } numLines = parsedLines @@ -336,17 +299,16 @@ func (h *Handler) GetInstanceLogs() http.HandlerFunc { // Use the instance manager which handles both local and remote instances logs, err := h.InstanceManager.GetInstanceLogs(validatedName, numLines) if err != nil { - http.Error(w, "Failed to get logs: "+err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, "logs_failed", "Failed to get logs: "+err.Error()) return } - w.Header().Set("Content-Type", "text/plain") - w.Write([]byte(logs)) + writeText(w, http.StatusOK, logs) } } -// ProxyToInstance godoc -// @Summary Proxy requests to a specific instance +// InstanceProxy godoc +// @Summary Proxy requests to a specific instance, does not autostart instance if stopped // @Description Forwards HTTP requests to the llama-server instance running on a specific port // @Tags instances // @Security ApiKeyAuth @@ -357,38 +319,28 @@ func (h *Handler) GetInstanceLogs() http.HandlerFunc { // @Failure 503 {string} string "Instance is not running" // @Router /instances/{name}/proxy [get] // @Router /instances/{name}/proxy [post] -func (h *Handler) ProxyToInstance() http.HandlerFunc { +func (h *Handler) InstanceProxy() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - name := chi.URLParam(r, "name") - - validatedName, err := validation.ValidateInstanceName(name) + inst, err := h.getInstance(r) if err != nil { - http.Error(w, "Invalid instance name: "+err.Error(), http.StatusBadRequest) - return - } - - inst, err := h.InstanceManager.GetInstance(validatedName) - if err != nil { - http.Error(w, "Failed to get instance: "+err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusBadRequest, "invalid_instance", err.Error()) return } if !inst.IsRunning() { - http.Error(w, "Instance is not running", http.StatusServiceUnavailable) + writeError(w, http.StatusServiceUnavailable, "instance_not_running", "Instance is not running") return } - // Get the cached proxy for this instance proxy, err := inst.GetProxy() if err != nil { - http.Error(w, "Failed to get proxy: "+err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, "proxy_failed", "Failed to get proxy: "+err.Error()) return } - // Check if this is a remote instance if !inst.IsRemote() { // Strip the "/api/v1/instances//proxy" prefix from the request URL - prefix := fmt.Sprintf("/api/v1/instances/%s/proxy", validatedName) + prefix := fmt.Sprintf("/api/v1/instances/%s/proxy", inst.Name) r.URL.Path = strings.TrimPrefix(r.URL.Path, prefix) } @@ -396,7 +348,6 @@ func (h *Handler) ProxyToInstance() http.HandlerFunc { r.Header.Set("X-Forwarded-Host", r.Header.Get("Host")) r.Header.Set("X-Forwarded-Proto", "http") - // Forward the request using the cached proxy proxy.ServeHTTP(w, r) } } diff --git a/pkg/server/handlers_nodes.go b/pkg/server/handlers_nodes.go index 98a4b43..7c84b0a 100644 --- a/pkg/server/handlers_nodes.go +++ b/pkg/server/handlers_nodes.go @@ -1,13 +1,12 @@ package server import ( - "encoding/json" "net/http" "github.com/go-chi/chi/v5" ) -// NodeResponse represents a sanitized node configuration for API responses +// NodeResponse represents a node configuration in API responses type NodeResponse struct { Address string `json:"address"` } @@ -31,11 +30,7 @@ func (h *Handler) ListNodes() http.HandlerFunc { } } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(nodeResponses); err != nil { - http.Error(w, "Failed to encode nodes: "+err.Error(), http.StatusInternalServerError) - return - } + writeJSON(w, http.StatusOK, nodeResponses) } } @@ -55,13 +50,13 @@ func (h *Handler) GetNode() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") if name == "" { - http.Error(w, "Node name cannot be empty", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_request", "Node name cannot be empty") return } nodeConfig, exists := h.cfg.Nodes[name] if !exists { - http.Error(w, "Node not found", http.StatusNotFound) + writeError(w, http.StatusNotFound, "not_found", "Node not found") return } @@ -70,10 +65,6 @@ func (h *Handler) GetNode() http.HandlerFunc { Address: nodeConfig.Address, } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(nodeResponse); err != nil { - http.Error(w, "Failed to encode node: "+err.Error(), http.StatusInternalServerError) - return - } + writeJSON(w, http.StatusOK, nodeResponse) } } diff --git a/pkg/server/handlers_openai.go b/pkg/server/handlers_openai.go index 9ad3207..35ac746 100644 --- a/pkg/server/handlers_openai.go +++ b/pkg/server/handlers_openai.go @@ -8,6 +8,20 @@ import ( "net/http" ) +// OpenAIListInstancesResponse represents the response structure for listing instances (models) in OpenAI-compatible format +type OpenAIListInstancesResponse struct { + Object string `json:"object"` + Data []OpenAIInstance `json:"data"` +} + +// OpenAIInstance represents a single instance (model) in OpenAI-compatible format +type OpenAIInstance struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` +} + // OpenAIListInstances godoc // @Summary List instances in OpenAI-compatible format // @Description Returns a list of instances in a format compatible with OpenAI API @@ -21,7 +35,7 @@ func (h *Handler) OpenAIListInstances() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { instances, err := h.InstanceManager.ListInstances() if err != nil { - http.Error(w, "Failed to list instances: "+err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, "list_failed", "Failed to list instances: "+err.Error()) return } @@ -40,11 +54,7 @@ func (h *Handler) OpenAIListInstances() http.HandlerFunc { Data: openaiInstances, } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(openaiResponse); err != nil { - http.Error(w, "Failed to encode instances: "+err.Error(), http.StatusInternalServerError) - return - } + writeJSON(w, http.StatusOK, openaiResponse) } } @@ -64,7 +74,7 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc { // Read the entire body first bodyBytes, err := io.ReadAll(r.Body) if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_request", "Failed to read request body") return } r.Body.Close() @@ -72,67 +82,41 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc { // Parse the body to extract instance name var requestBody map[string]any if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { - http.Error(w, "Invalid request body", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_request", "Invalid request body") return } modelName, ok := requestBody["model"].(string) if !ok || modelName == "" { - http.Error(w, "Instance name is required", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_request", "Instance name is required") return } // Validate instance name at the entry point validatedName, err := validation.ValidateInstanceName(modelName) if err != nil { - http.Error(w, "Invalid instance name: "+err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_instance_name", err.Error()) return } // Route to the appropriate inst based on instance name inst, err := h.InstanceManager.GetInstance(validatedName) if err != nil { - http.Error(w, "Invalid instance: "+err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, "invalid_instance", err.Error()) return } if !inst.IsRemote() && !inst.IsRunning() { - options := inst.GetOptions() - allowOnDemand := options != nil && options.OnDemandStart != nil && *options.OnDemandStart - if !allowOnDemand { - http.Error(w, "Instance is not running", http.StatusServiceUnavailable) - return - } - - if h.InstanceManager.IsMaxRunningInstancesReached() { - if h.cfg.Instances.EnableLRUEviction { - err := h.InstanceManager.EvictLRUInstance() - if err != nil { - http.Error(w, "Cannot start Instance, failed to evict instance "+err.Error(), http.StatusInternalServerError) - return - } - } else { - http.Error(w, "Cannot start Instance, maximum number of instances reached", http.StatusConflict) - return - } - } - - // If on-demand start is enabled, start the instance - if _, err := h.InstanceManager.StartInstance(validatedName); err != nil { - http.Error(w, "Failed to start instance: "+err.Error(), http.StatusInternalServerError) - return - } - - // Wait for the instance to become healthy before proceeding - if err := inst.WaitForHealthy(h.cfg.Instances.OnDemandStartTimeout); err != nil { // 2 minutes timeout - http.Error(w, "Instance failed to become healthy: "+err.Error(), http.StatusServiceUnavailable) + err := h.ensureInstanceRunning(inst) + if err != nil { + writeError(w, http.StatusInternalServerError, "instance_start_failed", err.Error()) return } } proxy, err := inst.GetProxy() if err != nil { - http.Error(w, "Failed to get proxy: "+err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, "proxy_failed", err.Error()) return } diff --git a/pkg/server/handlers_system.go b/pkg/server/handlers_system.go index e3bb016..2e61288 100644 --- a/pkg/server/handlers_system.go +++ b/pkg/server/handlers_system.go @@ -16,7 +16,7 @@ import ( // @Router /version [get] func (h *Handler) VersionHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - fmt.Fprintf(w, "Version: %s\nCommit: %s\nBuild Time: %s\n", h.cfg.Version, h.cfg.CommitHash, h.cfg.BuildTime) + versionInfo := fmt.Sprintf("Version: %s\nCommit: %s\nBuild Time: %s\n", h.cfg.Version, h.cfg.CommitHash, h.cfg.BuildTime) + writeText(w, http.StatusOK, versionInfo) } } diff --git a/pkg/server/openai.go b/pkg/server/openai.go deleted file mode 100644 index 98d1043..0000000 --- a/pkg/server/openai.go +++ /dev/null @@ -1,13 +0,0 @@ -package server - -type OpenAIListInstancesResponse struct { - Object string `json:"object"` - Data []OpenAIInstance `json:"data"` -} - -type OpenAIInstance struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - OwnedBy string `json:"owned_by"` -} diff --git a/pkg/server/routes.go b/pkg/server/routes.go index 6ced6a7..ffe89ec 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -86,7 +86,7 @@ func SetupRouter(handler *Handler) *chi.Mux { // Llama.cpp server proxy endpoints (proxied to the actual llama.cpp server) r.Route("/proxy", func(r chi.Router) { - r.HandleFunc("/*", handler.ProxyToInstance()) // Proxy all llama.cpp server requests + r.HandleFunc("/*", handler.InstanceProxy()) // Proxy all llama.cpp server requests }) }) }) @@ -117,7 +117,7 @@ func SetupRouter(handler *Handler) *chi.Mux { // Public Routes // Allow llama-cpp server to serve its own WebUI if it is running. // Don't auto start the server since it can be accessed without an API key - r.Get("/", handler.LlamaCppProxy(false)) + r.Get("/", handler.LlamaCppUIProxy()) // Private Routes r.Group(func(r chi.Router) { @@ -127,7 +127,7 @@ func SetupRouter(handler *Handler) *chi.Mux { } // This handler auto start the server if it's not running - llamaCppHandler := handler.LlamaCppProxy(true) + llamaCppHandler := handler.LlamaCppProxy() // llama.cpp server specific proxy endpoints r.Get("/props", llamaCppHandler)