diff --git a/README.md b/README.md index 7b917e6..0f27290 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,7 @@ server: host: "0.0.0.0" # Server host to bind to port: 8080 # Server port to bind to allowed_origins: ["*"] # Allowed CORS origins (default: all) + allowed_headers: ["*"] # Allowed CORS headers (default: all) enable_swagger: false # Enable Swagger UI for API docs backends: diff --git a/docs/getting-started/configuration.md b/docs/getting-started/configuration.md index 1ed750e..be4fc6d 100644 --- a/docs/getting-started/configuration.md +++ b/docs/getting-started/configuration.md @@ -17,6 +17,7 @@ server: host: "0.0.0.0" # Server host to bind to port: 8080 # Server port to bind to allowed_origins: ["*"] # Allowed CORS origins (default: all) + allowed_headers: ["*"] # Allowed CORS headers (default: all) enable_swagger: false # Enable Swagger UI for API docs backends: @@ -104,6 +105,7 @@ server: host: "0.0.0.0" # Server host to bind to (default: "0.0.0.0") port: 8080 # Server port to bind to (default: 8080) allowed_origins: ["*"] # CORS allowed origins (default: ["*"]) + allowed_headers: ["*"] # CORS allowed headers (default: ["*"]) enable_swagger: false # Enable Swagger UI (default: false) ``` diff --git a/llamactl.yaml b/llamactl.yaml new file mode 100644 index 0000000..1c616eb --- /dev/null +++ b/llamactl.yaml @@ -0,0 +1,5 @@ +auth: + management_keys: + - test-mgmt + inference_keys: + - test-inf diff --git a/pkg/config/config.go b/pkg/config/config.go index 1d86f4c..77637b8 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -58,6 +58,9 @@ type ServerConfig struct { // Allowed origins for CORS (e.g., "http://localhost:3000") AllowedOrigins []string `yaml:"allowed_origins"` + // Allowed headers for CORS (e.g., "Accept", "Authorization", "Content-Type", "X-CSRF-Token") + AllowedHeaders []string `yaml:"allowed_headers"` + // Enable Swagger UI for API documentation EnableSwagger bool `yaml:"enable_swagger"` @@ -143,6 +146,7 @@ func LoadConfig(configPath string) (AppConfig, error) { Host: "0.0.0.0", Port: 8080, AllowedOrigins: []string{"*"}, // Default to allow all origins + AllowedHeaders: []string{"*"}, // Default to allow all headers EnableSwagger: false, }, Backends: BackendConfig{ diff --git a/pkg/manager/manager.go b/pkg/manager/manager.go index 354fbd2..a5a7138 100644 --- a/pkg/manager/manager.go +++ b/pkg/manager/manager.go @@ -314,19 +314,32 @@ func (im *instanceManager) loadInstance(name, path string) error { } // autoStartInstances starts instances that were running when persisted and have auto-restart enabled +// For instances with auto-restart disabled, it sets their status to Stopped func (im *instanceManager) autoStartInstances() { im.mu.RLock() var instancesToStart []*instance.Process + var instancesToStop []*instance.Process for _, inst := range im.instances { if inst.IsRunning() && // Was running when persisted inst.GetOptions() != nil && - inst.GetOptions().AutoRestart != nil && - *inst.GetOptions().AutoRestart { - instancesToStart = append(instancesToStart, inst) + inst.GetOptions().AutoRestart != nil { + if *inst.GetOptions().AutoRestart { + instancesToStart = append(instancesToStart, inst) + } else { + // Instance was running but auto-restart is disabled, mark as stopped + instancesToStop = append(instancesToStop, inst) + } } } im.mu.RUnlock() + // Stop instances that have auto-restart disabled + for _, inst := range instancesToStop { + log.Printf("Instance %s was running but auto-restart is disabled, setting status to stopped", inst.Name) + inst.SetStatus(instance.Stopped) + } + + // Start instances that have auto-restart enabled for _, inst := range instancesToStart { log.Printf("Auto-starting instance %s", inst.Name) // Reset running state before starting (since Start() expects stopped instance) diff --git a/pkg/manager/manager_test.go b/pkg/manager/manager_test.go index ed9bde5..e9fa1bc 100644 --- a/pkg/manager/manager_test.go +++ b/pkg/manager/manager_test.go @@ -209,3 +209,66 @@ func createTestManager() manager.InstanceManager { } return manager.NewInstanceManager(backendConfig, cfg, nil) } + +func TestAutoRestartDisabledInstanceStatus(t *testing.T) { + tempDir := t.TempDir() + + backendConfig := config.BackendConfig{ + LlamaCpp: config.BackendSettings{ + Command: "llama-server", + }, + } + + cfg := config.InstancesConfig{ + PortRange: [2]int{8000, 9000}, + InstancesDir: tempDir, + MaxInstances: 10, + TimeoutCheckInterval: 5, + } + + // Create first manager and instance with auto-restart disabled + manager1 := manager.NewInstanceManager(backendConfig, cfg) + + autoRestart := false + options := &instance.CreateInstanceOptions{ + BackendType: backends.BackendTypeLlamaCpp, + AutoRestart: &autoRestart, + LlamaServerOptions: &llamacpp.LlamaServerOptions{ + Model: "/path/to/model.gguf", + Port: 8080, + }, + } + + inst, err := manager1.CreateInstance("test-instance", options) + if err != nil { + t.Fatalf("CreateInstance failed: %v", err) + } + + // Simulate instance being in running state when persisted + // (this would happen if the instance was running when llamactl was stopped) + inst.SetStatus(instance.Running) + + // Shutdown first manager + manager1.Shutdown() + + // Create second manager (simulating restart of llamactl) + manager2 := manager.NewInstanceManager(backendConfig, cfg) + + // Get the loaded instance + loadedInst, err := manager2.GetInstance("test-instance") + if err != nil { + t.Fatalf("GetInstance failed: %v", err) + } + + // The instance should be marked as Stopped, not Running + // because auto-restart is disabled + if loadedInst.IsRunning() { + t.Errorf("Expected instance with auto-restart disabled to be stopped after manager restart, but it was running") + } + + if loadedInst.GetStatus() != instance.Stopped { + t.Errorf("Expected instance status to be Stopped, got %v", loadedInst.GetStatus()) + } + + manager2.Shutdown() +} diff --git a/pkg/server/handlers_backends.go b/pkg/server/handlers_backends.go index 5f55cd4..7d6cab0 100644 --- a/pkg/server/handlers_backends.go +++ b/pkg/server/handlers_backends.go @@ -2,6 +2,7 @@ package server import ( "encoding/json" + "fmt" "llamactl/pkg/backends" "llamactl/pkg/backends/llamacpp" "llamactl/pkg/backends/mlx" @@ -10,6 +11,8 @@ import ( "net/http" "os/exec" "strings" + + "github.com/go-chi/chi/v5" ) // ParseCommandRequest represents the request body for command parsing @@ -17,6 +20,84 @@ type ParseCommandRequest struct { Command string `json:"command"` } +func (h *Handler) LlamaCppProxy(onDemandStart bool) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + + // Get the instance name from the URL parameter + name := chi.URLParam(r, "name") + if name == "" { + http.Error(w, "Instance name cannot be empty", http.StatusBadRequest) + return + } + + // Route to the appropriate inst based on instance name + inst, err := h.InstanceManager.GetInstance(name) + if err != nil { + http.Error(w, "Invalid instance: "+err.Error(), http.StatusBadRequest) + return + } + + options := inst.GetOptions() + if options == nil { + http.Error(w, "Cannot obtain Instance's options", http.StatusInternalServerError) + return + } + + if options.BackendType != backends.BackendTypeLlamaCpp { + http.Error(w, "Instance is not a llama.cpp server.", http.StatusBadRequest) + return + } + + if !inst.IsRunning() { + + if !(onDemandStart && options.OnDemandStart != nil && *options.OnDemandStart) { + http.Error(w, "Instance is not running", http.StatusServiceUnavailable) + return + } + + if h.InstanceManager.IsMaxRunningInstancesReached() { + if h.cfg.Instances.EnableLRUEviction { + err := h.InstanceManager.EvictLRUInstance() + if err != nil { + http.Error(w, "Cannot start Instance, failed to evict instance "+err.Error(), http.StatusInternalServerError) + return + } + } else { + http.Error(w, "Cannot start Instance, maximum number of instances reached", http.StatusConflict) + return + } + } + + // If on-demand start is enabled, start the instance + if _, err := h.InstanceManager.StartInstance(name); err != nil { + http.Error(w, "Failed to start instance: "+err.Error(), http.StatusInternalServerError) + return + } + + // Wait for the instance to become healthy before proceeding + if err := inst.WaitForHealthy(h.cfg.Instances.OnDemandStartTimeout); err != nil { // 2 minutes timeout + http.Error(w, "Instance failed to become healthy: "+err.Error(), http.StatusServiceUnavailable) + return + } + } + + proxy, err := inst.GetProxy() + if err != nil { + http.Error(w, "Failed to get proxy: "+err.Error(), http.StatusInternalServerError) + return + } + + // Strip the "/llama-cpp/" prefix from the request URL + prefix := fmt.Sprintf("/llama-cpp/%s", name) + r.URL.Path = strings.TrimPrefix(r.URL.Path, prefix) + + // Update the last request time for the instance + inst.UpdateLastRequestTime() + + proxy.ServeHTTP(w, r) + } +} + // ParseLlamaCommand godoc // @Summary Parse llama-server command // @Description Parses a llama-server command string into instance options diff --git a/pkg/server/handlers_instances.go b/pkg/server/handlers_instances.go index e8b108c..6a90cba 100644 --- a/pkg/server/handlers_instances.go +++ b/pkg/server/handlers_instances.go @@ -102,7 +102,7 @@ func (h *Handler) GetInstance() http.HandlerFunc { inst, err := h.InstanceManager.GetInstance(name) if err != nil { - http.Error(w, "Failed to get instance: "+err.Error(), http.StatusInternalServerError) + http.Error(w, "Invalid instance: "+err.Error(), http.StatusBadRequest) return } @@ -361,12 +361,6 @@ func (h *Handler) ProxyToInstance() http.HandlerFunc { return } - // Check if this is a remote instance - if inst.IsRemote() { - h.RemoteInstanceProxy(w, r, name, inst) - return - } - if !inst.IsRunning() { http.Error(w, "Instance is not running", http.StatusServiceUnavailable) return @@ -381,29 +375,15 @@ func (h *Handler) ProxyToInstance() http.HandlerFunc { // Strip the "/api/v1/instances//proxy" prefix from the request URL prefix := fmt.Sprintf("/api/v1/instances/%s/proxy", name) - proxyPath := r.URL.Path[len(prefix):] - - // Ensure the proxy path starts with "/" - if !strings.HasPrefix(proxyPath, "/") { - proxyPath = "/" + proxyPath - } + r.URL.Path = strings.TrimPrefix(r.URL.Path, prefix) // Update the last request time for the instance inst.UpdateLastRequestTime() - // Modify the request to remove the proxy prefix - originalPath := r.URL.Path - r.URL.Path = proxyPath - // Set forwarded headers r.Header.Set("X-Forwarded-Host", r.Header.Get("Host")) r.Header.Set("X-Forwarded-Proto", "http") - // Restore original path for logging purposes - defer func() { - r.URL.Path = originalPath - }() - // Forward the request using the cached proxy proxy.ServeHTTP(w, r) } diff --git a/pkg/server/handlers_openai.go b/pkg/server/handlers_openai.go index ef4c29d..07196f0 100644 --- a/pkg/server/handlers_openai.go +++ b/pkg/server/handlers_openai.go @@ -87,7 +87,7 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc { // Route to the appropriate inst based on instance name inst, err := h.InstanceManager.GetInstance(modelName) if err != nil { - http.Error(w, "Failed to get instance: "+err.Error(), http.StatusInternalServerError) + http.Error(w, "Invalid instance: "+err.Error(), http.StatusBadRequest) return } @@ -98,7 +98,8 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc { } if !inst.IsRunning() { - allowOnDemand := inst.GetOptions() != nil && inst.GetOptions().OnDemandStart != nil && *inst.GetOptions().OnDemandStart + options := inst.GetOptions() + allowOnDemand := options != nil && options.OnDemandStart != nil && *options.OnDemandStart if !allowOnDemand { http.Error(w, "Instance is not running", http.StatusServiceUnavailable) return diff --git a/pkg/server/routes.go b/pkg/server/routes.go index d14baec..6ced6a7 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -20,7 +20,7 @@ func SetupRouter(handler *Handler) *chi.Mux { r.Use(cors.Handler(cors.Options{ AllowedOrigins: handler.cfg.Server.AllowedOrigins, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, - AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, + AllowedHeaders: handler.cfg.Server.AllowedHeaders, ExposedHeaders: []string{"Link"}, AllowCredentials: false, MaxAge: 300, @@ -112,6 +112,51 @@ func SetupRouter(handler *Handler) *chi.Mux { }) + r.Route("/llama-cpp/{name}", func(r chi.Router) { + + // Public Routes + // Allow llama-cpp server to serve its own WebUI if it is running. + // Don't auto start the server since it can be accessed without an API key + r.Get("/", handler.LlamaCppProxy(false)) + + // Private Routes + r.Group(func(r chi.Router) { + + if authMiddleware != nil && handler.cfg.Auth.RequireInferenceAuth { + r.Use(authMiddleware.AuthMiddleware(KeyTypeInference)) + } + + // This handler auto start the server if it's not running + llamaCppHandler := handler.LlamaCppProxy(true) + + // llama.cpp server specific proxy endpoints + r.Get("/props", llamaCppHandler) + // /slots endpoint is secured (see: https://github.com/ggml-org/llama.cpp/pull/15630) + r.Get("/slots", llamaCppHandler) + r.Post("/apply-template", llamaCppHandler) + r.Post("/completion", llamaCppHandler) + r.Post("/detokenize", llamaCppHandler) + r.Post("/embeddings", llamaCppHandler) + r.Post("/infill", llamaCppHandler) + r.Post("/metrics", llamaCppHandler) + r.Post("/props", llamaCppHandler) + r.Post("/reranking", llamaCppHandler) + r.Post("/tokenize", llamaCppHandler) + + // OpenAI-compatible proxy endpoint + // Handles all POST requests to /v1/*, including: + // - /v1/completions + // - /v1/chat/completions + // - /v1/embeddings + // - /v1/rerank + // - /v1/reranking + // llamaCppHandler is used here because some users of llama.cpp endpoints depend + // on "model" field being optional, and handler.OpenAIProxy requires it. + r.Post("/v1/*", llamaCppHandler) + }) + + }) + // Serve WebUI files if err := webui.SetupWebUI(r); err != nil { fmt.Printf("Failed to set up WebUI: %v\n", err) diff --git a/webui/src/lib/__tests__/api.test.ts b/webui/src/lib/__tests__/api.test.ts index 87e8ac7..2eda209 100644 --- a/webui/src/lib/__tests__/api.test.ts +++ b/webui/src/lib/__tests__/api.test.ts @@ -11,11 +11,13 @@ describe('API Error Handling', () => { }) it('converts HTTP errors to meaningful messages', async () => { - mockFetch.mockResolvedValue({ + const mockResponse = { ok: false, status: 409, - text: () => Promise.resolve('Instance already exists') - }) + text: () => Promise.resolve('Instance already exists'), + clone: function() { return this } + } + mockFetch.mockResolvedValue(mockResponse) await expect(instancesApi.create('existing', {})) .rejects @@ -23,11 +25,13 @@ describe('API Error Handling', () => { }) it('handles empty error responses gracefully', async () => { - mockFetch.mockResolvedValue({ + const mockResponse = { ok: false, status: 500, - text: () => Promise.resolve('') - }) + text: () => Promise.resolve(''), + clone: function() { return this } + } + mockFetch.mockResolvedValue(mockResponse) await expect(instancesApi.list()) .rejects diff --git a/webui/src/lib/api.ts b/webui/src/lib/api.ts index 7e15bc9..ea17e9a 100644 --- a/webui/src/lib/api.ts +++ b/webui/src/lib/api.ts @@ -49,11 +49,8 @@ async function apiCall( } else { // Handle empty responses for JSON endpoints const contentLength = response.headers.get('content-length'); - if (contentLength === '0' || contentLength === null) { - const text = await response.text(); - if (text.trim() === '') { - return {} as T; // Return empty object for empty JSON responses - } + if (contentLength === '0') { + return {} as T; // Return empty object for empty JSON responses } const data = await response.json() as T; return data; diff --git a/webui/src/lib/errorUtils.ts b/webui/src/lib/errorUtils.ts index 1860bf9..85cdf03 100644 --- a/webui/src/lib/errorUtils.ts +++ b/webui/src/lib/errorUtils.ts @@ -26,7 +26,8 @@ export async function handleApiError(response: Response): Promise { } if (!response.ok) { - const errorMessage = await parseErrorResponse(response) + // Clone the response before reading to avoid consuming the body stream + const errorMessage = await parseErrorResponse(response.clone()) throw new Error(errorMessage) } } \ No newline at end of file