diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go index ec86764..a4d8ff7 100644 --- a/server/internal/db/struct_query.go +++ b/server/internal/db/struct_query.go @@ -142,17 +142,13 @@ func (q *Query) InsertStruct(s any, table string) (*Query, error) { return q, nil } -func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (*Query, error) { +func (q *Query) UpdateStruct(s any, table string) (*Query, error) { fields, err := StructTagsToFields(s) if err != nil { return nil, err } - if len(where) != len(args) { - return nil, fmt.Errorf("number of where clauses does not match number of arguments") - } - - q = q.Update("users") + q = q.Update(table) for _, f := range fields { value := f.Value @@ -172,13 +168,6 @@ func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (* q = q.Set(f.Name).Placeholder(value) } - for i, w := range where { - if i != 0 && i < len(args) { - q = q.And(w) - } - q = q.Where(w).Placeholder(args[i]) - } - return q, nil } @@ -291,7 +280,6 @@ func (db *database) ScanStructs(rows *sql.Rows, destSlice any) error { if rows == nil { return fmt.Errorf("rows cannot be nil") } - defer rows.Close() // Get the slice value and element type sliceVal := reflect.ValueOf(destSlice) diff --git a/server/internal/db/workspaces.go b/server/internal/db/workspaces.go index 1783ad3..348e841 100644 --- a/server/internal/db/workspaces.go +++ b/server/internal/db/workspaces.go @@ -89,7 +89,8 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error { query := db.NewQuery() query, err := query. - UpdateStruct(workspace, "workspaces", []string{"id =", "user_id ="}, []any{workspace.ID, workspace.UserID}) + UpdateStruct(workspace, "workspaces") + query = query.Where("id =").Placeholder(workspace.ID).And("user_id =").Placeholder(workspace.UserID) if err != nil { return fmt.Errorf("failed to create query: %w", err) @@ -120,7 +121,7 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro defer rows.Close() var workspaces []*models.Workspace - err = db.ScanStructs(rows, workspaces) + err = db.ScanStructs(rows, &workspaces) if err != nil { return nil, fmt.Errorf("failed to scan workspaces: %w", err) } @@ -131,12 +132,10 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro // UpdateWorkspaceSettings updates only the settings portion of a workspace func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { - where := []string{"id ="} - args := []any{workspace.ID} - query := db.NewQuery() query, err := query. - UpdateStruct(workspace, "workspaces", where, args) + UpdateStruct(workspace, "workspaces") + query = query.Where("id =").Placeholder(workspace.ID) if err != nil { return fmt.Errorf("failed to create query: %w", err) @@ -265,7 +264,7 @@ func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { defer rows.Close() var workspaces []*models.Workspace - err = db.ScanStructs(rows, workspaces) + err = db.ScanStructs(rows, &workspaces) if err != nil { return nil, fmt.Errorf("failed to scan workspaces: %w", err) }