Support llama.cpp router mode for openai endpoints

This commit is contained in:
2025-12-21 23:32:33 +01:00
parent faf026aa54
commit 38790aa507
5 changed files with 174 additions and 183 deletions

View File

@@ -14,6 +14,7 @@ const (
BackendTypeLlamaCpp BackendType = "llama_cpp" BackendTypeLlamaCpp BackendType = "llama_cpp"
BackendTypeMlxLm BackendType = "mlx_lm" BackendTypeMlxLm BackendType = "mlx_lm"
BackendTypeVllm BackendType = "vllm" BackendTypeVllm BackendType = "vllm"
BackendTypeUnknown BackendType = "unknown"
) )
type backend interface { type backend interface {

View File

@@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"time" "time"
"llamactl/pkg/backends"
"llamactl/pkg/config" "llamactl/pkg/config"
) )
@@ -117,6 +118,14 @@ func (i *Instance) WaitForHealthy(timeout int) error {
return i.process.waitForHealthy(timeout) return i.process.waitForHealthy(timeout)
} }
func (i *Instance) GetBackendType() backends.BackendType {
opts := i.GetOptions()
if opts == nil {
return backends.BackendTypeUnknown
}
return opts.BackendOptions.BackendType
}
// GetOptions returns the current options // GetOptions returns the current options
func (i *Instance) GetOptions() *Options { func (i *Instance) GetOptions() *Options {
if i.options == nil { if i.options == nil {

View File

@@ -1,141 +0,0 @@
package instance
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"llamactl/pkg/backends"
"net/http"
"time"
)
// Model represents a model available in a llama.cpp instance
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by"`
Created int64 `json:"created"`
InCache bool `json:"in_cache"`
Path string `json:"path"`
Status ModelStatus `json:"status"`
}
// ModelStatus represents the status of a model in an instance
type ModelStatus struct {
Value string `json:"value"` // "loaded" | "loading" | "unloaded"
Args []string `json:"args"`
}
// IsLlamaCpp checks if this instance is a llama.cpp instance
func (i *Instance) IsLlamaCpp() bool {
opts := i.GetOptions()
if opts == nil {
return false
}
return opts.BackendOptions.BackendType == backends.BackendTypeLlamaCpp
}
// GetModels fetches the models available in this llama.cpp instance
func (i *Instance) GetModels() ([]Model, error) {
if !i.IsLlamaCpp() {
return nil, fmt.Errorf("instance %s is not a llama.cpp instance", i.Name)
}
if !i.IsRunning() {
return nil, fmt.Errorf("instance %s is not running", i.Name)
}
var result struct {
Data []Model `json:"data"`
}
if err := i.doRequest("GET", "/models", nil, &result, 10*time.Second); err != nil {
return nil, fmt.Errorf("failed to fetch models: %w", err)
}
return result.Data, nil
}
// LoadModel loads a model in this llama.cpp instance
func (i *Instance) LoadModel(modelName string) error {
if !i.IsLlamaCpp() {
return fmt.Errorf("instance %s is not a llama.cpp instance", i.Name)
}
if !i.IsRunning() {
return fmt.Errorf("instance %s is not running", i.Name)
}
// Make the load request
reqBody := map[string]string{"model": modelName}
if err := i.doRequest("POST", "/models/load", reqBody, nil, 30*time.Second); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
return nil
}
// UnloadModel unloads a model from this llama.cpp instance
func (i *Instance) UnloadModel(modelName string) error {
if !i.IsLlamaCpp() {
return fmt.Errorf("instance %s is not a llama.cpp instance", i.Name)
}
if !i.IsRunning() {
return fmt.Errorf("instance %s is not running", i.Name)
}
// Make the unload request
reqBody := map[string]string{"model": modelName}
if err := i.doRequest("POST", "/models/unload", reqBody, nil, 30*time.Second); err != nil {
return fmt.Errorf("failed to unload model: %w", err)
}
return nil
}
// doRequest makes an HTTP request to this instance's backend
func (i *Instance) doRequest(method, path string, reqBody, respBody any, timeout time.Duration) error {
url := fmt.Sprintf("http://%s:%d%s", i.GetHost(), i.GetPort(), path)
var bodyReader io.Reader
if reqBody != nil {
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal request body: %w", err)
}
bodyReader = bytes.NewReader(bodyBytes)
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
if reqBody != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))
}
if respBody != nil {
if err := json.NewDecoder(resp.Body).Decode(respBody); err != nil {
return fmt.Errorf("failed to decode response: %w", err)
}
}
return nil
}

