From 976425d660d18eba7e669410cc8d82e6b18d0386 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 5 Mar 2025 21:07:05 +0100 Subject: [PATCH] Use ScanStruct in sessions --- server/internal/db/sessions.go | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/server/internal/db/sessions.go b/server/internal/db/sessions.go index fdfcc7d..4d69115 100644 --- a/server/internal/db/sessions.go +++ b/server/internal/db/sessions.go @@ -26,15 +26,18 @@ func (db *database) CreateSession(session *models.Session) error { // GetSessionByRefreshToken retrieves a session by its refresh token func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { session := &models.Session{} - query := db.NewQuery(). - Select("id", "user_id", "refresh_token", "expires_at", "created_at"). - From("sessions"). - Where("refresh_token = "). + query := db.NewQuery() + query, err := query.SelectStruct(session, "sessions") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + query = query.Where("refresh_token = "). Placeholder(refreshToken). And("expires_at >"). Placeholder(time.Now()) - err := db.QueryRow(query.String(), query.Args()...).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) + row := db.QueryRow(query.String(), query.Args()...) + err = db.ScanStruct(row, session) if err == sql.ErrNoRows { return nil, fmt.Errorf("session not found or expired") } @@ -48,15 +51,18 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi // GetSessionByID retrieves a session by its ID func (db *database) GetSessionByID(sessionID string) (*models.Session, error) { session := &models.Session{} - query := db.NewQuery(). - Select("id", "user_id", "refresh_token", "expires_at", "created_at"). - From("sessions"). - Where("id = "). + query := db.NewQuery() + query, err := query.SelectStruct(session, "sessions") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + query = query.Where("id = "). Placeholder(sessionID). And("expires_at >"). Placeholder(time.Now()) - err := db.QueryRow(query.String(), query.Args()...).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) + row := db.QueryRow(query.String(), query.Args()...) + err = db.ScanStruct(row, session) if err == sql.ErrNoRows { return nil, fmt.Errorf("session not found") }