From 3aa8c838e8d4fb4eea03ffff756611b91ff88faf Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 5 Mar 2025 21:20:57 +0100 Subject: [PATCH] Use struct queries in users --- server/internal/db/users.go | 72 +++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 39 deletions(-) diff --git a/server/internal/db/users.go b/server/internal/db/users.go index 5d0cc55..f6b00e2 100644 --- a/server/internal/db/users.go +++ b/server/internal/db/users.go @@ -94,16 +94,16 @@ func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) e // GetUserByID retrieves a user by its ID func (db *database) GetUserByID(id int) (*models.User, error) { - query := db.NewQuery(). - Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id"). - From("users"). - Where("id = ").Placeholder(id) - user := &models.User{} - err := db.QueryRow(query.String(), query.Args()...). - Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash, - &user.Role, &user.CreatedAt, &user.LastWorkspaceID) + query := db.NewQuery() + query, err := query.SelectStruct(user, "users") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + query = query.Where("id = ").Placeholder(id) + row := db.QueryRow(query.String(), query.Args()...) + err = db.ScanStruct(row, user) if err == sql.ErrNoRows { return nil, fmt.Errorf("user not found") } @@ -115,15 +115,16 @@ func (db *database) GetUserByID(id int) (*models.User, error) { // GetUserByEmail retrieves a user by its email func (db *database) GetUserByEmail(email string) (*models.User, error) { - query := db.NewQuery(). - Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id"). - From("users"). - Where("email = ").Placeholder(email) - user := &models.User{} - err := db.QueryRow(query.String(), query.Args()...). - Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash, - &user.Role, &user.CreatedAt, &user.LastWorkspaceID) + query := db.NewQuery() + query, err := query.SelectStruct(user, "users") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + + query = query.Where("email = ").Placeholder(email) + row := db.QueryRow(query.String(), query.Args()...) + err = db.ScanStruct(row, user) if err == sql.ErrNoRows { return nil, fmt.Errorf("user not found") @@ -137,14 +138,12 @@ func (db *database) GetUserByEmail(email string) (*models.User, error) { // UpdateUser updates an existing user record in the database func (db *database) UpdateUser(user *models.User) error { - query := db.NewQuery(). - Update("users"). - Set("email").Placeholder(user.Email). - Set("display_name").Placeholder(user.DisplayName). - Set("password_hash").Placeholder(user.PasswordHash). - Set("role").Placeholder(user.Role). - Set("last_workspace_id").Placeholder(user.LastWorkspaceID). - Where("id = ").Placeholder(user.ID) + query := db.NewQuery() + query, err := query.UpdateStruct(user, "users") + if err != nil { + return fmt.Errorf("failed to create query: %w", err) + } + query = query.Where("id = ").Placeholder(user.ID) result, err := db.Exec(query.String(), query.Args()...) if err != nil { @@ -165,10 +164,12 @@ func (db *database) UpdateUser(user *models.User) error { // GetAllUsers retrieves all users from the database func (db *database) GetAllUsers() ([]*models.User, error) { - query := db.NewQuery(). - Select("id", "email", "display_name", "role", "created_at", "last_workspace_id"). - From("users"). - OrderBy("id ASC") + query := db.NewQuery() + query, err := query.SelectStruct(&models.User{}, "users") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + query = query.OrderBy("id ASC") rows, err := db.Query(query.String(), query.Args()...) if err != nil { @@ -176,17 +177,10 @@ func (db *database) GetAllUsers() ([]*models.User, error) { } defer rows.Close() - var users []*models.User - for rows.Next() { - user := &models.User{} - err := rows.Scan( - &user.ID, &user.Email, &user.DisplayName, &user.Role, - &user.CreatedAt, &user.LastWorkspaceID, - ) - if err != nil { - return nil, fmt.Errorf("failed to scan user row: %w", err) - } - users = append(users, user) + users := []*models.User{} + err = db.ScanStructs(rows, &users) + if err != nil { + return nil, fmt.Errorf("failed to scan users: %w", err) } return users, nil