View File

@@ -8,8 +8,6 @@ import (
"net/http" "net/http"
"os/exec" "os/exec"
"strings" "strings"
"github.com/go-chi/chi/v5"
) )
// ParseCommandRequest represents the request body for backend command parsing // ParseCommandRequest represents the request body for backend command parsing
@@ -322,24 +320,41 @@ func (h *Handler) LlamaServerListDevicesHandler() http.HandlerFunc {
// @Router /api/v1/llama-cpp/{name}/models [get] // @Router /api/v1/llama-cpp/{name}/models [get]
func (h *Handler) LlamaCppListModels() http.HandlerFunc { func (h *Handler) LlamaCppListModels() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
inst, err := h.getInstance(r) inst, err := h.validateLlamaCppInstance(r)
if err != nil { if err != nil {
writeError(w, http.StatusBadRequest, "invalid_instance", err.Error()) writeError(w, http.StatusBadRequest, "invalid instance", err.Error())
return return
} }
models, err := inst.GetModels() // Check instance permissions
if err != nil { if err := h.authMiddleware.CheckInstancePermission(r.Context(), inst.ID); err != nil {
writeError(w, http.StatusBadRequest, "get_models_failed", err.Error()) writeError(w, http.StatusForbidden, "permission_denied", err.Error())
return return
} }
response := map[string]any{ // Check if instance is shutting down before autostart logic
"object": "list", if inst.GetStatus() == instance.ShuttingDown {
"data": models, writeError(w, http.StatusServiceUnavailable, "instance_shutting_down", "Instance is shutting down")
return
} }
writeJSON(w, http.StatusOK, response) if !inst.IsRemote() && !inst.IsRunning() {
err := h.ensureInstanceRunning(inst)
if err != nil {
writeError(w, http.StatusInternalServerError, "instance start failed", err.Error())
return
}
}
// Modify request path to /models for proxying
r.URL.Path = "/models"
// Use instance's ServeHTTP which tracks inflight requests and handles shutting down state
err = inst.ServeHTTP(w, r)
if err != nil {
// Error is already handled in ServeHTTP (response written)
return
}
} }
} }
@@ -357,23 +372,41 @@ func (h *Handler) LlamaCppListModels() http.HandlerFunc {
// @Router /api/v1/llama-cpp/{name}/models/{model}/load [post] // @Router /api/v1/llama-cpp/{name}/models/{model}/load [post]
func (h *Handler) LlamaCppLoadModel() http.HandlerFunc { func (h *Handler) LlamaCppLoadModel() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
inst, err := h.getInstance(r) inst, err := h.validateLlamaCppInstance(r)
if err != nil { if err != nil {
writeError(w, http.StatusBadRequest, "invalid_instance", err.Error()) writeError(w, http.StatusBadRequest, "invalid instance", err.Error())
return return
} }
modelName := chi.URLParam(r, "model") // Check instance permissions
if err := h.authMiddleware.CheckInstancePermission(r.Context(), inst.ID); err != nil {
if err := inst.LoadModel(modelName); err != nil { writeError(w, http.StatusForbidden, "permission_denied", err.Error())
writeError(w, http.StatusBadRequest, "load_model_failed", err.Error())
return return
} }
writeJSON(w, http.StatusOK, map[string]string{ // Check if instance is shutting down before autostart logic
"status": "success", if inst.GetStatus() == instance.ShuttingDown {
"message": fmt.Sprintf("Model %s loaded successfully", modelName), writeError(w, http.StatusServiceUnavailable, "instance_shutting_down", "Instance is shutting down")
}) return
}
if !inst.IsRemote() && !inst.IsRunning() {
err := h.ensureInstanceRunning(inst)
if err != nil {
writeError(w, http.StatusInternalServerError, "instance start failed", err.Error())
return
}
}
// Modify request path to /models/load for proxying
r.URL.Path = "/models/load"
// Use instance's ServeHTTP which tracks inflight requests and handles shutting down state
err = inst.ServeHTTP(w, r)
if err != nil {
// Error is already handled in ServeHTTP (response written)
return
}
} }
} }
@@ -391,22 +424,40 @@ func (h *Handler) LlamaCppLoadModel() http.HandlerFunc {
// @Router /api/v1/llama-cpp/{name}/models/{model}/unload [post] // @Router /api/v1/llama-cpp/{name}/models/{model}/unload [post]
func (h *Handler) LlamaCppUnloadModel() http.HandlerFunc { func (h *Handler) LlamaCppUnloadModel() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
inst, err := h.getInstance(r) inst, err := h.validateLlamaCppInstance(r)
if err != nil { if err != nil {
writeError(w, http.StatusBadRequest, "invalid_instance", err.Error()) writeError(w, http.StatusBadRequest, "invalid instance", err.Error())
return return
} }
modelName := chi.URLParam(r, "model") // Check instance permissions
if err := h.authMiddleware.CheckInstancePermission(r.Context(), inst.ID); err != nil {
if err := inst.UnloadModel(modelName); err != nil { writeError(w, http.StatusForbidden, "permission_denied", err.Error())
writeError(w, http.StatusBadRequest, "unload_model_failed", err.Error())
return return
} }
writeJSON(w, http.StatusOK, map[string]string{ // Check if instance is shutting down before autostart logic
"status": "success", if inst.GetStatus() == instance.ShuttingDown {
"message": fmt.Sprintf("Model %s unloaded successfully", modelName), writeError(w, http.StatusServiceUnavailable, "instance_shutting_down", "Instance is shutting down")
}) return
}
if !inst.IsRemote() && !inst.IsRunning() {
err := h.ensureInstanceRunning(inst)
if err != nil {
writeError(w, http.StatusInternalServerError, "instance start failed", err.Error())
return
}
}
// Modify request path to /models/unload for proxying
r.URL.Path = "/models/unload"
// Use instance's ServeHTTP which tracks inflight requests and handles shutting down state
err = inst.ServeHTTP(w, r)
if err != nil {
// Error is already handled in ServeHTTP (response written)
return
}
} }
} }

