diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go index 3484307..ec86764 100644 --- a/server/internal/db/struct_query.go +++ b/server/internal/db/struct_query.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "reflect" + "sort" "strings" "unicode" ) @@ -81,6 +82,11 @@ func StructTagsToFields(s any) ([]DBField, error) { encrypted: encrypted, }) } + + sort.Slice(fields, func(i, j int) bool { + return fields[i].Name < fields[j].Name + }) + return fields, nil } @@ -176,17 +182,34 @@ func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (* return q, nil } -func (db *database) ScanStruct(row *sql.Row, dest any) error { - // Get the fields of the destination struct - fields, err := StructTagsToFields(dest) +func (q *Query) SelectStruct(s any, table string) (*Query, error) { + fields, err := StructTagsToFields(s) + if err != nil { + return nil, err + } + + columns := make([]string, 0, len(fields)) + for _, f := range fields { + columns = append(columns, f.Name) + } + + q = q.Select(columns...).From(table) + return q, nil +} + +// Scanner is an interface that both sql.Row and sql.Rows satisfy +type Scanner interface { + Scan(dest ...interface{}) error +} + +// scanStructInstance is an internal function that handles the scanning logic for a single instance +func (db *database) scanStructInstance(destVal reflect.Value, scanner Scanner) error { + fields, err := StructTagsToFields(destVal.Interface()) if err != nil { return fmt.Errorf("failed to extract struct fields: %w", err) } - // Create a slice of pointers to hold the scan destinations scanDest := make([]interface{}, len(fields)) - destVal := reflect.ValueOf(dest).Elem() - var fieldsToDecrypt []string nullStringIndexes := make(map[int]reflect.Value) @@ -202,8 +225,8 @@ func (db *database) ScanStruct(row *sql.Row, dest any) error { } if structField.Kind() == reflect.String { + // Handle null strings separately nullStringIndexes[i] = structField - var ns sql.NullString scanDest[i] = &ns } else { @@ -211,12 +234,12 @@ func (db *database) ScanStruct(row *sql.Row, dest any) error { } } - // Scan the row into the destination pointers - if err := row.Scan(scanDest...); err != nil { + // Scan using the scanner interface + if err := scanner.Scan(scanDest...); err != nil { return err } - // Set null strings to nil if they are null + // Set null strings to their values if they are valid for i, field := range nullStringIndexes { ns := scanDest[i].(*sql.NullString) if ns.Valid { @@ -227,11 +250,94 @@ func (db *database) ScanStruct(row *sql.Row, dest any) error { // Decrypt encrypted fields for _, fieldName := range fieldsToDecrypt { field := destVal.FieldByName(fieldName) - decValue, err := db.secretsService.Decrypt(field.Interface().(string)) - if err != nil { - return err + if !field.IsZero() { + decValue, err := db.secretsService.Decrypt(field.Interface().(string)) + if err != nil { + return err + } + field.SetString(decValue) } - field.SetString(decValue) + } + + return nil +} + +// ScanStruct scans a single row into a struct +func (db *database) ScanStruct(row *sql.Row, dest any) error { + if row == nil { + return fmt.Errorf("row cannot be nil") + } + + if row.Err() != nil { + return row.Err() + } + + // Get the destination value + destVal := reflect.ValueOf(dest) + if destVal.Kind() != reflect.Ptr || destVal.IsNil() { + return fmt.Errorf("destination must be a non-nil pointer to a struct, got %T", dest) + } + + destVal = destVal.Elem() + if destVal.Kind() != reflect.Struct { + return fmt.Errorf("destination must be a pointer to a struct, got pointer to %s", destVal.Kind()) + } + + return db.scanStructInstance(destVal, row) +} + +// ScanStructs scans multiple rows into a slice of structs +func (db *database) ScanStructs(rows *sql.Rows, destSlice any) error { + if rows == nil { + return fmt.Errorf("rows cannot be nil") + } + defer rows.Close() + + // Get the slice value and element type + sliceVal := reflect.ValueOf(destSlice) + if sliceVal.Kind() != reflect.Ptr || sliceVal.IsNil() { + return fmt.Errorf("destination must be a non-nil pointer to a slice, got %T", destSlice) + } + + sliceVal = sliceVal.Elem() + if sliceVal.Kind() != reflect.Slice { + return fmt.Errorf("destination must be a pointer to a slice, got pointer to %s", sliceVal.Kind()) + } + + // Get the element type of the slice + elemType := sliceVal.Type().Elem() + + // Check if we have a direct struct type or a pointer to struct + isPtr := elemType.Kind() == reflect.Ptr + structType := elemType + if isPtr { + structType = elemType.Elem() + } + if structType.Kind() != reflect.Struct { + return fmt.Errorf("slice element type must be a struct or pointer to struct, got %s", elemType.String()) + } + + // Process each row + for rows.Next() { + // Create a new instance of the struct for each row + newElem := reflect.New(structType).Elem() + + // Scan this row into the new element + if err := db.scanStructInstance(newElem, rows); err != nil { + return err + } + + // Add the new element to the result slice + if isPtr { + sliceVal.Set(reflect.Append(sliceVal, newElem.Addr())) + } else { + sliceVal.Set(reflect.Append(sliceVal, newElem)) + } + } + + // Check for errors from iterating over rows + if err := rows.Err(); err != nil { + return err } return nil diff --git a/server/internal/db/workspaces.go b/server/internal/db/workspaces.go index 64ebd7e..1783ad3 100644 --- a/server/internal/db/workspaces.go +++ b/server/internal/db/workspaces.go @@ -20,7 +20,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error { } query, err := db.NewQuery(). - InsertStruct(workspace, "workspaces", db.secretsService) + InsertStruct(workspace, "workspaces") if err != nil { return fmt.Errorf("failed to create query: %w", err) @@ -39,30 +39,16 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error { // GetWorkspaceByID retrieves a workspace by its ID func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { - query := db.NewQuery(). - Select( - "id", "user_id", "name", "created_at", - "theme", "auto_save", "show_hidden_files", - "git_enabled", "git_url", "git_user", "git_token", - "git_auto_commit", "git_commit_msg_template", - "git_commit_name", "git_commit_email", - "last_opened_file_path"). - From("workspaces"). - Where("id = ").Placeholder(id) - workspace := &models.Workspace{} - var encryptedToken string + query := db.NewQuery() + query, err := query.SelectStruct(workspace, "workspaces") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + query = query.Where("id = ").Placeholder(id) - var lastOpenedFile sql.NullString - - err := db.QueryRow(query.String(), query.Args()...).Scan( - &workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt, - &workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles, - &workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken, - &workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, - &workspace.GitCommitName, &workspace.GitCommitEmail, - &lastOpenedFile, - ) + row := db.QueryRow(query.String(), query.Args()...) + err = db.ScanStruct(row, workspace) if err == sql.ErrNoRows { return nil, fmt.Errorf("workspace not found") @@ -71,45 +57,22 @@ func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { return nil, fmt.Errorf("failed to fetch workspace: %w", err) } - if lastOpenedFile.Valid { - workspace.LastOpenedFilePath = lastOpenedFile.String - } - - // Decrypt token - workspace.GitToken, err = db.decryptToken(encryptedToken) - if err != nil { - return nil, fmt.Errorf("failed to decrypt token: %w", err) - } - return workspace, nil } // GetWorkspaceByName retrieves a workspace by its name and user ID func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) { - query := db.NewQuery(). - Select( - "id", "user_id", "name", "created_at", - "theme", "auto_save", "show_hidden_files", - "git_enabled", "git_url", "git_user", "git_token", - "git_auto_commit", "git_commit_msg_template", - "git_commit_name", "git_commit_email", - "last_opened_file_path"). - From("workspaces"). - Where("user_id = ").Placeholder(userID). + workspace := &models.Workspace{} + query := db.NewQuery() + query, err := query.SelectStruct(workspace, "workspaces") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + query = query.Where("user_id = ").Placeholder(userID). And("name = ").Placeholder(workspaceName) - workspace := &models.Workspace{} - var encryptedToken string - var lastOpenedFile sql.NullString - - err := db.QueryRow(query.String(), query.Args()...).Scan( - &workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt, - &workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles, - &workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken, - &workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, - &workspace.GitCommitName, &workspace.GitCommitEmail, - &lastOpenedFile, - ) + row := db.QueryRow(query.String(), query.Args()...) + err = db.ScanStruct(row, workspace) if err == sql.ErrNoRows { return nil, fmt.Errorf("workspace not found") @@ -118,16 +81,6 @@ func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*model return nil, fmt.Errorf("failed to fetch workspace: %w", err) } - if lastOpenedFile.Valid { - workspace.LastOpenedFilePath = lastOpenedFile.String - } - - // Decrypt token - workspace.GitToken, err = db.decryptToken(encryptedToken) - if err != nil { - return nil, fmt.Errorf("failed to decrypt token: %w", err) - } - return workspace, nil } @@ -136,7 +89,7 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error { query := db.NewQuery() query, err := query. - UpdateStruct(workspace, "workspaces", []string{"id =", "user_id ="}, []interface{}{workspace.ID, workspace.UserID}) + UpdateStruct(workspace, "workspaces", []string{"id =", "user_id ="}, []any{workspace.ID, workspace.UserID}) if err != nil { return fmt.Errorf("failed to create query: %w", err) @@ -152,16 +105,13 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error { // GetWorkspacesByUserID retrieves all workspaces for a user func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { - query := db.NewQuery(). - Select( - "id", "user_id", "name", "created_at", - "theme", "auto_save", "show_hidden_files", - "git_enabled", "git_url", "git_user", "git_token", - "git_auto_commit", "git_commit_msg_template", - "git_commit_name", "git_commit_email", - "last_opened_file_path"). - From("workspaces"). - Where("user_id = ").Placeholder(userID) + workspace := &models.Workspace{} + query := db.NewQuery() + query, err := query.SelectStruct(workspace, "workspaces") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + query = query.Where("user_id = ").Placeholder(userID) rows, err := db.Query(query.String(), query.Args()...) if err != nil { @@ -170,37 +120,9 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro defer rows.Close() var workspaces []*models.Workspace - for rows.Next() { - workspace := &models.Workspace{} - var encryptedToken string - var lastOpenedFile sql.NullString - err := rows.Scan( - &workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt, - &workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles, - &workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken, - &workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, - &workspace.GitCommitName, &workspace.GitCommitEmail, - &lastOpenedFile, - ) - if err != nil { - return nil, fmt.Errorf("failed to scan workspace row: %w", err) - } - - if lastOpenedFile.Valid { - workspace.LastOpenedFilePath = lastOpenedFile.String - } - - // Decrypt token - workspace.GitToken, err = db.decryptToken(encryptedToken) - if err != nil { - return nil, fmt.Errorf("failed to decrypt token: %w", err) - } - - workspaces = append(workspaces, workspace) - } - - if err = rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating workspace rows: %w", err) + err = db.ScanStructs(rows, workspaces) + if err != nil { + return nil, fmt.Errorf("failed to scan workspaces: %w", err) } return workspaces, nil @@ -210,7 +132,7 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { where := []string{"id ="} - args := []interface{}{workspace.ID} + args := []any{workspace.ID} query := db.NewQuery() query, err := query. @@ -330,15 +252,11 @@ func (db *database) GetLastOpenedFile(workspaceID int) (string, error) { // GetAllWorkspaces retrieves all workspaces in the database func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { - query := db.NewQuery(). - Select( - "id", "user_id", "name", "created_at", - "theme", "auto_save", "show_hidden_files", - "git_enabled", "git_url", "git_user", "git_token", - "git_auto_commit", "git_commit_msg_template", - "git_commit_name", "git_commit_email", - "last_opened_file_path"). - From("workspaces") + query := db.NewQuery() + query, err := query.SelectStruct(&models.Workspace{}, "workspaces") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } rows, err := db.Query(query.String(), query.Args()...) if err != nil { @@ -347,38 +265,9 @@ func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { defer rows.Close() var workspaces []*models.Workspace - for rows.Next() { - workspace := &models.Workspace{} - var encryptedToken string - var lastOpenedFile sql.NullString - - err := rows.Scan( - &workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt, - &workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles, - &workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken, - &workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, - &workspace.GitCommitName, &workspace.GitCommitEmail, - &lastOpenedFile, - ) - if err != nil { - return nil, fmt.Errorf("failed to scan workspace row: %w", err) - } - - if lastOpenedFile.Valid { - workspace.LastOpenedFilePath = lastOpenedFile.String - } - - // Decrypt token - workspace.GitToken, err = db.decryptToken(encryptedToken) - if err != nil { - return nil, fmt.Errorf("failed to decrypt token: %w", err) - } - - workspaces = append(workspaces, workspace) - } - - if err = rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating workspace rows: %w", err) + err = db.ScanStructs(rows, workspaces) + if err != nil { + return nil, fmt.Errorf("failed to scan workspaces: %w", err) } return workspaces, nil