Rework UpdateStruct

This commit is contained in:
2025-03-03 22:04:38 +01:00
parent 829b359e82
commit 0f97927219
2 changed files with 8 additions and 21 deletions

View File

@@ -142,17 +142,13 @@ func (q *Query) InsertStruct(s any, table string) (*Query, error) {
return q, nil return q, nil
} }
func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (*Query, error) { func (q *Query) UpdateStruct(s any, table string) (*Query, error) {
fields, err := StructTagsToFields(s) fields, err := StructTagsToFields(s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(where) != len(args) { q = q.Update(table)
return nil, fmt.Errorf("number of where clauses does not match number of arguments")
}
q = q.Update("users")
for _, f := range fields { for _, f := range fields {
value := f.Value value := f.Value
@@ -172,13 +168,6 @@ func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (*
q = q.Set(f.Name).Placeholder(value) q = q.Set(f.Name).Placeholder(value)
} }
for i, w := range where {
if i != 0 && i < len(args) {
q = q.And(w)
}
q = q.Where(w).Placeholder(args[i])
}
return q, nil return q, nil
} }
@@ -291,7 +280,6 @@ func (db *database) ScanStructs(rows *sql.Rows, destSlice any) error {
if rows == nil { if rows == nil {
return fmt.Errorf("rows cannot be nil") return fmt.Errorf("rows cannot be nil")
} }
defer rows.Close()
// Get the slice value and element type // Get the slice value and element type
sliceVal := reflect.ValueOf(destSlice) sliceVal := reflect.ValueOf(destSlice)

View File

@@ -89,7 +89,8 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
query := db.NewQuery() query := db.NewQuery()
query, err := query. query, err := query.
UpdateStruct(workspace, "workspaces", []string{"id =", "user_id ="}, []any{workspace.ID, workspace.UserID}) UpdateStruct(workspace, "workspaces")
query = query.Where("id =").Placeholder(workspace.ID).And("user_id =").Placeholder(workspace.UserID)
if err != nil { if err != nil {
return fmt.Errorf("failed to create query: %w", err) return fmt.Errorf("failed to create query: %w", err)
@@ -120,7 +121,7 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro
defer rows.Close() defer rows.Close()
var workspaces []*models.Workspace var workspaces []*models.Workspace
err = db.ScanStructs(rows, workspaces) err = db.ScanStructs(rows, &workspaces)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to scan workspaces: %w", err) return nil, fmt.Errorf("failed to scan workspaces: %w", err)
} }
@@ -131,12 +132,10 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro
// UpdateWorkspaceSettings updates only the settings portion of a workspace // UpdateWorkspaceSettings updates only the settings portion of a workspace
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
where := []string{"id ="}
args := []any{workspace.ID}
query := db.NewQuery() query := db.NewQuery()
query, err := query. query, err := query.
UpdateStruct(workspace, "workspaces", where, args) UpdateStruct(workspace, "workspaces")
query = query.Where("id =").Placeholder(workspace.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to create query: %w", err) return fmt.Errorf("failed to create query: %w", err)
@@ -265,7 +264,7 @@ func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
defer rows.Close() defer rows.Close()
var workspaces []*models.Workspace var workspaces []*models.Workspace
err = db.ScanStructs(rows, workspaces) err = db.ScanStructs(rows, &workspaces)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to scan workspaces: %w", err) return nil, fmt.Errorf("failed to scan workspaces: %w", err)
} }