Implement update struct

This commit is contained in:
2025-03-02 18:40:12 +01:00
parent 204dacd15e
commit ccac439465
8 changed files with 127 additions and 131 deletions

View File

@@ -195,29 +195,6 @@ func (db *database) Close() error {
return nil return nil
} }
// Helper methods for token encryption/decryption func (db *database) NewQuery() *Query {
func (db *database) encryptToken(token string) (string, error) { return NewQuery(db.dbType, db.secretsService)
if token == "" {
return "", nil
}
encrypted, err := db.secretsService.Encrypt(token)
if err != nil {
return "", fmt.Errorf("failed to encrypt token: %w", err)
}
return encrypted, nil
}
func (db *database) decryptToken(token string) (string, error) {
if token == "" {
return "", nil
}
decrypted, err := db.secretsService.Decrypt(token)
if err != nil {
return "", fmt.Errorf("failed to decrypt token: %w", err)
}
return decrypted, nil
} }

View File

@@ -2,6 +2,7 @@ package db
import ( import (
"fmt" "fmt"
"lemma/internal/secrets"
"strings" "strings"
) )
@@ -15,27 +16,29 @@ const (
// Query represents a SQL query with its parameters // Query represents a SQL query with its parameters
type Query struct { type Query struct {
builder strings.Builder builder strings.Builder
args []any args []any
dbType DBType dbType DBType
pos int // tracks the current placeholder position secretsService secrets.Service
hasSelect bool pos int // tracks the current placeholder position
hasFrom bool hasSelect bool
hasWhere bool hasFrom bool
hasOrderBy bool hasWhere bool
hasGroupBy bool hasOrderBy bool
hasHaving bool hasGroupBy bool
hasLimit bool hasHaving bool
hasOffset bool hasLimit bool
isInParens bool hasOffset bool
parensDepth int isInParens bool
parensDepth int
} }
// NewQuery creates a new Query instance // NewQuery creates a new Query instance
func NewQuery(dbType DBType) *Query { func NewQuery(dbType DBType, secretsService secrets.Service) *Query {
return &Query{ return &Query{
dbType: dbType, dbType: dbType,
args: make([]any, 0), secretsService: secretsService,
args: make([]any, 0),
} }
} }

View File

@@ -24,7 +24,7 @@ func TestNewQuery(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
// Test that a new query is empty // Test that a new query is empty
if q.String() != "" { if q.String() != "" {
@@ -120,7 +120,7 @@ func TestBasicQueryBuilding(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()
@@ -215,7 +215,7 @@ func TestPlaceholders(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()
@@ -328,7 +328,7 @@ func TestWhereClauseBuilding(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()
@@ -403,7 +403,7 @@ func TestJoinClauseBuilding(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()
@@ -482,7 +482,7 @@ func TestOrderLimitOffset(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()
@@ -575,7 +575,7 @@ func TestInsertUpdateDelete(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()
@@ -641,7 +641,7 @@ func TestHavingClause(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()
@@ -790,7 +790,7 @@ func TestQueryReturning(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
query := db.NewQuery(tc.dbType) query := db.NewQuery(tc.dbType, &mockSecrets{})
result := tc.buildQuery(query) result := tc.buildQuery(query)
if result.String() != tc.expectedSQL { if result.String() != tc.expectedSQL {
@@ -838,7 +838,7 @@ func TestComplexQueries(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType) q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q) q = tt.buildFn(q)
gotSQL := q.String() gotSQL := q.String()

View File

@@ -10,8 +10,8 @@ import (
// CreateSession inserts a new session record into the database // CreateSession inserts a new session record into the database
func (db *database) CreateSession(session *models.Session) error { func (db *database) CreateSession(session *models.Session) error {
query, err := NewQuery(db.dbType). query, err := db.NewQuery().
InsertStruct(session, "sessions", db.secretsService) InsertStruct(session, "sessions")
if err != nil { if err != nil {
return fmt.Errorf("failed to create query: %w", err) return fmt.Errorf("failed to create query: %w", err)
} }
@@ -26,7 +26,7 @@ func (db *database) CreateSession(session *models.Session) error {
// GetSessionByRefreshToken retrieves a session by its refresh token // GetSessionByRefreshToken retrieves a session by its refresh token
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
session := &models.Session{} session := &models.Session{}
query := NewQuery(db.dbType). query := db.NewQuery().
Select("id", "user_id", "refresh_token", "expires_at", "created_at"). Select("id", "user_id", "refresh_token", "expires_at", "created_at").
From("sessions"). From("sessions").
Where("refresh_token = "). Where("refresh_token = ").
@@ -48,7 +48,7 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi
// GetSessionByID retrieves a session by its ID // GetSessionByID retrieves a session by its ID
func (db *database) GetSessionByID(sessionID string) (*models.Session, error) { func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
session := &models.Session{} session := &models.Session{}
query := NewQuery(db.dbType). query := db.NewQuery().
Select("id", "user_id", "refresh_token", "expires_at", "created_at"). Select("id", "user_id", "refresh_token", "expires_at", "created_at").
From("sessions"). From("sessions").
Where("id = "). Where("id = ").
@@ -69,7 +69,7 @@ func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
// DeleteSession removes a session from the database // DeleteSession removes a session from the database
func (db *database) DeleteSession(sessionID string) error { func (db *database) DeleteSession(sessionID string) error {
query := NewQuery(db.dbType). query := db.NewQuery().
Delete(). Delete().
From("sessions"). From("sessions").
Where("id = "). Where("id = ").
@@ -95,7 +95,7 @@ func (db *database) DeleteSession(sessionID string) error {
// CleanExpiredSessions removes all expired sessions from the database // CleanExpiredSessions removes all expired sessions from the database
func (db *database) CleanExpiredSessions() error { func (db *database) CleanExpiredSessions() error {
log := getLogger().WithGroup("sessions") log := getLogger().WithGroup("sessions")
query := NewQuery(db.dbType). query := db.NewQuery().
Delete(). Delete().
From("sessions"). From("sessions").
Where("expires_at <="). Where("expires_at <=").

View File

@@ -2,7 +2,6 @@ package db
import ( import (
"fmt" "fmt"
"lemma/internal/secrets"
"reflect" "reflect"
"strings" "strings"
"unicode" "unicode"
@@ -98,7 +97,7 @@ func toSnakeCase(s string) string {
return res return res
} }
func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service) (*Query, error) { func (q *Query) InsertStruct(s any, table string) (*Query, error) {
fields, err := StructTagsToFields(s) fields, err := StructTagsToFields(s)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -115,7 +114,7 @@ func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service
} }
if f.encrypted { if f.encrypted {
encValue, err := secretsService.Encrypt(value.(string)) encValue, err := q.secretsService.Encrypt(value.(string))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -133,3 +132,43 @@ func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service
q.Insert(table, columns...).Values(len(columns)).AddArgs(values...) q.Insert(table, columns...).Values(len(columns)).AddArgs(values...)
return q, nil return q, nil
} }
func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (*Query, error) {
fields, err := StructTagsToFields(s)
if err != nil {
return nil, err
}
if len(where) != len(args) {
return nil, fmt.Errorf("number of where clauses does not match number of arguments")
}
q = q.Update("users")
for _, f := range fields {
value := f.Value
if f.useDefault {
continue
}
if f.encrypted {
encValue, err := q.secretsService.Encrypt(value.(string))
if err != nil {
return nil, err
}
value = encValue
}
q = q.Set(f.Name).Placeholder(value)
}
for i, w := range where {
if i != 0 && i < len(args) {
q = q.And(w)
}
q = q.Where(w).Placeholder(args[i])
}
return q, nil
}

View File

@@ -49,7 +49,7 @@ func (db *database) EnsureJWTSecret() (string, error) {
// GetSystemSetting retrieves a system setting by key // GetSystemSetting retrieves a system setting by key
func (db *database) GetSystemSetting(key string) (string, error) { func (db *database) GetSystemSetting(key string) (string, error) {
var value string var value string
query := NewQuery(db.dbType). query := db.NewQuery().
Select("value"). Select("value").
From("system_settings"). From("system_settings").
Where("key = "). Where("key = ").
@@ -64,7 +64,7 @@ func (db *database) GetSystemSetting(key string) (string, error) {
// SetSystemSetting stores or updates a system setting // SetSystemSetting stores or updates a system setting
func (db *database) SetSystemSetting(key, value string) error { func (db *database) SetSystemSetting(key, value string) error {
query := NewQuery(db.dbType). query := db.NewQuery().
Insert("system_settings", "key", "value"). Insert("system_settings", "key", "value").
Values(2). Values(2).
AddArgs(key, value). AddArgs(key, value).
@@ -100,7 +100,7 @@ func (db *database) GetSystemStats() (*UserStats, error) {
stats := &UserStats{} stats := &UserStats{}
// Get total users // Get total users
query := NewQuery(db.dbType). query := db.NewQuery().
Select("COUNT(*)"). Select("COUNT(*)").
From("users") From("users")
err := db.QueryRow(query.String()).Scan(&stats.TotalUsers) err := db.QueryRow(query.String()).Scan(&stats.TotalUsers)
@@ -109,7 +109,7 @@ func (db *database) GetSystemStats() (*UserStats, error) {
} }
// Get total workspaces // Get total workspaces
query = NewQuery(db.dbType). query = db.NewQuery().
Select("COUNT(*)"). Select("COUNT(*)").
From("workspaces") From("workspaces")
err = db.QueryRow(query.String()).Scan(&stats.TotalWorkspaces) err = db.QueryRow(query.String()).Scan(&stats.TotalWorkspaces)
@@ -118,7 +118,7 @@ func (db *database) GetSystemStats() (*UserStats, error) {
} }
// Get active users (users with activity in last 30 days) // Get active users (users with activity in last 30 days)
query = NewQuery(db.dbType). query = db.NewQuery().
Select("COUNT(DISTINCT user_id)"). Select("COUNT(DISTINCT user_id)").
From("sessions"). From("sessions").
Where("created_at > datetime('now', '-30 days')") Where("created_at > datetime('now', '-30 days')")

View File

@@ -17,8 +17,8 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
} }
defer tx.Rollback() defer tx.Rollback()
query, err := NewQuery(db.dbType). query, err := db.NewQuery().
InsertStruct(user, "users", db.secretsService) InsertStruct(user, "users")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create query: %w", err) return nil, fmt.Errorf("failed to create query: %w", err)
@@ -46,7 +46,7 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
} }
// Update user's last workspace ID // Update user's last workspace ID
query = NewQuery(db.dbType). query = db.NewQuery().
Update("users"). Update("users").
Set("last_workspace_id"). Set("last_workspace_id").
Placeholder(defaultWorkspace.ID). Placeholder(defaultWorkspace.ID).
@@ -72,8 +72,8 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error { func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
log := getLogger().WithGroup("users") log := getLogger().WithGroup("users")
insertQuery, err := NewQuery(db.dbType). insertQuery, err := db.NewQuery().
InsertStruct(workspace, "workspaces", db.secretsService) InsertStruct(workspace, "workspaces")
if err != nil { if err != nil {
return fmt.Errorf("failed to create query: %w", err) return fmt.Errorf("failed to create query: %w", err)
@@ -94,7 +94,7 @@ func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) e
// GetUserByID retrieves a user by its ID // GetUserByID retrieves a user by its ID
func (db *database) GetUserByID(id int) (*models.User, error) { func (db *database) GetUserByID(id int) (*models.User, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id"). Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id").
From("users"). From("users").
Where("id = ").Placeholder(id) Where("id = ").Placeholder(id)
@@ -115,7 +115,7 @@ func (db *database) GetUserByID(id int) (*models.User, error) {
// GetUserByEmail retrieves a user by its email // GetUserByEmail retrieves a user by its email
func (db *database) GetUserByEmail(email string) (*models.User, error) { func (db *database) GetUserByEmail(email string) (*models.User, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id"). Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id").
From("users"). From("users").
Where("email = ").Placeholder(email) Where("email = ").Placeholder(email)
@@ -137,7 +137,7 @@ func (db *database) GetUserByEmail(email string) (*models.User, error) {
// UpdateUser updates an existing user record in the database // UpdateUser updates an existing user record in the database
func (db *database) UpdateUser(user *models.User) error { func (db *database) UpdateUser(user *models.User) error {
query := NewQuery(db.dbType). query := db.NewQuery().
Update("users"). Update("users").
Set("email").Placeholder(user.Email). Set("email").Placeholder(user.Email).
Set("display_name").Placeholder(user.DisplayName). Set("display_name").Placeholder(user.DisplayName).
@@ -165,7 +165,7 @@ func (db *database) UpdateUser(user *models.User) error {
// GetAllUsers retrieves all users from the database // GetAllUsers retrieves all users from the database
func (db *database) GetAllUsers() ([]*models.User, error) { func (db *database) GetAllUsers() ([]*models.User, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select("id", "email", "display_name", "role", "created_at", "last_workspace_id"). Select("id", "email", "display_name", "role", "created_at", "last_workspace_id").
From("users"). From("users").
OrderBy("id ASC") OrderBy("id ASC")
@@ -200,7 +200,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error
defer tx.Rollback() defer tx.Rollback()
// Find workspace ID from name // Find workspace ID from name
workspaceQuery := NewQuery(db.dbType). workspaceQuery := db.NewQuery().
Select("id"). Select("id").
From("workspaces"). From("workspaces").
Where("user_id = ").Placeholder(userID). Where("user_id = ").Placeholder(userID).
@@ -213,7 +213,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error
} }
// Update user's last workspace // Update user's last workspace
updateQuery := NewQuery(db.dbType). updateQuery := db.NewQuery().
Update("users"). Update("users").
Set("last_workspace_id").Placeholder(workspaceID). Set("last_workspace_id").Placeholder(workspaceID).
Where("id = ").Placeholder(userID) Where("id = ").Placeholder(userID)
@@ -245,7 +245,7 @@ func (db *database) DeleteUser(id int) error {
// Delete all user's workspaces first // Delete all user's workspaces first
log.Debug("deleting user workspaces", "user_id", id) log.Debug("deleting user workspaces", "user_id", id)
deleteWorkspacesQuery := NewQuery(db.dbType). deleteWorkspacesQuery := db.NewQuery().
Delete(). Delete().
From("workspaces"). From("workspaces").
Where("user_id = ").Placeholder(id) Where("user_id = ").Placeholder(id)
@@ -256,7 +256,7 @@ func (db *database) DeleteUser(id int) error {
} }
// Delete the user // Delete the user
deleteUserQuery := NewQuery(db.dbType). deleteUserQuery := db.NewQuery().
Delete(). Delete().
From("users"). From("users").
Where("id = ").Placeholder(id) Where("id = ").Placeholder(id)
@@ -277,7 +277,7 @@ func (db *database) DeleteUser(id int) error {
// GetLastWorkspaceName retrieves the name of the last workspace accessed by a user // GetLastWorkspaceName retrieves the name of the last workspace accessed by a user
func (db *database) GetLastWorkspaceName(userID int) (string, error) { func (db *database) GetLastWorkspaceName(userID int) (string, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select("w.name"). Select("w.name").
From("workspaces w"). From("workspaces w").
Join(InnerJoin, "users u", "u.last_workspace_id = w.id"). Join(InnerJoin, "users u", "u.last_workspace_id = w.id").
@@ -298,7 +298,7 @@ func (db *database) GetLastWorkspaceName(userID int) (string, error) {
// CountAdminUsers returns the number of admin users in the system // CountAdminUsers returns the number of admin users in the system
func (db *database) CountAdminUsers() (int, error) { func (db *database) CountAdminUsers() (int, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select("COUNT(*)"). Select("COUNT(*)").
From("users"). From("users").
Where("role = ").Placeholder(models.RoleAdmin) Where("role = ").Placeholder(models.RoleAdmin)

View File

@@ -19,7 +19,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error {
workspace.SetDefaultSettings() workspace.SetDefaultSettings()
} }
query, err := NewQuery(db.dbType). query, err := db.NewQuery().
InsertStruct(workspace, "workspaces", db.secretsService) InsertStruct(workspace, "workspaces", db.secretsService)
if err != nil { if err != nil {
@@ -39,7 +39,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error {
// GetWorkspaceByID retrieves a workspace by its ID // GetWorkspaceByID retrieves a workspace by its ID
func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select( Select(
"id", "user_id", "name", "created_at", "id", "user_id", "name", "created_at",
"theme", "auto_save", "show_hidden_files", "theme", "auto_save", "show_hidden_files",
@@ -86,7 +86,7 @@ func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
// GetWorkspaceByName retrieves a workspace by its name and user ID // GetWorkspaceByName retrieves a workspace by its name and user ID
func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) { func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select( Select(
"id", "user_id", "name", "created_at", "id", "user_id", "name", "created_at",
"theme", "auto_save", "show_hidden_files", "theme", "auto_save", "show_hidden_files",
@@ -133,28 +133,14 @@ func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*model
// UpdateWorkspace updates a workspace record in the database // UpdateWorkspace updates a workspace record in the database
func (db *database) UpdateWorkspace(workspace *models.Workspace) error { func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
// Encrypt token before storing
encryptedToken, err := db.encryptToken(workspace.GitToken)
if err != nil {
return fmt.Errorf("failed to encrypt token: %w", err)
}
query := NewQuery(db.dbType). query := db.NewQuery()
Update("workspaces"). query, err := query.
Set("name").Placeholder(workspace.Name). UpdateStruct(workspace, "workspaces", []string{"id =", "user_id ="}, []interface{}{workspace.ID, workspace.UserID})
Set("theme").Placeholder(workspace.Theme).
Set("auto_save").Placeholder(workspace.AutoSave). if err != nil {
Set("show_hidden_files").Placeholder(workspace.ShowHiddenFiles). return fmt.Errorf("failed to create query: %w", err)
Set("git_enabled").Placeholder(workspace.GitEnabled). }
Set("git_url").Placeholder(workspace.GitURL).
Set("git_user").Placeholder(workspace.GitUser).
Set("git_token").Placeholder(encryptedToken).
Set("git_auto_commit").Placeholder(workspace.GitAutoCommit).
Set("git_commit_msg_template").Placeholder(workspace.GitCommitMsgTemplate).
Set("git_commit_name").Placeholder(workspace.GitCommitName).
Set("git_commit_email").Placeholder(workspace.GitCommitEmail).
Where("id = ").Placeholder(workspace.ID).
And("user_id = ").Placeholder(workspace.UserID)
_, err = db.Exec(query.String(), query.Args()...) _, err = db.Exec(query.String(), query.Args()...)
if err != nil { if err != nil {
@@ -166,7 +152,7 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
// GetWorkspacesByUserID retrieves all workspaces for a user // GetWorkspacesByUserID retrieves all workspaces for a user
func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select( Select(
"id", "user_id", "name", "created_at", "id", "user_id", "name", "created_at",
"theme", "auto_save", "show_hidden_files", "theme", "auto_save", "show_hidden_files",
@@ -222,26 +208,17 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro
// UpdateWorkspaceSettings updates only the settings portion of a workspace // UpdateWorkspaceSettings updates only the settings portion of a workspace
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
// Encrypt token before storing
encryptedToken, err := db.encryptToken(workspace.GitToken)
if err != nil {
return fmt.Errorf("failed to encrypt token: %w", err)
}
query := NewQuery(db.dbType). where := []string{"id ="}
Update("workspaces"). args := []interface{}{workspace.ID}
Set("theme").Placeholder(workspace.Theme).
Set("auto_save").Placeholder(workspace.AutoSave). query := db.NewQuery()
Set("show_hidden_files").Placeholder(workspace.ShowHiddenFiles). query, err := query.
Set("git_enabled").Placeholder(workspace.GitEnabled). UpdateStruct(workspace, "workspaces", where, args)
Set("git_url").Placeholder(workspace.GitURL).
Set("git_user").Placeholder(workspace.GitUser). if err != nil {
Set("git_token").Placeholder(encryptedToken). return fmt.Errorf("failed to create query: %w", err)
Set("git_auto_commit").Placeholder(workspace.GitAutoCommit). }
Set("git_commit_msg_template").Placeholder(workspace.GitCommitMsgTemplate).
Set("git_commit_name").Placeholder(workspace.GitCommitName).
Set("git_commit_email").Placeholder(workspace.GitCommitEmail).
Where("id = ").Placeholder(workspace.ID)
_, err = db.Exec(query.String(), query.Args()...) _, err = db.Exec(query.String(), query.Args()...)
if err != nil { if err != nil {
@@ -255,7 +232,7 @@ func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
func (db *database) DeleteWorkspace(id int) error { func (db *database) DeleteWorkspace(id int) error {
log := getLogger().WithGroup("workspaces") log := getLogger().WithGroup("workspaces")
query := NewQuery(db.dbType). query := db.NewQuery().
Delete(). Delete().
From("workspaces"). From("workspaces").
Where("id = ").Placeholder(id) Where("id = ").Placeholder(id)
@@ -273,7 +250,7 @@ func (db *database) DeleteWorkspace(id int) error {
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error { func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
log := getLogger().WithGroup("workspaces") log := getLogger().WithGroup("workspaces")
query := NewQuery(db.dbType). query := db.NewQuery().
Delete(). Delete().
From("workspaces"). From("workspaces").
Where("id = ").Placeholder(id) Where("id = ").Placeholder(id)
@@ -294,7 +271,7 @@ func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
// UpdateLastWorkspaceTx sets the last workspace for a user in a transaction // UpdateLastWorkspaceTx sets the last workspace for a user in a transaction
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error { func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
query := NewQuery(db.dbType). query := db.NewQuery().
Update("users"). Update("users").
Set("last_workspace_id").Placeholder(workspaceID). Set("last_workspace_id").Placeholder(workspaceID).
Where("id = ").Placeholder(userID) Where("id = ").Placeholder(userID)
@@ -314,7 +291,7 @@ func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) e
// UpdateLastOpenedFile updates the last opened file path for a workspace // UpdateLastOpenedFile updates the last opened file path for a workspace
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error { func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
query := NewQuery(db.dbType). query := db.NewQuery().
Update("workspaces"). Update("workspaces").
Set("last_opened_file_path").Placeholder(filePath). Set("last_opened_file_path").Placeholder(filePath).
Where("id = ").Placeholder(workspaceID) Where("id = ").Placeholder(workspaceID)
@@ -329,7 +306,7 @@ func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error
// GetLastOpenedFile retrieves the last opened file path for a workspace // GetLastOpenedFile retrieves the last opened file path for a workspace
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) { func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select("last_opened_file_path"). Select("last_opened_file_path").
From("workspaces"). From("workspaces").
Where("id = ").Placeholder(workspaceID) Where("id = ").Placeholder(workspaceID)
@@ -353,7 +330,7 @@ func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
// GetAllWorkspaces retrieves all workspaces in the database // GetAllWorkspaces retrieves all workspaces in the database
func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
query := NewQuery(db.dbType). query := db.NewQuery().
Select( Select(
"id", "user_id", "name", "created_at", "id", "user_id", "name", "created_at",
"theme", "auto_save", "show_hidden_files", "theme", "auto_save", "show_hidden_files",