Rework db package to make it testable

This commit is contained in:
2024-11-21 22:36:12 +01:00
parent 2faefb6db5
commit 807e96a76c
16 changed files with 274 additions and 161 deletions

View File

@@ -1,3 +1,4 @@
// Package main contains the main entry point for the application. It sets up the server, database, and other services, and starts the server.
package main package main
import ( import (
@@ -61,7 +62,7 @@ func main() {
authMiddleware := auth.NewMiddleware(jwtManager) authMiddleware := auth.NewMiddleware(jwtManager)
// Initialize session service // Initialize session service
sessionService := auth.NewSessionService(database.DB, jwtManager) sessionService := auth.NewSessionService(database, jwtManager)
// Set up router // Set up router
r := chi.NewRouter() r := chi.NewRouter()

View File

@@ -12,7 +12,7 @@ import (
) )
// SetupRoutes configures the API routes // SetupRoutes configures the API routes
func SetupRoutes(r chi.Router, db *db.DB, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) { func SetupRoutes(r chi.Router, db db.Database, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) {
handler := &handlers.Handler{ handler := &handlers.Handler{
DB: db, DB: db,

View File

@@ -1,25 +1,17 @@
package auth package auth
import ( import (
"database/sql"
"fmt" "fmt"
"novamd/internal/db"
"novamd/internal/models"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
) )
// Session represents a user session in the database
type Session struct {
ID string // Unique session identifier
UserID int // ID of the user this session belongs to
RefreshToken string // The refresh token associated with this session
ExpiresAt time.Time // When this session expires
CreatedAt time.Time // When this session was created
}
// SessionService manages user sessions in the database // SessionService manages user sessions in the database
type SessionService struct { type SessionService struct {
db *sql.DB // Database connection db db.SessionStore // Database store for sessions
jwtManager JWTManager // JWT Manager for token operations jwtManager JWTManager // JWT Manager for token operations
} }
@@ -27,7 +19,7 @@ type SessionService struct {
// Parameters: // Parameters:
// - db: database connection // - db: database connection
// - jwtManager: JWT service for token operations // - jwtManager: JWT service for token operations
func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService { func NewSessionService(db db.SessionStore, jwtManager JWTManager) *SessionService {
return &SessionService{ return &SessionService{
db: db, db: db,
jwtManager: jwtManager, jwtManager: jwtManager,
@@ -42,7 +34,7 @@ func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService {
// - session: the created session // - session: the created session
// - accessToken: a new access token // - accessToken: a new access token
// - error: any error that occurred // - error: any error that occurred
func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) { func (s *SessionService) CreateSession(userID int, role string) (*models.Session, string, error) {
// Generate both access and refresh tokens // Generate both access and refresh tokens
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role) accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
if err != nil { if err != nil {
@@ -61,7 +53,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
} }
// Create a new session record // Create a new session record
session := &Session{ session := &models.Session{
ID: uuid.New().String(), ID: uuid.New().String(),
UserID: userID, UserID: userID,
RefreshToken: refreshToken, RefreshToken: refreshToken,
@@ -69,14 +61,9 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
// Store the session in the database // Store the session
_, err = s.db.Exec(` if err := s.db.CreateSession(session); err != nil {
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at) return nil, "", err
VALUES (?, ?, ?, ?, ?)`,
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
)
if err != nil {
return nil, "", fmt.Errorf("failed to store session: %w", err)
} }
return session, accessToken, nil return session, accessToken, nil
@@ -89,28 +76,18 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
// - string: a new access token // - string: a new access token
// - error: any error that occurred // - error: any error that occurred
func (s *SessionService) RefreshSession(refreshToken string) (string, error) { func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
// Get session from database
_, err := s.db.GetSessionByRefreshToken(refreshToken)
if err != nil {
return "", fmt.Errorf("invalid session: %w", err)
}
// Validate the refresh token // Validate the refresh token
claims, err := s.jwtManager.ValidateToken(refreshToken) claims, err := s.jwtManager.ValidateToken(refreshToken)
if err != nil { if err != nil {
return "", fmt.Errorf("invalid refresh token: %w", err) return "", fmt.Errorf("invalid refresh token: %w", err)
} }
// Check if the session exists and is not expired
var session Session
err = s.db.QueryRow(`
SELECT id, user_id, refresh_token, expires_at, created_at
FROM sessions
WHERE refresh_token = ? AND expires_at > ?`,
refreshToken, time.Now(),
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
if err == sql.ErrNoRows {
return "", fmt.Errorf("session not found or expired")
}
if err != nil {
return "", fmt.Errorf("failed to fetch session: %w", err)
}
// Generate a new access token // Generate a new access token
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role) return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
} }
@@ -121,20 +98,12 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
// Returns: // Returns:
// - error: any error that occurred // - error: any error that occurred
func (s *SessionService) InvalidateSession(sessionID string) error { func (s *SessionService) InvalidateSession(sessionID string) error {
_, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID) return s.db.DeleteSession(sessionID)
if err != nil {
return fmt.Errorf("failed to invalidate session: %w", err)
}
return nil
} }
// CleanExpiredSessions removes all expired sessions from the database // CleanExpiredSessions removes all expired sessions from the database
// Returns: // Returns:
// - error: any error that occurred // - error: any error that occurred
func (s *SessionService) CleanExpiredSessions() error { func (s *SessionService) CleanExpiredSessions() error {
_, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now()) return s.db.CleanExpiredSessions()
if err != nil {
return fmt.Errorf("failed to clean expired sessions: %w", err)
}
return nil
} }

View File

@@ -1,68 +0,0 @@
// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records.
package db
import "novamd/internal/models"
// UserStats represents system-wide statistics
type UserStats struct {
TotalUsers int `json:"totalUsers"`
TotalWorkspaces int `json:"totalWorkspaces"`
ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days
}
// GetAllUsers returns a list of all users in the system
func (db *DB) GetAllUsers() ([]*models.User, error) {
rows, err := db.Query(`
SELECT
id, email, display_name, role, created_at,
last_workspace_id
FROM users
ORDER BY id ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
var users []*models.User
for rows.Next() {
user := &models.User{}
err := rows.Scan(
&user.ID, &user.Email, &user.DisplayName, &user.Role,
&user.CreatedAt, &user.LastWorkspaceID,
)
if err != nil {
return nil, err
}
users = append(users, user)
}
return users, nil
}
// GetSystemStats returns system-wide statistics
func (db *DB) GetSystemStats() (*UserStats, error) {
stats := &UserStats{}
// Get total users
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
if err != nil {
return nil, err
}
// Get total workspaces
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
if err != nil {
return nil, err
}
// Get active users (users with activity in last 30 days)
err = db.QueryRow(`
SELECT COUNT(DISTINCT user_id)
FROM sessions
WHERE created_at > datetime('now', '-30 days')`).
Scan(&stats.ActiveUsers)
if err != nil {
return nil, err
}
return stats, nil
}

View File

@@ -1,3 +1,4 @@
// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records.
package db package db
import ( import (
@@ -5,18 +6,75 @@ import (
"fmt" "fmt"
"novamd/internal/crypto" "novamd/internal/crypto"
"novamd/internal/models"
_ "github.com/mattn/go-sqlite3" // SQLite driver _ "github.com/mattn/go-sqlite3" // SQLite driver
) )
// DB represents the database connection // UserStore defines the methods for interacting with user data in the database
type DB struct { type UserStore interface {
CreateUser(user *models.User) (*models.User, error)
GetUserByEmail(email string) (*models.User, error)
GetUserByID(userID int) (*models.User, error)
GetAllUsers() ([]*models.User, error)
UpdateUser(user *models.User) error
DeleteUser(userID int) error
UpdateLastWorkspace(userID int, workspaceName string) error
GetLastWorkspaceName(userID int) (string, error)
CountAdminUsers() (int, error)
}
// WorkspaceStore defines the methods for interacting with workspace data in the database
type WorkspaceStore interface {
CreateWorkspace(workspace *models.Workspace) error
GetWorkspaceByID(workspaceID int) (*models.Workspace, error)
GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error)
GetWorkspacesByUserID(userID int) ([]*models.Workspace, error)
UpdateWorkspace(workspace *models.Workspace) error
DeleteWorkspace(workspaceID int) error
UpdateWorkspaceSettings(workspace *models.Workspace) error
DeleteWorkspaceTx(tx *sql.Tx, workspaceID int) error
UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error
UpdateLastOpenedFile(workspaceID int, filePath string) error
GetLastOpenedFile(workspaceID int) (string, error)
GetAllWorkspaces() ([]*models.Workspace, error)
}
// SessionStore defines the methods for interacting with jwt sessions in the database
type SessionStore interface {
CreateSession(session *models.Session) error
GetSessionByRefreshToken(refreshToken string) (*models.Session, error)
DeleteSession(sessionID string) error
CleanExpiredSessions() error
}
// SystemStore defines the methods for interacting with system settings and stats in the database
type SystemStore interface {
GetSystemStats() (*UserStats, error)
EnsureJWTSecret() (string, error)
GetSystemSetting(key string) (string, error)
SetSystemSetting(key, value string) error
}
// Database defines the methods for interacting with the database
type Database interface {
UserStore
WorkspaceStore
SessionStore
SystemStore
Begin() (*sql.Tx, error)
Close() error
Migrate() error
}
// database represents the database connection
type database struct {
*sql.DB *sql.DB
crypto *crypto.Crypto crypto *crypto.Crypto
} }
// Init initializes the database connection // Init initializes the database connection
func Init(dbPath string, encryptionKey string) (*DB, error) { func Init(dbPath string, encryptionKey string) (Database, error) {
db, err := sql.Open("sqlite3", dbPath) db, err := sql.Open("sqlite3", dbPath)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -32,7 +90,7 @@ func Init(dbPath string, encryptionKey string) (*DB, error) {
return nil, fmt.Errorf("failed to initialize encryption: %w", err) return nil, fmt.Errorf("failed to initialize encryption: %w", err)
} }
database := &DB{ database := &database{
DB: db, DB: db,
crypto: cryptoService, crypto: cryptoService,
} }
@@ -45,19 +103,19 @@ func Init(dbPath string, encryptionKey string) (*DB, error) {
} }
// Close closes the database connection // Close closes the database connection
func (db *DB) Close() error { func (db *database) Close() error {
return db.DB.Close() return db.DB.Close()
} }
// Helper methods for token encryption/decryption // Helper methods for token encryption/decryption
func (db *DB) encryptToken(token string) (string, error) { func (db *database) encryptToken(token string) (string, error) {
if token == "" { if token == "" {
return "", nil return "", nil
} }
return db.crypto.Encrypt(token) return db.crypto.Encrypt(token)
} }
func (db *DB) decryptToken(token string) (string, error) { func (db *database) decryptToken(token string) (string, error) {
if token == "" { if token == "" {
return "", nil return "", nil
} }

View File

@@ -83,7 +83,7 @@ var migrations = []Migration{
} }
// Migrate applies all database migrations // Migrate applies all database migrations
func (db *DB) Migrate() error { func (db *database) Migrate() error {
// Create migrations table if it doesn't exist // Create migrations table if it doesn't exist
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations ( _, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations (
version INTEGER PRIMARY KEY version INTEGER PRIMARY KEY

View File

@@ -0,0 +1,70 @@
package db
import (
"database/sql"
"fmt"
"time"
"novamd/internal/models"
)
// CreateSession inserts a new session record into the database
func (db *database) CreateSession(session *models.Session) error {
_, err := db.Exec(`
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
VALUES (?, ?, ?, ?, ?)`,
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
)
if err != nil {
return fmt.Errorf("failed to store session: %w", err)
}
return nil
}
// GetSessionByRefreshToken retrieves a session by its refresh token
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
session := &models.Session{}
err := db.QueryRow(`
SELECT id, user_id, refresh_token, expires_at, created_at
FROM sessions
WHERE refresh_token = ? AND expires_at > ?`,
refreshToken, time.Now(),
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("session not found or expired")
}
if err != nil {
return nil, fmt.Errorf("failed to fetch session: %w", err)
}
return session, nil
}
// DeleteSession removes a session from the database
func (db *database) DeleteSession(sessionID string) error {
result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
if err != nil {
return fmt.Errorf("failed to delete session: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("session not found")
}
return nil
}
// CleanExpiredSessions removes all expired sessions from the database
func (db *database) CleanExpiredSessions() error {
_, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
if err != nil {
return fmt.Errorf("failed to clean expired sessions: %w", err)
}
return nil
}

View File

@@ -11,9 +11,16 @@ const (
JWTSecretKey = "jwt_secret" JWTSecretKey = "jwt_secret"
) )
// UserStats represents system-wide statistics
type UserStats struct {
TotalUsers int `json:"totalUsers"`
TotalWorkspaces int `json:"totalWorkspaces"`
ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days
}
// EnsureJWTSecret makes sure a JWT signing secret exists in the database // EnsureJWTSecret makes sure a JWT signing secret exists in the database
// If no secret exists, it generates and stores a new one // If no secret exists, it generates and stores a new one
func (db *DB) EnsureJWTSecret() (string, error) { func (db *database) EnsureJWTSecret() (string, error) {
// First, try to get existing secret // First, try to get existing secret
secret, err := db.GetSystemSetting(JWTSecretKey) secret, err := db.GetSystemSetting(JWTSecretKey)
if err == nil { if err == nil {
@@ -36,7 +43,7 @@ func (db *DB) EnsureJWTSecret() (string, error) {
} }
// GetSystemSetting retrieves a system setting by key // GetSystemSetting retrieves a system setting by key
func (db *DB) GetSystemSetting(key string) (string, error) { func (db *database) GetSystemSetting(key string) (string, error) {
var value string var value string
err := db.QueryRow("SELECT value FROM system_settings WHERE key = ?", key).Scan(&value) err := db.QueryRow("SELECT value FROM system_settings WHERE key = ?", key).Scan(&value)
if err != nil { if err != nil {
@@ -46,7 +53,7 @@ func (db *DB) GetSystemSetting(key string) (string, error) {
} }
// SetSystemSetting stores or updates a system setting // SetSystemSetting stores or updates a system setting
func (db *DB) SetSystemSetting(key, value string) error { func (db *database) SetSystemSetting(key, value string) error {
_, err := db.Exec(` _, err := db.Exec(`
INSERT INTO system_settings (key, value) INSERT INTO system_settings (key, value)
VALUES (?, ?) VALUES (?, ?)
@@ -64,3 +71,32 @@ func generateRandomSecret(bytes int) (string, error) {
} }
return base64.StdEncoding.EncodeToString(b), nil return base64.StdEncoding.EncodeToString(b), nil
} }
// GetSystemStats returns system-wide statistics
func (db *database) GetSystemStats() (*UserStats, error) {
stats := &UserStats{}
// Get total users
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
if err != nil {
return nil, err
}
// Get total workspaces
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
if err != nil {
return nil, err
}
// Get active users (users with activity in last 30 days)
err = db.QueryRow(`
SELECT COUNT(DISTINCT user_id)
FROM sessions
WHERE created_at > datetime('now', '-30 days')`).
Scan(&stats.ActiveUsers)
if err != nil {
return nil, err
}
return stats, nil
}

View File

@@ -6,7 +6,7 @@ import (
) )
// CreateUser inserts a new user record into the database // CreateUser inserts a new user record into the database
func (db *DB) CreateUser(user *models.User) (*models.User, error) { func (db *database) CreateUser(user *models.User) (*models.User, error) {
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -62,7 +62,7 @@ func (db *DB) CreateUser(user *models.User) (*models.User, error) {
} }
// Helper function to create a workspace in a transaction // Helper function to create a workspace in a transaction
func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error { func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
result, err := tx.Exec(` result, err := tx.Exec(`
INSERT INTO workspaces ( INSERT INTO workspaces (
user_id, name, user_id, name,
@@ -87,7 +87,7 @@ func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
} }
// GetUserByID retrieves a user by ID // GetUserByID retrieves a user by ID
func (db *DB) GetUserByID(id int) (*models.User, error) { func (db *database) GetUserByID(id int) (*models.User, error) {
user := &models.User{} user := &models.User{}
err := db.QueryRow(` err := db.QueryRow(`
SELECT SELECT
@@ -104,7 +104,7 @@ func (db *DB) GetUserByID(id int) (*models.User, error) {
} }
// GetUserByEmail retrieves a user by email // GetUserByEmail retrieves a user by email
func (db *DB) GetUserByEmail(email string) (*models.User, error) { func (db *database) GetUserByEmail(email string) (*models.User, error) {
user := &models.User{} user := &models.User{}
err := db.QueryRow(` err := db.QueryRow(`
SELECT SELECT
@@ -122,7 +122,7 @@ func (db *DB) GetUserByEmail(email string) (*models.User, error) {
} }
// UpdateUser updates a user's information // UpdateUser updates a user's information
func (db *DB) UpdateUser(user *models.User) error { func (db *database) UpdateUser(user *models.User) error {
_, err := db.Exec(` _, err := db.Exec(`
UPDATE users UPDATE users
SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ? SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ?
@@ -131,8 +131,36 @@ func (db *DB) UpdateUser(user *models.User) error {
return err return err
} }
// GetAllUsers returns a list of all users in the system
func (db *database) GetAllUsers() ([]*models.User, error) {
rows, err := db.Query(`
SELECT
id, email, display_name, role, created_at,
last_workspace_id
FROM users
ORDER BY id ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
var users []*models.User
for rows.Next() {
user := &models.User{}
err := rows.Scan(
&user.ID, &user.Email, &user.DisplayName, &user.Role,
&user.CreatedAt, &user.LastWorkspaceID,
)
if err != nil {
return nil, err
}
users = append(users, user)
}
return users, nil
}
// UpdateLastWorkspace updates the last workspace the user accessed // UpdateLastWorkspace updates the last workspace the user accessed
func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error { func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error {
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
return err return err
@@ -155,7 +183,7 @@ func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error {
} }
// DeleteUser deletes a user and all their workspaces // DeleteUser deletes a user and all their workspaces
func (db *DB) DeleteUser(id int) error { func (db *database) DeleteUser(id int) error {
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
return err return err
@@ -178,7 +206,7 @@ func (db *DB) DeleteUser(id int) error {
} }
// GetLastWorkspaceName returns the name of the last workspace the user accessed // GetLastWorkspaceName returns the name of the last workspace the user accessed
func (db *DB) GetLastWorkspaceName(userID int) (string, error) { func (db *database) GetLastWorkspaceName(userID int) (string, error) {
var workspaceName string var workspaceName string
err := db.QueryRow(` err := db.QueryRow(`
SELECT SELECT
@@ -189,3 +217,10 @@ func (db *DB) GetLastWorkspaceName(userID int) (string, error) {
Scan(&workspaceName) Scan(&workspaceName)
return workspaceName, err return workspaceName, err
} }
// CountAdminUsers returns the number of admin users in the system
func (db *database) CountAdminUsers() (int, error) {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count)
return count, err
}

View File

@@ -7,7 +7,7 @@ import (
) )
// CreateWorkspace inserts a new workspace record into the database // CreateWorkspace inserts a new workspace record into the database
func (db *DB) CreateWorkspace(workspace *models.Workspace) error { func (db *database) CreateWorkspace(workspace *models.Workspace) error {
// Set default settings if not provided // Set default settings if not provided
if workspace.Theme == "" { if workspace.Theme == "" {
workspace.GetDefaultSettings() workspace.GetDefaultSettings()
@@ -42,7 +42,7 @@ func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
} }
// GetWorkspaceByID retrieves a workspace by its ID // GetWorkspaceByID retrieves a workspace by its ID
func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) { func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
workspace := &models.Workspace{} workspace := &models.Workspace{}
var encryptedToken string var encryptedToken string
@@ -75,7 +75,7 @@ func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) {
} }
// GetWorkspaceByName retrieves a workspace by its name and user ID // GetWorkspaceByName retrieves a workspace by its name and user ID
func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) { func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
workspace := &models.Workspace{} workspace := &models.Workspace{}
var encryptedToken string var encryptedToken string
@@ -108,7 +108,7 @@ func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Work
} }
// UpdateWorkspace updates a workspace record in the database // UpdateWorkspace updates a workspace record in the database
func (db *DB) UpdateWorkspace(workspace *models.Workspace) error { func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
// Encrypt token before storing // Encrypt token before storing
encryptedToken, err := db.encryptToken(workspace.GitToken) encryptedToken, err := db.encryptToken(workspace.GitToken)
if err != nil { if err != nil {
@@ -146,7 +146,7 @@ func (db *DB) UpdateWorkspace(workspace *models.Workspace) error {
} }
// GetWorkspacesByUserID retrieves all workspaces for a user // GetWorkspacesByUserID retrieves all workspaces for a user
func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
rows, err := db.Query(` rows, err := db.Query(`
SELECT SELECT
id, user_id, name, created_at, id, user_id, name, created_at,
@@ -189,7 +189,7 @@ func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
// UpdateWorkspaceSettings updates only the settings portion of a workspace // UpdateWorkspaceSettings updates only the settings portion of a workspace
// This is useful when you don't want to modify the name or other core workspace properties // This is useful when you don't want to modify the name or other core workspace properties
func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error { func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
_, err := db.Exec(` _, err := db.Exec(`
UPDATE workspaces UPDATE workspaces
SET SET
@@ -218,31 +218,31 @@ func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error {
} }
// DeleteWorkspace removes a workspace record from the database // DeleteWorkspace removes a workspace record from the database
func (db *DB) DeleteWorkspace(id int) error { func (db *database) DeleteWorkspace(id int) error {
_, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id) _, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id)
return err return err
} }
// DeleteWorkspaceTx removes a workspace record from the database within a transaction // DeleteWorkspaceTx removes a workspace record from the database within a transaction
func (db *DB) DeleteWorkspaceTx(tx *sql.Tx, id int) error { func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
_, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id) _, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id)
return err return err
} }
// UpdateLastWorkspaceTx sets the last workspace for a user in with a transaction // UpdateLastWorkspaceTx sets the last workspace for a user in with a transaction
func (db *DB) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error { func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
_, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID) _, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID)
return err return err
} }
// UpdateLastOpenedFile updates the last opened file path for a workspace // UpdateLastOpenedFile updates the last opened file path for a workspace
func (db *DB) UpdateLastOpenedFile(workspaceID int, filePath string) error { func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
_, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?", filePath, workspaceID) _, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?", filePath, workspaceID)
return err return err
} }
// GetLastOpenedFile retrieves the last opened file path for a workspace // GetLastOpenedFile retrieves the last opened file path for a workspace
func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) { func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
var filePath sql.NullString var filePath sql.NullString
err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?", workspaceID).Scan(&filePath) err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?", workspaceID).Scan(&filePath)
if err != nil { if err != nil {
@@ -255,7 +255,7 @@ func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) {
} }
// GetAllWorkspaces retrieves all workspaces in the database // GetAllWorkspaces retrieves all workspaces in the database
func (db *DB) GetAllWorkspaces() ([]*models.Workspace, error) { func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
rows, err := db.Query(` rows, err := db.Query(`
SELECT SELECT
id, user_id, name, created_at, id, user_id, name, created_at,

View File

@@ -19,7 +19,7 @@ type LoginResponse struct {
AccessToken string `json:"accessToken"` AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"` RefreshToken string `json:"refreshToken"`
User *models.User `json:"user"` User *models.User `json:"user"`
Session *auth.Session `json:"session"` Session *models.Session `json:"session"`
} }
type RefreshRequest struct { type RefreshRequest struct {

View File

@@ -9,12 +9,12 @@ import (
// Handler provides common functionality for all handlers // Handler provides common functionality for all handlers
type Handler struct { type Handler struct {
DB *db.DB DB db.Database
Storage storage.Manager Storage storage.Manager
} }
// NewHandler creates a new handler with the given dependencies // NewHandler creates a new handler with the given dependencies
func NewHandler(db *db.DB, s storage.Manager) *Handler { func NewHandler(db db.Database, s storage.Manager) *Handler {
return &Handler{ return &Handler{
DB: db, DB: db,
Storage: s, Storage: s,

View File

@@ -171,8 +171,7 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
// Prevent admin from deleting their own account if they're the last admin // Prevent admin from deleting their own account if they're the last admin
if user.Role == "admin" { if user.Role == "admin" {
// Count number of admin users // Count number of admin users
adminCount := 0 adminCount, err := h.DB.CountAdminUsers()
err := h.DB.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&adminCount)
if err != nil { if err != nil {
http.Error(w, "Failed to verify admin status", http.StatusInternalServerError) http.Error(w, "Failed to verify admin status", http.StatusInternalServerError)
return return

View File

@@ -29,7 +29,7 @@ func WithUserContext(next http.Handler) http.Handler {
} }
// Workspace context // Workspace context
func WithWorkspaceContext(db *db.DB) func(http.Handler) http.Handler { func WithWorkspaceContext(db db.Database) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r) ctx, ok := httpcontext.GetRequestContext(w, r)

View File

@@ -0,0 +1,13 @@
// Package models contains the data models used throughout the application. These models are used to represent data in the database, as well as to validate and serialize data in the application.
package models
import "time"
// Session represents a user session in the database
type Session struct {
ID string // Unique session identifier
UserID int // ID of the user this session belongs to
RefreshToken string // The refresh token associated with this session
ExpiresAt time.Time // When this session expires
CreatedAt time.Time // When this session was created
}

View File

@@ -13,11 +13,11 @@ import (
) )
type UserService struct { type UserService struct {
DB *db.DB DB db.Database
Storage storage.Manager Storage storage.Manager
} }
func NewUserService(database *db.DB, s storage.Manager) *UserService { func NewUserService(database db.Database, s storage.Manager) *UserService {
return &UserService{ return &UserService{
DB: database, DB: database,
Storage: s, Storage: s,