mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-05 15:44:21 +00:00
Implement update struct
This commit is contained in:
@@ -195,29 +195,6 @@ func (db *database) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper methods for token encryption/decryption
|
func (db *database) NewQuery() *Query {
|
||||||
func (db *database) encryptToken(token string) (string, error) {
|
return NewQuery(db.dbType, db.secretsService)
|
||||||
if token == "" {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
encrypted, err := db.secretsService.Encrypt(token)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to encrypt token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return encrypted, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *database) decryptToken(token string) (string, error) {
|
|
||||||
if token == "" {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
decrypted, err := db.secretsService.Decrypt(token)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to decrypt token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return decrypted, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package db
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"lemma/internal/secrets"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,27 +16,29 @@ const (
|
|||||||
|
|
||||||
// Query represents a SQL query with its parameters
|
// Query represents a SQL query with its parameters
|
||||||
type Query struct {
|
type Query struct {
|
||||||
builder strings.Builder
|
builder strings.Builder
|
||||||
args []any
|
args []any
|
||||||
dbType DBType
|
dbType DBType
|
||||||
pos int // tracks the current placeholder position
|
secretsService secrets.Service
|
||||||
hasSelect bool
|
pos int // tracks the current placeholder position
|
||||||
hasFrom bool
|
hasSelect bool
|
||||||
hasWhere bool
|
hasFrom bool
|
||||||
hasOrderBy bool
|
hasWhere bool
|
||||||
hasGroupBy bool
|
hasOrderBy bool
|
||||||
hasHaving bool
|
hasGroupBy bool
|
||||||
hasLimit bool
|
hasHaving bool
|
||||||
hasOffset bool
|
hasLimit bool
|
||||||
isInParens bool
|
hasOffset bool
|
||||||
parensDepth int
|
isInParens bool
|
||||||
|
parensDepth int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewQuery creates a new Query instance
|
// NewQuery creates a new Query instance
|
||||||
func NewQuery(dbType DBType) *Query {
|
func NewQuery(dbType DBType, secretsService secrets.Service) *Query {
|
||||||
return &Query{
|
return &Query{
|
||||||
dbType: dbType,
|
dbType: dbType,
|
||||||
args: make([]any, 0),
|
secretsService: secretsService,
|
||||||
|
args: make([]any, 0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func TestNewQuery(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
q := db.NewQuery(tt.dbType)
|
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||||
|
|
||||||
// Test that a new query is empty
|
// Test that a new query is empty
|
||||||
if q.String() != "" {
|
if q.String() != "" {
|
||||||
@@ -120,7 +120,7 @@ func TestBasicQueryBuilding(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
q := db.NewQuery(tt.dbType)
|
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||||
q = tt.buildFn(q)
|
q = tt.buildFn(q)
|
||||||
|
|
||||||
gotSQL := q.String()
|
gotSQL := q.String()
|
||||||
@@ -215,7 +215,7 @@ func TestPlaceholders(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
q := db.NewQuery(tt.dbType)
|
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||||
q = tt.buildFn(q)
|
q = tt.buildFn(q)
|
||||||
|
|
||||||
gotSQL := q.String()
|
gotSQL := q.String()
|
||||||
@@ -328,7 +328,7 @@ func TestWhereClauseBuilding(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
q := db.NewQuery(tt.dbType)
|
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||||
q = tt.buildFn(q)
|
q = tt.buildFn(q)
|
||||||
|
|
||||||
gotSQL := q.String()
|
gotSQL := q.String()
|
||||||
@@ -403,7 +403,7 @@ func TestJoinClauseBuilding(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
q := db.NewQuery(tt.dbType)
|
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||||
q = tt.buildFn(q)
|
q = tt.buildFn(q)
|
||||||
|
|
||||||
gotSQL := q.String()
|
gotSQL := q.String()
|
||||||
@@ -482,7 +482,7 @@ func TestOrderLimitOffset(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
q := db.NewQuery(tt.dbType)
|
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||||
q = tt.buildFn(q)
|
q = tt.buildFn(q)
|
||||||
|
|
||||||
gotSQL := q.String()
|
gotSQL := q.String()
|
||||||
@@ -575,7 +575,7 @@ func TestInsertUpdateDelete(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
q := db.NewQuery(tt.dbType)
|
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||||
q = tt.buildFn(q)
|
q = tt.buildFn(q)
|
||||||
|
|
||||||
gotSQL := q.String()
|
gotSQL := q.String()
|
||||||
@@ -641,7 +641,7 @@ func TestHavingClause(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
q := db.NewQuery(tt.dbType)
|
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||||
q = tt.buildFn(q)
|
q = tt.buildFn(q)
|
||||||
|
|
||||||
gotSQL := q.String()
|
gotSQL := q.String()
|
||||||
@@ -790,7 +790,7 @@ func TestQueryReturning(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
query := db.NewQuery(tc.dbType)
|
query := db.NewQuery(tc.dbType, &mockSecrets{})
|
||||||
result := tc.buildQuery(query)
|
result := tc.buildQuery(query)
|
||||||
|
|
||||||
if result.String() != tc.expectedSQL {
|
if result.String() != tc.expectedSQL {
|
||||||
@@ -838,7 +838,7 @@ func TestComplexQueries(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
q := db.NewQuery(tt.dbType)
|
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||||
q = tt.buildFn(q)
|
q = tt.buildFn(q)
|
||||||
|
|
||||||
gotSQL := q.String()
|
gotSQL := q.String()
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
|
|
||||||
// CreateSession inserts a new session record into the database
|
// CreateSession inserts a new session record into the database
|
||||||
func (db *database) CreateSession(session *models.Session) error {
|
func (db *database) CreateSession(session *models.Session) error {
|
||||||
query, err := NewQuery(db.dbType).
|
query, err := db.NewQuery().
|
||||||
InsertStruct(session, "sessions", db.secretsService)
|
InsertStruct(session, "sessions")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create query: %w", err)
|
return fmt.Errorf("failed to create query: %w", err)
|
||||||
}
|
}
|
||||||
@@ -26,7 +26,7 @@ func (db *database) CreateSession(session *models.Session) error {
|
|||||||
// GetSessionByRefreshToken retrieves a session by its refresh token
|
// GetSessionByRefreshToken retrieves a session by its refresh token
|
||||||
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
|
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
|
||||||
session := &models.Session{}
|
session := &models.Session{}
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Select("id", "user_id", "refresh_token", "expires_at", "created_at").
|
Select("id", "user_id", "refresh_token", "expires_at", "created_at").
|
||||||
From("sessions").
|
From("sessions").
|
||||||
Where("refresh_token = ").
|
Where("refresh_token = ").
|
||||||
@@ -48,7 +48,7 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi
|
|||||||
// GetSessionByID retrieves a session by its ID
|
// GetSessionByID retrieves a session by its ID
|
||||||
func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
|
func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
|
||||||
session := &models.Session{}
|
session := &models.Session{}
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Select("id", "user_id", "refresh_token", "expires_at", "created_at").
|
Select("id", "user_id", "refresh_token", "expires_at", "created_at").
|
||||||
From("sessions").
|
From("sessions").
|
||||||
Where("id = ").
|
Where("id = ").
|
||||||
@@ -69,7 +69,7 @@ func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
|
|||||||
|
|
||||||
// DeleteSession removes a session from the database
|
// DeleteSession removes a session from the database
|
||||||
func (db *database) DeleteSession(sessionID string) error {
|
func (db *database) DeleteSession(sessionID string) error {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Delete().
|
Delete().
|
||||||
From("sessions").
|
From("sessions").
|
||||||
Where("id = ").
|
Where("id = ").
|
||||||
@@ -95,7 +95,7 @@ func (db *database) DeleteSession(sessionID string) error {
|
|||||||
// CleanExpiredSessions removes all expired sessions from the database
|
// CleanExpiredSessions removes all expired sessions from the database
|
||||||
func (db *database) CleanExpiredSessions() error {
|
func (db *database) CleanExpiredSessions() error {
|
||||||
log := getLogger().WithGroup("sessions")
|
log := getLogger().WithGroup("sessions")
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Delete().
|
Delete().
|
||||||
From("sessions").
|
From("sessions").
|
||||||
Where("expires_at <=").
|
Where("expires_at <=").
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package db
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"lemma/internal/secrets"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode"
|
"unicode"
|
||||||
@@ -98,7 +97,7 @@ func toSnakeCase(s string) string {
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service) (*Query, error) {
|
func (q *Query) InsertStruct(s any, table string) (*Query, error) {
|
||||||
fields, err := StructTagsToFields(s)
|
fields, err := StructTagsToFields(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -115,7 +114,7 @@ func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service
|
|||||||
}
|
}
|
||||||
|
|
||||||
if f.encrypted {
|
if f.encrypted {
|
||||||
encValue, err := secretsService.Encrypt(value.(string))
|
encValue, err := q.secretsService.Encrypt(value.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -133,3 +132,43 @@ func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service
|
|||||||
q.Insert(table, columns...).Values(len(columns)).AddArgs(values...)
|
q.Insert(table, columns...).Values(len(columns)).AddArgs(values...)
|
||||||
return q, nil
|
return q, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (*Query, error) {
|
||||||
|
fields, err := StructTagsToFields(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(where) != len(args) {
|
||||||
|
return nil, fmt.Errorf("number of where clauses does not match number of arguments")
|
||||||
|
}
|
||||||
|
|
||||||
|
q = q.Update("users")
|
||||||
|
|
||||||
|
for _, f := range fields {
|
||||||
|
value := f.Value
|
||||||
|
|
||||||
|
if f.useDefault {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.encrypted {
|
||||||
|
encValue, err := q.secretsService.Encrypt(value.(string))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
value = encValue
|
||||||
|
}
|
||||||
|
|
||||||
|
q = q.Set(f.Name).Placeholder(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, w := range where {
|
||||||
|
if i != 0 && i < len(args) {
|
||||||
|
q = q.And(w)
|
||||||
|
}
|
||||||
|
q = q.Where(w).Placeholder(args[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return q, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func (db *database) EnsureJWTSecret() (string, error) {
|
|||||||
// GetSystemSetting retrieves a system setting by key
|
// GetSystemSetting retrieves a system setting by key
|
||||||
func (db *database) GetSystemSetting(key string) (string, error) {
|
func (db *database) GetSystemSetting(key string) (string, error) {
|
||||||
var value string
|
var value string
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Select("value").
|
Select("value").
|
||||||
From("system_settings").
|
From("system_settings").
|
||||||
Where("key = ").
|
Where("key = ").
|
||||||
@@ -64,7 +64,7 @@ func (db *database) GetSystemSetting(key string) (string, error) {
|
|||||||
|
|
||||||
// SetSystemSetting stores or updates a system setting
|
// SetSystemSetting stores or updates a system setting
|
||||||
func (db *database) SetSystemSetting(key, value string) error {
|
func (db *database) SetSystemSetting(key, value string) error {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Insert("system_settings", "key", "value").
|
Insert("system_settings", "key", "value").
|
||||||
Values(2).
|
Values(2).
|
||||||
AddArgs(key, value).
|
AddArgs(key, value).
|
||||||
@@ -100,7 +100,7 @@ func (db *database) GetSystemStats() (*UserStats, error) {
|
|||||||
stats := &UserStats{}
|
stats := &UserStats{}
|
||||||
|
|
||||||
// Get total users
|
// Get total users
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Select("COUNT(*)").
|
Select("COUNT(*)").
|
||||||
From("users")
|
From("users")
|
||||||
err := db.QueryRow(query.String()).Scan(&stats.TotalUsers)
|
err := db.QueryRow(query.String()).Scan(&stats.TotalUsers)
|
||||||
@@ -109,7 +109,7 @@ func (db *database) GetSystemStats() (*UserStats, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get total workspaces
|
// Get total workspaces
|
||||||
query = NewQuery(db.dbType).
|
query = db.NewQuery().
|
||||||
Select("COUNT(*)").
|
Select("COUNT(*)").
|
||||||
From("workspaces")
|
From("workspaces")
|
||||||
err = db.QueryRow(query.String()).Scan(&stats.TotalWorkspaces)
|
err = db.QueryRow(query.String()).Scan(&stats.TotalWorkspaces)
|
||||||
@@ -118,7 +118,7 @@ func (db *database) GetSystemStats() (*UserStats, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get active users (users with activity in last 30 days)
|
// Get active users (users with activity in last 30 days)
|
||||||
query = NewQuery(db.dbType).
|
query = db.NewQuery().
|
||||||
Select("COUNT(DISTINCT user_id)").
|
Select("COUNT(DISTINCT user_id)").
|
||||||
From("sessions").
|
From("sessions").
|
||||||
Where("created_at > datetime('now', '-30 days')")
|
Where("created_at > datetime('now', '-30 days')")
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
|
|||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
query, err := NewQuery(db.dbType).
|
query, err := db.NewQuery().
|
||||||
InsertStruct(user, "users", db.secretsService)
|
InsertStruct(user, "users")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create query: %w", err)
|
return nil, fmt.Errorf("failed to create query: %w", err)
|
||||||
@@ -46,7 +46,7 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update user's last workspace ID
|
// Update user's last workspace ID
|
||||||
query = NewQuery(db.dbType).
|
query = db.NewQuery().
|
||||||
Update("users").
|
Update("users").
|
||||||
Set("last_workspace_id").
|
Set("last_workspace_id").
|
||||||
Placeholder(defaultWorkspace.ID).
|
Placeholder(defaultWorkspace.ID).
|
||||||
@@ -72,8 +72,8 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
|
|||||||
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
|
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
|
||||||
log := getLogger().WithGroup("users")
|
log := getLogger().WithGroup("users")
|
||||||
|
|
||||||
insertQuery, err := NewQuery(db.dbType).
|
insertQuery, err := db.NewQuery().
|
||||||
InsertStruct(workspace, "workspaces", db.secretsService)
|
InsertStruct(workspace, "workspaces")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create query: %w", err)
|
return fmt.Errorf("failed to create query: %w", err)
|
||||||
@@ -94,7 +94,7 @@ func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) e
|
|||||||
|
|
||||||
// GetUserByID retrieves a user by its ID
|
// GetUserByID retrieves a user by its ID
|
||||||
func (db *database) GetUserByID(id int) (*models.User, error) {
|
func (db *database) GetUserByID(id int) (*models.User, error) {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id").
|
Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id").
|
||||||
From("users").
|
From("users").
|
||||||
Where("id = ").Placeholder(id)
|
Where("id = ").Placeholder(id)
|
||||||
@@ -115,7 +115,7 @@ func (db *database) GetUserByID(id int) (*models.User, error) {
|
|||||||
|
|
||||||
// GetUserByEmail retrieves a user by its email
|
// GetUserByEmail retrieves a user by its email
|
||||||
func (db *database) GetUserByEmail(email string) (*models.User, error) {
|
func (db *database) GetUserByEmail(email string) (*models.User, error) {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id").
|
Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id").
|
||||||
From("users").
|
From("users").
|
||||||
Where("email = ").Placeholder(email)
|
Where("email = ").Placeholder(email)
|
||||||
@@ -137,7 +137,7 @@ func (db *database) GetUserByEmail(email string) (*models.User, error) {
|
|||||||
|
|
||||||
// UpdateUser updates an existing user record in the database
|
// UpdateUser updates an existing user record in the database
|
||||||
func (db *database) UpdateUser(user *models.User) error {
|
func (db *database) UpdateUser(user *models.User) error {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Update("users").
|
Update("users").
|
||||||
Set("email").Placeholder(user.Email).
|
Set("email").Placeholder(user.Email).
|
||||||
Set("display_name").Placeholder(user.DisplayName).
|
Set("display_name").Placeholder(user.DisplayName).
|
||||||
@@ -165,7 +165,7 @@ func (db *database) UpdateUser(user *models.User) error {
|
|||||||
|
|
||||||
// GetAllUsers retrieves all users from the database
|
// GetAllUsers retrieves all users from the database
|
||||||
func (db *database) GetAllUsers() ([]*models.User, error) {
|
func (db *database) GetAllUsers() ([]*models.User, error) {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Select("id", "email", "display_name", "role", "created_at", "last_workspace_id").
|
Select("id", "email", "display_name", "role", "created_at", "last_workspace_id").
|
||||||
From("users").
|
From("users").
|
||||||
OrderBy("id ASC")
|
OrderBy("id ASC")
|
||||||
@@ -200,7 +200,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error
|
|||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
// Find workspace ID from name
|
// Find workspace ID from name
|
||||||
workspaceQuery := NewQuery(db.dbType).
|
workspaceQuery := db.NewQuery().
|
||||||
Select("id").
|
Select("id").
|
||||||
From("workspaces").
|
From("workspaces").
|
||||||
Where("user_id = ").Placeholder(userID).
|
Where("user_id = ").Placeholder(userID).
|
||||||
@@ -213,7 +213,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update user's last workspace
|
// Update user's last workspace
|
||||||
updateQuery := NewQuery(db.dbType).
|
updateQuery := db.NewQuery().
|
||||||
Update("users").
|
Update("users").
|
||||||
Set("last_workspace_id").Placeholder(workspaceID).
|
Set("last_workspace_id").Placeholder(workspaceID).
|
||||||
Where("id = ").Placeholder(userID)
|
Where("id = ").Placeholder(userID)
|
||||||
@@ -245,7 +245,7 @@ func (db *database) DeleteUser(id int) error {
|
|||||||
// Delete all user's workspaces first
|
// Delete all user's workspaces first
|
||||||
log.Debug("deleting user workspaces", "user_id", id)
|
log.Debug("deleting user workspaces", "user_id", id)
|
||||||
|
|
||||||
deleteWorkspacesQuery := NewQuery(db.dbType).
|
deleteWorkspacesQuery := db.NewQuery().
|
||||||
Delete().
|
Delete().
|
||||||
From("workspaces").
|
From("workspaces").
|
||||||
Where("user_id = ").Placeholder(id)
|
Where("user_id = ").Placeholder(id)
|
||||||
@@ -256,7 +256,7 @@ func (db *database) DeleteUser(id int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Delete the user
|
// Delete the user
|
||||||
deleteUserQuery := NewQuery(db.dbType).
|
deleteUserQuery := db.NewQuery().
|
||||||
Delete().
|
Delete().
|
||||||
From("users").
|
From("users").
|
||||||
Where("id = ").Placeholder(id)
|
Where("id = ").Placeholder(id)
|
||||||
@@ -277,7 +277,7 @@ func (db *database) DeleteUser(id int) error {
|
|||||||
|
|
||||||
// GetLastWorkspaceName retrieves the name of the last workspace accessed by a user
|
// GetLastWorkspaceName retrieves the name of the last workspace accessed by a user
|
||||||
func (db *database) GetLastWorkspaceName(userID int) (string, error) {
|
func (db *database) GetLastWorkspaceName(userID int) (string, error) {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Select("w.name").
|
Select("w.name").
|
||||||
From("workspaces w").
|
From("workspaces w").
|
||||||
Join(InnerJoin, "users u", "u.last_workspace_id = w.id").
|
Join(InnerJoin, "users u", "u.last_workspace_id = w.id").
|
||||||
@@ -298,7 +298,7 @@ func (db *database) GetLastWorkspaceName(userID int) (string, error) {
|
|||||||
|
|
||||||
// CountAdminUsers returns the number of admin users in the system
|
// CountAdminUsers returns the number of admin users in the system
|
||||||
func (db *database) CountAdminUsers() (int, error) {
|
func (db *database) CountAdminUsers() (int, error) {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Select("COUNT(*)").
|
Select("COUNT(*)").
|
||||||
From("users").
|
From("users").
|
||||||
Where("role = ").Placeholder(models.RoleAdmin)
|
Where("role = ").Placeholder(models.RoleAdmin)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error {
|
|||||||
workspace.SetDefaultSettings()
|
workspace.SetDefaultSettings()
|
||||||
}
|
}
|
||||||
|
|
||||||
query, err := NewQuery(db.dbType).
|
query, err := db.NewQuery().
|
||||||
InsertStruct(workspace, "workspaces", db.secretsService)
|
InsertStruct(workspace, "workspaces", db.secretsService)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -39,7 +39,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error {
|
|||||||
|
|
||||||
// GetWorkspaceByID retrieves a workspace by its ID
|
// GetWorkspaceByID retrieves a workspace by its ID
|
||||||
func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
|
func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Select(
|
Select(
|
||||||
"id", "user_id", "name", "created_at",
|
"id", "user_id", "name", "created_at",
|
||||||
"theme", "auto_save", "show_hidden_files",
|
"theme", "auto_save", "show_hidden_files",
|
||||||
@@ -86,7 +86,7 @@ func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
|
|||||||
|
|
||||||
// GetWorkspaceByName retrieves a workspace by its name and user ID
|
// GetWorkspaceByName retrieves a workspace by its name and user ID
|
||||||
func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
|
func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Select(
|
Select(
|
||||||
"id", "user_id", "name", "created_at",
|
"id", "user_id", "name", "created_at",
|
||||||
"theme", "auto_save", "show_hidden_files",
|
"theme", "auto_save", "show_hidden_files",
|
||||||
@@ -133,28 +133,14 @@ func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*model
|
|||||||
|
|
||||||
// UpdateWorkspace updates a workspace record in the database
|
// UpdateWorkspace updates a workspace record in the database
|
||||||
func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
|
func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
|
||||||
// Encrypt token before storing
|
|
||||||
encryptedToken, err := db.encryptToken(workspace.GitToken)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to encrypt token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery()
|
||||||
Update("workspaces").
|
query, err := query.
|
||||||
Set("name").Placeholder(workspace.Name).
|
UpdateStruct(workspace, "workspaces", []string{"id =", "user_id ="}, []interface{}{workspace.ID, workspace.UserID})
|
||||||
Set("theme").Placeholder(workspace.Theme).
|
|
||||||
Set("auto_save").Placeholder(workspace.AutoSave).
|
if err != nil {
|
||||||
Set("show_hidden_files").Placeholder(workspace.ShowHiddenFiles).
|
return fmt.Errorf("failed to create query: %w", err)
|
||||||
Set("git_enabled").Placeholder(workspace.GitEnabled).
|
}
|
||||||
Set("git_url").Placeholder(workspace.GitURL).
|
|
||||||
Set("git_user").Placeholder(workspace.GitUser).
|
|
||||||
Set("git_token").Placeholder(encryptedToken).
|
|
||||||
Set("git_auto_commit").Placeholder(workspace.GitAutoCommit).
|
|
||||||
Set("git_commit_msg_template").Placeholder(workspace.GitCommitMsgTemplate).
|
|
||||||
Set("git_commit_name").Placeholder(workspace.GitCommitName).
|
|
||||||
Set("git_commit_email").Placeholder(workspace.GitCommitEmail).
|
|
||||||
Where("id = ").Placeholder(workspace.ID).
|
|
||||||
And("user_id = ").Placeholder(workspace.UserID)
|
|
||||||
|
|
||||||
_, err = db.Exec(query.String(), query.Args()...)
|
_, err = db.Exec(query.String(), query.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -166,7 +152,7 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
|
|||||||
|
|
||||||
// GetWorkspacesByUserID retrieves all workspaces for a user
|
// GetWorkspacesByUserID retrieves all workspaces for a user
|
||||||
func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
|
func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Select(
|
Select(
|
||||||
"id", "user_id", "name", "created_at",
|
"id", "user_id", "name", "created_at",
|
||||||
"theme", "auto_save", "show_hidden_files",
|
"theme", "auto_save", "show_hidden_files",
|
||||||
@@ -222,26 +208,17 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro
|
|||||||
|
|
||||||
// UpdateWorkspaceSettings updates only the settings portion of a workspace
|
// UpdateWorkspaceSettings updates only the settings portion of a workspace
|
||||||
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
|
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
|
||||||
// Encrypt token before storing
|
|
||||||
encryptedToken, err := db.encryptToken(workspace.GitToken)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to encrypt token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
query := NewQuery(db.dbType).
|
where := []string{"id ="}
|
||||||
Update("workspaces").
|
args := []interface{}{workspace.ID}
|
||||||
Set("theme").Placeholder(workspace.Theme).
|
|
||||||
Set("auto_save").Placeholder(workspace.AutoSave).
|
query := db.NewQuery()
|
||||||
Set("show_hidden_files").Placeholder(workspace.ShowHiddenFiles).
|
query, err := query.
|
||||||
Set("git_enabled").Placeholder(workspace.GitEnabled).
|
UpdateStruct(workspace, "workspaces", where, args)
|
||||||
Set("git_url").Placeholder(workspace.GitURL).
|
|
||||||
Set("git_user").Placeholder(workspace.GitUser).
|
if err != nil {
|
||||||
Set("git_token").Placeholder(encryptedToken).
|
return fmt.Errorf("failed to create query: %w", err)
|
||||||
Set("git_auto_commit").Placeholder(workspace.GitAutoCommit).
|
}
|
||||||
Set("git_commit_msg_template").Placeholder(workspace.GitCommitMsgTemplate).
|
|
||||||
Set("git_commit_name").Placeholder(workspace.GitCommitName).
|
|
||||||
Set("git_commit_email").Placeholder(workspace.GitCommitEmail).
|
|
||||||
Where("id = ").Placeholder(workspace.ID)
|
|
||||||
|
|
||||||
_, err = db.Exec(query.String(), query.Args()...)
|
_, err = db.Exec(query.String(), query.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -255,7 +232,7 @@ func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
|
|||||||
func (db *database) DeleteWorkspace(id int) error {
|
func (db *database) DeleteWorkspace(id int) error {
|
||||||
log := getLogger().WithGroup("workspaces")
|
log := getLogger().WithGroup("workspaces")
|
||||||
|
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Delete().
|
Delete().
|
||||||
From("workspaces").
|
From("workspaces").
|
||||||
Where("id = ").Placeholder(id)
|
Where("id = ").Placeholder(id)
|
||||||
@@ -273,7 +250,7 @@ func (db *database) DeleteWorkspace(id int) error {
|
|||||||
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
|
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
|
||||||
log := getLogger().WithGroup("workspaces")
|
log := getLogger().WithGroup("workspaces")
|
||||||
|
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Delete().
|
Delete().
|
||||||
From("workspaces").
|
From("workspaces").
|
||||||
Where("id = ").Placeholder(id)
|
Where("id = ").Placeholder(id)
|
||||||
@@ -294,7 +271,7 @@ func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
|
|||||||
|
|
||||||
// UpdateLastWorkspaceTx sets the last workspace for a user in a transaction
|
// UpdateLastWorkspaceTx sets the last workspace for a user in a transaction
|
||||||
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
|
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Update("users").
|
Update("users").
|
||||||
Set("last_workspace_id").Placeholder(workspaceID).
|
Set("last_workspace_id").Placeholder(workspaceID).
|
||||||
Where("id = ").Placeholder(userID)
|
Where("id = ").Placeholder(userID)
|
||||||
@@ -314,7 +291,7 @@ func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) e
|
|||||||
|
|
||||||
// UpdateLastOpenedFile updates the last opened file path for a workspace
|
// UpdateLastOpenedFile updates the last opened file path for a workspace
|
||||||
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
|
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Update("workspaces").
|
Update("workspaces").
|
||||||
Set("last_opened_file_path").Placeholder(filePath).
|
Set("last_opened_file_path").Placeholder(filePath).
|
||||||
Where("id = ").Placeholder(workspaceID)
|
Where("id = ").Placeholder(workspaceID)
|
||||||
@@ -329,7 +306,7 @@ func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error
|
|||||||
|
|
||||||
// GetLastOpenedFile retrieves the last opened file path for a workspace
|
// GetLastOpenedFile retrieves the last opened file path for a workspace
|
||||||
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
|
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Select("last_opened_file_path").
|
Select("last_opened_file_path").
|
||||||
From("workspaces").
|
From("workspaces").
|
||||||
Where("id = ").Placeholder(workspaceID)
|
Where("id = ").Placeholder(workspaceID)
|
||||||
@@ -353,7 +330,7 @@ func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
|
|||||||
|
|
||||||
// GetAllWorkspaces retrieves all workspaces in the database
|
// GetAllWorkspaces retrieves all workspaces in the database
|
||||||
func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
|
func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
|
||||||
query := NewQuery(db.dbType).
|
query := db.NewQuery().
|
||||||
Select(
|
Select(
|
||||||
"id", "user_id", "name", "created_at",
|
"id", "user_id", "name", "created_at",
|
||||||
"theme", "auto_save", "show_hidden_files",
|
"theme", "auto_save", "show_hidden_files",
|
||||||
|
|||||||
Reference in New Issue
Block a user