From ad4af2f82d9449252e8322ffb4da2e087881fdab Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 7 Dec 2024 21:41:37 +0100 Subject: [PATCH] Update auth test --- server/internal/auth/middleware_test.go | 156 ++++++++++++++++--- server/internal/auth/session_test.go | 196 +++++++++++++++++------- 2 files changed, 269 insertions(+), 83 deletions(-) diff --git a/server/internal/auth/middleware_test.go b/server/internal/auth/middleware_test.go index 06c44be..983c3d3 100644 --- a/server/internal/auth/middleware_test.go +++ b/server/internal/auth/middleware_test.go @@ -3,6 +3,7 @@ package auth_test import ( "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -11,6 +12,42 @@ import ( "novamd/internal/models" ) +// Mock SessionManager +type mockSessionManager struct { + sessions map[string]*models.Session +} + +func newMockSessionManager() *mockSessionManager { + return &mockSessionManager{ + sessions: make(map[string]*models.Session), + } +} + +func (m *mockSessionManager) CreateSession(userID int, role string) (*models.Session, string, error) { + return nil, "", nil // Not needed for these tests +} + +func (m *mockSessionManager) RefreshSession(refreshToken string) (string, error) { + return "", nil // Not needed for these tests +} + +func (m *mockSessionManager) ValidateSession(sessionID string) (*models.Session, error) { + session, exists := m.sessions[sessionID] + if !exists { + return nil, nil + } + return session, nil +} + +func (m *mockSessionManager) InvalidateSession(token string) error { + delete(m.sessions, token) + return nil +} + +func (m *mockSessionManager) CleanExpiredSessions() error { + return nil +} + // Complete mockResponseWriter implementation type mockResponseWriter struct { headers http.Header @@ -44,62 +81,122 @@ func TestAuthenticateMiddleware(t *testing.T) { RefreshTokenExpiry: 24 * time.Hour, } jwtService, _ := auth.NewJWTService(config) - middleware := auth.NewMiddleware(jwtService) + sessionManager := newMockSessionManager() + cookieManager := auth.NewCookieService(true, "localhost") + middleware := auth.NewMiddleware(jwtService, sessionManager, cookieManager) testCases := []struct { name string - setupAuth func() string + setupRequest func() *http.Request + setupSession func(sessionID string) + method string wantStatusCode int }{ { - name: "valid token", - setupAuth: func() string { + name: "valid token with valid session", + setupRequest: func() *http.Request { + req := httptest.NewRequest("GET", "/test", nil) token, _ := jwtService.GenerateAccessToken(1, "admin") - return token + cookie := cookieManager.GenerateAccessTokenCookie(token) + req.AddCookie(cookie) + return req }, + setupSession: func(sessionID string) { + sessionManager.sessions[sessionID] = &models.Session{ + ID: sessionID, + UserID: 1, + ExpiresAt: time.Now().Add(15 * time.Minute), + } + }, + method: "GET", wantStatusCode: http.StatusOK, }, { - name: "missing auth header", - setupAuth: func() string { - return "" + name: "valid token but invalid session", + setupRequest: func() *http.Request { + req := httptest.NewRequest("GET", "/test", nil) + token, _ := jwtService.GenerateAccessToken(1, "admin") + cookie := cookieManager.GenerateAccessTokenCookie(token) + req.AddCookie(cookie) + return req }, + setupSession: func(sessionID string) {}, // No session setup + method: "GET", wantStatusCode: http.StatusUnauthorized, }, { - name: "invalid auth format", - setupAuth: func() string { - return "InvalidFormat token" + name: "missing auth cookie", + setupRequest: func() *http.Request { + return httptest.NewRequest("GET", "/test", nil) }, + setupSession: func(sessionID string) {}, + method: "GET", wantStatusCode: http.StatusUnauthorized, }, { - name: "invalid token", - setupAuth: func() string { - return "Bearer invalid.token.here" + name: "POST request without CSRF token", + setupRequest: func() *http.Request { + req := httptest.NewRequest("POST", "/test", nil) + token, _ := jwtService.GenerateAccessToken(1, "admin") + cookie := cookieManager.GenerateAccessTokenCookie(token) + req.AddCookie(cookie) + return req }, - wantStatusCode: http.StatusUnauthorized, + setupSession: func(sessionID string) { + sessionManager.sessions[sessionID] = &models.Session{ + ID: sessionID, + UserID: 1, + ExpiresAt: time.Now().Add(15 * time.Minute), + } + }, + method: "POST", + wantStatusCode: http.StatusForbidden, + }, + { + name: "POST request with valid CSRF token", + setupRequest: func() *http.Request { + req := httptest.NewRequest("POST", "/test", nil) + token, _ := jwtService.GenerateAccessToken(1, "admin") + cookie := cookieManager.GenerateAccessTokenCookie(token) + req.AddCookie(cookie) + + csrfToken := "test-csrf-token" + csrfCookie := cookieManager.GenerateCSRFCookie(csrfToken) + req.AddCookie(csrfCookie) + req.Header.Set("X-CSRF-Token", csrfToken) + return req + }, + setupSession: func(sessionID string) { + sessionManager.sessions[sessionID] = &models.Session{ + ID: sessionID, + UserID: 1, + ExpiresAt: time.Now().Add(15 * time.Minute), + } + }, + method: "POST", + wantStatusCode: http.StatusOK, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - // Create test request - req := httptest.NewRequest("GET", "/test", nil) - if token := tc.setupAuth(); token != "" { - req.Header.Set("Authorization", "Bearer "+token) - } - - // Create response recorder + req := tc.setupRequest() w := newMockResponseWriter() // Create test handler nextCalled := false - next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { nextCalled = true w.WriteHeader(http.StatusOK) }) + // If we have a valid token, set up the session + if cookie, err := req.Cookie("access_token"); err == nil { + if claims, err := jwtService.ValidateToken(cookie.Value); err == nil { + tc.setupSession(claims.ID) + } + } + // Execute middleware middleware.Authenticate(next).ServeHTTP(w, req) @@ -115,6 +212,15 @@ func TestAuthenticateMiddleware(t *testing.T) { if tc.wantStatusCode != http.StatusOK && nextCalled { t.Error("next handler was called when it shouldn't have been") } + + // For unauthorized responses, check if cookies were invalidated + if w.statusCode == http.StatusUnauthorized { + for _, cookie := range w.Header()["Set-Cookie"] { + if strings.Contains(cookie, "Max-Age=0") { + t.Error("cookies were not properly invalidated") + } + } + } }) } } @@ -126,7 +232,7 @@ func TestRequireRole(t *testing.T) { RefreshTokenExpiry: 24 * time.Hour, } jwtService, _ := auth.NewJWTService(config) - middleware := auth.NewMiddleware(jwtService) + middleware := auth.NewMiddleware(jwtService, &mockSessionManager{}, auth.NewCookieService(true, "localhost")) testCases := []struct { name string @@ -198,7 +304,7 @@ func TestRequireWorkspaceAccess(t *testing.T) { SigningKey: "test-key", } jwtService, _ := auth.NewJWTService(config) - middleware := auth.NewMiddleware(jwtService) + middleware := auth.NewMiddleware(jwtService, &mockSessionManager{}, auth.NewCookieService(true, "localhost")) testCases := []struct { name string diff --git a/server/internal/auth/session_test.go b/server/internal/auth/session_test.go index 5457c37..8464fae 100644 --- a/server/internal/auth/session_test.go +++ b/server/internal/auth/session_test.go @@ -13,7 +13,7 @@ import ( // Mock SessionStore type mockSessionStore struct { sessions map[string]*models.Session - sessionsByToken map[string]*models.Session // Added index by refresh token + sessionsByToken map[string]*models.Session } func newMockSessionStore() *mockSessionStore { @@ -29,6 +29,17 @@ func (m *mockSessionStore) CreateSession(session *models.Session) error { return nil } +func (m *mockSessionStore) GetSessionByID(sessionID string) (*models.Session, error) { + session, exists := m.sessions[sessionID] + if !exists { + return nil, errors.New("session not found") + } + if session.ExpiresAt.Before(time.Now()) { + return nil, errors.New("session expired") + } + return session, nil +} + func (m *mockSessionStore) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { session, exists := m.sessionsByToken[refreshToken] if !exists { @@ -111,9 +122,9 @@ func TestCreateSession(t *testing.T) { } // Verify the session was stored - storedSession, exists := mockDB.sessions[session.ID] - if !exists { - t.Error("session was not stored in database") + storedSession, err := mockDB.GetSessionByID(session.ID) + if err != nil { + t.Errorf("failed to get stored session: %v", err) } if storedSession.RefreshToken != session.RefreshToken { t.Error("stored refresh token doesn't match") @@ -138,6 +149,97 @@ func TestCreateSession(t *testing.T) { } } +func TestValidateSession(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + jwtService, _ := auth.NewJWTService(config) + mockDB := newMockSessionStore() + sessionService := auth.NewSessionService(mockDB, jwtService) + + testCases := []struct { + name string + setupSession func() string + wantErr bool + errorContains string + }{ + { + name: "valid session", + setupSession: func() string { + session := &models.Session{ + ID: "test-session-1", + UserID: 1, + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + } + if err := mockDB.CreateSession(session); err != nil { + t.Fatalf("failed to create session: %v", err) + } + + return session.ID + }, + wantErr: false, + }, + { + name: "expired session", + setupSession: func() string { + session := &models.Session{ + ID: "test-session-2", + UserID: 1, + ExpiresAt: time.Now().Add(-1 * time.Hour), + CreatedAt: time.Now().Add(-2 * time.Hour), + } + if err := mockDB.CreateSession(session); err != nil { + t.Fatalf("failed to create session: %v", err) + } + return session.ID + }, + wantErr: true, + errorContains: "session expired", + }, + { + name: "non-existent session", + setupSession: func() string { + return "non-existent-session-id" + }, + wantErr: true, + errorContains: "session not found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sessionID := tc.setupSession() + session, err := sessionService.ValidateSession(sessionID) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tc.errorContains != "" && !strings.Contains(err.Error(), tc.errorContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errorContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if session == nil { + t.Error("expected session, got nil") + return + } + + if session.ID != sessionID { + t.Errorf("session ID = %v, want %v", session.ID, sessionID) + } + }) + } +} + func TestRefreshSession(t *testing.T) { config := auth.JWTConfig{ SigningKey: "test-key", @@ -180,7 +282,7 @@ func TestRefreshSession(t *testing.T) { ID: "test-session-2", UserID: 1, RefreshToken: token, - ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired + ExpiresAt: time.Now().Add(-1 * time.Hour), CreatedAt: time.Now().Add(-2 * time.Hour), } if err := mockDB.CreateSession(session); err != nil { @@ -233,7 +335,7 @@ func TestRefreshSession(t *testing.T) { } } -func TestInvalidateSession(t *testing.T) { +func TestCleanExpiredSessions(t *testing.T) { config := auth.JWTConfig{ SigningKey: "test-key", AccessTokenExpiry: 15 * time.Minute, @@ -243,62 +345,40 @@ func TestInvalidateSession(t *testing.T) { mockDB := newMockSessionStore() sessionService := auth.NewSessionService(mockDB, jwtService) - testCases := []struct { - name string - setupSession func() string - wantErr bool - errorContains string - }{ - { - name: "valid session invalidation", - setupSession: func() string { - session := &models.Session{ - ID: "test-session-1", - UserID: 1, - RefreshToken: "valid-token", - ExpiresAt: time.Now().Add(24 * time.Hour), - CreatedAt: time.Now(), - } - if err := mockDB.CreateSession(session); err != nil { - t.Fatalf("failed to create session: %v", err) - } - return session.ID - }, - wantErr: false, - }, - { - name: "non-existent session", - setupSession: func() string { - return "non-existent-session-id" - }, - wantErr: true, - errorContains: "session not found", - }, + // Create test sessions + validSession := &models.Session{ + ID: "valid-session", + UserID: 1, + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + } + if err := mockDB.CreateSession(validSession); err != nil { + t.Fatalf("failed to create valid session: %v", err) } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - sessionID := tc.setupSession() - err := sessionService.InvalidateSession(sessionID) + expiredSession := &models.Session{ + ID: "expired-session", + UserID: 2, + ExpiresAt: time.Now().Add(-1 * time.Hour), + CreatedAt: time.Now().Add(-2 * time.Hour), + } + if err := mockDB.CreateSession(expiredSession); err != nil { + t.Fatalf("failed to create expired session: %v", err) + } - if tc.wantErr { - if err == nil { - t.Error("expected error, got nil") - } else if !strings.Contains(err.Error(), tc.errorContains) { - t.Errorf("error = %v, want error containing %v", err, tc.errorContains) - } - return - } + // Clean expired sessions + err := sessionService.CleanExpiredSessions() + if err != nil { + t.Errorf("unexpected error cleaning sessions: %v", err) + } - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } + // Verify valid session still exists + if _, err := mockDB.GetSessionByID(validSession.ID); err != nil { + t.Error("valid session was incorrectly removed") + } - // Verify session was removed - if _, exists := mockDB.sessions[sessionID]; exists { - t.Error("session still exists after invalidation") - } - }) + // Verify expired session was removed + if _, err := mockDB.GetSessionByID(expiredSession.ID); err == nil { + t.Error("expired session was not removed") } }