Implement api key auth

This commit is contained in:
2025-07-30 20:15:14 +02:00
parent 72ba008d1e
commit b3540d5b3e
3 changed files with 298 additions and 13 deletions

View File

@@ -14,6 +14,7 @@ import (
type Config struct { type Config struct {
Server ServerConfig `yaml:"server"` Server ServerConfig `yaml:"server"`
Instances InstancesConfig `yaml:"instances"` Instances InstancesConfig `yaml:"instances"`
Auth AuthConfig `yaml:"auth"`
} }
// ServerConfig contains HTTP server configuration // ServerConfig contains HTTP server configuration
@@ -26,6 +27,9 @@ type ServerConfig struct {
// Allowed origins for CORS (e.g., "http://localhost:3000") // Allowed origins for CORS (e.g., "http://localhost:3000")
AllowedOrigins []string `yaml:"allowed_origins"` AllowedOrigins []string `yaml:"allowed_origins"`
// Enable Swagger UI for API documentation
EnableSwagger bool `yaml:"enable_swagger"`
} }
// InstancesConfig contains instance management configuration // InstancesConfig contains instance management configuration
@@ -52,6 +56,22 @@ type InstancesConfig struct {
DefaultRestartDelay int `yaml:"default_restart_delay"` DefaultRestartDelay int `yaml:"default_restart_delay"`
} }
// AuthConfig contains authentication settings
type AuthConfig struct {
// Require authentication for OpenAI compatible inference endpoints
RequireInferenceAuth bool `yaml:"require_inference_auth"`
// List of keys for OpenAI compatible inference endpoints
InferenceKeys []string `yaml:"inference_keys"`
// Require authentication for management endpoints
RequireManagementAuth bool `yaml:"require_management_auth"`
// List of keys for management endpoints
ManagementKeys []string `yaml:"management_keys"`
}
// LoadConfig loads configuration with the following precedence: // LoadConfig loads configuration with the following precedence:
// 1. Hardcoded defaults // 1. Hardcoded defaults
// 2. Config file // 2. Config file
@@ -63,6 +83,7 @@ func LoadConfig(configPath string) (Config, error) {
Host: "0.0.0.0", Host: "0.0.0.0",
Port: 8080, Port: 8080,
AllowedOrigins: []string{"*"}, // Default to allow all origins AllowedOrigins: []string{"*"}, // Default to allow all origins
EnableSwagger: false,
}, },
Instances: InstancesConfig{ Instances: InstancesConfig{
PortRange: [2]int{8000, 9000}, PortRange: [2]int{8000, 9000},
@@ -73,6 +94,12 @@ func LoadConfig(configPath string) (Config, error) {
DefaultMaxRestarts: 3, DefaultMaxRestarts: 3,
DefaultRestartDelay: 5, DefaultRestartDelay: 5,
}, },
Auth: AuthConfig{
RequireInferenceAuth: true,
InferenceKeys: []string{},
RequireManagementAuth: true,
ManagementKeys: []string{},
},
} }
// 2. Load from config file // 2. Load from config file
@@ -121,6 +148,14 @@ func loadEnvVars(cfg *Config) {
cfg.Server.Port = p cfg.Server.Port = p
} }
} }
if allowedOrigins := os.Getenv("LLAMACTL_ALLOWED_ORIGINS"); allowedOrigins != "" {
cfg.Server.AllowedOrigins = strings.Split(allowedOrigins, ",")
}
if enableSwagger := os.Getenv("LLAMACTL_ENABLE_SWAGGER"); enableSwagger != "" {
if b, err := strconv.ParseBool(enableSwagger); err == nil {
cfg.Server.EnableSwagger = b
}
}
// Instance config // Instance config
if portRange := os.Getenv("LLAMACTL_INSTANCE_PORT_RANGE"); portRange != "" { if portRange := os.Getenv("LLAMACTL_INSTANCE_PORT_RANGE"); portRange != "" {
@@ -154,6 +189,23 @@ func loadEnvVars(cfg *Config) {
cfg.Instances.DefaultRestartDelay = seconds cfg.Instances.DefaultRestartDelay = seconds
} }
} }
// Auth config
if requireInferenceAuth := os.Getenv("LLAMACTL_REQUIRE_INFERENCE_AUTH"); requireInferenceAuth != "" {
if b, err := strconv.ParseBool(requireInferenceAuth); err == nil {
cfg.Auth.RequireInferenceAuth = b
}
}
if inferenceKeys := os.Getenv("LLAMACTL_INFERENCE_KEYS"); inferenceKeys != "" {
cfg.Auth.InferenceKeys = strings.Split(inferenceKeys, ",")
}
if requireManagementAuth := os.Getenv("LLAMACTL_REQUIRE_MANAGEMENT_AUTH"); requireManagementAuth != "" {
if b, err := strconv.ParseBool(requireManagementAuth); err == nil {
cfg.Auth.RequireManagementAuth = b
}
}
if managementKeys := os.Getenv("LLAMACTL_MANAGEMENT_KEYS"); managementKeys != "" {
cfg.Auth.ManagementKeys = strings.Split(managementKeys, ",")
}
} }
// ParsePortRange parses port range from string formats like "8000-9000" or "8000,9000" // ParsePortRange parses port range from string formats like "8000-9000" or "8000,9000"

