Split large package into subpackages

This commit is contained in:
2025-08-04 19:23:56 +02:00
parent a3c44dad1e
commit 6a7a9a2d09
21 changed files with 413 additions and 396 deletions

575
pkg/server/handlers.go Normal file
View File

@@ -0,0 +1,575 @@
package server
import (
"bytes"
"encoding/json"
"fmt"
"io"
"llamactl/pkg/config"
"llamactl/pkg/instance"
"llamactl/pkg/manager"
"net/http"
"os/exec"
"strconv"
"strings"
"github.com/go-chi/chi/v5"
)
type Handler struct {
InstanceManager manager.InstanceManager
cfg config.Config
}
func NewHandler(im manager.InstanceManager, cfg config.Config) *Handler {
return &Handler{
InstanceManager: im,
cfg: cfg,
}
}
// HelpHandler godoc
// @Summary Get help for llama server
// @Description Returns the help text for the llama server command
// @Tags server
// @Security ApiKeyAuth
// @Produces text/plain
// @Success 200 {string} string "Help text"
// @Failure 500 {string} string "Internal Server Error"
// @Router /server/help [get]
func (h *Handler) HelpHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
helpCmd := exec.Command("llama-server", "--help")
output, err := helpCmd.CombinedOutput()
if err != nil {
http.Error(w, "Failed to get help: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/plain")
w.Write(output)
}
}
// VersionHandler godoc
// @Summary Get version of llama server
// @Description Returns the version of the llama server command
// @Tags server
// @Security ApiKeyAuth
// @Produces text/plain
// @Success 200 {string} string "Version information"
// @Failure 500 {string} string "Internal Server Error"
// @Router /server/version [get]
func (h *Handler) VersionHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
versionCmd := exec.Command("llama-server", "--version")
output, err := versionCmd.CombinedOutput()
if err != nil {
http.Error(w, "Failed to get version: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/plain")
w.Write(output)
}
}
// ListDevicesHandler godoc
// @Summary List available devices for llama server
// @Description Returns a list of available devices for the llama server
// @Tags server
// @Security ApiKeyAuth
// @Produces text/plain
// @Success 200 {string} string "List of devices"
// @Failure 500 {string} string "Internal Server Error"
// @Router /server/devices [get]
func (h *Handler) ListDevicesHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
listCmd := exec.Command("llama-server", "--list-devices")
output, err := listCmd.CombinedOutput()
if err != nil {
http.Error(w, "Failed to list devices: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/plain")
w.Write(output)
}
}
// ListInstances godoc
// @Summary List all instances
// @Description Returns a list of all instances managed by the server
// @Tags instances
// @Security ApiKeyAuth
// @Produces json
// @Success 200 {array} Instance "List of instances"
// @Failure 500 {string} string "Internal Server Error"
// @Router /instances [get]
func (h *Handler) ListInstances() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
instances, err := h.InstanceManager.ListInstances()
if err != nil {
http.Error(w, "Failed to list instances: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(instances); err != nil {
http.Error(w, "Failed to encode instances: "+err.Error(), http.StatusInternalServerError)
return
}
}
}
// CreateInstance godoc
// @Summary Create and start a new instance
// @Description Creates a new instance with the provided configuration options
// @Tags instances
// @Security ApiKeyAuth
// @Accept json
// @Produces json
// @Param name path string true "Instance Name"
// @Param options body CreateInstanceOptions true "Instance configuration options"
// @Success 201 {object} Instance "Created instance details"
// @Failure 400 {string} string "Invalid request body"
// @Failure 500 {string} string "Internal Server Error"
// @Router /instances/{name} [post]
func (h *Handler) CreateInstance() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
if name == "" {
http.Error(w, "Instance name cannot be empty", http.StatusBadRequest)
return
}
var options instance.CreateInstanceOptions
if err := json.NewDecoder(r.Body).Decode(&options); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
inst, err := h.InstanceManager.CreateInstance(name, &options)
if err != nil {
http.Error(w, "Failed to create instance: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(inst); err != nil {
http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError)
return
}
}
}
// GetInstance godoc
// @Summary Get details of a specific instance
// @Description Returns the details of a specific instance by name
// @Tags instances
// @Security ApiKeyAuth
// @Produces json
// @Param name path string true "Instance Name"
// @Success 200 {object} Instance "Instance details"
// @Failure 400 {string} string "Invalid name format"
// @Failure 500 {string} string "Internal Server Error"
// @Router /instances/{name} [get]
func (h *Handler) GetInstance() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
if name == "" {
http.Error(w, "Instance name cannot be empty", http.StatusBadRequest)
return
}
inst, err := h.InstanceManager.GetInstance(name)
if err != nil {
http.Error(w, "Failed to get instance: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(inst); err != nil {
http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError)
return
}
}
}
// UpdateInstance godoc
// @Summary Update an instance's configuration
// @Description Updates the configuration of a specific instance by name
// @Tags instances
// @Security ApiKeyAuth
// @Accept json
// @Produces json
// @Param name path string true "Instance Name"
// @Param options body CreateInstanceOptions true "Instance configuration options"
// @Success 200 {object} Instance "Updated instance details"
// @Failure 400 {string} string "Invalid name format"
// @Failure 500 {string} string "Internal Server Error"
// @Router /instances/{name} [put]
func (h *Handler) UpdateInstance() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
if name == "" {
http.Error(w, "Instance name cannot be empty", http.StatusBadRequest)
return
}
var options instance.CreateInstanceOptions
if err := json.NewDecoder(r.Body).Decode(&options); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
inst, err := h.InstanceManager.UpdateInstance(name, &options)
if err != nil {
http.Error(w, "Failed to update instance: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(inst); err != nil {
http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError)
return
}
}
}
// StartInstance godoc
// @Summary Start a stopped instance
// @Description Starts a specific instance by name
// @Tags instances
// @Security ApiKeyAuth
// @Produces json
// @Param name path string true "Instance Name"
// @Success 200 {object} Instance "Started instance details"
// @Failure 400 {string} string "Invalid name format"
// @Failure 500 {string} string "Internal Server Error"
// @Router /instances/{name}/start [post]
func (h *Handler) StartInstance() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
if name == "" {
http.Error(w, "Instance name cannot be empty", http.StatusBadRequest)
return
}
inst, err := h.InstanceManager.StartInstance(name)
if err != nil {
http.Error(w, "Failed to start instance: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(inst); err != nil {
http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError)
return
}
}
}
// StopInstance godoc
// @Summary Stop a running instance
// @Description Stops a specific instance by name
// @Tags instances
// @Security ApiKeyAuth
// @Produces json
// @Param name path string true "Instance Name"
// @Success 200 {object} Instance "Stopped instance details"
// @Failure 400 {string} string "Invalid name format"
// @Failure 500 {string} string "Internal Server Error"
// @Router /instances/{name}/stop [post]
func (h *Handler) StopInstance() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
if name == "" {
http.Error(w, "Instance name cannot be empty", http.StatusBadRequest)
return
}
inst, err := h.InstanceManager.StopInstance(name)
if err != nil {
http.Error(w, "Failed to stop instance: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(inst); err != nil {
http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError)
return
}
}
}
// RestartInstance godoc
// @Summary Restart a running instance
// @Description Restarts a specific instance by name
// @Tags instances
// @Security ApiKeyAuth
// @Produces json
// @Param name path string true "Instance Name"
// @Success 200 {object} Instance "Restarted instance details"
// @Failure 400 {string} string "Invalid name format"
// @Failure 500 {string} string "Internal Server Error"
// @Router /instances/{name}/restart [post]
func (h *Handler) RestartInstance() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
if name == "" {
http.Error(w, "Instance name cannot be empty", http.StatusBadRequest)
return
}
inst, err := h.InstanceManager.RestartInstance(name)
if err != nil {
http.Error(w, "Failed to restart instance: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(inst); err != nil {
http.Error(w, "Failed to encode instance: "+err.Error(), http.StatusInternalServerError)
return
}
}
}
// DeleteInstance godoc
// @Summary Delete an instance
// @Description Stops and removes a specific instance by name
// @Tags instances
// @Security ApiKeyAuth
// @Param name path string true "Instance Name"
// @Success 204 "No Content"
// @Failure 400 {string} string "Invalid name format"
// @Failure 500 {string} string "Internal Server Error"
// @Router /instances/{name} [delete]
func (h *Handler) DeleteInstance() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
if name == "" {
http.Error(w, "Instance name cannot be empty", http.StatusBadRequest)
return
}
if err := h.InstanceManager.DeleteInstance(name); err != nil {
http.Error(w, "Failed to delete instance: "+err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}
}
// GetInstanceLogs godoc
// @Summary Get logs from a specific instance
// @Description Returns the logs from a specific instance by name with optional line limit
// @Tags instances
// @Security ApiKeyAuth
// @Param name path string true "Instance Name"
// @Param lines query string false "Number of lines to retrieve (default: all lines)"
// @Produces text/plain
// @Success 200 {string} string "Instance logs"
// @Failure 400 {string} string "Invalid name format or lines parameter"
// @Failure 500 {string} string "Internal Server Error"
// @Router /instances/{name}/logs [get]
func (h *Handler) GetInstanceLogs() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
if name == "" {
http.Error(w, "Instance name cannot be empty", http.StatusBadRequest)
return
}
lines := r.URL.Query().Get("lines")
if lines == "" {
lines = "-1"
}
num_lines, err := strconv.Atoi(lines)
if err != nil {
http.Error(w, "Invalid lines parameter: "+err.Error(), http.StatusBadRequest)
return
}
inst, err := h.InstanceManager.GetInstance(name)
if err != nil {
http.Error(w, "Failed to get instance: "+err.Error(), http.StatusInternalServerError)
return
}
logs, err := inst.GetLogs(num_lines)
if err != nil {
http.Error(w, "Failed to get logs: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/plain")
w.Write([]byte(logs))
}
}
// ProxyToInstance godoc
// @Summary Proxy requests to a specific instance
// @Description Forwards HTTP requests to the llama-server instance running on a specific port
// @Tags instances
// @Security ApiKeyAuth
// @Param name path string true "Instance Name"
// @Success 200 "Request successfully proxied to instance"
// @Failure 400 {string} string "Invalid name format"
// @Failure 500 {string} string "Internal Server Error"
// @Failure 503 {string} string "Instance is not running"
// @Router /instances/{name}/proxy [get]
// @Router /instances/{name}/proxy [post]
func (h *Handler) ProxyToInstance() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
if name == "" {
http.Error(w, "Instance name cannot be empty", http.StatusBadRequest)
return
}
inst, err := h.InstanceManager.GetInstance(name)
if err != nil {
http.Error(w, "Failed to get instance: "+err.Error(), http.StatusInternalServerError)
return
}
if !inst.Running {
http.Error(w, "Instance is not running", http.StatusServiceUnavailable)
return
}
// Get the cached proxy for this instance
proxy, err := inst.GetProxy()
if err != nil {
http.Error(w, "Failed to get proxy: "+err.Error(), http.StatusInternalServerError)
return
}
// Strip the "/api/v1/instances/<name>/proxy" prefix from the request URL
prefix := fmt.Sprintf("/api/v1/instances/%s/proxy", name)
proxyPath := r.URL.Path[len(prefix):]
// Ensure the proxy path starts with "/"
if !strings.HasPrefix(proxyPath, "/") {
proxyPath = "/" + proxyPath
}
// Modify the request to remove the proxy prefix
originalPath := r.URL.Path
r.URL.Path = proxyPath
// Set forwarded headers
r.Header.Set("X-Forwarded-Host", r.Header.Get("Host"))
r.Header.Set("X-Forwarded-Proto", "http")
// Restore original path for logging purposes
defer func() {
r.URL.Path = originalPath
}()
// Forward the request using the cached proxy
proxy.ServeHTTP(w, r)
}
}
// OpenAIListInstances godoc
// @Summary List instances in OpenAI-compatible format
// @Description Returns a list of instances in a format compatible with OpenAI API
// @Tags openai
// @Security ApiKeyAuth
// @Produces json
// @Success 200 {object} OpenAIListInstancesResponse "List of OpenAI-compatible instances"
// @Failure 500 {string} string "Internal Server Error"
// @Router /v1/models [get]
func (h *Handler) OpenAIListInstances() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
instances, err := h.InstanceManager.ListInstances()
if err != nil {
http.Error(w, "Failed to list instances: "+err.Error(), http.StatusInternalServerError)
return
}
openaiInstances := make([]OpenAIInstance, len(instances))
for i, inst := range instances {
openaiInstances[i] = OpenAIInstance{
ID: inst.Name,
Object: "model",
Created: inst.Created,
OwnedBy: "llamactl",
}
}
openaiResponse := OpenAIListInstancesResponse{
Object: "list",
Data: openaiInstances,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(openaiResponse); err != nil {
http.Error(w, "Failed to encode instances: "+err.Error(), http.StatusInternalServerError)
return
}
}
}
// OpenAIProxy godoc
// @Summary OpenAI-compatible proxy endpoint
// @Description Handles all POST requests to /v1/*, routing to the appropriate instance based on the request body. Requires API key authentication via the `Authorization` header.
// @Tags openai
// @Security ApiKeyAuth
// @Accept json
// @Produces json
// @Success 200 "OpenAI response"
// @Failure 400 {string} string "Invalid request body or model name"
// @Failure 500 {string} string "Internal Server Error"
// @Router /v1/ [post]
func (h *Handler) OpenAIProxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Read the entire body first
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
r.Body.Close()
// Parse the body to extract model name
var requestBody map[string]any
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
modelName, ok := requestBody["model"].(string)
if !ok || modelName == "" {
http.Error(w, "Model name is required", http.StatusBadRequest)
return
}
// Route to the appropriate inst based on model name
inst, err := h.InstanceManager.GetInstance(modelName)
if err != nil {
http.Error(w, "Failed to get instance: "+err.Error(), http.StatusInternalServerError)
return
}
if !inst.Running {
http.Error(w, "Instance is not running", http.StatusServiceUnavailable)
return
}
proxy, err := inst.GetProxy()
if err != nil {
http.Error(w, "Failed to get proxy: "+err.Error(), http.StatusInternalServerError)
return
}
// Recreate the request body from the bytes we read
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
r.ContentLength = int64(len(bodyBytes))
proxy.ServeHTTP(w, r)
}
}

