From 5ccf493e04cea3d9bfc06ba664e4b3f967ad996a Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 3 Dec 2025 21:14:44 +0100 Subject: [PATCH] Add permission checks to proxies --- pkg/auth/key.go | 8 +- pkg/server/handlers_auth.go | 114 +++++---- pkg/server/handlers_backends.go | 6 + pkg/server/handlers_instances.go | 6 + pkg/server/handlers_openai.go | 6 + pkg/server/middleware.go | 112 +-------- pkg/server/middleware_test.go | 384 +++++++++++++------------------ pkg/server/routes.go | 6 +- 8 files changed, 271 insertions(+), 371 deletions(-) diff --git a/pkg/auth/key.go b/pkg/auth/key.go index 211647c..9485c1b 100644 --- a/pkg/auth/key.go +++ b/pkg/auth/key.go @@ -33,9 +33,8 @@ type KeyPermission struct { CanViewLogs bool } -// GenerateKey generates a cryptographically secure inference API key -// Format: sk-inference-<64-hex-chars> -func GenerateKey() (string, error) { +// GenerateKey generates a cryptographically secure API key with the given prefix +func GenerateKey(prefix string) (string, error) { // Generate 32 random bytes bytes := make([]byte, 32) _, err := rand.Read(bytes) @@ -46,6 +45,5 @@ func GenerateKey() (string, error) { // Convert to hex (64 characters) hexStr := hex.EncodeToString(bytes) - // Prefix with "sk-inference-" - return fmt.Sprintf("sk-inference-%s", hexStr), nil + return fmt.Sprintf("%s-%s", prefix, hexStr), nil } diff --git a/pkg/server/handlers_auth.go b/pkg/server/handlers_auth.go index 3971711..2be79b0 100644 --- a/pkg/server/handlers_auth.go +++ b/pkg/server/handlers_auth.go @@ -24,6 +24,38 @@ type CreateKeyRequest struct { InstancePermissions []InstancePermission } +type CreateKeyResponse struct { + ID int `json:"id"` + Name string `json:"name"` + UserID string `json:"user_id"` + PermissionMode auth.PermissionMode `json:"permission_mode"` + ExpiresAt *int64 `json:"expires_at"` + Enabled bool `json:"enabled"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + LastUsedAt *int64 `json:"last_used_at"` + Key string `json:"key"` +} + +type KeyResponse struct { + ID int `json:"id"` + Name string `json:"name"` + UserID string `json:"user_id"` + PermissionMode auth.PermissionMode `json:"permission_mode"` + ExpiresAt *int64 `json:"expires_at"` + Enabled bool `json:"enabled"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + LastUsedAt *int64 `json:"last_used_at"` +} + +type KeyPermissionResponse struct { + InstanceID int `json:"instance_id"` + InstanceName string `json:"instance_name"` + CanInfer bool `json:"can_infer"` + CanViewLogs bool `json:"can_view_logs"` +} + // CreateInferenceKey handles POST /api/v1/keys func (h *Handler) CreateInferenceKey() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { @@ -76,7 +108,7 @@ func (h *Handler) CreateInferenceKey() http.HandlerFunc { } // Generate plain-text key - plainTextKey, err := auth.GenerateKey() + plainTextKey, err := auth.GenerateKey("llamactl-") if err != nil { writeError(w, http.StatusInternalServerError, "key_generation_failed", "Failed to generate API key") return @@ -118,18 +150,20 @@ func (h *Handler) CreateInferenceKey() http.HandlerFunc { if err != nil { writeError(w, http.StatusInternalServerError, "creation_failed", fmt.Sprintf("Failed to create API key: %v", err)) return - } // Return response with plain-text key (only shown once) - response := map[string]interface{}{ - "id": apiKey.ID, - "name": apiKey.Name, - "user_id": apiKey.UserID, - "permission_mode": apiKey.PermissionMode, - "expires_at": apiKey.ExpiresAt, - "enabled": apiKey.Enabled, - "created_at": apiKey.CreatedAt, - "updated_at": apiKey.UpdatedAt, - "last_used_at": apiKey.LastUsedAt, - "key": plainTextKey, // Only returned on creation + } + + // Return response with plain-text key (only shown once) + response := CreateKeyResponse{ + ID: apiKey.ID, + Name: apiKey.Name, + UserID: apiKey.UserID, + PermissionMode: apiKey.PermissionMode, + ExpiresAt: apiKey.ExpiresAt, + Enabled: apiKey.Enabled, + CreatedAt: apiKey.CreatedAt, + UpdatedAt: apiKey.UpdatedAt, + LastUsedAt: apiKey.LastUsedAt, + Key: plainTextKey, } w.Header().Set("Content-Type", "application/json") @@ -148,18 +182,18 @@ func (h *Handler) ListInferenceKeys() http.HandlerFunc { } // Remove key_hash from all keys - var response []map[string]interface{} + response := make([]KeyResponse, 0, len(keys)) for _, key := range keys { - response = append(response, map[string]interface{}{ - "id": key.ID, - "name": key.Name, - "user_id": key.UserID, - "permission_mode": key.PermissionMode, - "expires_at": key.ExpiresAt, - "enabled": key.Enabled, - "created_at": key.CreatedAt, - "updated_at": key.UpdatedAt, - "last_used_at": key.LastUsedAt, + response = append(response, KeyResponse{ + ID: key.ID, + Name: key.Name, + UserID: key.UserID, + PermissionMode: key.PermissionMode, + ExpiresAt: key.ExpiresAt, + Enabled: key.Enabled, + CreatedAt: key.CreatedAt, + UpdatedAt: key.UpdatedAt, + LastUsedAt: key.LastUsedAt, }) } @@ -189,16 +223,16 @@ func (h *Handler) GetInferenceKey() http.HandlerFunc { } // Remove key_hash from response - response := map[string]interface{}{ - "id": key.ID, - "name": key.Name, - "user_id": key.UserID, - "permission_mode": key.PermissionMode, - "expires_at": key.ExpiresAt, - "enabled": key.Enabled, - "created_at": key.CreatedAt, - "updated_at": key.UpdatedAt, - "last_used_at": key.LastUsedAt, + response := KeyResponse{ + ID: key.ID, + Name: key.Name, + UserID: key.UserID, + PermissionMode: key.PermissionMode, + ExpiresAt: key.ExpiresAt, + Enabled: key.Enabled, + CreatedAt: key.CreatedAt, + UpdatedAt: key.UpdatedAt, + LastUsedAt: key.LastUsedAt, } w.Header().Set("Content-Type", "application/json") @@ -268,13 +302,13 @@ func (h *Handler) GetInferenceKeyPermissions() http.HandlerFunc { instanceNameMap[inst.ID] = inst.Name } - var response []map[string]interface{} + response := make([]KeyPermissionResponse, 0, len(permissions)) for _, perm := range permissions { - response = append(response, map[string]interface{}{ - "instance_id": perm.InstanceID, - "instance_name": instanceNameMap[perm.InstanceID], - "can_infer": perm.CanInfer, - "can_view_logs": perm.CanViewLogs, + response = append(response, KeyPermissionResponse{ + InstanceID: perm.InstanceID, + InstanceName: instanceNameMap[perm.InstanceID], + CanInfer: perm.CanInfer, + CanViewLogs: perm.CanViewLogs, }) } diff --git a/pkg/server/handlers_backends.go b/pkg/server/handlers_backends.go index 1e249f9..065b24e 100644 --- a/pkg/server/handlers_backends.go +++ b/pkg/server/handlers_backends.go @@ -109,6 +109,12 @@ func (h *Handler) LlamaCppProxy() http.HandlerFunc { return } + // Check instance permissions + if err := h.authMiddleware.CheckInstancePermission(r.Context(), inst.ID); err != nil { + writeError(w, http.StatusForbidden, "permission_denied", err.Error()) + return + } + // Check if instance is shutting down before autostart logic if inst.GetStatus() == instance.ShuttingDown { writeError(w, http.StatusServiceUnavailable, "instance_shutting_down", "Instance is shutting down") diff --git a/pkg/server/handlers_instances.go b/pkg/server/handlers_instances.go index 43bed3e..e155fd1 100644 --- a/pkg/server/handlers_instances.go +++ b/pkg/server/handlers_instances.go @@ -327,6 +327,12 @@ func (h *Handler) InstanceProxy() http.HandlerFunc { return } + // Check instance permissions + if err := h.authMiddleware.CheckInstancePermission(r.Context(), inst.ID); err != nil { + writeError(w, http.StatusForbidden, "permission_denied", err.Error()) + return + } + if !inst.IsRunning() { writeError(w, http.StatusServiceUnavailable, "instance_not_running", "Instance is not running") return diff --git a/pkg/server/handlers_openai.go b/pkg/server/handlers_openai.go index 81aa9e7..a7ad635 100644 --- a/pkg/server/handlers_openai.go +++ b/pkg/server/handlers_openai.go @@ -107,6 +107,12 @@ func (h *Handler) OpenAIProxy() http.HandlerFunc { return } + // Check instance permissions + if err := h.authMiddleware.CheckInstancePermission(r.Context(), inst.ID); err != nil { + writeError(w, http.StatusForbidden, "permission_denied", err.Error()) + return + } + // Check if instance is shutting down before autostart logic if inst.GetStatus() == instance.ShuttingDown { writeError(w, http.StatusServiceUnavailable, "instance_shutting_down", "Instance is shutting down") diff --git a/pkg/server/middleware.go b/pkg/server/middleware.go index 0c1797c..cd4f24f 100644 --- a/pkg/server/middleware.go +++ b/pkg/server/middleware.go @@ -2,9 +2,7 @@ package server import ( "context" - "crypto/rand" "crypto/subtle" - "encoding/hex" "fmt" "llamactl/pkg/auth" "llamactl/pkg/config" @@ -16,13 +14,6 @@ import ( "time" ) -type KeyType int - -const ( - KeyTypeInference KeyType = iota - KeyTypeManagement -) - // contextKey is a custom type for context keys to avoid collisions type contextKey string @@ -56,7 +47,12 @@ func NewAPIAuthMiddleware(authCfg config.AuthConfig, authStore database.AuthStor const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" if authCfg.RequireManagementAuth && len(authCfg.ManagementKeys) == 0 { - key := generateAPIKey(KeyTypeManagement) + key, err := auth.GenerateKey("llamactl-mgmt-") + if err != nil { + log.Printf("Warning: Failed to generate management key: %v", err) + // Fallback to PID-based key for safety + key = fmt.Sprintf("sk-management-fallback-%d", os.Getpid()) + } managementKeys[key] = true generated = true fmt.Printf("%s\n⚠️ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner) @@ -79,32 +75,6 @@ func NewAPIAuthMiddleware(authCfg config.AuthConfig, authStore database.AuthStor } } -// generateAPIKey creates a cryptographically secure API key -func generateAPIKey(keyType KeyType) string { - // Generate 32 random bytes (256 bits) - randomBytes := make([]byte, 32) - - var prefix string - - switch keyType { - case KeyTypeInference: - prefix = "sk-inference" - case KeyTypeManagement: - prefix = "sk-management" - default: - prefix = "sk-unknown" - } - - if _, err := rand.Read(randomBytes); err != nil { - log.Printf("Warning: Failed to generate secure random key, using fallback") - // Fallback to a less secure method if crypto/rand fails - return fmt.Sprintf("%s-fallback-%d", prefix, os.Getpid()) - } - - // Convert to hex and add prefix - return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes)) -} - // InferenceAuthMiddleware returns middleware for inference endpoints func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { @@ -123,7 +93,7 @@ func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Ha // Try database authentication first var foundKey *auth.APIKey - if a.requireInferenceAuth { + if a.requireInferenceAuth && a.authStore != nil { activeKeys, err := a.authStore.GetActiveKeys(r.Context()) if err != nil { log.Printf("Failed to get active inference keys: %v", err) @@ -208,80 +178,16 @@ func (a *APIAuthMiddleware) CheckInstancePermission(ctx context.Context, instanc // Check per-instance permissions canInfer, err := a.authStore.HasPermission(ctx, apiKey.ID, instanceID) if err != nil { - return err + return fmt.Errorf("failed to check permission: %w", err) } if !canInfer { - return http.ErrBodyNotAllowed // Use this as a generic error to indicate permission denied + return fmt.Errorf("permission denied: key does not have access to this instance") } return nil } -// AuthMiddleware returns a middleware that checks API keys for the given key type (legacy support) -func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == "OPTIONS" { - next.ServeHTTP(w, r) - return - } - - apiKey := a.extractAPIKey(r) - if apiKey == "" { - a.unauthorized(w, "Missing API key") - return - } - - var isValid bool - switch keyType { - case KeyTypeInference: - // Try database authentication first - if a.requireInferenceAuth { - activeKeys, err := a.authStore.GetActiveKeys(r.Context()) - if err == nil { - for _, key := range activeKeys { - if auth.VerifyKey(apiKey, key.KeyHash) { - foundKey := key - // Async update last_used_at - go func(keyID int) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := a.authStore.TouchKey(ctx, keyID); err != nil { - log.Printf("Failed to update last used timestamp for key %d: %v", keyID, err) - } - }(key.ID) - - // Add APIKey to context for permission checking - ctx := context.WithValue(r.Context(), apiKeyContextKey, foundKey) - r = r.WithContext(ctx) - isValid = true - break - } - } - } - } - - // If no database key found, try management key (higher privilege) - if !isValid { - isValid = a.isValidManagementKey(apiKey) - } - case KeyTypeManagement: - isValid = a.isValidManagementKey(apiKey) - default: - isValid = false - } - - if !isValid { - a.unauthorized(w, "Invalid API key") - return - } - - next.ServeHTTP(w, r) - }) - } -} - // extractAPIKey extracts the API key from the request func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string { // Check Authorization header: "Bearer sk-..." diff --git a/pkg/server/middleware_test.go b/pkg/server/middleware_test.go index 720362f..6a552e4 100644 --- a/pkg/server/middleware_test.go +++ b/pkg/server/middleware_test.go @@ -9,99 +9,44 @@ import ( "testing" ) -func TestAuthMiddleware(t *testing.T) { +func TestInferenceAuthMiddleware(t *testing.T) { tests := []struct { name string - keyType server.KeyType - inferenceKeys []string managementKeys []string requestKey string method string expectedStatus int }{ - // Valid key tests - using management keys only since config-based inference keys are deprecated { - name: "valid management key for inference", // Management keys work for inference - keyType: server.KeyTypeInference, + name: "valid management key for inference", managementKeys: []string{"sk-management-admin123"}, requestKey: "sk-management-admin123", method: "GET", expectedStatus: http.StatusOK, }, { - name: "valid management key for management", - keyType: server.KeyTypeManagement, - managementKeys: []string{"sk-management-admin123"}, - requestKey: "sk-management-admin123", - method: "GET", - expectedStatus: http.StatusOK, - }, - - // Invalid key tests - { - name: "inference key for management should fail", - keyType: server.KeyTypeManagement, - inferenceKeys: []string{"sk-inference-user123"}, - requestKey: "sk-inference-user123", - method: "GET", - expectedStatus: http.StatusUnauthorized, - }, - { - name: "invalid inference key", - keyType: server.KeyTypeInference, - inferenceKeys: []string{"sk-inference-valid123"}, - requestKey: "sk-inference-invalid", - method: "GET", - expectedStatus: http.StatusUnauthorized, - }, - { - name: "missing inference key", - keyType: server.KeyTypeInference, - inferenceKeys: []string{"sk-inference-valid123"}, - requestKey: "", - method: "GET", - expectedStatus: http.StatusUnauthorized, - }, - { - name: "invalid management key", - keyType: server.KeyTypeManagement, + name: "invalid key", managementKeys: []string{"sk-management-valid123"}, requestKey: "sk-management-invalid", method: "GET", expectedStatus: http.StatusUnauthorized, }, { - name: "missing management key", - keyType: server.KeyTypeManagement, + name: "missing key", managementKeys: []string{"sk-management-valid123"}, requestKey: "", method: "GET", expectedStatus: http.StatusUnauthorized, }, - - // OPTIONS requests should always pass { - name: "OPTIONS request bypasses inference auth", - keyType: server.KeyTypeInference, - inferenceKeys: []string{"sk-inference-valid123"}, - requestKey: "", - method: "OPTIONS", - expectedStatus: http.StatusOK, - }, - { - name: "OPTIONS request bypasses management auth", - keyType: server.KeyTypeManagement, + name: "OPTIONS request bypasses auth", managementKeys: []string{"sk-management-valid123"}, requestKey: "", method: "OPTIONS", expectedStatus: http.StatusOK, }, - - // Cross-key-type validation { name: "management key works for inference endpoint", - keyType: server.KeyTypeInference, - inferenceKeys: []string{}, managementKeys: []string{"sk-management-admin"}, requestKey: "sk-management-admin", method: "POST", @@ -112,8 +57,8 @@ func TestAuthMiddleware(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := config.AuthConfig{ - InferenceKeys: tt.inferenceKeys, - ManagementKeys: tt.managementKeys, + RequireInferenceAuth: true, + ManagementKeys: tt.managementKeys, } middleware := server.NewAPIAuthMiddleware(cfg, nil) @@ -123,24 +68,17 @@ func TestAuthMiddleware(t *testing.T) { req.Header.Set("Authorization", "Bearer "+tt.requestKey) } - // Create test handler using appropriate middleware - var handler http.Handler - if tt.keyType == server.KeyTypeInference { - handler = middleware.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - } else { - handler = middleware.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - } + // Create test handler + handler := middleware.InferenceAuthMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) // Execute request recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, req) if recorder.Code != tt.expectedStatus { - t.Errorf("AuthMiddleware() status = %v, expected %v", recorder.Code, tt.expectedStatus) + t.Errorf("InferenceAuthMiddleware() status = %v, expected %v", recorder.Code, tt.expectedStatus) } // Check that unauthorized responses have proper format @@ -159,178 +97,171 @@ func TestAuthMiddleware(t *testing.T) { } } -func TestGenerateAPIKey(t *testing.T) { +func TestManagementAuthMiddleware(t *testing.T) { tests := []struct { - name string - keyType server.KeyType - }{ - {"inference key generation", server.KeyTypeInference}, - {"management key generation", server.KeyTypeManagement}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Test auto-generation by creating config that will trigger it - var config config.AuthConfig - if tt.keyType == server.KeyTypeInference { - config.RequireInferenceAuth = true - config.InferenceKeys = []string{} // Empty to trigger generation - } else { - config.RequireManagementAuth = true - config.ManagementKeys = []string{} // Empty to trigger generation - } - - // Create middleware - this should trigger key generation - middleware := server.NewAPIAuthMiddleware(config, nil) - - // Test that auth is required (meaning a key was generated) - req := httptest.NewRequest("GET", "/", nil) - recorder := httptest.NewRecorder() - - var handler http.Handler - if tt.keyType == server.KeyTypeInference { - handler = middleware.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - } else { - handler = middleware.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - } - - handler.ServeHTTP(recorder, req) - - // Should be unauthorized without a key (proving that a key was generated and auth is working) - if recorder.Code != http.StatusUnauthorized { - t.Errorf("Expected unauthorized without key, got status %v", recorder.Code) - } - - // Test uniqueness by creating another middleware instance - middleware2 := server.NewAPIAuthMiddleware(config, nil) - - req2 := httptest.NewRequest("GET", "/", nil) - recorder2 := httptest.NewRecorder() - - if tt.keyType == server.KeyTypeInference { - handler2 := middleware2.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - handler2.ServeHTTP(recorder2, req2) - } else { - handler2 := middleware2.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - handler2.ServeHTTP(recorder2, req2) - } - - // Both should require auth (proving keys were generated for both instances) - if recorder2.Code != http.StatusUnauthorized { - t.Errorf("Expected unauthorized for second middleware without key, got status %v", recorder2.Code) - } - }) - } -} - -func TestAutoGeneration(t *testing.T) { - tests := []struct { - name string - requireInference bool - requireManagement bool - providedInference []string - providedManagement []string - shouldGenerateInf bool // Whether inference key should be generated - shouldGenerateMgmt bool // Whether management key should be generated + name string + managementKeys []string + requestKey string + method string + expectedStatus int }{ { - name: "inference auth required, keys provided - no generation", - requireInference: true, - requireManagement: false, - providedInference: []string{"sk-inference-provided"}, - providedManagement: []string{}, - shouldGenerateInf: false, - shouldGenerateMgmt: false, + name: "valid management key", + managementKeys: []string{"sk-management-admin123"}, + requestKey: "sk-management-admin123", + method: "GET", + expectedStatus: http.StatusOK, }, { - name: "inference auth required, no keys - should auto-generate", - requireInference: true, - requireManagement: false, - providedInference: []string{}, - providedManagement: []string{}, - shouldGenerateInf: true, - shouldGenerateMgmt: false, + name: "invalid management key", + managementKeys: []string{"sk-management-valid123"}, + requestKey: "sk-management-invalid", + method: "GET", + expectedStatus: http.StatusUnauthorized, }, { - name: "management auth required, keys provided - no generation", - requireInference: false, - requireManagement: true, - providedInference: []string{}, - providedManagement: []string{"sk-management-provided"}, - shouldGenerateInf: false, - shouldGenerateMgmt: false, + name: "missing management key", + managementKeys: []string{"sk-management-valid123"}, + requestKey: "", + method: "GET", + expectedStatus: http.StatusUnauthorized, }, { - name: "management auth required, no keys - should auto-generate", - requireInference: false, - requireManagement: true, - providedInference: []string{}, - providedManagement: []string{}, - shouldGenerateInf: false, - shouldGenerateMgmt: true, - }, - { - name: "both required, both provided - no generation", - requireInference: true, - requireManagement: true, - providedInference: []string{"sk-inference-provided"}, - providedManagement: []string{"sk-management-provided"}, - shouldGenerateInf: false, - shouldGenerateMgmt: false, - }, - { - name: "both required, none provided - should auto-generate both", - requireInference: true, - requireManagement: true, - providedInference: []string{}, - providedManagement: []string{}, - shouldGenerateInf: true, - shouldGenerateMgmt: true, + name: "OPTIONS request bypasses management auth", + managementKeys: []string{"sk-management-valid123"}, + requestKey: "", + method: "OPTIONS", + expectedStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := config.AuthConfig{ + RequireManagementAuth: true, + ManagementKeys: tt.managementKeys, + } + middleware := server.NewAPIAuthMiddleware(cfg, nil) + + // Create test request + req := httptest.NewRequest(tt.method, "/test", nil) + if tt.requestKey != "" { + req.Header.Set("Authorization", "Bearer "+tt.requestKey) + } + + // Create test handler + handler := middleware.ManagementAuthMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Execute request + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + + if recorder.Code != tt.expectedStatus { + t.Errorf("ManagementAuthMiddleware() status = %v, expected %v", recorder.Code, tt.expectedStatus) + } + + // Check that unauthorized responses have proper format + if recorder.Code == http.StatusUnauthorized { + contentType := recorder.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Unauthorized response Content-Type = %v, expected application/json", contentType) + } + + body := recorder.Body.String() + if !strings.Contains(body, `"type": "authentication_error"`) { + t.Errorf("Unauthorized response missing proper error type: %v", body) + } + } + }) + } +} + +func TestManagementKeyAutoGeneration(t *testing.T) { + // Test auto-generation for management keys + config := config.AuthConfig{ + RequireManagementAuth: true, + ManagementKeys: []string{}, // Empty to trigger generation + } + + // Create middleware - this should trigger key generation + middleware := server.NewAPIAuthMiddleware(config, nil) + + // Test that auth is required (meaning a key was generated) + req := httptest.NewRequest("GET", "/", nil) + recorder := httptest.NewRecorder() + + handler := middleware.ManagementAuthMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + handler.ServeHTTP(recorder, req) + + // Should be unauthorized without a key (proving that a key was generated and auth is working) + if recorder.Code != http.StatusUnauthorized { + t.Errorf("Expected unauthorized without key, got status %v", recorder.Code) + } + + // Test uniqueness by creating another middleware instance + middleware2 := server.NewAPIAuthMiddleware(config, nil) + + req2 := httptest.NewRequest("GET", "/", nil) + recorder2 := httptest.NewRecorder() + + handler2 := middleware2.ManagementAuthMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + handler2.ServeHTTP(recorder2, req2) + + // Both should require auth (proving keys were generated for both instances) + if recorder2.Code != http.StatusUnauthorized { + t.Errorf("Expected unauthorized for second middleware without key, got status %v", recorder2.Code) + } +} + +func TestAutoGenerationScenarios(t *testing.T) { + tests := []struct { + name string + requireManagement bool + providedManagement []string + shouldGenerate bool + }{ + { + name: "management auth required, keys provided - no generation", + requireManagement: true, + providedManagement: []string{"sk-management-provided"}, + shouldGenerate: false, + }, + { + name: "management auth required, no keys - should auto-generate", + requireManagement: true, + providedManagement: []string{}, + shouldGenerate: true, + }, + { + name: "management auth not required - no generation", + requireManagement: false, + providedManagement: []string{}, + shouldGenerate: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := config.AuthConfig{ - RequireInferenceAuth: tt.requireInference, RequireManagementAuth: tt.requireManagement, - InferenceKeys: tt.providedInference, ManagementKeys: tt.providedManagement, } middleware := server.NewAPIAuthMiddleware(cfg, nil) - // Test inference behavior if inference auth is required - if tt.requireInference { - req := httptest.NewRequest("GET", "/v1/models", nil) - recorder := httptest.NewRecorder() - - handler := middleware.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - handler.ServeHTTP(recorder, req) - - // Should always be unauthorized without a key (since middleware assumes auth is required) - if recorder.Code != http.StatusUnauthorized { - t.Errorf("Expected unauthorized for inference without key, got status %v", recorder.Code) - } - } - // Test management behavior if management auth is required if tt.requireManagement { req := httptest.NewRequest("GET", "/api/v1/instances", nil) recorder := httptest.NewRecorder() - handler := middleware.AuthMiddleware(server.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := middleware.ManagementAuthMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) @@ -344,3 +275,16 @@ func TestAutoGeneration(t *testing.T) { }) } } + +func TestConfigBasedInferenceKeysDeprecationWarning(t *testing.T) { + // Test that config-based inference keys trigger a warning (captured in logs) + cfg := config.AuthConfig{ + InferenceKeys: []string{"sk-inference-old"}, + } + + // Creating middleware should log a warning, but shouldn't fail + _ = server.NewAPIAuthMiddleware(cfg, nil) + + // If we get here without panic, the test passes + // The warning is logged but not returned as an error +} diff --git a/pkg/server/routes.go b/pkg/server/routes.go index 36a6081..6920a61 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -39,7 +39,7 @@ func SetupRouter(handler *Handler) *chi.Mux { r.Route("/api/v1", func(r chi.Router) { if authMiddleware != nil && handler.cfg.Auth.RequireManagementAuth { - r.Use(authMiddleware.AuthMiddleware(KeyTypeManagement)) + r.Use(authMiddleware.ManagementAuthMiddleware()) } r.Get("/version", handler.VersionHandler()) @@ -108,7 +108,7 @@ func SetupRouter(handler *Handler) *chi.Mux { r.Route("/v1", func(r chi.Router) { if authMiddleware != nil && handler.cfg.Auth.RequireInferenceAuth { - r.Use(authMiddleware.AuthMiddleware(KeyTypeInference)) + r.Use(authMiddleware.InferenceAuthMiddleware()) } r.Get("/models", handler.OpenAIListInstances()) // List instances in OpenAI-compatible format @@ -136,7 +136,7 @@ func SetupRouter(handler *Handler) *chi.Mux { r.Group(func(r chi.Router) { if authMiddleware != nil && handler.cfg.Auth.RequireInferenceAuth { - r.Use(authMiddleware.AuthMiddleware(KeyTypeInference)) + r.Use(authMiddleware.InferenceAuthMiddleware()) } // This handler auto starts the server if it's not running