215
pkg/middleware.go Normal file
View File

@@ -0,0 +1,215 @@
package llamactl
import (
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"fmt"
"log"
"net/http"
"os"
"strings"
)
type KeyType int
const (
KeyTypeInference KeyType = iota
KeyTypeManagement
)
type APIAuthMiddleware struct {
requireInferenceAuth bool
inferenceKeys map[string]bool
requireManagementAuth bool
managementKeys map[string]bool
}
// NewAPIAuthMiddleware creates a new APIAuthMiddleware with the given configuration
func NewAPIAuthMiddleware(config AuthConfig) *APIAuthMiddleware {
var generated bool = false
inferenceAPIKeys := make(map[string]bool)
managementAPIKeys := make(map[string]bool)
const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
if config.RequireManagementAuth && len(config.ManagementKeys) == 0 {
key := generateAPIKey(KeyTypeManagement)
managementAPIKeys[key] = true
generated = true
fmt.Printf("%s\n⚠ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner)
fmt.Printf("🔑 Generated Management API Key:\n\n %s\n\n", key)
}
for _, key := range config.ManagementKeys {
managementAPIKeys[key] = true
}
if config.RequireInferenceAuth && len(config.InferenceKeys) == 0 {
key := generateAPIKey(KeyTypeInference)
inferenceAPIKeys[key] = true
generated = true
fmt.Printf("%s\n⚠ INFERENCE AUTHENTICATION REQUIRED\n%s\n", banner, banner)
fmt.Printf("🔑 Generated Inference API Key:\n\n %s\n\n", key)
}
for _, key := range config.InferenceKeys {
inferenceAPIKeys[key] = true
}
if generated {
fmt.Printf("%s\n⚠ IMPORTANT\n%s\n", banner, banner)
fmt.Println("• These keys are auto-generated and will change on restart")
fmt.Println("• For production, add explicit keys to your configuration")
fmt.Println("• Copy these keys before they disappear from the terminal")
fmt.Println(banner)
}
return &APIAuthMiddleware{
requireInferenceAuth: config.RequireInferenceAuth,
inferenceKeys: inferenceAPIKeys,
requireManagementAuth: config.RequireManagementAuth,
managementKeys: managementAPIKeys,
}
}
// generateAPIKey creates a cryptographically secure API key
func generateAPIKey(keyType KeyType) string {
// Generate 32 random bytes (256 bits)
randomBytes := make([]byte, 32)
var prefix string
switch keyType {
case KeyTypeInference:
prefix = "sk-inference"
case KeyTypeManagement:
prefix = "sk-management"
default:
prefix = "sk-unknown"
}
if _, err := rand.Read(randomBytes); err != nil {
log.Printf("Warning: Failed to generate secure random key, using fallback")
// Fallback to a less secure method if crypto/rand fails
return fmt.Sprintf("%s-fallback-%d", prefix, os.Getpid())
}
// Convert to hex and add prefix
return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes))
}
// InferenceMiddleware returns middleware for OpenAI inference endpoints
func (a *APIAuthMiddleware) InferenceMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !a.requireInferenceAuth {
next.ServeHTTP(w, r)
return
}
if r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
apiKey := a.extractAPIKey(r)
if apiKey == "" {
a.unauthorized(w, "Missing API key")
return
}
// Check if key is valid for OpenAI access
// Management keys also work for OpenAI endpoints (higher privilege)
if !a.isValidKey(apiKey, KeyTypeInference) && !a.isValidKey(apiKey, KeyTypeManagement) {
a.unauthorized(w, "Invalid API key")
return
}
next.ServeHTTP(w, r)
})
}
}
// ManagementMiddleware returns middleware for management endpoints
func (a *APIAuthMiddleware) ManagementMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !a.requireManagementAuth {
next.ServeHTTP(w, r)
return
}
if r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
apiKey := a.extractAPIKey(r)
if apiKey == "" {
a.unauthorized(w, "Missing API key")
return
}
// Only management keys work for management endpoints
if !a.isValidKey(apiKey, KeyTypeManagement) {
a.unauthorized(w, "Insufficient privileges - management key required")
return
}
next.ServeHTTP(w, r)
})
}
}
// extractAPIKey extracts the API key from the request
func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string {
// Check Authorization header: "Bearer sk-..."
if auth := r.Header.Get("Authorization"); auth != "" {
if strings.HasPrefix(auth, "Bearer ") {
return strings.TrimPrefix(auth, "Bearer ")
}
}
// Check X-API-Key header
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
return apiKey
}
// Check query parameter
if apiKey := r.URL.Query().Get("api_key"); apiKey != "" {
return apiKey
}
return ""
}
// isValidKey checks if the provided API key is valid for the given key type
func (a *APIAuthMiddleware) isValidKey(providedKey string, keyType KeyType) bool {
var validKeys map[string]bool
switch keyType {
case KeyTypeInference:
validKeys = a.inferenceKeys
case KeyTypeManagement:
validKeys = a.managementKeys
default:
return false
}
for validKey := range validKeys {
if len(providedKey) == len(validKey) &&
subtle.ConstantTimeCompare([]byte(providedKey), []byte(validKey)) == 1 {
return true
}
}
return false
}
// unauthorized sends an unauthorized response
func (a *APIAuthMiddleware) unauthorized(w http.ResponseWriter, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
response := fmt.Sprintf(`{"error": {"message": "%s", "type": "authentication_error"}}`, message)
w.Write([]byte(response))
}