189
pkg/server/middleware.go Normal file
View File

@@ -0,0 +1,189 @@
package server
import (
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"fmt"
"llamactl/pkg/config"
"log"
"net/http"
"os"
"strings"
)
type KeyType int
const (
KeyTypeInference KeyType = iota
KeyTypeManagement
)
type APIAuthMiddleware struct {
requireInferenceAuth bool
inferenceKeys map[string]bool
requireManagementAuth bool
managementKeys map[string]bool
}
// NewAPIAuthMiddleware creates a new APIAuthMiddleware with the given configuration
func NewAPIAuthMiddleware(authCfg config.AuthConfig) *APIAuthMiddleware {
var generated bool = false
inferenceAPIKeys := make(map[string]bool)
managementAPIKeys := make(map[string]bool)
const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
if authCfg.RequireManagementAuth && len(authCfg.ManagementKeys) == 0 {
key := generateAPIKey(KeyTypeManagement)
managementAPIKeys[key] = true
generated = true
fmt.Printf("%s\n⚠ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner)
fmt.Printf("🔑 Generated Management API Key:\n\n %s\n\n", key)
}
for _, key := range authCfg.ManagementKeys {
managementAPIKeys[key] = true
}
if authCfg.RequireInferenceAuth && len(authCfg.InferenceKeys) == 0 {
key := generateAPIKey(KeyTypeInference)
inferenceAPIKeys[key] = true
generated = true
fmt.Printf("%s\n⚠ INFERENCE AUTHENTICATION REQUIRED\n%s\n", banner, banner)
fmt.Printf("🔑 Generated Inference API Key:\n\n %s\n\n", key)
}
for _, key := range authCfg.InferenceKeys {
inferenceAPIKeys[key] = true
}
if generated {
fmt.Printf("%s\n⚠ IMPORTANT\n%s\n", banner, banner)
fmt.Println("• These keys are auto-generated and will change on restart")
fmt.Println("• For production, add explicit keys to your configuration")
fmt.Println("• Copy these keys before they disappear from the terminal")
fmt.Println(banner)
}
return &APIAuthMiddleware{
requireInferenceAuth: authCfg.RequireInferenceAuth,
inferenceKeys: inferenceAPIKeys,
requireManagementAuth: authCfg.RequireManagementAuth,
managementKeys: managementAPIKeys,
}
}
// generateAPIKey creates a cryptographically secure API key
func generateAPIKey(keyType KeyType) string {
// Generate 32 random bytes (256 bits)
randomBytes := make([]byte, 32)
var prefix string
switch keyType {
case KeyTypeInference:
prefix = "sk-inference"
case KeyTypeManagement:
prefix = "sk-management"
default:
prefix = "sk-unknown"
}
if _, err := rand.Read(randomBytes); err != nil {
log.Printf("Warning: Failed to generate secure random key, using fallback")
// Fallback to a less secure method if crypto/rand fails
return fmt.Sprintf("%s-fallback-%d", prefix, os.Getpid())
}
// Convert to hex and add prefix
return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes))
}
// AuthMiddleware returns a middleware that checks API keys for the given key type
func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
apiKey := a.extractAPIKey(r)
if apiKey == "" {
a.unauthorized(w, "Missing API key")
return
}
var isValid bool
switch keyType {
case KeyTypeInference:
// Management keys also work for OpenAI endpoints (higher privilege)
isValid = a.isValidKey(apiKey, KeyTypeInference) || a.isValidKey(apiKey, KeyTypeManagement)
case KeyTypeManagement:
isValid = a.isValidKey(apiKey, KeyTypeManagement)
default:
isValid = false
}
if !isValid {
a.unauthorized(w, "Invalid API key")
return
}
next.ServeHTTP(w, r)
})
}
}
// extractAPIKey extracts the API key from the request
func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string {
// Check Authorization header: "Bearer sk-..."
if auth := r.Header.Get("Authorization"); auth != "" {
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
return after
}
}
// Check X-API-Key header
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
return apiKey
}
// Check query parameter
if apiKey := r.URL.Query().Get("api_key"); apiKey != "" {
return apiKey
}
return ""
}
// isValidKey checks if the provided API key is valid for the given key type
func (a *APIAuthMiddleware) isValidKey(providedKey string, keyType KeyType) bool {
var validKeys map[string]bool
switch keyType {
case KeyTypeInference:
validKeys = a.inferenceKeys
case KeyTypeManagement:
validKeys = a.managementKeys
default:
return false
}
for validKey := range validKeys {
if len(providedKey) == len(validKey) &&
subtle.ConstantTimeCompare([]byte(providedKey), []byte(validKey)) == 1 {
return true
}
}
return false
}
// unauthorized sends an unauthorized response
func (a *APIAuthMiddleware) unauthorized(w http.ResponseWriter, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
response := fmt.Sprintf(`{"error": {"message": "%s", "type": "authentication_error"}}`, message)
w.Write([]byte(response))
}

