Implement update struct

This commit is contained in:
2025-03-02 18:40:12 +01:00
parent 204dacd15e
commit ccac439465
8 changed files with 127 additions and 131 deletions

View File

@@ -195,29 +195,6 @@ func (db *database) Close() error {
return nil return nil
} }
// Helper methods for token encryption/decryption func (db *database) NewQuery() *Query {
func (db *database) encryptToken(token string) (string, error) { return NewQuery(db.dbType, db.secretsService)
if token == "" {
return "", nil
}
encrypted, err := db.secretsService.Encrypt(token)
if err != nil {
return "", fmt.Errorf("failed to encrypt token: %w", err)
}
return encrypted, nil
}
func (db *database) decryptToken(token string) (string, error) {
if token == "" {
return "", nil
}
decrypted, err := db.secretsService.Decrypt(token)
if err != nil {
return "", fmt.Errorf("failed to decrypt token: %w", err)
}
return decrypted, nil
} }

View File

@@ -2,6 +2,7 @@ package db
import ( import (
"fmt" "fmt"
"lemma/internal/secrets"
"strings" "strings"
) )
@@ -18,6 +19,7 @@ type Query struct {
builder strings.Builder builder strings.Builder
args []any args []any
dbType DBType dbType DBType
secretsService secrets.Service
pos int // tracks the current placeholder position pos int // tracks the current placeholder position
hasSelect bool hasSelect bool
hasFrom bool hasFrom bool
@@ -32,9 +34,10 @@ type Query struct {
} }
// NewQuery creates a new Query instance // NewQuery creates a new Query instance
func NewQuery(dbType DBType) *Query { func NewQuery(dbType DBType, secretsService secrets.Service) *Query {
return &Query{ return &Query{
dbType: dbType, dbType: dbType,
secretsService: secretsService,
args: make([]any, 0), args: make([]any, 0),
} }
} }

View File

