diff --git a/pkg/middleware.go b/pkg/middleware.go index 13cc800..6ed3195 100644 --- a/pkg/middleware.go +++ b/pkg/middleware.go @@ -99,15 +99,10 @@ func generateAPIKey(keyType KeyType) string { return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes)) } -// InferenceMiddleware returns middleware for OpenAI inference endpoints -func (a *APIAuthMiddleware) InferenceMiddleware() func(http.Handler) http.Handler { +// AuthMiddleware returns a middleware that checks API keys for the given key type +func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !a.requireInferenceAuth { - next.ServeHTTP(w, r) - return - } - if r.Method == "OPTIONS" { next.ServeHTTP(w, r) return @@ -119,9 +114,18 @@ func (a *APIAuthMiddleware) InferenceMiddleware() func(http.Handler) http.Handle return } - // Check if key is valid for OpenAI access - // Management keys also work for OpenAI endpoints (higher privilege) - if !a.isValidKey(apiKey, KeyTypeInference) && !a.isValidKey(apiKey, KeyTypeManagement) { + var isValid bool + switch keyType { + case KeyTypeInference: + // Management keys also work for OpenAI endpoints (higher privilege) + isValid = a.isValidKey(apiKey, KeyTypeInference) || a.isValidKey(apiKey, KeyTypeManagement) + case KeyTypeManagement: + isValid = a.isValidKey(apiKey, KeyTypeManagement) + default: + isValid = false + } + + if !isValid { a.unauthorized(w, "Invalid API key") return } @@ -131,43 +135,12 @@ func (a *APIAuthMiddleware) InferenceMiddleware() func(http.Handler) http.Handle } } -// ManagementMiddleware returns middleware for management endpoints -func (a *APIAuthMiddleware) ManagementMiddleware() func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !a.requireManagementAuth { - next.ServeHTTP(w, r) - return - } - - if r.Method == "OPTIONS" { - next.ServeHTTP(w, r) - return - } - - apiKey := a.extractAPIKey(r) - if apiKey == "" { - a.unauthorized(w, "Missing API key") - return - } - - // Only management keys work for management endpoints - if !a.isValidKey(apiKey, KeyTypeManagement) { - a.unauthorized(w, "Insufficient privileges - management key required") - return - } - - next.ServeHTTP(w, r) - }) - } -} - // extractAPIKey extracts the API key from the request func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string { // Check Authorization header: "Bearer sk-..." if auth := r.Header.Get("Authorization"); auth != "" { - if strings.HasPrefix(auth, "Bearer ") { - return strings.TrimPrefix(auth, "Bearer ") + if after, ok := strings.CutPrefix(auth, "Bearer "); ok { + return after } } diff --git a/pkg/middleware_test.go b/pkg/middleware_test.go new file mode 100644 index 0000000..2e16d1a --- /dev/null +++ b/pkg/middleware_test.go @@ -0,0 +1,354 @@ +package llamactl_test + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + llamactl "llamactl/pkg" +) + +func TestAuthMiddleware(t *testing.T) { + tests := []struct { + name string + keyType llamactl.KeyType + inferenceKeys []string + managementKeys []string + requestKey string + method string + expectedStatus int + }{ + // Valid key tests + { + name: "valid inference key for inference", + keyType: llamactl.KeyTypeInference, + inferenceKeys: []string{"sk-inference-valid123"}, + requestKey: "sk-inference-valid123", + method: "GET", + expectedStatus: http.StatusOK, + }, + { + name: "valid management key for inference", // Management keys work for inference + keyType: llamactl.KeyTypeInference, + managementKeys: []string{"sk-management-admin123"}, + requestKey: "sk-management-admin123", + method: "GET", + expectedStatus: http.StatusOK, + }, + { + name: "valid management key for management", + keyType: llamactl.KeyTypeManagement, + managementKeys: []string{"sk-management-admin123"}, + requestKey: "sk-management-admin123", + method: "GET", + expectedStatus: http.StatusOK, + }, + + // Invalid key tests + { + name: "inference key for management should fail", + keyType: llamactl.KeyTypeManagement, + inferenceKeys: []string{"sk-inference-user123"}, + requestKey: "sk-inference-user123", + method: "GET", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "invalid inference key", + keyType: llamactl.KeyTypeInference, + inferenceKeys: []string{"sk-inference-valid123"}, + requestKey: "sk-inference-invalid", + method: "GET", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "missing inference key", + keyType: llamactl.KeyTypeInference, + inferenceKeys: []string{"sk-inference-valid123"}, + requestKey: "", + method: "GET", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "invalid management key", + keyType: llamactl.KeyTypeManagement, + managementKeys: []string{"sk-management-valid123"}, + requestKey: "sk-management-invalid", + method: "GET", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "missing management key", + keyType: llamactl.KeyTypeManagement, + managementKeys: []string{"sk-management-valid123"}, + requestKey: "", + method: "GET", + expectedStatus: http.StatusUnauthorized, + }, + + // OPTIONS requests should always pass + { + name: "OPTIONS request bypasses inference auth", + keyType: llamactl.KeyTypeInference, + inferenceKeys: []string{"sk-inference-valid123"}, + requestKey: "", + method: "OPTIONS", + expectedStatus: http.StatusOK, + }, + { + name: "OPTIONS request bypasses management auth", + keyType: llamactl.KeyTypeManagement, + managementKeys: []string{"sk-management-valid123"}, + requestKey: "", + method: "OPTIONS", + expectedStatus: http.StatusOK, + }, + + // Cross-key-type validation + { + name: "management key works for inference endpoint", + keyType: llamactl.KeyTypeInference, + inferenceKeys: []string{}, + managementKeys: []string{"sk-management-admin"}, + requestKey: "sk-management-admin", + method: "POST", + expectedStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := llamactl.AuthConfig{ + InferenceKeys: tt.inferenceKeys, + ManagementKeys: tt.managementKeys, + } + middleware := llamactl.NewAPIAuthMiddleware(config) + + // Create test request + req := httptest.NewRequest(tt.method, "/test", nil) + if tt.requestKey != "" { + req.Header.Set("Authorization", "Bearer "+tt.requestKey) + } + + // Create test handler using the appropriate middleware + var handler http.Handler + if tt.keyType == llamactl.KeyTypeInference { + handler = middleware.AuthMiddleware(llamactl.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + } else { + handler = middleware.AuthMiddleware(llamactl.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + } + + // Execute request + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + + if recorder.Code != tt.expectedStatus { + t.Errorf("AuthMiddleware() status = %v, expected %v", recorder.Code, tt.expectedStatus) + } + + // Check that unauthorized responses have proper format + if recorder.Code == http.StatusUnauthorized { + contentType := recorder.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Unauthorized response Content-Type = %v, expected application/json", contentType) + } + + body := recorder.Body.String() + if !strings.Contains(body, `"type": "authentication_error"`) { + t.Errorf("Unauthorized response missing proper error type: %v", body) + } + } + }) + } +} + +func TestGenerateAPIKey(t *testing.T) { + tests := []struct { + name string + keyType llamactl.KeyType + }{ + {"inference key generation", llamactl.KeyTypeInference}, + {"management key generation", llamactl.KeyTypeManagement}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test auto-generation by creating config that will trigger it + var config llamactl.AuthConfig + if tt.keyType == llamactl.KeyTypeInference { + config.RequireInferenceAuth = true + config.InferenceKeys = []string{} // Empty to trigger generation + } else { + config.RequireManagementAuth = true + config.ManagementKeys = []string{} // Empty to trigger generation + } + + // Create middleware - this should trigger key generation + middleware := llamactl.NewAPIAuthMiddleware(config) + + // Test that auth is required (meaning a key was generated) + req := httptest.NewRequest("GET", "/", nil) + recorder := httptest.NewRecorder() + + var handler http.Handler + if tt.keyType == llamactl.KeyTypeInference { + handler = middleware.AuthMiddleware(llamactl.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + } else { + handler = middleware.AuthMiddleware(llamactl.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + } + + handler.ServeHTTP(recorder, req) + + // Should be unauthorized without a key (proving that a key was generated and auth is working) + if recorder.Code != http.StatusUnauthorized { + t.Errorf("Expected unauthorized without key, got status %v", recorder.Code) + } + + // Test uniqueness by creating another middleware instance + middleware2 := llamactl.NewAPIAuthMiddleware(config) + + req2 := httptest.NewRequest("GET", "/", nil) + recorder2 := httptest.NewRecorder() + + if tt.keyType == llamactl.KeyTypeInference { + handler2 := middleware2.AuthMiddleware(llamactl.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + handler2.ServeHTTP(recorder2, req2) + } else { + handler2 := middleware2.AuthMiddleware(llamactl.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + handler2.ServeHTTP(recorder2, req2) + } + + // Both should require auth (proving keys were generated for both instances) + if recorder2.Code != http.StatusUnauthorized { + t.Errorf("Expected unauthorized for second middleware without key, got status %v", recorder2.Code) + } + }) + } +} + +func TestAutoGeneration(t *testing.T) { + tests := []struct { + name string + requireInference bool + requireManagement bool + providedInference []string + providedManagement []string + shouldGenerateInf bool // Whether inference key should be generated + shouldGenerateMgmt bool // Whether management key should be generated + }{ + { + name: "inference auth required, keys provided - no generation", + requireInference: true, + requireManagement: false, + providedInference: []string{"sk-inference-provided"}, + providedManagement: []string{}, + shouldGenerateInf: false, + shouldGenerateMgmt: false, + }, + { + name: "inference auth required, no keys - should auto-generate", + requireInference: true, + requireManagement: false, + providedInference: []string{}, + providedManagement: []string{}, + shouldGenerateInf: true, + shouldGenerateMgmt: false, + }, + { + name: "management auth required, keys provided - no generation", + requireInference: false, + requireManagement: true, + providedInference: []string{}, + providedManagement: []string{"sk-management-provided"}, + shouldGenerateInf: false, + shouldGenerateMgmt: false, + }, + { + name: "management auth required, no keys - should auto-generate", + requireInference: false, + requireManagement: true, + providedInference: []string{}, + providedManagement: []string{}, + shouldGenerateInf: false, + shouldGenerateMgmt: true, + }, + { + name: "both required, both provided - no generation", + requireInference: true, + requireManagement: true, + providedInference: []string{"sk-inference-provided"}, + providedManagement: []string{"sk-management-provided"}, + shouldGenerateInf: false, + shouldGenerateMgmt: false, + }, + { + name: "both required, none provided - should auto-generate both", + requireInference: true, + requireManagement: true, + providedInference: []string{}, + providedManagement: []string{}, + shouldGenerateInf: true, + shouldGenerateMgmt: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := llamactl.AuthConfig{ + RequireInferenceAuth: tt.requireInference, + RequireManagementAuth: tt.requireManagement, + InferenceKeys: tt.providedInference, + ManagementKeys: tt.providedManagement, + } + + middleware := llamactl.NewAPIAuthMiddleware(config) + + // Test inference behavior if inference auth is required + if tt.requireInference { + req := httptest.NewRequest("GET", "/v1/models", nil) + recorder := httptest.NewRecorder() + + handler := middleware.AuthMiddleware(llamactl.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + handler.ServeHTTP(recorder, req) + + // Should always be unauthorized without a key (since middleware assumes auth is required) + if recorder.Code != http.StatusUnauthorized { + t.Errorf("Expected unauthorized for inference without key, got status %v", recorder.Code) + } + } + + // Test management behavior if management auth is required + if tt.requireManagement { + req := httptest.NewRequest("GET", "/api/v1/instances", nil) + recorder := httptest.NewRecorder() + + handler := middleware.AuthMiddleware(llamactl.KeyTypeManagement)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + handler.ServeHTTP(recorder, req) + + // Should always be unauthorized without a key (since middleware assumes auth is required) + if recorder.Code != http.StatusUnauthorized { + t.Errorf("Expected unauthorized for management without key, got status %v", recorder.Code) + } + } + }) + } +} diff --git a/pkg/routes.go b/pkg/routes.go index 998af38..f0c3c54 100644 --- a/pkg/routes.go +++ b/pkg/routes.go @@ -39,7 +39,7 @@ func SetupRouter(handler *Handler) *chi.Mux { r.Route("/api/v1", func(r chi.Router) { if authMiddleware != nil && handler.config.Auth.RequireManagementAuth { - r.Use(authMiddleware.ManagementMiddleware()) + r.Use(authMiddleware.AuthMiddleware(KeyTypeManagement)) } r.Route("/server", func(r chi.Router) { @@ -74,7 +74,7 @@ func SetupRouter(handler *Handler) *chi.Mux { r.Route(("/v1"), func(r chi.Router) { if authMiddleware != nil && handler.config.Auth.RequireInferenceAuth { - r.Use(authMiddleware.InferenceMiddleware()) + r.Use(authMiddleware.AuthMiddleware(KeyTypeInference)) } r.Get(("/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format