Rework db package to make it testable

This commit is contained in:
2024-11-21 22:36:12 +01:00
parent 2faefb6db5
commit 807e96a76c
16 changed files with 274 additions and 161 deletions

View File

@@ -1,68 +0,0 @@
// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records.
package db
import "novamd/internal/models"
// UserStats represents system-wide statistics
type UserStats struct {
TotalUsers int `json:"totalUsers"`
TotalWorkspaces int `json:"totalWorkspaces"`
ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days
}
// GetAllUsers returns a list of all users in the system
func (db *DB) GetAllUsers() ([]*models.User, error) {
rows, err := db.Query(`
SELECT
id, email, display_name, role, created_at,
last_workspace_id
FROM users
ORDER BY id ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
var users []*models.User
for rows.Next() {
user := &models.User{}
err := rows.Scan(
&user.ID, &user.Email, &user.DisplayName, &user.Role,
&user.CreatedAt, &user.LastWorkspaceID,
)
if err != nil {
return nil, err
}
users = append(users, user)
}
return users, nil
}
// GetSystemStats returns system-wide statistics
func (db *DB) GetSystemStats() (*UserStats, error) {
stats := &UserStats{}
// Get total users
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
if err != nil {
return nil, err
}
// Get total workspaces
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
if err != nil {
return nil, err
}
// Get active users (users with activity in last 30 days)
err = db.QueryRow(`
SELECT COUNT(DISTINCT user_id)
FROM sessions
WHERE created_at > datetime('now', '-30 days')`).
Scan(&stats.ActiveUsers)
if err != nil {
return nil, err
}
return stats, nil
}

View File

@@ -1,3 +1,4 @@
// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records.
package db
import (
@@ -5,18 +6,75 @@ import (
"fmt"
"novamd/internal/crypto"
"novamd/internal/models"
_ "github.com/mattn/go-sqlite3" // SQLite driver
)
// DB represents the database connection
type DB struct {
// UserStore defines the methods for interacting with user data in the database
type UserStore interface {
CreateUser(user *models.User) (*models.User, error)
GetUserByEmail(email string) (*models.User, error)
GetUserByID(userID int) (*models.User, error)
GetAllUsers() ([]*models.User, error)
UpdateUser(user *models.User) error
DeleteUser(userID int) error
UpdateLastWorkspace(userID int, workspaceName string) error
GetLastWorkspaceName(userID int) (string, error)
CountAdminUsers() (int, error)
}
// WorkspaceStore defines the methods for interacting with workspace data in the database
type WorkspaceStore interface {
CreateWorkspace(workspace *models.Workspace) error
GetWorkspaceByID(workspaceID int) (*models.Workspace, error)
GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error)
GetWorkspacesByUserID(userID int) ([]*models.Workspace, error)
UpdateWorkspace(workspace *models.Workspace) error
DeleteWorkspace(workspaceID int) error
UpdateWorkspaceSettings(workspace *models.Workspace) error
DeleteWorkspaceTx(tx *sql.Tx, workspaceID int) error
UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error
UpdateLastOpenedFile(workspaceID int, filePath string) error
GetLastOpenedFile(workspaceID int) (string, error)
GetAllWorkspaces() ([]*models.Workspace, error)
}
// SessionStore defines the methods for interacting with jwt sessions in the database
type SessionStore interface {
CreateSession(session *models.Session) error
GetSessionByRefreshToken(refreshToken string) (*models.Session, error)
DeleteSession(sessionID string) error
CleanExpiredSessions() error
}
// SystemStore defines the methods for interacting with system settings and stats in the database
type SystemStore interface {
GetSystemStats() (*UserStats, error)
EnsureJWTSecret() (string, error)
GetSystemSetting(key string) (string, error)
SetSystemSetting(key, value string) error
}
// Database defines the methods for interacting with the database
type Database interface {
UserStore
WorkspaceStore
SessionStore
SystemStore
Begin() (*sql.Tx, error)
Close() error
Migrate() error
}
// database represents the database connection
type database struct {
*sql.DB
crypto *crypto.Crypto
}
// Init initializes the database connection
func Init(dbPath string, encryptionKey string) (*DB, error) {
func Init(dbPath string, encryptionKey string) (Database, error) {
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, err
@@ -32,7 +90,7 @@ func Init(dbPath string, encryptionKey string) (*DB, error) {
return nil, fmt.Errorf("failed to initialize encryption: %w", err)
}
database := &DB{
database := &database{
DB: db,
crypto: cryptoService,
}
@@ -45,19 +103,19 @@ func Init(dbPath string, encryptionKey string) (*DB, error) {
}
// Close closes the database connection
func (db *DB) Close() error {
func (db *database) Close() error {
return db.DB.Close()
}
// Helper methods for token encryption/decryption
func (db *DB) encryptToken(token string) (string, error) {
func (db *database) encryptToken(token string) (string, error) {
if token == "" {
return "", nil
}
return db.crypto.Encrypt(token)
}
func (db *DB) decryptToken(token string) (string, error) {
func (db *database) decryptToken(token string) (string, error) {
if token == "" {
return "", nil
}

View File

@@ -83,7 +83,7 @@ var migrations = []Migration{
}
// Migrate applies all database migrations
func (db *DB) Migrate() error {
func (db *database) Migrate() error {
// Create migrations table if it doesn't exist
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations (
version INTEGER PRIMARY KEY

View File

@@ -0,0 +1,70 @@
package db
import (
"database/sql"
"fmt"
"time"
"novamd/internal/models"
)
// CreateSession inserts a new session record into the database
func (db *database) CreateSession(session *models.Session) error {
_, err := db.Exec(`
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
VALUES (?, ?, ?, ?, ?)`,
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
)
if err != nil {
return fmt.Errorf("failed to store session: %w", err)
}
return nil
}
// GetSessionByRefreshToken retrieves a session by its refresh token
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
session := &models.Session{}
err := db.QueryRow(`
SELECT id, user_id, refresh_token, expires_at, created_at
FROM sessions
WHERE refresh_token = ? AND expires_at > ?`,
refreshToken, time.Now(),
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("session not found or expired")
}
if err != nil {
return nil, fmt.Errorf("failed to fetch session: %w", err)
}
return session, nil
}
// DeleteSession removes a session from the database
func (db *database) DeleteSession(sessionID string) error {
result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
if err != nil {
return fmt.Errorf("failed to delete session: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("session not found")
}
return nil
}
// CleanExpiredSessions removes all expired sessions from the database
func (db *database) CleanExpiredSessions() error {
_, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
if err != nil {
return fmt.Errorf("failed to clean expired sessions: %w", err)
}
return nil
}

View File

@@ -11,9 +11,16 @@ const (
JWTSecretKey = "jwt_secret"
)
// UserStats represents system-wide statistics
type UserStats struct {
TotalUsers int `json:"totalUsers"`
TotalWorkspaces int `json:"totalWorkspaces"`
ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days
}
// EnsureJWTSecret makes sure a JWT signing secret exists in the database
// If no secret exists, it generates and stores a new one
func (db *DB) EnsureJWTSecret() (string, error) {
func (db *database) EnsureJWTSecret() (string, error) {
// First, try to get existing secret
secret, err := db.GetSystemSetting(JWTSecretKey)
if err == nil {
@@ -36,7 +43,7 @@ func (db *DB) EnsureJWTSecret() (string, error) {
}
// GetSystemSetting retrieves a system setting by key
func (db *DB) GetSystemSetting(key string) (string, error) {
func (db *database) GetSystemSetting(key string) (string, error) {
var value string
err := db.QueryRow("SELECT value FROM system_settings WHERE key = ?", key).Scan(&value)
if err != nil {
@@ -46,7 +53,7 @@ func (db *DB) GetSystemSetting(key string) (string, error) {
}
// SetSystemSetting stores or updates a system setting
func (db *DB) SetSystemSetting(key, value string) error {
func (db *database) SetSystemSetting(key, value string) error {
_, err := db.Exec(`
INSERT INTO system_settings (key, value)
VALUES (?, ?)
@@ -64,3 +71,32 @@ func generateRandomSecret(bytes int) (string, error) {
}
return base64.StdEncoding.EncodeToString(b), nil
}
// GetSystemStats returns system-wide statistics
func (db *database) GetSystemStats() (*UserStats, error) {
stats := &UserStats{}
// Get total users
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
if err != nil {
return nil, err
}
// Get total workspaces
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
if err != nil {
return nil, err
}
// Get active users (users with activity in last 30 days)
err = db.QueryRow(`
SELECT COUNT(DISTINCT user_id)
FROM sessions
WHERE created_at > datetime('now', '-30 days')`).
Scan(&stats.ActiveUsers)
if err != nil {
return nil, err
}
return stats, nil
}

View File

@@ -6,7 +6,7 @@ import (
)
// CreateUser inserts a new user record into the database
func (db *DB) CreateUser(user *models.User) (*models.User, error) {
func (db *database) CreateUser(user *models.User) (*models.User, error) {
tx, err := db.Begin()
if err != nil {
return nil, err
@@ -62,7 +62,7 @@ func (db *DB) CreateUser(user *models.User) (*models.User, error) {
}
// Helper function to create a workspace in a transaction
func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
result, err := tx.Exec(`
INSERT INTO workspaces (
user_id, name,
@@ -87,7 +87,7 @@ func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
}
// GetUserByID retrieves a user by ID
func (db *DB) GetUserByID(id int) (*models.User, error) {
func (db *database) GetUserByID(id int) (*models.User, error) {
user := &models.User{}
err := db.QueryRow(`
SELECT
@@ -104,7 +104,7 @@ func (db *DB) GetUserByID(id int) (*models.User, error) {
}
// GetUserByEmail retrieves a user by email
func (db *DB) GetUserByEmail(email string) (*models.User, error) {
func (db *database) GetUserByEmail(email string) (*models.User, error) {
user := &models.User{}
err := db.QueryRow(`
SELECT
@@ -122,7 +122,7 @@ func (db *DB) GetUserByEmail(email string) (*models.User, error) {
}
// UpdateUser updates a user's information
func (db *DB) UpdateUser(user *models.User) error {
func (db *database) UpdateUser(user *models.User) error {
_, err := db.Exec(`
UPDATE users
SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ?
@@ -131,8 +131,36 @@ func (db *DB) UpdateUser(user *models.User) error {
return err
}
// GetAllUsers returns a list of all users in the system
func (db *database) GetAllUsers() ([]*models.User, error) {
rows, err := db.Query(`
SELECT
id, email, display_name, role, created_at,
last_workspace_id
FROM users
ORDER BY id ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
var users []*models.User
for rows.Next() {
user := &models.User{}
err := rows.Scan(
&user.ID, &user.Email, &user.DisplayName, &user.Role,
&user.CreatedAt, &user.LastWorkspaceID,
)
if err != nil {
return nil, err
}
users = append(users, user)
}
return users, nil
}
// UpdateLastWorkspace updates the last workspace the user accessed
func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error {
func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error {
tx, err := db.Begin()
if err != nil {
return err
@@ -155,7 +183,7 @@ func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error {
}
// DeleteUser deletes a user and all their workspaces
func (db *DB) DeleteUser(id int) error {
func (db *database) DeleteUser(id int) error {
tx, err := db.Begin()
if err != nil {
return err
@@ -178,7 +206,7 @@ func (db *DB) DeleteUser(id int) error {
}
// GetLastWorkspaceName returns the name of the last workspace the user accessed
func (db *DB) GetLastWorkspaceName(userID int) (string, error) {
func (db *database) GetLastWorkspaceName(userID int) (string, error) {
var workspaceName string
err := db.QueryRow(`
SELECT
@@ -189,3 +217,10 @@ func (db *DB) GetLastWorkspaceName(userID int) (string, error) {
Scan(&workspaceName)
return workspaceName, err
}
// CountAdminUsers returns the number of admin users in the system
func (db *database) CountAdminUsers() (int, error) {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count)
return count, err
}

View File

@@ -7,7 +7,7 @@ import (
)
// CreateWorkspace inserts a new workspace record into the database
func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
func (db *database) CreateWorkspace(workspace *models.Workspace) error {
// Set default settings if not provided
if workspace.Theme == "" {
workspace.GetDefaultSettings()
@@ -42,7 +42,7 @@ func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
}
// GetWorkspaceByID retrieves a workspace by its ID
func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) {
func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
workspace := &models.Workspace{}
var encryptedToken string
@@ -75,7 +75,7 @@ func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) {
}
// GetWorkspaceByName retrieves a workspace by its name and user ID
func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
workspace := &models.Workspace{}
var encryptedToken string
@@ -108,7 +108,7 @@ func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Work
}
// UpdateWorkspace updates a workspace record in the database
func (db *DB) UpdateWorkspace(workspace *models.Workspace) error {
func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
// Encrypt token before storing
encryptedToken, err := db.encryptToken(workspace.GitToken)
if err != nil {
@@ -146,7 +146,7 @@ func (db *DB) UpdateWorkspace(workspace *models.Workspace) error {
}
// GetWorkspacesByUserID retrieves all workspaces for a user
func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
rows, err := db.Query(`
SELECT
id, user_id, name, created_at,
@@ -189,7 +189,7 @@ func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
// UpdateWorkspaceSettings updates only the settings portion of a workspace
// This is useful when you don't want to modify the name or other core workspace properties
func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error {
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
_, err := db.Exec(`
UPDATE workspaces
SET
@@ -218,31 +218,31 @@ func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error {
}
// DeleteWorkspace removes a workspace record from the database
func (db *DB) DeleteWorkspace(id int) error {
func (db *database) DeleteWorkspace(id int) error {
_, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id)
return err
}
// DeleteWorkspaceTx removes a workspace record from the database within a transaction
func (db *DB) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
_, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id)
return err
}
// UpdateLastWorkspaceTx sets the last workspace for a user in with a transaction
func (db *DB) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
_, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID)
return err
}
// UpdateLastOpenedFile updates the last opened file path for a workspace
func (db *DB) UpdateLastOpenedFile(workspaceID int, filePath string) error {
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
_, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?", filePath, workspaceID)
return err
}
// GetLastOpenedFile retrieves the last opened file path for a workspace
func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) {
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
var filePath sql.NullString
err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?", workspaceID).Scan(&filePath)
if err != nil {
@@ -255,7 +255,7 @@ func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) {
}
// GetAllWorkspaces retrieves all workspaces in the database
func (db *DB) GetAllWorkspaces() ([]*models.Workspace, error) {
func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
rows, err := db.Query(`
SELECT
id, user_id, name, created_at,