Fix session validation

This commit is contained in:
2024-12-08 17:13:34 +01:00
parent 69af630332
commit 2268ea48f2
5 changed files with 43 additions and 34 deletions

View File

@@ -3,7 +3,6 @@ package auth
import (
"crypto/rand"
"encoding/hex"
"fmt"
"time"
@@ -35,8 +34,8 @@ type JWTConfig struct {
// JWTManager defines the interface for managing JWT tokens
type JWTManager interface {
GenerateAccessToken(userID int, role string) (string, error)
GenerateRefreshToken(userID int, role string) (string, error)
GenerateAccessToken(userID int, role string, sessionID string) (string, error)
GenerateRefreshToken(userID int, role string, sessionID string) (string, error)
ValidateToken(tokenString string) (*Claims, error)
}
@@ -62,17 +61,17 @@ func NewJWTService(config JWTConfig) (JWTManager, error) {
}
// GenerateAccessToken creates a new access token for a user with the given userID and role
func (s *jwtService) GenerateAccessToken(userID int, role string) (string, error) {
return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry)
func (s *jwtService) GenerateAccessToken(userID int, role, sessionID string) (string, error) {
return s.generateToken(userID, role, sessionID, AccessToken, s.config.AccessTokenExpiry)
}
// GenerateRefreshToken creates a new refresh token for a user with the given userID and role
func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, error) {
return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry)
func (s *jwtService) GenerateRefreshToken(userID int, role, sessionID string) (string, error) {
return s.generateToken(userID, role, sessionID, RefreshToken, s.config.RefreshTokenExpiry)
}
// generateToken is an internal helper function that creates a new JWT token
func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) {
func (s *jwtService) generateToken(userID int, role string, sessionID string, tokenType TokenType, expiry time.Duration) (string, error) {
now := time.Now()
// Add a random nonce to ensure uniqueness
@@ -86,7 +85,7 @@ func (s *jwtService) generateToken(userID int, role string, tokenType TokenType,
ExpiresAt: jwt.NewNumericDate(now.Add(expiry)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
ID: hex.EncodeToString(nonce),
ID: sessionID,
},
UserID: userID,
Role: role,

View File

@@ -1,3 +1,4 @@
// Package auth_test provides tests for the auth package
package auth_test
import (
@@ -98,9 +99,9 @@ func TestGenerateAndValidateToken(t *testing.T) {
// Generate token based on type
if tc.tokenType == auth.AccessToken {
token, err = service.GenerateAccessToken(tc.userID, tc.role)
token, err = service.GenerateAccessToken(tc.userID, tc.role, "")
} else {
token, err = service.GenerateRefreshToken(tc.userID, tc.role)
token, err = service.GenerateRefreshToken(tc.userID, tc.role, "")
}
if err != nil {

View File

@@ -1,6 +1,7 @@
package auth_test
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
@@ -23,18 +24,18 @@ func newMockSessionManager() *mockSessionManager {
}
}
func (m *mockSessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
func (m *mockSessionManager) CreateSession(_ int, _ string) (*models.Session, string, error) {
return nil, "", nil // Not needed for these tests
}
func (m *mockSessionManager) RefreshSession(refreshToken string) (string, error) {
func (m *mockSessionManager) RefreshSession(_ string) (string, error) {
return "", nil // Not needed for these tests
}
func (m *mockSessionManager) ValidateSession(sessionID string) (*models.Session, error) {
session, exists := m.sessions[sessionID]
if !exists {
return nil, nil
return nil, fmt.Errorf("session not found")
}
return session, nil
}
@@ -87,16 +88,16 @@ func TestAuthenticateMiddleware(t *testing.T) {
testCases := []struct {
name string
setupRequest func() *http.Request
setupRequest func(sessionID string) *http.Request
setupSession func(sessionID string)
method string
wantStatusCode int
}{
{
name: "valid token with valid session",
setupRequest: func() *http.Request {
setupRequest: func(sessionID string) *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin")
token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie)
return req
@@ -113,31 +114,31 @@ func TestAuthenticateMiddleware(t *testing.T) {
},
{
name: "valid token but invalid session",
setupRequest: func() *http.Request {
setupRequest: func(sessionID string) *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin")
token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie)
return req
},
setupSession: func(sessionID string) {}, // No session setup
setupSession: func(_ string) {}, // No session setup
method: "GET",
wantStatusCode: http.StatusUnauthorized,
},
{
name: "missing auth cookie",
setupRequest: func() *http.Request {
setupRequest: func(_ string) *http.Request {
return httptest.NewRequest("GET", "/test", nil)
},
setupSession: func(sessionID string) {},
setupSession: func(_ string) {},
method: "GET",
wantStatusCode: http.StatusUnauthorized,
},
{
name: "POST request without CSRF token",
setupRequest: func() *http.Request {
setupRequest: func(sessionID string) *http.Request {
req := httptest.NewRequest("POST", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin")
token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie)
return req
@@ -154,9 +155,9 @@ func TestAuthenticateMiddleware(t *testing.T) {
},
{
name: "POST request with valid CSRF token",
setupRequest: func() *http.Request {
setupRequest: func(sessionID string) *http.Request {
req := httptest.NewRequest("POST", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin")
token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie)
@@ -180,12 +181,14 @@ func TestAuthenticateMiddleware(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := tc.setupRequest()
sessionID := tc.name
req := tc.setupRequest(sessionID)
w := newMockResponseWriter()
// Create test handler
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})

View File

@@ -9,6 +9,7 @@ import (
"github.com/google/uuid"
)
// SessionManager is an interface for managing user sessions
type SessionManager interface {
CreateSession(userID int, role string) (*models.Session, string, error)
RefreshSession(refreshToken string) (string, error)
@@ -24,6 +25,7 @@ type sessionManager struct {
}
// NewSessionService creates a new session service with the given database and JWT manager
// revive:disable:unexported-return
func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManager {
return &sessionManager{
db: db,
@@ -33,13 +35,17 @@ func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManage
// CreateSession creates a new user session for a user with the given userID and role
func (s *sessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
// Generate a new session ID
sessionID := uuid.New().String()
// Generate both access and refresh tokens
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role, sessionID)
if err != nil {
return nil, "", fmt.Errorf("failed to generate access token: %w", err)
}
refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role)
refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role, sessionID)
if err != nil {
return nil, "", fmt.Errorf("failed to generate refresh token: %w", err)
}
@@ -52,7 +58,7 @@ func (s *sessionManager) CreateSession(userID int, role string) (*models.Session
// Create a new session record
session := &models.Session{
ID: uuid.New().String(),
ID: sessionID,
UserID: userID,
RefreshToken: refreshToken,
ExpiresAt: claims.ExpiresAt.Time,
@@ -87,7 +93,7 @@ func (s *sessionManager) RefreshSession(refreshToken string) (string, error) {
}
// Generate a new access token
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role, session.ID)
}
// ValidateSession checks if a session with the given sessionID is valid

View File

@@ -259,7 +259,7 @@ func TestRefreshSession(t *testing.T) {
{
name: "valid refresh token",
setupSession: func() string {
token, _ := jwtService.GenerateRefreshToken(1, "admin")
token, _ := jwtService.GenerateRefreshToken(1, "admin", "test-session-1")
session := &models.Session{
ID: "test-session-1",
UserID: 1,
@@ -277,7 +277,7 @@ func TestRefreshSession(t *testing.T) {
{
name: "expired refresh token",
setupSession: func() string {
token, _ := jwtService.GenerateRefreshToken(1, "admin")
token, _ := jwtService.GenerateRefreshToken(1, "admin", "test-session-2")
session := &models.Session{
ID: "test-session-2",
UserID: 1,