mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-12-23 09:34:23 +00:00
Add permission checks to proxies
This commit is contained in:
@@ -2,9 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"llamactl/pkg/auth"
|
||||
"llamactl/pkg/config"
|
||||
@@ -16,13 +14,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type KeyType int
|
||||
|
||||
const (
|
||||
KeyTypeInference KeyType = iota
|
||||
KeyTypeManagement
|
||||
)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions
|
||||
type contextKey string
|
||||
|
||||
@@ -56,7 +47,12 @@ func NewAPIAuthMiddleware(authCfg config.AuthConfig, authStore database.AuthStor
|
||||
const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
if authCfg.RequireManagementAuth && len(authCfg.ManagementKeys) == 0 {
|
||||
key := generateAPIKey(KeyTypeManagement)
|
||||
key, err := auth.GenerateKey("llamactl-mgmt-")
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to generate management key: %v", err)
|
||||
// Fallback to PID-based key for safety
|
||||
key = fmt.Sprintf("sk-management-fallback-%d", os.Getpid())
|
||||
}
|
||||
managementKeys[key] = true
|
||||
generated = true
|
||||
fmt.Printf("%s\n⚠️ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner)
|
||||
@@ -79,32 +75,6 @@ func NewAPIAuthMiddleware(authCfg config.AuthConfig, authStore database.AuthStor
|
||||
}
|
||||
}
|
||||
|
||||
// generateAPIKey creates a cryptographically secure API key
|
||||
func generateAPIKey(keyType KeyType) string {
|
||||
// Generate 32 random bytes (256 bits)
|
||||
randomBytes := make([]byte, 32)
|
||||
|
||||
var prefix string
|
||||
|
||||
switch keyType {
|
||||
case KeyTypeInference:
|
||||
prefix = "sk-inference"
|
||||
case KeyTypeManagement:
|
||||
prefix = "sk-management"
|
||||
default:
|
||||
prefix = "sk-unknown"
|
||||
}
|
||||
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
log.Printf("Warning: Failed to generate secure random key, using fallback")
|
||||
// Fallback to a less secure method if crypto/rand fails
|
||||
return fmt.Sprintf("%s-fallback-%d", prefix, os.Getpid())
|
||||
}
|
||||
|
||||
// Convert to hex and add prefix
|
||||
return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes))
|
||||
}
|
||||
|
||||
// InferenceAuthMiddleware returns middleware for inference endpoints
|
||||
func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
@@ -123,7 +93,7 @@ func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Ha
|
||||
|
||||
// Try database authentication first
|
||||
var foundKey *auth.APIKey
|
||||
if a.requireInferenceAuth {
|
||||
if a.requireInferenceAuth && a.authStore != nil {
|
||||
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
|
||||
if err != nil {
|
||||
log.Printf("Failed to get active inference keys: %v", err)
|
||||
@@ -208,80 +178,16 @@ func (a *APIAuthMiddleware) CheckInstancePermission(ctx context.Context, instanc
|
||||
// Check per-instance permissions
|
||||
canInfer, err := a.authStore.HasPermission(ctx, apiKey.ID, instanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to check permission: %w", err)
|
||||
}
|
||||
|
||||
if !canInfer {
|
||||
return http.ErrBodyNotAllowed // Use this as a generic error to indicate permission denied
|
||||
return fmt.Errorf("permission denied: key does not have access to this instance")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthMiddleware returns a middleware that checks API keys for the given key type (legacy support)
|
||||
func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "OPTIONS" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := a.extractAPIKey(r)
|
||||
if apiKey == "" {
|
||||
a.unauthorized(w, "Missing API key")
|
||||
return
|
||||
}
|
||||
|
||||
var isValid bool
|
||||
switch keyType {
|
||||
case KeyTypeInference:
|
||||
// Try database authentication first
|
||||
if a.requireInferenceAuth {
|
||||
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
|
||||
if err == nil {
|
||||
for _, key := range activeKeys {
|
||||
if auth.VerifyKey(apiKey, key.KeyHash) {
|
||||
foundKey := key
|
||||
// Async update last_used_at
|
||||
go func(keyID int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := a.authStore.TouchKey(ctx, keyID); err != nil {
|
||||
log.Printf("Failed to update last used timestamp for key %d: %v", keyID, err)
|
||||
}
|
||||
}(key.ID)
|
||||
|
||||
// Add APIKey to context for permission checking
|
||||
ctx := context.WithValue(r.Context(), apiKeyContextKey, foundKey)
|
||||
r = r.WithContext(ctx)
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no database key found, try management key (higher privilege)
|
||||
if !isValid {
|
||||
isValid = a.isValidManagementKey(apiKey)
|
||||
}
|
||||
case KeyTypeManagement:
|
||||
isValid = a.isValidManagementKey(apiKey)
|
||||
default:
|
||||
isValid = false
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
a.unauthorized(w, "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// extractAPIKey extracts the API key from the request
|
||||
func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string {
|
||||
// Check Authorization header: "Bearer sk-..."
|
||||
|
||||
Reference in New Issue
Block a user