Implement update struct

This commit is contained in:
2025-03-02 18:40:12 +01:00
parent 204dacd15e
commit ccac439465
8 changed files with 127 additions and 131 deletions

View File

@@ -195,29 +195,6 @@ func (db *database) Close() error {
return nil
}
// Helper methods for token encryption/decryption
func (db *database) encryptToken(token string) (string, error) {
if token == "" {
return "", nil
}
encrypted, err := db.secretsService.Encrypt(token)
if err != nil {
return "", fmt.Errorf("failed to encrypt token: %w", err)
}
return encrypted, nil
}
func (db *database) decryptToken(token string) (string, error) {
if token == "" {
return "", nil
}
decrypted, err := db.secretsService.Decrypt(token)
if err != nil {
return "", fmt.Errorf("failed to decrypt token: %w", err)
}
return decrypted, nil
func (db *database) NewQuery() *Query {
return NewQuery(db.dbType, db.secretsService)
}

View File

@@ -2,6 +2,7 @@ package db
import (
"fmt"
"lemma/internal/secrets"
"strings"
)
@@ -15,27 +16,29 @@ const (
// Query represents a SQL query with its parameters
type Query struct {
builder strings.Builder
args []any
dbType DBType
pos int // tracks the current placeholder position
hasSelect bool
hasFrom bool
hasWhere bool
hasOrderBy bool
hasGroupBy bool
hasHaving bool
hasLimit bool
hasOffset bool
isInParens bool
parensDepth int
builder strings.Builder
args []any
dbType DBType
secretsService secrets.Service
pos int // tracks the current placeholder position
hasSelect bool
hasFrom bool
hasWhere bool
hasOrderBy bool
hasGroupBy bool
hasHaving bool
hasLimit bool
hasOffset bool
isInParens bool
parensDepth int
}
// NewQuery creates a new Query instance
func NewQuery(dbType DBType) *Query {
func NewQuery(dbType DBType, secretsService secrets.Service) *Query {
return &Query{
dbType: dbType,
args: make([]any, 0),
dbType: dbType,
secretsService: secretsService,
args: make([]any, 0),
}
}

View File

@@ -24,7 +24,7 @@ func TestNewQuery(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q := db.NewQuery(tt.dbType, &mockSecrets{})
// Test that a new query is empty
if q.String() != "" {
@@ -120,7 +120,7 @@ func TestBasicQueryBuilding(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
@@ -215,7 +215,7 @@ func TestPlaceholders(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
@@ -328,7 +328,7 @@ func TestWhereClauseBuilding(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
@@ -403,7 +403,7 @@ func TestJoinClauseBuilding(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
@@ -482,7 +482,7 @@ func TestOrderLimitOffset(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
@@ -575,7 +575,7 @@ func TestInsertUpdateDelete(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
@@ -641,7 +641,7 @@ func TestHavingClause(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
@@ -790,7 +790,7 @@ func TestQueryReturning(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
query := db.NewQuery(tc.dbType)
query := db.NewQuery(tc.dbType, &mockSecrets{})
result := tc.buildQuery(query)
if result.String() != tc.expectedSQL {
@@ -838,7 +838,7 @@ func TestComplexQueries(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()

View File

@@ -10,8 +10,8 @@ import (
// CreateSession inserts a new session record into the database
func (db *database) CreateSession(session *models.Session) error {
query, err := NewQuery(db.dbType).
InsertStruct(session, "sessions", db.secretsService)
query, err := db.NewQuery().
InsertStruct(session, "sessions")
if err != nil {
return fmt.Errorf("failed to create query: %w", err)
}
@@ -26,7 +26,7 @@ func (db *database) CreateSession(session *models.Session) error {
// GetSessionByRefreshToken retrieves a session by its refresh token
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
session := &models.Session{}
query := NewQuery(db.dbType).
query := db.NewQuery().
Select("id", "user_id", "refresh_token", "expires_at", "created_at").
From("sessions").
Where("refresh_token = ").
@@ -48,7 +48,7 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi
// GetSessionByID retrieves a session by its ID
func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
session := &models.Session{}
query := NewQuery(db.dbType).
query := db.NewQuery().
Select("id", "user_id", "refresh_token", "expires_at", "created_at").
From("sessions").
Where("id = ").
@@ -69,7 +69,7 @@ func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
// DeleteSession removes a session from the database
func (db *database) DeleteSession(sessionID string) error {
query := NewQuery(db.dbType).
query := db.NewQuery().
Delete().
From("sessions").
Where("id = ").
@@ -95,7 +95,7 @@ func (db *database) DeleteSession(sessionID string) error {
// CleanExpiredSessions removes all expired sessions from the database
func (db *database) CleanExpiredSessions() error {
log := getLogger().WithGroup("sessions")
query := NewQuery(db.dbType).
query := db.NewQuery().
Delete().
From("sessions").
Where("expires_at <=").

View File

@@ -2,7 +2,6 @@ package db
import (
"fmt"
"lemma/internal/secrets"
"reflect"
"strings"
"unicode"
@@ -98,7 +97,7 @@ func toSnakeCase(s string) string {
return res
}
func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service) (*Query, error) {
func (q *Query) InsertStruct(s any, table string) (*Query, error) {
fields, err := StructTagsToFields(s)
if err != nil {
return nil, err
@@ -115,7 +114,7 @@ func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service
}
if f.encrypted {
encValue, err := secretsService.Encrypt(value.(string))
encValue, err := q.secretsService.Encrypt(value.(string))
if err != nil {
return nil, err
}
@@ -133,3 +132,43 @@ func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service
q.Insert(table, columns...).Values(len(columns)).AddArgs(values...)
return q, nil
}
func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (*Query, error) {
fields, err := StructTagsToFields(s)
if err != nil {
return nil, err
}
if len(where) != len(args) {
return nil, fmt.Errorf("number of where clauses does not match number of arguments")
}
q = q.Update("users")
for _, f := range fields {
value := f.Value
if f.useDefault {
continue
}
if f.encrypted {
encValue, err := q.secretsService.Encrypt(value.(string))
if err != nil {
return nil, err
}
value = encValue
}
q = q.Set(f.Name).Placeholder(value)
}
for i, w := range where {
if i != 0 && i < len(args) {
q = q.And(w)
}
q = q.Where(w).Placeholder(args[i])
}
return q, nil
}

View File

@@ -49,7 +49,7 @@ func (db *database) EnsureJWTSecret() (string, error) {
// GetSystemSetting retrieves a system setting by key
func (db *database) GetSystemSetting(key string) (string, error) {
var value string
query := NewQuery(db.dbType).
query := db.NewQuery().
Select("value").
From("system_settings").
Where("key = ").
@@ -64,7 +64,7 @@ func (db *database) GetSystemSetting(key string) (string, error) {
// SetSystemSetting stores or updates a system setting
func (db *database) SetSystemSetting(key, value string) error {
query := NewQuery(db.dbType).
query := db.NewQuery().
Insert("system_settings", "key", "value").
Values(2).
AddArgs(key, value).
@@ -100,7 +100,7 @@ func (db *database) GetSystemStats() (*UserStats, error) {
stats := &UserStats{}
// Get total users
query := NewQuery(db.dbType).
query := db.NewQuery().
Select("COUNT(*)").
From("users")
err := db.QueryRow(query.String()).Scan(&stats.TotalUsers)
@@ -109,7 +109,7 @@ func (db *database) GetSystemStats() (*UserStats, error) {
}
// Get total workspaces
query = NewQuery(db.dbType).
query = db.NewQuery().
Select("COUNT(*)").
From("workspaces")
err = db.QueryRow(query.String()).Scan(&stats.TotalWorkspaces)
@@ -118,7 +118,7 @@ func (db *database) GetSystemStats() (*UserStats, error) {
}
// Get active users (users with activity in last 30 days)
query = NewQuery(db.dbType).
query = db.NewQuery().
Select("COUNT(DISTINCT user_id)").
From("sessions").
Where("created_at > datetime('now', '-30 days')")

View File

@@ -17,8 +17,8 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
}
defer tx.Rollback()
query, err := NewQuery(db.dbType).
InsertStruct(user, "users", db.secretsService)
query, err := db.NewQuery().
InsertStruct(user, "users")
if err != nil {
return nil, fmt.Errorf("failed to create query: %w", err)
@@ -46,7 +46,7 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
}
// Update user's last workspace ID
query = NewQuery(db.dbType).
query = db.NewQuery().
Update("users").
Set("last_workspace_id").
Placeholder(defaultWorkspace.ID).
@@ -72,8 +72,8 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
log := getLogger().WithGroup("users")
insertQuery, err := NewQuery(db.dbType).
InsertStruct(workspace, "workspaces", db.secretsService)
insertQuery, err := db.NewQuery().
InsertStruct(workspace, "workspaces")
if err != nil {
return fmt.Errorf("failed to create query: %w", err)
@@ -94,7 +94,7 @@ func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) e
// GetUserByID retrieves a user by its ID
func (db *database) GetUserByID(id int) (*models.User, error) {
query := NewQuery(db.dbType).
query := db.NewQuery().
Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id").
From("users").
Where("id = ").Placeholder(id)
@@ -115,7 +115,7 @@ func (db *database) GetUserByID(id int) (*models.User, error) {
// GetUserByEmail retrieves a user by its email
func (db *database) GetUserByEmail(email string) (*models.User, error) {
query := NewQuery(db.dbType).
query := db.NewQuery().
Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id").
From("users").
Where("email = ").Placeholder(email)
@@ -137,7 +137,7 @@ func (db *database) GetUserByEmail(email string) (*models.User, error) {
// UpdateUser updates an existing user record in the database
func (db *database) UpdateUser(user *models.User) error {
query := NewQuery(db.dbType).
query := db.NewQuery().
Update("users").
Set("email").Placeholder(user.Email).
Set("display_name").Placeholder(user.DisplayName).
@@ -165,7 +165,7 @@ func (db *database) UpdateUser(user *models.User) error {
// GetAllUsers retrieves all users from the database
func (db *database) GetAllUsers() ([]*models.User, error) {
query := NewQuery(db.dbType).
query := db.NewQuery().
Select("id", "email", "display_name", "role", "created_at", "last_workspace_id").
From("users").
OrderBy("id ASC")
@@ -200,7 +200,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error
defer tx.Rollback()
// Find workspace ID from name
workspaceQuery := NewQuery(db.dbType).
workspaceQuery := db.NewQuery().
Select("id").
From("workspaces").
Where("user_id = ").Placeholder(userID).
@@ -213,7 +213,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error
}
// Update user's last workspace
updateQuery := NewQuery(db.dbType).
updateQuery := db.NewQuery().
Update("users").
Set("last_workspace_id").Placeholder(workspaceID).
Where("id = ").Placeholder(userID)
@@ -245,7 +245,7 @@ func (db *database) DeleteUser(id int) error {
// Delete all user's workspaces first
log.Debug("deleting user workspaces", "user_id", id)
deleteWorkspacesQuery := NewQuery(db.dbType).
deleteWorkspacesQuery := db.NewQuery().
Delete().
From("workspaces").
Where("user_id = ").Placeholder(id)
@@ -256,7 +256,7 @@ func (db *database) DeleteUser(id int) error {
}
// Delete the user
deleteUserQuery := NewQuery(db.dbType).
deleteUserQuery := db.NewQuery().
Delete().
From("users").
Where("id = ").Placeholder(id)
@@ -277,7 +277,7 @@ func (db *database) DeleteUser(id int) error {
// GetLastWorkspaceName retrieves the name of the last workspace accessed by a user
func (db *database) GetLastWorkspaceName(userID int) (string, error) {
query := NewQuery(db.dbType).
query := db.NewQuery().
Select("w.name").
From("workspaces w").
Join(InnerJoin, "users u", "u.last_workspace_id = w.id").
@@ -298,7 +298,7 @@ func (db *database) GetLastWorkspaceName(userID int) (string, error) {
// CountAdminUsers returns the number of admin users in the system
func (db *database) CountAdminUsers() (int, error) {
query := NewQuery(db.dbType).
query := db.NewQuery().
Select("COUNT(*)").
From("users").
Where("role = ").Placeholder(models.RoleAdmin)

View File

@@ -19,7 +19,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error {
workspace.SetDefaultSettings()
}
query, err := NewQuery(db.dbType).
query, err := db.NewQuery().
InsertStruct(workspace, "workspaces", db.secretsService)
if err != nil {
@@ -39,7 +39,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error {
// GetWorkspaceByID retrieves a workspace by its ID
func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
query := NewQuery(db.dbType).
query := db.NewQuery().
Select(
"id", "user_id", "name", "created_at",
"theme", "auto_save", "show_hidden_files",
@@ -86,7 +86,7 @@ func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
// GetWorkspaceByName retrieves a workspace by its name and user ID
func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
query := NewQuery(db.dbType).
query := db.NewQuery().
Select(
"id", "user_id", "name", "created_at",
"theme", "auto_save", "show_hidden_files",
@@ -133,28 +133,14 @@ func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*model
// UpdateWorkspace updates a workspace record in the database
func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
// Encrypt token before storing
encryptedToken, err := db.encryptToken(workspace.GitToken)
if err != nil {
return fmt.Errorf("failed to encrypt token: %w", err)
}
query := NewQuery(db.dbType).
Update("workspaces").
Set("name").Placeholder(workspace.Name).
Set("theme").Placeholder(workspace.Theme).
Set("auto_save").Placeholder(workspace.AutoSave).
Set("show_hidden_files").Placeholder(workspace.ShowHiddenFiles).
Set("git_enabled").Placeholder(workspace.GitEnabled).
Set("git_url").Placeholder(workspace.GitURL).
Set("git_user").Placeholder(workspace.GitUser).
Set("git_token").Placeholder(encryptedToken).
Set("git_auto_commit").Placeholder(workspace.GitAutoCommit).
Set("git_commit_msg_template").Placeholder(workspace.GitCommitMsgTemplate).
Set("git_commit_name").Placeholder(workspace.GitCommitName).
Set("git_commit_email").Placeholder(workspace.GitCommitEmail).
Where("id = ").Placeholder(workspace.ID).
And("user_id = ").Placeholder(workspace.UserID)
query := db.NewQuery()
query, err := query.
UpdateStruct(workspace, "workspaces", []string{"id =", "user_id ="}, []interface{}{workspace.ID, workspace.UserID})
if err != nil {
return fmt.Errorf("failed to create query: %w", err)
}
_, err = db.Exec(query.String(), query.Args()...)
if err != nil {
@@ -166,7 +152,7 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
// GetWorkspacesByUserID retrieves all workspaces for a user
func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
query := NewQuery(db.dbType).
query := db.NewQuery().
Select(
"id", "user_id", "name", "created_at",
"theme", "auto_save", "show_hidden_files",
@@ -222,26 +208,17 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro
// UpdateWorkspaceSettings updates only the settings portion of a workspace
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
// Encrypt token before storing
encryptedToken, err := db.encryptToken(workspace.GitToken)
if err != nil {
return fmt.Errorf("failed to encrypt token: %w", err)
}
query := NewQuery(db.dbType).
Update("workspaces").
Set("theme").Placeholder(workspace.Theme).
Set("auto_save").Placeholder(workspace.AutoSave).
Set("show_hidden_files").Placeholder(workspace.ShowHiddenFiles).
Set("git_enabled").Placeholder(workspace.GitEnabled).
Set("git_url").Placeholder(workspace.GitURL).
Set("git_user").Placeholder(workspace.GitUser).
Set("git_token").Placeholder(encryptedToken).
Set("git_auto_commit").Placeholder(workspace.GitAutoCommit).
Set("git_commit_msg_template").Placeholder(workspace.GitCommitMsgTemplate).
Set("git_commit_name").Placeholder(workspace.GitCommitName).
Set("git_commit_email").Placeholder(workspace.GitCommitEmail).
Where("id = ").Placeholder(workspace.ID)
where := []string{"id ="}
args := []interface{}{workspace.ID}
query := db.NewQuery()
query, err := query.
UpdateStruct(workspace, "workspaces", where, args)
if err != nil {
return fmt.Errorf("failed to create query: %w", err)
}
_, err = db.Exec(query.String(), query.Args()...)
if err != nil {
@@ -255,7 +232,7 @@ func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
func (db *database) DeleteWorkspace(id int) error {
log := getLogger().WithGroup("workspaces")
query := NewQuery(db.dbType).
query := db.NewQuery().
Delete().
From("workspaces").
Where("id = ").Placeholder(id)
@@ -273,7 +250,7 @@ func (db *database) DeleteWorkspace(id int) error {
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
log := getLogger().WithGroup("workspaces")
query := NewQuery(db.dbType).
query := db.NewQuery().
Delete().
From("workspaces").
Where("id = ").Placeholder(id)
@@ -294,7 +271,7 @@ func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
// UpdateLastWorkspaceTx sets the last workspace for a user in a transaction
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
query := NewQuery(db.dbType).
query := db.NewQuery().
Update("users").
Set("last_workspace_id").Placeholder(workspaceID).
Where("id = ").Placeholder(userID)
@@ -314,7 +291,7 @@ func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) e
// UpdateLastOpenedFile updates the last opened file path for a workspace
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
query := NewQuery(db.dbType).
query := db.NewQuery().
Update("workspaces").
Set("last_opened_file_path").Placeholder(filePath).
Where("id = ").Placeholder(workspaceID)
@@ -329,7 +306,7 @@ func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error
// GetLastOpenedFile retrieves the last opened file path for a workspace
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
query := NewQuery(db.dbType).
query := db.NewQuery().
Select("last_opened_file_path").
From("workspaces").
Where("id = ").Placeholder(workspaceID)
@@ -353,7 +330,7 @@ func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
// GetAllWorkspaces retrieves all workspaces in the database
func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
query := NewQuery(db.dbType).
query := db.NewQuery().
Select(
"id", "user_id", "name", "created_at",
"theme", "auto_save", "show_hidden_files",