Fix session validation

This commit is contained in:
2024-12-08 17:13:34 +01:00
parent 69af630332
commit 2268ea48f2
5 changed files with 43 additions and 34 deletions

View File

@@ -3,7 +3,6 @@ package auth
import ( import (
"crypto/rand" "crypto/rand"
"encoding/hex"
"fmt" "fmt"
"time" "time"
@@ -35,8 +34,8 @@ type JWTConfig struct {
// JWTManager defines the interface for managing JWT tokens // JWTManager defines the interface for managing JWT tokens
type JWTManager interface { type JWTManager interface {
GenerateAccessToken(userID int, role string) (string, error) GenerateAccessToken(userID int, role string, sessionID string) (string, error)
GenerateRefreshToken(userID int, role string) (string, error) GenerateRefreshToken(userID int, role string, sessionID string) (string, error)
ValidateToken(tokenString string) (*Claims, error) ValidateToken(tokenString string) (*Claims, error)
} }
@@ -62,17 +61,17 @@ func NewJWTService(config JWTConfig) (JWTManager, error) {
} }
// GenerateAccessToken creates a new access token for a user with the given userID and role // GenerateAccessToken creates a new access token for a user with the given userID and role
func (s *jwtService) GenerateAccessToken(userID int, role string) (string, error) { func (s *jwtService) GenerateAccessToken(userID int, role, sessionID string) (string, error) {
return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry) return s.generateToken(userID, role, sessionID, AccessToken, s.config.AccessTokenExpiry)
} }
// GenerateRefreshToken creates a new refresh token for a user with the given userID and role // GenerateRefreshToken creates a new refresh token for a user with the given userID and role
func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, error) { func (s *jwtService) GenerateRefreshToken(userID int, role, sessionID string) (string, error) {
return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry) return s.generateToken(userID, role, sessionID, RefreshToken, s.config.RefreshTokenExpiry)
} }
// generateToken is an internal helper function that creates a new JWT token // generateToken is an internal helper function that creates a new JWT token
func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) { func (s *jwtService) generateToken(userID int, role string, sessionID string, tokenType TokenType, expiry time.Duration) (string, error) {
now := time.Now() now := time.Now()
// Add a random nonce to ensure uniqueness // Add a random nonce to ensure uniqueness
@@ -86,7 +85,7 @@ func (s *jwtService) generateToken(userID int, role string, tokenType TokenType,
ExpiresAt: jwt.NewNumericDate(now.Add(expiry)), ExpiresAt: jwt.NewNumericDate(now.Add(expiry)),
IssuedAt: jwt.NewNumericDate(now), IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now),
ID: hex.EncodeToString(nonce), ID: sessionID,
}, },
UserID: userID, UserID: userID,
Role: role, Role: role,

View File

