Rework db package to make it testable

This commit is contained in:
2024-11-21 22:36:12 +01:00
parent 2faefb6db5
commit 807e96a76c
16 changed files with 274 additions and 161 deletions

View File

@@ -1,33 +1,25 @@
package auth
import (
"database/sql"
"fmt"
"novamd/internal/db"
"novamd/internal/models"
"time"
"github.com/google/uuid"
)
// Session represents a user session in the database
type Session struct {
ID string // Unique session identifier
UserID int // ID of the user this session belongs to
RefreshToken string // The refresh token associated with this session
ExpiresAt time.Time // When this session expires
CreatedAt time.Time // When this session was created
}
// SessionService manages user sessions in the database
type SessionService struct {
db *sql.DB // Database connection
jwtManager JWTManager // JWT Manager for token operations
db db.SessionStore // Database store for sessions
jwtManager JWTManager // JWT Manager for token operations
}
// NewSessionService creates a new session service
// Parameters:
// - db: database connection
// - jwtManager: JWT service for token operations
func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService {
func NewSessionService(db db.SessionStore, jwtManager JWTManager) *SessionService {
return &SessionService{
db: db,
jwtManager: jwtManager,
@@ -42,7 +34,7 @@ func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService {
// - session: the created session
// - accessToken: a new access token
// - error: any error that occurred
func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) {
func (s *SessionService) CreateSession(userID int, role string) (*models.Session, string, error) {
// Generate both access and refresh tokens
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
if err != nil {
@@ -61,7 +53,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
}
// Create a new session record
session := &Session{
session := &models.Session{
ID: uuid.New().String(),
UserID: userID,
RefreshToken: refreshToken,
@@ -69,14 +61,9 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
CreatedAt: time.Now(),
}
// Store the session in the database
_, err = s.db.Exec(`
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
VALUES (?, ?, ?, ?, ?)`,
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
)
if err != nil {
return nil, "", fmt.Errorf("failed to store session: %w", err)
// Store the session
if err := s.db.CreateSession(session); err != nil {
return nil, "", err
}
return session, accessToken, nil
@@ -89,28 +76,18 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
// - string: a new access token
// - error: any error that occurred
func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
// Get session from database
_, err := s.db.GetSessionByRefreshToken(refreshToken)
if err != nil {
return "", fmt.Errorf("invalid session: %w", err)
}
// Validate the refresh token
claims, err := s.jwtManager.ValidateToken(refreshToken)
if err != nil {
return "", fmt.Errorf("invalid refresh token: %w", err)
}
// Check if the session exists and is not expired
var session Session
err = s.db.QueryRow(`
SELECT id, user_id, refresh_token, expires_at, created_at
FROM sessions
WHERE refresh_token = ? AND expires_at > ?`,
refreshToken, time.Now(),
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
if err == sql.ErrNoRows {
return "", fmt.Errorf("session not found or expired")
}
if err != nil {
return "", fmt.Errorf("failed to fetch session: %w", err)
}
// Generate a new access token
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
}
@@ -121,20 +98,12 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
// Returns:
// - error: any error that occurred
func (s *SessionService) InvalidateSession(sessionID string) error {
_, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
if err != nil {
return fmt.Errorf("failed to invalidate session: %w", err)
}
return nil
return s.db.DeleteSession(sessionID)
}
// CleanExpiredSessions removes all expired sessions from the database
// Returns:
// - error: any error that occurred
func (s *SessionService) CleanExpiredSessions() error {
_, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
if err != nil {
return fmt.Errorf("failed to clean expired sessions: %w", err)
}
return nil
return s.db.CleanExpiredSessions()
}