diff --git a/pkg/backends/backend.go b/pkg/backends/backend.go index ad138a5..2bb0fa8 100644 --- a/pkg/backends/backend.go +++ b/pkg/backends/backend.go @@ -14,6 +14,7 @@ const ( BackendTypeLlamaCpp BackendType = "llama_cpp" BackendTypeMlxLm BackendType = "mlx_lm" BackendTypeVllm BackendType = "vllm" + BackendTypeUnknown BackendType = "unknown" ) type backend interface { diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go index a8faa1b..6149a4a 100644 --- a/pkg/instance/instance.go +++ b/pkg/instance/instance.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "llamactl/pkg/backends" "llamactl/pkg/config" ) @@ -117,6 +118,14 @@ func (i *Instance) WaitForHealthy(timeout int) error { return i.process.waitForHealthy(timeout) } +func (i *Instance) GetBackendType() backends.BackendType { + opts := i.GetOptions() + if opts == nil { + return backends.BackendTypeUnknown + } + return opts.BackendOptions.BackendType +} + // GetOptions returns the current options func (i *Instance) GetOptions() *Options { if i.options == nil { diff --git a/pkg/instance/models.go b/pkg/instance/models.go deleted file mode 100644 index f911f18..0000000 --- a/pkg/instance/models.go +++ /dev/null @@ -1,141 +0,0 @@ -package instance - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "llamactl/pkg/backends" - "net/http" - "time" -) - -// Model represents a model available in a llama.cpp instance -type Model struct { - ID string `json:"id"` - Object string `json:"object"` - OwnedBy string `json:"owned_by"` - Created int64 `json:"created"` - InCache bool `json:"in_cache"` - Path string `json:"path"` - Status ModelStatus `json:"status"` -} - -// ModelStatus represents the status of a model in an instance -type ModelStatus struct { - Value string `json:"value"` // "loaded" | "loading" | "unloaded" - Args []string `json:"args"` -} - -// IsLlamaCpp checks if this instance is a llama.cpp instance -func (i *Instance) IsLlamaCpp() bool { - opts := i.GetOptions() - if opts == nil { - return false - } - return opts.BackendOptions.BackendType == backends.BackendTypeLlamaCpp -} - -// GetModels fetches the models available in this llama.cpp instance -func (i *Instance) GetModels() ([]Model, error) { - if !i.IsLlamaCpp() { - return nil, fmt.Errorf("instance %s is not a llama.cpp instance", i.Name) - } - - if !i.IsRunning() { - return nil, fmt.Errorf("instance %s is not running", i.Name) - } - - var result struct { - Data []Model `json:"data"` - } - if err := i.doRequest("GET", "/models", nil, &result, 10*time.Second); err != nil { - return nil, fmt.Errorf("failed to fetch models: %w", err) - } - - return result.Data, nil -} - -// LoadModel loads a model in this llama.cpp instance -func (i *Instance) LoadModel(modelName string) error { - if !i.IsLlamaCpp() { - return fmt.Errorf("instance %s is not a llama.cpp instance", i.Name) - } - - if !i.IsRunning() { - return fmt.Errorf("instance %s is not running", i.Name) - } - - // Make the load request - reqBody := map[string]string{"model": modelName} - if err := i.doRequest("POST", "/models/load", reqBody, nil, 30*time.Second); err != nil { - return fmt.Errorf("failed to load model: %w", err) - } - - return nil -} - -// UnloadModel unloads a model from this llama.cpp instance -func (i *Instance) UnloadModel(modelName string) error { - if !i.IsLlamaCpp() { - return fmt.Errorf("instance %s is not a llama.cpp instance", i.Name) - } - - if !i.IsRunning() { - return fmt.Errorf("instance %s is not running", i.Name) - } - - // Make the unload request - reqBody := map[string]string{"model": modelName} - if err := i.doRequest("POST", "/models/unload", reqBody, nil, 30*time.Second); err != nil { - return fmt.Errorf("failed to unload model: %w", err) - } - - return nil -} - -// doRequest makes an HTTP request to this instance's backend -func (i *Instance) doRequest(method, path string, reqBody, respBody any, timeout time.Duration) error { - url := fmt.Sprintf("http://%s:%d%s", i.GetHost(), i.GetPort(), path) - - var bodyReader io.Reader - if reqBody != nil { - bodyBytes, err := json.Marshal(reqBody) - if err != nil { - return fmt.Errorf("failed to marshal request body: %w", err) - } - bodyReader = bytes.NewReader(bodyBytes) - } - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, method, url, bodyReader) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - if reqBody != nil { - req.Header.Set("Content-Type", "application/json") - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - return fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)) - } - - if respBody != nil { - if err := json.NewDecoder(resp.Body).Decode(respBody); err != nil { - return fmt.Errorf("failed to decode response: %w", err) - } - } - - return nil -} diff --git a/pkg/server/handlers_backends.go b/pkg/server/handlers_backends.go index 2cc9304..8fc9ff6 100644 --- a/pkg/server/handlers_backends.go +++ b/pkg/server/handlers_backends.go @@ -8,8 +8,6 @@ import ( "net/http" "os/exec" "strings" - - "github.com/go-chi/chi/v5" ) // ParseCommandRequest represents the request body for backend command parsing @@ -322,24 +320,41 @@ func (h *Handler) LlamaServerListDevicesHandler() http.HandlerFunc { // @Router /api/v1/llama-cpp/{name}/models [get] func (h *Handler) LlamaCppListModels() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - inst, err := h.getInstance(r) + inst, err := h.validateLlamaCppInstance(r) if err != nil { - writeError(w, http.StatusBadRequest, "invalid_instance", err.Error()) + writeError(w, http.StatusBadRequest, "invalid instance", err.Error()) return } - models, err := inst.GetModels() - if err != nil { - writeError(w, http.StatusBadRequest, "get_models_failed", err.Error()) + // Check instance permissions + if err := h.authMiddleware.CheckInstancePermission(r.Context(), inst.ID); err != nil { + writeError(w, http.StatusForbidden, "permission_denied", err.Error()) return } - response := map[string]any{ - "object": "list", - "data": models, + // Check if instance is shutting down before autostart logic + if inst.GetStatus() == instance.ShuttingDown { + writeError(w, http.StatusServiceUnavailable, "instance_shutting_down", "Instance is shutting down") + return } - writeJSON(w, http.StatusOK, response) + if !inst.IsRemote() && !inst.IsRunning() { + err := h.ensureInstanceRunning(inst) + if err != nil { + writeError(w, http.StatusInternalServerError, "instance start failed", err.Error()) + return + } + } + + // Modify request path to /models for proxying + r.URL.Path = "/models" + + // Use instance's ServeHTTP which tracks inflight requests and handles shutting down state + err = inst.ServeHTTP(w, r) + if err != nil { + // Error is already handled in ServeHTTP (response written) + return + } } } @@ -357,23 +372,41 @@ func (h *Handler) LlamaCppListModels() http.HandlerFunc { // @Router /api/v1/llama-cpp/{name}/models/{model}/load [post] func (h *Handler) LlamaCppLoadModel() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - inst, err := h.getInstance(r) + inst, err := h.validateLlamaCppInstance(r) if err != nil { - writeError(w, http.StatusBadRequest, "invalid_instance", err.Error()) + writeError(w, http.StatusBadRequest, "invalid instance", err.Error()) return } - modelName := chi.URLParam(r, "model") - - if err := inst.LoadModel(modelName); err != nil { - writeError(w, http.StatusBadRequest, "load_model_failed", err.Error()) + // Check instance permissions + if err := h.authMiddleware.CheckInstancePermission(r.Context(), inst.ID); err != nil { + writeError(w, http.StatusForbidden, "permission_denied", err.Error()) return } - writeJSON(w, http.StatusOK, map[string]string{ - "status": "success", - "message": fmt.Sprintf("Model %s loaded successfully", modelName), - }) + // Check if instance is shutting down before autostart logic + if inst.GetStatus() == instance.ShuttingDown { + writeError(w, http.StatusServiceUnavailable, "instance_shutting_down", "Instance is shutting down") + return + } + + if !inst.IsRemote() && !inst.IsRunning() { + err := h.ensureInstanceRunning(inst) + if err != nil { + writeError(w, http.StatusInternalServerError, "instance start failed", err.Error()) + return + } + } + + // Modify request path to /models/load for proxying + r.URL.Path = "/models/load" + + // Use instance's ServeHTTP which tracks inflight requests and handles shutting down state + err = inst.ServeHTTP(w, r) + if err != nil { + // Error is already handled in ServeHTTP (response written) + return + } } } @@ -391,22 +424,40 @@ func (h *Handler) LlamaCppLoadModel() http.HandlerFunc { // @Router /api/v1/llama-cpp/{name}/models/{model}/unload [post] func (h *Handler) LlamaCppUnloadModel() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - inst, err := h.getInstance(r) + inst, err := h.validateLlamaCppInstance(r) if err != nil { - writeError(w, http.StatusBadRequest, "invalid_instance", err.Error()) + writeError(w, http.StatusBadRequest, "invalid instance", err.Error()) return } - modelName := chi.URLParam(r, "model") - - if err := inst.UnloadModel(modelName); err != nil { - writeError(w, http.StatusBadRequest, "unload_model_failed", err.Error()) + // Check instance permissions + if err := h.authMiddleware.CheckInstancePermission(r.Context(), inst.ID); err != nil { + writeError(w, http.StatusForbidden, "permission_denied", err.Error()) return } - writeJSON(w, http.StatusOK, map[string]string{ - "status": "success", - "message": fmt.Sprintf("Model %s unloaded successfully", modelName), - }) + // Check if instance is shutting down before autostart logic + if inst.GetStatus() == instance.ShuttingDown { + writeError(w, http.StatusServiceUnavailable, "instance_shutting_down", "Instance is shutting down") + return + } + + if !inst.IsRemote() && !inst.IsRunning() { + err := h.ensureInstanceRunning(inst) + if err != nil { + writeError(w, http.StatusInternalServerError, "instance start failed", err.Error()) + return + } + } + + // Modify request path to /models/unload for proxying + r.URL.Path = "/models/unload" + + // Use instance's ServeHTTP which tracks inflight requests and handles shutting down state + err = inst.ServeHTTP(w, r) + if err != nil { + // Error is already handled in ServeHTTP (response written) + return + } } } diff --git a/pkg/server/handlers_openai.go b/pkg/server/handlers_openai.go index 2240020..06f06b7 100644 --- a/pkg/server/handlers_openai.go +++ b/pkg/server/handlers_openai.go @@ -5,9 +5,11 @@ import ( "encoding/json" "fmt" "io" + "llamactl/pkg/backends" "llamactl/pkg/instance" "llamactl/pkg/validation" "net/http" + "strings" ) // OpenAIListInstancesResponse represents the response structure for listing instances (models) in OpenAI-compatible format @@ -24,6 +26,53 @@ type OpenAIInstance struct { OwnedBy string `json:"owned_by"` } +// LlamaCppModel represents a model available in a llama.cpp instance +type LlamaCppModel struct { + ID string `json:"id"` + Object string `json:"object"` + OwnedBy string `json:"owned_by"` + Created int64 `json:"created"` + InCache bool `json:"in_cache"` + Path string `json:"path"` + Status LlamaCppModelStatus `json:"status"` +} + +// LlamaCppModelStatus represents the status of a model in a llama.cpp instance +type LlamaCppModelStatus struct { + Value string `json:"value"` // "loaded" | "loading" | "unloaded" + Args []string `json:"args"` +} + +// fetchLlamaCppModels fetches models from a llama.cpp instance using the proxy +func fetchLlamaCppModels(inst *instance.Instance) ([]LlamaCppModel, error) { + // Create a request to the instance's /models endpoint + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s:%d/models", inst.GetHost(), inst.GetPort()), nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Use a custom response writer to capture the response + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var result struct { + Data []LlamaCppModel `json:"data"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return result.Data, nil +} + // OpenAIListInstances godoc // @Summary List instances in OpenAI-compatible format // @Description Returns a list of instances in a format compatible with OpenAI API @@ -46,9 +95,9 @@ func (h *Handler) OpenAIListInstances() http.HandlerFunc { // For each llama.cpp instance, try to fetch models and add them as separate entries for _, inst := range instances { - if inst.IsLlamaCpp() && inst.IsRunning() { + if inst.GetBackendType() == backends.BackendTypeLlamaCpp && inst.IsRunning() { // Try to fetch models from the instance - models, err := inst.GetModels() + models, err := fetchLlamaCppModels(inst) if err != nil { fmt.Printf("Failed to fetch models from instance %s: %v", inst.Name, err) continue @@ -56,9 +105,9 @@ func (h *Handler) OpenAIListInstances() http.HandlerFunc { for _, model := range models { openaiInstances = append(openaiInstances, OpenAIInstance{ - ID: model.ID, + ID: inst.Name + "/" + model.ID, Object: "model", - Created: model.Created, + Created: inst.Created, OwnedBy: inst.Name, }) } @@ -115,17 +164,24 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc { return } - modelName, ok := requestBody["model"].(string) - if !ok || modelName == "" { + reqModelName, ok := requestBody["model"].(string) + if !ok || reqModelName == "" { writeError(w, http.StatusBadRequest, "invalid_request", "Model name is required") return } - // Resolve model name to instance name (checks instance names first, then model registry) - instanceName, err := h.InstanceManager.ResolveInstance(modelName) - if err != nil { - writeError(w, http.StatusBadRequest, "model_not_found", err.Error()) - return + // Parse instance name and model name from / format + var instanceName string + var modelName string + + // Check if model name contains "/" + if idx := strings.Index(reqModelName, "/"); idx != -1 { + // Split into instance and model parts + instanceName = reqModelName[:idx] + modelName = reqModelName[idx+1:] + } else { + instanceName = reqModelName + modelName = reqModelName } // Validate instance name at the entry point @@ -154,6 +210,11 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc { return } + if inst.IsRemote() { + // Don't replace model name for remote instances + modelName = reqModelName + } + if !inst.IsRemote() && !inst.IsRunning() { err := h.ensureInstanceRunning(inst) if err != nil { @@ -162,6 +223,16 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc { } } + // Update the request body with just the model name + requestBody["model"] = modelName + + // Re-marshal the updated body + bodyBytes, err = json.Marshal(requestBody) + if err != nil { + writeError(w, http.StatusInternalServerError, "marshal_error", "Failed to update request body") + return + } + // Recreate the request body from the bytes we read r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) r.ContentLength = int64(len(bodyBytes))