Rework db package to make it testable

This commit is contained in:
2024-11-21 22:36:12 +01:00
parent 2faefb6db5
commit 807e96a76c
16 changed files with 274 additions and 161 deletions

View File

@@ -12,7 +12,7 @@ import (
)
// SetupRoutes configures the API routes
func SetupRoutes(r chi.Router, db *db.DB, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) {
func SetupRoutes(r chi.Router, db db.Database, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) {
handler := &handlers.Handler{
DB: db,

View File

@@ -1,33 +1,25 @@
package auth
import (
"database/sql"
"fmt"
"novamd/internal/db"
"novamd/internal/models"
"time"
"github.com/google/uuid"
)
// Session represents a user session in the database
type Session struct {
ID string // Unique session identifier
UserID int // ID of the user this session belongs to
RefreshToken string // The refresh token associated with this session
ExpiresAt time.Time // When this session expires
CreatedAt time.Time // When this session was created
}
// SessionService manages user sessions in the database
type SessionService struct {
db *sql.DB // Database connection
jwtManager JWTManager // JWT Manager for token operations
db db.SessionStore // Database store for sessions
jwtManager JWTManager // JWT Manager for token operations
}
// NewSessionService creates a new session service
// Parameters:
// - db: database connection
// - jwtManager: JWT service for token operations
func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService {
func NewSessionService(db db.SessionStore, jwtManager JWTManager) *SessionService {
return &SessionService{
db: db,
jwtManager: jwtManager,
@@ -42,7 +34,7 @@ func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService {
// - session: the created session
// - accessToken: a new access token
// - error: any error that occurred
func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) {
func (s *SessionService) CreateSession(userID int, role string) (*models.Session, string, error) {
// Generate both access and refresh tokens
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
if err != nil {
@@ -61,7 +53,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
}
// Create a new session record
session := &Session{
session := &models.Session{
ID: uuid.New().String(),
UserID: userID,
RefreshToken: refreshToken,
@@ -69,14 +61,9 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
CreatedAt: time.Now(),
}
// Store the session in the database
_, err = s.db.Exec(`
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
VALUES (?, ?, ?, ?, ?)`,
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
)
if err != nil {
return nil, "", fmt.Errorf("failed to store session: %w", err)
// Store the session
if err := s.db.CreateSession(session); err != nil {
return nil, "", err
}
return session, accessToken, nil
@@ -89,28 +76,18 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
// - string: a new access token
// - error: any error that occurred
func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
// Get session from database
_, err := s.db.GetSessionByRefreshToken(refreshToken)
if err != nil {
return "", fmt.Errorf("invalid session: %w", err)
}
// Validate the refresh token
claims, err := s.jwtManager.ValidateToken(refreshToken)
if err != nil {
return "", fmt.Errorf("invalid refresh token: %w", err)
}
// Check if the session exists and is not expired
var session Session
err = s.db.QueryRow(`
SELECT id, user_id, refresh_token, expires_at, created_at
FROM sessions
WHERE refresh_token = ? AND expires_at > ?`,
refreshToken, time.Now(),
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
if err == sql.ErrNoRows {
return "", fmt.Errorf("session not found or expired")
}
if err != nil {
return "", fmt.Errorf("failed to fetch session: %w", err)
}
// Generate a new access token
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
}
@@ -121,20 +98,12 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
// Returns:
// - error: any error that occurred
func (s *SessionService) InvalidateSession(sessionID string) error {
_, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
if err != nil {
return fmt.Errorf("failed to invalidate session: %w", err)
}
return nil
return s.db.DeleteSession(sessionID)
}
// CleanExpiredSessions removes all expired sessions from the database
// Returns:
// - error: any error that occurred
func (s *SessionService) CleanExpiredSessions() error {
_, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
if err != nil {
return fmt.Errorf("failed to clean expired sessions: %w", err)
}
return nil
return s.db.CleanExpiredSessions()
}

View File

@@ -1,68 +0,0 @@
// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records.
package db
import "novamd/internal/models"
// UserStats represents system-wide statistics
type UserStats struct {
TotalUsers int `json:"totalUsers"`
TotalWorkspaces int `json:"totalWorkspaces"`
ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days
}
// GetAllUsers returns a list of all users in the system
func (db *DB) GetAllUsers() ([]*models.User, error) {
rows, err := db.Query(`
SELECT
id, email, display_name, role, created_at,
last_workspace_id
FROM users
ORDER BY id ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
var users []*models.User
for rows.Next() {
user := &models.User{}
err := rows.Scan(
&user.ID, &user.Email, &user.DisplayName, &user.Role,
&user.CreatedAt, &user.LastWorkspaceID,
)
if err != nil {
return nil, err
}
users = append(users, user)
}
return users, nil
}
// GetSystemStats returns system-wide statistics
func (db *DB) GetSystemStats() (*UserStats, error) {
stats := &UserStats{}
// Get total users
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
if err != nil {
return nil, err
}
// Get total workspaces
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
if err != nil {
return nil, err
}
// Get active users (users with activity in last 30 days)
err = db.QueryRow(`
SELECT COUNT(DISTINCT user_id)
FROM sessions
WHERE created_at > datetime('now', '-30 days')`).
Scan(&stats.ActiveUsers)
if err != nil {
return nil, err
}
return stats, nil
}

View File

@@ -1,3 +1,4 @@
// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records.
package db
import (
@@ -5,18 +6,75 @@ import (
"fmt"
"novamd/internal/crypto"
"novamd/internal/models"
_ "github.com/mattn/go-sqlite3" // SQLite driver
)
// DB represents the database connection
type DB struct {
// UserStore defines the methods for interacting with user data in the database
type UserStore interface {
CreateUser(user *models.User) (*models.User, error)
GetUserByEmail(email string) (*models.User, error)
GetUserByID(userID int) (*models.User, error)
GetAllUsers() ([]*models.User, error)
UpdateUser(user *models.User) error
DeleteUser(userID int) error
UpdateLastWorkspace(userID int, workspaceName string) error
GetLastWorkspaceName(userID int) (string, error)
CountAdminUsers() (int, error)
}
// WorkspaceStore defines the methods for interacting with workspace data in the database
type WorkspaceStore interface {
CreateWorkspace(workspace *models.Workspace) error
GetWorkspaceByID(workspaceID int) (*models.Workspace, error)
GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error)
GetWorkspacesByUserID(userID int) ([]*models.Workspace, error)
UpdateWorkspace(workspace *models.Workspace) error
DeleteWorkspace(workspaceID int) error
UpdateWorkspaceSettings(workspace *models.Workspace) error
DeleteWorkspaceTx(tx *sql.Tx, workspaceID int) error
UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error
UpdateLastOpenedFile(workspaceID int, filePath string) error
GetLastOpenedFile(workspaceID int) (string, error)
GetAllWorkspaces() ([]*models.Workspace, error)
}
// SessionStore defines the methods for interacting with jwt sessions in the database
type SessionStore interface {
CreateSession(session *models.Session) error
GetSessionByRefreshToken(refreshToken string) (*models.Session, error)
DeleteSession(sessionID string) error
CleanExpiredSessions() error
}
// SystemStore defines the methods for interacting with system settings and stats in the database
type SystemStore interface {
GetSystemStats() (*UserStats, error)
EnsureJWTSecret() (string, error)
GetSystemSetting(key string) (string, error)
SetSystemSetting(key, value string) error
}
// Database defines the methods for interacting with the database
type Database interface {
UserStore
WorkspaceStore
SessionStore
SystemStore
Begin() (*sql.Tx, error)
Close() error
Migrate() error
}
// database represents the database connection
type database struct {
*sql.DB
crypto *crypto.Crypto
}
// Init initializes the database connection
func Init(dbPath string, encryptionKey string) (*DB, error) {
func Init(dbPath string, encryptionKey string) (Database, error) {
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, err
@@ -32,7 +90,7 @@ func Init(dbPath string, encryptionKey string) (*DB, error) {
return nil, fmt.Errorf("failed to initialize encryption: %w", err)
}
database := &DB{
database := &database{
DB: db,
crypto: cryptoService,
}
@@ -45,19 +103,19 @@ func Init(dbPath string, encryptionKey string) (*DB, error) {
}
// Close closes the database connection
func (db *DB) Close() error {
func (db *database) Close() error {
return db.DB.Close()
}
// Helper methods for token encryption/decryption
func (db *DB) encryptToken(token string) (string, error) {
func (db *database) encryptToken(token string) (string, error) {
if token == "" {
return "", nil
}
return db.crypto.Encrypt(token)
}
func (db *DB) decryptToken(token string) (string, error) {
func (db *database) decryptToken(token string) (string, error) {
if token == "" {
return "", nil
}

View File

@@ -83,7 +83,7 @@ var migrations = []Migration{
}
// Migrate applies all database migrations
func (db *DB) Migrate() error {
func (db *database) Migrate() error {
// Create migrations table if it doesn't exist
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations (
version INTEGER PRIMARY KEY

View File

@@ -0,0 +1,70 @@
package db
import (
"database/sql"
"fmt"
"time"
"novamd/internal/models"
)
// CreateSession inserts a new session record into the database
func (db *database) CreateSession(session *models.Session) error {
_, err := db.Exec(`
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
VALUES (?, ?, ?, ?, ?)`,
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
)
if err != nil {
return fmt.Errorf("failed to store session: %w", err)
}
return nil
}
// GetSessionByRefreshToken retrieves a session by its refresh token
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
session := &models.Session{}
err := db.QueryRow(`
SELECT id, user_id, refresh_token, expires_at, created_at
FROM sessions
WHERE refresh_token = ? AND expires_at > ?`,
refreshToken, time.Now(),
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("session not found or expired")
}
if err != nil {
return nil, fmt.Errorf("failed to fetch session: %w", err)
}
return session, nil
}
// DeleteSession removes a session from the database
func (db *database) DeleteSession(sessionID string) error {
result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
if err != nil {
return fmt.Errorf("failed to delete session: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("session not found")
}
return nil
}
// CleanExpiredSessions removes all expired sessions from the database
func (db *database) CleanExpiredSessions() error {
_, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
if err != nil {
return fmt.Errorf("failed to clean expired sessions: %w", err)
}
return nil
}

View File

@@ -11,9 +11,16 @@ const (
JWTSecretKey = "jwt_secret"
)
// UserStats represents system-wide statistics
type UserStats struct {
TotalUsers int `json:"totalUsers"`
TotalWorkspaces int `json:"totalWorkspaces"`
ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days
}
// EnsureJWTSecret makes sure a JWT signing secret exists in the database
// If no secret exists, it generates and stores a new one
func (db *DB) EnsureJWTSecret() (string, error) {
func (db *database) EnsureJWTSecret() (string, error) {
// First, try to get existing secret
secret, err := db.GetSystemSetting(JWTSecretKey)
if err == nil {
@@ -36,7 +43,7 @@ func (db *DB) EnsureJWTSecret() (string, error) {
}
// GetSystemSetting retrieves a system setting by key
func (db *DB) GetSystemSetting(key string) (string, error) {
func (db *database) GetSystemSetting(key string) (string, error) {
var value string
err := db.QueryRow("SELECT value FROM system_settings WHERE key = ?", key).Scan(&value)
if err != nil {
@@ -46,7 +53,7 @@ func (db *DB) GetSystemSetting(key string) (string, error) {
}
// SetSystemSetting stores or updates a system setting
func (db *DB) SetSystemSetting(key, value string) error {
func (db *database) SetSystemSetting(key, value string) error {
_, err := db.Exec(`
INSERT INTO system_settings (key, value)
VALUES (?, ?)
@@ -64,3 +71,32 @@ func generateRandomSecret(bytes int) (string, error) {
}
return base64.StdEncoding.EncodeToString(b), nil
}
// GetSystemStats returns system-wide statistics
func (db *database) GetSystemStats() (*UserStats, error) {
stats := &UserStats{}
// Get total users
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
if err != nil {
return nil, err
}
// Get total workspaces
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
if err != nil {
return nil, err
}
// Get active users (users with activity in last 30 days)
err = db.QueryRow(`
SELECT COUNT(DISTINCT user_id)
FROM sessions
WHERE created_at > datetime('now', '-30 days')`).
Scan(&stats.ActiveUsers)
if err != nil {
return nil, err
}
return stats, nil
}

View File

@@ -6,7 +6,7 @@ import (
)
// CreateUser inserts a new user record into the database
func (db *DB) CreateUser(user *models.User) (*models.User, error) {
func (db *database) CreateUser(user *models.User) (*models.User, error) {
tx, err := db.Begin()
if err != nil {
return nil, err
@@ -62,7 +62,7 @@ func (db *DB) CreateUser(user *models.User) (*models.User, error) {
}
// Helper function to create a workspace in a transaction
func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
result, err := tx.Exec(`
INSERT INTO workspaces (
user_id, name,
@@ -87,7 +87,7 @@ func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
}
// GetUserByID retrieves a user by ID
func (db *DB) GetUserByID(id int) (*models.User, error) {
func (db *database) GetUserByID(id int) (*models.User, error) {
user := &models.User{}
err := db.QueryRow(`
SELECT
@@ -104,7 +104,7 @@ func (db *DB) GetUserByID(id int) (*models.User, error) {
}
// GetUserByEmail retrieves a user by email
func (db *DB) GetUserByEmail(email string) (*models.User, error) {
func (db *database) GetUserByEmail(email string) (*models.User, error) {
user := &models.User{}
err := db.QueryRow(`
SELECT
@@ -122,7 +122,7 @@ func (db *DB) GetUserByEmail(email string) (*models.User, error) {
}
// UpdateUser updates a user's information
func (db *DB) UpdateUser(user *models.User) error {
func (db *database) UpdateUser(user *models.User) error {
_, err := db.Exec(`
UPDATE users
SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ?
@@ -131,8 +131,36 @@ func (db *DB) UpdateUser(user *models.User) error {
return err
}
// GetAllUsers returns a list of all users in the system
func (db *database) GetAllUsers() ([]*models.User, error) {
rows, err := db.Query(`
SELECT
id, email, display_name, role, created_at,
last_workspace_id
FROM users
ORDER BY id ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
var users []*models.User
for rows.Next() {
user := &models.User{}
err := rows.Scan(
&user.ID, &user.Email, &user.DisplayName, &user.Role,
&user.CreatedAt, &user.LastWorkspaceID,
)
if err != nil {
return nil, err
}
users = append(users, user)
}
return users, nil
}
// UpdateLastWorkspace updates the last workspace the user accessed
func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error {
func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error {
tx, err := db.Begin()
if err != nil {
return err
@@ -155,7 +183,7 @@ func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error {
}
// DeleteUser deletes a user and all their workspaces
func (db *DB) DeleteUser(id int) error {
func (db *database) DeleteUser(id int) error {
tx, err := db.Begin()
if err != nil {
return err
@@ -178,7 +206,7 @@ func (db *DB) DeleteUser(id int) error {
}
// GetLastWorkspaceName returns the name of the last workspace the user accessed
func (db *DB) GetLastWorkspaceName(userID int) (string, error) {
func (db *database) GetLastWorkspaceName(userID int) (string, error) {
var workspaceName string
err := db.QueryRow(`
SELECT
@@ -189,3 +217,10 @@ func (db *DB) GetLastWorkspaceName(userID int) (string, error) {
Scan(&workspaceName)
return workspaceName, err
}
// CountAdminUsers returns the number of admin users in the system
func (db *database) CountAdminUsers() (int, error) {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count)
return count, err
}

View File

@@ -7,7 +7,7 @@ import (
)
// CreateWorkspace inserts a new workspace record into the database
func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
func (db *database) CreateWorkspace(workspace *models.Workspace) error {
// Set default settings if not provided
if workspace.Theme == "" {
workspace.GetDefaultSettings()
@@ -42,7 +42,7 @@ func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
}
// GetWorkspaceByID retrieves a workspace by its ID
func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) {
func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
workspace := &models.Workspace{}
var encryptedToken string
@@ -75,7 +75,7 @@ func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) {
}
// GetWorkspaceByName retrieves a workspace by its name and user ID
func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
workspace := &models.Workspace{}
var encryptedToken string
@@ -108,7 +108,7 @@ func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Work
}
// UpdateWorkspace updates a workspace record in the database
func (db *DB) UpdateWorkspace(workspace *models.Workspace) error {
func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
// Encrypt token before storing
encryptedToken, err := db.encryptToken(workspace.GitToken)
if err != nil {
@@ -146,7 +146,7 @@ func (db *DB) UpdateWorkspace(workspace *models.Workspace) error {
}
// GetWorkspacesByUserID retrieves all workspaces for a user
func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
rows, err := db.Query(`
SELECT
id, user_id, name, created_at,
@@ -189,7 +189,7 @@ func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
// UpdateWorkspaceSettings updates only the settings portion of a workspace
// This is useful when you don't want to modify the name or other core workspace properties
func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error {
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
_, err := db.Exec(`
UPDATE workspaces
SET
@@ -218,31 +218,31 @@ func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error {
}
// DeleteWorkspace removes a workspace record from the database
func (db *DB) DeleteWorkspace(id int) error {
func (db *database) DeleteWorkspace(id int) error {
_, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id)
return err
}
// DeleteWorkspaceTx removes a workspace record from the database within a transaction
func (db *DB) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
_, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id)
return err
}
// UpdateLastWorkspaceTx sets the last workspace for a user in with a transaction
func (db *DB) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
_, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID)
return err
}
// UpdateLastOpenedFile updates the last opened file path for a workspace
func (db *DB) UpdateLastOpenedFile(workspaceID int, filePath string) error {
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
_, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?", filePath, workspaceID)
return err
}
// GetLastOpenedFile retrieves the last opened file path for a workspace
func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) {
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
var filePath sql.NullString
err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?", workspaceID).Scan(&filePath)
if err != nil {
@@ -255,7 +255,7 @@ func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) {
}
// GetAllWorkspaces retrieves all workspaces in the database
func (db *DB) GetAllWorkspaces() ([]*models.Workspace, error) {
func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
rows, err := db.Query(`
SELECT
id, user_id, name, created_at,

View File

@@ -16,10 +16,10 @@ type LoginRequest struct {
}
type LoginResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
User *models.User `json:"user"`
Session *auth.Session `json:"session"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
User *models.User `json:"user"`
Session *models.Session `json:"session"`
}
type RefreshRequest struct {

View File

@@ -9,12 +9,12 @@ import (
// Handler provides common functionality for all handlers
type Handler struct {
DB *db.DB
DB db.Database
Storage storage.Manager
}
// NewHandler creates a new handler with the given dependencies
func NewHandler(db *db.DB, s storage.Manager) *Handler {
func NewHandler(db db.Database, s storage.Manager) *Handler {
return &Handler{
DB: db,
Storage: s,

View File

@@ -171,8 +171,7 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
// Prevent admin from deleting their own account if they're the last admin
if user.Role == "admin" {
// Count number of admin users
adminCount := 0
err := h.DB.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&adminCount)
adminCount, err := h.DB.CountAdminUsers()
if err != nil {
http.Error(w, "Failed to verify admin status", http.StatusInternalServerError)
return

View File

@@ -29,7 +29,7 @@ func WithUserContext(next http.Handler) http.Handler {
}
// Workspace context
func WithWorkspaceContext(db *db.DB) func(http.Handler) http.Handler {
func WithWorkspaceContext(db db.Database) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)

View File

@@ -0,0 +1,13 @@
// Package models contains the data models used throughout the application. These models are used to represent data in the database, as well as to validate and serialize data in the application.
package models
import "time"
// Session represents a user session in the database
type Session struct {
ID string // Unique session identifier
UserID int // ID of the user this session belongs to
RefreshToken string // The refresh token associated with this session
ExpiresAt time.Time // When this session expires
CreatedAt time.Time // When this session was created
}

View File

@@ -13,11 +13,11 @@ import (
)
type UserService struct {
DB *db.DB
DB db.Database
Storage storage.Manager
}
func NewUserService(database *db.DB, s storage.Manager) *UserService {
func NewUserService(database db.Database, s storage.Manager) *UserService {
return &UserService{
DB: database,
Storage: s,