mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-06 07:54:22 +00:00
Rework db package to make it testable
This commit is contained in:
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
// SetupRoutes configures the API routes
|
||||
func SetupRoutes(r chi.Router, db *db.DB, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) {
|
||||
func SetupRoutes(r chi.Router, db db.Database, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) {
|
||||
|
||||
handler := &handlers.Handler{
|
||||
DB: db,
|
||||
|
||||
@@ -1,33 +1,25 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/models"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Session represents a user session in the database
|
||||
type Session struct {
|
||||
ID string // Unique session identifier
|
||||
UserID int // ID of the user this session belongs to
|
||||
RefreshToken string // The refresh token associated with this session
|
||||
ExpiresAt time.Time // When this session expires
|
||||
CreatedAt time.Time // When this session was created
|
||||
}
|
||||
|
||||
// SessionService manages user sessions in the database
|
||||
type SessionService struct {
|
||||
db *sql.DB // Database connection
|
||||
jwtManager JWTManager // JWT Manager for token operations
|
||||
db db.SessionStore // Database store for sessions
|
||||
jwtManager JWTManager // JWT Manager for token operations
|
||||
}
|
||||
|
||||
// NewSessionService creates a new session service
|
||||
// Parameters:
|
||||
// - db: database connection
|
||||
// - jwtManager: JWT service for token operations
|
||||
func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService {
|
||||
func NewSessionService(db db.SessionStore, jwtManager JWTManager) *SessionService {
|
||||
return &SessionService{
|
||||
db: db,
|
||||
jwtManager: jwtManager,
|
||||
@@ -42,7 +34,7 @@ func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService {
|
||||
// - session: the created session
|
||||
// - accessToken: a new access token
|
||||
// - error: any error that occurred
|
||||
func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) {
|
||||
func (s *SessionService) CreateSession(userID int, role string) (*models.Session, string, error) {
|
||||
// Generate both access and refresh tokens
|
||||
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
|
||||
if err != nil {
|
||||
@@ -61,7 +53,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
|
||||
}
|
||||
|
||||
// Create a new session record
|
||||
session := &Session{
|
||||
session := &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: userID,
|
||||
RefreshToken: refreshToken,
|
||||
@@ -69,14 +61,9 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Store the session in the database
|
||||
_, err = s.db.Exec(`
|
||||
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to store session: %w", err)
|
||||
// Store the session
|
||||
if err := s.db.CreateSession(session); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return session, accessToken, nil
|
||||
@@ -89,28 +76,18 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
|
||||
// - string: a new access token
|
||||
// - error: any error that occurred
|
||||
func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
||||
// Get session from database
|
||||
_, err := s.db.GetSessionByRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid session: %w", err)
|
||||
}
|
||||
|
||||
// Validate the refresh token
|
||||
claims, err := s.jwtManager.ValidateToken(refreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Check if the session exists and is not expired
|
||||
var session Session
|
||||
err = s.db.QueryRow(`
|
||||
SELECT id, user_id, refresh_token, expires_at, created_at
|
||||
FROM sessions
|
||||
WHERE refresh_token = ? AND expires_at > ?`,
|
||||
refreshToken, time.Now(),
|
||||
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("session not found or expired")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch session: %w", err)
|
||||
}
|
||||
|
||||
// Generate a new access token
|
||||
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
|
||||
}
|
||||
@@ -121,20 +98,12 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
||||
// Returns:
|
||||
// - error: any error that occurred
|
||||
func (s *SessionService) InvalidateSession(sessionID string) error {
|
||||
_, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to invalidate session: %w", err)
|
||||
}
|
||||
return nil
|
||||
return s.db.DeleteSession(sessionID)
|
||||
}
|
||||
|
||||
// CleanExpiredSessions removes all expired sessions from the database
|
||||
// Returns:
|
||||
// - error: any error that occurred
|
||||
func (s *SessionService) CleanExpiredSessions() error {
|
||||
_, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clean expired sessions: %w", err)
|
||||
}
|
||||
return nil
|
||||
return s.db.CleanExpiredSessions()
|
||||
}
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records.
|
||||
package db
|
||||
|
||||
import "novamd/internal/models"
|
||||
|
||||
// UserStats represents system-wide statistics
|
||||
type UserStats struct {
|
||||
TotalUsers int `json:"totalUsers"`
|
||||
TotalWorkspaces int `json:"totalWorkspaces"`
|
||||
ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days
|
||||
}
|
||||
|
||||
// GetAllUsers returns a list of all users in the system
|
||||
func (db *DB) GetAllUsers() ([]*models.User, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
id, email, display_name, role, created_at,
|
||||
last_workspace_id
|
||||
FROM users
|
||||
ORDER BY id ASC`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []*models.User
|
||||
for rows.Next() {
|
||||
user := &models.User{}
|
||||
err := rows.Scan(
|
||||
&user.ID, &user.Email, &user.DisplayName, &user.Role,
|
||||
&user.CreatedAt, &user.LastWorkspaceID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetSystemStats returns system-wide statistics
|
||||
func (db *DB) GetSystemStats() (*UserStats, error) {
|
||||
stats := &UserStats{}
|
||||
|
||||
// Get total users
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get total workspaces
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get active users (users with activity in last 30 days)
|
||||
err = db.QueryRow(`
|
||||
SELECT COUNT(DISTINCT user_id)
|
||||
FROM sessions
|
||||
WHERE created_at > datetime('now', '-30 days')`).
|
||||
Scan(&stats.ActiveUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records.
|
||||
package db
|
||||
|
||||
import (
|
||||
@@ -5,18 +6,75 @@ import (
|
||||
"fmt"
|
||||
|
||||
"novamd/internal/crypto"
|
||||
"novamd/internal/models"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
// DB represents the database connection
|
||||
type DB struct {
|
||||
// UserStore defines the methods for interacting with user data in the database
|
||||
type UserStore interface {
|
||||
CreateUser(user *models.User) (*models.User, error)
|
||||
GetUserByEmail(email string) (*models.User, error)
|
||||
GetUserByID(userID int) (*models.User, error)
|
||||
GetAllUsers() ([]*models.User, error)
|
||||
UpdateUser(user *models.User) error
|
||||
DeleteUser(userID int) error
|
||||
UpdateLastWorkspace(userID int, workspaceName string) error
|
||||
GetLastWorkspaceName(userID int) (string, error)
|
||||
CountAdminUsers() (int, error)
|
||||
}
|
||||
|
||||
// WorkspaceStore defines the methods for interacting with workspace data in the database
|
||||
type WorkspaceStore interface {
|
||||
CreateWorkspace(workspace *models.Workspace) error
|
||||
GetWorkspaceByID(workspaceID int) (*models.Workspace, error)
|
||||
GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error)
|
||||
GetWorkspacesByUserID(userID int) ([]*models.Workspace, error)
|
||||
UpdateWorkspace(workspace *models.Workspace) error
|
||||
DeleteWorkspace(workspaceID int) error
|
||||
UpdateWorkspaceSettings(workspace *models.Workspace) error
|
||||
DeleteWorkspaceTx(tx *sql.Tx, workspaceID int) error
|
||||
UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error
|
||||
UpdateLastOpenedFile(workspaceID int, filePath string) error
|
||||
GetLastOpenedFile(workspaceID int) (string, error)
|
||||
GetAllWorkspaces() ([]*models.Workspace, error)
|
||||
}
|
||||
|
||||
// SessionStore defines the methods for interacting with jwt sessions in the database
|
||||
type SessionStore interface {
|
||||
CreateSession(session *models.Session) error
|
||||
GetSessionByRefreshToken(refreshToken string) (*models.Session, error)
|
||||
DeleteSession(sessionID string) error
|
||||
CleanExpiredSessions() error
|
||||
}
|
||||
|
||||
// SystemStore defines the methods for interacting with system settings and stats in the database
|
||||
type SystemStore interface {
|
||||
GetSystemStats() (*UserStats, error)
|
||||
EnsureJWTSecret() (string, error)
|
||||
GetSystemSetting(key string) (string, error)
|
||||
SetSystemSetting(key, value string) error
|
||||
}
|
||||
|
||||
// Database defines the methods for interacting with the database
|
||||
type Database interface {
|
||||
UserStore
|
||||
WorkspaceStore
|
||||
SessionStore
|
||||
SystemStore
|
||||
Begin() (*sql.Tx, error)
|
||||
Close() error
|
||||
Migrate() error
|
||||
}
|
||||
|
||||
// database represents the database connection
|
||||
type database struct {
|
||||
*sql.DB
|
||||
crypto *crypto.Crypto
|
||||
}
|
||||
|
||||
// Init initializes the database connection
|
||||
func Init(dbPath string, encryptionKey string) (*DB, error) {
|
||||
func Init(dbPath string, encryptionKey string) (Database, error) {
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -32,7 +90,7 @@ func Init(dbPath string, encryptionKey string) (*DB, error) {
|
||||
return nil, fmt.Errorf("failed to initialize encryption: %w", err)
|
||||
}
|
||||
|
||||
database := &DB{
|
||||
database := &database{
|
||||
DB: db,
|
||||
crypto: cryptoService,
|
||||
}
|
||||
@@ -45,19 +103,19 @@ func Init(dbPath string, encryptionKey string) (*DB, error) {
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (db *DB) Close() error {
|
||||
func (db *database) Close() error {
|
||||
return db.DB.Close()
|
||||
}
|
||||
|
||||
// Helper methods for token encryption/decryption
|
||||
func (db *DB) encryptToken(token string) (string, error) {
|
||||
func (db *database) encryptToken(token string) (string, error) {
|
||||
if token == "" {
|
||||
return "", nil
|
||||
}
|
||||
return db.crypto.Encrypt(token)
|
||||
}
|
||||
|
||||
func (db *DB) decryptToken(token string) (string, error) {
|
||||
func (db *database) decryptToken(token string) (string, error) {
|
||||
if token == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -83,7 +83,7 @@ var migrations = []Migration{
|
||||
}
|
||||
|
||||
// Migrate applies all database migrations
|
||||
func (db *DB) Migrate() error {
|
||||
func (db *database) Migrate() error {
|
||||
// Create migrations table if it doesn't exist
|
||||
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations (
|
||||
version INTEGER PRIMARY KEY
|
||||
|
||||
70
server/internal/db/sessions.go
Normal file
70
server/internal/db/sessions.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"novamd/internal/models"
|
||||
)
|
||||
|
||||
// CreateSession inserts a new session record into the database
|
||||
func (db *database) CreateSession(session *models.Session) error {
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSessionByRefreshToken retrieves a session by its refresh token
|
||||
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
|
||||
session := &models.Session{}
|
||||
err := db.QueryRow(`
|
||||
SELECT id, user_id, refresh_token, expires_at, created_at
|
||||
FROM sessions
|
||||
WHERE refresh_token = ? AND expires_at > ?`,
|
||||
refreshToken, time.Now(),
|
||||
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch session: %w", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// DeleteSession removes a session from the database
|
||||
func (db *database) DeleteSession(sessionID string) error {
|
||||
result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete session: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanExpiredSessions removes all expired sessions from the database
|
||||
func (db *database) CleanExpiredSessions() error {
|
||||
_, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clean expired sessions: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -11,9 +11,16 @@ const (
|
||||
JWTSecretKey = "jwt_secret"
|
||||
)
|
||||
|
||||
// UserStats represents system-wide statistics
|
||||
type UserStats struct {
|
||||
TotalUsers int `json:"totalUsers"`
|
||||
TotalWorkspaces int `json:"totalWorkspaces"`
|
||||
ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days
|
||||
}
|
||||
|
||||
// EnsureJWTSecret makes sure a JWT signing secret exists in the database
|
||||
// If no secret exists, it generates and stores a new one
|
||||
func (db *DB) EnsureJWTSecret() (string, error) {
|
||||
func (db *database) EnsureJWTSecret() (string, error) {
|
||||
// First, try to get existing secret
|
||||
secret, err := db.GetSystemSetting(JWTSecretKey)
|
||||
if err == nil {
|
||||
@@ -36,7 +43,7 @@ func (db *DB) EnsureJWTSecret() (string, error) {
|
||||
}
|
||||
|
||||
// GetSystemSetting retrieves a system setting by key
|
||||
func (db *DB) GetSystemSetting(key string) (string, error) {
|
||||
func (db *database) GetSystemSetting(key string) (string, error) {
|
||||
var value string
|
||||
err := db.QueryRow("SELECT value FROM system_settings WHERE key = ?", key).Scan(&value)
|
||||
if err != nil {
|
||||
@@ -46,7 +53,7 @@ func (db *DB) GetSystemSetting(key string) (string, error) {
|
||||
}
|
||||
|
||||
// SetSystemSetting stores or updates a system setting
|
||||
func (db *DB) SetSystemSetting(key, value string) error {
|
||||
func (db *database) SetSystemSetting(key, value string) error {
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO system_settings (key, value)
|
||||
VALUES (?, ?)
|
||||
@@ -64,3 +71,32 @@ func generateRandomSecret(bytes int) (string, error) {
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// GetSystemStats returns system-wide statistics
|
||||
func (db *database) GetSystemStats() (*UserStats, error) {
|
||||
stats := &UserStats{}
|
||||
|
||||
// Get total users
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get total workspaces
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get active users (users with activity in last 30 days)
|
||||
err = db.QueryRow(`
|
||||
SELECT COUNT(DISTINCT user_id)
|
||||
FROM sessions
|
||||
WHERE created_at > datetime('now', '-30 days')`).
|
||||
Scan(&stats.ActiveUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
// CreateUser inserts a new user record into the database
|
||||
func (db *DB) CreateUser(user *models.User) (*models.User, error) {
|
||||
func (db *database) CreateUser(user *models.User) (*models.User, error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -62,7 +62,7 @@ func (db *DB) CreateUser(user *models.User) (*models.User, error) {
|
||||
}
|
||||
|
||||
// Helper function to create a workspace in a transaction
|
||||
func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
|
||||
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
|
||||
result, err := tx.Exec(`
|
||||
INSERT INTO workspaces (
|
||||
user_id, name,
|
||||
@@ -87,7 +87,7 @@ func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
|
||||
}
|
||||
|
||||
// GetUserByID retrieves a user by ID
|
||||
func (db *DB) GetUserByID(id int) (*models.User, error) {
|
||||
func (db *database) GetUserByID(id int) (*models.User, error) {
|
||||
user := &models.User{}
|
||||
err := db.QueryRow(`
|
||||
SELECT
|
||||
@@ -104,7 +104,7 @@ func (db *DB) GetUserByID(id int) (*models.User, error) {
|
||||
}
|
||||
|
||||
// GetUserByEmail retrieves a user by email
|
||||
func (db *DB) GetUserByEmail(email string) (*models.User, error) {
|
||||
func (db *database) GetUserByEmail(email string) (*models.User, error) {
|
||||
user := &models.User{}
|
||||
err := db.QueryRow(`
|
||||
SELECT
|
||||
@@ -122,7 +122,7 @@ func (db *DB) GetUserByEmail(email string) (*models.User, error) {
|
||||
}
|
||||
|
||||
// UpdateUser updates a user's information
|
||||
func (db *DB) UpdateUser(user *models.User) error {
|
||||
func (db *database) UpdateUser(user *models.User) error {
|
||||
_, err := db.Exec(`
|
||||
UPDATE users
|
||||
SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ?
|
||||
@@ -131,8 +131,36 @@ func (db *DB) UpdateUser(user *models.User) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAllUsers returns a list of all users in the system
|
||||
func (db *database) GetAllUsers() ([]*models.User, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
id, email, display_name, role, created_at,
|
||||
last_workspace_id
|
||||
FROM users
|
||||
ORDER BY id ASC`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []*models.User
|
||||
for rows.Next() {
|
||||
user := &models.User{}
|
||||
err := rows.Scan(
|
||||
&user.ID, &user.Email, &user.DisplayName, &user.Role,
|
||||
&user.CreatedAt, &user.LastWorkspaceID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// UpdateLastWorkspace updates the last workspace the user accessed
|
||||
func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error {
|
||||
func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -155,7 +183,7 @@ func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error {
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user and all their workspaces
|
||||
func (db *DB) DeleteUser(id int) error {
|
||||
func (db *database) DeleteUser(id int) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -178,7 +206,7 @@ func (db *DB) DeleteUser(id int) error {
|
||||
}
|
||||
|
||||
// GetLastWorkspaceName returns the name of the last workspace the user accessed
|
||||
func (db *DB) GetLastWorkspaceName(userID int) (string, error) {
|
||||
func (db *database) GetLastWorkspaceName(userID int) (string, error) {
|
||||
var workspaceName string
|
||||
err := db.QueryRow(`
|
||||
SELECT
|
||||
@@ -189,3 +217,10 @@ func (db *DB) GetLastWorkspaceName(userID int) (string, error) {
|
||||
Scan(&workspaceName)
|
||||
return workspaceName, err
|
||||
}
|
||||
|
||||
// CountAdminUsers returns the number of admin users in the system
|
||||
func (db *database) CountAdminUsers() (int, error) {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
// CreateWorkspace inserts a new workspace record into the database
|
||||
func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
|
||||
func (db *database) CreateWorkspace(workspace *models.Workspace) error {
|
||||
// Set default settings if not provided
|
||||
if workspace.Theme == "" {
|
||||
workspace.GetDefaultSettings()
|
||||
@@ -42,7 +42,7 @@ func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
|
||||
}
|
||||
|
||||
// GetWorkspaceByID retrieves a workspace by its ID
|
||||
func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) {
|
||||
func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
|
||||
workspace := &models.Workspace{}
|
||||
var encryptedToken string
|
||||
|
||||
@@ -75,7 +75,7 @@ func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) {
|
||||
}
|
||||
|
||||
// GetWorkspaceByName retrieves a workspace by its name and user ID
|
||||
func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
|
||||
func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
|
||||
workspace := &models.Workspace{}
|
||||
var encryptedToken string
|
||||
|
||||
@@ -108,7 +108,7 @@ func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Work
|
||||
}
|
||||
|
||||
// UpdateWorkspace updates a workspace record in the database
|
||||
func (db *DB) UpdateWorkspace(workspace *models.Workspace) error {
|
||||
func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
|
||||
// Encrypt token before storing
|
||||
encryptedToken, err := db.encryptToken(workspace.GitToken)
|
||||
if err != nil {
|
||||
@@ -146,7 +146,7 @@ func (db *DB) UpdateWorkspace(workspace *models.Workspace) error {
|
||||
}
|
||||
|
||||
// GetWorkspacesByUserID retrieves all workspaces for a user
|
||||
func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
|
||||
func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
id, user_id, name, created_at,
|
||||
@@ -189,7 +189,7 @@ func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
|
||||
|
||||
// UpdateWorkspaceSettings updates only the settings portion of a workspace
|
||||
// This is useful when you don't want to modify the name or other core workspace properties
|
||||
func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error {
|
||||
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
|
||||
_, err := db.Exec(`
|
||||
UPDATE workspaces
|
||||
SET
|
||||
@@ -218,31 +218,31 @@ func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error {
|
||||
}
|
||||
|
||||
// DeleteWorkspace removes a workspace record from the database
|
||||
func (db *DB) DeleteWorkspace(id int) error {
|
||||
func (db *database) DeleteWorkspace(id int) error {
|
||||
_, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteWorkspaceTx removes a workspace record from the database within a transaction
|
||||
func (db *DB) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
|
||||
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
|
||||
_, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateLastWorkspaceTx sets the last workspace for a user in with a transaction
|
||||
func (db *DB) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
|
||||
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
|
||||
_, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateLastOpenedFile updates the last opened file path for a workspace
|
||||
func (db *DB) UpdateLastOpenedFile(workspaceID int, filePath string) error {
|
||||
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
|
||||
_, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?", filePath, workspaceID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetLastOpenedFile retrieves the last opened file path for a workspace
|
||||
func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) {
|
||||
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
|
||||
var filePath sql.NullString
|
||||
err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?", workspaceID).Scan(&filePath)
|
||||
if err != nil {
|
||||
@@ -255,7 +255,7 @@ func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) {
|
||||
}
|
||||
|
||||
// GetAllWorkspaces retrieves all workspaces in the database
|
||||
func (db *DB) GetAllWorkspaces() ([]*models.Workspace, error) {
|
||||
func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
id, user_id, name, created_at,
|
||||
|
||||
@@ -16,10 +16,10 @@ type LoginRequest struct {
|
||||
}
|
||||
|
||||
type LoginResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
User *models.User `json:"user"`
|
||||
Session *auth.Session `json:"session"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
User *models.User `json:"user"`
|
||||
Session *models.Session `json:"session"`
|
||||
}
|
||||
|
||||
type RefreshRequest struct {
|
||||
|
||||
@@ -9,12 +9,12 @@ import (
|
||||
|
||||
// Handler provides common functionality for all handlers
|
||||
type Handler struct {
|
||||
DB *db.DB
|
||||
DB db.Database
|
||||
Storage storage.Manager
|
||||
}
|
||||
|
||||
// NewHandler creates a new handler with the given dependencies
|
||||
func NewHandler(db *db.DB, s storage.Manager) *Handler {
|
||||
func NewHandler(db db.Database, s storage.Manager) *Handler {
|
||||
return &Handler{
|
||||
DB: db,
|
||||
Storage: s,
|
||||
|
||||
@@ -171,8 +171,7 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
|
||||
// Prevent admin from deleting their own account if they're the last admin
|
||||
if user.Role == "admin" {
|
||||
// Count number of admin users
|
||||
adminCount := 0
|
||||
err := h.DB.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&adminCount)
|
||||
adminCount, err := h.DB.CountAdminUsers()
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to verify admin status", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -29,7 +29,7 @@ func WithUserContext(next http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// Workspace context
|
||||
func WithWorkspaceContext(db *db.DB) func(http.Handler) http.Handler {
|
||||
func WithWorkspaceContext(db db.Database) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
|
||||
13
server/internal/models/session.go
Normal file
13
server/internal/models/session.go
Normal file
@@ -0,0 +1,13 @@
|
||||
// Package models contains the data models used throughout the application. These models are used to represent data in the database, as well as to validate and serialize data in the application.
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// Session represents a user session in the database
|
||||
type Session struct {
|
||||
ID string // Unique session identifier
|
||||
UserID int // ID of the user this session belongs to
|
||||
RefreshToken string // The refresh token associated with this session
|
||||
ExpiresAt time.Time // When this session expires
|
||||
CreatedAt time.Time // When this session was created
|
||||
}
|
||||
@@ -13,11 +13,11 @@ import (
|
||||
)
|
||||
|
||||
type UserService struct {
|
||||
DB *db.DB
|
||||
DB db.Database
|
||||
Storage storage.Manager
|
||||
}
|
||||
|
||||
func NewUserService(database *db.DB, s storage.Manager) *UserService {
|
||||
func NewUserService(database db.Database, s storage.Manager) *UserService {
|
||||
return &UserService{
|
||||
DB: database,
|
||||
Storage: s,
|
||||
|
||||
Reference in New Issue
Block a user