@@ -24,7 +24,7 @@ func TestNewQuery(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
// Test that a new query is empty // Test that a new query is empty
if q.String() != "" { if q.String() != "" {
@@ -120,7 +120,7 @@ func TestBasicQueryBuilding(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()
@@ -215,7 +215,7 @@ func TestPlaceholders(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()
@@ -328,7 +328,7 @@ func TestWhereClauseBuilding(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()
@@ -403,7 +403,7 @@ func TestJoinClauseBuilding(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()
@@ -482,7 +482,7 @@ func TestOrderLimitOffset(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()
@@ -575,7 +575,7 @@ func TestInsertUpdateDelete(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()
@@ -641,7 +641,7 @@ func TestHavingClause(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()
@@ -790,7 +790,7 @@ func TestQueryReturning(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
query := db.NewQuery(tc.dbType) query := db.NewQuery(tc.dbType, &mockSecrets{})
result := tc.buildQuery(query) result := tc.buildQuery(query)
if result.String() != tc.expectedSQL { if result.String() != tc.expectedSQL {
@@ -838,7 +838,7 @@ func TestComplexQueries(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()

View File

@@ -10,8 +10,8 @@ import (
// CreateSession inserts a new session record into the database // CreateSession inserts a new session record into the database
func (db *database) CreateSession(session *models.Session) error { func (db *database) CreateSession(session *models.Session) error {
query, err := NewQuery(db.dbType). query, err := db.NewQuery().
InsertStruct(session, "sessions", db.secretsService) InsertStruct(session, "sessions")
if err != nil { if err != nil {
return fmt.Errorf("failed to create query: %w", err) return fmt.Errorf("failed to create query: %w", err)
} }
@@ -26,7 +26,7 @@ func (db *database) CreateSession(session *models.Session) error {
// GetSessionByRefreshToken retrieves a session by its refresh token // GetSessionByRefreshToken retrieves a session by its refresh token
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
session := &models.Session{} session := &models.Session{}
query := NewQuery(db.dbType). query := db.NewQuery().
Select("id", "user_id", "refresh_token", "expires_at", "created_at"). Select("id", "user_id", "refresh_token", "expires_at", "created_at").
From("sessions"). From("sessions").
Where("refresh_token = "). Where("refresh_token = ").
@@ -48,7 +48,7 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi
// GetSessionByID retrieves a session by its ID // GetSessionByID retrieves a session by its ID
func (db *database) GetSessionByID(sessionID string) (*models.Session, error) { func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
session := &models.Session{} session := &models.Session{}
query := NewQuery(db.dbType). query := db.NewQuery().
Select("id", "user_id", "refresh_token", "expires_at", "created_at"). Select("id", "user_id", "refresh_token", "expires_at", "created_at").
From("sessions"). From("sessions").
Where("id = "). Where("id = ").
@@ -69,7 +69,7 @@ func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
// DeleteSession removes a session from the database // DeleteSession removes a session from the database
func (db *database) DeleteSession(sessionID string) error { func (db *database) DeleteSession(sessionID string) error {
query := NewQuery(db.dbType). query := db.NewQuery().
Delete(). Delete().
From("sessions"). From("sessions").
Where("id = "). Where("id = ").
@@ -95,7 +95,7 @@ func (db *database) DeleteSession(sessionID string) error {
// CleanExpiredSessions removes all expired sessions from the database // CleanExpiredSessions removes all expired sessions from the database
func (db *database) CleanExpiredSessions() error { func (db *database) CleanExpiredSessions() error {
log := getLogger().WithGroup("sessions") log := getLogger().WithGroup("sessions")
query := NewQuery(db.dbType). query := db.NewQuery().
Delete(). Delete().
From("sessions"). From("sessions").
Where("expires_at <="). Where("expires_at <=").

View File

@@ -2,7 +2,6 @@ package db
import ( import (
"fmt" "fmt"
"lemma/internal/secrets"
"reflect" "reflect"
"strings" "strings"
"unicode" "unicode"
@@ -98,7 +97,7 @@ func toSnakeCase(s string) string {
return res return res
} }
func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service) (*Query, error) { func (q *Query) InsertStruct(s any, table string) (*Query, error) {
fields, err := StructTagsToFields(s) fields, err := StructTagsToFields(s)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -115,7 +114,7 @@ func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service
} }
if f.encrypted { if f.encrypted {
encValue, err := secretsService.Encrypt(value.(string)) encValue, err := q.secretsService.Encrypt(value.(string))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -133,3 +132,43 @@ func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service
q.Insert(table, columns...).Values(len(columns)).AddArgs(values...) q.Insert(table, columns...).Values(len(columns)).AddArgs(values...)
return q, nil return q, nil
} }
func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (*Query, error) {
fields, err := StructTagsToFields(s)
if err != nil {
return nil, err
}
if len(where) != len(args) {
return nil, fmt.Errorf("number of where clauses does not match number of arguments")
}
q = q.Update("users")
for _, f := range fields {
value := f.Value
if f.useDefault {
continue
}
if f.encrypted {
encValue, err := q.secretsService.Encrypt(value.(string))
if err != nil {
return nil, err
}
value = encValue
}
q = q.Set(f.Name).Placeholder(value)
}
for i, w := range where {
if i != 0 && i < len(args) {
q = q.And(w)
}
q = q.Where(w).Placeholder(args[i])
}
return q, nil
}

View File

@@ -49,7 +49,7 @@ func (db *database) EnsureJWTSecret() (string, error) {
// GetSystemSetting retrieves a system setting by key // GetSystemSetting retrieves a system setting by key
func (db *database) GetSystemSetting(key string) (string, error) { func (db *database) GetSystemSetting(key string) (string, error) {
var value string var value string
query := NewQuery(db.dbType). query := db.NewQuery().
Select("value"). Select("value").
From("system_settings"). From("system_settings").
Where("key = "). Where("key = ").
@@ -64,7 +64,7 @@ func (db *database) GetSystemSetting(key string) (string, error) {
// SetSystemSetting stores or updates a system setting // SetSystemSetting stores or updates a system setting
func (db *database) SetSystemSetting(key, value string) error { func (db *database) SetSystemSetting(key, value string) error {
query := NewQuery(db.dbType). query := db.NewQuery().
Insert("system_settings", "key", "value"). Insert("system_settings", "key", "value").
Values(2). Values(2).
AddArgs(key, value). AddArgs(key, value).
@@ -100,7 +100,7 @@ func (db *database) GetSystemStats() (*UserStats, error) {
stats := &UserStats{} stats := &UserStats{}
// Get total users // Get total users
query := NewQuery(db.dbType). query := db.NewQuery().
Select("COUNT(*)"). Select("COUNT(*)").
From("users") From("users")
err := db.QueryRow(query.String()).Scan(&stats.TotalUsers) err := db.QueryRow(query.String()).Scan(&stats.TotalUsers)
@@ -109,7 +109,7 @@ func (db *database) GetSystemStats() (*UserStats, error) {
} }
// Get total workspaces // Get total workspaces
query = NewQuery(db.dbType). query = db.NewQuery().
Select("COUNT(*)"). Select("COUNT(*)").
From("workspaces") From("workspaces")
err = db.QueryRow(query.String()).Scan(&stats.TotalWorkspaces) err = db.QueryRow(query.String()).Scan(&stats.TotalWorkspaces)
@@ -118,7 +118,7 @@ func (db *database) GetSystemStats() (*UserStats, error) {
} }
// Get active users (users with activity in last 30 days) // Get active users (users with activity in last 30 days)
query = NewQuery(db.dbType). query = db.NewQuery().
Select("COUNT(DISTINCT user_id)"). Select("COUNT(DISTINCT user_id)").
From("sessions"). From("sessions").
Where("created_at > datetime('now', '-30 days')") Where("created_at > datetime('now', '-30 days')")

View File

@@ -17,8 +17,8 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
} }
defer tx.Rollback() defer tx.Rollback()
query, err := NewQuery(db.dbType). query, err := db.NewQuery().
InsertStruct(user, "users", db.secretsService) InsertStruct(user, "users")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create query: %w", err) return nil, fmt.Errorf("failed to create query: %w", err)
@@ -46,7 +46,7 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
} }
// Update user's last workspace ID // Update user's last workspace ID
query = NewQuery(db.dbType). query = db.NewQuery().
Update("users"). Update("users").
Set("last_workspace_id"). Set("last_workspace_id").
Placeholder(defaultWorkspace.ID). Placeholder(defaultWorkspace.ID).
@@ -72,8 +72,8 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error { func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
log := getLogger().WithGroup("users") log := getLogger().WithGroup("users")
insertQuery, err := NewQuery(db.dbType). insertQuery, err := db.NewQuery().
InsertStruct(workspace, "workspaces", db.secretsService) InsertStruct(workspace, "workspaces")
if err != nil { if err != nil {
return fmt.Errorf("failed to create query: %w", err) return fmt.Errorf("failed to create query: %w", err)
@@ -94,7 +94,7 @@ func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) e
// GetUserByID retrieves a user by its ID // GetUserByID retrieves a user by its ID
func (db *database) GetUserByID(id int) (*models.User, error) { func (db *database) GetUserByID(id int) (*models.User, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id"). Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id").
From("users"). From("users").
Where("id = ").Placeholder(id) Where("id = ").Placeholder(id)
@@ -115,7 +115,7 @@ func (db *database) GetUserByID(id int) (*models.User, error) {
// GetUserByEmail retrieves a user by its email // GetUserByEmail retrieves a user by its email
func (db *database) GetUserByEmail(email string) (*models.User, error) { func (db *database) GetUserByEmail(email string) (*models.User, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id"). Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id").
From("users"). From("users").
Where("email = ").Placeholder(email) Where("email = ").Placeholder(email)
@@ -137,7 +137,7 @@ func (db *database) GetUserByEmail(email string) (*models.User, error) {
// UpdateUser updates an existing user record in the database // UpdateUser updates an existing user record in the database
func (db *database) UpdateUser(user *models.User) error { func (db *database) UpdateUser(user *models.User) error {
query := NewQuery(db.dbType). query := db.NewQuery().
Update("users"). Update("users").
Set("email").Placeholder(user.Email). Set("email").Placeholder(user.Email).
Set("display_name").Placeholder(user.DisplayName). Set("display_name").Placeholder(user.DisplayName).
@@ -165,7 +165,7 @@ func (db *database) UpdateUser(user *models.User) error {
// GetAllUsers retrieves all users from the database // GetAllUsers retrieves all users from the database
func (db *database) GetAllUsers() ([]*models.User, error) { func (db *database) GetAllUsers() ([]*models.User, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select("id", "email", "display_name", "role", "created_at", "last_workspace_id"). Select("id", "email", "display_name", "role", "created_at", "last_workspace_id").
From("users"). From("users").
OrderBy("id ASC") OrderBy("id ASC")
@@ -200,7 +200,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error
defer tx.Rollback() defer tx.Rollback()
// Find workspace ID from name // Find workspace ID from name
workspaceQuery := NewQuery(db.dbType). workspaceQuery := db.NewQuery().
Select("id"). Select("id").
From("workspaces"). From("workspaces").
Where("user_id = ").Placeholder(userID). Where("user_id = ").Placeholder(userID).
@@ -213,7 +213,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error
} }
// Update user's last workspace // Update user's last workspace
updateQuery := NewQuery(db.dbType). updateQuery := db.NewQuery().
Update("users"). Update("users").
Set("last_workspace_id").Placeholder(workspaceID). Set("last_workspace_id").Placeholder(workspaceID).
Where("id = ").Placeholder(userID) Where("id = ").Placeholder(userID)
@@ -245,7 +245,7 @@ func (db *database) DeleteUser(id int) error {
// Delete all user's workspaces first // Delete all user's workspaces first
log.Debug("deleting user workspaces", "user_id", id) log.Debug("deleting user workspaces", "user_id", id)
deleteWorkspacesQuery := NewQuery(db.dbType). deleteWorkspacesQuery := db.NewQuery().
Delete(). Delete().
From("workspaces"). From("workspaces").
Where("user_id = ").Placeholder(id) Where("user_id = ").Placeholder(id)
@@ -256,7 +256,7 @@ func (db *database) DeleteUser(id int) error {
} }
// Delete the user // Delete the user
deleteUserQuery := NewQuery(db.dbType). deleteUserQuery := db.NewQuery().
Delete(). Delete().
From("users"). From("users").
Where("id = ").Placeholder(id) Where("id = ").Placeholder(id)
@@ -277,7 +277,7 @@ func (db *database) DeleteUser(id int) error {
// GetLastWorkspaceName retrieves the name of the last workspace accessed by a user // GetLastWorkspaceName retrieves the name of the last workspace accessed by a user
func (db *database) GetLastWorkspaceName(userID int) (string, error) { func (db *database) GetLastWorkspaceName(userID int) (string, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select("w.name"). Select("w.name").
From("workspaces w"). From("workspaces w").
Join(InnerJoin, "users u", "u.last_workspace_id = w.id"). Join(InnerJoin, "users u", "u.last_workspace_id = w.id").
@@ -298,7 +298,7 @@ func (db *database) GetLastWorkspaceName(userID int) (string, error) {
// CountAdminUsers returns the number of admin users in the system // CountAdminUsers returns the number of admin users in the system
func (db *database) CountAdminUsers() (int, error) { func (db *database) CountAdminUsers() (int, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select("COUNT(*)"). Select("COUNT(*)").
From("users"). From("users").
Where("role = ").Placeholder(models.RoleAdmin) Where("role = ").Placeholder(models.RoleAdmin)

View File

@@ -19,7 +19,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error {
workspace.SetDefaultSettings() workspace.SetDefaultSettings()
} }
query, err := NewQuery(db.dbType). query, err := db.NewQuery().
InsertStruct(workspace, "workspaces", db.secretsService) InsertStruct(workspace, "workspaces", db.secretsService)
if err != nil { if err != nil {
@@ -39,7 +39,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error {
// GetWorkspaceByID retrieves a workspace by its ID // GetWorkspaceByID retrieves a workspace by its ID
func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select( Select(
"id", "user_id", "name", "created_at", "id", "user_id", "name", "created_at",
"theme", "auto_save", "show_hidden_files", "theme", "auto_save", "show_hidden_files",
@@ -86,7 +86,7 @@ func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
// GetWorkspaceByName retrieves a workspace by its name and user ID // GetWorkspaceByName retrieves a workspace by its name and user ID
func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) { func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select( Select(
"id", "user_id", "name", "created_at", "id", "user_id", "name", "created_at",
"theme", "auto_save", "show_hidden_files", "theme", "auto_save", "show_hidden_files",
@@ -133,28 +133,14 @@ func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*model
// UpdateWorkspace updates a workspace record in the database // UpdateWorkspace updates a workspace record in the database
func (db *database) UpdateWorkspace(workspace *models.Workspace) error { func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
// Encrypt token before storing
encryptedToken, err := db.encryptToken(workspace.GitToken)
if err != nil {
return fmt.Errorf("failed to encrypt token: %w", err)
}
query := NewQuery(db.dbType). query := db.NewQuery()
Update("workspaces"). query, err := query.
Set("name").Placeholder(workspace.Name). UpdateStruct(workspace, "workspaces", []string{"id =", "user_id ="}, []interface{}{workspace.ID, workspace.UserID})
Set("theme").Placeholder(workspace.Theme).
Set("auto_save").Placeholder(workspace.AutoSave). if err != nil {
Set("show_hidden_files").Placeholder(workspace.ShowHiddenFiles). return fmt.Errorf("failed to create query: %w", err)
Set("git_enabled").Placeholder(workspace.GitEnabled). }
Set("git_url").Placeholder(workspace.GitURL).
Set("git_user").Placeholder(workspace.GitUser).
Set("git_token").Placeholder(encryptedToken).
Set("git_auto_commit").Placeholder(workspace.GitAutoCommit).
Set("git_commit_msg_template").Placeholder(workspace.GitCommitMsgTemplate).
Set("git_commit_name").Placeholder(workspace.GitCommitName).
Set("git_commit_email").Placeholder(workspace.GitCommitEmail).
Where("id = ").Placeholder(workspace.ID).
And("user_id = ").Placeholder(workspace.UserID)
_, err = db.Exec(query.String(), query.Args()...) _, err = db.Exec(query.String(), query.Args()...)
if err != nil { if err != nil {
@@ -166,7 +152,7 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
// GetWorkspacesByUserID retrieves all workspaces for a user // GetWorkspacesByUserID retrieves all workspaces for a user
func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select( Select(
"id", "user_id", "name", "created_at", "id", "user_id", "name", "created_at",
"theme", "auto_save", "show_hidden_files", "theme", "auto_save", "show_hidden_files",
@@ -222,26 +208,17 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro
// UpdateWorkspaceSettings updates only the settings portion of a workspace // UpdateWorkspaceSettings updates only the settings portion of a workspace
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
// Encrypt token before storing
encryptedToken, err := db.encryptToken(workspace.GitToken)
if err != nil {
return fmt.Errorf("failed to encrypt token: %w", err)
}
query := NewQuery(db.dbType). where := []string{"id ="}
Update("workspaces"). args := []interface{}{workspace.ID}
Set("theme").Placeholder(workspace.Theme).
Set("auto_save").Placeholder(workspace.AutoSave). query := db.NewQuery()
Set("show_hidden_files").Placeholder(workspace.ShowHiddenFiles). query, err := query.
Set("git_enabled").Placeholder(workspace.GitEnabled). UpdateStruct(workspace, "workspaces", where, args)
Set("git_url").Placeholder(workspace.GitURL).
Set("git_user").Placeholder(workspace.GitUser). if err != nil {
Set("git_token").Placeholder(encryptedToken). return fmt.Errorf("failed to create query: %w", err)
Set("git_auto_commit").Placeholder(workspace.GitAutoCommit). }
Set("git_commit_msg_template").Placeholder(workspace.GitCommitMsgTemplate).
Set("git_commit_name").Placeholder(workspace.GitCommitName).
Set("git_commit_email").Placeholder(workspace.GitCommitEmail).
Where("id = ").Placeholder(workspace.ID)
_, err = db.Exec(query.String(), query.Args()...) _, err = db.Exec(query.String(), query.Args()...)
if err != nil { if err != nil {
@@ -255,7 +232,7 @@ func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
func (db *database) DeleteWorkspace(id int) error { func (db *database) DeleteWorkspace(id int) error {
log := getLogger().WithGroup("workspaces") log := getLogger().WithGroup("workspaces")
query := NewQuery(db.dbType). query := db.NewQuery().
Delete(). Delete().
From("workspaces"). From("workspaces").
Where("id = ").Placeholder(id) Where("id = ").Placeholder(id)
@@ -273,7 +250,7 @@ func (db *database) DeleteWorkspace(id int) error {
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error { func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
log := getLogger().WithGroup("workspaces") log := getLogger().WithGroup("workspaces")
query := NewQuery(db.dbType). query := db.NewQuery().
Delete(). Delete().
From("workspaces"). From("workspaces").
Where("id = ").Placeholder(id) Where("id = ").Placeholder(id)
@@ -294,7 +271,7 @@ func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
// UpdateLastWorkspaceTx sets the last workspace for a user in a transaction // UpdateLastWorkspaceTx sets the last workspace for a user in a transaction
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error { func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
query := NewQuery(db.dbType). query := db.NewQuery().
Update("users"). Update("users").
Set("last_workspace_id").Placeholder(workspaceID). Set("last_workspace_id").Placeholder(workspaceID).
Where("id = ").Placeholder(userID) Where("id = ").Placeholder(userID)
@@ -314,7 +291,7 @@ func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) e
// UpdateLastOpenedFile updates the last opened file path for a workspace // UpdateLastOpenedFile updates the last opened file path for a workspace
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error { func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
query := NewQuery(db.dbType). query := db.NewQuery().
Update("workspaces"). Update("workspaces").
Set("last_opened_file_path").Placeholder(filePath). Set("last_opened_file_path").Placeholder(filePath).
Where("id = ").Placeholder(workspaceID) Where("id = ").Placeholder(workspaceID)
@@ -329,7 +306,7 @@ func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error
// GetLastOpenedFile retrieves the last opened file path for a workspace // GetLastOpenedFile retrieves the last opened file path for a workspace
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) { func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select("last_opened_file_path"). Select("last_opened_file_path").
From("workspaces"). From("workspaces").
Where("id = ").Placeholder(workspaceID) Where("id = ").Placeholder(workspaceID)
@@ -353,7 +330,7 @@ func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
// GetAllWorkspaces retrieves all workspaces in the database // GetAllWorkspaces retrieves all workspaces in the database
func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select( Select(
"id", "user_id", "name", "created_at", "id", "user_id", "name", "created_at",
"theme", "auto_save", "show_hidden_files", "theme", "auto_save", "show_hidden_files",