From 2268ea48f28044b56187d67138506f3f653041ec Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 8 Dec 2024 17:13:34 +0100 Subject: [PATCH] Fix session validation --- server/internal/auth/jwt.go | 17 ++++++------ server/internal/auth/jwt_test.go | 5 ++-- server/internal/auth/middleware_test.go | 37 +++++++++++++------------ server/internal/auth/session.go | 14 +++++++--- server/internal/auth/session_test.go | 4 +-- 5 files changed, 43 insertions(+), 34 deletions(-) diff --git a/server/internal/auth/jwt.go b/server/internal/auth/jwt.go index 66b6c24..9df0de0 100644 --- a/server/internal/auth/jwt.go +++ b/server/internal/auth/jwt.go @@ -3,7 +3,6 @@ package auth import ( "crypto/rand" - "encoding/hex" "fmt" "time" @@ -35,8 +34,8 @@ type JWTConfig struct { // JWTManager defines the interface for managing JWT tokens type JWTManager interface { - GenerateAccessToken(userID int, role string) (string, error) - GenerateRefreshToken(userID int, role string) (string, error) + GenerateAccessToken(userID int, role string, sessionID string) (string, error) + GenerateRefreshToken(userID int, role string, sessionID string) (string, error) ValidateToken(tokenString string) (*Claims, error) } @@ -62,17 +61,17 @@ func NewJWTService(config JWTConfig) (JWTManager, error) { } // GenerateAccessToken creates a new access token for a user with the given userID and role -func (s *jwtService) GenerateAccessToken(userID int, role string) (string, error) { - return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry) +func (s *jwtService) GenerateAccessToken(userID int, role, sessionID string) (string, error) { + return s.generateToken(userID, role, sessionID, AccessToken, s.config.AccessTokenExpiry) } // GenerateRefreshToken creates a new refresh token for a user with the given userID and role -func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, error) { - return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry) +func (s *jwtService) GenerateRefreshToken(userID int, role, sessionID string) (string, error) { + return s.generateToken(userID, role, sessionID, RefreshToken, s.config.RefreshTokenExpiry) } // generateToken is an internal helper function that creates a new JWT token -func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) { +func (s *jwtService) generateToken(userID int, role string, sessionID string, tokenType TokenType, expiry time.Duration) (string, error) { now := time.Now() // Add a random nonce to ensure uniqueness @@ -86,7 +85,7 @@ func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, ExpiresAt: jwt.NewNumericDate(now.Add(expiry)), IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), - ID: hex.EncodeToString(nonce), + ID: sessionID, }, UserID: userID, Role: role, diff --git a/server/internal/auth/jwt_test.go b/server/internal/auth/jwt_test.go index 61ca928..14b1a50 100644 --- a/server/internal/auth/jwt_test.go +++ b/server/internal/auth/jwt_test.go @@ -1,3 +1,4 @@ +// Package auth_test provides tests for the auth package package auth_test import ( @@ -98,9 +99,9 @@ func TestGenerateAndValidateToken(t *testing.T) { // Generate token based on type if tc.tokenType == auth.AccessToken { - token, err = service.GenerateAccessToken(tc.userID, tc.role) + token, err = service.GenerateAccessToken(tc.userID, tc.role, "") } else { - token, err = service.GenerateRefreshToken(tc.userID, tc.role) + token, err = service.GenerateRefreshToken(tc.userID, tc.role, "") } if err != nil { diff --git a/server/internal/auth/middleware_test.go b/server/internal/auth/middleware_test.go index 983c3d3..3bac129 100644 --- a/server/internal/auth/middleware_test.go +++ b/server/internal/auth/middleware_test.go @@ -1,6 +1,7 @@ package auth_test import ( + "fmt" "net/http" "net/http/httptest" "strings" @@ -23,18 +24,18 @@ func newMockSessionManager() *mockSessionManager { } } -func (m *mockSessionManager) CreateSession(userID int, role string) (*models.Session, string, error) { +func (m *mockSessionManager) CreateSession(_ int, _ string) (*models.Session, string, error) { return nil, "", nil // Not needed for these tests } -func (m *mockSessionManager) RefreshSession(refreshToken string) (string, error) { +func (m *mockSessionManager) RefreshSession(_ string) (string, error) { return "", nil // Not needed for these tests } func (m *mockSessionManager) ValidateSession(sessionID string) (*models.Session, error) { session, exists := m.sessions[sessionID] if !exists { - return nil, nil + return nil, fmt.Errorf("session not found") } return session, nil } @@ -87,16 +88,16 @@ func TestAuthenticateMiddleware(t *testing.T) { testCases := []struct { name string - setupRequest func() *http.Request + setupRequest func(sessionID string) *http.Request setupSession func(sessionID string) method string wantStatusCode int }{ { name: "valid token with valid session", - setupRequest: func() *http.Request { + setupRequest: func(sessionID string) *http.Request { req := httptest.NewRequest("GET", "/test", nil) - token, _ := jwtService.GenerateAccessToken(1, "admin") + token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID) cookie := cookieManager.GenerateAccessTokenCookie(token) req.AddCookie(cookie) return req @@ -113,31 +114,31 @@ func TestAuthenticateMiddleware(t *testing.T) { }, { name: "valid token but invalid session", - setupRequest: func() *http.Request { + setupRequest: func(sessionID string) *http.Request { req := httptest.NewRequest("GET", "/test", nil) - token, _ := jwtService.GenerateAccessToken(1, "admin") + token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID) cookie := cookieManager.GenerateAccessTokenCookie(token) req.AddCookie(cookie) return req }, - setupSession: func(sessionID string) {}, // No session setup + setupSession: func(_ string) {}, // No session setup method: "GET", wantStatusCode: http.StatusUnauthorized, }, { name: "missing auth cookie", - setupRequest: func() *http.Request { + setupRequest: func(_ string) *http.Request { return httptest.NewRequest("GET", "/test", nil) }, - setupSession: func(sessionID string) {}, + setupSession: func(_ string) {}, method: "GET", wantStatusCode: http.StatusUnauthorized, }, { name: "POST request without CSRF token", - setupRequest: func() *http.Request { + setupRequest: func(sessionID string) *http.Request { req := httptest.NewRequest("POST", "/test", nil) - token, _ := jwtService.GenerateAccessToken(1, "admin") + token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID) cookie := cookieManager.GenerateAccessTokenCookie(token) req.AddCookie(cookie) return req @@ -154,9 +155,9 @@ func TestAuthenticateMiddleware(t *testing.T) { }, { name: "POST request with valid CSRF token", - setupRequest: func() *http.Request { + setupRequest: func(sessionID string) *http.Request { req := httptest.NewRequest("POST", "/test", nil) - token, _ := jwtService.GenerateAccessToken(1, "admin") + token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID) cookie := cookieManager.GenerateAccessTokenCookie(token) req.AddCookie(cookie) @@ -180,12 +181,14 @@ func TestAuthenticateMiddleware(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - req := tc.setupRequest() + sessionID := tc.name + + req := tc.setupRequest(sessionID) w := newMockResponseWriter() // Create test handler nextCalled := false - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { nextCalled = true w.WriteHeader(http.StatusOK) }) diff --git a/server/internal/auth/session.go b/server/internal/auth/session.go index 21d0090..bb3df67 100644 --- a/server/internal/auth/session.go +++ b/server/internal/auth/session.go @@ -9,6 +9,7 @@ import ( "github.com/google/uuid" ) +// SessionManager is an interface for managing user sessions type SessionManager interface { CreateSession(userID int, role string) (*models.Session, string, error) RefreshSession(refreshToken string) (string, error) @@ -24,6 +25,7 @@ type sessionManager struct { } // NewSessionService creates a new session service with the given database and JWT manager +// revive:disable:unexported-return func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManager { return &sessionManager{ db: db, @@ -33,13 +35,17 @@ func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManage // CreateSession creates a new user session for a user with the given userID and role func (s *sessionManager) CreateSession(userID int, role string) (*models.Session, string, error) { + + // Generate a new session ID + sessionID := uuid.New().String() + // Generate both access and refresh tokens - accessToken, err := s.jwtManager.GenerateAccessToken(userID, role) + accessToken, err := s.jwtManager.GenerateAccessToken(userID, role, sessionID) if err != nil { return nil, "", fmt.Errorf("failed to generate access token: %w", err) } - refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role) + refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role, sessionID) if err != nil { return nil, "", fmt.Errorf("failed to generate refresh token: %w", err) } @@ -52,7 +58,7 @@ func (s *sessionManager) CreateSession(userID int, role string) (*models.Session // Create a new session record session := &models.Session{ - ID: uuid.New().String(), + ID: sessionID, UserID: userID, RefreshToken: refreshToken, ExpiresAt: claims.ExpiresAt.Time, @@ -87,7 +93,7 @@ func (s *sessionManager) RefreshSession(refreshToken string) (string, error) { } // Generate a new access token - return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role) + return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role, session.ID) } // ValidateSession checks if a session with the given sessionID is valid diff --git a/server/internal/auth/session_test.go b/server/internal/auth/session_test.go index 8464fae..91410fb 100644 --- a/server/internal/auth/session_test.go +++ b/server/internal/auth/session_test.go @@ -259,7 +259,7 @@ func TestRefreshSession(t *testing.T) { { name: "valid refresh token", setupSession: func() string { - token, _ := jwtService.GenerateRefreshToken(1, "admin") + token, _ := jwtService.GenerateRefreshToken(1, "admin", "test-session-1") session := &models.Session{ ID: "test-session-1", UserID: 1, @@ -277,7 +277,7 @@ func TestRefreshSession(t *testing.T) { { name: "expired refresh token", setupSession: func() string { - token, _ := jwtService.GenerateRefreshToken(1, "admin") + token, _ := jwtService.GenerateRefreshToken(1, "admin", "test-session-2") session := &models.Session{ ID: "test-session-2", UserID: 1,