diff --git a/server/internal/db/sessions.go b/server/internal/db/sessions.go index f5b8f81..7df51c6 100644 --- a/server/internal/db/sessions.go +++ b/server/internal/db/sessions.go @@ -10,11 +10,11 @@ import ( // CreateSession inserts a new session record into the database func (db *database) CreateSession(session *models.Session) error { - _, err := db.Exec(` - INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at) - VALUES (?, ?, ?, ?, ?)`, - session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt, - ) + query := NewQuery(db.dbType). + Insert("sessions", "id", "user_id", "refresh_token", "expires_at", "created_at"). + Values(5). + AddArgs(session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt) + _, err := db.Exec(query.String(), query.Args()...) if err != nil { return fmt.Errorf("failed to store session: %w", err) } @@ -25,12 +25,14 @@ func (db *database) CreateSession(session *models.Session) error { // GetSessionByRefreshToken retrieves a session by its refresh token func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { session := &models.Session{} - err := db.QueryRow(` - SELECT id, user_id, refresh_token, expires_at, created_at - FROM sessions - WHERE refresh_token = ? AND expires_at > ?`, - refreshToken, time.Now(), - ).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) + query := NewQuery(db.dbType). + Select("id", "user_id", "refresh_token", "expires_at", "created_at"). + From("sessions"). + Where("refresh_token = "). + Placeholder(refreshToken). + And("expires_at >"). + Placeholder(time.Now()) + err := db.QueryRow(query.String(), query.Args()...).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) if err == sql.ErrNoRows { return nil, fmt.Errorf("session not found or expired") @@ -45,12 +47,14 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi // GetSessionByID retrieves a session by its ID func (db *database) GetSessionByID(sessionID string) (*models.Session, error) { session := &models.Session{} - err := db.QueryRow(` - SELECT id, user_id, refresh_token, expires_at, created_at - FROM sessions - WHERE id = ? AND expires_at > ?`, - sessionID, time.Now(), - ).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) + query := NewQuery(db.dbType). + Select("id", "user_id", "refresh_token", "expires_at", "created_at"). + From("sessions"). + Where("id = "). + Placeholder(sessionID). + And("expires_at >"). + Placeholder(time.Now()) + err := db.QueryRow(query.String(), query.Args()...).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) if err == sql.ErrNoRows { return nil, fmt.Errorf("session not found") @@ -64,7 +68,13 @@ func (db *database) GetSessionByID(sessionID string) (*models.Session, error) { // DeleteSession removes a session from the database func (db *database) DeleteSession(sessionID string) error { - result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID) + query := NewQuery(db.dbType). + Delete(). + From("sessions"). + Where("id = "). + Placeholder(sessionID) + + result, err := db.Exec(query.String(), query.Args()...) if err != nil { return fmt.Errorf("failed to delete session: %w", err) } @@ -84,7 +94,12 @@ func (db *database) DeleteSession(sessionID string) error { // CleanExpiredSessions removes all expired sessions from the database func (db *database) CleanExpiredSessions() error { log := getLogger().WithGroup("sessions") - result, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now()) + query := NewQuery(db.dbType). + Delete(). + From("sessions"). + Where("expires_at <="). + Placeholder(time.Now()) + result, err := db.Exec(query.String(), query.Args()...) if err != nil { return fmt.Errorf("failed to clean expired sessions: %w", err) }