Files
llamactl/pkg/server/middleware.go

333 lines
9.5 KiB
Go
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package server
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"fmt"
"llamactl/pkg/auth"
"llamactl/pkg/config"
"llamactl/pkg/database"
"log"
"net/http"
"os"
"strings"
"time"
)
type KeyType int
const (
KeyTypeInference KeyType = iota
KeyTypeManagement
)
// contextKey is a custom type for context keys to avoid collisions
type contextKey string
const (
apiKeyContextKey contextKey = "apiKey"
)
type APIAuthMiddleware struct {
authStore database.AuthStore
requireInferenceAuth bool
requireManagementAuth bool
managementKeys map[string]bool // Config-based management keys
}
// NewAPIAuthMiddleware creates a new APIAuthMiddleware with the given configuration
func NewAPIAuthMiddleware(authCfg config.AuthConfig, authStore database.AuthStore) *APIAuthMiddleware {
// Load management keys from config into managementKeys map
managementKeys := make(map[string]bool)
for _, key := range authCfg.ManagementKeys {
managementKeys[key] = true
}
// If len(authCfg.InferenceKeys) > 0, log warning
if len(authCfg.InferenceKeys) > 0 {
log.Println("⚠️ Config-based inference keys are no longer supported and will be ignored.")
log.Println(" Please create inference keys in web UI or via management API.")
}
// Handle legacy auto-generation for management keys if none provided and auth is required
var generated bool = false
const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
if authCfg.RequireManagementAuth && len(authCfg.ManagementKeys) == 0 {
key := generateAPIKey(KeyTypeManagement)
managementKeys[key] = true
generated = true
fmt.Printf("%s\n⚠ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner)
fmt.Printf("🔑 Generated Management API Key:\n\n %s\n\n", key)
}
if generated {
fmt.Printf("%s\n⚠ IMPORTANT\n%s\n", banner, banner)
fmt.Println("• This key is auto-generated and will change on restart")
fmt.Println("• For production, add explicit keys to your configuration")
fmt.Println("• Copy this key before it disappears from the terminal")
fmt.Println(banner)
}
return &APIAuthMiddleware{
authStore: authStore,
requireInferenceAuth: authCfg.RequireInferenceAuth,
requireManagementAuth: authCfg.RequireManagementAuth,
managementKeys: managementKeys,
}
}
// generateAPIKey creates a cryptographically secure API key
func generateAPIKey(keyType KeyType) string {
// Generate 32 random bytes (256 bits)
randomBytes := make([]byte, 32)
var prefix string
switch keyType {
case KeyTypeInference:
prefix = "sk-inference"
case KeyTypeManagement:
prefix = "sk-management"
default:
prefix = "sk-unknown"
}
if _, err := rand.Read(randomBytes); err != nil {
log.Printf("Warning: Failed to generate secure random key, using fallback")
// Fallback to a less secure method if crypto/rand fails
return fmt.Sprintf("%s-fallback-%d", prefix, os.Getpid())
}
// Convert to hex and add prefix
return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes))
}
// InferenceAuthMiddleware returns middleware for inference endpoints
func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
// Extract API key from request
apiKey := a.extractAPIKey(r)
if apiKey == "" {
a.unauthorized(w, "Missing API key")
return
}
// Try database authentication first
var foundKey *auth.APIKey
if a.requireInferenceAuth {
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
if err != nil {
log.Printf("Failed to get active inference keys: %v", err)
// Continue to management key fallback
} else {
for _, key := range activeKeys {
if auth.VerifyKey(apiKey, key.KeyHash) {
foundKey = key
// Async update last_used_at
go func(keyID int) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := a.authStore.TouchKey(ctx, keyID); err != nil {
log.Printf("Failed to update last used timestamp for key %d: %v", keyID, err)
}
}(key.ID)
break
}
}
}
}
// If no database key found, try management key authentication (config-based)
if foundKey == nil {
if !a.isValidManagementKey(apiKey) {
a.unauthorized(w, "Invalid API key")
return
}
// Management key was used, continue without adding APIKey to context
} else {
// Add APIKey to context for permission checking
ctx := context.WithValue(r.Context(), apiKeyContextKey, foundKey)
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
})
}
}
// ManagementAuthMiddleware returns middleware for management endpoints
func (a *APIAuthMiddleware) ManagementAuthMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
// Extract API key from request
apiKey := a.extractAPIKey(r)
if apiKey == "" {
a.unauthorized(w, "Missing API key")
return
}
// Check if key exists in managementKeys map using constant-time comparison
if !a.isValidManagementKey(apiKey) {
a.unauthorized(w, "Invalid API key")
return
}
next.ServeHTTP(w, r)
})
}
}
// CheckInstancePermission checks if the authenticated key has permission for the instance
func (a *APIAuthMiddleware) CheckInstancePermission(ctx context.Context, instanceID int) error {
// Extract APIKey from context
apiKey, ok := ctx.Value(apiKeyContextKey).(*auth.APIKey)
if !ok {
// APIKey is nil, management key was used, allow all
return nil
}
// If permission_mode == "allow_all", allow all
if apiKey.PermissionMode == auth.PermissionModeAllowAll {
return nil
}
// Check per-instance permissions
canInfer, err := a.authStore.HasPermission(ctx, apiKey.ID, instanceID)
if err != nil {
return err
}
if !canInfer {
return http.ErrBodyNotAllowed // Use this as a generic error to indicate permission denied
}
return nil
}
// AuthMiddleware returns a middleware that checks API keys for the given key type (legacy support)
func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
apiKey := a.extractAPIKey(r)
if apiKey == "" {
a.unauthorized(w, "Missing API key")
return
}
var isValid bool
switch keyType {
case KeyTypeInference:
// Try database authentication first
if a.requireInferenceAuth {
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
if err == nil {
for _, key := range activeKeys {
if auth.VerifyKey(apiKey, key.KeyHash) {
foundKey := key
// Async update last_used_at
go func(keyID int) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := a.authStore.TouchKey(ctx, keyID); err != nil {
log.Printf("Failed to update last used timestamp for key %d: %v", keyID, err)
}
}(key.ID)
// Add APIKey to context for permission checking
ctx := context.WithValue(r.Context(), apiKeyContextKey, foundKey)
r = r.WithContext(ctx)
isValid = true
break
}
}
}
}
// If no database key found, try management key (higher privilege)
if !isValid {
isValid = a.isValidManagementKey(apiKey)
}
case KeyTypeManagement:
isValid = a.isValidManagementKey(apiKey)
default:
isValid = false
}
if !isValid {
a.unauthorized(w, "Invalid API key")
return
}
next.ServeHTTP(w, r)
})
}
}
// extractAPIKey extracts the API key from the request
func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string {
// Check Authorization header: "Bearer sk-..."
if auth := r.Header.Get("Authorization"); auth != "" {
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
return after
}
}
// Check X-API-Key header
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
return apiKey
}
// Check query parameter
if apiKey := r.URL.Query().Get("api_key"); apiKey != "" {
return apiKey
}
return ""
}
// isValidManagementKey checks if the provided API key is a valid management key
func (a *APIAuthMiddleware) isValidManagementKey(providedKey string) bool {
for validKey := range a.managementKeys {
if len(providedKey) == len(validKey) &&
subtle.ConstantTimeCompare([]byte(providedKey), []byte(validKey)) == 1 {
return true
}
}
return false
}
// unauthorized sends an unauthorized response
func (a *APIAuthMiddleware) unauthorized(w http.ResponseWriter, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
response := fmt.Sprintf(`{"error": {"message": "%s", "type": "authentication_error"}}`, message)
w.Write([]byte(response))
}
// forbidden sends a forbidden response
func (a *APIAuthMiddleware) forbidden(w http.ResponseWriter, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
response := fmt.Sprintf(`{"error": {"message": "%s", "type": "permission_denied"}}`, message)
w.Write([]byte(response))
}