Files
llamactl/pkg/server/middleware.go

233 lines
6.9 KiB
Go
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package server
import (
"context"
"crypto/subtle"
"fmt"
"llamactl/pkg/auth"
"llamactl/pkg/config"
"llamactl/pkg/database"
"log"
"net/http"
"os"
"strings"
"time"
)
// contextKey is a custom type for context keys to avoid collisions
type contextKey string
const (
apiKeyContextKey contextKey = "apiKey"
)
type APIAuthMiddleware struct {
authStore database.AuthStore
requireInferenceAuth bool
requireManagementAuth bool
managementKeys map[string]bool // Config-based management keys
}
// NewAPIAuthMiddleware creates a new APIAuthMiddleware with the given configuration
func NewAPIAuthMiddleware(authCfg config.AuthConfig, authStore database.AuthStore) *APIAuthMiddleware {
// Load management keys from config into managementKeys map
managementKeys := make(map[string]bool)
for _, key := range authCfg.ManagementKeys {
managementKeys[key] = true
}
// Handle legacy auto-generation for management keys if none provided and auth is required
var generated bool = false
const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
if authCfg.RequireManagementAuth && len(authCfg.ManagementKeys) == 0 {
key, err := auth.GenerateKey("llamactl-mgmt")
if err != nil {
log.Printf("Warning: Failed to generate management key: %v", err)
// Fallback to PID-based key for safety
key = fmt.Sprintf("sk-management-fallback-%d", os.Getpid())
}
managementKeys[key] = true
generated = true
fmt.Printf("%s\n⚠ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner)
fmt.Printf("🔑 Generated Management API Key:\n\n %s\n\n", key)
}
if generated {
fmt.Printf("%s\n⚠ IMPORTANT\n%s\n", banner, banner)
fmt.Println("• This key is auto-generated and will change on restart")
fmt.Println("• For production, add explicit keys to your configuration")
fmt.Println("• Copy this key before it disappears from the terminal")
fmt.Println(banner)
}
return &APIAuthMiddleware{
authStore: authStore,
requireInferenceAuth: authCfg.RequireInferenceAuth,
requireManagementAuth: authCfg.RequireManagementAuth,
managementKeys: managementKeys,
}
}
// InferenceAuthMiddleware returns middleware for inference endpoints
func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
// Extract API key from request
apiKey := a.extractAPIKey(r)
if apiKey == "" {
a.unauthorized(w, "Missing API key")
return
}
// Try database authentication first
var foundKey *auth.APIKey
if a.requireInferenceAuth && a.authStore != nil {
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
if err != nil {
log.Printf("Failed to get active inference keys: %v", err)
// Continue to management key fallback
} else {
for _, key := range activeKeys {
if auth.VerifyKey(apiKey, key.KeyHash) {
foundKey = key
// Async update last_used_at
go func(keyID int) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := a.authStore.TouchKey(ctx, keyID); err != nil {
log.Printf("Failed to update last used timestamp for key %d: %v", keyID, err)
}
}(key.ID)
break
}
}
}
}
// If no database key found, try management key authentication (config-based)
if foundKey == nil {
if !a.isValidManagementKey(apiKey) {
a.unauthorized(w, "Invalid API key")
return
}
// Management key was used, continue without adding APIKey to context
} else {
// Add APIKey to context for permission checking
ctx := context.WithValue(r.Context(), apiKeyContextKey, foundKey)
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
})
}
}
// ManagementAuthMiddleware returns middleware for management endpoints
func (a *APIAuthMiddleware) ManagementAuthMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
// Extract API key from request
apiKey := a.extractAPIKey(r)
if apiKey == "" {
a.unauthorized(w, "Missing API key")
return
}
// Check if key exists in managementKeys map using constant-time comparison
if !a.isValidManagementKey(apiKey) {
a.unauthorized(w, "Invalid API key")
return
}
next.ServeHTTP(w, r)
})
}
}
// CheckInstancePermission checks if the authenticated key has permission for the instance
func (a *APIAuthMiddleware) CheckInstancePermission(ctx context.Context, instanceID int) error {
// Extract APIKey from context
apiKey, ok := ctx.Value(apiKeyContextKey).(*auth.APIKey)
if !ok {
// APIKey is nil, management key was used, allow all
return nil
}
// If permission_mode == "allow_all", allow all
if apiKey.PermissionMode == auth.PermissionModeAllowAll {
return nil
}
// Check per-instance permissions
canInfer, err := a.authStore.HasPermission(ctx, apiKey.ID, instanceID)
if err != nil {
return fmt.Errorf("failed to check permission: %w", err)
}
if !canInfer {
return fmt.Errorf("permission denied: key does not have access to this instance")
}
return nil
}
// extractAPIKey extracts the API key from the request
func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string {
// Check Authorization header: "Bearer sk-..."
if auth := r.Header.Get("Authorization"); auth != "" {
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
return after
}
}
// Check X-API-Key header
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
return apiKey
}
// Check query parameter
if apiKey := r.URL.Query().Get("api_key"); apiKey != "" {
return apiKey
}
return ""
}
// isValidManagementKey checks if the provided API key is a valid management key
func (a *APIAuthMiddleware) isValidManagementKey(providedKey string) bool {
for validKey := range a.managementKeys {
if len(providedKey) == len(validKey) &&
subtle.ConstantTimeCompare([]byte(providedKey), []byte(validKey)) == 1 {
return true
}
}
return false
}
// unauthorized sends an unauthorized response
func (a *APIAuthMiddleware) unauthorized(w http.ResponseWriter, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
response := fmt.Sprintf(`{"error": {"message": "%s", "type": "authentication_error"}}`, message)
w.Write([]byte(response))
}
// forbidden sends a forbidden response
func (a *APIAuthMiddleware) forbidden(w http.ResponseWriter, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
response := fmt.Sprintf(`{"error": {"message": "%s", "type": "permission_denied"}}`, message)
w.Write([]byte(response))
}