mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-12-23 01:24:24 +00:00
Implement model management for llama.cpp instances
This commit is contained in:
@@ -5,9 +5,12 @@ import (
|
||||
"fmt"
|
||||
"llamactl/pkg/backends"
|
||||
"llamactl/pkg/instance"
|
||||
"log"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// ParseCommandRequest represents the request body for backend command parsing
|
||||
@@ -306,3 +309,115 @@ func (h *Handler) LlamaServerVersionHandler() http.HandlerFunc {
|
||||
func (h *Handler) LlamaServerListDevicesHandler() http.HandlerFunc {
|
||||
return h.executeLlamaServerCommand("--list-devices", "Failed to list devices")
|
||||
}
|
||||
|
||||
// LlamaCppListModels godoc
|
||||
// @Summary List models in a llama.cpp instance
|
||||
// @Description Returns a list of models available in the specified llama.cpp instance
|
||||
// @Tags Llama.cpp
|
||||
// @Security ApiKeyAuth
|
||||
// @Produces json
|
||||
// @Param name path string true "Instance Name"
|
||||
// @Success 200 {object} map[string]any "Models list response"
|
||||
// @Failure 400 {string} string "Invalid instance"
|
||||
// @Failure 500 {string} string "Internal Server Error"
|
||||
// @Router /api/v1/llama-cpp/{name}/models [get]
|
||||
func (h *Handler) LlamaCppListModels() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
inst, err := h.getInstance(r)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_instance", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
models, err := inst.GetModels()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "get_models_failed", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response := map[string]any{
|
||||
"object": "list",
|
||||
"data": models,
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
// LlamaCppLoadModel godoc
|
||||
// @Summary Load a model in a llama.cpp instance
|
||||
// @Description Loads the specified model in the given llama.cpp instance
|
||||
// @Tags Llama.cpp
|
||||
// @Security ApiKeyAuth
|
||||
// @Produces json
|
||||
// @Param name path string true "Instance Name"
|
||||
// @Param model path string true "Model Name"
|
||||
// @Success 200 {object} map[string]string "Success message"
|
||||
// @Failure 400 {string} string "Invalid request"
|
||||
// @Failure 500 {string} string "Internal Server Error"
|
||||
// @Router /api/v1/llama-cpp/{name}/models/{model}/load [post]
|
||||
func (h *Handler) LlamaCppLoadModel() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
inst, err := h.getInstance(r)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_instance", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
modelName := chi.URLParam(r, "model")
|
||||
|
||||
if err := inst.LoadModel(modelName); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "load_model_failed", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh the model registry
|
||||
if err := h.InstanceManager.RefreshModelRegistry(inst); err != nil {
|
||||
log.Printf("Warning: failed to refresh model registry after load: %v", err)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]string{
|
||||
"status": "success",
|
||||
"message": fmt.Sprintf("Model %s loaded successfully", modelName),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// LlamaCppUnloadModel godoc
|
||||
// @Summary Unload a model in a llama.cpp instance
|
||||
// @Description Unloads the specified model in the given llama.cpp instance
|
||||
// @Tags Llama.cpp
|
||||
// @Security ApiKeyAuth
|
||||
// @Produces json
|
||||
// @Param name path string true "Instance Name"
|
||||
// @Param model path string true "Model Name"
|
||||
// @Success 200 {object} map[string]string "Success message"
|
||||
// @Failure 400 {string} string "Invalid request"
|
||||
// @Failure 500 {string} string "Internal Server Error"
|
||||
// @Router /api/v1/llama-cpp/{name}/models/{model}/unload [post]
|
||||
func (h *Handler) LlamaCppUnloadModel() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
inst, err := h.getInstance(r)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_instance", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
modelName := chi.URLParam(r, "model")
|
||||
|
||||
if err := inst.UnloadModel(modelName); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "unload_model_failed", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh the model registry
|
||||
if err := h.InstanceManager.RefreshModelRegistry(inst); err != nil {
|
||||
log.Printf("Warning: failed to refresh model registry after unload: %v", err)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]string{
|
||||
"status": "success",
|
||||
"message": fmt.Sprintf("Model %s unloaded successfully", modelName),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"llamactl/pkg/instance"
|
||||
"llamactl/pkg/validation"
|
||||
@@ -40,14 +41,41 @@ func (h *Handler) OpenAIListInstances() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
openaiInstances := make([]OpenAIInstance, len(instances))
|
||||
for i, inst := range instances {
|
||||
openaiInstances[i] = OpenAIInstance{
|
||||
var openaiInstances []OpenAIInstance
|
||||
|
||||
// For each llama.cpp instance, try to fetch models and add them as separate entries
|
||||
for _, inst := range instances {
|
||||
|
||||
if inst.IsLlamaCpp() && inst.IsRunning() {
|
||||
// Try to fetch models from the instance
|
||||
models, err := inst.GetModels()
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to fetch models from instance %s: %v", inst.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
openaiInstances = append(openaiInstances, OpenAIInstance{
|
||||
ID: model.ID,
|
||||
Object: "model",
|
||||
Created: model.Created,
|
||||
OwnedBy: inst.Name,
|
||||
})
|
||||
}
|
||||
|
||||
if len(models) > 1 {
|
||||
// Skip adding the instance name if multiple models are present
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Add instance name as single entry (for non-llama.cpp or if model fetch failed)
|
||||
openaiInstances = append(openaiInstances, OpenAIInstance{
|
||||
ID: inst.Name,
|
||||
Object: "model",
|
||||
Created: inst.Created,
|
||||
OwnedBy: "llamactl",
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
openaiResponse := OpenAIListInstancesResponse{
|
||||
@@ -89,12 +117,19 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc {
|
||||
|
||||
modelName, ok := requestBody["model"].(string)
|
||||
if !ok || modelName == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "Instance name is required")
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "Model name is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve model name to instance name (checks instance names first, then model registry)
|
||||
instanceName, err := h.InstanceManager.ResolveInstance(modelName)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "model_not_found", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Validate instance name at the entry point
|
||||
validatedName, err := validation.ValidateInstanceName(modelName)
|
||||
validatedName, err := validation.ValidateInstanceName(instanceName)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_instance_name", err.Error())
|
||||
return
|
||||
|
||||
@@ -73,6 +73,13 @@ func SetupRouter(handler *Handler) *chi.Mux {
|
||||
})
|
||||
})
|
||||
|
||||
// Llama.cpp instance-specific endpoints
|
||||
r.Route("/llama-cpp/{name}", func(r chi.Router) {
|
||||
r.Get("/models", handler.LlamaCppListModels())
|
||||
r.Post("/models/{model}/load", handler.LlamaCppLoadModel())
|
||||
r.Post("/models/{model}/unload", handler.LlamaCppUnloadModel())
|
||||
})
|
||||
|
||||
// Node management endpoints
|
||||
r.Route("/nodes", func(r chi.Router) {
|
||||
r.Get("/", handler.ListNodes()) // List all nodes
|
||||
|
||||
Reference in New Issue
Block a user