mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-06 07:54:22 +00:00
Rework db package to make it testable
This commit is contained in:
@@ -1,33 +1,25 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/models"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Session represents a user session in the database
|
||||
type Session struct {
|
||||
ID string // Unique session identifier
|
||||
UserID int // ID of the user this session belongs to
|
||||
RefreshToken string // The refresh token associated with this session
|
||||
ExpiresAt time.Time // When this session expires
|
||||
CreatedAt time.Time // When this session was created
|
||||
}
|
||||
|
||||
// SessionService manages user sessions in the database
|
||||
type SessionService struct {
|
||||
db *sql.DB // Database connection
|
||||
jwtManager JWTManager // JWT Manager for token operations
|
||||
db db.SessionStore // Database store for sessions
|
||||
jwtManager JWTManager // JWT Manager for token operations
|
||||
}
|
||||
|
||||
// NewSessionService creates a new session service
|
||||
// Parameters:
|
||||
// - db: database connection
|
||||
// - jwtManager: JWT service for token operations
|
||||
func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService {
|
||||
func NewSessionService(db db.SessionStore, jwtManager JWTManager) *SessionService {
|
||||
return &SessionService{
|
||||
db: db,
|
||||
jwtManager: jwtManager,
|
||||
@@ -42,7 +34,7 @@ func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService {
|
||||
// - session: the created session
|
||||
// - accessToken: a new access token
|
||||
// - error: any error that occurred
|
||||
func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) {
|
||||
func (s *SessionService) CreateSession(userID int, role string) (*models.Session, string, error) {
|
||||
// Generate both access and refresh tokens
|
||||
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
|
||||
if err != nil {
|
||||
@@ -61,7 +53,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
|
||||
}
|
||||
|
||||
// Create a new session record
|
||||
session := &Session{
|
||||
session := &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: userID,
|
||||
RefreshToken: refreshToken,
|
||||
@@ -69,14 +61,9 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Store the session in the database
|
||||
_, err = s.db.Exec(`
|
||||
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to store session: %w", err)
|
||||
// Store the session
|
||||
if err := s.db.CreateSession(session); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return session, accessToken, nil
|
||||
@@ -89,28 +76,18 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
|
||||
// - string: a new access token
|
||||
// - error: any error that occurred
|
||||
func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
||||
// Get session from database
|
||||
_, err := s.db.GetSessionByRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid session: %w", err)
|
||||
}
|
||||
|
||||
// Validate the refresh token
|
||||
claims, err := s.jwtManager.ValidateToken(refreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Check if the session exists and is not expired
|
||||
var session Session
|
||||
err = s.db.QueryRow(`
|
||||
SELECT id, user_id, refresh_token, expires_at, created_at
|
||||
FROM sessions
|
||||
WHERE refresh_token = ? AND expires_at > ?`,
|
||||
refreshToken, time.Now(),
|
||||
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("session not found or expired")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch session: %w", err)
|
||||
}
|
||||
|
||||
// Generate a new access token
|
||||
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
|
||||
}
|
||||
@@ -121,20 +98,12 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
||||
// Returns:
|
||||
// - error: any error that occurred
|
||||
func (s *SessionService) InvalidateSession(sessionID string) error {
|
||||
_, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to invalidate session: %w", err)
|
||||
}
|
||||
return nil
|
||||
return s.db.DeleteSession(sessionID)
|
||||
}
|
||||
|
||||
// CleanExpiredSessions removes all expired sessions from the database
|
||||
// Returns:
|
||||
// - error: any error that occurred
|
||||
func (s *SessionService) CleanExpiredSessions() error {
|
||||
_, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clean expired sessions: %w", err)
|
||||
}
|
||||
return nil
|
||||
return s.db.CleanExpiredSessions()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user