mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-11-06 00:54:23 +00:00
Implement api key auth
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
||||
type Config struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Instances InstancesConfig `yaml:"instances"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
}
|
||||
|
||||
// ServerConfig contains HTTP server configuration
|
||||
@@ -26,6 +27,9 @@ type ServerConfig struct {
|
||||
|
||||
// Allowed origins for CORS (e.g., "http://localhost:3000")
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
|
||||
// Enable Swagger UI for API documentation
|
||||
EnableSwagger bool `yaml:"enable_swagger"`
|
||||
}
|
||||
|
||||
// InstancesConfig contains instance management configuration
|
||||
@@ -52,6 +56,22 @@ type InstancesConfig struct {
|
||||
DefaultRestartDelay int `yaml:"default_restart_delay"`
|
||||
}
|
||||
|
||||
// AuthConfig contains authentication settings
|
||||
type AuthConfig struct {
|
||||
|
||||
// Require authentication for OpenAI compatible inference endpoints
|
||||
RequireInferenceAuth bool `yaml:"require_inference_auth"`
|
||||
|
||||
// List of keys for OpenAI compatible inference endpoints
|
||||
InferenceKeys []string `yaml:"inference_keys"`
|
||||
|
||||
// Require authentication for management endpoints
|
||||
RequireManagementAuth bool `yaml:"require_management_auth"`
|
||||
|
||||
// List of keys for management endpoints
|
||||
ManagementKeys []string `yaml:"management_keys"`
|
||||
}
|
||||
|
||||
// LoadConfig loads configuration with the following precedence:
|
||||
// 1. Hardcoded defaults
|
||||
// 2. Config file
|
||||
@@ -63,6 +83,7 @@ func LoadConfig(configPath string) (Config, error) {
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
AllowedOrigins: []string{"*"}, // Default to allow all origins
|
||||
EnableSwagger: false,
|
||||
},
|
||||
Instances: InstancesConfig{
|
||||
PortRange: [2]int{8000, 9000},
|
||||
@@ -73,6 +94,12 @@ func LoadConfig(configPath string) (Config, error) {
|
||||
DefaultMaxRestarts: 3,
|
||||
DefaultRestartDelay: 5,
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
RequireInferenceAuth: true,
|
||||
InferenceKeys: []string{},
|
||||
RequireManagementAuth: true,
|
||||
ManagementKeys: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
// 2. Load from config file
|
||||
@@ -121,6 +148,14 @@ func loadEnvVars(cfg *Config) {
|
||||
cfg.Server.Port = p
|
||||
}
|
||||
}
|
||||
if allowedOrigins := os.Getenv("LLAMACTL_ALLOWED_ORIGINS"); allowedOrigins != "" {
|
||||
cfg.Server.AllowedOrigins = strings.Split(allowedOrigins, ",")
|
||||
}
|
||||
if enableSwagger := os.Getenv("LLAMACTL_ENABLE_SWAGGER"); enableSwagger != "" {
|
||||
if b, err := strconv.ParseBool(enableSwagger); err == nil {
|
||||
cfg.Server.EnableSwagger = b
|
||||
}
|
||||
}
|
||||
|
||||
// Instance config
|
||||
if portRange := os.Getenv("LLAMACTL_INSTANCE_PORT_RANGE"); portRange != "" {
|
||||
@@ -154,6 +189,23 @@ func loadEnvVars(cfg *Config) {
|
||||
cfg.Instances.DefaultRestartDelay = seconds
|
||||
}
|
||||
}
|
||||
// Auth config
|
||||
if requireInferenceAuth := os.Getenv("LLAMACTL_REQUIRE_INFERENCE_AUTH"); requireInferenceAuth != "" {
|
||||
if b, err := strconv.ParseBool(requireInferenceAuth); err == nil {
|
||||
cfg.Auth.RequireInferenceAuth = b
|
||||
}
|
||||
}
|
||||
if inferenceKeys := os.Getenv("LLAMACTL_INFERENCE_KEYS"); inferenceKeys != "" {
|
||||
cfg.Auth.InferenceKeys = strings.Split(inferenceKeys, ",")
|
||||
}
|
||||
if requireManagementAuth := os.Getenv("LLAMACTL_REQUIRE_MANAGEMENT_AUTH"); requireManagementAuth != "" {
|
||||
if b, err := strconv.ParseBool(requireManagementAuth); err == nil {
|
||||
cfg.Auth.RequireManagementAuth = b
|
||||
}
|
||||
}
|
||||
if managementKeys := os.Getenv("LLAMACTL_MANAGEMENT_KEYS"); managementKeys != "" {
|
||||
cfg.Auth.ManagementKeys = strings.Split(managementKeys, ",")
|
||||
}
|
||||
}
|
||||
|
||||
// ParsePortRange parses port range from string formats like "8000-9000" or "8000,9000"
|
||||
|
||||
215
pkg/middleware.go
Normal file
215
pkg/middleware.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package llamactl
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type KeyType int
|
||||
|
||||
const (
|
||||
KeyTypeInference KeyType = iota
|
||||
KeyTypeManagement
|
||||
)
|
||||
|
||||
type APIAuthMiddleware struct {
|
||||
requireInferenceAuth bool
|
||||
inferenceKeys map[string]bool
|
||||
requireManagementAuth bool
|
||||
managementKeys map[string]bool
|
||||
}
|
||||
|
||||
// NewAPIAuthMiddleware creates a new APIAuthMiddleware with the given configuration
|
||||
func NewAPIAuthMiddleware(config AuthConfig) *APIAuthMiddleware {
|
||||
|
||||
var generated bool = false
|
||||
|
||||
inferenceAPIKeys := make(map[string]bool)
|
||||
managementAPIKeys := make(map[string]bool)
|
||||
|
||||
const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
if config.RequireManagementAuth && len(config.ManagementKeys) == 0 {
|
||||
key := generateAPIKey(KeyTypeManagement)
|
||||
managementAPIKeys[key] = true
|
||||
generated = true
|
||||
fmt.Printf("%s\n⚠️ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner)
|
||||
fmt.Printf("🔑 Generated Management API Key:\n\n %s\n\n", key)
|
||||
}
|
||||
for _, key := range config.ManagementKeys {
|
||||
managementAPIKeys[key] = true
|
||||
}
|
||||
|
||||
if config.RequireInferenceAuth && len(config.InferenceKeys) == 0 {
|
||||
key := generateAPIKey(KeyTypeInference)
|
||||
inferenceAPIKeys[key] = true
|
||||
generated = true
|
||||
fmt.Printf("%s\n⚠️ INFERENCE AUTHENTICATION REQUIRED\n%s\n", banner, banner)
|
||||
fmt.Printf("🔑 Generated Inference API Key:\n\n %s\n\n", key)
|
||||
}
|
||||
for _, key := range config.InferenceKeys {
|
||||
inferenceAPIKeys[key] = true
|
||||
}
|
||||
|
||||
if generated {
|
||||
fmt.Printf("%s\n⚠️ IMPORTANT\n%s\n", banner, banner)
|
||||
fmt.Println("• These keys are auto-generated and will change on restart")
|
||||
fmt.Println("• For production, add explicit keys to your configuration")
|
||||
fmt.Println("• Copy these keys before they disappear from the terminal")
|
||||
fmt.Println(banner)
|
||||
}
|
||||
|
||||
return &APIAuthMiddleware{
|
||||
requireInferenceAuth: config.RequireInferenceAuth,
|
||||
inferenceKeys: inferenceAPIKeys,
|
||||
requireManagementAuth: config.RequireManagementAuth,
|
||||
managementKeys: managementAPIKeys,
|
||||
}
|
||||
}
|
||||
|
||||
// generateAPIKey creates a cryptographically secure API key
|
||||
func generateAPIKey(keyType KeyType) string {
|
||||
// Generate 32 random bytes (256 bits)
|
||||
randomBytes := make([]byte, 32)
|
||||
|
||||
var prefix string
|
||||
|
||||
switch keyType {
|
||||
case KeyTypeInference:
|
||||
prefix = "sk-inference"
|
||||
case KeyTypeManagement:
|
||||
prefix = "sk-management"
|
||||
default:
|
||||
prefix = "sk-unknown"
|
||||
}
|
||||
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
log.Printf("Warning: Failed to generate secure random key, using fallback")
|
||||
// Fallback to a less secure method if crypto/rand fails
|
||||
return fmt.Sprintf("%s-fallback-%d", prefix, os.Getpid())
|
||||
}
|
||||
|
||||
// Convert to hex and add prefix
|
||||
return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes))
|
||||
}
|
||||
|
||||
// InferenceMiddleware returns middleware for OpenAI inference endpoints
|
||||
func (a *APIAuthMiddleware) InferenceMiddleware() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !a.requireInferenceAuth {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := a.extractAPIKey(r)
|
||||
if apiKey == "" {
|
||||
a.unauthorized(w, "Missing API key")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if key is valid for OpenAI access
|
||||
// Management keys also work for OpenAI endpoints (higher privilege)
|
||||
if !a.isValidKey(apiKey, KeyTypeInference) && !a.isValidKey(apiKey, KeyTypeManagement) {
|
||||
a.unauthorized(w, "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ManagementMiddleware returns middleware for management endpoints
|
||||
func (a *APIAuthMiddleware) ManagementMiddleware() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !a.requireManagementAuth {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := a.extractAPIKey(r)
|
||||
if apiKey == "" {
|
||||
a.unauthorized(w, "Missing API key")
|
||||
return
|
||||
}
|
||||
|
||||
// Only management keys work for management endpoints
|
||||
if !a.isValidKey(apiKey, KeyTypeManagement) {
|
||||
a.unauthorized(w, "Insufficient privileges - management key required")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// extractAPIKey extracts the API key from the request
|
||||
func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string {
|
||||
// Check Authorization header: "Bearer sk-..."
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
return strings.TrimPrefix(auth, "Bearer ")
|
||||
}
|
||||
}
|
||||
|
||||
// Check X-API-Key header
|
||||
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
|
||||
return apiKey
|
||||
}
|
||||
|
||||
// Check query parameter
|
||||
if apiKey := r.URL.Query().Get("api_key"); apiKey != "" {
|
||||
return apiKey
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// isValidKey checks if the provided API key is valid for the given key type
|
||||
func (a *APIAuthMiddleware) isValidKey(providedKey string, keyType KeyType) bool {
|
||||
var validKeys map[string]bool
|
||||
|
||||
switch keyType {
|
||||
case KeyTypeInference:
|
||||
validKeys = a.inferenceKeys
|
||||
case KeyTypeManagement:
|
||||
validKeys = a.managementKeys
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
for validKey := range validKeys {
|
||||
if len(providedKey) == len(validKey) &&
|
||||
subtle.ConstantTimeCompare([]byte(providedKey), []byte(validKey)) == 1 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// unauthorized sends an unauthorized response
|
||||
func (a *APIAuthMiddleware) unauthorized(w http.ResponseWriter, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
response := fmt.Sprintf(`{"error": {"message": "%s", "type": "authentication_error"}}`, message)
|
||||
w.Write([]byte(response))
|
||||
}
|
||||
@@ -26,12 +26,22 @@ func SetupRouter(handler *Handler) *chi.Mux {
|
||||
MaxAge: 300,
|
||||
}))
|
||||
|
||||
// Add API authentication middleware
|
||||
authMiddleware := NewAPIAuthMiddleware(handler.config.Auth)
|
||||
|
||||
if handler.config.Server.EnableSwagger {
|
||||
r.Get("/swagger/*", httpSwagger.Handler(
|
||||
httpSwagger.URL("/swagger/doc.json"),
|
||||
))
|
||||
}
|
||||
|
||||
// Define routes
|
||||
r.Route("/api/v1", func(r chi.Router) {
|
||||
|
||||
if authMiddleware != nil && handler.config.Auth.RequireManagementAuth {
|
||||
r.Use(authMiddleware.ManagementMiddleware())
|
||||
}
|
||||
|
||||
r.Route("/server", func(r chi.Router) {
|
||||
r.Get("/help", handler.HelpHandler())
|
||||
r.Get("/version", handler.VersionHandler())
|
||||
@@ -61,7 +71,13 @@ func SetupRouter(handler *Handler) *chi.Mux {
|
||||
})
|
||||
})
|
||||
|
||||
r.Get(("/v1/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format
|
||||
r.Route(("/v1"), func(r chi.Router) {
|
||||
|
||||
if authMiddleware != nil && handler.config.Auth.RequireInferenceAuth {
|
||||
r.Use(authMiddleware.InferenceMiddleware())
|
||||
}
|
||||
|
||||
r.Get(("/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format
|
||||
|
||||
// OpenAI-compatible proxy endpoint
|
||||
// Handles all POST requests to /v1/*, including:
|
||||
@@ -71,7 +87,9 @@ func SetupRouter(handler *Handler) *chi.Mux {
|
||||
// - /v1/rerank
|
||||
// - /v1/reranking
|
||||
// The instance/model to use is determined by the request body.
|
||||
r.Post("/v1/*", handler.OpenAIProxy())
|
||||
r.Post("/*", handler.OpenAIProxy())
|
||||
|
||||
})
|
||||
|
||||
// Serve WebUI files
|
||||
if err := webui.SetupWebUI(r); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user