mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-06 16:04:23 +00:00
Fix session validation
This commit is contained in:
@@ -3,7 +3,6 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -35,8 +34,8 @@ type JWTConfig struct {
|
|||||||
|
|
||||||
// JWTManager defines the interface for managing JWT tokens
|
// JWTManager defines the interface for managing JWT tokens
|
||||||
type JWTManager interface {
|
type JWTManager interface {
|
||||||
GenerateAccessToken(userID int, role string) (string, error)
|
GenerateAccessToken(userID int, role string, sessionID string) (string, error)
|
||||||
GenerateRefreshToken(userID int, role string) (string, error)
|
GenerateRefreshToken(userID int, role string, sessionID string) (string, error)
|
||||||
ValidateToken(tokenString string) (*Claims, error)
|
ValidateToken(tokenString string) (*Claims, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,17 +61,17 @@ func NewJWTService(config JWTConfig) (JWTManager, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GenerateAccessToken creates a new access token for a user with the given userID and role
|
// GenerateAccessToken creates a new access token for a user with the given userID and role
|
||||||
func (s *jwtService) GenerateAccessToken(userID int, role string) (string, error) {
|
func (s *jwtService) GenerateAccessToken(userID int, role, sessionID string) (string, error) {
|
||||||
return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry)
|
return s.generateToken(userID, role, sessionID, AccessToken, s.config.AccessTokenExpiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateRefreshToken creates a new refresh token for a user with the given userID and role
|
// GenerateRefreshToken creates a new refresh token for a user with the given userID and role
|
||||||
func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, error) {
|
func (s *jwtService) GenerateRefreshToken(userID int, role, sessionID string) (string, error) {
|
||||||
return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry)
|
return s.generateToken(userID, role, sessionID, RefreshToken, s.config.RefreshTokenExpiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateToken is an internal helper function that creates a new JWT token
|
// generateToken is an internal helper function that creates a new JWT token
|
||||||
func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) {
|
func (s *jwtService) generateToken(userID int, role string, sessionID string, tokenType TokenType, expiry time.Duration) (string, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
// Add a random nonce to ensure uniqueness
|
// Add a random nonce to ensure uniqueness
|
||||||
@@ -86,7 +85,7 @@ func (s *jwtService) generateToken(userID int, role string, tokenType TokenType,
|
|||||||
ExpiresAt: jwt.NewNumericDate(now.Add(expiry)),
|
ExpiresAt: jwt.NewNumericDate(now.Add(expiry)),
|
||||||
IssuedAt: jwt.NewNumericDate(now),
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
ID: hex.EncodeToString(nonce),
|
ID: sessionID,
|
||||||
},
|
},
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
Role: role,
|
Role: role,
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
// Package auth_test provides tests for the auth package
|
||||||
package auth_test
|
package auth_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -98,9 +99,9 @@ func TestGenerateAndValidateToken(t *testing.T) {
|
|||||||
|
|
||||||
// Generate token based on type
|
// Generate token based on type
|
||||||
if tc.tokenType == auth.AccessToken {
|
if tc.tokenType == auth.AccessToken {
|
||||||
token, err = service.GenerateAccessToken(tc.userID, tc.role)
|
token, err = service.GenerateAccessToken(tc.userID, tc.role, "")
|
||||||
} else {
|
} else {
|
||||||
token, err = service.GenerateRefreshToken(tc.userID, tc.role)
|
token, err = service.GenerateRefreshToken(tc.userID, tc.role, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package auth_test
|
package auth_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -23,18 +24,18 @@ func newMockSessionManager() *mockSessionManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockSessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
|
func (m *mockSessionManager) CreateSession(_ int, _ string) (*models.Session, string, error) {
|
||||||
return nil, "", nil // Not needed for these tests
|
return nil, "", nil // Not needed for these tests
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockSessionManager) RefreshSession(refreshToken string) (string, error) {
|
func (m *mockSessionManager) RefreshSession(_ string) (string, error) {
|
||||||
return "", nil // Not needed for these tests
|
return "", nil // Not needed for these tests
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockSessionManager) ValidateSession(sessionID string) (*models.Session, error) {
|
func (m *mockSessionManager) ValidateSession(sessionID string) (*models.Session, error) {
|
||||||
session, exists := m.sessions[sessionID]
|
session, exists := m.sessions[sessionID]
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, nil
|
return nil, fmt.Errorf("session not found")
|
||||||
}
|
}
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
@@ -87,16 +88,16 @@ func TestAuthenticateMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
setupRequest func() *http.Request
|
setupRequest func(sessionID string) *http.Request
|
||||||
setupSession func(sessionID string)
|
setupSession func(sessionID string)
|
||||||
method string
|
method string
|
||||||
wantStatusCode int
|
wantStatusCode int
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "valid token with valid session",
|
name: "valid token with valid session",
|
||||||
setupRequest: func() *http.Request {
|
setupRequest: func(sessionID string) *http.Request {
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
token, _ := jwtService.GenerateAccessToken(1, "admin")
|
token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
|
||||||
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
||||||
req.AddCookie(cookie)
|
req.AddCookie(cookie)
|
||||||
return req
|
return req
|
||||||
@@ -113,31 +114,31 @@ func TestAuthenticateMiddleware(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid token but invalid session",
|
name: "valid token but invalid session",
|
||||||
setupRequest: func() *http.Request {
|
setupRequest: func(sessionID string) *http.Request {
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
token, _ := jwtService.GenerateAccessToken(1, "admin")
|
token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
|
||||||
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
||||||
req.AddCookie(cookie)
|
req.AddCookie(cookie)
|
||||||
return req
|
return req
|
||||||
},
|
},
|
||||||
setupSession: func(sessionID string) {}, // No session setup
|
setupSession: func(_ string) {}, // No session setup
|
||||||
method: "GET",
|
method: "GET",
|
||||||
wantStatusCode: http.StatusUnauthorized,
|
wantStatusCode: http.StatusUnauthorized,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "missing auth cookie",
|
name: "missing auth cookie",
|
||||||
setupRequest: func() *http.Request {
|
setupRequest: func(_ string) *http.Request {
|
||||||
return httptest.NewRequest("GET", "/test", nil)
|
return httptest.NewRequest("GET", "/test", nil)
|
||||||
},
|
},
|
||||||
setupSession: func(sessionID string) {},
|
setupSession: func(_ string) {},
|
||||||
method: "GET",
|
method: "GET",
|
||||||
wantStatusCode: http.StatusUnauthorized,
|
wantStatusCode: http.StatusUnauthorized,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "POST request without CSRF token",
|
name: "POST request without CSRF token",
|
||||||
setupRequest: func() *http.Request {
|
setupRequest: func(sessionID string) *http.Request {
|
||||||
req := httptest.NewRequest("POST", "/test", nil)
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
token, _ := jwtService.GenerateAccessToken(1, "admin")
|
token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
|
||||||
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
||||||
req.AddCookie(cookie)
|
req.AddCookie(cookie)
|
||||||
return req
|
return req
|
||||||
@@ -154,9 +155,9 @@ func TestAuthenticateMiddleware(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "POST request with valid CSRF token",
|
name: "POST request with valid CSRF token",
|
||||||
setupRequest: func() *http.Request {
|
setupRequest: func(sessionID string) *http.Request {
|
||||||
req := httptest.NewRequest("POST", "/test", nil)
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
token, _ := jwtService.GenerateAccessToken(1, "admin")
|
token, _ := jwtService.GenerateAccessToken(1, "admin", sessionID)
|
||||||
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
||||||
req.AddCookie(cookie)
|
req.AddCookie(cookie)
|
||||||
|
|
||||||
@@ -180,12 +181,14 @@ func TestAuthenticateMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
req := tc.setupRequest()
|
sessionID := tc.name
|
||||||
|
|
||||||
|
req := tc.setupRequest(sessionID)
|
||||||
w := newMockResponseWriter()
|
w := newMockResponseWriter()
|
||||||
|
|
||||||
// Create test handler
|
// Create test handler
|
||||||
nextCalled := false
|
nextCalled := false
|
||||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
nextCalled = true
|
nextCalled = true
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SessionManager is an interface for managing user sessions
|
||||||
type SessionManager interface {
|
type SessionManager interface {
|
||||||
CreateSession(userID int, role string) (*models.Session, string, error)
|
CreateSession(userID int, role string) (*models.Session, string, error)
|
||||||
RefreshSession(refreshToken string) (string, error)
|
RefreshSession(refreshToken string) (string, error)
|
||||||
@@ -24,6 +25,7 @@ type sessionManager struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewSessionService creates a new session service with the given database and JWT manager
|
// NewSessionService creates a new session service with the given database and JWT manager
|
||||||
|
// revive:disable:unexported-return
|
||||||
func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManager {
|
func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManager {
|
||||||
return &sessionManager{
|
return &sessionManager{
|
||||||
db: db,
|
db: db,
|
||||||
@@ -33,13 +35,17 @@ func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManage
|
|||||||
|
|
||||||
// CreateSession creates a new user session for a user with the given userID and role
|
// CreateSession creates a new user session for a user with the given userID and role
|
||||||
func (s *sessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
|
func (s *sessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
|
||||||
|
|
||||||
|
// Generate a new session ID
|
||||||
|
sessionID := uuid.New().String()
|
||||||
|
|
||||||
// Generate both access and refresh tokens
|
// Generate both access and refresh tokens
|
||||||
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
|
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role, sessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("failed to generate access token: %w", err)
|
return nil, "", fmt.Errorf("failed to generate access token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role)
|
refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role, sessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("failed to generate refresh token: %w", err)
|
return nil, "", fmt.Errorf("failed to generate refresh token: %w", err)
|
||||||
}
|
}
|
||||||
@@ -52,7 +58,7 @@ func (s *sessionManager) CreateSession(userID int, role string) (*models.Session
|
|||||||
|
|
||||||
// Create a new session record
|
// Create a new session record
|
||||||
session := &models.Session{
|
session := &models.Session{
|
||||||
ID: uuid.New().String(),
|
ID: sessionID,
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
ExpiresAt: claims.ExpiresAt.Time,
|
ExpiresAt: claims.ExpiresAt.Time,
|
||||||
@@ -87,7 +93,7 @@ func (s *sessionManager) RefreshSession(refreshToken string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate a new access token
|
// Generate a new access token
|
||||||
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
|
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role, session.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSession checks if a session with the given sessionID is valid
|
// ValidateSession checks if a session with the given sessionID is valid
|
||||||
|
|||||||
@@ -259,7 +259,7 @@ func TestRefreshSession(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "valid refresh token",
|
name: "valid refresh token",
|
||||||
setupSession: func() string {
|
setupSession: func() string {
|
||||||
token, _ := jwtService.GenerateRefreshToken(1, "admin")
|
token, _ := jwtService.GenerateRefreshToken(1, "admin", "test-session-1")
|
||||||
session := &models.Session{
|
session := &models.Session{
|
||||||
ID: "test-session-1",
|
ID: "test-session-1",
|
||||||
UserID: 1,
|
UserID: 1,
|
||||||
@@ -277,7 +277,7 @@ func TestRefreshSession(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "expired refresh token",
|
name: "expired refresh token",
|
||||||
setupSession: func() string {
|
setupSession: func() string {
|
||||||
token, _ := jwtService.GenerateRefreshToken(1, "admin")
|
token, _ := jwtService.GenerateRefreshToken(1, "admin", "test-session-2")
|
||||||
session := &models.Session{
|
session := &models.Session{
|
||||||
ID: "test-session-2",
|
ID: "test-session-2",
|
||||||
UserID: 1,
|
UserID: 1,
|
||||||
|
|||||||
Reference in New Issue
Block a user