mirror of
https://github.com/lordmathis/llamactl.git
synced 2025-12-22 17:14:22 +00:00
Initial api key store implementation
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"llamactl/pkg/config"
|
||||
"llamactl/pkg/database"
|
||||
"llamactl/pkg/instance"
|
||||
"llamactl/pkg/manager"
|
||||
"llamactl/pkg/validation"
|
||||
@@ -52,20 +53,25 @@ type Handler struct {
|
||||
InstanceManager manager.InstanceManager
|
||||
cfg config.AppConfig
|
||||
httpClient *http.Client
|
||||
authStore database.AuthStore
|
||||
authMiddleware *APIAuthMiddleware
|
||||
}
|
||||
|
||||
// NewHandler creates a new Handler instance with the provided instance manager and configuration
|
||||
func NewHandler(im manager.InstanceManager, cfg config.AppConfig) *Handler {
|
||||
return &Handler{
|
||||
func NewHandler(im manager.InstanceManager, cfg config.AppConfig, authStore database.AuthStore) *Handler {
|
||||
handler := &Handler{
|
||||
InstanceManager: im,
|
||||
cfg: cfg,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
authStore: authStore,
|
||||
}
|
||||
handler.authMiddleware = NewAPIAuthMiddleware(cfg.Auth, authStore)
|
||||
return handler
|
||||
}
|
||||
|
||||
// getInstance retrieves an instance by name from the request query parameters
|
||||
// getInstance retrieves an instance by name from request query parameters
|
||||
func (h *Handler) getInstance(r *http.Request) (*instance.Instance, error) {
|
||||
name := chi.URLParam(r, "name")
|
||||
validatedName, err := validation.ValidateInstanceName(name)
|
||||
@@ -81,7 +87,7 @@ func (h *Handler) getInstance(r *http.Request) (*instance.Instance, error) {
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
// ensureInstanceRunning ensures the instance is running by starting it if on-demand start is enabled
|
||||
// ensureInstanceRunning ensures that an instance is running by starting it if on-demand start is enabled
|
||||
// It handles LRU eviction when the maximum number of running instances is reached
|
||||
func (h *Handler) ensureInstanceRunning(inst *instance.Instance) error {
|
||||
options := inst.GetOptions()
|
||||
|
||||
284
pkg/server/handlers_auth.go
Normal file
284
pkg/server/handlers_auth.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"llamactl/pkg/auth"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
type InstancePermission struct {
|
||||
InstanceID int `json:"instance_id"`
|
||||
CanInfer bool `json:"can_infer"`
|
||||
CanViewLogs bool `json:"can_view_logs"`
|
||||
}
|
||||
|
||||
type CreateKeyRequest struct {
|
||||
Name string
|
||||
PermissionMode auth.PermissionMode
|
||||
ExpiresAt *int64
|
||||
InstancePermissions []InstancePermission
|
||||
}
|
||||
|
||||
// CreateInferenceKey handles POST /api/v1/keys
|
||||
func (h *Handler) CreateInferenceKey() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req CreateKeyRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_json", "Invalid JSON in request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate request
|
||||
if req.Name == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_name", "Name is required")
|
||||
return
|
||||
}
|
||||
if len(req.Name) > 100 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_name", "Name must be 100 characters or less")
|
||||
return
|
||||
}
|
||||
if req.PermissionMode != auth.PermissionModeAllowAll && req.PermissionMode != auth.PermissionModePerInstance {
|
||||
writeError(w, http.StatusBadRequest, "invalid_permission_mode", "Permission mode must be 'allow_all' or 'per_instance'")
|
||||
return
|
||||
}
|
||||
if req.PermissionMode == auth.PermissionModePerInstance && len(req.InstancePermissions) == 0 {
|
||||
writeError(w, http.StatusBadRequest, "missing_permissions", "Instance permissions required when permission mode is 'per_instance'")
|
||||
return
|
||||
}
|
||||
if req.ExpiresAt != nil && *req.ExpiresAt <= time.Now().Unix() {
|
||||
writeError(w, http.StatusBadRequest, "invalid_expires_at", "Expiration time must be in future")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate instance IDs exist
|
||||
if req.PermissionMode == auth.PermissionModePerInstance {
|
||||
instances, err := h.InstanceManager.ListInstances()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "fetch_instances_failed", fmt.Sprintf("Failed to fetch instances: %v", err))
|
||||
return
|
||||
}
|
||||
instanceIDMap := make(map[int]bool)
|
||||
for _, inst := range instances {
|
||||
instanceIDMap[inst.ID] = true
|
||||
}
|
||||
|
||||
for _, perm := range req.InstancePermissions {
|
||||
if !instanceIDMap[perm.InstanceID] {
|
||||
writeError(w, http.StatusBadRequest, "invalid_instance_id", fmt.Sprintf("Instance ID %d does not exist", perm.InstanceID))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate plain-text key
|
||||
plainTextKey, err := auth.GenerateKey()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "key_generation_failed", "Failed to generate API key")
|
||||
return
|
||||
}
|
||||
|
||||
// Hash key
|
||||
keyHash, err := auth.HashKey(plainTextKey)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "key_hashing_failed", "Failed to hash API key")
|
||||
return
|
||||
}
|
||||
|
||||
// Create APIKey struct
|
||||
now := time.Now().Unix()
|
||||
apiKey := &auth.APIKey{
|
||||
KeyHash: keyHash,
|
||||
Name: req.Name,
|
||||
UserID: "system",
|
||||
PermissionMode: req.PermissionMode,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
// Convert InstancePermissions to KeyPermissions
|
||||
var keyPermissions []auth.KeyPermission
|
||||
for _, perm := range req.InstancePermissions {
|
||||
keyPermissions = append(keyPermissions, auth.KeyPermission{
|
||||
KeyID: 0, // Will be set by database after key creation
|
||||
InstanceID: perm.InstanceID,
|
||||
CanInfer: perm.CanInfer,
|
||||
CanViewLogs: perm.CanViewLogs,
|
||||
})
|
||||
}
|
||||
|
||||
// Create in database
|
||||
err = h.authStore.CreateKey(r.Context(), apiKey, keyPermissions)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "creation_failed", fmt.Sprintf("Failed to create API key: %v", err))
|
||||
return
|
||||
} // Return response with plain-text key (only shown once)
|
||||
response := map[string]interface{}{
|
||||
"id": apiKey.ID,
|
||||
"name": apiKey.Name,
|
||||
"user_id": apiKey.UserID,
|
||||
"permission_mode": apiKey.PermissionMode,
|
||||
"expires_at": apiKey.ExpiresAt,
|
||||
"enabled": apiKey.Enabled,
|
||||
"created_at": apiKey.CreatedAt,
|
||||
"updated_at": apiKey.UpdatedAt,
|
||||
"last_used_at": apiKey.LastUsedAt,
|
||||
"key": plainTextKey, // Only returned on creation
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
}
|
||||
|
||||
// ListInferenceKeys handles GET /api/v1/keys
|
||||
func (h *Handler) ListInferenceKeys() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := h.authStore.GetUserKeys(r.Context(), "system")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "fetch_failed", fmt.Sprintf("Failed to fetch API keys: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Remove key_hash from all keys
|
||||
var response []map[string]interface{}
|
||||
for _, key := range keys {
|
||||
response = append(response, map[string]interface{}{
|
||||
"id": key.ID,
|
||||
"name": key.Name,
|
||||
"user_id": key.UserID,
|
||||
"permission_mode": key.PermissionMode,
|
||||
"expires_at": key.ExpiresAt,
|
||||
"enabled": key.Enabled,
|
||||
"created_at": key.CreatedAt,
|
||||
"updated_at": key.UpdatedAt,
|
||||
"last_used_at": key.LastUsedAt,
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
}
|
||||
|
||||
// GetInferenceKey handles GET /api/v1/keys/{id}
|
||||
func (h *Handler) GetInferenceKey() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_id", "Invalid key ID")
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.authStore.GetKeyByID(r.Context(), id)
|
||||
if err != nil {
|
||||
if err.Error() == "API key not found" {
|
||||
writeError(w, http.StatusNotFound, "not_found", "API key not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "fetch_failed", fmt.Sprintf("Failed to fetch API key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Remove key_hash from response
|
||||
response := map[string]interface{}{
|
||||
"id": key.ID,
|
||||
"name": key.Name,
|
||||
"user_id": key.UserID,
|
||||
"permission_mode": key.PermissionMode,
|
||||
"expires_at": key.ExpiresAt,
|
||||
"enabled": key.Enabled,
|
||||
"created_at": key.CreatedAt,
|
||||
"updated_at": key.UpdatedAt,
|
||||
"last_used_at": key.LastUsedAt,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteInferenceKey handles DELETE /api/v1/keys/{id}
|
||||
func (h *Handler) DeleteInferenceKey() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_id", "Invalid key ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.authStore.DeleteKey(r.Context(), id)
|
||||
if err != nil {
|
||||
if err.Error() == "API key not found" {
|
||||
writeError(w, http.StatusNotFound, "not_found", "API key not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "deletion_failed", fmt.Sprintf("Failed to delete API key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
// GetInferenceKeyPermissions handles GET /api/v1/keys/{id}/permissions
|
||||
func (h *Handler) GetInferenceKeyPermissions() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_id", "Invalid key ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify key exists
|
||||
_, err = h.authStore.GetKeyByID(r.Context(), id)
|
||||
if err != nil {
|
||||
if err.Error() == "API key not found" {
|
||||
writeError(w, http.StatusNotFound, "not_found", "API key not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "fetch_failed", fmt.Sprintf("Failed to fetch API key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
permissions, err := h.authStore.GetPermissions(r.Context(), id)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "fetch_failed", fmt.Sprintf("Failed to fetch permissions: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Get instance names for the permissions
|
||||
instances, err := h.InstanceManager.ListInstances()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "fetch_instances_failed", fmt.Sprintf("Failed to fetch instances: %v", err))
|
||||
return
|
||||
}
|
||||
instanceNameMap := make(map[int]string)
|
||||
for _, inst := range instances {
|
||||
instanceNameMap[inst.ID] = inst.Name
|
||||
}
|
||||
|
||||
var response []map[string]interface{}
|
||||
for _, perm := range permissions {
|
||||
response = append(response, map[string]interface{}{
|
||||
"instance_id": perm.InstanceID,
|
||||
"instance_name": instanceNameMap[perm.InstanceID],
|
||||
"can_infer": perm.CanInfer,
|
||||
"can_view_logs": perm.CanViewLogs,
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,19 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"llamactl/pkg/auth"
|
||||
"llamactl/pkg/config"
|
||||
"llamactl/pkg/database"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type KeyType int
|
||||
@@ -19,58 +23,59 @@ const (
|
||||
KeyTypeManagement
|
||||
)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
apiKeyContextKey contextKey = "apiKey"
|
||||
)
|
||||
|
||||
type APIAuthMiddleware struct {
|
||||
authStore database.AuthStore
|
||||
requireInferenceAuth bool
|
||||
inferenceKeys map[string]bool
|
||||
requireManagementAuth bool
|
||||
managementKeys map[string]bool
|
||||
managementKeys map[string]bool // Config-based management keys
|
||||
}
|
||||
|
||||
// NewAPIAuthMiddleware creates a new APIAuthMiddleware with the given configuration
|
||||
func NewAPIAuthMiddleware(authCfg config.AuthConfig) *APIAuthMiddleware {
|
||||
func NewAPIAuthMiddleware(authCfg config.AuthConfig, authStore database.AuthStore) *APIAuthMiddleware {
|
||||
// Load management keys from config into managementKeys map
|
||||
managementKeys := make(map[string]bool)
|
||||
for _, key := range authCfg.ManagementKeys {
|
||||
managementKeys[key] = true
|
||||
}
|
||||
|
||||
// If len(authCfg.InferenceKeys) > 0, log warning
|
||||
if len(authCfg.InferenceKeys) > 0 {
|
||||
log.Println("⚠️ Config-based inference keys are no longer supported and will be ignored.")
|
||||
log.Println(" Please create inference keys in web UI or via management API.")
|
||||
}
|
||||
|
||||
// Handle legacy auto-generation for management keys if none provided and auth is required
|
||||
var generated bool = false
|
||||
|
||||
inferenceAPIKeys := make(map[string]bool)
|
||||
managementAPIKeys := make(map[string]bool)
|
||||
|
||||
const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
if authCfg.RequireManagementAuth && len(authCfg.ManagementKeys) == 0 {
|
||||
key := generateAPIKey(KeyTypeManagement)
|
||||
managementAPIKeys[key] = true
|
||||
managementKeys[key] = true
|
||||
generated = true
|
||||
fmt.Printf("%s\n⚠️ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner)
|
||||
fmt.Printf("🔑 Generated Management API Key:\n\n %s\n\n", key)
|
||||
}
|
||||
for _, key := range authCfg.ManagementKeys {
|
||||
managementAPIKeys[key] = true
|
||||
}
|
||||
|
||||
if authCfg.RequireInferenceAuth && len(authCfg.InferenceKeys) == 0 {
|
||||
key := generateAPIKey(KeyTypeInference)
|
||||
inferenceAPIKeys[key] = true
|
||||
generated = true
|
||||
fmt.Printf("%s\n⚠️ INFERENCE AUTHENTICATION REQUIRED\n%s\n", banner, banner)
|
||||
fmt.Printf("🔑 Generated Inference API Key:\n\n %s\n\n", key)
|
||||
}
|
||||
for _, key := range authCfg.InferenceKeys {
|
||||
inferenceAPIKeys[key] = true
|
||||
}
|
||||
|
||||
if generated {
|
||||
fmt.Printf("%s\n⚠️ IMPORTANT\n%s\n", banner, banner)
|
||||
fmt.Println("• These keys are auto-generated and will change on restart")
|
||||
fmt.Println("• This key is auto-generated and will change on restart")
|
||||
fmt.Println("• For production, add explicit keys to your configuration")
|
||||
fmt.Println("• Copy these keys before they disappear from the terminal")
|
||||
fmt.Println("• Copy this key before it disappears from the terminal")
|
||||
fmt.Println(banner)
|
||||
}
|
||||
|
||||
return &APIAuthMiddleware{
|
||||
authStore: authStore,
|
||||
requireInferenceAuth: authCfg.RequireInferenceAuth,
|
||||
inferenceKeys: inferenceAPIKeys,
|
||||
requireManagementAuth: authCfg.RequireManagementAuth,
|
||||
managementKeys: managementAPIKeys,
|
||||
managementKeys: managementKeys,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,7 +105,120 @@ func generateAPIKey(keyType KeyType) string {
|
||||
return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes))
|
||||
}
|
||||
|
||||
// AuthMiddleware returns a middleware that checks API keys for the given key type
|
||||
// InferenceAuthMiddleware returns middleware for inference endpoints
|
||||
func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "OPTIONS" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract API key from request
|
||||
apiKey := a.extractAPIKey(r)
|
||||
if apiKey == "" {
|
||||
a.unauthorized(w, "Missing API key")
|
||||
return
|
||||
}
|
||||
|
||||
// Try database authentication first
|
||||
var foundKey *auth.APIKey
|
||||
if a.requireInferenceAuth {
|
||||
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
|
||||
if err != nil {
|
||||
log.Printf("Failed to get active inference keys: %v", err)
|
||||
// Continue to management key fallback
|
||||
} else {
|
||||
for _, key := range activeKeys {
|
||||
if auth.VerifyKey(apiKey, key.KeyHash) {
|
||||
foundKey = key
|
||||
// Async update last_used_at
|
||||
go func(keyID int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := a.authStore.TouchKey(ctx, keyID); err != nil {
|
||||
log.Printf("Failed to update last used timestamp for key %d: %v", keyID, err)
|
||||
}
|
||||
}(key.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no database key found, try management key authentication (config-based)
|
||||
if foundKey == nil {
|
||||
if !a.isValidManagementKey(apiKey) {
|
||||
a.unauthorized(w, "Invalid API key")
|
||||
return
|
||||
}
|
||||
// Management key was used, continue without adding APIKey to context
|
||||
} else {
|
||||
// Add APIKey to context for permission checking
|
||||
ctx := context.WithValue(r.Context(), apiKeyContextKey, foundKey)
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ManagementAuthMiddleware returns middleware for management endpoints
|
||||
func (a *APIAuthMiddleware) ManagementAuthMiddleware() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "OPTIONS" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract API key from request
|
||||
apiKey := a.extractAPIKey(r)
|
||||
if apiKey == "" {
|
||||
a.unauthorized(w, "Missing API key")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if key exists in managementKeys map using constant-time comparison
|
||||
if !a.isValidManagementKey(apiKey) {
|
||||
a.unauthorized(w, "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// CheckInstancePermission checks if the authenticated key has permission for the instance
|
||||
func (a *APIAuthMiddleware) CheckInstancePermission(ctx context.Context, instanceID int) error {
|
||||
// Extract APIKey from context
|
||||
apiKey, ok := ctx.Value(apiKeyContextKey).(*auth.APIKey)
|
||||
if !ok {
|
||||
// APIKey is nil, management key was used, allow all
|
||||
return nil
|
||||
}
|
||||
|
||||
// If permission_mode == "allow_all", allow all
|
||||
if apiKey.PermissionMode == auth.PermissionModeAllowAll {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check per-instance permissions
|
||||
canInfer, err := a.authStore.HasPermission(ctx, apiKey.ID, instanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !canInfer {
|
||||
return http.ErrBodyNotAllowed // Use this as a generic error to indicate permission denied
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthMiddleware returns a middleware that checks API keys for the given key type (legacy support)
|
||||
func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -118,10 +236,38 @@ func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) h
|
||||
var isValid bool
|
||||
switch keyType {
|
||||
case KeyTypeInference:
|
||||
// Management keys also work for OpenAI endpoints (higher privilege)
|
||||
isValid = a.isValidKey(apiKey, KeyTypeInference) || a.isValidKey(apiKey, KeyTypeManagement)
|
||||
// Try database authentication first
|
||||
if a.requireInferenceAuth {
|
||||
activeKeys, err := a.authStore.GetActiveKeys(r.Context())
|
||||
if err == nil {
|
||||
for _, key := range activeKeys {
|
||||
if auth.VerifyKey(apiKey, key.KeyHash) {
|
||||
foundKey := key
|
||||
// Async update last_used_at
|
||||
go func(keyID int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := a.authStore.TouchKey(ctx, keyID); err != nil {
|
||||
log.Printf("Failed to update last used timestamp for key %d: %v", keyID, err)
|
||||
}
|
||||
}(key.ID)
|
||||
|
||||
// Add APIKey to context for permission checking
|
||||
ctx := context.WithValue(r.Context(), apiKeyContextKey, foundKey)
|
||||
r = r.WithContext(ctx)
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no database key found, try management key (higher privilege)
|
||||
if !isValid {
|
||||
isValid = a.isValidManagementKey(apiKey)
|
||||
}
|
||||
case KeyTypeManagement:
|
||||
isValid = a.isValidKey(apiKey, KeyTypeManagement)
|
||||
isValid = a.isValidManagementKey(apiKey)
|
||||
default:
|
||||
isValid = false
|
||||
}
|
||||
@@ -158,20 +304,9 @@ func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// isValidKey checks if the provided API key is valid for the given key type
|
||||
func (a *APIAuthMiddleware) isValidKey(providedKey string, keyType KeyType) bool {
|
||||
var validKeys map[string]bool
|
||||
|
||||
switch keyType {
|
||||
case KeyTypeInference:
|
||||
validKeys = a.inferenceKeys
|
||||
case KeyTypeManagement:
|
||||
validKeys = a.managementKeys
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
for validKey := range validKeys {
|
||||
// isValidManagementKey checks if the provided API key is a valid management key
|
||||
func (a *APIAuthMiddleware) isValidManagementKey(providedKey string) bool {
|
||||
for validKey := range a.managementKeys {
|
||||
if len(providedKey) == len(validKey) &&
|
||||
subtle.ConstantTimeCompare([]byte(providedKey), []byte(validKey)) == 1 {
|
||||
return true
|
||||
@@ -187,3 +322,11 @@ func (a *APIAuthMiddleware) unauthorized(w http.ResponseWriter, message string)
|
||||
response := fmt.Sprintf(`{"error": {"message": "%s", "type": "authentication_error"}}`, message)
|
||||
w.Write([]byte(response))
|
||||
}
|
||||
|
||||
// forbidden sends a forbidden response
|
||||
func (a *APIAuthMiddleware) forbidden(w http.ResponseWriter, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
response := fmt.Sprintf(`{"error": {"message": "%s", "type": "permission_denied"}}`, message)
|
||||
w.Write([]byte(response))
|
||||
}
|
||||
|
||||
@@ -19,15 +19,7 @@ func TestAuthMiddleware(t *testing.T) {
|
||||
method string
|
||||
expectedStatus int
|
||||
}{
|
||||
// Valid key tests
|
||||
{
|
||||
name: "valid inference key for inference",
|
||||
keyType: server.KeyTypeInference,
|
||||
inferenceKeys: []string{"sk-inference-valid123"},
|
||||
requestKey: "sk-inference-valid123",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
// Valid key tests - using management keys only since config-based inference keys are deprecated
|
||||
{
|
||||
name: "valid management key for inference", // Management keys work for inference
|
||||
keyType: server.KeyTypeInference,
|
||||
@@ -123,7 +115,7 @@ func TestAuthMiddleware(t *testing.T) {
|
||||
InferenceKeys: tt.inferenceKeys,
|
||||
ManagementKeys: tt.managementKeys,
|
||||
}
|
||||
middleware := server.NewAPIAuthMiddleware(cfg)
|
||||
middleware := server.NewAPIAuthMiddleware(cfg, nil)
|
||||
|
||||
// Create test request
|
||||
req := httptest.NewRequest(tt.method, "/test", nil)
|
||||
@@ -131,7 +123,7 @@ func TestAuthMiddleware(t *testing.T) {
|
||||
req.Header.Set("Authorization", "Bearer "+tt.requestKey)
|
||||
}
|
||||
|
||||
// Create test handler using the appropriate middleware
|
||||
// Create test handler using appropriate middleware
|
||||
var handler http.Handler
|
||||
if tt.keyType == server.KeyTypeInference {
|
||||
handler = middleware.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -189,7 +181,7 @@ func TestGenerateAPIKey(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create middleware - this should trigger key generation
|
||||
middleware := server.NewAPIAuthMiddleware(config)
|
||||
middleware := server.NewAPIAuthMiddleware(config, nil)
|
||||
|
||||
// Test that auth is required (meaning a key was generated)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
@@ -214,7 +206,7 @@ func TestGenerateAPIKey(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test uniqueness by creating another middleware instance
|
||||
middleware2 := server.NewAPIAuthMiddleware(config)
|
||||
middleware2 := server.NewAPIAuthMiddleware(config, nil)
|
||||
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
recorder2 := httptest.NewRecorder()
|
||||
@@ -314,7 +306,7 @@ func TestAutoGeneration(t *testing.T) {
|
||||
ManagementKeys: tt.providedManagement,
|
||||
}
|
||||
|
||||
middleware := server.NewAPIAuthMiddleware(cfg)
|
||||
middleware := server.NewAPIAuthMiddleware(cfg, nil)
|
||||
|
||||
// Test inference behavior if inference auth is required
|
||||
if tt.requireInference {
|
||||
|
||||
@@ -27,7 +27,7 @@ func SetupRouter(handler *Handler) *chi.Mux {
|
||||
}))
|
||||
|
||||
// Add API authentication middleware
|
||||
authMiddleware := NewAPIAuthMiddleware(handler.cfg.Auth)
|
||||
authMiddleware := NewAPIAuthMiddleware(handler.cfg.Auth, handler.authStore)
|
||||
|
||||
if handler.cfg.Server.EnableSwagger {
|
||||
r.Get("/swagger/*", httpSwagger.Handler(
|
||||
@@ -46,6 +46,17 @@ func SetupRouter(handler *Handler) *chi.Mux {
|
||||
|
||||
r.Get("/config", handler.ConfigHandler())
|
||||
|
||||
// API key management endpoints
|
||||
r.Route("/auth", func(r chi.Router) {
|
||||
r.Route("/keys", func(r chi.Router) {
|
||||
r.Post("/", handler.CreateInferenceKey()) // Create API key
|
||||
r.Get("/", handler.ListInferenceKeys()) // List API keys
|
||||
r.Get("/{id}", handler.GetInferenceKey()) // Get API key details
|
||||
r.Delete("/{id}", handler.DeleteInferenceKey()) // Delete API key
|
||||
r.Get("/{id}/permissions", handler.GetInferenceKeyPermissions()) // Get key permissions
|
||||
})
|
||||
})
|
||||
|
||||
// Backend-specific endpoints
|
||||
r.Route("/backends", func(r chi.Router) {
|
||||
r.Route("/llama-cpp", func(r chi.Router) {
|
||||
@@ -94,13 +105,13 @@ func SetupRouter(handler *Handler) *chi.Mux {
|
||||
})
|
||||
})
|
||||
|
||||
r.Route(("/v1"), func(r chi.Router) {
|
||||
r.Route("/v1", func(r chi.Router) {
|
||||
|
||||
if authMiddleware != nil && handler.cfg.Auth.RequireInferenceAuth {
|
||||
r.Use(authMiddleware.AuthMiddleware(KeyTypeInference))
|
||||
}
|
||||
|
||||
r.Get(("/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format
|
||||
r.Get("/models", handler.OpenAIListInstances()) // List instances in OpenAI-compatible format
|
||||
|
||||
// OpenAI-compatible proxy endpoint
|
||||
// Handles all POST requests to /v1/*, including:
|
||||
@@ -128,7 +139,7 @@ func SetupRouter(handler *Handler) *chi.Mux {
|
||||
r.Use(authMiddleware.AuthMiddleware(KeyTypeInference))
|
||||
}
|
||||
|
||||
// This handler auto start the server if it's not running
|
||||
// This handler auto starts the server if it's not running
|
||||
llamaCppHandler := handler.LlamaCppProxy()
|
||||
|
||||
// llama.cpp server specific proxy endpoints
|
||||
|
||||
Reference in New Issue
Block a user