diff --git a/server/internal/db/db.go b/server/internal/db/db.go index ff80725..e171158 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -195,29 +195,6 @@ func (db *database) Close() error { return nil } -// Helper methods for token encryption/decryption -func (db *database) encryptToken(token string) (string, error) { - if token == "" { - return "", nil - } - - encrypted, err := db.secretsService.Encrypt(token) - if err != nil { - return "", fmt.Errorf("failed to encrypt token: %w", err) - } - - return encrypted, nil -} - -func (db *database) decryptToken(token string) (string, error) { - if token == "" { - return "", nil - } - - decrypted, err := db.secretsService.Decrypt(token) - if err != nil { - return "", fmt.Errorf("failed to decrypt token: %w", err) - } - - return decrypted, nil +func (db *database) NewQuery() *Query { + return NewQuery(db.dbType, db.secretsService) } diff --git a/server/internal/db/query.go b/server/internal/db/query.go index a552a8e..192e848 100644 --- a/server/internal/db/query.go +++ b/server/internal/db/query.go @@ -2,6 +2,7 @@ package db import ( "fmt" + "lemma/internal/secrets" "strings" ) @@ -15,27 +16,29 @@ const ( // Query represents a SQL query with its parameters type Query struct { - builder strings.Builder - args []any - dbType DBType - pos int // tracks the current placeholder position - hasSelect bool - hasFrom bool - hasWhere bool - hasOrderBy bool - hasGroupBy bool - hasHaving bool - hasLimit bool - hasOffset bool - isInParens bool - parensDepth int + builder strings.Builder + args []any + dbType DBType + secretsService secrets.Service + pos int // tracks the current placeholder position + hasSelect bool + hasFrom bool + hasWhere bool + hasOrderBy bool + hasGroupBy bool + hasHaving bool + hasLimit bool + hasOffset bool + isInParens bool + parensDepth int } // NewQuery creates a new Query instance -func NewQuery(dbType DBType) *Query { +func NewQuery(dbType DBType, secretsService secrets.Service) *Query { return &Query{ - dbType: dbType, - args: make([]any, 0), + dbType: dbType, + secretsService: secretsService, + args: make([]any, 0), } } diff --git a/server/internal/db/query_test.go b/server/internal/db/query_test.go index 4910328..907b633 100644 --- a/server/internal/db/query_test.go +++ b/server/internal/db/query_test.go @@ -24,7 +24,7 @@ func TestNewQuery(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) // Test that a new query is empty if q.String() != "" { @@ -120,7 +120,7 @@ func TestBasicQueryBuilding(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() @@ -215,7 +215,7 @@ func TestPlaceholders(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() @@ -328,7 +328,7 @@ func TestWhereClauseBuilding(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() @@ -403,7 +403,7 @@ func TestJoinClauseBuilding(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() @@ -482,7 +482,7 @@ func TestOrderLimitOffset(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() @@ -575,7 +575,7 @@ func TestInsertUpdateDelete(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() @@ -641,7 +641,7 @@ func TestHavingClause(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() @@ -790,7 +790,7 @@ func TestQueryReturning(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - query := db.NewQuery(tc.dbType) + query := db.NewQuery(tc.dbType, &mockSecrets{}) result := tc.buildQuery(query) if result.String() != tc.expectedSQL { @@ -838,7 +838,7 @@ func TestComplexQueries(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() diff --git a/server/internal/db/sessions.go b/server/internal/db/sessions.go index bdf55db..fdfcc7d 100644 --- a/server/internal/db/sessions.go +++ b/server/internal/db/sessions.go @@ -10,8 +10,8 @@ import ( // CreateSession inserts a new session record into the database func (db *database) CreateSession(session *models.Session) error { - query, err := NewQuery(db.dbType). - InsertStruct(session, "sessions", db.secretsService) + query, err := db.NewQuery(). + InsertStruct(session, "sessions") if err != nil { return fmt.Errorf("failed to create query: %w", err) } @@ -26,7 +26,7 @@ func (db *database) CreateSession(session *models.Session) error { // GetSessionByRefreshToken retrieves a session by its refresh token func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { session := &models.Session{} - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("id", "user_id", "refresh_token", "expires_at", "created_at"). From("sessions"). Where("refresh_token = "). @@ -48,7 +48,7 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi // GetSessionByID retrieves a session by its ID func (db *database) GetSessionByID(sessionID string) (*models.Session, error) { session := &models.Session{} - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("id", "user_id", "refresh_token", "expires_at", "created_at"). From("sessions"). Where("id = "). @@ -69,7 +69,7 @@ func (db *database) GetSessionByID(sessionID string) (*models.Session, error) { // DeleteSession removes a session from the database func (db *database) DeleteSession(sessionID string) error { - query := NewQuery(db.dbType). + query := db.NewQuery(). Delete(). From("sessions"). Where("id = "). @@ -95,7 +95,7 @@ func (db *database) DeleteSession(sessionID string) error { // CleanExpiredSessions removes all expired sessions from the database func (db *database) CleanExpiredSessions() error { log := getLogger().WithGroup("sessions") - query := NewQuery(db.dbType). + query := db.NewQuery(). Delete(). From("sessions"). Where("expires_at <="). diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go index d6c8cde..efa6321 100644 --- a/server/internal/db/struct_query.go +++ b/server/internal/db/struct_query.go @@ -2,7 +2,6 @@ package db import ( "fmt" - "lemma/internal/secrets" "reflect" "strings" "unicode" @@ -98,7 +97,7 @@ func toSnakeCase(s string) string { return res } -func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service) (*Query, error) { +func (q *Query) InsertStruct(s any, table string) (*Query, error) { fields, err := StructTagsToFields(s) if err != nil { return nil, err @@ -115,7 +114,7 @@ func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service } if f.encrypted { - encValue, err := secretsService.Encrypt(value.(string)) + encValue, err := q.secretsService.Encrypt(value.(string)) if err != nil { return nil, err } @@ -133,3 +132,43 @@ func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service q.Insert(table, columns...).Values(len(columns)).AddArgs(values...) return q, nil } + +func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (*Query, error) { + fields, err := StructTagsToFields(s) + if err != nil { + return nil, err + } + + if len(where) != len(args) { + return nil, fmt.Errorf("number of where clauses does not match number of arguments") + } + + q = q.Update("users") + + for _, f := range fields { + value := f.Value + + if f.useDefault { + continue + } + + if f.encrypted { + encValue, err := q.secretsService.Encrypt(value.(string)) + if err != nil { + return nil, err + } + value = encValue + } + + q = q.Set(f.Name).Placeholder(value) + } + + for i, w := range where { + if i != 0 && i < len(args) { + q = q.And(w) + } + q = q.Where(w).Placeholder(args[i]) + } + + return q, nil +} diff --git a/server/internal/db/system.go b/server/internal/db/system.go index d81aef3..a193ee7 100644 --- a/server/internal/db/system.go +++ b/server/internal/db/system.go @@ -49,7 +49,7 @@ func (db *database) EnsureJWTSecret() (string, error) { // GetSystemSetting retrieves a system setting by key func (db *database) GetSystemSetting(key string) (string, error) { var value string - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("value"). From("system_settings"). Where("key = "). @@ -64,7 +64,7 @@ func (db *database) GetSystemSetting(key string) (string, error) { // SetSystemSetting stores or updates a system setting func (db *database) SetSystemSetting(key, value string) error { - query := NewQuery(db.dbType). + query := db.NewQuery(). Insert("system_settings", "key", "value"). Values(2). AddArgs(key, value). @@ -100,7 +100,7 @@ func (db *database) GetSystemStats() (*UserStats, error) { stats := &UserStats{} // Get total users - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("COUNT(*)"). From("users") err := db.QueryRow(query.String()).Scan(&stats.TotalUsers) @@ -109,7 +109,7 @@ func (db *database) GetSystemStats() (*UserStats, error) { } // Get total workspaces - query = NewQuery(db.dbType). + query = db.NewQuery(). Select("COUNT(*)"). From("workspaces") err = db.QueryRow(query.String()).Scan(&stats.TotalWorkspaces) @@ -118,7 +118,7 @@ func (db *database) GetSystemStats() (*UserStats, error) { } // Get active users (users with activity in last 30 days) - query = NewQuery(db.dbType). + query = db.NewQuery(). Select("COUNT(DISTINCT user_id)"). From("sessions"). Where("created_at > datetime('now', '-30 days')") diff --git a/server/internal/db/users.go b/server/internal/db/users.go index ed792d6..5d0cc55 100644 --- a/server/internal/db/users.go +++ b/server/internal/db/users.go @@ -17,8 +17,8 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) { } defer tx.Rollback() - query, err := NewQuery(db.dbType). - InsertStruct(user, "users", db.secretsService) + query, err := db.NewQuery(). + InsertStruct(user, "users") if err != nil { return nil, fmt.Errorf("failed to create query: %w", err) @@ -46,7 +46,7 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) { } // Update user's last workspace ID - query = NewQuery(db.dbType). + query = db.NewQuery(). Update("users"). Set("last_workspace_id"). Placeholder(defaultWorkspace.ID). @@ -72,8 +72,8 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) { func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error { log := getLogger().WithGroup("users") - insertQuery, err := NewQuery(db.dbType). - InsertStruct(workspace, "workspaces", db.secretsService) + insertQuery, err := db.NewQuery(). + InsertStruct(workspace, "workspaces") if err != nil { return fmt.Errorf("failed to create query: %w", err) @@ -94,7 +94,7 @@ func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) e // GetUserByID retrieves a user by its ID func (db *database) GetUserByID(id int) (*models.User, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id"). From("users"). Where("id = ").Placeholder(id) @@ -115,7 +115,7 @@ func (db *database) GetUserByID(id int) (*models.User, error) { // GetUserByEmail retrieves a user by its email func (db *database) GetUserByEmail(email string) (*models.User, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id"). From("users"). Where("email = ").Placeholder(email) @@ -137,7 +137,7 @@ func (db *database) GetUserByEmail(email string) (*models.User, error) { // UpdateUser updates an existing user record in the database func (db *database) UpdateUser(user *models.User) error { - query := NewQuery(db.dbType). + query := db.NewQuery(). Update("users"). Set("email").Placeholder(user.Email). Set("display_name").Placeholder(user.DisplayName). @@ -165,7 +165,7 @@ func (db *database) UpdateUser(user *models.User) error { // GetAllUsers retrieves all users from the database func (db *database) GetAllUsers() ([]*models.User, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("id", "email", "display_name", "role", "created_at", "last_workspace_id"). From("users"). OrderBy("id ASC") @@ -200,7 +200,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error defer tx.Rollback() // Find workspace ID from name - workspaceQuery := NewQuery(db.dbType). + workspaceQuery := db.NewQuery(). Select("id"). From("workspaces"). Where("user_id = ").Placeholder(userID). @@ -213,7 +213,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error } // Update user's last workspace - updateQuery := NewQuery(db.dbType). + updateQuery := db.NewQuery(). Update("users"). Set("last_workspace_id").Placeholder(workspaceID). Where("id = ").Placeholder(userID) @@ -245,7 +245,7 @@ func (db *database) DeleteUser(id int) error { // Delete all user's workspaces first log.Debug("deleting user workspaces", "user_id", id) - deleteWorkspacesQuery := NewQuery(db.dbType). + deleteWorkspacesQuery := db.NewQuery(). Delete(). From("workspaces"). Where("user_id = ").Placeholder(id) @@ -256,7 +256,7 @@ func (db *database) DeleteUser(id int) error { } // Delete the user - deleteUserQuery := NewQuery(db.dbType). + deleteUserQuery := db.NewQuery(). Delete(). From("users"). Where("id = ").Placeholder(id) @@ -277,7 +277,7 @@ func (db *database) DeleteUser(id int) error { // GetLastWorkspaceName retrieves the name of the last workspace accessed by a user func (db *database) GetLastWorkspaceName(userID int) (string, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("w.name"). From("workspaces w"). Join(InnerJoin, "users u", "u.last_workspace_id = w.id"). @@ -298,7 +298,7 @@ func (db *database) GetLastWorkspaceName(userID int) (string, error) { // CountAdminUsers returns the number of admin users in the system func (db *database) CountAdminUsers() (int, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("COUNT(*)"). From("users"). Where("role = ").Placeholder(models.RoleAdmin) diff --git a/server/internal/db/workspaces.go b/server/internal/db/workspaces.go index 120033f..64ebd7e 100644 --- a/server/internal/db/workspaces.go +++ b/server/internal/db/workspaces.go @@ -19,7 +19,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error { workspace.SetDefaultSettings() } - query, err := NewQuery(db.dbType). + query, err := db.NewQuery(). InsertStruct(workspace, "workspaces", db.secretsService) if err != nil { @@ -39,7 +39,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error { // GetWorkspaceByID retrieves a workspace by its ID func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select( "id", "user_id", "name", "created_at", "theme", "auto_save", "show_hidden_files", @@ -86,7 +86,7 @@ func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { // GetWorkspaceByName retrieves a workspace by its name and user ID func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select( "id", "user_id", "name", "created_at", "theme", "auto_save", "show_hidden_files", @@ -133,28 +133,14 @@ func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*model // UpdateWorkspace updates a workspace record in the database func (db *database) UpdateWorkspace(workspace *models.Workspace) error { - // Encrypt token before storing - encryptedToken, err := db.encryptToken(workspace.GitToken) - if err != nil { - return fmt.Errorf("failed to encrypt token: %w", err) - } - query := NewQuery(db.dbType). - Update("workspaces"). - Set("name").Placeholder(workspace.Name). - Set("theme").Placeholder(workspace.Theme). - Set("auto_save").Placeholder(workspace.AutoSave). - Set("show_hidden_files").Placeholder(workspace.ShowHiddenFiles). - Set("git_enabled").Placeholder(workspace.GitEnabled). - Set("git_url").Placeholder(workspace.GitURL). - Set("git_user").Placeholder(workspace.GitUser). - Set("git_token").Placeholder(encryptedToken). - Set("git_auto_commit").Placeholder(workspace.GitAutoCommit). - Set("git_commit_msg_template").Placeholder(workspace.GitCommitMsgTemplate). - Set("git_commit_name").Placeholder(workspace.GitCommitName). - Set("git_commit_email").Placeholder(workspace.GitCommitEmail). - Where("id = ").Placeholder(workspace.ID). - And("user_id = ").Placeholder(workspace.UserID) + query := db.NewQuery() + query, err := query. + UpdateStruct(workspace, "workspaces", []string{"id =", "user_id ="}, []interface{}{workspace.ID, workspace.UserID}) + + if err != nil { + return fmt.Errorf("failed to create query: %w", err) + } _, err = db.Exec(query.String(), query.Args()...) if err != nil { @@ -166,7 +152,7 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error { // GetWorkspacesByUserID retrieves all workspaces for a user func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select( "id", "user_id", "name", "created_at", "theme", "auto_save", "show_hidden_files", @@ -222,26 +208,17 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro // UpdateWorkspaceSettings updates only the settings portion of a workspace func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { - // Encrypt token before storing - encryptedToken, err := db.encryptToken(workspace.GitToken) - if err != nil { - return fmt.Errorf("failed to encrypt token: %w", err) - } - query := NewQuery(db.dbType). - Update("workspaces"). - Set("theme").Placeholder(workspace.Theme). - Set("auto_save").Placeholder(workspace.AutoSave). - Set("show_hidden_files").Placeholder(workspace.ShowHiddenFiles). - Set("git_enabled").Placeholder(workspace.GitEnabled). - Set("git_url").Placeholder(workspace.GitURL). - Set("git_user").Placeholder(workspace.GitUser). - Set("git_token").Placeholder(encryptedToken). - Set("git_auto_commit").Placeholder(workspace.GitAutoCommit). - Set("git_commit_msg_template").Placeholder(workspace.GitCommitMsgTemplate). - Set("git_commit_name").Placeholder(workspace.GitCommitName). - Set("git_commit_email").Placeholder(workspace.GitCommitEmail). - Where("id = ").Placeholder(workspace.ID) + where := []string{"id ="} + args := []interface{}{workspace.ID} + + query := db.NewQuery() + query, err := query. + UpdateStruct(workspace, "workspaces", where, args) + + if err != nil { + return fmt.Errorf("failed to create query: %w", err) + } _, err = db.Exec(query.String(), query.Args()...) if err != nil { @@ -255,7 +232,7 @@ func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { func (db *database) DeleteWorkspace(id int) error { log := getLogger().WithGroup("workspaces") - query := NewQuery(db.dbType). + query := db.NewQuery(). Delete(). From("workspaces"). Where("id = ").Placeholder(id) @@ -273,7 +250,7 @@ func (db *database) DeleteWorkspace(id int) error { func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error { log := getLogger().WithGroup("workspaces") - query := NewQuery(db.dbType). + query := db.NewQuery(). Delete(). From("workspaces"). Where("id = ").Placeholder(id) @@ -294,7 +271,7 @@ func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error { // UpdateLastWorkspaceTx sets the last workspace for a user in a transaction func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error { - query := NewQuery(db.dbType). + query := db.NewQuery(). Update("users"). Set("last_workspace_id").Placeholder(workspaceID). Where("id = ").Placeholder(userID) @@ -314,7 +291,7 @@ func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) e // UpdateLastOpenedFile updates the last opened file path for a workspace func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error { - query := NewQuery(db.dbType). + query := db.NewQuery(). Update("workspaces"). Set("last_opened_file_path").Placeholder(filePath). Where("id = ").Placeholder(workspaceID) @@ -329,7 +306,7 @@ func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error // GetLastOpenedFile retrieves the last opened file path for a workspace func (db *database) GetLastOpenedFile(workspaceID int) (string, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("last_opened_file_path"). From("workspaces"). Where("id = ").Placeholder(workspaceID) @@ -353,7 +330,7 @@ func (db *database) GetLastOpenedFile(workspaceID int) (string, error) { // GetAllWorkspaces retrieves all workspaces in the database func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select( "id", "user_id", "name", "created_at", "theme", "auto_save", "show_hidden_files",