View File

@@ -26,12 +26,22 @@ func SetupRouter(handler *Handler) *chi.Mux {
MaxAge: 300, MaxAge: 300,
})) }))
// Add API authentication middleware
authMiddleware := NewAPIAuthMiddleware(handler.config.Auth)
if handler.config.Server.EnableSwagger {
r.Get("/swagger/*", httpSwagger.Handler( r.Get("/swagger/*", httpSwagger.Handler(
httpSwagger.URL("/swagger/doc.json"), httpSwagger.URL("/swagger/doc.json"),
)) ))
}
// Define routes // Define routes
r.Route("/api/v1", func(r chi.Router) { r.Route("/api/v1", func(r chi.Router) {
if authMiddleware != nil && handler.config.Auth.RequireManagementAuth {
r.Use(authMiddleware.ManagementMiddleware())
}
r.Route("/server", func(r chi.Router) { r.Route("/server", func(r chi.Router) {
r.Get("/help", handler.HelpHandler()) r.Get("/help", handler.HelpHandler())
r.Get("/version", handler.VersionHandler()) r.Get("/version", handler.VersionHandler())
@@ -61,7 +71,13 @@ func SetupRouter(handler *Handler) *chi.Mux {
}) })
}) })
r.Get(("/v1/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format r.Route(("/v1"), func(r chi.Router) {
if authMiddleware != nil && handler.config.Auth.RequireInferenceAuth {
r.Use(authMiddleware.InferenceMiddleware())
}
r.Get(("/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format
// OpenAI-compatible proxy endpoint // OpenAI-compatible proxy endpoint
// Handles all POST requests to /v1/*, including: // Handles all POST requests to /v1/*, including:
@@ -71,7 +87,9 @@ func SetupRouter(handler *Handler) *chi.Mux {
// - /v1/rerank // - /v1/rerank
// - /v1/reranking // - /v1/reranking
// The instance/model to use is determined by the request body. // The instance/model to use is determined by the request body.
r.Post("/v1/*", handler.OpenAIProxy()) r.Post("/*", handler.OpenAIProxy())
})
// Serve WebUI files // Serve WebUI files
if err := webui.SetupWebUI(r); err != nil { if err := webui.SetupWebUI(r); err != nil {