View File

@@ -5,9 +5,11 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"llamactl/pkg/backends"
"llamactl/pkg/instance" "llamactl/pkg/instance"
"llamactl/pkg/validation" "llamactl/pkg/validation"
"net/http" "net/http"
"strings"
) )
// OpenAIListInstancesResponse represents the response structure for listing instances (models) in OpenAI-compatible format // OpenAIListInstancesResponse represents the response structure for listing instances (models) in OpenAI-compatible format
@@ -24,6 +26,53 @@ type OpenAIInstance struct {
OwnedBy string `json:"owned_by"` OwnedBy string `json:"owned_by"`
} }
// LlamaCppModel represents a model available in a llama.cpp instance
type LlamaCppModel struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by"`
Created int64 `json:"created"`
InCache bool `json:"in_cache"`
Path string `json:"path"`
Status LlamaCppModelStatus `json:"status"`
}
// LlamaCppModelStatus represents the status of a model in a llama.cpp instance
type LlamaCppModelStatus struct {
Value string `json:"value"` // "loaded" | "loading" | "unloaded"
Args []string `json:"args"`
}
// fetchLlamaCppModels fetches models from a llama.cpp instance using the proxy
func fetchLlamaCppModels(inst *instance.Instance) ([]LlamaCppModel, error) {
// Create a request to the instance's /models endpoint
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s:%d/models", inst.GetHost(), inst.GetPort()), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Use a custom response writer to capture the response
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))
}
var result struct {
Data []LlamaCppModel `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return result.Data, nil
}
// OpenAIListInstances godoc // OpenAIListInstances godoc
// @Summary List instances in OpenAI-compatible format // @Summary List instances in OpenAI-compatible format
// @Description Returns a list of instances in a format compatible with OpenAI API // @Description Returns a list of instances in a format compatible with OpenAI API
@@ -46,9 +95,9 @@ func (h *Handler) OpenAIListInstances() http.HandlerFunc {
// For each llama.cpp instance, try to fetch models and add them as separate entries // For each llama.cpp instance, try to fetch models and add them as separate entries
for _, inst := range instances { for _, inst := range instances {
if inst.IsLlamaCpp() && inst.IsRunning() { if inst.GetBackendType() == backends.BackendTypeLlamaCpp && inst.IsRunning() {
// Try to fetch models from the instance // Try to fetch models from the instance
models, err := inst.GetModels() models, err := fetchLlamaCppModels(inst)
if err != nil { if err != nil {
fmt.Printf("Failed to fetch models from instance %s: %v", inst.Name, err) fmt.Printf("Failed to fetch models from instance %s: %v", inst.Name, err)
continue continue
@@ -56,9 +105,9 @@ func (h *Handler) OpenAIListInstances() http.HandlerFunc {
for _, model := range models { for _, model := range models {
openaiInstances = append(openaiInstances, OpenAIInstance{ openaiInstances = append(openaiInstances, OpenAIInstance{
ID: model.ID, ID: inst.Name + "/" + model.ID,
Object: "model", Object: "model",
Created: model.Created, Created: inst.Created,
OwnedBy: inst.Name, OwnedBy: inst.Name,
}) })
} }
@@ -115,17 +164,24 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc {
return return
} }
modelName, ok := requestBody["model"].(string) reqModelName, ok := requestBody["model"].(string)
if !ok || modelName == "" { if !ok || reqModelName == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "Model name is required") writeError(w, http.StatusBadRequest, "invalid_request", "Model name is required")
return return
} }
// Resolve model name to instance name (checks instance names first, then model registry) // Parse instance name and model name from <instance_name>/<model_name> format
instanceName, err := h.InstanceManager.ResolveInstance(modelName) var instanceName string
if err != nil { var modelName string
writeError(w, http.StatusBadRequest, "model_not_found", err.Error())
return // Check if model name contains "/"
if idx := strings.Index(reqModelName, "/"); idx != -1 {
// Split into instance and model parts
instanceName = reqModelName[:idx]
modelName = reqModelName[idx+1:]
} else {
instanceName = reqModelName
modelName = reqModelName
} }
// Validate instance name at the entry point // Validate instance name at the entry point
@@ -154,6 +210,11 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc {
return return
} }
if inst.IsRemote() {
// Don't replace model name for remote instances
modelName = reqModelName
}
if !inst.IsRemote() && !inst.IsRunning() { if !inst.IsRemote() && !inst.IsRunning() {
err := h.ensureInstanceRunning(inst) err := h.ensureInstanceRunning(inst)
if err != nil { if err != nil {
@@ -162,6 +223,16 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc {
} }
} }
// Update the request body with just the model name
requestBody["model"] = modelName
// Re-marshal the updated body
bodyBytes, err = json.Marshal(requestBody)
if err != nil {
writeError(w, http.StatusInternalServerError, "marshal_error", "Failed to update request body")
return
}
// Recreate the request body from the bytes we read // Recreate the request body from the bytes we read
r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
r.ContentLength = int64(len(bodyBytes)) r.ContentLength = int64(len(bodyBytes))