@@ -1,3 +1,4 @@
// Package auth_test provides tests for the auth package
package auth_test package auth_test
import ( import (
@@ -98,9 +99,9 @@ func TestGenerateAndValidateToken(t *testing.T) {
// Generate token based on type // Generate token based on type
if tc.tokenType == auth.AccessToken { if tc.tokenType == auth.AccessToken {
token, err = service.GenerateAccessToken(tc.userID, tc.role) token, err = service.GenerateAccessToken(tc.userID, tc.role, "")
} else { } else {
token, err = service.GenerateRefreshToken(tc.userID, tc.role) token, err = service.GenerateRefreshToken(tc.userID, tc.role, "")
} }
if err != nil { if err != nil {

View File

@@ -1,6 +1,7 @@
package auth_test package auth_test
import ( import (
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@@ -23,18 +24,18 @@ func newMockSessionManager() *mockSessionManager {
} }
} }
func (m *mockSessionManager) CreateSession(userID int, role string) (*models.Session, string, error) { func (m *mockSessionManager) CreateSession(_ int, _ string) (*models.Session, string, error) {
return nil, "", nil // Not needed for these tests return nil, "", nil // Not needed for these tests
} }
func (m *mockSessionManager) RefreshSession(refreshToken string) (string, error) { func (m *mockSessionManager) RefreshSession(_ string) (string, error) {
return "", nil // Not needed for these tests return "", nil // Not needed for these tests
} }
func (m *mockSessionManager) ValidateSession(sessionID string) (*models.Session, error) { func (m *mockSessionManager) ValidateSession(sessionID string) (*models.Session, error) {
session, exists := m.sessions[sessionID] session, exists := m.sessions[sessionID]
if !exists { if !exists {
return nil, nil return nil, fmt.Errorf("session not found")
} }
return session, nil return session, nil
} }
@@ -87,16 +88,16 @@ func TestAuthenticateMiddleware(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
setupRequest func() *http.Request setupRequest func(sessionID string) *http.Request
setupSession func(sessionID string) setupSession func(sessionID string)
method string method string
wantStatusCode int wantStatusCode int
}{ }{
{ {
name: "valid token with valid session", name: "valid token with valid session",
setupRequest: func() *http.Request { setupRequest: func(sessionID string) *http.Request {
req := httptest.NewRequest("GET", "/test", nil) req := httptest.NewRequest("GET", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin") token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
cookie := cookieManager.GenerateAccessTokenCookie(token) cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie) req.AddCookie(cookie)
return req return req
@@ -113,31 +114,31 @@ func TestAuthenticateMiddleware(t *testing.T) {
}, },
{ {
name: "valid token but invalid session", name: "valid token but invalid session",
setupRequest: func() *http.Request { setupRequest: func(sessionID string) *http.Request {
req := httptest.NewRequest("GET", "/test", nil) req := httptest.NewRequest("GET", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin") token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
cookie := cookieManager.GenerateAccessTokenCookie(token) cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie) req.AddCookie(cookie)
return req return req
}, },
setupSession: func(sessionID string) {}, // No session setup setupSession: func(_ string) {}, // No session setup
method: "GET", method: "GET",
wantStatusCode: http.StatusUnauthorized, wantStatusCode: http.StatusUnauthorized,
}, },
{ {
name: "missing auth cookie", name: "missing auth cookie",
setupRequest: func() *http.Request { setupRequest: func(_ string) *http.Request {
return httptest.NewRequest("GET", "/test", nil) return httptest.NewRequest("GET", "/test", nil)
}, },
setupSession: func(sessionID string) {}, setupSession: func(_ string) {},
method: "GET", method: "GET",
wantStatusCode: http.StatusUnauthorized, wantStatusCode: http.StatusUnauthorized,
}, },
{ {
name: "POST request without CSRF token", name: "POST request without CSRF token",
setupRequest: func() *http.Request { setupRequest: func(sessionID string) *http.Request {
req := httptest.NewRequest("POST", "/test", nil) req := httptest.NewRequest("POST", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin") token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
cookie := cookieManager.GenerateAccessTokenCookie(token) cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie) req.AddCookie(cookie)
return req return req
@@ -154,9 +155,9 @@ func TestAuthenticateMiddleware(t *testing.T) {
}, },
{ {
name: "POST request with valid CSRF token", name: "POST request with valid CSRF token",
setupRequest: func() *http.Request { setupRequest: func(sessionID string) *http.Request {
req := httptest.NewRequest("POST", "/test", nil) req := httptest.NewRequest("POST", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin") token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
cookie := cookieManager.GenerateAccessTokenCookie(token) cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie) req.AddCookie(cookie)
@@ -180,12 +181,14 @@ func TestAuthenticateMiddleware(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req := tc.setupRequest() sessionID := tc.name
req := tc.setupRequest(sessionID)
w := newMockResponseWriter() w := newMockResponseWriter()
// Create test handler // Create test handler
nextCalled := false nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
nextCalled = true nextCalled = true
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })

View File

@@ -9,6 +9,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
// SessionManager is an interface for managing user sessions
type SessionManager interface { type SessionManager interface {
CreateSession(userID int, role string) (*models.Session, string, error) CreateSession(userID int, role string) (*models.Session, string, error)
RefreshSession(refreshToken string) (string, error) RefreshSession(refreshToken string) (string, error)
@@ -24,6 +25,7 @@ type sessionManager struct {
} }
// NewSessionService creates a new session service with the given database and JWT manager // NewSessionService creates a new session service with the given database and JWT manager
// revive:disable:unexported-return
func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManager { func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManager {
return &sessionManager{ return &sessionManager{
db: db, db: db,
@@ -33,13 +35,17 @@ func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManage
// CreateSession creates a new user session for a user with the given userID and role // CreateSession creates a new user session for a user with the given userID and role
func (s *sessionManager) CreateSession(userID int, role string) (*models.Session, string, error) { func (s *sessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
// Generate a new session ID
sessionID := uuid.New().String()
// Generate both access and refresh tokens // Generate both access and refresh tokens
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role) accessToken, err := s.jwtManager.GenerateAccessToken(userID, role, sessionID)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("failed to generate access token: %w", err) return nil, "", fmt.Errorf("failed to generate access token: %w", err)
} }
refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role) refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role, sessionID)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("failed to generate refresh token: %w", err) return nil, "", fmt.Errorf("failed to generate refresh token: %w", err)
} }
@@ -52,7 +58,7 @@ func (s *sessionManager) CreateSession(userID int, role string) (*models.Session
// Create a new session record // Create a new session record
session := &models.Session{ session := &models.Session{
ID: uuid.New().String(), ID: sessionID,
UserID: userID, UserID: userID,
RefreshToken: refreshToken, RefreshToken: refreshToken,
ExpiresAt: claims.ExpiresAt.Time, ExpiresAt: claims.ExpiresAt.Time,
@@ -87,7 +93,7 @@ func (s *sessionManager) RefreshSession(refreshToken string) (string, error) {
} }
// Generate a new access token // Generate a new access token
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role) return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role, session.ID)
} }
// ValidateSession checks if a session with the given sessionID is valid // ValidateSession checks if a session with the given sessionID is valid

View File

@@ -259,7 +259,7 @@ func TestRefreshSession(t *testing.T) {
{ {
name: "valid refresh token", name: "valid refresh token",
setupSession: func() string { setupSession: func() string {
token, _ := jwtService.GenerateRefreshToken(1, "admin") token, _ := jwtService.GenerateRefreshToken(1, "admin", "test-session-1")
session := &models.Session{ session := &models.Session{
ID: "test-session-1", ID: "test-session-1",
UserID: 1, UserID: 1,
@@ -277,7 +277,7 @@ func TestRefreshSession(t *testing.T) {
{ {
name: "expired refresh token", name: "expired refresh token",
setupSession: func() string { setupSession: func() string {
token, _ := jwtService.GenerateRefreshToken(1, "admin") token, _ := jwtService.GenerateRefreshToken(1, "admin", "test-session-2")
session := &models.Session{ session := &models.Session{
ID: "test-session-2", ID: "test-session-2",
UserID: 1, UserID: 1,