mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-05 23:44:22 +00:00
Rework db package to make it testable
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
// Package main contains the main entry point for the application. It sets up the server, database, and other services, and starts the server.
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -61,7 +62,7 @@ func main() {
|
|||||||
authMiddleware := auth.NewMiddleware(jwtManager)
|
authMiddleware := auth.NewMiddleware(jwtManager)
|
||||||
|
|
||||||
// Initialize session service
|
// Initialize session service
|
||||||
sessionService := auth.NewSessionService(database.DB, jwtManager)
|
sessionService := auth.NewSessionService(database, jwtManager)
|
||||||
|
|
||||||
// Set up router
|
// Set up router
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// SetupRoutes configures the API routes
|
// SetupRoutes configures the API routes
|
||||||
func SetupRoutes(r chi.Router, db *db.DB, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) {
|
func SetupRoutes(r chi.Router, db db.Database, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) {
|
||||||
|
|
||||||
handler := &handlers.Handler{
|
handler := &handlers.Handler{
|
||||||
DB: db,
|
DB: db,
|
||||||
|
|||||||
@@ -1,25 +1,17 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"novamd/internal/db"
|
||||||
|
"novamd/internal/models"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Session represents a user session in the database
|
|
||||||
type Session struct {
|
|
||||||
ID string // Unique session identifier
|
|
||||||
UserID int // ID of the user this session belongs to
|
|
||||||
RefreshToken string // The refresh token associated with this session
|
|
||||||
ExpiresAt time.Time // When this session expires
|
|
||||||
CreatedAt time.Time // When this session was created
|
|
||||||
}
|
|
||||||
|
|
||||||
// SessionService manages user sessions in the database
|
// SessionService manages user sessions in the database
|
||||||
type SessionService struct {
|
type SessionService struct {
|
||||||
db *sql.DB // Database connection
|
db db.SessionStore // Database store for sessions
|
||||||
jwtManager JWTManager // JWT Manager for token operations
|
jwtManager JWTManager // JWT Manager for token operations
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -27,7 +19,7 @@ type SessionService struct {
|
|||||||
// Parameters:
|
// Parameters:
|
||||||
// - db: database connection
|
// - db: database connection
|
||||||
// - jwtManager: JWT service for token operations
|
// - jwtManager: JWT service for token operations
|
||||||
func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService {
|
func NewSessionService(db db.SessionStore, jwtManager JWTManager) *SessionService {
|
||||||
return &SessionService{
|
return &SessionService{
|
||||||
db: db,
|
db: db,
|
||||||
jwtManager: jwtManager,
|
jwtManager: jwtManager,
|
||||||
@@ -42,7 +34,7 @@ func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService {
|
|||||||
// - session: the created session
|
// - session: the created session
|
||||||
// - accessToken: a new access token
|
// - accessToken: a new access token
|
||||||
// - error: any error that occurred
|
// - error: any error that occurred
|
||||||
func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) {
|
func (s *SessionService) CreateSession(userID int, role string) (*models.Session, string, error) {
|
||||||
// Generate both access and refresh tokens
|
// Generate both access and refresh tokens
|
||||||
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
|
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -61,7 +53,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new session record
|
// Create a new session record
|
||||||
session := &Session{
|
session := &models.Session{
|
||||||
ID: uuid.New().String(),
|
ID: uuid.New().String(),
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
@@ -69,14 +61,9 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
|
|||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the session in the database
|
// Store the session
|
||||||
_, err = s.db.Exec(`
|
if err := s.db.CreateSession(session); err != nil {
|
||||||
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
|
return nil, "", err
|
||||||
VALUES (?, ?, ?, ?, ?)`,
|
|
||||||
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", fmt.Errorf("failed to store session: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return session, accessToken, nil
|
return session, accessToken, nil
|
||||||
@@ -89,28 +76,18 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
|
|||||||
// - string: a new access token
|
// - string: a new access token
|
||||||
// - error: any error that occurred
|
// - error: any error that occurred
|
||||||
func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
||||||
|
// Get session from database
|
||||||
|
_, err := s.db.GetSessionByRefreshToken(refreshToken)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("invalid session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Validate the refresh token
|
// Validate the refresh token
|
||||||
claims, err := s.jwtManager.ValidateToken(refreshToken)
|
claims, err := s.jwtManager.ValidateToken(refreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("invalid refresh token: %w", err)
|
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the session exists and is not expired
|
|
||||||
var session Session
|
|
||||||
err = s.db.QueryRow(`
|
|
||||||
SELECT id, user_id, refresh_token, expires_at, created_at
|
|
||||||
FROM sessions
|
|
||||||
WHERE refresh_token = ? AND expires_at > ?`,
|
|
||||||
refreshToken, time.Now(),
|
|
||||||
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
|
|
||||||
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return "", fmt.Errorf("session not found or expired")
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to fetch session: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate a new access token
|
// Generate a new access token
|
||||||
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
|
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
|
||||||
}
|
}
|
||||||
@@ -121,20 +98,12 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - error: any error that occurred
|
// - error: any error that occurred
|
||||||
func (s *SessionService) InvalidateSession(sessionID string) error {
|
func (s *SessionService) InvalidateSession(sessionID string) error {
|
||||||
_, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
|
return s.db.DeleteSession(sessionID)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to invalidate session: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanExpiredSessions removes all expired sessions from the database
|
// CleanExpiredSessions removes all expired sessions from the database
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: any error that occurred
|
// - error: any error that occurred
|
||||||
func (s *SessionService) CleanExpiredSessions() error {
|
func (s *SessionService) CleanExpiredSessions() error {
|
||||||
_, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
|
return s.db.CleanExpiredSessions()
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to clean expired sessions: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records.
|
|
||||||
package db
|
|
||||||
|
|
||||||
import "novamd/internal/models"
|
|
||||||
|
|
||||||
// UserStats represents system-wide statistics
|
|
||||||
type UserStats struct {
|
|
||||||
TotalUsers int `json:"totalUsers"`
|
|
||||||
TotalWorkspaces int `json:"totalWorkspaces"`
|
|
||||||
ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllUsers returns a list of all users in the system
|
|
||||||
func (db *DB) GetAllUsers() ([]*models.User, error) {
|
|
||||||
rows, err := db.Query(`
|
|
||||||
SELECT
|
|
||||||
id, email, display_name, role, created_at,
|
|
||||||
last_workspace_id
|
|
||||||
FROM users
|
|
||||||
ORDER BY id ASC`)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var users []*models.User
|
|
||||||
for rows.Next() {
|
|
||||||
user := &models.User{}
|
|
||||||
err := rows.Scan(
|
|
||||||
&user.ID, &user.Email, &user.DisplayName, &user.Role,
|
|
||||||
&user.CreatedAt, &user.LastWorkspaceID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
users = append(users, user)
|
|
||||||
}
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSystemStats returns system-wide statistics
|
|
||||||
func (db *DB) GetSystemStats() (*UserStats, error) {
|
|
||||||
stats := &UserStats{}
|
|
||||||
|
|
||||||
// Get total users
|
|
||||||
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get total workspaces
|
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get active users (users with activity in last 30 days)
|
|
||||||
err = db.QueryRow(`
|
|
||||||
SELECT COUNT(DISTINCT user_id)
|
|
||||||
FROM sessions
|
|
||||||
WHERE created_at > datetime('now', '-30 days')`).
|
|
||||||
Scan(&stats.ActiveUsers)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return stats, nil
|
|
||||||
}
|
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records.
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -5,18 +6,75 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"novamd/internal/crypto"
|
"novamd/internal/crypto"
|
||||||
|
"novamd/internal/models"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||||
)
|
)
|
||||||
|
|
||||||
// DB represents the database connection
|
// UserStore defines the methods for interacting with user data in the database
|
||||||
type DB struct {
|
type UserStore interface {
|
||||||
|
CreateUser(user *models.User) (*models.User, error)
|
||||||
|
GetUserByEmail(email string) (*models.User, error)
|
||||||
|
GetUserByID(userID int) (*models.User, error)
|
||||||
|
GetAllUsers() ([]*models.User, error)
|
||||||
|
UpdateUser(user *models.User) error
|
||||||
|
DeleteUser(userID int) error
|
||||||
|
UpdateLastWorkspace(userID int, workspaceName string) error
|
||||||
|
GetLastWorkspaceName(userID int) (string, error)
|
||||||
|
CountAdminUsers() (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WorkspaceStore defines the methods for interacting with workspace data in the database
|
||||||
|
type WorkspaceStore interface {
|
||||||
|
CreateWorkspace(workspace *models.Workspace) error
|
||||||
|
GetWorkspaceByID(workspaceID int) (*models.Workspace, error)
|
||||||
|
GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error)
|
||||||
|
GetWorkspacesByUserID(userID int) ([]*models.Workspace, error)
|
||||||
|
UpdateWorkspace(workspace *models.Workspace) error
|
||||||
|
DeleteWorkspace(workspaceID int) error
|
||||||
|
UpdateWorkspaceSettings(workspace *models.Workspace) error
|
||||||
|
DeleteWorkspaceTx(tx *sql.Tx, workspaceID int) error
|
||||||
|
UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error
|
||||||
|
UpdateLastOpenedFile(workspaceID int, filePath string) error
|
||||||
|
GetLastOpenedFile(workspaceID int) (string, error)
|
||||||
|
GetAllWorkspaces() ([]*models.Workspace, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionStore defines the methods for interacting with jwt sessions in the database
|
||||||
|
type SessionStore interface {
|
||||||
|
CreateSession(session *models.Session) error
|
||||||
|
GetSessionByRefreshToken(refreshToken string) (*models.Session, error)
|
||||||
|
DeleteSession(sessionID string) error
|
||||||
|
CleanExpiredSessions() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemStore defines the methods for interacting with system settings and stats in the database
|
||||||
|
type SystemStore interface {
|
||||||
|
GetSystemStats() (*UserStats, error)
|
||||||
|
EnsureJWTSecret() (string, error)
|
||||||
|
GetSystemSetting(key string) (string, error)
|
||||||
|
SetSystemSetting(key, value string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Database defines the methods for interacting with the database
|
||||||
|
type Database interface {
|
||||||
|
UserStore
|
||||||
|
WorkspaceStore
|
||||||
|
SessionStore
|
||||||
|
SystemStore
|
||||||
|
Begin() (*sql.Tx, error)
|
||||||
|
Close() error
|
||||||
|
Migrate() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// database represents the database connection
|
||||||
|
type database struct {
|
||||||
*sql.DB
|
*sql.DB
|
||||||
crypto *crypto.Crypto
|
crypto *crypto.Crypto
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes the database connection
|
// Init initializes the database connection
|
||||||
func Init(dbPath string, encryptionKey string) (*DB, error) {
|
func Init(dbPath string, encryptionKey string) (Database, error) {
|
||||||
db, err := sql.Open("sqlite3", dbPath)
|
db, err := sql.Open("sqlite3", dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -32,7 +90,7 @@ func Init(dbPath string, encryptionKey string) (*DB, error) {
|
|||||||
return nil, fmt.Errorf("failed to initialize encryption: %w", err)
|
return nil, fmt.Errorf("failed to initialize encryption: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
database := &DB{
|
database := &database{
|
||||||
DB: db,
|
DB: db,
|
||||||
crypto: cryptoService,
|
crypto: cryptoService,
|
||||||
}
|
}
|
||||||
@@ -45,19 +103,19 @@ func Init(dbPath string, encryptionKey string) (*DB, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the database connection
|
// Close closes the database connection
|
||||||
func (db *DB) Close() error {
|
func (db *database) Close() error {
|
||||||
return db.DB.Close()
|
return db.DB.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper methods for token encryption/decryption
|
// Helper methods for token encryption/decryption
|
||||||
func (db *DB) encryptToken(token string) (string, error) {
|
func (db *database) encryptToken(token string) (string, error) {
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
return db.crypto.Encrypt(token)
|
return db.crypto.Encrypt(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) decryptToken(token string) (string, error) {
|
func (db *database) decryptToken(token string) (string, error) {
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ var migrations = []Migration{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Migrate applies all database migrations
|
// Migrate applies all database migrations
|
||||||
func (db *DB) Migrate() error {
|
func (db *database) Migrate() error {
|
||||||
// Create migrations table if it doesn't exist
|
// Create migrations table if it doesn't exist
|
||||||
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations (
|
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations (
|
||||||
version INTEGER PRIMARY KEY
|
version INTEGER PRIMARY KEY
|
||||||
|
|||||||
70
server/internal/db/sessions.go
Normal file
70
server/internal/db/sessions.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"novamd/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateSession inserts a new session record into the database
|
||||||
|
func (db *database) CreateSession(session *models.Session) error {
|
||||||
|
_, err := db.Exec(`
|
||||||
|
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?)`,
|
||||||
|
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to store session: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionByRefreshToken retrieves a session by its refresh token
|
||||||
|
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
|
||||||
|
session := &models.Session{}
|
||||||
|
err := db.QueryRow(`
|
||||||
|
SELECT id, user_id, refresh_token, expires_at, created_at
|
||||||
|
FROM sessions
|
||||||
|
WHERE refresh_token = ? AND expires_at > ?`,
|
||||||
|
refreshToken, time.Now(),
|
||||||
|
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
|
||||||
|
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("session not found or expired")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to fetch session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSession removes a session from the database
|
||||||
|
func (db *database) DeleteSession(sessionID string) error {
|
||||||
|
result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowsAffected, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rowsAffected == 0 {
|
||||||
|
return fmt.Errorf("session not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanExpiredSessions removes all expired sessions from the database
|
||||||
|
func (db *database) CleanExpiredSessions() error {
|
||||||
|
_, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to clean expired sessions: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -11,9 +11,16 @@ const (
|
|||||||
JWTSecretKey = "jwt_secret"
|
JWTSecretKey = "jwt_secret"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// UserStats represents system-wide statistics
|
||||||
|
type UserStats struct {
|
||||||
|
TotalUsers int `json:"totalUsers"`
|
||||||
|
TotalWorkspaces int `json:"totalWorkspaces"`
|
||||||
|
ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days
|
||||||
|
}
|
||||||
|
|
||||||
// EnsureJWTSecret makes sure a JWT signing secret exists in the database
|
// EnsureJWTSecret makes sure a JWT signing secret exists in the database
|
||||||
// If no secret exists, it generates and stores a new one
|
// If no secret exists, it generates and stores a new one
|
||||||
func (db *DB) EnsureJWTSecret() (string, error) {
|
func (db *database) EnsureJWTSecret() (string, error) {
|
||||||
// First, try to get existing secret
|
// First, try to get existing secret
|
||||||
secret, err := db.GetSystemSetting(JWTSecretKey)
|
secret, err := db.GetSystemSetting(JWTSecretKey)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -36,7 +43,7 @@ func (db *DB) EnsureJWTSecret() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetSystemSetting retrieves a system setting by key
|
// GetSystemSetting retrieves a system setting by key
|
||||||
func (db *DB) GetSystemSetting(key string) (string, error) {
|
func (db *database) GetSystemSetting(key string) (string, error) {
|
||||||
var value string
|
var value string
|
||||||
err := db.QueryRow("SELECT value FROM system_settings WHERE key = ?", key).Scan(&value)
|
err := db.QueryRow("SELECT value FROM system_settings WHERE key = ?", key).Scan(&value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -46,7 +53,7 @@ func (db *DB) GetSystemSetting(key string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetSystemSetting stores or updates a system setting
|
// SetSystemSetting stores or updates a system setting
|
||||||
func (db *DB) SetSystemSetting(key, value string) error {
|
func (db *database) SetSystemSetting(key, value string) error {
|
||||||
_, err := db.Exec(`
|
_, err := db.Exec(`
|
||||||
INSERT INTO system_settings (key, value)
|
INSERT INTO system_settings (key, value)
|
||||||
VALUES (?, ?)
|
VALUES (?, ?)
|
||||||
@@ -64,3 +71,32 @@ func generateRandomSecret(bytes int) (string, error) {
|
|||||||
}
|
}
|
||||||
return base64.StdEncoding.EncodeToString(b), nil
|
return base64.StdEncoding.EncodeToString(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSystemStats returns system-wide statistics
|
||||||
|
func (db *database) GetSystemStats() (*UserStats, error) {
|
||||||
|
stats := &UserStats{}
|
||||||
|
|
||||||
|
// Get total users
|
||||||
|
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get total workspaces
|
||||||
|
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get active users (users with activity in last 30 days)
|
||||||
|
err = db.QueryRow(`
|
||||||
|
SELECT COUNT(DISTINCT user_id)
|
||||||
|
FROM sessions
|
||||||
|
WHERE created_at > datetime('now', '-30 days')`).
|
||||||
|
Scan(&stats.ActiveUsers)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// CreateUser inserts a new user record into the database
|
// CreateUser inserts a new user record into the database
|
||||||
func (db *DB) CreateUser(user *models.User) (*models.User, error) {
|
func (db *database) CreateUser(user *models.User) (*models.User, error) {
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -62,7 +62,7 @@ func (db *DB) CreateUser(user *models.User) (*models.User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to create a workspace in a transaction
|
// Helper function to create a workspace in a transaction
|
||||||
func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
|
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
|
||||||
result, err := tx.Exec(`
|
result, err := tx.Exec(`
|
||||||
INSERT INTO workspaces (
|
INSERT INTO workspaces (
|
||||||
user_id, name,
|
user_id, name,
|
||||||
@@ -87,7 +87,7 @@ func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUserByID retrieves a user by ID
|
// GetUserByID retrieves a user by ID
|
||||||
func (db *DB) GetUserByID(id int) (*models.User, error) {
|
func (db *database) GetUserByID(id int) (*models.User, error) {
|
||||||
user := &models.User{}
|
user := &models.User{}
|
||||||
err := db.QueryRow(`
|
err := db.QueryRow(`
|
||||||
SELECT
|
SELECT
|
||||||
@@ -104,7 +104,7 @@ func (db *DB) GetUserByID(id int) (*models.User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUserByEmail retrieves a user by email
|
// GetUserByEmail retrieves a user by email
|
||||||
func (db *DB) GetUserByEmail(email string) (*models.User, error) {
|
func (db *database) GetUserByEmail(email string) (*models.User, error) {
|
||||||
user := &models.User{}
|
user := &models.User{}
|
||||||
err := db.QueryRow(`
|
err := db.QueryRow(`
|
||||||
SELECT
|
SELECT
|
||||||
@@ -122,7 +122,7 @@ func (db *DB) GetUserByEmail(email string) (*models.User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateUser updates a user's information
|
// UpdateUser updates a user's information
|
||||||
func (db *DB) UpdateUser(user *models.User) error {
|
func (db *database) UpdateUser(user *models.User) error {
|
||||||
_, err := db.Exec(`
|
_, err := db.Exec(`
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ?
|
SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ?
|
||||||
@@ -131,8 +131,36 @@ func (db *DB) UpdateUser(user *models.User) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAllUsers returns a list of all users in the system
|
||||||
|
func (db *database) GetAllUsers() ([]*models.User, error) {
|
||||||
|
rows, err := db.Query(`
|
||||||
|
SELECT
|
||||||
|
id, email, display_name, role, created_at,
|
||||||
|
last_workspace_id
|
||||||
|
FROM users
|
||||||
|
ORDER BY id ASC`)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var users []*models.User
|
||||||
|
for rows.Next() {
|
||||||
|
user := &models.User{}
|
||||||
|
err := rows.Scan(
|
||||||
|
&user.ID, &user.Email, &user.DisplayName, &user.Role,
|
||||||
|
&user.CreatedAt, &user.LastWorkspaceID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
users = append(users, user)
|
||||||
|
}
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateLastWorkspace updates the last workspace the user accessed
|
// UpdateLastWorkspace updates the last workspace the user accessed
|
||||||
func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error {
|
func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error {
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -155,7 +183,7 @@ func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUser deletes a user and all their workspaces
|
// DeleteUser deletes a user and all their workspaces
|
||||||
func (db *DB) DeleteUser(id int) error {
|
func (db *database) DeleteUser(id int) error {
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -178,7 +206,7 @@ func (db *DB) DeleteUser(id int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetLastWorkspaceName returns the name of the last workspace the user accessed
|
// GetLastWorkspaceName returns the name of the last workspace the user accessed
|
||||||
func (db *DB) GetLastWorkspaceName(userID int) (string, error) {
|
func (db *database) GetLastWorkspaceName(userID int) (string, error) {
|
||||||
var workspaceName string
|
var workspaceName string
|
||||||
err := db.QueryRow(`
|
err := db.QueryRow(`
|
||||||
SELECT
|
SELECT
|
||||||
@@ -189,3 +217,10 @@ func (db *DB) GetLastWorkspaceName(userID int) (string, error) {
|
|||||||
Scan(&workspaceName)
|
Scan(&workspaceName)
|
||||||
return workspaceName, err
|
return workspaceName, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountAdminUsers returns the number of admin users in the system
|
||||||
|
func (db *database) CountAdminUsers() (int, error) {
|
||||||
|
var count int
|
||||||
|
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count)
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// CreateWorkspace inserts a new workspace record into the database
|
// CreateWorkspace inserts a new workspace record into the database
|
||||||
func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
|
func (db *database) CreateWorkspace(workspace *models.Workspace) error {
|
||||||
// Set default settings if not provided
|
// Set default settings if not provided
|
||||||
if workspace.Theme == "" {
|
if workspace.Theme == "" {
|
||||||
workspace.GetDefaultSettings()
|
workspace.GetDefaultSettings()
|
||||||
@@ -42,7 +42,7 @@ func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetWorkspaceByID retrieves a workspace by its ID
|
// GetWorkspaceByID retrieves a workspace by its ID
|
||||||
func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) {
|
func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
|
||||||
workspace := &models.Workspace{}
|
workspace := &models.Workspace{}
|
||||||
var encryptedToken string
|
var encryptedToken string
|
||||||
|
|
||||||
@@ -75,7 +75,7 @@ func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetWorkspaceByName retrieves a workspace by its name and user ID
|
// GetWorkspaceByName retrieves a workspace by its name and user ID
|
||||||
func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
|
func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
|
||||||
workspace := &models.Workspace{}
|
workspace := &models.Workspace{}
|
||||||
var encryptedToken string
|
var encryptedToken string
|
||||||
|
|
||||||
@@ -108,7 +108,7 @@ func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Work
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateWorkspace updates a workspace record in the database
|
// UpdateWorkspace updates a workspace record in the database
|
||||||
func (db *DB) UpdateWorkspace(workspace *models.Workspace) error {
|
func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
|
||||||
// Encrypt token before storing
|
// Encrypt token before storing
|
||||||
encryptedToken, err := db.encryptToken(workspace.GitToken)
|
encryptedToken, err := db.encryptToken(workspace.GitToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -146,7 +146,7 @@ func (db *DB) UpdateWorkspace(workspace *models.Workspace) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetWorkspacesByUserID retrieves all workspaces for a user
|
// GetWorkspacesByUserID retrieves all workspaces for a user
|
||||||
func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
|
func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
|
||||||
rows, err := db.Query(`
|
rows, err := db.Query(`
|
||||||
SELECT
|
SELECT
|
||||||
id, user_id, name, created_at,
|
id, user_id, name, created_at,
|
||||||
@@ -189,7 +189,7 @@ func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
|
|||||||
|
|
||||||
// UpdateWorkspaceSettings updates only the settings portion of a workspace
|
// UpdateWorkspaceSettings updates only the settings portion of a workspace
|
||||||
// This is useful when you don't want to modify the name or other core workspace properties
|
// This is useful when you don't want to modify the name or other core workspace properties
|
||||||
func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error {
|
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
|
||||||
_, err := db.Exec(`
|
_, err := db.Exec(`
|
||||||
UPDATE workspaces
|
UPDATE workspaces
|
||||||
SET
|
SET
|
||||||
@@ -218,31 +218,31 @@ func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteWorkspace removes a workspace record from the database
|
// DeleteWorkspace removes a workspace record from the database
|
||||||
func (db *DB) DeleteWorkspace(id int) error {
|
func (db *database) DeleteWorkspace(id int) error {
|
||||||
_, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id)
|
_, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteWorkspaceTx removes a workspace record from the database within a transaction
|
// DeleteWorkspaceTx removes a workspace record from the database within a transaction
|
||||||
func (db *DB) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
|
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
|
||||||
_, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id)
|
_, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLastWorkspaceTx sets the last workspace for a user in with a transaction
|
// UpdateLastWorkspaceTx sets the last workspace for a user in with a transaction
|
||||||
func (db *DB) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
|
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
|
||||||
_, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID)
|
_, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLastOpenedFile updates the last opened file path for a workspace
|
// UpdateLastOpenedFile updates the last opened file path for a workspace
|
||||||
func (db *DB) UpdateLastOpenedFile(workspaceID int, filePath string) error {
|
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
|
||||||
_, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?", filePath, workspaceID)
|
_, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?", filePath, workspaceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLastOpenedFile retrieves the last opened file path for a workspace
|
// GetLastOpenedFile retrieves the last opened file path for a workspace
|
||||||
func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) {
|
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
|
||||||
var filePath sql.NullString
|
var filePath sql.NullString
|
||||||
err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?", workspaceID).Scan(&filePath)
|
err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?", workspaceID).Scan(&filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -255,7 +255,7 @@ func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAllWorkspaces retrieves all workspaces in the database
|
// GetAllWorkspaces retrieves all workspaces in the database
|
||||||
func (db *DB) GetAllWorkspaces() ([]*models.Workspace, error) {
|
func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
|
||||||
rows, err := db.Query(`
|
rows, err := db.Query(`
|
||||||
SELECT
|
SELECT
|
||||||
id, user_id, name, created_at,
|
id, user_id, name, created_at,
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ type LoginResponse struct {
|
|||||||
AccessToken string `json:"accessToken"`
|
AccessToken string `json:"accessToken"`
|
||||||
RefreshToken string `json:"refreshToken"`
|
RefreshToken string `json:"refreshToken"`
|
||||||
User *models.User `json:"user"`
|
User *models.User `json:"user"`
|
||||||
Session *auth.Session `json:"session"`
|
Session *models.Session `json:"session"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RefreshRequest struct {
|
type RefreshRequest struct {
|
||||||
|
|||||||
@@ -9,12 +9,12 @@ import (
|
|||||||
|
|
||||||
// Handler provides common functionality for all handlers
|
// Handler provides common functionality for all handlers
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
DB *db.DB
|
DB db.Database
|
||||||
Storage storage.Manager
|
Storage storage.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new handler with the given dependencies
|
// NewHandler creates a new handler with the given dependencies
|
||||||
func NewHandler(db *db.DB, s storage.Manager) *Handler {
|
func NewHandler(db db.Database, s storage.Manager) *Handler {
|
||||||
return &Handler{
|
return &Handler{
|
||||||
DB: db,
|
DB: db,
|
||||||
Storage: s,
|
Storage: s,
|
||||||
|
|||||||
@@ -171,8 +171,7 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
|
|||||||
// Prevent admin from deleting their own account if they're the last admin
|
// Prevent admin from deleting their own account if they're the last admin
|
||||||
if user.Role == "admin" {
|
if user.Role == "admin" {
|
||||||
// Count number of admin users
|
// Count number of admin users
|
||||||
adminCount := 0
|
adminCount, err := h.DB.CountAdminUsers()
|
||||||
err := h.DB.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&adminCount)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "Failed to verify admin status", http.StatusInternalServerError)
|
http.Error(w, "Failed to verify admin status", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func WithUserContext(next http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Workspace context
|
// Workspace context
|
||||||
func WithWorkspaceContext(db *db.DB) func(http.Handler) http.Handler {
|
func WithWorkspaceContext(db db.Database) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||||
|
|||||||
13
server/internal/models/session.go
Normal file
13
server/internal/models/session.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
// Package models contains the data models used throughout the application. These models are used to represent data in the database, as well as to validate and serialize data in the application.
|
||||||
|
package models
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// Session represents a user session in the database
|
||||||
|
type Session struct {
|
||||||
|
ID string // Unique session identifier
|
||||||
|
UserID int // ID of the user this session belongs to
|
||||||
|
RefreshToken string // The refresh token associated with this session
|
||||||
|
ExpiresAt time.Time // When this session expires
|
||||||
|
CreatedAt time.Time // When this session was created
|
||||||
|
}
|
||||||
@@ -13,11 +13,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type UserService struct {
|
type UserService struct {
|
||||||
DB *db.DB
|
DB db.Database
|
||||||
Storage storage.Manager
|
Storage storage.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserService(database *db.DB, s storage.Manager) *UserService {
|
func NewUserService(database db.Database, s storage.Manager) *UserService {
|
||||||
return &UserService{
|
return &UserService{
|
||||||
DB: database,
|
DB: database,
|
||||||
Storage: s,
|
Storage: s,
|
||||||
|
|||||||
Reference in New Issue
Block a user