View File

@@ -0,0 +1,354 @@
package server_test
import (
"llamactl/pkg/config"
"llamactl/pkg/server"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestAuthMiddleware(t *testing.T) {
tests := []struct {
name string
keyType server.KeyType
inferenceKeys []string
managementKeys []string
requestKey string
method string
expectedStatus int
}{
// Valid key tests
{
name: "valid inference key for inference",
keyType: server.KeyTypeInference,
inferenceKeys: []string{"sk-inference-valid123"},
requestKey: "sk-inference-valid123",
method: "GET",
expectedStatus: http.StatusOK,
},
{
name: "valid management key for inference", // Management keys work for inference
keyType: server.KeyTypeInference,
managementKeys: []string{"sk-management-admin123"},
requestKey: "sk-management-admin123",
method: "GET",
expectedStatus: http.StatusOK,
},
{
name: "valid management key for management",
keyType: server.KeyTypeManagement,
managementKeys: []string{"sk-management-admin123"},
requestKey: "sk-management-admin123",
method: "GET",
expectedStatus: http.StatusOK,
},
// Invalid key tests
{
name: "inference key for management should fail",
keyType: server.KeyTypeManagement,
inferenceKeys: []string{"sk-inference-user123"},
requestKey: "sk-inference-user123",
method: "GET",
expectedStatus: http.StatusUnauthorized,
},
{
name: "invalid inference key",
keyType: server.KeyTypeInference,
inferenceKeys: []string{"sk-inference-valid123"},
requestKey: "sk-inference-invalid",
method: "GET",
expectedStatus: http.StatusUnauthorized,
},
{
name: "missing inference key",
keyType: server.KeyTypeInference,
inferenceKeys: []string{"sk-inference-valid123"},
requestKey: "",
method: "GET",
expectedStatus: http.StatusUnauthorized,
},
{
name: "invalid management key",
keyType: server.KeyTypeManagement,
managementKeys: []string{"sk-management-valid123"},
requestKey: "sk-management-invalid",
method: "GET",
expectedStatus: http.StatusUnauthorized,
},
{
name: "missing management key",
keyType: server.KeyTypeManagement,
managementKeys: []string{"sk-management-valid123"},
requestKey: "",
method: "GET",
expectedStatus: http.StatusUnauthorized,
},
// OPTIONS requests should always pass
{
name: "OPTIONS request bypasses inference auth",
keyType: server.KeyTypeInference,
inferenceKeys: []string{"sk-inference-valid123"},
requestKey: "",
method: "OPTIONS",
expectedStatus: http.StatusOK,
},
{
name: "OPTIONS request bypasses management auth",
keyType: server.KeyTypeManagement,
managementKeys: []string{"sk-management-valid123"},
requestKey: "",
method: "OPTIONS",
expectedStatus: http.StatusOK,
},
// Cross-key-type validation
{
name: "management key works for inference endpoint",
keyType: server.KeyTypeInference,
inferenceKeys: []string{},
managementKeys: []string{"sk-management-admin"},
requestKey: "sk-management-admin",
method: "POST",
expectedStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := config.AuthConfig{
InferenceKeys: tt.inferenceKeys,
ManagementKeys: tt.managementKeys,
}
middleware := server.NewAPIAuthMiddleware(cfg)
// Create test request
req := httptest.NewRequest(tt.method, "/test", nil)
if tt.requestKey != "" {
req.Header.Set("Authorization", "Bearer "+tt.requestKey)
}
// Create test handler using the appropriate middleware
var handler http.Handler
if tt.keyType == server.KeyTypeInference {
handler = middleware.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
} else {
handler = middleware.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
}
// Execute request
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
if recorder.Code != tt.expectedStatus {
t.Errorf("AuthMiddleware() status = %v, expected %v", recorder.Code, tt.expectedStatus)
}
// Check that unauthorized responses have proper format
if recorder.Code == http.StatusUnauthorized {
contentType := recorder.Header().Get("Content-Type")
if contentType != "application/json" {
t.Errorf("Unauthorized response Content-Type = %v, expected application/json", contentType)
}
body := recorder.Body.String()
if !strings.Contains(body, `"type": "authentication_error"`) {
t.Errorf("Unauthorized response missing proper error type: %v", body)
}
}
})
}
}
func TestGenerateAPIKey(t *testing.T) {
tests := []struct {
name string
keyType server.KeyType
}{
{"inference key generation", server.KeyTypeInference},
{"management key generation", server.KeyTypeManagement},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test auto-generation by creating config that will trigger it
var config config.AuthConfig
if tt.keyType == server.KeyTypeInference {
config.RequireInferenceAuth = true
config.InferenceKeys = []string{} // Empty to trigger generation
} else {
config.RequireManagementAuth = true
config.ManagementKeys = []string{} // Empty to trigger generation
}
// Create middleware - this should trigger key generation
middleware := server.NewAPIAuthMiddleware(config)
// Test that auth is required (meaning a key was generated)
req := httptest.NewRequest("GET", "/", nil)
recorder := httptest.NewRecorder()
var handler http.Handler
if tt.keyType == server.KeyTypeInference {
handler = middleware.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
} else {
handler = middleware.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
}
handler.ServeHTTP(recorder, req)
// Should be unauthorized without a key (proving that a key was generated and auth is working)
if recorder.Code != http.StatusUnauthorized {
t.Errorf("Expected unauthorized without key, got status %v", recorder.Code)
}
// Test uniqueness by creating another middleware instance
middleware2 := server.NewAPIAuthMiddleware(config)
req2 := httptest.NewRequest("GET", "/", nil)
recorder2 := httptest.NewRecorder()
if tt.keyType == server.KeyTypeInference {
handler2 := middleware2.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
handler2.ServeHTTP(recorder2, req2)
} else {
handler2 := middleware2.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
handler2.ServeHTTP(recorder2, req2)
}
// Both should require auth (proving keys were generated for both instances)
if recorder2.Code != http.StatusUnauthorized {
t.Errorf("Expected unauthorized for second middleware without key, got status %v", recorder2.Code)
}
})
}
}
func TestAutoGeneration(t *testing.T) {
tests := []struct {
name string
requireInference bool
requireManagement bool
providedInference []string
providedManagement []string
shouldGenerateInf bool // Whether inference key should be generated
shouldGenerateMgmt bool // Whether management key should be generated
}{
{
name: "inference auth required, keys provided - no generation",
requireInference: true,
requireManagement: false,
providedInference: []string{"sk-inference-provided"},
providedManagement: []string{},
shouldGenerateInf: false,
shouldGenerateMgmt: false,
},
{
name: "inference auth required, no keys - should auto-generate",
requireInference: true,
requireManagement: false,
providedInference: []string{},
providedManagement: []string{},
shouldGenerateInf: true,
shouldGenerateMgmt: false,
},
{
name: "management auth required, keys provided - no generation",
requireInference: false,
requireManagement: true,
providedInference: []string{},
providedManagement: []string{"sk-management-provided"},
shouldGenerateInf: false,
shouldGenerateMgmt: false,
},
{
name: "management auth required, no keys - should auto-generate",
requireInference: false,
requireManagement: true,
providedInference: []string{},
providedManagement: []string{},
shouldGenerateInf: false,
shouldGenerateMgmt: true,
},
{
name: "both required, both provided - no generation",
requireInference: true,
requireManagement: true,
providedInference: []string{"sk-inference-provided"},
providedManagement: []string{"sk-management-provided"},
shouldGenerateInf: false,
shouldGenerateMgmt: false,
},
{
name: "both required, none provided - should auto-generate both",
requireInference: true,
requireManagement: true,
providedInference: []string{},
providedManagement: []string{},
shouldGenerateInf: true,
shouldGenerateMgmt: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := config.AuthConfig{
RequireInferenceAuth: tt.requireInference,
RequireManagementAuth: tt.requireManagement,
InferenceKeys: tt.providedInference,
ManagementKeys: tt.providedManagement,
}
middleware := server.NewAPIAuthMiddleware(cfg)
// Test inference behavior if inference auth is required
if tt.requireInference {
req := httptest.NewRequest("GET", "/v1/models", nil)
recorder := httptest.NewRecorder()
handler := middleware.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
handler.ServeHTTP(recorder, req)
// Should always be unauthorized without a key (since middleware assumes auth is required)
if recorder.Code != http.StatusUnauthorized {
t.Errorf("Expected unauthorized for inference without key, got status %v", recorder.Code)
}
}
// Test management behavior if management auth is required
if tt.requireManagement {
req := httptest.NewRequest("GET", "/api/v1/instances", nil)
recorder := httptest.NewRecorder()
handler := middleware.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
handler.ServeHTTP(recorder, req)
// Should always be unauthorized without a key (since middleware assumes auth is required)
if recorder.Code != http.StatusUnauthorized {
t.Errorf("Expected unauthorized for management without key, got status %v", recorder.Code)
}
}
})
}
}

