mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-05 15:44:21 +00:00
Fix session validation
This commit is contained in:
@@ -3,7 +3,6 @@ package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -35,8 +34,8 @@ type JWTConfig struct {
|
||||
|
||||
// JWTManager defines the interface for managing JWT tokens
|
||||
type JWTManager interface {
|
||||
GenerateAccessToken(userID int, role string) (string, error)
|
||||
GenerateRefreshToken(userID int, role string) (string, error)
|
||||
GenerateAccessToken(userID int, role string, sessionID string) (string, error)
|
||||
GenerateRefreshToken(userID int, role string, sessionID string) (string, error)
|
||||
ValidateToken(tokenString string) (*Claims, error)
|
||||
}
|
||||
|
||||
@@ -62,17 +61,17 @@ func NewJWTService(config JWTConfig) (JWTManager, error) {
|
||||
}
|
||||
|
||||
// GenerateAccessToken creates a new access token for a user with the given userID and role
|
||||
func (s *jwtService) GenerateAccessToken(userID int, role string) (string, error) {
|
||||
return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry)
|
||||
func (s *jwtService) GenerateAccessToken(userID int, role, sessionID string) (string, error) {
|
||||
return s.generateToken(userID, role, sessionID, AccessToken, s.config.AccessTokenExpiry)
|
||||
}
|
||||
|
||||
// GenerateRefreshToken creates a new refresh token for a user with the given userID and role
|
||||
func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, error) {
|
||||
return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry)
|
||||
func (s *jwtService) GenerateRefreshToken(userID int, role, sessionID string) (string, error) {
|
||||
return s.generateToken(userID, role, sessionID, RefreshToken, s.config.RefreshTokenExpiry)
|
||||
}
|
||||
|
||||
// generateToken is an internal helper function that creates a new JWT token
|
||||
func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) {
|
||||
func (s *jwtService) generateToken(userID int, role string, sessionID string, tokenType TokenType, expiry time.Duration) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
// Add a random nonce to ensure uniqueness
|
||||
@@ -86,7 +85,7 @@ func (s *jwtService) generateToken(userID int, role string, tokenType TokenType,
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(expiry)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
ID: hex.EncodeToString(nonce),
|
||||
ID: sessionID,
|
||||
},
|
||||
UserID: userID,
|
||||
Role: role,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package auth_test provides tests for the auth package
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
@@ -98,9 +99,9 @@ func TestGenerateAndValidateToken(t *testing.T) {
|
||||
|
||||
// Generate token based on type
|
||||
if tc.tokenType == auth.AccessToken {
|
||||
token, err = service.GenerateAccessToken(tc.userID, tc.role)
|
||||
token, err = service.GenerateAccessToken(tc.userID, tc.role, "")
|
||||
} else {
|
||||
token, err = service.GenerateRefreshToken(tc.userID, tc.role)
|
||||
token, err = service.GenerateRefreshToken(tc.userID, tc.role, "")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -23,18 +24,18 @@ func newMockSessionManager() *mockSessionManager {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
|
||||
func (m *mockSessionManager) CreateSession(_ int, _ string) (*models.Session, string, error) {
|
||||
return nil, "", nil // Not needed for these tests
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) RefreshSession(refreshToken string) (string, error) {
|
||||
func (m *mockSessionManager) RefreshSession(_ string) (string, error) {
|
||||
return "", nil // Not needed for these tests
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) ValidateSession(sessionID string) (*models.Session, error) {
|
||||
session, exists := m.sessions[sessionID]
|
||||
if !exists {
|
||||
return nil, nil
|
||||
return nil, fmt.Errorf("session not found")
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
@@ -87,16 +88,16 @@ func TestAuthenticateMiddleware(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupRequest func() *http.Request
|
||||
setupRequest func(sessionID string) *http.Request
|
||||
setupSession func(sessionID string)
|
||||
method string
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "valid token with valid session",
|
||||
setupRequest: func() *http.Request {
|
||||
setupRequest: func(sessionID string) *http.Request {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
token, _ := jwtService.GenerateAccessToken(1, "admin")
|
||||
token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
|
||||
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
||||
req.AddCookie(cookie)
|
||||
return req
|
||||
@@ -113,31 +114,31 @@ func TestAuthenticateMiddleware(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "valid token but invalid session",
|
||||
setupRequest: func() *http.Request {
|
||||
setupRequest: func(sessionID string) *http.Request {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
token, _ := jwtService.GenerateAccessToken(1, "admin")
|
||||
token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
|
||||
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
||||
req.AddCookie(cookie)
|
||||
return req
|
||||
},
|
||||
setupSession: func(sessionID string) {}, // No session setup
|
||||
setupSession: func(_ string) {}, // No session setup
|
||||
method: "GET",
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "missing auth cookie",
|
||||
setupRequest: func() *http.Request {
|
||||
setupRequest: func(_ string) *http.Request {
|
||||
return httptest.NewRequest("GET", "/test", nil)
|
||||
},
|
||||
setupSession: func(sessionID string) {},
|
||||
setupSession: func(_ string) {},
|
||||
method: "GET",
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "POST request without CSRF token",
|
||||
setupRequest: func() *http.Request {
|
||||
setupRequest: func(sessionID string) *http.Request {
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
token, _ := jwtService.GenerateAccessToken(1, "admin")
|
||||
token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
|
||||
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
||||
req.AddCookie(cookie)
|
||||
return req
|
||||
@@ -154,9 +155,9 @@ func TestAuthenticateMiddleware(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "POST request with valid CSRF token",
|
||||
setupRequest: func() *http.Request {
|
||||
setupRequest: func(sessionID string) *http.Request {
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
token, _ := jwtService.GenerateAccessToken(1, "admin")
|
||||
token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
|
||||
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
||||
req.AddCookie(cookie)
|
||||
|
||||
@@ -180,12 +181,14 @@ func TestAuthenticateMiddleware(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := tc.setupRequest()
|
||||
sessionID := tc.name
|
||||
|
||||
req := tc.setupRequest(sessionID)
|
||||
w := newMockResponseWriter()
|
||||
|
||||
// Create test handler
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// SessionManager is an interface for managing user sessions
|
||||
type SessionManager interface {
|
||||
CreateSession(userID int, role string) (*models.Session, string, error)
|
||||
RefreshSession(refreshToken string) (string, error)
|
||||
@@ -24,6 +25,7 @@ type sessionManager struct {
|
||||
}
|
||||
|
||||
// NewSessionService creates a new session service with the given database and JWT manager
|
||||
// revive:disable:unexported-return
|
||||
func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManager {
|
||||
return &sessionManager{
|
||||
db: db,
|
||||
@@ -33,13 +35,17 @@ func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManage
|
||||
|
||||
// CreateSession creates a new user session for a user with the given userID and role
|
||||
func (s *sessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
|
||||
|
||||
// Generate a new session ID
|
||||
sessionID := uuid.New().String()
|
||||
|
||||
// Generate both access and refresh tokens
|
||||
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
|
||||
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role, sessionID)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to generate access token: %w", err)
|
||||
}
|
||||
|
||||
refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role)
|
||||
refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role, sessionID)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to generate refresh token: %w", err)
|
||||
}
|
||||
@@ -52,7 +58,7 @@ func (s *sessionManager) CreateSession(userID int, role string) (*models.Session
|
||||
|
||||
// Create a new session record
|
||||
session := &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
ID: sessionID,
|
||||
UserID: userID,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresAt: claims.ExpiresAt.Time,
|
||||
@@ -87,7 +93,7 @@ func (s *sessionManager) RefreshSession(refreshToken string) (string, error) {
|
||||
}
|
||||
|
||||
// Generate a new access token
|
||||
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
|
||||
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role, session.ID)
|
||||
}
|
||||
|
||||
// ValidateSession checks if a session with the given sessionID is valid
|
||||
|
||||
@@ -259,7 +259,7 @@ func TestRefreshSession(t *testing.T) {
|
||||
{
|
||||
name: "valid refresh token",
|
||||
setupSession: func() string {
|
||||
token, _ := jwtService.GenerateRefreshToken(1, "admin")
|
||||
token, _ := jwtService.GenerateRefreshToken(1, "admin", "test-session-1")
|
||||
session := &models.Session{
|
||||
ID: "test-session-1",
|
||||
UserID: 1,
|
||||
@@ -277,7 +277,7 @@ func TestRefreshSession(t *testing.T) {
|
||||
{
|
||||
name: "expired refresh token",
|
||||
setupSession: func() string {
|
||||
token, _ := jwtService.GenerateRefreshToken(1, "admin")
|
||||
token, _ := jwtService.GenerateRefreshToken(1, "admin", "test-session-2")
|
||||
session := &models.Session{
|
||||
ID: "test-session-2",
|
||||
UserID: 1,
|
||||
|
||||
Reference in New Issue
Block a user