mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-12-22 09:04:22 +00:00
233 lines
6.9 KiB
Go
233 lines
6.9 KiB
Go
package server
|
||
|
||
import (
|
||
"context"
|
||
"crypto/subtle"
|
||
"fmt"
|
||
"llamactl/pkg/auth"
|
||
"llamactl/pkg/config"
|
||
"llamactl/pkg/database"
|
||
"log"
|
||
"net/http"
|
||
"os"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
// contextKey is a custom type for context keys to avoid collisions
|
||
type contextKey string
|
||
|
||
const (
|
||
apiKeyContextKey contextKey = "apiKey"
|
||
)
|
||
|
||
type APIAuthMiddleware struct {
|
||
authStore database.AuthStore
|
||
requireInferenceAuth bool
|
||
requireManagementAuth bool
|
||
managementKeys map[string]bool // Config-based management keys
|
||
}
|
||
|
||
// NewAPIAuthMiddleware creates a new APIAuthMiddleware with the given configuration
|
||
func NewAPIAuthMiddleware(authCfg config.AuthConfig, authStore database.AuthStore) *APIAuthMiddleware {
|
||
// Load management keys from config into managementKeys map
|
||
managementKeys := make(map[string]bool)
|
||
for _, key := range authCfg.ManagementKeys {
|
||
managementKeys[key] = true
|
||
}
|
||
|
||
// Handle legacy auto-generation for management keys if none provided and auth is required
|
||
var generated bool = false
|
||
const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||
|
||
if authCfg.RequireManagementAuth && len(authCfg.ManagementKeys) == 0 {
|
||
key, err := auth.GenerateKey("llamactl-mgmt")
|
||
if err != nil {
|
||
log.Printf("Warning: Failed to generate management key: %v", err)
|
||
// Fallback to PID-based key for safety
|
||
key = fmt.Sprintf("sk-management-fallback-%d", os.Getpid())
|
||
}
|
||
managementKeys[key] = true
|
||
generated = true
|
||
fmt.Printf("%s\n⚠️ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner)
|
||
fmt.Printf("🔑 Generated Management API Key:\n\n %s\n\n", key)
|
||
}
|
||
|
||
if generated {
|
||
fmt.Printf("%s\n⚠️ IMPORTANT\n%s\n", banner, banner)
|
||
fmt.Println("• This key is auto-generated and will change on restart")
|
||
fmt.Println("• For production, add explicit keys to your configuration")
|
||
fmt.Println("• Copy this key before it disappears from the terminal")
|
||
fmt.Println(banner)
|
||
}
|
||
|
||
return &APIAuthMiddleware{
|
||
authStore: authStore,
|
||
requireInferenceAuth: authCfg.RequireInferenceAuth,
|
||
requireManagementAuth: authCfg.RequireManagementAuth,
|
||
managementKeys: managementKeys,
|
||
}
|
||
}
|
||
|
||
// InferenceAuthMiddleware returns middleware for inference endpoints
|
||
func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Handler {
|
||
return func(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method == "OPTIONS" {
|
||
next.ServeHTTP(w, r)
|
||
return
|
||
}
|
||
|
||
// Extract API key from request
|
||
apiKey := a.extractAPIKey(r)
|
||
if apiKey == "" {
|
||
a.unauthorized(w, "Missing API key")
|
||
return
|
||
}
|
||
|
||
// Try database authentication first
|
||
var foundKey *auth.APIKey
|
||
if a.requireInferenceAuth && a.authStore != nil {
|
||
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
|
||
if err != nil {
|
||
log.Printf("Failed to get active inference keys: %v", err)
|
||
// Continue to management key fallback
|
||
} else {
|
||
for _, key := range activeKeys {
|
||
if auth.VerifyKey(apiKey, key.KeyHash) {
|
||
foundKey = key
|
||
// Async update last_used_at
|
||
go func(keyID int) {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
if err := a.authStore.TouchKey(ctx, keyID); err != nil {
|
||
log.Printf("Failed to update last used timestamp for key %d: %v", keyID, err)
|
||
}
|
||
}(key.ID)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// If no database key found, try management key authentication (config-based)
|
||
if foundKey == nil {
|
||
if !a.isValidManagementKey(apiKey) {
|
||
a.unauthorized(w, "Invalid API key")
|
||
return
|
||
}
|
||
// Management key was used, continue without adding APIKey to context
|
||
} else {
|
||
// Add APIKey to context for permission checking
|
||
ctx := context.WithValue(r.Context(), apiKeyContextKey, foundKey)
|
||
r = r.WithContext(ctx)
|
||
}
|
||
|
||
next.ServeHTTP(w, r)
|
||
})
|
||
}
|
||
}
|
||
|
||
// ManagementAuthMiddleware returns middleware for management endpoints
|
||
func (a *APIAuthMiddleware) ManagementAuthMiddleware() func(http.Handler) http.Handler {
|
||
return func(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method == "OPTIONS" {
|
||
next.ServeHTTP(w, r)
|
||
return
|
||
}
|
||
|
||
// Extract API key from request
|
||
apiKey := a.extractAPIKey(r)
|
||
if apiKey == "" {
|
||
a.unauthorized(w, "Missing API key")
|
||
return
|
||
}
|
||
|
||
// Check if key exists in managementKeys map using constant-time comparison
|
||
if !a.isValidManagementKey(apiKey) {
|
||
a.unauthorized(w, "Invalid API key")
|
||
return
|
||
}
|
||
|
||
next.ServeHTTP(w, r)
|
||
})
|
||
}
|
||
}
|
||
|
||
// CheckInstancePermission checks if the authenticated key has permission for the instance
|
||
func (a *APIAuthMiddleware) CheckInstancePermission(ctx context.Context, instanceID int) error {
|
||
// Extract APIKey from context
|
||
apiKey, ok := ctx.Value(apiKeyContextKey).(*auth.APIKey)
|
||
if !ok {
|
||
// APIKey is nil, management key was used, allow all
|
||
return nil
|
||
}
|
||
|
||
// If permission_mode == "allow_all", allow all
|
||
if apiKey.PermissionMode == auth.PermissionModeAllowAll {
|
||
return nil
|
||
}
|
||
|
||
// Check per-instance permissions
|
||
canInfer, err := a.authStore.HasPermission(ctx, apiKey.ID, instanceID)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to check permission: %w", err)
|
||
}
|
||
|
||
if !canInfer {
|
||
return fmt.Errorf("permission denied: key does not have access to this instance")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// extractAPIKey extracts the API key from the request
|
||
func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string {
|
||
// Check Authorization header: "Bearer sk-..."
|
||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
|
||
return after
|
||
}
|
||
}
|
||
|
||
// Check X-API-Key header
|
||
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
|
||
return apiKey
|
||
}
|
||
|
||
// Check query parameter
|
||
if apiKey := r.URL.Query().Get("api_key"); apiKey != "" {
|
||
return apiKey
|
||
}
|
||
|
||
return ""
|
||
}
|
||
|
||
// isValidManagementKey checks if the provided API key is a valid management key
|
||
func (a *APIAuthMiddleware) isValidManagementKey(providedKey string) bool {
|
||
for validKey := range a.managementKeys {
|
||
if len(providedKey) == len(validKey) &&
|
||
subtle.ConstantTimeCompare([]byte(providedKey), []byte(validKey)) == 1 {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// unauthorized sends an unauthorized response
|
||
func (a *APIAuthMiddleware) unauthorized(w http.ResponseWriter, message string) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
response := fmt.Sprintf(`{"error": {"message": "%s", "type": "authentication_error"}}`, message)
|
||
w.Write([]byte(response))
|
||
}
|
||
|
||
// forbidden sends a forbidden response
|
||
func (a *APIAuthMiddleware) forbidden(w http.ResponseWriter, message string) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusForbidden)
|
||
response := fmt.Sprintf(`{"error": {"message": "%s", "type": "permission_denied"}}`, message)
|
||
w.Write([]byte(response))
|
||
}
|