13
pkg/server/openai.go Normal file
View File

@@ -0,0 +1,13 @@
package server
type OpenAIListInstancesResponse struct {
Object string `json:"object"`
Data []OpenAIInstance `json:"data"`
}
type OpenAIInstance struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}

100
pkg/server/routes.go Normal file
View File

@@ -0,0 +1,100 @@
package server
import (
"fmt"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
httpSwagger "github.com/swaggo/http-swagger"
_ "llamactl/docs"
"llamactl/webui"
)
func SetupRouter(handler *Handler) *chi.Mux {
r := chi.NewRouter()
r.Use(middleware.Logger)
// Add CORS middleware
r.Use(cors.Handler(cors.Options{
AllowedOrigins: handler.cfg.Server.AllowedOrigins,
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: false,
MaxAge: 300,
}))
// Add API authentication middleware
authMiddleware := NewAPIAuthMiddleware(handler.cfg.Auth)
if handler.cfg.Server.EnableSwagger {
r.Get("/swagger/*", httpSwagger.Handler(
httpSwagger.URL("/swagger/doc.json"),
))
}
// Define routes
r.Route("/api/v1", func(r chi.Router) {
if authMiddleware != nil && handler.cfg.Auth.RequireManagementAuth {
r.Use(authMiddleware.AuthMiddleware(KeyTypeManagement))
}
r.Route("/server", func(r chi.Router) {
r.Get("/help", handler.HelpHandler())
r.Get("/version", handler.VersionHandler())
r.Get("/devices", handler.ListDevicesHandler())
})
// Instance management endpoints
r.Route("/instances", func(r chi.Router) {
r.Get("/", handler.ListInstances()) // List all instances
r.Route("/{name}", func(r chi.Router) {
// Instance management
r.Get("/", handler.GetInstance()) // Get instance details
r.Post("/", handler.CreateInstance()) // Create and start new instance
r.Put("/", handler.UpdateInstance()) // Update instance configuration
r.Delete("/", handler.DeleteInstance()) // Stop and remove instance
r.Post("/start", handler.StartInstance()) // Start stopped instance
r.Post("/stop", handler.StopInstance()) // Stop running instance
r.Post("/restart", handler.RestartInstance()) // Restart instance
r.Get("/logs", handler.GetInstanceLogs()) // Get instance logs
// Llama.cpp server proxy endpoints (proxied to the actual llama.cpp server)
r.Route("/proxy", func(r chi.Router) {
r.HandleFunc("/*", handler.ProxyToInstance()) // Proxy all llama.cpp server requests
})
})
})
})
r.Route(("/v1"), func(r chi.Router) {
if authMiddleware != nil && handler.cfg.Auth.RequireInferenceAuth {
r.Use(authMiddleware.AuthMiddleware(KeyTypeInference))
}
r.Get(("/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format
// OpenAI-compatible proxy endpoint
// Handles all POST requests to /v1/*, including:
// - /v1/completions
// - /v1/chat/completions
// - /v1/embeddings
// - /v1/rerank
// - /v1/reranking
// The instance/model to use is determined by the request body.
r.Post("/*", handler.OpenAIProxy())
})
// Serve WebUI files
if err := webui.SetupWebUI(r); err != nil {
fmt.Printf("Failed to set up WebUI: %v\n", err)
}
return r
}