Add permission checks to proxies

This commit is contained in:
2025-12-03 21:14:44 +01:00
parent 9eee42c673
commit 5ccf493e04
8 changed files with 271 additions and 371 deletions

View File

@@ -2,9 +2,7 @@ package server
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"fmt"
"llamactl/pkg/auth"
"llamactl/pkg/config"
@@ -16,13 +14,6 @@ import (
"time"
)
type KeyType int
const (
KeyTypeInference KeyType = iota
KeyTypeManagement
)
// contextKey is a custom type for context keys to avoid collisions
type contextKey string
@@ -56,7 +47,12 @@ func NewAPIAuthMiddleware(authCfg config.AuthConfig, authStore database.AuthStor
const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
if authCfg.RequireManagementAuth && len(authCfg.ManagementKeys) == 0 {
key := generateAPIKey(KeyTypeManagement)
key, err := auth.GenerateKey("llamactl-mgmt-")
if err != nil {
log.Printf("Warning: Failed to generate management key: %v", err)
// Fallback to PID-based key for safety
key = fmt.Sprintf("sk-management-fallback-%d", os.Getpid())
}
managementKeys[key] = true
generated = true
fmt.Printf("%s\n⚠ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner)
@@ -79,32 +75,6 @@ func NewAPIAuthMiddleware(authCfg config.AuthConfig, authStore database.AuthStor
}
}
// generateAPIKey creates a cryptographically secure API key
func generateAPIKey(keyType KeyType) string {
// Generate 32 random bytes (256 bits)
randomBytes := make([]byte, 32)
var prefix string
switch keyType {
case KeyTypeInference:
prefix = "sk-inference"
case KeyTypeManagement:
prefix = "sk-management"
default:
prefix = "sk-unknown"
}
if _, err := rand.Read(randomBytes); err != nil {
log.Printf("Warning: Failed to generate secure random key, using fallback")
// Fallback to a less secure method if crypto/rand fails
return fmt.Sprintf("%s-fallback-%d", prefix, os.Getpid())
}
// Convert to hex and add prefix
return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes))
}
// InferenceAuthMiddleware returns middleware for inference endpoints
func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
@@ -123,7 +93,7 @@ func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Ha
// Try database authentication first
var foundKey *auth.APIKey
if a.requireInferenceAuth {
if a.requireInferenceAuth && a.authStore != nil {
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
if err != nil {
log.Printf("Failed to get active inference keys: %v", err)
@@ -208,80 +178,16 @@ func (a *APIAuthMiddleware) CheckInstancePermission(ctx context.Context, instanc
// Check per-instance permissions
canInfer, err := a.authStore.HasPermission(ctx, apiKey.ID, instanceID)
if err != nil {
return err
return fmt.Errorf("failed to check permission: %w", err)
}
if !canInfer {
return http.ErrBodyNotAllowed // Use this as a generic error to indicate permission denied
return fmt.Errorf("permission denied: key does not have access to this instance")
}
return nil
}
// AuthMiddleware returns a middleware that checks API keys for the given key type (legacy support)
func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
apiKey := a.extractAPIKey(r)
if apiKey == "" {
a.unauthorized(w, "Missing API key")
return
}
var isValid bool
switch keyType {
case KeyTypeInference:
// Try database authentication first
if a.requireInferenceAuth {
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
if err == nil {
for _, key := range activeKeys {
if auth.VerifyKey(apiKey, key.KeyHash) {
foundKey := key
// Async update last_used_at
go func(keyID int) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := a.authStore.TouchKey(ctx, keyID); err != nil {
log.Printf("Failed to update last used timestamp for key %d: %v", keyID, err)
}
}(key.ID)
// Add APIKey to context for permission checking
ctx := context.WithValue(r.Context(), apiKeyContextKey, foundKey)
r = r.WithContext(ctx)
isValid = true
break
}
}
}
}
// If no database key found, try management key (higher privilege)
if !isValid {
isValid = a.isValidManagementKey(apiKey)
}
case KeyTypeManagement:
isValid = a.isValidManagementKey(apiKey)
default:
isValid = false
}
if !isValid {
a.unauthorized(w, "Invalid API key")
return
}
next.ServeHTTP(w, r)
})
}
}
// extractAPIKey extracts the API key from the request
func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string {
// Check Authorization header: "Bearer sk-..."