Merge pull request #28 from lordmathis/feat/logging

Implement structured logging
This commit is contained in:
2024-12-19 23:45:28 +01:00
committed by GitHub
54 changed files with 1604 additions and 305 deletions

1
.gitignore vendored
View File

@@ -157,6 +157,7 @@ go.work.sum
# env file # env file
.env .env
.env.dev
main main
*.db *.db

14
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,14 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Launch NovaMD Server",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/server/cmd/server/main.go",
"cwd": "${workspaceFolder}",
"envFile": "${workspaceFolder}/server/.env"
}
]
}

View File

@@ -5,6 +5,7 @@ import (
"log" "log"
"novamd/internal/app" "novamd/internal/app"
"novamd/internal/logging"
) )
// @title NovaMD API // @title NovaMD API
@@ -23,6 +24,10 @@ func main() {
log.Fatal("Failed to load configuration:", err) log.Fatal("Failed to load configuration:", err)
} }
// Setup logging
logging.Setup(cfg.LogLevel)
logging.Debug("Configuration loaded", "config", cfg.Redact())
// Initialize and start server // Initialize and start server
options, err := app.DefaultOptions(cfg) options, err := app.DefaultOptions(cfg)
if err != nil { if err != nil {
@@ -32,7 +37,7 @@ func main() {
server := app.NewServer(options) server := app.NewServer(options)
defer func() { defer func() {
if err := server.Close(); err != nil { if err := server.Close(); err != nil {
log.Println("Error closing server:", err) logging.Error("Failed to close server:", err)
} }
}() }()

View File

@@ -2,6 +2,7 @@ package app
import ( import (
"fmt" "fmt"
"novamd/internal/logging"
"novamd/internal/secrets" "novamd/internal/secrets"
"os" "os"
"strconv" "strconv"
@@ -25,6 +26,7 @@ type Config struct {
RateLimitRequests int RateLimitRequests int
RateLimitWindow time.Duration RateLimitWindow time.Duration
IsDevelopment bool IsDevelopment bool
LogLevel logging.LogLevel
} }
// DefaultConfig returns a new Config instance with default values // DefaultConfig returns a new Config instance with default values
@@ -54,6 +56,16 @@ func (c *Config) validate() error {
return nil return nil
} }
// Redact redacts sensitive fields from a Config instance
func (c *Config) Redact() *Config {
redacted := *c
redacted.AdminPassword = "[REDACTED]"
redacted.AdminEmail = "[REDACTED]"
redacted.EncryptionKey = "[REDACTED]"
redacted.JWTSigningKey = "[REDACTED]"
return &redacted
}
// LoadConfig creates a new Config instance with values from environment variables // LoadConfig creates a new Config instance with values from environment variables
func LoadConfig() (*Config, error) { func LoadConfig() (*Config, error) {
config := DefaultConfig() config := DefaultConfig()
@@ -97,17 +109,29 @@ func LoadConfig() (*Config, error) {
// Configure rate limiting // Configure rate limiting
if reqStr := os.Getenv("NOVAMD_RATE_LIMIT_REQUESTS"); reqStr != "" { if reqStr := os.Getenv("NOVAMD_RATE_LIMIT_REQUESTS"); reqStr != "" {
if parsed, err := strconv.Atoi(reqStr); err == nil { parsed, err := strconv.Atoi(reqStr)
if err == nil {
config.RateLimitRequests = parsed config.RateLimitRequests = parsed
} }
} }
if windowStr := os.Getenv("NOVAMD_RATE_LIMIT_WINDOW"); windowStr != "" { if windowStr := os.Getenv("NOVAMD_RATE_LIMIT_WINDOW"); windowStr != "" {
if parsed, err := time.ParseDuration(windowStr); err == nil { parsed, err := time.ParseDuration(windowStr)
if err == nil {
config.RateLimitWindow = parsed config.RateLimitWindow = parsed
} }
} }
// Configure log level, if isDevelopment is set, default to debug
if logLevel := os.Getenv("NOVAMD_LOG_LEVEL"); logLevel != "" {
parsed := logging.ParseLogLevel(logLevel)
config.LogLevel = parsed
} else if config.IsDevelopment {
config.LogLevel = logging.DEBUG
} else {
config.LogLevel = logging.INFO
}
// Validate all settings // Validate all settings
if err := config.validate(); err != nil { if err := config.validate(); err != nil {
return nil, err return nil, err

View File

@@ -5,6 +5,8 @@ import (
"os" "os"
"testing" "testing"
"time" "time"
_ "novamd/internal/testenv"
) )
func TestDefaultConfig(t *testing.T) { func TestDefaultConfig(t *testing.T) {

View File

@@ -4,13 +4,13 @@ package app
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"log"
"time" "time"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"novamd/internal/auth" "novamd/internal/auth"
"novamd/internal/db" "novamd/internal/db"
"novamd/internal/logging"
"novamd/internal/models" "novamd/internal/models"
"novamd/internal/secrets" "novamd/internal/secrets"
"novamd/internal/storage" "novamd/internal/storage"
@@ -18,6 +18,7 @@ import (
// initSecretsService initializes the secrets service // initSecretsService initializes the secrets service
func initSecretsService(cfg *Config) (secrets.Service, error) { func initSecretsService(cfg *Config) (secrets.Service, error) {
logging.Debug("initializing secrets service")
secretsService, err := secrets.NewService(cfg.EncryptionKey) secretsService, err := secrets.NewService(cfg.EncryptionKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize secrets service: %w", err) return nil, fmt.Errorf("failed to initialize secrets service: %w", err)
@@ -27,6 +28,8 @@ func initSecretsService(cfg *Config) (secrets.Service, error) {
// initDatabase initializes and migrates the database // initDatabase initializes and migrates the database
func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, error) { func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, error) {
logging.Debug("initializing database", "path", cfg.DBPath)
database, err := db.Init(cfg.DBPath, secretsService) database, err := db.Init(cfg.DBPath, secretsService)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize database: %w", err) return nil, fmt.Errorf("failed to initialize database: %w", err)
@@ -41,9 +44,15 @@ func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, err
// initAuth initializes JWT and session services // initAuth initializes JWT and session services
func initAuth(cfg *Config, database db.Database) (auth.JWTManager, auth.SessionManager, auth.CookieManager, error) { func initAuth(cfg *Config, database db.Database) (auth.JWTManager, auth.SessionManager, auth.CookieManager, error) {
logging.Debug("initializing authentication services")
accessTokeExpiry := 15 * time.Minute
refreshTokenExpiry := 7 * 24 * time.Hour
// Get or generate JWT signing key // Get or generate JWT signing key
signingKey := cfg.JWTSigningKey signingKey := cfg.JWTSigningKey
if signingKey == "" { if signingKey == "" {
logging.Debug("no JWT signing key provided, generating new key")
var err error var err error
signingKey, err = database.EnsureJWTSecret() signingKey, err = database.EnsureJWTSecret()
if err != nil { if err != nil {
@@ -51,20 +60,16 @@ func initAuth(cfg *Config, database db.Database) (auth.JWTManager, auth.SessionM
} }
} }
// Initialize JWT service
jwtManager, err := auth.NewJWTService(auth.JWTConfig{ jwtManager, err := auth.NewJWTService(auth.JWTConfig{
SigningKey: signingKey, SigningKey: signingKey,
AccessTokenExpiry: 15 * time.Minute, AccessTokenExpiry: accessTokeExpiry,
RefreshTokenExpiry: 7 * 24 * time.Hour, RefreshTokenExpiry: refreshTokenExpiry,
}) })
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("failed to initialize JWT service: %w", err) return nil, nil, nil, fmt.Errorf("failed to initialize JWT service: %w", err)
} }
// Initialize session service
sessionManager := auth.NewSessionService(database, jwtManager) sessionManager := auth.NewSessionService(database, jwtManager)
// Cookie service
cookieService := auth.NewCookieService(cfg.IsDevelopment, cfg.Domain) cookieService := auth.NewCookieService(cfg.IsDevelopment, cfg.Domain)
return jwtManager, sessionManager, cookieService, nil return jwtManager, sessionManager, cookieService, nil
@@ -72,26 +77,26 @@ func initAuth(cfg *Config, database db.Database) (auth.JWTManager, auth.SessionM
// setupAdminUser creates the admin user if it doesn't exist // setupAdminUser creates the admin user if it doesn't exist
func setupAdminUser(database db.Database, storageManager storage.Manager, cfg *Config) error { func setupAdminUser(database db.Database, storageManager storage.Manager, cfg *Config) error {
adminEmail := cfg.AdminEmail
adminPassword := cfg.AdminPassword
// Check if admin user exists // Check if admin user exists
adminUser, err := database.GetUserByEmail(adminEmail) adminUser, err := database.GetUserByEmail(cfg.AdminEmail)
if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("failed to check for existing admin user: %w", err)
}
if adminUser != nil { if adminUser != nil {
return nil // Admin user already exists logging.Debug("admin user already exists", "userId", adminUser.ID)
} else if err != sql.ErrNoRows { return nil
return err
} }
// Hash the password // Hash the password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(adminPassword), bcrypt.DefaultCost) hashedPassword, err := bcrypt.GenerateFromPassword([]byte(cfg.AdminPassword), bcrypt.DefaultCost)
if err != nil { if err != nil {
return fmt.Errorf("failed to hash password: %w", err) return fmt.Errorf("failed to hash admin password: %w", err)
} }
// Create admin user // Create admin user
adminUser = &models.User{ adminUser = &models.User{
Email: adminEmail, Email: cfg.AdminEmail,
DisplayName: "Admin", DisplayName: "Admin",
PasswordHash: string(hashedPassword), PasswordHash: string(hashedPassword),
Role: models.RoleAdmin, Role: models.RoleAdmin,
@@ -102,13 +107,14 @@ func setupAdminUser(database db.Database, storageManager storage.Manager, cfg *C
return fmt.Errorf("failed to create admin user: %w", err) return fmt.Errorf("failed to create admin user: %w", err)
} }
// Initialize workspace directory
err = storageManager.InitializeUserWorkspace(createdUser.ID, createdUser.LastWorkspaceID) err = storageManager.InitializeUserWorkspace(createdUser.ID, createdUser.LastWorkspaceID)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize admin workspace: %w", err) return fmt.Errorf("failed to initialize admin workspace: %w", err)
} }
log.Printf("Created admin user with ID: %d and default workspace with ID: %d", createdUser.ID, createdUser.LastWorkspaceID) logging.Info("admin user setup completed",
"userId", createdUser.ID,
"workspaceId", createdUser.LastWorkspaceID)
return nil return nil
} }

View File

@@ -3,6 +3,7 @@ package app
import ( import (
"novamd/internal/auth" "novamd/internal/auth"
"novamd/internal/db" "novamd/internal/db"
"novamd/internal/logging"
"novamd/internal/storage" "novamd/internal/storage"
) )
@@ -33,6 +34,9 @@ func DefaultOptions(cfg *Config) (*Options, error) {
// Initialize storage // Initialize storage
storageManager := storage.NewService(cfg.WorkDir) storageManager := storage.NewService(cfg.WorkDir)
// Initialize logger
logging.Setup(cfg.LogLevel)
// Initialize auth services // Initialize auth services
jwtManager, sessionService, cookieService, err := initAuth(cfg, database) jwtManager, sessionService, cookieService, err := initAuth(cfg, database)
if err != nil { if err != nil {

View File

@@ -4,6 +4,7 @@ import (
"novamd/internal/auth" "novamd/internal/auth"
"novamd/internal/context" "novamd/internal/context"
"novamd/internal/handlers" "novamd/internal/handlers"
"novamd/internal/logging"
"time" "time"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
@@ -19,6 +20,7 @@ import (
// setupRouter creates and configures the chi router with middleware and routes // setupRouter creates and configures the chi router with middleware and routes
func setupRouter(o Options) *chi.Mux { func setupRouter(o Options) *chi.Mux {
logging.Debug("setting up router")
r := chi.NewRouter() r := chi.NewRouter()
// Basic middleware // Basic middleware

View File

@@ -1,8 +1,8 @@
package app package app
import ( import (
"log"
"net/http" "net/http"
"novamd/internal/logging"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
) )
@@ -25,12 +25,13 @@ func NewServer(options *Options) *Server {
func (s *Server) Start() error { func (s *Server) Start() error {
// Start server // Start server
addr := ":" + s.options.Config.Port addr := ":" + s.options.Config.Port
log.Printf("Server starting on port %s", s.options.Config.Port) logging.Info("starting server", "address", addr)
return http.ListenAndServe(addr, s.router) return http.ListenAndServe(addr, s.router)
} }
// Close handles graceful shutdown of server dependencies // Close handles graceful shutdown of server dependencies
func (s *Server) Close() error { func (s *Server) Close() error {
logging.Info("shutting down server")
return s.options.Database.Close() return s.options.Database.Close()
} }

View File

@@ -3,8 +3,22 @@ package auth
import ( import (
"net/http" "net/http"
"novamd/internal/logging"
) )
var logger logging.Logger
func getAuthLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("auth")
}
return logger
}
func getCookieLogger() logging.Logger {
return getAuthLogger().WithGroup("cookie")
}
// CookieManager interface defines methods for generating cookies // CookieManager interface defines methods for generating cookies
type CookieManager interface { type CookieManager interface {
GenerateAccessTokenCookie(token string) *http.Cookie GenerateAccessTokenCookie(token string) *http.Cookie
@@ -22,6 +36,8 @@ type cookieManager struct {
// NewCookieService creates a new cookie service // NewCookieService creates a new cookie service
func NewCookieService(isDevelopment bool, domain string) CookieManager { func NewCookieService(isDevelopment bool, domain string) CookieManager {
log := getCookieLogger()
secure := !isDevelopment secure := !isDevelopment
var sameSite http.SameSite var sameSite http.SameSite
@@ -31,6 +47,11 @@ func NewCookieService(isDevelopment bool, domain string) CookieManager {
sameSite = http.SameSiteStrictMode sameSite = http.SameSiteStrictMode
} }
log.Debug("creating cookie service",
"secure", secure,
"sameSite", sameSite,
"domain", domain)
return &cookieManager{ return &cookieManager{
Domain: domain, Domain: domain,
Secure: secure, Secure: secure,
@@ -40,6 +61,12 @@ func NewCookieService(isDevelopment bool, domain string) CookieManager {
// GenerateAccessTokenCookie creates a new cookie for the access token // GenerateAccessTokenCookie creates a new cookie for the access token
func (c *cookieManager) GenerateAccessTokenCookie(token string) *http.Cookie { func (c *cookieManager) GenerateAccessTokenCookie(token string) *http.Cookie {
log := getCookieLogger()
log.Debug("generating access token cookie",
"secure", c.Secure,
"sameSite", c.SameSite,
"maxAge", 900)
return &http.Cookie{ return &http.Cookie{
Name: "access_token", Name: "access_token",
Value: token, Value: token,
@@ -53,6 +80,12 @@ func (c *cookieManager) GenerateAccessTokenCookie(token string) *http.Cookie {
// GenerateRefreshTokenCookie creates a new cookie for the refresh token // GenerateRefreshTokenCookie creates a new cookie for the refresh token
func (c *cookieManager) GenerateRefreshTokenCookie(token string) *http.Cookie { func (c *cookieManager) GenerateRefreshTokenCookie(token string) *http.Cookie {
log := getCookieLogger()
log.Debug("generating refresh token cookie",
"secure", c.Secure,
"sameSite", c.SameSite,
"maxAge", 604800)
return &http.Cookie{ return &http.Cookie{
Name: "refresh_token", Name: "refresh_token",
Value: token, Value: token,
@@ -66,6 +99,13 @@ func (c *cookieManager) GenerateRefreshTokenCookie(token string) *http.Cookie {
// GenerateCSRFCookie creates a new cookie for the CSRF token // GenerateCSRFCookie creates a new cookie for the CSRF token
func (c *cookieManager) GenerateCSRFCookie(token string) *http.Cookie { func (c *cookieManager) GenerateCSRFCookie(token string) *http.Cookie {
log := getCookieLogger()
log.Debug("generating CSRF cookie",
"secure", c.Secure,
"sameSite", c.SameSite,
"maxAge", 900,
"httpOnly", false)
return &http.Cookie{ return &http.Cookie{
Name: "csrf_token", Name: "csrf_token",
Value: token, Value: token,
@@ -79,6 +119,12 @@ func (c *cookieManager) GenerateCSRFCookie(token string) *http.Cookie {
// InvalidateCookie creates a new cookie with a MaxAge of -1 to invalidate the cookie // InvalidateCookie creates a new cookie with a MaxAge of -1 to invalidate the cookie
func (c *cookieManager) InvalidateCookie(cookieType string) *http.Cookie { func (c *cookieManager) InvalidateCookie(cookieType string) *http.Cookie {
log := getCookieLogger()
log.Debug("invalidating cookie",
"type", cookieType,
"secure", c.Secure,
"sameSite", c.SameSite)
return &http.Cookie{ return &http.Cookie{
Name: cookieType, Name: cookieType,
Value: "", Value: "",

View File

@@ -4,11 +4,16 @@ package auth
import ( import (
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"novamd/internal/logging"
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
) )
func getJWTLogger() logging.Logger {
return getAuthLogger().WithGroup("jwt")
}
// TokenType represents the type of JWT token (access or refresh) // TokenType represents the type of JWT token (access or refresh)
type TokenType string type TokenType string
@@ -50,13 +55,15 @@ func NewJWTService(config JWTConfig) (JWTManager, error) {
if config.SigningKey == "" { if config.SigningKey == "" {
return nil, fmt.Errorf("signing key is required") return nil, fmt.Errorf("signing key is required")
} }
// Set default expiry times if not provided // Set default expiry times if not provided
if config.AccessTokenExpiry == 0 { if config.AccessTokenExpiry == 0 {
config.AccessTokenExpiry = 15 * time.Minute // Default to 15 minutes config.AccessTokenExpiry = 15 * time.Minute
} }
if config.RefreshTokenExpiry == 0 { if config.RefreshTokenExpiry == 0 {
config.RefreshTokenExpiry = 7 * 24 * time.Hour // Default to 7 days config.RefreshTokenExpiry = 7 * 24 * time.Hour
} }
return &jwtService{config: config}, nil return &jwtService{config: config}, nil
} }
@@ -93,11 +100,18 @@ func (s *jwtService) generateToken(userID int, role string, sessionID string, to
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(s.config.SigningKey)) signedToken, err := token.SignedString([]byte(s.config.SigningKey))
if err != nil {
return "", err
}
return signedToken, nil
} }
// ValidateToken validates and parses a JWT token // ValidateToken validates and parses a JWT token
func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) { func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) {
log := getJWTLogger()
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
// Validate the signing method // Validate the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
@@ -110,9 +124,16 @@ func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) {
return nil, fmt.Errorf("invalid token: %w", err) return nil, fmt.Errorf("invalid token: %w", err)
} }
if claims, ok := token.Claims.(*Claims); ok && token.Valid { claims, ok := token.Claims.(*Claims)
return claims, nil if !ok || !token.Valid {
return nil, fmt.Errorf("invalid token claims")
} }
return nil, fmt.Errorf("invalid token claims") log.Debug("token validated",
"userId", claims.UserID,
"role", claims.Role,
"tokenType", claims.Type,
"expiresAt", claims.ExpiresAt)
return claims, nil
} }

View File

@@ -6,10 +6,9 @@ import (
"time" "time"
"novamd/internal/auth" "novamd/internal/auth"
_ "novamd/internal/testenv"
) )
// jwt_test.go tests
func TestNewJWTService(t *testing.T) { func TestNewJWTService(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string

View File

@@ -3,10 +3,14 @@ package auth
import ( import (
"crypto/subtle" "crypto/subtle"
"net/http" "net/http"
"novamd/internal/context" "novamd/internal/context"
"novamd/internal/logging"
) )
func getMiddlewareLogger() logging.Logger {
return getAuthLogger().WithGroup("middleware")
}
// Middleware handles JWT authentication for protected routes // Middleware handles JWT authentication for protected routes
type Middleware struct { type Middleware struct {
jwtManager JWTManager jwtManager JWTManager
@@ -26,9 +30,15 @@ func NewMiddleware(jwtManager JWTManager, sessionManager SessionManager, cookieM
// Authenticate middleware validates JWT tokens and sets user information in context // Authenticate middleware validates JWT tokens and sets user information in context
func (m *Middleware) Authenticate(next http.Handler) http.Handler { func (m *Middleware) Authenticate(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract token from Authorization header log := getMiddlewareLogger().With(
"handler", "Authenticate",
"clientIP", r.RemoteAddr,
)
// Extract token from cookie
cookie, err := r.Cookie("access_token") cookie, err := r.Cookie("access_token")
if err != nil { if err != nil {
log.Warn("attempt to access protected route without token")
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
return return
} }
@@ -36,12 +46,14 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
// Validate token // Validate token
claims, err := m.jwtManager.ValidateToken(cookie.Value) claims, err := m.jwtManager.ValidateToken(cookie.Value)
if err != nil { if err != nil {
log.Warn("attempt to access protected route with invalid token", "error", err.Error())
http.Error(w, "Invalid token", http.StatusUnauthorized) http.Error(w, "Invalid token", http.StatusUnauthorized)
return return
} }
// Check token type // Check token type
if claims.Type != AccessToken { if claims.Type != AccessToken {
log.Warn("attempt to access protected route with invalid token type", "type", claims.Type)
http.Error(w, "Invalid token type", http.StatusUnauthorized) http.Error(w, "Invalid token type", http.StatusUnauthorized)
return return
} }
@@ -49,6 +61,7 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
// Check if session is still valid in database // Check if session is still valid in database
session, err := m.sessionManager.ValidateSession(claims.ID) session, err := m.sessionManager.ValidateSession(claims.ID)
if err != nil || session == nil { if err != nil || session == nil {
log.Warn("attempt to access protected route with invalid session", "error", err.Error())
m.cookieManager.InvalidateCookie("access_token") m.cookieManager.InvalidateCookie("access_token")
m.cookieManager.InvalidateCookie("refresh_token") m.cookieManager.InvalidateCookie("refresh_token")
m.cookieManager.InvalidateCookie("csrf_token") m.cookieManager.InvalidateCookie("csrf_token")
@@ -60,17 +73,20 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions { if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions {
csrfCookie, err := r.Cookie("csrf_token") csrfCookie, err := r.Cookie("csrf_token")
if err != nil { if err != nil {
log.Warn("attempt to access protected route without CSRF token", "error", err.Error())
http.Error(w, "CSRF cookie not found", http.StatusForbidden) http.Error(w, "CSRF cookie not found", http.StatusForbidden)
return return
} }
csrfHeader := r.Header.Get("X-CSRF-Token") csrfHeader := r.Header.Get("X-CSRF-Token")
if csrfHeader == "" { if csrfHeader == "" {
log.Warn("attempt to access protected route without CSRF header")
http.Error(w, "CSRF token header not found", http.StatusForbidden) http.Error(w, "CSRF token header not found", http.StatusForbidden)
return return
} }
if subtle.ConstantTimeCompare([]byte(csrfCookie.Value), []byte(csrfHeader)) != 1 { if subtle.ConstantTimeCompare([]byte(csrfCookie.Value), []byte(csrfHeader)) != 1 {
log.Warn("attempt to access protected route with invalid CSRF token")
http.Error(w, "CSRF token mismatch", http.StatusForbidden) http.Error(w, "CSRF token mismatch", http.StatusForbidden)
return return
} }
@@ -91,12 +107,19 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler { func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := getMiddlewareLogger().With(
"handler", "RequireRole",
"requiredRole", role,
"clientIP", r.RemoteAddr,
)
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
if !ok { if !ok {
return return
} }
if ctx.UserRole != role && ctx.UserRole != "admin" { if ctx.UserRole != role && ctx.UserRole != "admin" {
log.Warn("attempt to access protected route without required role")
http.Error(w, "Insufficient permissions", http.StatusForbidden) http.Error(w, "Insufficient permissions", http.StatusForbidden)
return return
} }
@@ -114,7 +137,13 @@ func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
return return
} }
// If no workspace in context, allow the request (might be a non-workspace endpoint) log := getMiddlewareLogger().With(
"handler", "RequireWorkspaceAccess",
"clientIP", r.RemoteAddr,
"userId", ctx.UserID,
)
// If no workspace in context, allow the request
if ctx.Workspace == nil { if ctx.Workspace == nil {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
@@ -122,6 +151,7 @@ func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
// Check if user has access (either owner or admin) // Check if user has access (either owner or admin)
if ctx.Workspace.UserID != ctx.UserID && ctx.UserRole != "admin" { if ctx.Workspace.UserID != ctx.UserID && ctx.UserRole != "admin" {
log.Warn("attempt to access workspace without permission")
http.Error(w, "Not Found", http.StatusNotFound) http.Error(w, "Not Found", http.StatusNotFound)
return return
} }

View File

@@ -11,6 +11,7 @@ import (
"novamd/internal/auth" "novamd/internal/auth"
"novamd/internal/context" "novamd/internal/context"
"novamd/internal/models" "novamd/internal/models"
_ "novamd/internal/testenv"
) )
// Mock SessionManager // Mock SessionManager

View File

@@ -3,12 +3,17 @@ package auth
import ( import (
"fmt" "fmt"
"novamd/internal/db" "novamd/internal/db"
"novamd/internal/logging"
"novamd/internal/models" "novamd/internal/models"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
) )
func getSessionLogger() logging.Logger {
return getAuthLogger().WithGroup("session")
}
// SessionManager is an interface for managing user sessions // SessionManager is an interface for managing user sessions
type SessionManager interface { type SessionManager interface {
CreateSession(userID int, role string) (*models.Session, string, error) CreateSession(userID int, role string) (*models.Session, string, error)
@@ -35,6 +40,7 @@ func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManage
// CreateSession creates a new user session for a user with the given userID and role // CreateSession creates a new user session for a user with the given userID and role
func (s *sessionManager) CreateSession(userID int, role string) (*models.Session, string, error) { func (s *sessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
log := getSessionLogger()
// Generate a new session ID // Generate a new session ID
sessionID := uuid.New().String() sessionID := uuid.New().String()
@@ -70,12 +76,18 @@ func (s *sessionManager) CreateSession(userID int, role string) (*models.Session
return nil, "", err return nil, "", err
} }
log.Debug("created new session",
"userId", userID,
"role", role,
"sessionId", sessionID,
"expiresAt", claims.ExpiresAt.Time)
return session, accessToken, nil return session, accessToken, nil
} }
// RefreshSession creates a new access token using a refreshToken // RefreshSession creates a new access token using a refreshToken
func (s *sessionManager) RefreshSession(refreshToken string) (string, error) { func (s *sessionManager) RefreshSession(refreshToken string) (string, error) {
// Get session from database first // Get session from database
session, err := s.db.GetSessionByRefreshToken(refreshToken) session, err := s.db.GetSessionByRefreshToken(refreshToken)
if err != nil { if err != nil {
return "", fmt.Errorf("invalid session: %w", err) return "", fmt.Errorf("invalid session: %w", err)
@@ -87,17 +99,22 @@ func (s *sessionManager) RefreshSession(refreshToken string) (string, error) {
return "", fmt.Errorf("invalid refresh token: %w", err) return "", fmt.Errorf("invalid refresh token: %w", err)
} }
// Double check that the claims match the session
if claims.UserID != session.UserID { if claims.UserID != session.UserID {
return "", fmt.Errorf("token does not match session") return "", fmt.Errorf("token does not match session")
} }
// Generate a new access token // Generate a new access token
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role, session.ID) newToken, err := s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role, session.ID)
if err != nil {
return "", err
}
return newToken, nil
} }
// ValidateSession checks if a session with the given sessionID is valid // ValidateSession checks if a session with the given sessionID is valid
func (s *sessionManager) ValidateSession(sessionID string) (*models.Session, error) { func (s *sessionManager) ValidateSession(sessionID string) (*models.Session, error) {
log := getSessionLogger()
// Get the session from the database // Get the session from the database
session, err := s.db.GetSessionByID(sessionID) session, err := s.db.GetSessionByID(sessionID)
@@ -105,21 +122,43 @@ func (s *sessionManager) ValidateSession(sessionID string) (*models.Session, err
return nil, fmt.Errorf("failed to get session: %w", err) return nil, fmt.Errorf("failed to get session: %w", err)
} }
log.Debug("validated session",
"sessionId", sessionID,
"userId", session.UserID,
"expiresAt", session.ExpiresAt)
return session, nil return session, nil
} }
// InvalidateSession removes a session with the given sessionID from the database // InvalidateSession removes a session with the given sessionID from the database
func (s *sessionManager) InvalidateSession(token string) error { func (s *sessionManager) InvalidateSession(token string) error {
log := getSessionLogger()
// Parse the JWT to get the session info // Parse the JWT to get the session info
claims, err := s.jwtManager.ValidateToken(token) claims, err := s.jwtManager.ValidateToken(token)
if err != nil { if err != nil {
return fmt.Errorf("invalid token: %w", err) return fmt.Errorf("invalid token: %w", err)
} }
return s.db.DeleteSession(claims.ID) if err := s.db.DeleteSession(claims.ID); err != nil {
return err
}
log.Debug("invalidated session",
"sessionId", claims.ID,
"userId", claims.UserID)
return nil
} }
// CleanExpiredSessions removes all expired sessions from the database // CleanExpiredSessions removes all expired sessions from the database
func (s *sessionManager) CleanExpiredSessions() error { func (s *sessionManager) CleanExpiredSessions() error {
return s.db.CleanExpiredSessions() log := getSessionLogger()
if err := s.db.CleanExpiredSessions(); err != nil {
return err
}
log.Info("cleaned expired sessions")
return nil
} }

View File

@@ -8,6 +8,7 @@ import (
"novamd/internal/auth" "novamd/internal/auth"
"novamd/internal/models" "novamd/internal/models"
_ "novamd/internal/testenv"
) )
// Mock SessionStore // Mock SessionStore

View File

@@ -5,6 +5,7 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"novamd/internal/logging"
"novamd/internal/models" "novamd/internal/models"
) )
@@ -28,10 +29,22 @@ type HandlerContext struct {
Workspace *models.Workspace // Optional, only set for workspace routes Workspace *models.Workspace // Optional, only set for workspace routes
} }
var logger logging.Logger
func getLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("context")
}
return logger
}
// GetRequestContext retrieves the handler context from the request // GetRequestContext retrieves the handler context from the request
func GetRequestContext(w http.ResponseWriter, r *http.Request) (*HandlerContext, bool) { func GetRequestContext(w http.ResponseWriter, r *http.Request) (*HandlerContext, bool) {
ctx := r.Context().Value(HandlerContextKey) ctx := r.Context().Value(HandlerContextKey)
if ctx == nil { if ctx == nil {
getLogger().Error("missing handler context in request",
"path", r.URL.Path,
"method", r.Method)
http.Error(w, "Internal server error", http.StatusInternalServerError) http.Error(w, "Internal server error", http.StatusInternalServerError)
return nil, false return nil, false
} }

View File

@@ -7,6 +7,7 @@ import (
"testing" "testing"
"novamd/internal/context" "novamd/internal/context"
_ "novamd/internal/testenv"
) )
func TestGetRequestContext(t *testing.T) { func TestGetRequestContext(t *testing.T) {

View File

@@ -10,9 +10,13 @@ import (
// WithUserContextMiddleware extracts user information from JWT claims // WithUserContextMiddleware extracts user information from JWT claims
// and adds it to the request context // and adds it to the request context
func WithUserContextMiddleware(next http.Handler) http.Handler { func WithUserContextMiddleware(next http.Handler) http.Handler {
log := getLogger()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, err := GetUserFromContext(r.Context()) claims, err := GetUserFromContext(r.Context())
if err != nil { if err != nil {
log.Error("failed to get user from context",
"error", err,
"path", r.URL.Path)
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
return return
} }
@@ -30,6 +34,7 @@ func WithUserContextMiddleware(next http.Handler) http.Handler {
// WithWorkspaceContextMiddleware adds workspace information to the request context // WithWorkspaceContextMiddleware adds workspace information to the request context
func WithWorkspaceContextMiddleware(db db.WorkspaceReader) func(http.Handler) http.Handler { func WithWorkspaceContextMiddleware(db db.WorkspaceReader) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
log := getLogger()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, ok := GetRequestContext(w, r) ctx, ok := GetRequestContext(w, r)
if !ok { if !ok {
@@ -39,11 +44,15 @@ func WithWorkspaceContextMiddleware(db db.WorkspaceReader) func(http.Handler) ht
workspaceName := chi.URLParam(r, "workspaceName") workspaceName := chi.URLParam(r, "workspaceName")
workspace, err := db.GetWorkspaceByName(ctx.UserID, workspaceName) workspace, err := db.GetWorkspaceByName(ctx.UserID, workspaceName)
if err != nil { if err != nil {
http.Error(w, "Workspace not found", http.StatusNotFound) log.Error("failed to get workspace",
"error", err,
"userID", ctx.UserID,
"workspace", workspaceName,
"path", r.URL.Path)
http.Error(w, "Failed to get workspace", http.StatusNotFound)
return return
} }
// Update existing context with workspace
ctx.Workspace = workspace ctx.Workspace = workspace
r = WithHandlerContext(r, ctx) r = WithHandlerContext(r, ctx)
next.ServeHTTP(w, r) next.ServeHTTP(w, r)

View File

@@ -9,6 +9,7 @@ import (
"novamd/internal/context" "novamd/internal/context"
"novamd/internal/models" "novamd/internal/models"
_ "novamd/internal/testenv"
) )
// MockDB implements the minimal database interface needed for testing // MockDB implements the minimal database interface needed for testing
@@ -89,6 +90,10 @@ func TestWithUserContextMiddleware(t *testing.T) {
} }
} }
type contextKey string
const workspaceNameKey contextKey = "workspaceName"
func TestWithWorkspaceContextMiddleware(t *testing.T) { func TestWithWorkspaceContextMiddleware(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -158,7 +163,7 @@ func TestWithWorkspaceContextMiddleware(t *testing.T) {
} }
// Add workspace name to request context via chi URL params // Add workspace name to request context via chi URL params
req = req.WithContext(stdctx.WithValue(req.Context(), "workspaceName", tt.workspaceName)) req = req.WithContext(stdctx.WithValue(req.Context(), workspaceNameKey, tt.workspaceName))
nextCalled := false nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@@ -3,7 +3,9 @@ package db
import ( import (
"database/sql" "database/sql"
"fmt"
"novamd/internal/logging"
"novamd/internal/models" "novamd/internal/models"
"novamd/internal/secrets" "novamd/internal/secrets"
@@ -77,6 +79,7 @@ type Database interface {
Migrate() error Migrate() error
} }
// Verify that the database implements the required interfaces
var ( var (
// Main Database interface // Main Database interface
_ Database = (*database)(nil) _ Database = (*database)(nil)
@@ -92,6 +95,15 @@ var (
_ WorkspaceWriter = (*database)(nil) _ WorkspaceWriter = (*database)(nil)
) )
var logger logging.Logger
func getLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("db")
}
return logger
}
// database represents the database connection // database represents the database connection
type database struct { type database struct {
*sql.DB *sql.DB
@@ -100,19 +112,22 @@ type database struct {
// Init initializes the database connection // Init initializes the database connection
func Init(dbPath string, secretsService secrets.Service) (Database, error) { func Init(dbPath string, secretsService secrets.Service) (Database, error) {
log := getLogger()
db, err := sql.Open("sqlite3", dbPath) db, err := sql.Open("sqlite3", dbPath)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to open database: %w", err)
} }
if err := db.Ping(); err != nil { if err := db.Ping(); err != nil {
return nil, err return nil, fmt.Errorf("failed to ping database: %w", err)
} }
// Enable foreign keys for this connection // Enable foreign keys for this connection
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
return nil, err return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
} }
log.Debug("foreign keys enabled")
database := &database{ database := &database{
DB: db, DB: db,
@@ -124,7 +139,13 @@ func Init(dbPath string, secretsService secrets.Service) (Database, error) {
// Close closes the database connection // Close closes the database connection
func (db *database) Close() error { func (db *database) Close() error {
return db.DB.Close() log := getLogger()
log.Info("closing database connection")
if err := db.DB.Close(); err != nil {
return fmt.Errorf("failed to close database: %w", err)
}
return nil
} }
// Helper methods for token encryption/decryption // Helper methods for token encryption/decryption
@@ -132,12 +153,24 @@ func (db *database) encryptToken(token string) (string, error) {
if token == "" { if token == "" {
return "", nil return "", nil
} }
return db.secretsService.Encrypt(token)
encrypted, err := db.secretsService.Encrypt(token)
if err != nil {
return "", fmt.Errorf("failed to encrypt token: %w", err)
}
return encrypted, nil
} }
func (db *database) decryptToken(token string) (string, error) { func (db *database) decryptToken(token string) (string, error) {
if token == "" { if token == "" {
return "", nil return "", nil
} }
return db.secretsService.Decrypt(token)
decrypted, err := db.secretsService.Decrypt(token)
if err != nil {
return "", fmt.Errorf("failed to decrypt token: %w", err)
}
return decrypted, nil
} }

View File

@@ -2,7 +2,6 @@ package db
import ( import (
"fmt" "fmt"
"log"
) )
// Migration represents a database migration // Migration represents a database migration
@@ -79,56 +78,64 @@ var migrations = []Migration{
// Migrate applies all database migrations // Migrate applies all database migrations
func (db *database) Migrate() error { func (db *database) Migrate() error {
log := getLogger().WithGroup("migrations")
log.Info("starting database migration")
// Create migrations table if it doesn't exist // Create migrations table if it doesn't exist
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations ( _, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations (
version INTEGER PRIMARY KEY version INTEGER PRIMARY KEY
)`) )`)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to create migrations table: %w", err)
} }
// Get current version // Get current version
var currentVersion int var currentVersion int
err = db.QueryRow("SELECT COALESCE(MAX(version), 0) FROM migrations").Scan(&currentVersion) err = db.QueryRow("SELECT COALESCE(MAX(version), 0) FROM migrations").Scan(&currentVersion)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to get current migration version: %w", err)
} }
// Apply new migrations // Apply new migrations
for _, migration := range migrations { for _, migration := range migrations {
if migration.Version > currentVersion { if migration.Version > currentVersion {
log.Printf("Applying migration %d", migration.Version) log := log.With("migration_version", migration.Version)
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
return err return fmt.Errorf("failed to begin transaction for migration %d: %w", migration.Version, err)
} }
// Execute migration SQL
_, err = tx.Exec(migration.SQL) _, err = tx.Exec(migration.SQL)
if err != nil { if err != nil {
if rbErr := tx.Rollback(); rbErr != nil { if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("migration %d failed: %v, rollback failed: %v", migration.Version, err, rbErr) return fmt.Errorf("migration %d failed: %v, rollback failed: %v",
migration.Version, err, rbErr)
} }
return fmt.Errorf("migration %d failed: %v", migration.Version, err) return fmt.Errorf("migration %d failed: %w", migration.Version, err)
} }
// Update migrations table
_, err = tx.Exec("INSERT INTO migrations (version) VALUES (?)", migration.Version) _, err = tx.Exec("INSERT INTO migrations (version) VALUES (?)", migration.Version)
if err != nil { if err != nil {
if rbErr := tx.Rollback(); rbErr != nil { if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("failed to update migration version: %v, rollback failed: %v", err, rbErr) return fmt.Errorf("failed to update migration version: %v, rollback failed: %v",
err, rbErr)
} }
return fmt.Errorf("failed to update migration version: %v", err) return fmt.Errorf("failed to update migration version: %w", err)
} }
// Commit transaction
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
return fmt.Errorf("failed to commit migration %d: %v", migration.Version, err) return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err)
} }
currentVersion = migration.Version currentVersion = migration.Version
log.Debug("migration applied", "new_version", currentVersion)
} }
} }
log.Printf("Database is at version %d", currentVersion) log.Info("database migration completed", "final_version", currentVersion)
return nil return nil
} }

View File

@@ -5,6 +5,8 @@ import (
"novamd/internal/db" "novamd/internal/db"
_ "novamd/internal/testenv"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )

View File

@@ -11,13 +11,14 @@ import (
// CreateSession inserts a new session record into the database // CreateSession inserts a new session record into the database
func (db *database) CreateSession(session *models.Session) error { func (db *database) CreateSession(session *models.Session) error {
_, err := db.Exec(` _, err := db.Exec(`
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at) INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
VALUES (?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?)`,
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt, session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
) )
if err != nil { if err != nil {
return fmt.Errorf("failed to store session: %w", err) return fmt.Errorf("failed to store session: %w", err)
} }
return nil return nil
} }
@@ -45,9 +46,9 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi
func (db *database) GetSessionByID(sessionID string) (*models.Session, error) { func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
session := &models.Session{} session := &models.Session{}
err := db.QueryRow(` err := db.QueryRow(`
SELECT id, user_id, refresh_token, expires_at, created_at SELECT id, user_id, refresh_token, expires_at, created_at
FROM sessions FROM sessions
WHERE id = ? AND expires_at > ?`, WHERE id = ? AND expires_at > ?`,
sessionID, time.Now(), sessionID, time.Now(),
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) ).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
@@ -82,9 +83,17 @@ func (db *database) DeleteSession(sessionID string) error {
// CleanExpiredSessions removes all expired sessions from the database // CleanExpiredSessions removes all expired sessions from the database
func (db *database) CleanExpiredSessions() error { func (db *database) CleanExpiredSessions() error {
_, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now()) log := getLogger().WithGroup("sessions")
result, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
if err != nil { if err != nil {
return fmt.Errorf("failed to clean expired sessions: %w", err) return fmt.Errorf("failed to clean expired sessions: %w", err)
} }
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
log.Info("cleaned expired sessions", "sessions_removed", rowsAffected)
return nil return nil
} }

View File

@@ -7,6 +7,7 @@ import (
"novamd/internal/db" "novamd/internal/db"
"novamd/internal/models" "novamd/internal/models"
_ "novamd/internal/testenv"
"github.com/google/uuid" "github.com/google/uuid"
) )

View File

@@ -21,6 +21,8 @@ type UserStats struct {
// EnsureJWTSecret makes sure a JWT signing secret exists in the database // EnsureJWTSecret makes sure a JWT signing secret exists in the database
// If no secret exists, it generates and stores a new one // If no secret exists, it generates and stores a new one
func (db *database) EnsureJWTSecret() (string, error) { func (db *database) EnsureJWTSecret() (string, error) {
log := getLogger().WithGroup("system")
// First, try to get existing secret // First, try to get existing secret
secret, err := db.GetSystemSetting(JWTSecretKey) secret, err := db.GetSystemSetting(JWTSecretKey)
if err == nil { if err == nil {
@@ -39,6 +41,8 @@ func (db *database) EnsureJWTSecret() (string, error) {
return "", fmt.Errorf("failed to store JWT secret: %w", err) return "", fmt.Errorf("failed to store JWT secret: %w", err)
} }
log.Info("new JWT secret generated and stored")
return newSecret, nil return newSecret, nil
} }
@@ -49,27 +53,38 @@ func (db *database) GetSystemSetting(key string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return value, nil return value, nil
} }
// SetSystemSetting stores or updates a system setting // SetSystemSetting stores or updates a system setting
func (db *database) SetSystemSetting(key, value string) error { func (db *database) SetSystemSetting(key, value string) error {
_, err := db.Exec(` _, err := db.Exec(`
INSERT INTO system_settings (key, value) INSERT INTO system_settings (key, value)
VALUES (?, ?) VALUES (?, ?)
ON CONFLICT(key) DO UPDATE SET value = ?`, ON CONFLICT(key) DO UPDATE SET value = ?`,
key, value, value) key, value, value)
return err
if err != nil {
return fmt.Errorf("failed to store system setting: %w", err)
}
return nil
} }
// generateRandomSecret generates a cryptographically secure random string // generateRandomSecret generates a cryptographically secure random string
func generateRandomSecret(bytes int) (string, error) { func generateRandomSecret(bytes int) (string, error) {
log := getLogger().WithGroup("system")
log.Debug("generating random secret", "bytes", bytes)
b := make([]byte, bytes) b := make([]byte, bytes)
_, err := rand.Read(b) _, err := rand.Read(b)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to generate random bytes: %w", err)
} }
return base64.StdEncoding.EncodeToString(b), nil
secret := base64.StdEncoding.EncodeToString(b)
return secret, nil
} }
// GetSystemStats returns system-wide statistics // GetSystemStats returns system-wide statistics
@@ -79,24 +94,23 @@ func (db *database) GetSystemStats() (*UserStats, error) {
// Get total users // Get total users
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers) err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to get total users count: %w", err)
} }
// Get total workspaces // Get total workspaces
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces) err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to get total workspaces count: %w", err)
} }
// Get active users (users with activity in last 30 days) // Get active users (users with activity in last 30 days)
err = db.QueryRow(` err = db.QueryRow(`
SELECT COUNT(DISTINCT user_id) SELECT COUNT(DISTINCT user_id)
FROM sessions FROM sessions
WHERE created_at > datetime('now', '-30 days')`). WHERE created_at > datetime('now', '-30 days')`).
Scan(&stats.ActiveUsers) Scan(&stats.ActiveUsers)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to get active users count: %w", err)
} }
return stats, nil return stats, nil
} }

View File

@@ -9,6 +9,7 @@ import (
"novamd/internal/db" "novamd/internal/db"
"novamd/internal/models" "novamd/internal/models"
_ "novamd/internal/testenv"
"github.com/google/uuid" "github.com/google/uuid"
) )

View File

@@ -2,35 +2,39 @@ package db
import ( import (
"database/sql" "database/sql"
"fmt"
"novamd/internal/models" "novamd/internal/models"
) )
// CreateUser inserts a new user record into the database // CreateUser inserts a new user record into the database
func (db *database) CreateUser(user *models.User) (*models.User, error) { func (db *database) CreateUser(user *models.User) (*models.User, error) {
log := getLogger().WithGroup("users")
log.Debug("creating user", "email", user.Email)
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to begin transaction: %w", err)
} }
defer tx.Rollback() defer tx.Rollback()
result, err := tx.Exec(` result, err := tx.Exec(`
INSERT INTO users (email, display_name, password_hash, role) INSERT INTO users (email, display_name, password_hash, role)
VALUES (?, ?, ?, ?)`, VALUES (?, ?, ?, ?)`,
user.Email, user.DisplayName, user.PasswordHash, user.Role) user.Email, user.DisplayName, user.PasswordHash, user.Role)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to insert user: %w", err)
} }
userID, err := result.LastInsertId() userID, err := result.LastInsertId()
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to get last insert ID: %w", err)
} }
user.ID = int(userID) user.ID = int(userID)
// Retrieve the created_at timestamp // Retrieve the created_at timestamp
err = tx.QueryRow("SELECT created_at FROM users WHERE id = ?", user.ID).Scan(&user.CreatedAt) err = tx.QueryRow("SELECT created_at FROM users WHERE id = ?", user.ID).Scan(&user.CreatedAt)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to get created timestamp: %w", err)
} }
// Create default workspace with default settings // Create default workspace with default settings
@@ -38,39 +42,42 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
UserID: user.ID, UserID: user.ID,
Name: "Main", Name: "Main",
} }
defaultWorkspace.SetDefaultSettings() // Initialize default settings defaultWorkspace.SetDefaultSettings()
// Create workspace with settings // Create workspace with settings
err = db.createWorkspaceTx(tx, defaultWorkspace) err = db.createWorkspaceTx(tx, defaultWorkspace)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to create default workspace: %w", err)
} }
// Update user's last workspace ID // Update user's last workspace ID
_, err = tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", defaultWorkspace.ID, user.ID) _, err = tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", defaultWorkspace.ID, user.ID)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to update last workspace ID: %w", err)
} }
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to commit transaction: %w", err)
} }
log.Debug("created user", "user_id", user.ID)
user.LastWorkspaceID = defaultWorkspace.ID user.LastWorkspaceID = defaultWorkspace.ID
return user, nil return user, nil
} }
// Helper function to create a workspace in a transaction // Helper function to create a workspace in a transaction
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error { func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
log := getLogger().WithGroup("users")
result, err := tx.Exec(` result, err := tx.Exec(`
INSERT INTO workspaces ( INSERT INTO workspaces (
user_id, name, user_id, name,
theme, auto_save, show_hidden_files, theme, auto_save, show_hidden_files,
git_enabled, git_url, git_user, git_token, git_enabled, git_url, git_user, git_token,
git_auto_commit, git_commit_msg_template, git_auto_commit, git_commit_msg_template,
git_commit_name, git_commit_email git_commit_name, git_commit_email
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
workspace.UserID, workspace.Name, workspace.UserID, workspace.Name,
workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles,
workspace.GitEnabled, workspace.GitURL, workspace.GitUser, workspace.GitToken, workspace.GitEnabled, workspace.GitURL, workspace.GitUser, workspace.GitToken,
@@ -78,17 +85,21 @@ func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) e
workspace.GitCommitName, workspace.GitCommitEmail, workspace.GitCommitName, workspace.GitCommitEmail,
) )
if err != nil { if err != nil {
return err return fmt.Errorf("failed to insert workspace: %w", err)
} }
id, err := result.LastInsertId() id, err := result.LastInsertId()
if err != nil { if err != nil {
return err return fmt.Errorf("failed to get workspace ID: %w", err)
} }
workspace.ID = int(id) workspace.ID = int(id)
log.Debug("created user workspace",
"workspace_id", workspace.ID,
"user_id", workspace.UserID)
return nil return nil
} }
// GetUserByID retrieves a user by ID
func (db *database) GetUserByID(id int) (*models.User, error) { func (db *database) GetUserByID(id int) (*models.User, error) {
user := &models.User{} user := &models.User{}
err := db.QueryRow(` err := db.QueryRow(`
@@ -97,15 +108,18 @@ func (db *database) GetUserByID(id int) (*models.User, error) {
last_workspace_id last_workspace_id
FROM users FROM users
WHERE id = ?`, id). WHERE id = ?`, id).
Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash, &user.Role, &user.CreatedAt, Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash,
&user.LastWorkspaceID) &user.Role, &user.CreatedAt, &user.LastWorkspaceID)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("user not found")
}
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to fetch user: %w", err)
} }
return user, nil return user, nil
} }
// GetUserByEmail retrieves a user by email
func (db *database) GetUserByEmail(email string) (*models.User, error) { func (db *database) GetUserByEmail(email string) (*models.User, error) {
user := &models.User{} user := &models.User{}
err := db.QueryRow(` err := db.QueryRow(`
@@ -114,35 +128,52 @@ func (db *database) GetUserByEmail(email string) (*models.User, error) {
last_workspace_id last_workspace_id
FROM users FROM users
WHERE email = ?`, email). WHERE email = ?`, email).
Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash, &user.Role, &user.CreatedAt, Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash,
&user.LastWorkspaceID) &user.Role, &user.CreatedAt, &user.LastWorkspaceID)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("user not found")
}
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to fetch user: %w", err)
} }
return user, nil return user, nil
} }
// UpdateUser updates a user's information
func (db *database) UpdateUser(user *models.User) error { func (db *database) UpdateUser(user *models.User) error {
_, err := db.Exec(` result, err := db.Exec(`
UPDATE users UPDATE users
SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ? SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ?
WHERE id = ?`, WHERE id = ?`,
user.Email, user.DisplayName, user.PasswordHash, user.Role, user.LastWorkspaceID, user.ID) user.Email, user.DisplayName, user.PasswordHash, user.Role,
return err user.LastWorkspaceID, user.ID)
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("user not found")
}
return nil
} }
// GetAllUsers returns a list of all users in the system
func (db *database) GetAllUsers() ([]*models.User, error) { func (db *database) GetAllUsers() ([]*models.User, error) {
rows, err := db.Query(` rows, err := db.Query(`
SELECT SELECT
id, email, display_name, role, created_at, id, email, display_name, role, created_at,
last_workspace_id last_workspace_id
FROM users FROM users
ORDER BY id ASC`) ORDER BY id ASC`)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to query users: %w", err)
} }
defer rows.Close() defer rows.Close()
@@ -154,60 +185,74 @@ func (db *database) GetAllUsers() ([]*models.User, error) {
&user.CreatedAt, &user.LastWorkspaceID, &user.CreatedAt, &user.LastWorkspaceID,
) )
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to scan user row: %w", err)
} }
users = append(users, user) users = append(users, user)
} }
return users, nil return users, nil
} }
// UpdateLastWorkspace updates the last workspace the user accessed
func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error { func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error {
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
return err return fmt.Errorf("failed to begin transaction: %w", err)
} }
defer tx.Rollback() defer tx.Rollback()
var workspaceID int var workspaceID int
err = tx.QueryRow("SELECT id FROM workspaces WHERE user_id = ? AND name = ?",
err = tx.QueryRow("SELECT id FROM workspaces WHERE user_id = ? AND name = ?", userID, workspaceName).Scan(&workspaceID) userID, workspaceName).Scan(&workspaceID)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to find workspace: %w", err)
} }
_, err = tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID) _, err = tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?",
workspaceID, userID)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to update last workspace: %w", err)
} }
return tx.Commit() err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
} }
// DeleteUser deletes a user and all their workspaces
func (db *database) DeleteUser(id int) error { func (db *database) DeleteUser(id int) error {
log := getLogger().WithGroup("users")
log.Debug("deleting user", "user_id", id)
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
return err return fmt.Errorf("failed to begin transaction: %w", err)
} }
defer tx.Rollback() defer tx.Rollback()
// Delete all user's workspaces first // Delete all user's workspaces first
log.Debug("deleting user workspaces", "user_id", id)
_, err = tx.Exec("DELETE FROM workspaces WHERE user_id = ?", id) _, err = tx.Exec("DELETE FROM workspaces WHERE user_id = ?", id)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to delete workspaces: %w", err)
} }
// Delete the user // Delete the user
_, err = tx.Exec("DELETE FROM users WHERE id = ?", id) _, err = tx.Exec("DELETE FROM users WHERE id = ?", id)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to delete user: %w", err)
} }
return tx.Commit() err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
log.Debug("deleted user", "user_id", id)
return nil
} }
// GetLastWorkspaceName returns the name of the last workspace the user accessed
func (db *database) GetLastWorkspaceName(userID int) (string, error) { func (db *database) GetLastWorkspaceName(userID int) (string, error) {
var workspaceName string var workspaceName string
err := db.QueryRow(` err := db.QueryRow(`
@@ -217,12 +262,24 @@ func (db *database) GetLastWorkspaceName(userID int) (string, error) {
JOIN users u ON u.last_workspace_id = w.id JOIN users u ON u.last_workspace_id = w.id
WHERE u.id = ?`, userID). WHERE u.id = ?`, userID).
Scan(&workspaceName) Scan(&workspaceName)
return workspaceName, err
if err == sql.ErrNoRows {
return "", fmt.Errorf("no last workspace found")
}
if err != nil {
return "", fmt.Errorf("failed to fetch last workspace name: %w", err)
}
return workspaceName, nil
} }
// CountAdminUsers returns the number of admin users in the system // CountAdminUsers returns the number of admin users in the system
func (db *database) CountAdminUsers() (int, error) { func (db *database) CountAdminUsers() (int, error) {
var count int var count int
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count) err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count)
return count, err if err != nil {
return 0, fmt.Errorf("failed to count admin users: %w", err)
}
return count, nil
} }

View File

@@ -6,6 +6,7 @@ import (
"novamd/internal/db" "novamd/internal/db"
"novamd/internal/models" "novamd/internal/models"
_ "novamd/internal/testenv"
) )
func TestUserOperations(t *testing.T) { func TestUserOperations(t *testing.T) {

View File

@@ -8,6 +8,12 @@ import (
// CreateWorkspace inserts a new workspace record into the database // CreateWorkspace inserts a new workspace record into the database
func (db *database) CreateWorkspace(workspace *models.Workspace) error { func (db *database) CreateWorkspace(workspace *models.Workspace) error {
log := getLogger().WithGroup("workspaces")
log.Debug("creating new workspace",
"user_id", workspace.UserID,
"name", workspace.Name,
"git_enabled", workspace.GitEnabled)
// Set default settings if not provided // Set default settings if not provided
if workspace.Theme == "" { if workspace.Theme == "" {
workspace.SetDefaultSettings() workspace.SetDefaultSettings()
@@ -20,25 +26,26 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error {
} }
result, err := db.Exec(` result, err := db.Exec(`
INSERT INTO workspaces ( INSERT INTO workspaces (
user_id, name, theme, auto_save, show_hidden_files, user_id, name, theme, auto_save, show_hidden_files,
git_enabled, git_url, git_user, git_token, git_enabled, git_url, git_user, git_token,
git_auto_commit, git_commit_msg_template, git_auto_commit, git_commit_msg_template,
git_commit_name, git_commit_email git_commit_name, git_commit_email
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
workspace.UserID, workspace.Name, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles, workspace.UserID, workspace.Name, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles,
workspace.GitEnabled, workspace.GitURL, workspace.GitUser, encryptedToken, workspace.GitEnabled, workspace.GitURL, workspace.GitUser, encryptedToken,
workspace.GitAutoCommit, workspace.GitCommitMsgTemplate, workspace.GitCommitName, workspace.GitCommitEmail, workspace.GitAutoCommit, workspace.GitCommitMsgTemplate, workspace.GitCommitName, workspace.GitCommitEmail,
) )
if err != nil { if err != nil {
return err return fmt.Errorf("failed to insert workspace: %w", err)
} }
id, err := result.LastInsertId() id, err := result.LastInsertId()
if err != nil { if err != nil {
return err return fmt.Errorf("failed to get workspace ID: %w", err)
} }
workspace.ID = int(id) workspace.ID = int(id)
return nil return nil
} }
@@ -48,23 +55,28 @@ func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
var encryptedToken string var encryptedToken string
err := db.QueryRow(` err := db.QueryRow(`
SELECT SELECT
id, user_id, name, created_at, id, user_id, name, created_at,
theme, auto_save, show_hidden_files, theme, auto_save, show_hidden_files,
git_enabled, git_url, git_user, git_token, git_enabled, git_url, git_user, git_token,
git_auto_commit, git_commit_msg_template, git_auto_commit, git_commit_msg_template,
git_commit_name, git_commit_email git_commit_name, git_commit_email
FROM workspaces FROM workspaces
WHERE id = ?`, WHERE id = ?`,
id, id,
).Scan( ).Scan(
&workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt, &workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt,
&workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles, &workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles,
&workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken, &workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken,
&workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, &workspace.GitCommitName, &workspace.GitCommitEmail, &workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate,
&workspace.GitCommitName, &workspace.GitCommitEmail,
) )
if err == sql.ErrNoRows {
return nil, fmt.Errorf("workspace not found")
}
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to fetch workspace: %w", err)
} }
// Decrypt token // Decrypt token
@@ -82,14 +94,14 @@ func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*model
var encryptedToken string var encryptedToken string
err := db.QueryRow(` err := db.QueryRow(`
SELECT SELECT
id, user_id, name, created_at, id, user_id, name, created_at,
theme, auto_save, show_hidden_files, theme, auto_save, show_hidden_files,
git_enabled, git_url, git_user, git_token, git_enabled, git_url, git_user, git_token,
git_auto_commit, git_commit_msg_template, git_auto_commit, git_commit_msg_template,
git_commit_name, git_commit_email git_commit_name, git_commit_email
FROM workspaces FROM workspaces
WHERE user_id = ? AND name = ?`, WHERE user_id = ? AND name = ?`,
userID, workspaceName, userID, workspaceName,
).Scan( ).Scan(
&workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt, &workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt,
@@ -98,8 +110,12 @@ func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*model
&workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, &workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate,
&workspace.GitCommitName, &workspace.GitCommitEmail, &workspace.GitCommitName, &workspace.GitCommitEmail,
) )
if err == sql.ErrNoRows {
return nil, fmt.Errorf("workspace not found")
}
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to fetch workspace: %w", err)
} }
// Decrypt token // Decrypt token
@@ -120,21 +136,21 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
} }
_, err = db.Exec(` _, err = db.Exec(`
UPDATE workspaces UPDATE workspaces
SET SET
name = ?, name = ?,
theme = ?, theme = ?,
auto_save = ?, auto_save = ?,
show_hidden_files = ?, show_hidden_files = ?,
git_enabled = ?, git_enabled = ?,
git_url = ?, git_url = ?,
git_user = ?, git_user = ?,
git_token = ?, git_token = ?,
git_auto_commit = ?, git_auto_commit = ?,
git_commit_msg_template = ?, git_commit_msg_template = ?,
git_commit_name = ?, git_commit_name = ?,
git_commit_email = ? git_commit_email = ?
WHERE id = ? AND user_id = ?`, WHERE id = ? AND user_id = ?`,
workspace.Name, workspace.Name,
workspace.Theme, workspace.Theme,
workspace.AutoSave, workspace.AutoSave,
@@ -150,24 +166,28 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
workspace.ID, workspace.ID,
workspace.UserID, workspace.UserID,
) )
return err if err != nil {
return fmt.Errorf("failed to update workspace: %w", err)
}
return nil
} }
// GetWorkspacesByUserID retrieves all workspaces for a user // GetWorkspacesByUserID retrieves all workspaces for a user
func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
rows, err := db.Query(` rows, err := db.Query(`
SELECT SELECT
id, user_id, name, created_at, id, user_id, name, created_at,
theme, auto_save, show_hidden_files, theme, auto_save, show_hidden_files,
git_enabled, git_url, git_user, git_token, git_enabled, git_url, git_user, git_token,
git_auto_commit, git_commit_msg_template, git_auto_commit, git_commit_msg_template,
git_commit_name, git_commit_email git_commit_name, git_commit_email
FROM workspaces FROM workspaces
WHERE user_id = ?`, WHERE user_id = ?`,
userID, userID,
) )
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to query workspaces: %w", err)
} }
defer rows.Close() defer rows.Close()
@@ -183,7 +203,7 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro
&workspace.GitCommitName, &workspace.GitCommitEmail, &workspace.GitCommitName, &workspace.GitCommitEmail,
) )
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to scan workspace row: %w", err)
} }
// Decrypt token // Decrypt token
@@ -194,27 +214,31 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro
workspaces = append(workspaces, workspace) workspaces = append(workspaces, workspace)
} }
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating workspace rows: %w", err)
}
return workspaces, nil return workspaces, nil
} }
// UpdateWorkspaceSettings updates only the settings portion of a workspace // UpdateWorkspaceSettings updates only the settings portion of a workspace
// This is useful when you don't want to modify the name or other core workspace properties
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
_, err := db.Exec(` _, err := db.Exec(`
UPDATE workspaces UPDATE workspaces
SET SET
theme = ?, theme = ?,
auto_save = ?, auto_save = ?,
show_hidden_files = ?, show_hidden_files = ?,
git_enabled = ?, git_enabled = ?,
git_url = ?, git_url = ?,
git_user = ?, git_user = ?,
git_token = ?, git_token = ?,
git_auto_commit = ?, git_auto_commit = ?,
git_commit_msg_template = ?, git_commit_msg_template = ?,
git_commit_name = ?, git_commit_name = ?,
git_commit_email = ? git_commit_email = ?
WHERE id = ?`, WHERE id = ?`,
workspace.Theme, workspace.Theme,
workspace.AutoSave, workspace.AutoSave,
workspace.ShowHiddenFiles, workspace.ShowHiddenFiles,
@@ -228,59 +252,104 @@ func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
workspace.GitCommitEmail, workspace.GitCommitEmail,
workspace.ID, workspace.ID,
) )
return err if err != nil {
return fmt.Errorf("failed to update workspace settings: %w", err)
}
return nil
} }
// DeleteWorkspace removes a workspace record from the database // DeleteWorkspace removes a workspace record from the database
func (db *database) DeleteWorkspace(id int) error { func (db *database) DeleteWorkspace(id int) error {
log := getLogger().WithGroup("workspaces")
_, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id) _, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id)
return err if err != nil {
return fmt.Errorf("failed to delete workspace: %w", err)
}
log.Debug("workspace deleted", "workspace_id", id)
return nil
} }
// DeleteWorkspaceTx removes a workspace record from the database within a transaction // DeleteWorkspaceTx removes a workspace record from the database within a transaction
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error { func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
_, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id) log := getLogger().WithGroup("workspaces")
return err result, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id)
if err != nil {
return fmt.Errorf("failed to delete workspace in transaction: %w", err)
}
_, err = result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected in transaction: %w", err)
}
log.Debug("workspace deleted",
"workspace_id", id)
return nil
} }
// UpdateLastWorkspaceTx sets the last workspace for a user in with a transaction // UpdateLastWorkspaceTx sets the last workspace for a user in a transaction
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error { func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
_, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID) result, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?",
return err workspaceID, userID)
if err != nil {
return fmt.Errorf("failed to update last workspace in transaction: %w", err)
}
_, err = result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected in transaction: %w", err)
}
return nil
} }
// UpdateLastOpenedFile updates the last opened file path for a workspace // UpdateLastOpenedFile updates the last opened file path for a workspace
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error { func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
_, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?", filePath, workspaceID) _, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?",
return err filePath, workspaceID)
if err != nil {
return fmt.Errorf("failed to update last opened file: %w", err)
}
return nil
} }
// GetLastOpenedFile retrieves the last opened file path for a workspace // GetLastOpenedFile retrieves the last opened file path for a workspace
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) { func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
var filePath sql.NullString var filePath sql.NullString
err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?", workspaceID).Scan(&filePath) err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?",
if err != nil { workspaceID).Scan(&filePath)
return "", err
if err == sql.ErrNoRows {
return "", fmt.Errorf("workspace not found")
} }
if err != nil {
return "", fmt.Errorf("failed to fetch last opened file: %w", err)
}
if !filePath.Valid { if !filePath.Valid {
return "", nil return "", nil
} }
return filePath.String, nil return filePath.String, nil
} }
// GetAllWorkspaces retrieves all workspaces in the database // GetAllWorkspaces retrieves all workspaces in the database
func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
rows, err := db.Query(` rows, err := db.Query(`
SELECT SELECT
id, user_id, name, created_at, id, user_id, name, created_at,
theme, auto_save, show_hidden_files, theme, auto_save, show_hidden_files,
git_enabled, git_url, git_user, git_token, git_enabled, git_url, git_user, git_token,
git_auto_commit, git_commit_msg_template, git_auto_commit, git_commit_msg_template,
git_commit_name, git_commit_email git_commit_name, git_commit_email
FROM workspaces`, FROM workspaces`,
) )
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to query workspaces: %w", err)
} }
defer rows.Close() defer rows.Close()
@@ -296,7 +365,7 @@ func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
&workspace.GitCommitName, &workspace.GitCommitEmail, &workspace.GitCommitName, &workspace.GitCommitEmail,
) )
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to scan workspace row: %w", err)
} }
// Decrypt token // Decrypt token
@@ -307,5 +376,10 @@ func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
workspaces = append(workspaces, workspace) workspaces = append(workspaces, workspace)
} }
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating workspace rows: %w", err)
}
return workspaces, nil return workspaces, nil
} }

View File

@@ -1,12 +1,12 @@
package db_test package db_test
import ( import (
"database/sql"
"strings" "strings"
"testing" "testing"
"novamd/internal/db" "novamd/internal/db"
"novamd/internal/models" "novamd/internal/models"
_ "novamd/internal/testenv"
) )
func TestWorkspaceOperations(t *testing.T) { func TestWorkspaceOperations(t *testing.T) {
@@ -385,8 +385,8 @@ func TestWorkspaceOperations(t *testing.T) {
// Verify workspace is gone // Verify workspace is gone
_, err = database.GetWorkspaceByID(workspace.ID) _, err = database.GetWorkspaceByID(workspace.ID)
if err != sql.ErrNoRows { if !strings.Contains(err.Error(), "workspace not found") {
t.Errorf("expected sql.ErrNoRows, got %v", err) t.Errorf("expected workspace not found, got %v", err)
} }
}) })
} }

View File

@@ -7,6 +7,8 @@ import (
"path/filepath" "path/filepath"
"time" "time"
"novamd/internal/logging"
"github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing" "github.com/go-git/go-git/v5/plumbing"
"github.com/go-git/go-git/v5/plumbing/object" "github.com/go-git/go-git/v5/plumbing/object"
@@ -46,6 +48,15 @@ type client struct {
repo *git.Repository repo *git.Repository
} }
var logger logging.Logger
func getLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("git")
}
return logger
}
// New creates a new git Client instance // New creates a new git Client instance
func New(url, username, token, workDir, commitName, commitEmail string) Client { func New(url, username, token, workDir, commitName, commitEmail string) Client {
return &client{ return &client{
@@ -62,6 +73,11 @@ func New(url, username, token, workDir, commitName, commitEmail string) Client {
// Clone clones the Git repository to the local directory // Clone clones the Git repository to the local directory
func (c *client) Clone() error { func (c *client) Clone() error {
log := getLogger()
log.Info("cloning git repository",
"url", c.URL,
"workDir", c.WorkDir)
auth := &http.BasicAuth{ auth := &http.BasicAuth{
Username: c.Username, Username: c.Username,
Password: c.Token, Password: c.Token,
@@ -73,7 +89,6 @@ func (c *client) Clone() error {
Auth: auth, Auth: auth,
Progress: os.Stdout, Progress: os.Stdout,
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to clone repository: %w", err) return fmt.Errorf("failed to clone repository: %w", err)
} }
@@ -83,6 +98,10 @@ func (c *client) Clone() error {
// Pull pulls the latest changes from the remote repository // Pull pulls the latest changes from the remote repository
func (c *client) Pull() error { func (c *client) Pull() error {
log := getLogger().With(
"workDir", c.WorkDir,
)
if c.repo == nil { if c.repo == nil {
return fmt.Errorf("repository not initialized") return fmt.Errorf("repository not initialized")
} }
@@ -101,16 +120,25 @@ func (c *client) Pull() error {
Auth: auth, Auth: auth,
Progress: os.Stdout, Progress: os.Stdout,
}) })
if err != nil && err != git.NoErrAlreadyUpToDate { if err != nil && err != git.NoErrAlreadyUpToDate {
return fmt.Errorf("failed to pull changes: %w", err) return fmt.Errorf("failed to pull changes: %w", err)
} }
if err == git.NoErrAlreadyUpToDate {
log.Debug("repository already up to date")
} else {
log.Debug("pulled latest changes")
}
return nil return nil
} }
// Commit commits the changes in the repository with the given message // Commit commits the changes in the repository with the given message
func (c *client) Commit(message string) (CommitHash, error) { func (c *client) Commit(message string) (CommitHash, error) {
log := getLogger().With(
"workDir", c.WorkDir,
)
if c.repo == nil { if c.repo == nil {
return CommitHash(plumbing.ZeroHash), fmt.Errorf("repository not initialized") return CommitHash(plumbing.ZeroHash), fmt.Errorf("repository not initialized")
} }
@@ -136,11 +164,16 @@ func (c *client) Commit(message string) (CommitHash, error) {
return CommitHash(plumbing.ZeroHash), fmt.Errorf("failed to commit changes: %w", err) return CommitHash(plumbing.ZeroHash), fmt.Errorf("failed to commit changes: %w", err)
} }
log.Debug("changes committed")
return CommitHash(hash), nil return CommitHash(hash), nil
} }
// Push pushes the changes to the remote repository // Push pushes the changes to the remote repository
func (c *client) Push() error { func (c *client) Push() error {
log := getLogger().With(
"workDir", c.WorkDir,
)
if c.repo == nil { if c.repo == nil {
return fmt.Errorf("repository not initialized") return fmt.Errorf("repository not initialized")
} }
@@ -154,17 +187,30 @@ func (c *client) Push() error {
Auth: auth, Auth: auth,
Progress: os.Stdout, Progress: os.Stdout,
}) })
if err != nil && err != git.NoErrAlreadyUpToDate { if err != nil && err != git.NoErrAlreadyUpToDate {
return fmt.Errorf("failed to push changes: %w", err) return fmt.Errorf("failed to push changes: %w", err)
} }
if err == git.NoErrAlreadyUpToDate {
log.Debug("remote already up to date",
"workDir", c.WorkDir)
} else {
log.Debug("pushed repository changes",
"workDir", c.WorkDir)
}
return nil return nil
} }
// EnsureRepo ensures the local repository is cloned and up-to-date // EnsureRepo ensures the local repository is cloned and up-to-date
func (c *client) EnsureRepo() error { func (c *client) EnsureRepo() error {
log := getLogger().With(
"workDir", c.WorkDir,
)
log.Debug("ensuring repository exists and is up to date")
if _, err := os.Stat(filepath.Join(c.WorkDir, ".git")); os.IsNotExist(err) { if _, err := os.Stat(filepath.Join(c.WorkDir, ".git")); os.IsNotExist(err) {
log.Info("repository not found, initiating clone")
return c.Clone() return c.Clone()
} }

View File

@@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"novamd/internal/context" "novamd/internal/context"
"novamd/internal/db" "novamd/internal/db"
"novamd/internal/logging"
"novamd/internal/models" "novamd/internal/models"
"novamd/internal/storage" "novamd/internal/storage"
"strconv" "strconv"
@@ -47,6 +48,10 @@ type SystemStats struct {
*storage.FileCountStats *storage.FileCountStats
} }
func getAdminLogger() logging.Logger {
return getHandlersLogger().WithGroup("admin")
}
// AdminListUsers godoc // AdminListUsers godoc
// @Summary List all users // @Summary List all users
// @Description Returns the list of all users // @Description Returns the list of all users
@@ -58,9 +63,22 @@ type SystemStats struct {
// @Failure 500 {object} ErrorResponse "Failed to list users" // @Failure 500 {object} ErrorResponse "Failed to list users"
// @Router /admin/users [get] // @Router /admin/users [get]
func (h *Handler) AdminListUsers() http.HandlerFunc { func (h *Handler) AdminListUsers() http.HandlerFunc {
return func(w http.ResponseWriter, _ *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
log := getAdminLogger().With(
"handler", "AdminListUsers",
"adminID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
users, err := h.DB.GetAllUsers() users, err := h.DB.GetAllUsers()
if err != nil { if err != nil {
log.Error("failed to fetch users from database",
"error", err.Error(),
)
respondError(w, "Failed to list users", http.StatusInternalServerError) respondError(w, "Failed to list users", http.StatusInternalServerError)
return return
} }
@@ -89,39 +107,63 @@ func (h *Handler) AdminListUsers() http.HandlerFunc {
// @Router /admin/users [post] // @Router /admin/users [post]
func (h *Handler) AdminCreateUser() http.HandlerFunc { func (h *Handler) AdminCreateUser() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
log := getAdminLogger().With(
"handler", "AdminCreateUser",
"adminID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
var req CreateUserRequest var req CreateUserRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Debug("failed to decode request body",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest) respondError(w, "Invalid request body", http.StatusBadRequest)
return return
} }
// Validate request // Validation logging
if req.Email == "" || req.Password == "" || req.Role == "" { if req.Email == "" || req.Password == "" || req.Role == "" {
log.Debug("missing required fields",
"hasEmail", req.Email != "",
"hasPassword", req.Password != "",
"hasRole", req.Role != "",
)
respondError(w, "Email, password, and role are required", http.StatusBadRequest) respondError(w, "Email, password, and role are required", http.StatusBadRequest)
return return
} }
// Check if email already exists // Email existence check
existingUser, err := h.DB.GetUserByEmail(req.Email) existingUser, err := h.DB.GetUserByEmail(req.Email)
if err == nil && existingUser != nil { if err == nil && existingUser != nil {
log.Warn("attempted to create user with existing email",
"email", req.Email,
)
respondError(w, "Email already exists", http.StatusConflict) respondError(w, "Email already exists", http.StatusConflict)
return return
} }
// Check if password is long enough
if len(req.Password) < 8 { if len(req.Password) < 8 {
log.Debug("password too short",
"passwordLength", len(req.Password),
)
respondError(w, "Password must be at least 8 characters", http.StatusBadRequest) respondError(w, "Password must be at least 8 characters", http.StatusBadRequest)
return return
} }
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil { if err != nil {
log.Error("failed to hash password",
"error", err.Error(),
)
respondError(w, "Failed to hash password", http.StatusInternalServerError) respondError(w, "Failed to hash password", http.StatusInternalServerError)
return return
} }
// Create user
user := &models.User{ user := &models.User{
Email: req.Email, Email: req.Email,
DisplayName: req.DisplayName, DisplayName: req.DisplayName,
@@ -131,16 +173,30 @@ func (h *Handler) AdminCreateUser() http.HandlerFunc {
insertedUser, err := h.DB.CreateUser(user) insertedUser, err := h.DB.CreateUser(user)
if err != nil { if err != nil {
log.Error("failed to create user in database",
"error", err.Error(),
"email", req.Email,
"role", req.Role,
)
respondError(w, "Failed to create user", http.StatusInternalServerError) respondError(w, "Failed to create user", http.StatusInternalServerError)
return return
} }
// Initialize user workspace
if err := h.Storage.InitializeUserWorkspace(insertedUser.ID, insertedUser.LastWorkspaceID); err != nil { if err := h.Storage.InitializeUserWorkspace(insertedUser.ID, insertedUser.LastWorkspaceID); err != nil {
log.Error("failed to initialize user workspace",
"error", err.Error(),
"userID", insertedUser.ID,
"workspaceID", insertedUser.LastWorkspaceID,
)
respondError(w, "Failed to initialize user workspace", http.StatusInternalServerError) respondError(w, "Failed to initialize user workspace", http.StatusInternalServerError)
return return
} }
log.Info("user created",
"newUserID", insertedUser.ID,
"email", insertedUser.Email,
"role", insertedUser.Role,
)
respondJSON(w, insertedUser) respondJSON(w, insertedUser)
} }
} }
@@ -159,14 +215,32 @@ func (h *Handler) AdminCreateUser() http.HandlerFunc {
// @Router /admin/users/{userId} [get] // @Router /admin/users/{userId} [get]
func (h *Handler) AdminGetUser() http.HandlerFunc { func (h *Handler) AdminGetUser() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
log := getAdminLogger().With(
"handler", "AdminGetUser",
"adminID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
userID, err := strconv.Atoi(chi.URLParam(r, "userId")) userID, err := strconv.Atoi(chi.URLParam(r, "userId"))
if err != nil { if err != nil {
log.Debug("invalid user ID format",
"userIDParam", chi.URLParam(r, "userId"),
"error", err.Error(),
)
respondError(w, "Invalid user ID", http.StatusBadRequest) respondError(w, "Invalid user ID", http.StatusBadRequest)
return return
} }
user, err := h.DB.GetUserByID(userID) user, err := h.DB.GetUserByID(userID)
if err != nil { if err != nil {
log.Debug("user not found",
"targetUserID", userID,
"error", err.Error(),
)
respondError(w, "User not found", http.StatusNotFound) respondError(w, "User not found", http.StatusNotFound)
return return
} }
@@ -194,49 +268,86 @@ func (h *Handler) AdminGetUser() http.HandlerFunc {
// @Router /admin/users/{userId} [put] // @Router /admin/users/{userId} [put]
func (h *Handler) AdminUpdateUser() http.HandlerFunc { func (h *Handler) AdminUpdateUser() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
log := getAdminLogger().With(
"handler", "AdminUpdateUser",
"adminID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
userID, err := strconv.Atoi(chi.URLParam(r, "userId")) userID, err := strconv.Atoi(chi.URLParam(r, "userId"))
if err != nil { if err != nil {
log.Debug("invalid user ID format",
"userIDParam", chi.URLParam(r, "userId"),
"error", err.Error(),
)
respondError(w, "Invalid user ID", http.StatusBadRequest) respondError(w, "Invalid user ID", http.StatusBadRequest)
return return
} }
// Get existing user
user, err := h.DB.GetUserByID(userID) user, err := h.DB.GetUserByID(userID)
if err != nil { if err != nil {
log.Debug("user not found",
"targetUserID", userID,
"error", err.Error(),
)
respondError(w, "User not found", http.StatusNotFound) respondError(w, "User not found", http.StatusNotFound)
return return
} }
var req UpdateUserRequest var req UpdateUserRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Debug("failed to decode request body",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest) respondError(w, "Invalid request body", http.StatusBadRequest)
return return
} }
// Update fields if provided // Track what's being updated for logging
updates := make(map[string]interface{})
if req.Email != "" { if req.Email != "" {
user.Email = req.Email user.Email = req.Email
updates["email"] = req.Email
} }
if req.DisplayName != "" { if req.DisplayName != "" {
user.DisplayName = req.DisplayName user.DisplayName = req.DisplayName
updates["displayName"] = req.DisplayName
} }
if req.Role != "" { if req.Role != "" {
user.Role = req.Role user.Role = req.Role
updates["role"] = req.Role
} }
if req.Password != "" { if req.Password != "" {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil { if err != nil {
log.Error("failed to hash password",
"error", err.Error(),
)
respondError(w, "Failed to hash password", http.StatusInternalServerError) respondError(w, "Failed to hash password", http.StatusInternalServerError)
return return
} }
user.PasswordHash = string(hashedPassword) user.PasswordHash = string(hashedPassword)
updates["passwordUpdated"] = true
} }
if err := h.DB.UpdateUser(user); err != nil { if err := h.DB.UpdateUser(user); err != nil {
log.Error("failed to update user in database",
"error", err.Error(),
"targetUserID", userID,
)
respondError(w, "Failed to update user", http.StatusInternalServerError) respondError(w, "Failed to update user", http.StatusInternalServerError)
return return
} }
log.Debug("user updated",
"targetUserID", userID,
"updates", updates,
)
respondJSON(w, user) respondJSON(w, user)
} }
} }
@@ -261,37 +372,61 @@ func (h *Handler) AdminDeleteUser() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getAdminLogger().With(
"handler", "AdminDeleteUser",
"adminID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
userID, err := strconv.Atoi(chi.URLParam(r, "userId")) userID, err := strconv.Atoi(chi.URLParam(r, "userId"))
if err != nil { if err != nil {
log.Debug("invalid user ID format",
"userIDParam", chi.URLParam(r, "userId"),
"error", err.Error(),
)
respondError(w, "Invalid user ID", http.StatusBadRequest) respondError(w, "Invalid user ID", http.StatusBadRequest)
return return
} }
// Prevent admin from deleting themselves
if userID == ctx.UserID { if userID == ctx.UserID {
log.Warn("admin attempted to delete own account")
respondError(w, "Cannot delete your own account", http.StatusBadRequest) respondError(w, "Cannot delete your own account", http.StatusBadRequest)
return return
} }
// Get user before deletion to check role
user, err := h.DB.GetUserByID(userID) user, err := h.DB.GetUserByID(userID)
if err != nil { if err != nil {
log.Debug("user not found",
"targetUserID", userID,
"error", err.Error(),
)
respondError(w, "User not found", http.StatusNotFound) respondError(w, "User not found", http.StatusNotFound)
return return
} }
// Prevent deletion of other admin users
if user.Role == models.RoleAdmin && ctx.UserID != userID { if user.Role == models.RoleAdmin && ctx.UserID != userID {
log.Warn("attempted to delete another admin user",
"targetUserID", userID,
"targetUserEmail", user.Email,
)
respondError(w, "Cannot delete other admin users", http.StatusForbidden) respondError(w, "Cannot delete other admin users", http.StatusForbidden)
return return
} }
if err := h.DB.DeleteUser(userID); err != nil { if err := h.DB.DeleteUser(userID); err != nil {
log.Error("failed to delete user from database",
"error", err.Error(),
"targetUserID", userID,
)
respondError(w, "Failed to delete user", http.StatusInternalServerError) respondError(w, "Failed to delete user", http.StatusInternalServerError)
return return
} }
log.Info("user deleted",
"targetUserID", userID,
"targetUserEmail", user.Email,
"targetUserRole", user.Role,
)
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
} }
} }
@@ -309,9 +444,22 @@ func (h *Handler) AdminDeleteUser() http.HandlerFunc {
// @Failure 500 {object} ErrorResponse "Failed to get file stats" // @Failure 500 {object} ErrorResponse "Failed to get file stats"
// @Router /admin/workspaces [get] // @Router /admin/workspaces [get]
func (h *Handler) AdminListWorkspaces() http.HandlerFunc { func (h *Handler) AdminListWorkspaces() http.HandlerFunc {
return func(w http.ResponseWriter, _ *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
log := getAdminLogger().With(
"handler", "AdminListWorkspaces",
"adminID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
workspaces, err := h.DB.GetAllWorkspaces() workspaces, err := h.DB.GetAllWorkspaces()
if err != nil { if err != nil {
log.Error("failed to fetch workspaces from database",
"error", err.Error(),
)
respondError(w, "Failed to list workspaces", http.StatusInternalServerError) respondError(w, "Failed to list workspaces", http.StatusInternalServerError)
return return
} }
@@ -319,11 +467,15 @@ func (h *Handler) AdminListWorkspaces() http.HandlerFunc {
workspacesStats := make([]*WorkspaceStats, 0, len(workspaces)) workspacesStats := make([]*WorkspaceStats, 0, len(workspaces))
for _, ws := range workspaces { for _, ws := range workspaces {
workspaceData := &WorkspaceStats{} workspaceData := &WorkspaceStats{}
user, err := h.DB.GetUserByID(ws.UserID) user, err := h.DB.GetUserByID(ws.UserID)
if err != nil { if err != nil {
log.Error("failed to fetch user for workspace",
"error", err.Error(),
"workspaceID", ws.ID,
"userID", ws.UserID,
)
respondError(w, "Failed to get user", http.StatusInternalServerError) respondError(w, "Failed to get user", http.StatusInternalServerError)
return return
} }
@@ -336,12 +488,16 @@ func (h *Handler) AdminListWorkspaces() http.HandlerFunc {
fileStats, err := h.Storage.GetFileStats(ws.UserID, ws.ID) fileStats, err := h.Storage.GetFileStats(ws.UserID, ws.ID)
if err != nil { if err != nil {
log.Error("failed to fetch file stats for workspace",
"error", err.Error(),
"workspaceID", ws.ID,
"userID", ws.UserID,
)
respondError(w, "Failed to get file stats", http.StatusInternalServerError) respondError(w, "Failed to get file stats", http.StatusInternalServerError)
return return
} }
workspaceData.FileCountStats = fileStats workspaceData.FileCountStats = fileStats
workspacesStats = append(workspacesStats, workspaceData) workspacesStats = append(workspacesStats, workspaceData)
} }
@@ -361,15 +517,31 @@ func (h *Handler) AdminListWorkspaces() http.HandlerFunc {
// @Failure 500 {object} ErrorResponse "Failed to get file stats" // @Failure 500 {object} ErrorResponse "Failed to get file stats"
// @Router /admin/stats [get] // @Router /admin/stats [get]
func (h *Handler) AdminGetSystemStats() http.HandlerFunc { func (h *Handler) AdminGetSystemStats() http.HandlerFunc {
return func(w http.ResponseWriter, _ *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
log := getAdminLogger().With(
"handler", "AdminGetSystemStats",
"adminID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
userStats, err := h.DB.GetSystemStats() userStats, err := h.DB.GetSystemStats()
if err != nil { if err != nil {
log.Error("failed to fetch user statistics",
"error", err.Error(),
)
respondError(w, "Failed to get user stats", http.StatusInternalServerError) respondError(w, "Failed to get user stats", http.StatusInternalServerError)
return return
} }
fileStats, err := h.Storage.GetTotalFileStats() fileStats, err := h.Storage.GetTotalFileStats()
if err != nil { if err != nil {
log.Error("failed to fetch file statistics",
"error", err.Error(),
)
respondError(w, "Failed to get file stats", http.StatusInternalServerError) respondError(w, "Failed to get file stats", http.StatusInternalServerError)
return return
} }

View File

@@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"novamd/internal/auth" "novamd/internal/auth"
"novamd/internal/context" "novamd/internal/context"
"novamd/internal/logging"
"novamd/internal/models" "novamd/internal/models"
"time" "time"
@@ -26,6 +27,10 @@ type LoginResponse struct {
ExpiresAt time.Time `json:"expiresAt,omitempty"` ExpiresAt time.Time `json:"expiresAt,omitempty"`
} }
func getAuthLogger() logging.Logger {
return getHandlersLogger().WithGroup("auth")
}
// Login godoc // Login godoc
// @Summary Login // @Summary Login
// @Description Logs in a user and returns a session with access and refresh tokens // @Description Logs in a user and returns a session with access and refresh tokens
@@ -43,62 +48,88 @@ type LoginResponse struct {
// @Router /auth/login [post] // @Router /auth/login [post]
func (h *Handler) Login(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc { func (h *Handler) Login(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
log := getAuthLogger().With(
"handler", "Login",
"clientIP", r.RemoteAddr,
)
var req LoginRequest var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Debug("failed to decode request body",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest) respondError(w, "Invalid request body", http.StatusBadRequest)
return return
} }
// Validate request
if req.Email == "" || req.Password == "" { if req.Email == "" || req.Password == "" {
log.Debug("missing required fields",
"hasEmail", req.Email != "",
"hasPassword", req.Password != "",
)
respondError(w, "Email and password are required", http.StatusBadRequest) respondError(w, "Email and password are required", http.StatusBadRequest)
return return
} }
// Get user from database
user, err := h.DB.GetUserByEmail(req.Email) user, err := h.DB.GetUserByEmail(req.Email)
if err != nil { if err != nil {
log.Debug("user not found",
"email", req.Email,
"error", err.Error(),
)
respondError(w, "Invalid credentials", http.StatusUnauthorized) respondError(w, "Invalid credentials", http.StatusUnauthorized)
return return
} }
// Verify password
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)) err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
if err != nil { if err != nil {
log.Warn("invalid password attempt",
"userID", user.ID,
"email", user.Email,
)
respondError(w, "Invalid credentials", http.StatusUnauthorized) respondError(w, "Invalid credentials", http.StatusUnauthorized)
return return
} }
// Create session and generate tokens
session, accessToken, err := authManager.CreateSession(user.ID, string(user.Role)) session, accessToken, err := authManager.CreateSession(user.ID, string(user.Role))
if err != nil { if err != nil {
log.Error("failed to create session",
"error", err.Error(),
"userID", user.ID,
)
respondError(w, "Failed to create session", http.StatusInternalServerError) respondError(w, "Failed to create session", http.StatusInternalServerError)
return return
} }
// Generate CSRF token
csrfToken := make([]byte, 32) csrfToken := make([]byte, 32)
if _, err := rand.Read(csrfToken); err != nil { if _, err := rand.Read(csrfToken); err != nil {
log.Error("failed to generate CSRF token",
"error", err.Error(),
"userID", user.ID,
)
respondError(w, "Failed to generate CSRF token", http.StatusInternalServerError) respondError(w, "Failed to generate CSRF token", http.StatusInternalServerError)
return return
} }
csrfTokenString := hex.EncodeToString(csrfToken) csrfTokenString := hex.EncodeToString(csrfToken)
// Set cookies
http.SetCookie(w, cookieService.GenerateAccessTokenCookie(accessToken)) http.SetCookie(w, cookieService.GenerateAccessTokenCookie(accessToken))
http.SetCookie(w, cookieService.GenerateRefreshTokenCookie(session.RefreshToken)) http.SetCookie(w, cookieService.GenerateRefreshTokenCookie(session.RefreshToken))
http.SetCookie(w, cookieService.GenerateCSRFCookie(csrfTokenString)) http.SetCookie(w, cookieService.GenerateCSRFCookie(csrfTokenString))
// Send CSRF token in header for initial setup
w.Header().Set("X-CSRF-Token", csrfTokenString) w.Header().Set("X-CSRF-Token", csrfTokenString)
// Only send user info in response, not tokens
response := LoginResponse{ response := LoginResponse{
User: user, User: user,
SessionID: session.ID, SessionID: session.ID,
ExpiresAt: session.ExpiresAt, ExpiresAt: session.ExpiresAt,
} }
log.Debug("user logged in",
"userID", user.ID,
"email", user.Email,
"role", user.Role,
"sessionID", session.ID,
)
respondJSON(w, response) respondJSON(w, response)
} }
} }
@@ -114,24 +145,41 @@ func (h *Handler) Login(authManager auth.SessionManager, cookieService auth.Cook
// @Router /auth/logout [post] // @Router /auth/logout [post]
func (h *Handler) Logout(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc { func (h *Handler) Logout(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
// Get session ID from cookie ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
log := getAuthLogger().With(
"handler", "Logout",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
sessionCookie, err := r.Cookie("access_token") sessionCookie, err := r.Cookie("access_token")
if err != nil { if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized) log.Debug("missing access token cookie",
"error", err.Error(),
)
respondError(w, "Access token required", http.StatusBadRequest)
return return
} }
// Invalidate the session in the database
if err := authManager.InvalidateSession(sessionCookie.Value); err != nil { if err := authManager.InvalidateSession(sessionCookie.Value); err != nil {
log.Error("failed to invalidate session",
"error", err.Error(),
"sessionID", sessionCookie.Value,
)
respondError(w, "Failed to invalidate session", http.StatusInternalServerError) respondError(w, "Failed to invalidate session", http.StatusInternalServerError)
return return
} }
// Clear cookies
http.SetCookie(w, cookieService.InvalidateCookie("access_token")) http.SetCookie(w, cookieService.InvalidateCookie("access_token"))
http.SetCookie(w, cookieService.InvalidateCookie("refresh_token")) http.SetCookie(w, cookieService.InvalidateCookie("refresh_token"))
http.SetCookie(w, cookieService.InvalidateCookie("csrf_token")) http.SetCookie(w, cookieService.InvalidateCookie("csrf_token"))
log.Info("user logged out successfully",
"sessionID", sessionCookie.Value,
)
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
} }
} }
@@ -151,22 +199,34 @@ func (h *Handler) Logout(authManager auth.SessionManager, cookieService auth.Coo
// @Router /auth/refresh [post] // @Router /auth/refresh [post]
func (h *Handler) RefreshToken(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc { func (h *Handler) RefreshToken(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
log := getAuthLogger().With(
"handler", "RefreshToken",
"clientIP", r.RemoteAddr,
)
refreshCookie, err := r.Cookie("refresh_token") refreshCookie, err := r.Cookie("refresh_token")
if err != nil { if err != nil {
log.Debug("missing refresh token cookie",
"error", err.Error(),
)
respondError(w, "Refresh token required", http.StatusBadRequest) respondError(w, "Refresh token required", http.StatusBadRequest)
return return
} }
// Generate new access token
accessToken, err := authManager.RefreshSession(refreshCookie.Value) accessToken, err := authManager.RefreshSession(refreshCookie.Value)
if err != nil { if err != nil {
log.Error("failed to refresh session",
"error", err.Error(),
)
respondError(w, "Invalid refresh token", http.StatusUnauthorized) respondError(w, "Invalid refresh token", http.StatusUnauthorized)
return return
} }
// Generate new CSRF token
csrfToken := make([]byte, 32) csrfToken := make([]byte, 32)
if _, err := rand.Read(csrfToken); err != nil { if _, err := rand.Read(csrfToken); err != nil {
log.Error("failed to generate CSRF token",
"error", err.Error(),
)
respondError(w, "Failed to generate CSRF token", http.StatusInternalServerError) respondError(w, "Failed to generate CSRF token", http.StatusInternalServerError)
return return
} }
@@ -196,10 +256,17 @@ func (h *Handler) GetCurrentUser() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getAuthLogger().With(
"handler", "GetCurrentUser",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
// Get user from database
user, err := h.DB.GetUserByID(ctx.UserID) user, err := h.DB.GetUserByID(ctx.UserID)
if err != nil { if err != nil {
log.Error("failed to fetch user",
"error", err.Error(),
)
respondError(w, "User not found", http.StatusNotFound) respondError(w, "User not found", http.StatusNotFound)
return return
} }

View File

@@ -8,6 +8,7 @@ import (
"time" "time"
"novamd/internal/context" "novamd/internal/context"
"novamd/internal/logging"
"novamd/internal/storage" "novamd/internal/storage"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
@@ -35,6 +36,10 @@ type UpdateLastOpenedFileRequest struct {
FilePath string `json:"filePath"` FilePath string `json:"filePath"`
} }
func getFilesLogger() logging.Logger {
return getHandlersLogger().WithGroup("files")
}
// ListFiles godoc // ListFiles godoc
// @Summary List files // @Summary List files
// @Description Lists all files in the user's workspace // @Description Lists all files in the user's workspace
@@ -52,9 +57,18 @@ func (h *Handler) ListFiles() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getFilesLogger().With(
"handler", "ListFiles",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
files, err := h.Storage.ListFilesRecursively(ctx.UserID, ctx.Workspace.ID) files, err := h.Storage.ListFilesRecursively(ctx.UserID, ctx.Workspace.ID)
if err != nil { if err != nil {
log.Error("failed to list files in workspace",
"error", err.Error(),
)
respondError(w, "Failed to list files", http.StatusInternalServerError) respondError(w, "Failed to list files", http.StatusInternalServerError)
return return
} }
@@ -82,15 +96,32 @@ func (h *Handler) LookupFileByName() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getFilesLogger().With(
"handler", "LookupFileByName",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
filename := r.URL.Query().Get("filename") filename := r.URL.Query().Get("filename")
if filename == "" { if filename == "" {
log.Debug("missing filename parameter")
respondError(w, "Filename is required", http.StatusBadRequest) respondError(w, "Filename is required", http.StatusBadRequest)
return return
} }
filePaths, err := h.Storage.FindFileByName(ctx.UserID, ctx.Workspace.ID, filename) filePaths, err := h.Storage.FindFileByName(ctx.UserID, ctx.Workspace.ID, filename)
if err != nil { if err != nil {
if !os.IsNotExist(err) {
log.Error("failed to lookup file",
"filename", filename,
"error", err.Error(),
)
} else {
log.Debug("file not found",
"filename", filename,
)
}
respondError(w, "File not found", http.StatusNotFound) respondError(w, "File not found", http.StatusNotFound)
return return
} }
@@ -120,21 +151,37 @@ func (h *Handler) GetFileContent() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getFilesLogger().With(
"handler", "GetFileContent",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
filePath := chi.URLParam(r, "*") filePath := chi.URLParam(r, "*")
content, err := h.Storage.GetFileContent(ctx.UserID, ctx.Workspace.ID, filePath) content, err := h.Storage.GetFileContent(ctx.UserID, ctx.Workspace.ID, filePath)
if err != nil { if err != nil {
if storage.IsPathValidationError(err) { if storage.IsPathValidationError(err) {
log.Error("invalid file path attempted",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Invalid file path", http.StatusBadRequest) respondError(w, "Invalid file path", http.StatusBadRequest)
return return
} }
if os.IsNotExist(err) { if os.IsNotExist(err) {
log.Debug("file not found",
"filePath", filePath,
)
respondError(w, "File not found", http.StatusNotFound) respondError(w, "File not found", http.StatusNotFound)
return return
} }
log.Error("failed to read file content",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Failed to read file", http.StatusInternalServerError) respondError(w, "Failed to read file", http.StatusInternalServerError)
return return
} }
@@ -142,6 +189,10 @@ func (h *Handler) GetFileContent() http.HandlerFunc {
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
_, err = w.Write(content) _, err = w.Write(content)
if err != nil { if err != nil {
log.Error("failed to write response",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Failed to write response", http.StatusInternalServerError) respondError(w, "Failed to write response", http.StatusInternalServerError)
return return
} }
@@ -169,10 +220,20 @@ func (h *Handler) SaveFile() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getFilesLogger().With(
"handler", "SaveFile",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
filePath := chi.URLParam(r, "*") filePath := chi.URLParam(r, "*")
content, err := io.ReadAll(r.Body) content, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
log.Error("failed to read request body",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Failed to read request body", http.StatusBadRequest) respondError(w, "Failed to read request body", http.StatusBadRequest)
return return
} }
@@ -180,10 +241,19 @@ func (h *Handler) SaveFile() http.HandlerFunc {
err = h.Storage.SaveFile(ctx.UserID, ctx.Workspace.ID, filePath, content) err = h.Storage.SaveFile(ctx.UserID, ctx.Workspace.ID, filePath, content)
if err != nil { if err != nil {
if storage.IsPathValidationError(err) { if storage.IsPathValidationError(err) {
log.Error("invalid file path attempted",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Invalid file path", http.StatusBadRequest) respondError(w, "Invalid file path", http.StatusBadRequest)
return return
} }
log.Error("failed to save file",
"filePath", filePath,
"contentSize", len(content),
"error", err.Error(),
)
respondError(w, "Failed to save file", http.StatusInternalServerError) respondError(w, "Failed to save file", http.StatusInternalServerError)
return return
} }
@@ -194,7 +264,6 @@ func (h *Handler) SaveFile() http.HandlerFunc {
UpdatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
} }
w.WriteHeader(http.StatusOK)
respondJSON(w, response) respondJSON(w, response)
} }
} }
@@ -211,7 +280,6 @@ func (h *Handler) SaveFile() http.HandlerFunc {
// @Failure 400 {object} ErrorResponse "Invalid file path" // @Failure 400 {object} ErrorResponse "Invalid file path"
// @Failure 404 {object} ErrorResponse "File not found" // @Failure 404 {object} ErrorResponse "File not found"
// @Failure 500 {object} ErrorResponse "Failed to delete file" // @Failure 500 {object} ErrorResponse "Failed to delete file"
// @Failure 500 {object} ErrorResponse "Failed to write response"
// @Router /workspaces/{workspace_name}/files/{file_path} [delete] // @Router /workspaces/{workspace_name}/files/{file_path} [delete]
func (h *Handler) DeleteFile() http.HandlerFunc { func (h *Handler) DeleteFile() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
@@ -219,20 +287,37 @@ func (h *Handler) DeleteFile() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getFilesLogger().With(
"handler", "DeleteFile",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
filePath := chi.URLParam(r, "*") filePath := chi.URLParam(r, "*")
err := h.Storage.DeleteFile(ctx.UserID, ctx.Workspace.ID, filePath) err := h.Storage.DeleteFile(ctx.UserID, ctx.Workspace.ID, filePath)
if err != nil { if err != nil {
if storage.IsPathValidationError(err) { if storage.IsPathValidationError(err) {
log.Error("invalid file path attempted",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Invalid file path", http.StatusBadRequest) respondError(w, "Invalid file path", http.StatusBadRequest)
return return
} }
if os.IsNotExist(err) { if os.IsNotExist(err) {
log.Debug("file not found",
"filePath", filePath,
)
respondError(w, "File not found", http.StatusNotFound) respondError(w, "File not found", http.StatusNotFound)
return return
} }
log.Error("failed to delete file",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Failed to delete file", http.StatusInternalServerError) respondError(w, "Failed to delete file", http.StatusInternalServerError)
return return
} }
@@ -259,14 +344,27 @@ func (h *Handler) GetLastOpenedFile() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getFilesLogger().With(
"handler", "GetLastOpenedFile",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
filePath, err := h.DB.GetLastOpenedFile(ctx.Workspace.ID) filePath, err := h.DB.GetLastOpenedFile(ctx.Workspace.ID)
if err != nil { if err != nil {
log.Error("failed to get last opened file from database",
"error", err.Error(),
)
respondError(w, "Failed to get last opened file", http.StatusInternalServerError) respondError(w, "Failed to get last opened file", http.StatusInternalServerError)
return return
} }
if _, err := h.Storage.ValidatePath(ctx.UserID, ctx.Workspace.ID, filePath); err != nil { if _, err := h.Storage.ValidatePath(ctx.UserID, ctx.Workspace.ID, filePath); err != nil {
log.Error("invalid file path stored",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Invalid file path", http.StatusBadRequest) respondError(w, "Invalid file path", http.StatusBadRequest)
return return
} }
@@ -297,10 +395,18 @@ func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getFilesLogger().With(
"handler", "UpdateLastOpenedFile",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
var requestBody UpdateLastOpenedFileRequest var requestBody UpdateLastOpenedFileRequest
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
log.Error("failed to decode request body",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest) respondError(w, "Invalid request body", http.StatusBadRequest)
return return
} }
@@ -310,21 +416,36 @@ func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc {
_, err := h.Storage.GetFileContent(ctx.UserID, ctx.Workspace.ID, requestBody.FilePath) _, err := h.Storage.GetFileContent(ctx.UserID, ctx.Workspace.ID, requestBody.FilePath)
if err != nil { if err != nil {
if storage.IsPathValidationError(err) { if storage.IsPathValidationError(err) {
log.Error("invalid file path attempted",
"filePath", requestBody.FilePath,
"error", err.Error(),
)
respondError(w, "Invalid file path", http.StatusBadRequest) respondError(w, "Invalid file path", http.StatusBadRequest)
return return
} }
if os.IsNotExist(err) { if os.IsNotExist(err) {
log.Debug("file not found",
"filePath", requestBody.FilePath,
)
respondError(w, "File not found", http.StatusNotFound) respondError(w, "File not found", http.StatusNotFound)
return return
} }
log.Error("failed to validate file path",
"filePath", requestBody.FilePath,
"error", err.Error(),
)
respondError(w, "Failed to update last opened file", http.StatusInternalServerError) respondError(w, "Failed to update last opened file", http.StatusInternalServerError)
return return
} }
} }
if err := h.DB.UpdateLastOpenedFile(ctx.Workspace.ID, requestBody.FilePath); err != nil { if err := h.DB.UpdateLastOpenedFile(ctx.Workspace.ID, requestBody.FilePath); err != nil {
log.Error("failed to update last opened file in database",
"filePath", requestBody.FilePath,
"error", err.Error(),
)
respondError(w, "Failed to update last opened file", http.StatusInternalServerError) respondError(w, "Failed to update last opened file", http.StatusInternalServerError)
return return
} }

View File

@@ -3,8 +3,8 @@ package handlers
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"novamd/internal/context" "novamd/internal/context"
"novamd/internal/logging"
) )
// CommitRequest represents a request to commit changes // CommitRequest represents a request to commit changes
@@ -22,6 +22,10 @@ type PullResponse struct {
Message string `json:"message" example:"Pulled changes from remote"` Message string `json:"message" example:"Pulled changes from remote"`
} }
func getGitLogger() logging.Logger {
return getHandlersLogger().WithGroup("git")
}
// StageCommitAndPush godoc // StageCommitAndPush godoc
// @Summary Stage, commit, and push changes // @Summary Stage, commit, and push changes
// @Description Stages, commits, and pushes changes to the remote repository // @Description Stages, commits, and pushes changes to the remote repository
@@ -42,21 +46,34 @@ func (h *Handler) StageCommitAndPush() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getGitLogger().With(
"handler", "StageCommitAndPush",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
var requestBody CommitRequest var requestBody CommitRequest
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
log.Error("failed to decode request body",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest) respondError(w, "Invalid request body", http.StatusBadRequest)
return return
} }
if requestBody.Message == "" { if requestBody.Message == "" {
log.Debug("empty commit message provided")
respondError(w, "Commit message is required", http.StatusBadRequest) respondError(w, "Commit message is required", http.StatusBadRequest)
return return
} }
hash, err := h.Storage.StageCommitAndPush(ctx.UserID, ctx.Workspace.ID, requestBody.Message) hash, err := h.Storage.StageCommitAndPush(ctx.UserID, ctx.Workspace.ID, requestBody.Message)
if err != nil { if err != nil {
log.Error("failed to perform git operations",
"error", err.Error(),
"commitMessage", requestBody.Message,
)
respondError(w, "Failed to stage, commit, and push changes: "+err.Error(), http.StatusInternalServerError) respondError(w, "Failed to stage, commit, and push changes: "+err.Error(), http.StatusInternalServerError)
return return
} }
@@ -82,9 +99,18 @@ func (h *Handler) PullChanges() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getGitLogger().With(
"handler", "PullChanges",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
err := h.Storage.Pull(ctx.UserID, ctx.Workspace.ID) err := h.Storage.Pull(ctx.UserID, ctx.Workspace.ID)
if err != nil { if err != nil {
log.Error("failed to pull changes from remote",
"error", err.Error(),
)
respondError(w, "Failed to pull changes: "+err.Error(), http.StatusInternalServerError) respondError(w, "Failed to pull changes: "+err.Error(), http.StatusInternalServerError)
return return
} }

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"novamd/internal/db" "novamd/internal/db"
"novamd/internal/logging"
"novamd/internal/storage" "novamd/internal/storage"
) )
@@ -18,6 +19,15 @@ type Handler struct {
Storage storage.Manager Storage storage.Manager
} }
var logger logging.Logger
func getHandlersLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("handlers")
}
return logger
}
// NewHandler creates a new handler with the given dependencies // NewHandler creates a new handler with the given dependencies
func NewHandler(db db.Database, s storage.Manager) *Handler { func NewHandler(db db.Database, s storage.Manager) *Handler {
return &Handler{ return &Handler{

View File

@@ -21,6 +21,8 @@ import (
"novamd/internal/models" "novamd/internal/models"
"novamd/internal/secrets" "novamd/internal/secrets"
"novamd/internal/storage" "novamd/internal/storage"
_ "novamd/internal/testenv"
) )
// testHarness encapsulates all the dependencies needed for testing // testHarness encapsulates all the dependencies needed for testing

View File

@@ -2,6 +2,7 @@ package handlers
import ( import (
"net/http" "net/http"
"novamd/internal/logging"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -19,8 +20,19 @@ func NewStaticHandler(staticPath string) *StaticHandler {
} }
} }
func getStaticLogger() logging.Logger {
return logging.WithGroup("static")
}
// ServeHTTP serves the static files // ServeHTTP serves the static files
func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log := getStaticLogger().With(
"handler", "ServeHTTP",
"clientIP", r.RemoteAddr,
"method", r.Method,
"url", r.URL.Path,
)
// Get the requested path // Get the requested path
requestedPath := r.URL.Path requestedPath := r.URL.Path
fullPath := filepath.Join(h.staticPath, requestedPath) fullPath := filepath.Join(h.staticPath, requestedPath)
@@ -28,6 +40,10 @@ func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Security check to prevent directory traversal // Security check to prevent directory traversal
if !strings.HasPrefix(cleanPath, h.staticPath) { if !strings.HasPrefix(cleanPath, h.staticPath) {
log.Warn("directory traversal attempt detected",
"requestedPath", requestedPath,
"cleanPath", cleanPath,
)
respondError(w, "Invalid path", http.StatusBadRequest) respondError(w, "Invalid path", http.StatusBadRequest)
return return
} }
@@ -40,6 +56,21 @@ func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Check if file exists (not counting .gz files) // Check if file exists (not counting .gz files)
stat, err := os.Stat(cleanPath) stat, err := os.Stat(cleanPath)
if err != nil || stat.IsDir() { if err != nil || stat.IsDir() {
if os.IsNotExist(err) {
log.Debug("file not found, serving index.html",
"requestedPath", requestedPath,
)
} else if stat != nil && stat.IsDir() {
log.Debug("directory requested, serving index.html",
"requestedPath", requestedPath,
)
} else {
log.Error("error checking file status",
"requestedPath", requestedPath,
"error", err.Error(),
)
}
// Serve index.html for SPA routing // Serve index.html for SPA routing
indexPath := filepath.Join(h.staticPath, "index.html") indexPath := filepath.Join(h.staticPath, "index.html")
http.ServeFile(w, r, indexPath) http.ServeFile(w, r, indexPath)
@@ -53,15 +84,16 @@ func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Encoding", "gzip") w.Header().Set("Content-Encoding", "gzip")
// Set proper content type based on original file // Set proper content type based on original file
contentType := "application/octet-stream"
switch filepath.Ext(cleanPath) { switch filepath.Ext(cleanPath) {
case ".js": case ".js":
w.Header().Set("Content-Type", "application/javascript") contentType = "application/javascript"
case ".css": case ".css":
w.Header().Set("Content-Type", "text/css") contentType = "text/css"
case ".html": case ".html":
w.Header().Set("Content-Type", "text/html") contentType = "text/html"
} }
w.Header().Set("Content-Type", contentType)
http.ServeFile(w, r, gzPath) http.ServeFile(w, r, gzPath)
return return
} }

View File

@@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"novamd/internal/context" "novamd/internal/context"
"novamd/internal/logging"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@@ -22,6 +23,10 @@ type DeleteAccountRequest struct {
Password string `json:"password"` Password string `json:"password"`
} }
func getProfileLogger() logging.Logger {
return getHandlersLogger().WithGroup("profile")
}
// UpdateProfile godoc // UpdateProfile godoc
// @Summary Update profile // @Summary Update profile
// @Description Updates the user's profile // @Description Updates the user's profile
@@ -48,9 +53,17 @@ func (h *Handler) UpdateProfile() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getProfileLogger().With(
"handler", "UpdateProfile",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
var req UpdateProfileRequest var req UpdateProfileRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Debug("failed to decode request body",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest) respondError(w, "Invalid request body", http.StatusBadRequest)
return return
} }
@@ -58,76 +71,94 @@ func (h *Handler) UpdateProfile() http.HandlerFunc {
// Get current user // Get current user
user, err := h.DB.GetUserByID(ctx.UserID) user, err := h.DB.GetUserByID(ctx.UserID)
if err != nil { if err != nil {
log.Error("failed to fetch user from database",
"error", err.Error(),
)
respondError(w, "User not found", http.StatusNotFound) respondError(w, "User not found", http.StatusNotFound)
return return
} }
// Track what's being updated for logging
updates := make(map[string]bool)
// Handle password update if requested // Handle password update if requested
if req.NewPassword != "" { if req.NewPassword != "" {
// Current password must be provided to change password
if req.CurrentPassword == "" { if req.CurrentPassword == "" {
log.Debug("password change attempted without current password")
respondError(w, "Current password is required to change password", http.StatusBadRequest) respondError(w, "Current password is required to change password", http.StatusBadRequest)
return return
} }
// Verify current password
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil { if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil {
log.Warn("incorrect password provided for password change")
respondError(w, "Current password is incorrect", http.StatusUnauthorized) respondError(w, "Current password is incorrect", http.StatusUnauthorized)
return return
} }
// Validate new password
if len(req.NewPassword) < 8 { if len(req.NewPassword) < 8 {
log.Debug("password change rejected - too short",
"passwordLength", len(req.NewPassword),
)
respondError(w, "New password must be at least 8 characters long", http.StatusBadRequest) respondError(w, "New password must be at least 8 characters long", http.StatusBadRequest)
return return
} }
// Hash new password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil { if err != nil {
log.Error("failed to hash new password",
"error", err.Error(),
)
respondError(w, "Failed to process new password", http.StatusInternalServerError) respondError(w, "Failed to process new password", http.StatusInternalServerError)
return return
} }
user.PasswordHash = string(hashedPassword) user.PasswordHash = string(hashedPassword)
updates["passwordChanged"] = true
} }
// Handle email update if requested // Handle email update if requested
if req.Email != "" && req.Email != user.Email { if req.Email != "" && req.Email != user.Email {
// Check if email change requires password verification
if req.CurrentPassword == "" { if req.CurrentPassword == "" {
log.Warn("attempted email change without current password")
respondError(w, "Current password is required to change email", http.StatusBadRequest) respondError(w, "Current password is required to change email", http.StatusBadRequest)
return return
} }
// Verify current password if not already verified for password change
if req.NewPassword == "" { if req.NewPassword == "" {
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil { if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil {
log.Warn("incorrect password provided for email change")
respondError(w, "Current password is incorrect", http.StatusUnauthorized) respondError(w, "Current password is incorrect", http.StatusUnauthorized)
return return
} }
} }
// Check if new email is already in use
existingUser, err := h.DB.GetUserByEmail(req.Email) existingUser, err := h.DB.GetUserByEmail(req.Email)
if err == nil && existingUser.ID != user.ID { if err == nil && existingUser.ID != user.ID {
log.Debug("email change rejected - already in use",
"requestedEmail", req.Email,
)
respondError(w, "Email already in use", http.StatusConflict) respondError(w, "Email already in use", http.StatusConflict)
return return
} }
user.Email = req.Email user.Email = req.Email
updates["emailChanged"] = true
} }
// Update display name if provided (no password required) // Update display name if provided
if req.DisplayName != "" { if req.DisplayName != "" {
user.DisplayName = req.DisplayName user.DisplayName = req.DisplayName
updates["displayNameChanged"] = true
} }
// Update user in database // Update user in database
if err := h.DB.UpdateUser(user); err != nil { if err := h.DB.UpdateUser(user); err != nil {
log.Error("failed to update user in database",
"error", err.Error(),
"updates", updates,
)
respondError(w, "Failed to update profile", http.StatusInternalServerError) respondError(w, "Failed to update profile", http.StatusInternalServerError)
return return
} }
// Return updated user data
respondJSON(w, user) respondJSON(w, user)
} }
} }
@@ -155,9 +186,17 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getProfileLogger().With(
"handler", "DeleteAccount",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
var req DeleteAccountRequest var req DeleteAccountRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Debug("failed to decode request body",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest) respondError(w, "Invalid request body", http.StatusBadRequest)
return return
} }
@@ -165,25 +204,32 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
// Get current user // Get current user
user, err := h.DB.GetUserByID(ctx.UserID) user, err := h.DB.GetUserByID(ctx.UserID)
if err != nil { if err != nil {
log.Error("failed to fetch user from database",
"error", err.Error(),
)
respondError(w, "User not found", http.StatusNotFound) respondError(w, "User not found", http.StatusNotFound)
return return
} }
// Verify password // Verify password
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil { if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
respondError(w, "Password is incorrect", http.StatusUnauthorized) log.Warn("incorrect password provided for account deletion")
respondError(w, "Incorrect password", http.StatusUnauthorized)
return return
} }
// Prevent admin from deleting their own account if they're the last admin // Prevent admin from deleting their own account if they're the last admin
if user.Role == "admin" { if user.Role == "admin" {
// Count number of admin users
adminCount, err := h.DB.CountAdminUsers() adminCount, err := h.DB.CountAdminUsers()
if err != nil { if err != nil {
respondError(w, "Failed to verify admin status", http.StatusInternalServerError) log.Error("failed to count admin users",
"error", err.Error(),
)
respondError(w, "Failed to get admin count", http.StatusInternalServerError)
return return
} }
if adminCount <= 1 { if adminCount <= 1 {
log.Warn("attempted to delete last admin account")
respondError(w, "Cannot delete the last admin account", http.StatusForbidden) respondError(w, "Cannot delete the last admin account", http.StatusForbidden)
return return
} }
@@ -192,6 +238,9 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
// Get user's workspaces for cleanup // Get user's workspaces for cleanup
workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID) workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID)
if err != nil { if err != nil {
log.Error("failed to fetch user workspaces",
"error", err.Error(),
)
respondError(w, "Failed to get user workspaces", http.StatusInternalServerError) respondError(w, "Failed to get user workspaces", http.StatusInternalServerError)
return return
} }
@@ -199,17 +248,31 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
// Delete workspace directories // Delete workspace directories
for _, workspace := range workspaces { for _, workspace := range workspaces {
if err := h.Storage.DeleteUserWorkspace(ctx.UserID, workspace.ID); err != nil { if err := h.Storage.DeleteUserWorkspace(ctx.UserID, workspace.ID); err != nil {
log.Error("failed to delete workspace directory",
"error", err.Error(),
"workspaceID", workspace.ID,
)
respondError(w, "Failed to delete workspace files", http.StatusInternalServerError) respondError(w, "Failed to delete workspace files", http.StatusInternalServerError)
return return
} }
log.Debug("workspace deleted",
"workspaceID", workspace.ID,
)
} }
// Delete user from database (this will cascade delete workspaces and sessions) // Delete user from database
if err := h.DB.DeleteUser(ctx.UserID); err != nil { if err := h.DB.DeleteUser(ctx.UserID); err != nil {
log.Error("failed to delete user from database",
"error", err.Error(),
)
respondError(w, "Failed to delete account", http.StatusInternalServerError) respondError(w, "Failed to delete account", http.StatusInternalServerError)
return return
} }
log.Info("user account deleted",
"email", user.Email,
"role", user.Role,
)
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
} }
} }

View File

@@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"novamd/internal/context" "novamd/internal/context"
"novamd/internal/logging"
"novamd/internal/models" "novamd/internal/models"
) )
@@ -19,6 +20,10 @@ type LastWorkspaceNameResponse struct {
LastWorkspaceName string `json:"lastWorkspaceName"` LastWorkspaceName string `json:"lastWorkspaceName"`
} }
func getWorkspaceLogger() logging.Logger {
return getHandlersLogger().WithGroup("workspace")
}
// ListWorkspaces godoc // ListWorkspaces godoc
// @Summary List workspaces // @Summary List workspaces
// @Description Lists all workspaces for the current user // @Description Lists all workspaces for the current user
@@ -35,9 +40,17 @@ func (h *Handler) ListWorkspaces() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getWorkspaceLogger().With(
"handler", "ListWorkspaces",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID) workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID)
if err != nil { if err != nil {
log.Error("failed to fetch workspaces from database",
"error", err.Error(),
)
respondError(w, "Failed to list workspaces", http.StatusInternalServerError) respondError(w, "Failed to list workspaces", http.StatusInternalServerError)
return return
} }
@@ -68,25 +81,44 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getWorkspaceLogger().With(
"handler", "CreateWorkspace",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
var workspace models.Workspace var workspace models.Workspace
if err := json.NewDecoder(r.Body).Decode(&workspace); err != nil { if err := json.NewDecoder(r.Body).Decode(&workspace); err != nil {
log.Debug("invalid request body received",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest) respondError(w, "Invalid request body", http.StatusBadRequest)
return return
} }
if err := workspace.ValidateGitSettings(); err != nil { if err := workspace.ValidateGitSettings(); err != nil {
log.Debug("invalid git settings provided",
"error", err.Error(),
)
respondError(w, "Invalid workspace", http.StatusBadRequest) respondError(w, "Invalid workspace", http.StatusBadRequest)
return return
} }
workspace.UserID = ctx.UserID workspace.UserID = ctx.UserID
if err := h.DB.CreateWorkspace(&workspace); err != nil { if err := h.DB.CreateWorkspace(&workspace); err != nil {
log.Error("failed to create workspace in database",
"error", err.Error(),
"workspaceName", workspace.Name,
)
respondError(w, "Failed to create workspace", http.StatusInternalServerError) respondError(w, "Failed to create workspace", http.StatusInternalServerError)
return return
} }
if err := h.Storage.InitializeUserWorkspace(workspace.UserID, workspace.ID); err != nil { if err := h.Storage.InitializeUserWorkspace(workspace.UserID, workspace.ID); err != nil {
log.Error("failed to initialize workspace directory",
"error", err.Error(),
"workspaceID", workspace.ID,
)
respondError(w, "Failed to initialize workspace directory", http.StatusInternalServerError) respondError(w, "Failed to initialize workspace directory", http.StatusInternalServerError)
return return
} }
@@ -101,11 +133,20 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc {
workspace.GitCommitName, workspace.GitCommitName,
workspace.GitCommitEmail, workspace.GitCommitEmail,
); err != nil { ); err != nil {
log.Error("failed to setup git repository",
"error", err.Error(),
"workspaceID", workspace.ID,
)
respondError(w, "Failed to setup git repo: "+err.Error(), http.StatusInternalServerError) respondError(w, "Failed to setup git repo: "+err.Error(), http.StatusInternalServerError)
return return
} }
} }
log.Info("workspace created",
"workspaceID", workspace.ID,
"workspaceName", workspace.Name,
"gitEnabled", workspace.GitEnabled,
)
respondJSON(w, workspace) respondJSON(w, workspace)
} }
} }
@@ -171,9 +212,18 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getWorkspaceLogger().With(
"handler", "UpdateWorkspace",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
var workspace models.Workspace var workspace models.Workspace
if err := json.NewDecoder(r.Body).Decode(&workspace); err != nil { if err := json.NewDecoder(r.Body).Decode(&workspace); err != nil {
log.Debug("invalid request body received",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest) respondError(w, "Invalid request body", http.StatusBadRequest)
return return
} }
@@ -184,12 +234,23 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc {
// Validate the workspace // Validate the workspace
if err := workspace.Validate(); err != nil { if err := workspace.Validate(); err != nil {
log.Debug("invalid workspace configuration",
"error", err.Error(),
)
respondError(w, err.Error(), http.StatusBadRequest) respondError(w, err.Error(), http.StatusBadRequest)
return return
} }
// Track what's changed for logging
changes := map[string]bool{
"gitSettings": gitSettingsChanged(&workspace, ctx.Workspace),
"name": workspace.Name != ctx.Workspace.Name,
"theme": workspace.Theme != ctx.Workspace.Theme,
"autoSave": workspace.AutoSave != ctx.Workspace.AutoSave,
}
// Handle Git repository setup/teardown if Git settings changed // Handle Git repository setup/teardown if Git settings changed
if gitSettingsChanged(&workspace, ctx.Workspace) { if changes["gitSettings"] {
if workspace.GitEnabled { if workspace.GitEnabled {
if err := h.Storage.SetupGitRepo( if err := h.Storage.SetupGitRepo(
ctx.UserID, ctx.UserID,
@@ -200,16 +261,21 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc {
workspace.GitCommitName, workspace.GitCommitName,
workspace.GitCommitEmail, workspace.GitCommitEmail,
); err != nil { ); err != nil {
log.Error("failed to setup git repository",
"error", err.Error(),
)
respondError(w, "Failed to setup git repo: "+err.Error(), http.StatusInternalServerError) respondError(w, "Failed to setup git repo: "+err.Error(), http.StatusInternalServerError)
return return
} }
} else { } else {
h.Storage.DisableGitRepo(ctx.UserID, ctx.Workspace.ID) h.Storage.DisableGitRepo(ctx.UserID, ctx.Workspace.ID)
} }
} }
if err := h.DB.UpdateWorkspace(&workspace); err != nil { if err := h.DB.UpdateWorkspace(&workspace); err != nil {
log.Error("failed to update workspace in database",
"error", err.Error(),
)
respondError(w, "Failed to update workspace", http.StatusInternalServerError) respondError(w, "Failed to update workspace", http.StatusInternalServerError)
return return
} }
@@ -241,15 +307,25 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getWorkspaceLogger().With(
"handler", "DeleteWorkspace",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
// Check if this is the user's last workspace // Check if this is the user's last workspace
workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID) workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID)
if err != nil { if err != nil {
log.Error("failed to fetch workspaces from database",
"error", err.Error(),
)
respondError(w, "Failed to get workspaces", http.StatusInternalServerError) respondError(w, "Failed to get workspaces", http.StatusInternalServerError)
return return
} }
if len(workspaces) <= 1 { if len(workspaces) <= 1 {
log.Debug("attempted to delete last workspace")
respondError(w, "Cannot delete the last workspace", http.StatusBadRequest) respondError(w, "Cannot delete the last workspace", http.StatusBadRequest)
return return
} }
@@ -265,14 +341,19 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
} }
} }
// Start transaction
tx, err := h.DB.Begin() tx, err := h.DB.Begin()
if err != nil { if err != nil {
log.Error("failed to start database transaction",
"error", err.Error(),
)
respondError(w, "Failed to start transaction", http.StatusInternalServerError) respondError(w, "Failed to start transaction", http.StatusInternalServerError)
return return
} }
defer func() { defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
log.Error("failed to rollback transaction",
"error", err.Error(),
)
respondError(w, "Failed to rollback transaction", http.StatusInternalServerError) respondError(w, "Failed to rollback transaction", http.StatusInternalServerError)
} }
}() }()
@@ -280,6 +361,10 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
// Update last workspace ID first // Update last workspace ID first
err = h.DB.UpdateLastWorkspaceTx(tx, ctx.UserID, nextWorkspaceID) err = h.DB.UpdateLastWorkspaceTx(tx, ctx.UserID, nextWorkspaceID)
if err != nil { if err != nil {
log.Error("failed to update last workspace reference",
"error", err.Error(),
"nextWorkspaceID", nextWorkspaceID,
)
respondError(w, "Failed to update last workspace", http.StatusInternalServerError) respondError(w, "Failed to update last workspace", http.StatusInternalServerError)
return return
} }
@@ -287,16 +372,27 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
// Delete the workspace // Delete the workspace
err = h.DB.DeleteWorkspaceTx(tx, ctx.Workspace.ID) err = h.DB.DeleteWorkspaceTx(tx, ctx.Workspace.ID)
if err != nil { if err != nil {
log.Error("failed to delete workspace from database",
"error", err.Error(),
)
respondError(w, "Failed to delete workspace", http.StatusInternalServerError) respondError(w, "Failed to delete workspace", http.StatusInternalServerError)
return return
} }
// Commit transaction // Commit transaction
if err = tx.Commit(); err != nil { if err = tx.Commit(); err != nil {
log.Error("failed to commit transaction",
"error", err.Error(),
)
respondError(w, "Failed to commit transaction", http.StatusInternalServerError) respondError(w, "Failed to commit transaction", http.StatusInternalServerError)
return return
} }
log.Info("workspace deleted",
"workspaceName", ctx.Workspace.Name,
"nextWorkspaceName", nextWorkspaceName,
)
// Return the next workspace ID in the response so frontend knows where to redirect // Return the next workspace ID in the response so frontend knows where to redirect
respondJSON(w, &DeleteWorkspaceResponse{NextWorkspaceName: nextWorkspaceName}) respondJSON(w, &DeleteWorkspaceResponse{NextWorkspaceName: nextWorkspaceName})
} }
@@ -318,9 +414,17 @@ func (h *Handler) GetLastWorkspaceName() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getWorkspaceLogger().With(
"handler", "GetLastWorkspaceName",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
workspaceName, err := h.DB.GetLastWorkspaceName(ctx.UserID) workspaceName, err := h.DB.GetLastWorkspaceName(ctx.UserID)
if err != nil { if err != nil {
log.Error("failed to fetch last workspace name",
"error", err.Error(),
)
respondError(w, "Failed to get last workspace", http.StatusInternalServerError) respondError(w, "Failed to get last workspace", http.StatusInternalServerError)
return return
} }
@@ -347,17 +451,29 @@ func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc {
if !ok { if !ok {
return return
} }
log := getWorkspaceLogger().With(
"handler", "UpdateLastWorkspaceName",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
var requestBody struct { var requestBody struct {
WorkspaceName string `json:"workspaceName"` WorkspaceName string `json:"workspaceName"`
} }
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
log.Debug("invalid request body received",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest) respondError(w, "Invalid request body", http.StatusBadRequest)
return return
} }
if err := h.DB.UpdateLastWorkspace(ctx.UserID, requestBody.WorkspaceName); err != nil { if err := h.DB.UpdateLastWorkspace(ctx.UserID, requestBody.WorkspaceName); err != nil {
log.Error("failed to update last workspace",
"error", err.Error(),
"workspaceName", requestBody.WorkspaceName,
)
respondError(w, "Failed to update last workspace", http.StatusInternalServerError) respondError(w, "Failed to update last workspace", http.StatusInternalServerError)
return return
} }

View File

@@ -0,0 +1,116 @@
// Package logging provides a simple logging interface for the server.
package logging
import (
"log/slog"
"os"
)
// Logger represents the interface for logging operations
type Logger interface {
Debug(msg string, args ...any)
Info(msg string, args ...any)
Warn(msg string, args ...any)
Error(msg string, args ...any)
WithGroup(name string) Logger
With(args ...any) Logger
}
// Implementation of the Logger interface using slog
type logger struct {
logger *slog.Logger
}
// Logger is the global logger instance
var defaultLogger Logger
// LogLevel represents the log level
type LogLevel slog.Level
// Log levels
const (
DEBUG LogLevel = LogLevel(slog.LevelDebug)
INFO LogLevel = LogLevel(slog.LevelInfo)
WARN LogLevel = LogLevel(slog.LevelWarn)
ERROR LogLevel = LogLevel(slog.LevelError)
)
// Setup initializes the logger with the given minimum log level
func Setup(minLevel LogLevel) {
opts := &slog.HandlerOptions{
Level: slog.Level(minLevel),
}
defaultLogger = &logger{
logger: slog.New(slog.NewTextHandler(os.Stdout, opts)),
}
}
// ParseLogLevel converts a string to a LogLevel
func ParseLogLevel(level string) LogLevel {
switch level {
case "debug":
return DEBUG
case "warn":
return WARN
case "error":
return ERROR
default:
return INFO
}
}
// Implementation of Logger interface methods
func (l *logger) Debug(msg string, args ...any) {
l.logger.Debug(msg, args...)
}
func (l *logger) Info(msg string, args ...any) {
l.logger.Info(msg, args...)
}
func (l *logger) Warn(msg string, args ...any) {
l.logger.Warn(msg, args...)
}
func (l *logger) Error(msg string, args ...any) {
l.logger.Error(msg, args...)
}
func (l *logger) WithGroup(name string) Logger {
return &logger{logger: l.logger.WithGroup(name)}
}
func (l *logger) With(args ...any) Logger {
return &logger{logger: l.logger.With(args...)}
}
// Debug logs a debug message
func Debug(msg string, args ...any) {
defaultLogger.Debug(msg, args...)
}
// Info logs an info message
func Info(msg string, args ...any) {
defaultLogger.Info(msg, args...)
}
// Warn logs a warning message
func Warn(msg string, args ...any) {
defaultLogger.Warn(msg, args...)
}
// Error logs an error message
func Error(msg string, args ...any) {
defaultLogger.Error(msg, args...)
}
// WithGroup adds a group to the logger context
func WithGroup(name string) Logger {
return defaultLogger.WithGroup(name)
}
// With adds key-value pairs to the logger context
func With(args ...any) Logger {
return defaultLogger.With(args...)
}

View File

@@ -8,6 +8,8 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io" "io"
"novamd/internal/logging"
) )
// Service is an interface for encrypting and decrypting strings // Service is an interface for encrypting and decrypting strings
@@ -20,6 +22,15 @@ type encryptor struct {
gcm cipher.AEAD gcm cipher.AEAD
} }
var logger logging.Logger
func getLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("secrets")
}
return logger
}
// ValidateKey checks if the provided base64-encoded key is suitable for AES-256 // ValidateKey checks if the provided base64-encoded key is suitable for AES-256
func ValidateKey(key string) error { func ValidateKey(key string) error {
_, err := decodeAndValidateKey(key) _, err := decodeAndValidateKey(key)
@@ -73,7 +84,10 @@ func NewService(key string) (Service, error) {
// Encrypt encrypts the plaintext using AES-256-GCM // Encrypt encrypts the plaintext using AES-256-GCM
func (e *encryptor) Encrypt(plaintext string) (string, error) { func (e *encryptor) Encrypt(plaintext string) (string, error) {
log := getLogger()
if plaintext == "" { if plaintext == "" {
log.Debug("empty plaintext provided, returning empty string")
return "", nil return "", nil
} }
@@ -83,12 +97,18 @@ func (e *encryptor) Encrypt(plaintext string) (string, error) {
} }
ciphertext := e.gcm.Seal(nonce, nonce, []byte(plaintext), nil) ciphertext := e.gcm.Seal(nonce, nonce, []byte(plaintext), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil encoded := base64.StdEncoding.EncodeToString(ciphertext)
log.Debug("data encrypted", "inputLength", len(plaintext), "outputLength", len(encoded))
return encoded, nil
} }
// Decrypt decrypts the ciphertext using AES-256-GCM // Decrypt decrypts the ciphertext using AES-256-GCM
func (e *encryptor) Decrypt(ciphertext string) (string, error) { func (e *encryptor) Decrypt(ciphertext string) (string, error) {
log := getLogger()
if ciphertext == "" { if ciphertext == "" {
log.Debug("empty ciphertext provided, returning empty string")
return "", nil return "", nil
} }
@@ -108,5 +128,6 @@ func (e *encryptor) Decrypt(ciphertext string) (string, error) {
return "", err return "", err
} }
log.Debug("data decrypted", "inputLength", len(ciphertext), "outputLength", len(plaintext))
return string(plaintext), nil return string(plaintext), nil
} }

View File

@@ -6,6 +6,7 @@ import (
"testing" "testing"
"novamd/internal/secrets" "novamd/internal/secrets"
_ "novamd/internal/testenv"
) )
func TestValidateKey(t *testing.T) { func TestValidateKey(t *testing.T) {

View File

@@ -1,5 +1,4 @@
// storage/errors.go // Package storage provides functionalities to interact with the storage system (filesystem).
package storage package storage
import ( import (

View File

@@ -1,5 +1,3 @@
// Package storage provides functionalities to interact with the file system,
// including listing files, finding files by name, getting file content, saving files, and deleting files.
package storage package storage
import ( import (
@@ -33,7 +31,12 @@ type FileNode struct {
// Workspace is identified by the given userID and workspaceID. // Workspace is identified by the given userID and workspaceID.
func (s *Service) ListFilesRecursively(userID, workspaceID int) ([]FileNode, error) { func (s *Service) ListFilesRecursively(userID, workspaceID int) ([]FileNode, error) {
workspacePath := s.GetWorkspacePath(userID, workspaceID) workspacePath := s.GetWorkspacePath(userID, workspaceID)
return s.walkDirectory(workspacePath, "") nodes, err := s.walkDirectory(workspacePath, "")
if err != nil {
return nil, err
}
return nodes, nil
} }
// walkDirectory recursively walks the directory and returns a list of files and directories. // walkDirectory recursively walks the directory and returns a list of files and directories.
@@ -147,6 +150,8 @@ func (s *Service) GetFileContent(userID, workspaceID int, filePath string) ([]by
// SaveFile writes the content to the file at the given filePath. // SaveFile writes the content to the file at the given filePath.
// Path must be a relative path within the workspace directory given by userID and workspaceID. // Path must be a relative path within the workspace directory given by userID and workspaceID.
func (s *Service) SaveFile(userID, workspaceID int, filePath string, content []byte) error { func (s *Service) SaveFile(userID, workspaceID int, filePath string, content []byte) error {
log := getLogger()
fullPath, err := s.ValidatePath(userID, workspaceID, filePath) fullPath, err := s.ValidatePath(userID, workspaceID, filePath)
if err != nil { if err != nil {
return err return err
@@ -157,17 +162,36 @@ func (s *Service) SaveFile(userID, workspaceID int, filePath string, content []b
return err return err
} }
return s.fs.WriteFile(fullPath, content, 0644) if err := s.fs.WriteFile(fullPath, content, 0644); err != nil {
return err
}
log.Debug("file saved",
"userID", userID,
"workspaceID", workspaceID,
"path", filePath,
"size", len(content))
return nil
} }
// DeleteFile deletes the file at the given filePath. // DeleteFile deletes the file at the given filePath.
// Path must be a relative path within the workspace directory given by userID and workspaceID. // Path must be a relative path within the workspace directory given by userID and workspaceID.
func (s *Service) DeleteFile(userID, workspaceID int, filePath string) error { func (s *Service) DeleteFile(userID, workspaceID int, filePath string) error {
log := getLogger()
fullPath, err := s.ValidatePath(userID, workspaceID, filePath) fullPath, err := s.ValidatePath(userID, workspaceID, filePath)
if err != nil { if err != nil {
return err return err
} }
return s.fs.Remove(fullPath)
if err := s.fs.Remove(fullPath); err != nil {
return err
}
log.Debug("file deleted",
"userID", userID,
"workspaceID", workspaceID,
"path", filePath)
return nil
} }
// FileCountStats holds statistics about files in a workspace // FileCountStats holds statistics about files in a workspace
@@ -186,13 +210,22 @@ func (s *Service) GetFileStats(userID, workspaceID int) (*FileCountStats, error)
return nil, fmt.Errorf("workspace directory does not exist") return nil, fmt.Errorf("workspace directory does not exist")
} }
return s.countFilesInPath(workspacePath) stats, err := s.countFilesInPath(workspacePath)
if err != nil {
return nil, err
}
return stats, nil
} }
// GetTotalFileStats returns the total file statistics for the storage. // GetTotalFileStats returns the total file statistics for the storage.
func (s *Service) GetTotalFileStats() (*FileCountStats, error) { func (s *Service) GetTotalFileStats() (*FileCountStats, error) {
return s.countFilesInPath(s.RootDir) stats, err := s.countFilesInPath(s.RootDir)
if err != nil {
return nil, err
}
return stats, nil
} }
// countFilesInPath counts the total number of files and the total size of files in the given directory. // countFilesInPath counts the total number of files and the total size of files in the given directory.

View File

@@ -5,6 +5,8 @@ import (
"novamd/internal/storage" "novamd/internal/storage"
"path/filepath" "path/filepath"
"testing" "testing"
_ "novamd/internal/testenv"
) )
// TestFileNode ensures FileNode structs are created correctly // TestFileNode ensures FileNode structs are created correctly

View File

@@ -2,6 +2,7 @@ package storage
import ( import (
"io/fs" "io/fs"
"novamd/internal/logging"
"os" "os"
) )
@@ -17,6 +18,15 @@ type fileSystem interface {
IsNotExist(err error) bool IsNotExist(err error) bool
} }
var logger logging.Logger
func getLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("storage")
}
return logger
}
// osFS implements the FileSystem interface using the real filesystem. // osFS implements the FileSystem interface using the real filesystem.
type osFS struct{} type osFS struct{}

View File

@@ -5,6 +5,8 @@ import (
"io/fs" "io/fs"
"path/filepath" "path/filepath"
"time" "time"
_ "novamd/internal/testenv"
) )
type mockDirEntry struct { type mockDirEntry struct {

View File

@@ -17,15 +17,23 @@ type RepositoryManager interface {
// The repository is cloned from the given gitURL using the given gitUser and gitToken. // The repository is cloned from the given gitURL using the given gitUser and gitToken.
func (s *Service) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken, commitName, commitEmail string) error { func (s *Service) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken, commitName, commitEmail string) error {
workspacePath := s.GetWorkspacePath(userID, workspaceID) workspacePath := s.GetWorkspacePath(userID, workspaceID)
if _, ok := s.GitRepos[userID]; !ok { if _, ok := s.GitRepos[userID]; !ok {
s.GitRepos[userID] = make(map[int]git.Client) s.GitRepos[userID] = make(map[int]git.Client)
} }
s.GitRepos[userID][workspaceID] = s.newGitClient(gitURL, gitUser, gitToken, workspacePath, commitName, commitEmail) s.GitRepos[userID][workspaceID] = s.newGitClient(gitURL, gitUser, gitToken, workspacePath, commitName, commitEmail)
return s.GitRepos[userID][workspaceID].EnsureRepo() return s.GitRepos[userID][workspaceID].EnsureRepo()
} }
// DisableGitRepo disables the Git repository for the given userID and workspaceID. // DisableGitRepo disables the Git repository for the given userID and workspaceID.
func (s *Service) DisableGitRepo(userID, workspaceID int) { func (s *Service) DisableGitRepo(userID, workspaceID int) {
log := getLogger().WithGroup("git")
log.Debug("disabling git repository",
"userID", userID,
"workspaceID", workspaceID)
if userRepos, ok := s.GitRepos[userID]; ok { if userRepos, ok := s.GitRepos[userID]; ok {
delete(userRepos, workspaceID) delete(userRepos, workspaceID)
if len(userRepos) == 0 { if len(userRepos) == 0 {
@@ -47,8 +55,11 @@ func (s *Service) StageCommitAndPush(userID, workspaceID int, message string) (g
return git.CommitHash{}, err return git.CommitHash{}, err
} }
err = repo.Push() if err = repo.Push(); err != nil {
return hash, err return hash, err
}
return hash, nil
} }
// Pull pulls the changes from the remote Git repository. // Pull pulls the changes from the remote Git repository.
@@ -59,7 +70,12 @@ func (s *Service) Pull(userID, workspaceID int) error {
return fmt.Errorf("git settings not configured for this workspace") return fmt.Errorf("git settings not configured for this workspace")
} }
return repo.Pull() err := repo.Pull()
if err != nil {
return err
}
return nil
} }
// getGitRepo returns the Git repository for the given user and workspace IDs. // getGitRepo returns the Git repository for the given user and workspace IDs.

View File

@@ -6,6 +6,7 @@ import (
"novamd/internal/git" "novamd/internal/git"
"novamd/internal/storage" "novamd/internal/storage"
_ "novamd/internal/testenv"
) )
// MockGitClient implements git.Client interface for testing // MockGitClient implements git.Client interface for testing

View File

@@ -43,6 +43,11 @@ func (s *Service) GetWorkspacePath(userID, workspaceID int) string {
// InitializeUserWorkspace creates the workspace directory for the given userID and workspaceID. // InitializeUserWorkspace creates the workspace directory for the given userID and workspaceID.
func (s *Service) InitializeUserWorkspace(userID, workspaceID int) error { func (s *Service) InitializeUserWorkspace(userID, workspaceID int) error {
log := getLogger()
log.Debug("initializing workspace directory",
"userID", userID,
"workspaceID", workspaceID)
workspacePath := s.GetWorkspacePath(userID, workspaceID) workspacePath := s.GetWorkspacePath(userID, workspaceID)
err := s.fs.MkdirAll(workspacePath, 0755) err := s.fs.MkdirAll(workspacePath, 0755)
if err != nil { if err != nil {
@@ -54,6 +59,11 @@ func (s *Service) InitializeUserWorkspace(userID, workspaceID int) error {
// DeleteUserWorkspace deletes the workspace directory for the given userID and workspaceID. // DeleteUserWorkspace deletes the workspace directory for the given userID and workspaceID.
func (s *Service) DeleteUserWorkspace(userID, workspaceID int) error { func (s *Service) DeleteUserWorkspace(userID, workspaceID int) error {
log := getLogger()
log.Debug("deleting workspace directory",
"userID", userID,
"workspaceID", workspaceID)
workspacePath := s.GetWorkspacePath(userID, workspaceID) workspacePath := s.GetWorkspacePath(userID, workspaceID)
err := s.fs.RemoveAll(workspacePath) err := s.fs.RemoveAll(workspacePath)
if err != nil { if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"testing" "testing"
"novamd/internal/storage" "novamd/internal/storage"
_ "novamd/internal/testenv"
) )
func TestValidatePath(t *testing.T) { func TestValidatePath(t *testing.T) {

View File

@@ -0,0 +1,9 @@
// Package testenv provides a setup for testing the application.
package testenv
import "novamd/internal/logging"
func init() {
// Initialize the logger
logging.Setup(logging.ERROR)
}