From 807e96a76c885a7717ae3e89245b20edc3236939 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 21 Nov 2024 22:36:12 +0100 Subject: [PATCH] Rework db package to make it testable --- server/cmd/server/main.go | 3 +- server/internal/api/routes.go | 2 +- server/internal/auth/session.go | 67 +++++------------ server/internal/db/admin.go | 68 ------------------ server/internal/db/db.go | 72 +++++++++++++++++-- server/internal/db/migrations.go | 2 +- server/internal/db/sessions.go | 70 ++++++++++++++++++ .../db/{system_settings.go => system.go} | 42 ++++++++++- server/internal/db/users.go | 51 ++++++++++--- server/internal/db/workspaces.go | 24 +++---- server/internal/handlers/auth_handlers.go | 8 +-- server/internal/handlers/handlers.go | 4 +- server/internal/handlers/user_handlers.go | 3 +- server/internal/middleware/context.go | 2 +- server/internal/models/session.go | 13 ++++ server/internal/user/user.go | 4 +- 16 files changed, 274 insertions(+), 161 deletions(-) delete mode 100644 server/internal/db/admin.go create mode 100644 server/internal/db/sessions.go rename server/internal/db/{system_settings.go => system.go} (57%) create mode 100644 server/internal/models/session.go diff --git a/server/cmd/server/main.go b/server/cmd/server/main.go index 951d8b8..7fbf006 100644 --- a/server/cmd/server/main.go +++ b/server/cmd/server/main.go @@ -1,3 +1,4 @@ +// Package main contains the main entry point for the application. It sets up the server, database, and other services, and starts the server. package main import ( @@ -61,7 +62,7 @@ func main() { authMiddleware := auth.NewMiddleware(jwtManager) // Initialize session service - sessionService := auth.NewSessionService(database.DB, jwtManager) + sessionService := auth.NewSessionService(database, jwtManager) // Set up router r := chi.NewRouter() diff --git a/server/internal/api/routes.go b/server/internal/api/routes.go index 71a873b..35751aa 100644 --- a/server/internal/api/routes.go +++ b/server/internal/api/routes.go @@ -12,7 +12,7 @@ import ( ) // SetupRoutes configures the API routes -func SetupRoutes(r chi.Router, db *db.DB, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) { +func SetupRoutes(r chi.Router, db db.Database, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) { handler := &handlers.Handler{ DB: db, diff --git a/server/internal/auth/session.go b/server/internal/auth/session.go index bc0d165..b6217d6 100644 --- a/server/internal/auth/session.go +++ b/server/internal/auth/session.go @@ -1,33 +1,25 @@ package auth import ( - "database/sql" "fmt" + "novamd/internal/db" + "novamd/internal/models" "time" "github.com/google/uuid" ) -// Session represents a user session in the database -type Session struct { - ID string // Unique session identifier - UserID int // ID of the user this session belongs to - RefreshToken string // The refresh token associated with this session - ExpiresAt time.Time // When this session expires - CreatedAt time.Time // When this session was created -} - // SessionService manages user sessions in the database type SessionService struct { - db *sql.DB // Database connection - jwtManager JWTManager // JWT Manager for token operations + db db.SessionStore // Database store for sessions + jwtManager JWTManager // JWT Manager for token operations } // NewSessionService creates a new session service // Parameters: // - db: database connection // - jwtManager: JWT service for token operations -func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService { +func NewSessionService(db db.SessionStore, jwtManager JWTManager) *SessionService { return &SessionService{ db: db, jwtManager: jwtManager, @@ -42,7 +34,7 @@ func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService { // - session: the created session // - accessToken: a new access token // - error: any error that occurred -func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) { +func (s *SessionService) CreateSession(userID int, role string) (*models.Session, string, error) { // Generate both access and refresh tokens accessToken, err := s.jwtManager.GenerateAccessToken(userID, role) if err != nil { @@ -61,7 +53,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin } // Create a new session record - session := &Session{ + session := &models.Session{ ID: uuid.New().String(), UserID: userID, RefreshToken: refreshToken, @@ -69,14 +61,9 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin CreatedAt: time.Now(), } - // Store the session in the database - _, err = s.db.Exec(` - INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at) - VALUES (?, ?, ?, ?, ?)`, - session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt, - ) - if err != nil { - return nil, "", fmt.Errorf("failed to store session: %w", err) + // Store the session + if err := s.db.CreateSession(session); err != nil { + return nil, "", err } return session, accessToken, nil @@ -89,28 +76,18 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin // - string: a new access token // - error: any error that occurred func (s *SessionService) RefreshSession(refreshToken string) (string, error) { + // Get session from database + _, err := s.db.GetSessionByRefreshToken(refreshToken) + if err != nil { + return "", fmt.Errorf("invalid session: %w", err) + } + // Validate the refresh token claims, err := s.jwtManager.ValidateToken(refreshToken) if err != nil { return "", fmt.Errorf("invalid refresh token: %w", err) } - // Check if the session exists and is not expired - var session Session - err = s.db.QueryRow(` - SELECT id, user_id, refresh_token, expires_at, created_at - FROM sessions - WHERE refresh_token = ? AND expires_at > ?`, - refreshToken, time.Now(), - ).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) - - if err == sql.ErrNoRows { - return "", fmt.Errorf("session not found or expired") - } - if err != nil { - return "", fmt.Errorf("failed to fetch session: %w", err) - } - // Generate a new access token return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role) } @@ -121,20 +98,12 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) { // Returns: // - error: any error that occurred func (s *SessionService) InvalidateSession(sessionID string) error { - _, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID) - if err != nil { - return fmt.Errorf("failed to invalidate session: %w", err) - } - return nil + return s.db.DeleteSession(sessionID) } // CleanExpiredSessions removes all expired sessions from the database // Returns: // - error: any error that occurred func (s *SessionService) CleanExpiredSessions() error { - _, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now()) - if err != nil { - return fmt.Errorf("failed to clean expired sessions: %w", err) - } - return nil + return s.db.CleanExpiredSessions() } diff --git a/server/internal/db/admin.go b/server/internal/db/admin.go deleted file mode 100644 index 15c8a84..0000000 --- a/server/internal/db/admin.go +++ /dev/null @@ -1,68 +0,0 @@ -// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records. -package db - -import "novamd/internal/models" - -// UserStats represents system-wide statistics -type UserStats struct { - TotalUsers int `json:"totalUsers"` - TotalWorkspaces int `json:"totalWorkspaces"` - ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days -} - -// GetAllUsers returns a list of all users in the system -func (db *DB) GetAllUsers() ([]*models.User, error) { - rows, err := db.Query(` - SELECT - id, email, display_name, role, created_at, - last_workspace_id - FROM users - ORDER BY id ASC`) - if err != nil { - return nil, err - } - defer rows.Close() - - var users []*models.User - for rows.Next() { - user := &models.User{} - err := rows.Scan( - &user.ID, &user.Email, &user.DisplayName, &user.Role, - &user.CreatedAt, &user.LastWorkspaceID, - ) - if err != nil { - return nil, err - } - users = append(users, user) - } - return users, nil -} - -// GetSystemStats returns system-wide statistics -func (db *DB) GetSystemStats() (*UserStats, error) { - stats := &UserStats{} - - // Get total users - err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers) - if err != nil { - return nil, err - } - - // Get total workspaces - err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces) - if err != nil { - return nil, err - } - - // Get active users (users with activity in last 30 days) - err = db.QueryRow(` - SELECT COUNT(DISTINCT user_id) - FROM sessions - WHERE created_at > datetime('now', '-30 days')`). - Scan(&stats.ActiveUsers) - if err != nil { - return nil, err - } - - return stats, nil -} diff --git a/server/internal/db/db.go b/server/internal/db/db.go index 8ca51c3..8cd04e9 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -1,3 +1,4 @@ +// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records. package db import ( @@ -5,18 +6,75 @@ import ( "fmt" "novamd/internal/crypto" + "novamd/internal/models" _ "github.com/mattn/go-sqlite3" // SQLite driver ) -// DB represents the database connection -type DB struct { +// UserStore defines the methods for interacting with user data in the database +type UserStore interface { + CreateUser(user *models.User) (*models.User, error) + GetUserByEmail(email string) (*models.User, error) + GetUserByID(userID int) (*models.User, error) + GetAllUsers() ([]*models.User, error) + UpdateUser(user *models.User) error + DeleteUser(userID int) error + UpdateLastWorkspace(userID int, workspaceName string) error + GetLastWorkspaceName(userID int) (string, error) + CountAdminUsers() (int, error) +} + +// WorkspaceStore defines the methods for interacting with workspace data in the database +type WorkspaceStore interface { + CreateWorkspace(workspace *models.Workspace) error + GetWorkspaceByID(workspaceID int) (*models.Workspace, error) + GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) + GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) + UpdateWorkspace(workspace *models.Workspace) error + DeleteWorkspace(workspaceID int) error + UpdateWorkspaceSettings(workspace *models.Workspace) error + DeleteWorkspaceTx(tx *sql.Tx, workspaceID int) error + UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error + UpdateLastOpenedFile(workspaceID int, filePath string) error + GetLastOpenedFile(workspaceID int) (string, error) + GetAllWorkspaces() ([]*models.Workspace, error) +} + +// SessionStore defines the methods for interacting with jwt sessions in the database +type SessionStore interface { + CreateSession(session *models.Session) error + GetSessionByRefreshToken(refreshToken string) (*models.Session, error) + DeleteSession(sessionID string) error + CleanExpiredSessions() error +} + +// SystemStore defines the methods for interacting with system settings and stats in the database +type SystemStore interface { + GetSystemStats() (*UserStats, error) + EnsureJWTSecret() (string, error) + GetSystemSetting(key string) (string, error) + SetSystemSetting(key, value string) error +} + +// Database defines the methods for interacting with the database +type Database interface { + UserStore + WorkspaceStore + SessionStore + SystemStore + Begin() (*sql.Tx, error) + Close() error + Migrate() error +} + +// database represents the database connection +type database struct { *sql.DB crypto *crypto.Crypto } // Init initializes the database connection -func Init(dbPath string, encryptionKey string) (*DB, error) { +func Init(dbPath string, encryptionKey string) (Database, error) { db, err := sql.Open("sqlite3", dbPath) if err != nil { return nil, err @@ -32,7 +90,7 @@ func Init(dbPath string, encryptionKey string) (*DB, error) { return nil, fmt.Errorf("failed to initialize encryption: %w", err) } - database := &DB{ + database := &database{ DB: db, crypto: cryptoService, } @@ -45,19 +103,19 @@ func Init(dbPath string, encryptionKey string) (*DB, error) { } // Close closes the database connection -func (db *DB) Close() error { +func (db *database) Close() error { return db.DB.Close() } // Helper methods for token encryption/decryption -func (db *DB) encryptToken(token string) (string, error) { +func (db *database) encryptToken(token string) (string, error) { if token == "" { return "", nil } return db.crypto.Encrypt(token) } -func (db *DB) decryptToken(token string) (string, error) { +func (db *database) decryptToken(token string) (string, error) { if token == "" { return "", nil } diff --git a/server/internal/db/migrations.go b/server/internal/db/migrations.go index 11cfc74..f59b20f 100644 --- a/server/internal/db/migrations.go +++ b/server/internal/db/migrations.go @@ -83,7 +83,7 @@ var migrations = []Migration{ } // Migrate applies all database migrations -func (db *DB) Migrate() error { +func (db *database) Migrate() error { // Create migrations table if it doesn't exist _, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations ( version INTEGER PRIMARY KEY diff --git a/server/internal/db/sessions.go b/server/internal/db/sessions.go new file mode 100644 index 0000000..596dc64 --- /dev/null +++ b/server/internal/db/sessions.go @@ -0,0 +1,70 @@ +package db + +import ( + "database/sql" + "fmt" + "time" + + "novamd/internal/models" +) + +// CreateSession inserts a new session record into the database +func (db *database) CreateSession(session *models.Session) error { + _, err := db.Exec(` + INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at) + VALUES (?, ?, ?, ?, ?)`, + session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt, + ) + if err != nil { + return fmt.Errorf("failed to store session: %w", err) + } + return nil +} + +// GetSessionByRefreshToken retrieves a session by its refresh token +func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { + session := &models.Session{} + err := db.QueryRow(` + SELECT id, user_id, refresh_token, expires_at, created_at + FROM sessions + WHERE refresh_token = ? AND expires_at > ?`, + refreshToken, time.Now(), + ).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) + + if err == sql.ErrNoRows { + return nil, fmt.Errorf("session not found or expired") + } + if err != nil { + return nil, fmt.Errorf("failed to fetch session: %w", err) + } + + return session, nil +} + +// DeleteSession removes a session from the database +func (db *database) DeleteSession(sessionID string) error { + result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID) + if err != nil { + return fmt.Errorf("failed to delete session: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rowsAffected == 0 { + return fmt.Errorf("session not found") + } + + return nil +} + +// CleanExpiredSessions removes all expired sessions from the database +func (db *database) CleanExpiredSessions() error { + _, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now()) + if err != nil { + return fmt.Errorf("failed to clean expired sessions: %w", err) + } + return nil +} diff --git a/server/internal/db/system_settings.go b/server/internal/db/system.go similarity index 57% rename from server/internal/db/system_settings.go rename to server/internal/db/system.go index 76fdfa5..f954b34 100644 --- a/server/internal/db/system_settings.go +++ b/server/internal/db/system.go @@ -11,9 +11,16 @@ const ( JWTSecretKey = "jwt_secret" ) +// UserStats represents system-wide statistics +type UserStats struct { + TotalUsers int `json:"totalUsers"` + TotalWorkspaces int `json:"totalWorkspaces"` + ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days +} + // EnsureJWTSecret makes sure a JWT signing secret exists in the database // If no secret exists, it generates and stores a new one -func (db *DB) EnsureJWTSecret() (string, error) { +func (db *database) EnsureJWTSecret() (string, error) { // First, try to get existing secret secret, err := db.GetSystemSetting(JWTSecretKey) if err == nil { @@ -36,7 +43,7 @@ func (db *DB) EnsureJWTSecret() (string, error) { } // GetSystemSetting retrieves a system setting by key -func (db *DB) GetSystemSetting(key string) (string, error) { +func (db *database) GetSystemSetting(key string) (string, error) { var value string err := db.QueryRow("SELECT value FROM system_settings WHERE key = ?", key).Scan(&value) if err != nil { @@ -46,7 +53,7 @@ func (db *DB) GetSystemSetting(key string) (string, error) { } // SetSystemSetting stores or updates a system setting -func (db *DB) SetSystemSetting(key, value string) error { +func (db *database) SetSystemSetting(key, value string) error { _, err := db.Exec(` INSERT INTO system_settings (key, value) VALUES (?, ?) @@ -64,3 +71,32 @@ func generateRandomSecret(bytes int) (string, error) { } return base64.StdEncoding.EncodeToString(b), nil } + +// GetSystemStats returns system-wide statistics +func (db *database) GetSystemStats() (*UserStats, error) { + stats := &UserStats{} + + // Get total users + err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers) + if err != nil { + return nil, err + } + + // Get total workspaces + err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces) + if err != nil { + return nil, err + } + + // Get active users (users with activity in last 30 days) + err = db.QueryRow(` + SELECT COUNT(DISTINCT user_id) + FROM sessions + WHERE created_at > datetime('now', '-30 days')`). + Scan(&stats.ActiveUsers) + if err != nil { + return nil, err + } + + return stats, nil +} diff --git a/server/internal/db/users.go b/server/internal/db/users.go index 6a040fb..17cc374 100644 --- a/server/internal/db/users.go +++ b/server/internal/db/users.go @@ -6,7 +6,7 @@ import ( ) // CreateUser inserts a new user record into the database -func (db *DB) CreateUser(user *models.User) (*models.User, error) { +func (db *database) CreateUser(user *models.User) (*models.User, error) { tx, err := db.Begin() if err != nil { return nil, err @@ -62,7 +62,7 @@ func (db *DB) CreateUser(user *models.User) (*models.User, error) { } // Helper function to create a workspace in a transaction -func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error { +func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error { result, err := tx.Exec(` INSERT INTO workspaces ( user_id, name, @@ -87,7 +87,7 @@ func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error { } // GetUserByID retrieves a user by ID -func (db *DB) GetUserByID(id int) (*models.User, error) { +func (db *database) GetUserByID(id int) (*models.User, error) { user := &models.User{} err := db.QueryRow(` SELECT @@ -104,7 +104,7 @@ func (db *DB) GetUserByID(id int) (*models.User, error) { } // GetUserByEmail retrieves a user by email -func (db *DB) GetUserByEmail(email string) (*models.User, error) { +func (db *database) GetUserByEmail(email string) (*models.User, error) { user := &models.User{} err := db.QueryRow(` SELECT @@ -122,7 +122,7 @@ func (db *DB) GetUserByEmail(email string) (*models.User, error) { } // UpdateUser updates a user's information -func (db *DB) UpdateUser(user *models.User) error { +func (db *database) UpdateUser(user *models.User) error { _, err := db.Exec(` UPDATE users SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ? @@ -131,8 +131,36 @@ func (db *DB) UpdateUser(user *models.User) error { return err } +// GetAllUsers returns a list of all users in the system +func (db *database) GetAllUsers() ([]*models.User, error) { + rows, err := db.Query(` + SELECT + id, email, display_name, role, created_at, + last_workspace_id + FROM users + ORDER BY id ASC`) + if err != nil { + return nil, err + } + defer rows.Close() + + var users []*models.User + for rows.Next() { + user := &models.User{} + err := rows.Scan( + &user.ID, &user.Email, &user.DisplayName, &user.Role, + &user.CreatedAt, &user.LastWorkspaceID, + ) + if err != nil { + return nil, err + } + users = append(users, user) + } + return users, nil +} + // UpdateLastWorkspace updates the last workspace the user accessed -func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error { +func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error { tx, err := db.Begin() if err != nil { return err @@ -155,7 +183,7 @@ func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error { } // DeleteUser deletes a user and all their workspaces -func (db *DB) DeleteUser(id int) error { +func (db *database) DeleteUser(id int) error { tx, err := db.Begin() if err != nil { return err @@ -178,7 +206,7 @@ func (db *DB) DeleteUser(id int) error { } // GetLastWorkspaceName returns the name of the last workspace the user accessed -func (db *DB) GetLastWorkspaceName(userID int) (string, error) { +func (db *database) GetLastWorkspaceName(userID int) (string, error) { var workspaceName string err := db.QueryRow(` SELECT @@ -189,3 +217,10 @@ func (db *DB) GetLastWorkspaceName(userID int) (string, error) { Scan(&workspaceName) return workspaceName, err } + +// CountAdminUsers returns the number of admin users in the system +func (db *database) CountAdminUsers() (int, error) { + var count int + err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count) + return count, err +} diff --git a/server/internal/db/workspaces.go b/server/internal/db/workspaces.go index 68be97a..ce39ce5 100644 --- a/server/internal/db/workspaces.go +++ b/server/internal/db/workspaces.go @@ -7,7 +7,7 @@ import ( ) // CreateWorkspace inserts a new workspace record into the database -func (db *DB) CreateWorkspace(workspace *models.Workspace) error { +func (db *database) CreateWorkspace(workspace *models.Workspace) error { // Set default settings if not provided if workspace.Theme == "" { workspace.GetDefaultSettings() @@ -42,7 +42,7 @@ func (db *DB) CreateWorkspace(workspace *models.Workspace) error { } // GetWorkspaceByID retrieves a workspace by its ID -func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) { +func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { workspace := &models.Workspace{} var encryptedToken string @@ -75,7 +75,7 @@ func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) { } // GetWorkspaceByName retrieves a workspace by its name and user ID -func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) { +func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) { workspace := &models.Workspace{} var encryptedToken string @@ -108,7 +108,7 @@ func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Work } // UpdateWorkspace updates a workspace record in the database -func (db *DB) UpdateWorkspace(workspace *models.Workspace) error { +func (db *database) UpdateWorkspace(workspace *models.Workspace) error { // Encrypt token before storing encryptedToken, err := db.encryptToken(workspace.GitToken) if err != nil { @@ -146,7 +146,7 @@ func (db *DB) UpdateWorkspace(workspace *models.Workspace) error { } // GetWorkspacesByUserID retrieves all workspaces for a user -func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { +func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { rows, err := db.Query(` SELECT id, user_id, name, created_at, @@ -189,7 +189,7 @@ func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { // UpdateWorkspaceSettings updates only the settings portion of a workspace // This is useful when you don't want to modify the name or other core workspace properties -func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error { +func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { _, err := db.Exec(` UPDATE workspaces SET @@ -218,31 +218,31 @@ func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error { } // DeleteWorkspace removes a workspace record from the database -func (db *DB) DeleteWorkspace(id int) error { +func (db *database) DeleteWorkspace(id int) error { _, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id) return err } // DeleteWorkspaceTx removes a workspace record from the database within a transaction -func (db *DB) DeleteWorkspaceTx(tx *sql.Tx, id int) error { +func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error { _, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id) return err } // UpdateLastWorkspaceTx sets the last workspace for a user in with a transaction -func (db *DB) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error { +func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error { _, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID) return err } // UpdateLastOpenedFile updates the last opened file path for a workspace -func (db *DB) UpdateLastOpenedFile(workspaceID int, filePath string) error { +func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error { _, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?", filePath, workspaceID) return err } // GetLastOpenedFile retrieves the last opened file path for a workspace -func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) { +func (db *database) GetLastOpenedFile(workspaceID int) (string, error) { var filePath sql.NullString err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?", workspaceID).Scan(&filePath) if err != nil { @@ -255,7 +255,7 @@ func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) { } // GetAllWorkspaces retrieves all workspaces in the database -func (db *DB) GetAllWorkspaces() ([]*models.Workspace, error) { +func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { rows, err := db.Query(` SELECT id, user_id, name, created_at, diff --git a/server/internal/handlers/auth_handlers.go b/server/internal/handlers/auth_handlers.go index 319b6d9..cce4c30 100644 --- a/server/internal/handlers/auth_handlers.go +++ b/server/internal/handlers/auth_handlers.go @@ -16,10 +16,10 @@ type LoginRequest struct { } type LoginResponse struct { - AccessToken string `json:"accessToken"` - RefreshToken string `json:"refreshToken"` - User *models.User `json:"user"` - Session *auth.Session `json:"session"` + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + User *models.User `json:"user"` + Session *models.Session `json:"session"` } type RefreshRequest struct { diff --git a/server/internal/handlers/handlers.go b/server/internal/handlers/handlers.go index 743e40d..7af3611 100644 --- a/server/internal/handlers/handlers.go +++ b/server/internal/handlers/handlers.go @@ -9,12 +9,12 @@ import ( // Handler provides common functionality for all handlers type Handler struct { - DB *db.DB + DB db.Database Storage storage.Manager } // NewHandler creates a new handler with the given dependencies -func NewHandler(db *db.DB, s storage.Manager) *Handler { +func NewHandler(db db.Database, s storage.Manager) *Handler { return &Handler{ DB: db, Storage: s, diff --git a/server/internal/handlers/user_handlers.go b/server/internal/handlers/user_handlers.go index fa57020..013baf0 100644 --- a/server/internal/handlers/user_handlers.go +++ b/server/internal/handlers/user_handlers.go @@ -171,8 +171,7 @@ func (h *Handler) DeleteAccount() http.HandlerFunc { // Prevent admin from deleting their own account if they're the last admin if user.Role == "admin" { // Count number of admin users - adminCount := 0 - err := h.DB.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&adminCount) + adminCount, err := h.DB.CountAdminUsers() if err != nil { http.Error(w, "Failed to verify admin status", http.StatusInternalServerError) return diff --git a/server/internal/middleware/context.go b/server/internal/middleware/context.go index 95a72d4..288ed24 100644 --- a/server/internal/middleware/context.go +++ b/server/internal/middleware/context.go @@ -29,7 +29,7 @@ func WithUserContext(next http.Handler) http.Handler { } // Workspace context -func WithWorkspaceContext(db *db.DB) func(http.Handler) http.Handler { +func WithWorkspaceContext(db db.Database) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, ok := httpcontext.GetRequestContext(w, r) diff --git a/server/internal/models/session.go b/server/internal/models/session.go new file mode 100644 index 0000000..d0b8119 --- /dev/null +++ b/server/internal/models/session.go @@ -0,0 +1,13 @@ +// Package models contains the data models used throughout the application. These models are used to represent data in the database, as well as to validate and serialize data in the application. +package models + +import "time" + +// Session represents a user session in the database +type Session struct { + ID string // Unique session identifier + UserID int // ID of the user this session belongs to + RefreshToken string // The refresh token associated with this session + ExpiresAt time.Time // When this session expires + CreatedAt time.Time // When this session was created +} diff --git a/server/internal/user/user.go b/server/internal/user/user.go index 638ec39..20383ca 100644 --- a/server/internal/user/user.go +++ b/server/internal/user/user.go @@ -13,11 +13,11 @@ import ( ) type UserService struct { - DB *db.DB + DB db.Database Storage storage.Manager } -func NewUserService(database *db.DB, s storage.Manager) *UserService { +func NewUserService(database db.Database, s storage.Manager) *UserService { return &UserService{ DB: database, Storage: s,