mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-12-22 17:14:22 +00:00
Initial api key store implementation
This commit is contained in:
@@ -1,15 +1,19 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"llamactl/pkg/auth"
|
||||
"llamactl/pkg/config"
|
||||
"llamactl/pkg/database"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type KeyType int
|
||||
@@ -19,58 +23,59 @@ const (
|
||||
KeyTypeManagement
|
||||
)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
apiKeyContextKey contextKey = "apiKey"
|
||||
)
|
||||
|
||||
type APIAuthMiddleware struct {
|
||||
authStore database.AuthStore
|
||||
requireInferenceAuth bool
|
||||
inferenceKeys map[string]bool
|
||||
requireManagementAuth bool
|
||||
managementKeys map[string]bool
|
||||
managementKeys map[string]bool // Config-based management keys
|
||||
}
|
||||
|
||||
// NewAPIAuthMiddleware creates a new APIAuthMiddleware with the given configuration
|
||||
func NewAPIAuthMiddleware(authCfg config.AuthConfig) *APIAuthMiddleware {
|
||||
func NewAPIAuthMiddleware(authCfg config.AuthConfig, authStore database.AuthStore) *APIAuthMiddleware {
|
||||
// Load management keys from config into managementKeys map
|
||||
managementKeys := make(map[string]bool)
|
||||
for _, key := range authCfg.ManagementKeys {
|
||||
managementKeys[key] = true
|
||||
}
|
||||
|
||||
// If len(authCfg.InferenceKeys) > 0, log warning
|
||||
if len(authCfg.InferenceKeys) > 0 {
|
||||
log.Println("⚠️ Config-based inference keys are no longer supported and will be ignored.")
|
||||
log.Println(" Please create inference keys in web UI or via management API.")
|
||||
}
|
||||
|
||||
// Handle legacy auto-generation for management keys if none provided and auth is required
|
||||
var generated bool = false
|
||||
|
||||
inferenceAPIKeys := make(map[string]bool)
|
||||
managementAPIKeys := make(map[string]bool)
|
||||
|
||||
const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
if authCfg.RequireManagementAuth && len(authCfg.ManagementKeys) == 0 {
|
||||
key := generateAPIKey(KeyTypeManagement)
|
||||
managementAPIKeys[key] = true
|
||||
managementKeys[key] = true
|
||||
generated = true
|
||||
fmt.Printf("%s\n⚠️ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner)
|
||||
fmt.Printf("🔑 Generated Management API Key:\n\n %s\n\n", key)
|
||||
}
|
||||
for _, key := range authCfg.ManagementKeys {
|
||||
managementAPIKeys[key] = true
|
||||
}
|
||||
|
||||
if authCfg.RequireInferenceAuth && len(authCfg.InferenceKeys) == 0 {
|
||||
key := generateAPIKey(KeyTypeInference)
|
||||
inferenceAPIKeys[key] = true
|
||||
generated = true
|
||||
fmt.Printf("%s\n⚠️ INFERENCE AUTHENTICATION REQUIRED\n%s\n", banner, banner)
|
||||
fmt.Printf("🔑 Generated Inference API Key:\n\n %s\n\n", key)
|
||||
}
|
||||
for _, key := range authCfg.InferenceKeys {
|
||||
inferenceAPIKeys[key] = true
|
||||
}
|
||||
|
||||
if generated {
|
||||
fmt.Printf("%s\n⚠️ IMPORTANT\n%s\n", banner, banner)
|
||||
fmt.Println("• These keys are auto-generated and will change on restart")
|
||||
fmt.Println("• This key is auto-generated and will change on restart")
|
||||
fmt.Println("• For production, add explicit keys to your configuration")
|
||||
fmt.Println("• Copy these keys before they disappear from the terminal")
|
||||
fmt.Println("• Copy this key before it disappears from the terminal")
|
||||
fmt.Println(banner)
|
||||
}
|
||||
|
||||
return &APIAuthMiddleware{
|
||||
authStore: authStore,
|
||||
requireInferenceAuth: authCfg.RequireInferenceAuth,
|
||||
inferenceKeys: inferenceAPIKeys,
|
||||
requireManagementAuth: authCfg.RequireManagementAuth,
|
||||
managementKeys: managementAPIKeys,
|
||||
managementKeys: managementKeys,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,7 +105,120 @@ func generateAPIKey(keyType KeyType) string {
|
||||
return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes))
|
||||
}
|
||||
|
||||
// AuthMiddleware returns a middleware that checks API keys for the given key type
|
||||
// InferenceAuthMiddleware returns middleware for inference endpoints
|
||||
func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "OPTIONS" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract API key from request
|
||||
apiKey := a.extractAPIKey(r)
|
||||
if apiKey == "" {
|
||||
a.unauthorized(w, "Missing API key")
|
||||
return
|
||||
}
|
||||
|
||||
// Try database authentication first
|
||||
var foundKey *auth.APIKey
|
||||
if a.requireInferenceAuth {
|
||||
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
|
||||
if err != nil {
|
||||
log.Printf("Failed to get active inference keys: %v", err)
|
||||
// Continue to management key fallback
|
||||
} else {
|
||||
for _, key := range activeKeys {
|
||||
if auth.VerifyKey(apiKey, key.KeyHash) {
|
||||
foundKey = key
|
||||
// Async update last_used_at
|
||||
go func(keyID int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := a.authStore.TouchKey(ctx, keyID); err != nil {
|
||||
log.Printf("Failed to update last used timestamp for key %d: %v", keyID, err)
|
||||
}
|
||||
}(key.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no database key found, try management key authentication (config-based)
|
||||
if foundKey == nil {
|
||||
if !a.isValidManagementKey(apiKey) {
|
||||
a.unauthorized(w, "Invalid API key")
|
||||
return
|
||||
}
|
||||
// Management key was used, continue without adding APIKey to context
|
||||
} else {
|
||||
// Add APIKey to context for permission checking
|
||||
ctx := context.WithValue(r.Context(), apiKeyContextKey, foundKey)
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ManagementAuthMiddleware returns middleware for management endpoints
|
||||
func (a *APIAuthMiddleware) ManagementAuthMiddleware() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "OPTIONS" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract API key from request
|
||||
apiKey := a.extractAPIKey(r)
|
||||
if apiKey == "" {
|
||||
a.unauthorized(w, "Missing API key")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if key exists in managementKeys map using constant-time comparison
|
||||
if !a.isValidManagementKey(apiKey) {
|
||||
a.unauthorized(w, "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// CheckInstancePermission checks if the authenticated key has permission for the instance
|
||||
func (a *APIAuthMiddleware) CheckInstancePermission(ctx context.Context, instanceID int) error {
|
||||
// Extract APIKey from context
|
||||
apiKey, ok := ctx.Value(apiKeyContextKey).(*auth.APIKey)
|
||||
if !ok {
|
||||
// APIKey is nil, management key was used, allow all
|
||||
return nil
|
||||
}
|
||||
|
||||
// If permission_mode == "allow_all", allow all
|
||||
if apiKey.PermissionMode == auth.PermissionModeAllowAll {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check per-instance permissions
|
||||
canInfer, err := a.authStore.HasPermission(ctx, apiKey.ID, instanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !canInfer {
|
||||
return http.ErrBodyNotAllowed // Use this as a generic error to indicate permission denied
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthMiddleware returns a middleware that checks API keys for the given key type (legacy support)
|
||||
func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -118,10 +236,38 @@ func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) h
|
||||
var isValid bool
|
||||
switch keyType {
|
||||
case KeyTypeInference:
|
||||
// Management keys also work for OpenAI endpoints (higher privilege)
|
||||
isValid = a.isValidKey(apiKey, KeyTypeInference) || a.isValidKey(apiKey, KeyTypeManagement)
|
||||
// Try database authentication first
|
||||
if a.requireInferenceAuth {
|
||||
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
|
||||
if err == nil {
|
||||
for _, key := range activeKeys {
|
||||
if auth.VerifyKey(apiKey, key.KeyHash) {
|
||||
foundKey := key
|
||||
// Async update last_used_at
|
||||
go func(keyID int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := a.authStore.TouchKey(ctx, keyID); err != nil {
|
||||
log.Printf("Failed to update last used timestamp for key %d: %v", keyID, err)
|
||||
}
|
||||
}(key.ID)
|
||||
|
||||
// Add APIKey to context for permission checking
|
||||
ctx := context.WithValue(r.Context(), apiKeyContextKey, foundKey)
|
||||
r = r.WithContext(ctx)
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no database key found, try management key (higher privilege)
|
||||
if !isValid {
|
||||
isValid = a.isValidManagementKey(apiKey)
|
||||
}
|
||||
case KeyTypeManagement:
|
||||
isValid = a.isValidKey(apiKey, KeyTypeManagement)
|
||||
isValid = a.isValidManagementKey(apiKey)
|
||||
default:
|
||||
isValid = false
|
||||
}
|
||||
@@ -158,20 +304,9 @@ func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// isValidKey checks if the provided API key is valid for the given key type
|
||||
func (a *APIAuthMiddleware) isValidKey(providedKey string, keyType KeyType) bool {
|
||||
var validKeys map[string]bool
|
||||
|
||||
switch keyType {
|
||||
case KeyTypeInference:
|
||||
validKeys = a.inferenceKeys
|
||||
case KeyTypeManagement:
|
||||
validKeys = a.managementKeys
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
for validKey := range validKeys {
|
||||
// isValidManagementKey checks if the provided API key is a valid management key
|
||||
func (a *APIAuthMiddleware) isValidManagementKey(providedKey string) bool {
|
||||
for validKey := range a.managementKeys {
|
||||
if len(providedKey) == len(validKey) &&
|
||||
subtle.ConstantTimeCompare([]byte(providedKey), []byte(validKey)) == 1 {
|
||||
return true
|
||||
@@ -187,3 +322,11 @@ func (a *APIAuthMiddleware) unauthorized(w http.ResponseWriter, message string)
|
||||
response := fmt.Sprintf(`{"error": {"message": "%s", "type": "authentication_error"}}`, message)
|
||||
w.Write([]byte(response))
|
||||
}
|
||||
|
||||
// forbidden sends a forbidden response
|
||||
func (a *APIAuthMiddleware) forbidden(w http.ResponseWriter, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
response := fmt.Sprintf(`{"error": {"message": "%s", "type": "permission_denied"}}`, message)
|
||||
w.Write([]byte(response))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user