Merge pull request #28 from lordmathis/feat/logging

Implement structured logging
This commit is contained in:
2024-12-19 23:45:28 +01:00
committed by GitHub
54 changed files with 1604 additions and 305 deletions

1
.gitignore vendored
View File

@@ -157,6 +157,7 @@ go.work.sum
# env file
.env
.env.dev
main
*.db

14
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,14 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Launch NovaMD Server",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/server/cmd/server/main.go",
"cwd": "${workspaceFolder}",
"envFile": "${workspaceFolder}/server/.env"
}
]
}

View File

@@ -5,6 +5,7 @@ import (
"log"
"novamd/internal/app"
"novamd/internal/logging"
)
// @title NovaMD API
@@ -23,6 +24,10 @@ func main() {
log.Fatal("Failed to load configuration:", err)
}
// Setup logging
logging.Setup(cfg.LogLevel)
logging.Debug("Configuration loaded", "config", cfg.Redact())
// Initialize and start server
options, err := app.DefaultOptions(cfg)
if err != nil {
@@ -32,7 +37,7 @@ func main() {
server := app.NewServer(options)
defer func() {
if err := server.Close(); err != nil {
log.Println("Error closing server:", err)
logging.Error("Failed to close server:", err)
}
}()

View File

@@ -2,6 +2,7 @@ package app
import (
"fmt"
"novamd/internal/logging"
"novamd/internal/secrets"
"os"
"strconv"
@@ -25,6 +26,7 @@ type Config struct {
RateLimitRequests int
RateLimitWindow time.Duration
IsDevelopment bool
LogLevel logging.LogLevel
}
// DefaultConfig returns a new Config instance with default values
@@ -54,6 +56,16 @@ func (c *Config) validate() error {
return nil
}
// Redact redacts sensitive fields from a Config instance
func (c *Config) Redact() *Config {
redacted := *c
redacted.AdminPassword = "[REDACTED]"
redacted.AdminEmail = "[REDACTED]"
redacted.EncryptionKey = "[REDACTED]"
redacted.JWTSigningKey = "[REDACTED]"
return &redacted
}
// LoadConfig creates a new Config instance with values from environment variables
func LoadConfig() (*Config, error) {
config := DefaultConfig()
@@ -97,17 +109,29 @@ func LoadConfig() (*Config, error) {
// Configure rate limiting
if reqStr := os.Getenv("NOVAMD_RATE_LIMIT_REQUESTS"); reqStr != "" {
if parsed, err := strconv.Atoi(reqStr); err == nil {
parsed, err := strconv.Atoi(reqStr)
if err == nil {
config.RateLimitRequests = parsed
}
}
if windowStr := os.Getenv("NOVAMD_RATE_LIMIT_WINDOW"); windowStr != "" {
if parsed, err := time.ParseDuration(windowStr); err == nil {
parsed, err := time.ParseDuration(windowStr)
if err == nil {
config.RateLimitWindow = parsed
}
}
// Configure log level, if isDevelopment is set, default to debug
if logLevel := os.Getenv("NOVAMD_LOG_LEVEL"); logLevel != "" {
parsed := logging.ParseLogLevel(logLevel)
config.LogLevel = parsed
} else if config.IsDevelopment {
config.LogLevel = logging.DEBUG
} else {
config.LogLevel = logging.INFO
}
// Validate all settings
if err := config.validate(); err != nil {
return nil, err

View File

@@ -5,6 +5,8 @@ import (
"os"
"testing"
"time"
_ "novamd/internal/testenv"
)
func TestDefaultConfig(t *testing.T) {

View File

@@ -4,13 +4,13 @@ package app
import (
"database/sql"
"fmt"
"log"
"time"
"golang.org/x/crypto/bcrypt"
"novamd/internal/auth"
"novamd/internal/db"
"novamd/internal/logging"
"novamd/internal/models"
"novamd/internal/secrets"
"novamd/internal/storage"
@@ -18,6 +18,7 @@ import (
// initSecretsService initializes the secrets service
func initSecretsService(cfg *Config) (secrets.Service, error) {
logging.Debug("initializing secrets service")
secretsService, err := secrets.NewService(cfg.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("failed to initialize secrets service: %w", err)
@@ -27,6 +28,8 @@ func initSecretsService(cfg *Config) (secrets.Service, error) {
// initDatabase initializes and migrates the database
func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, error) {
logging.Debug("initializing database", "path", cfg.DBPath)
database, err := db.Init(cfg.DBPath, secretsService)
if err != nil {
return nil, fmt.Errorf("failed to initialize database: %w", err)
@@ -41,9 +44,15 @@ func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, err
// initAuth initializes JWT and session services
func initAuth(cfg *Config, database db.Database) (auth.JWTManager, auth.SessionManager, auth.CookieManager, error) {
logging.Debug("initializing authentication services")
accessTokeExpiry := 15 * time.Minute
refreshTokenExpiry := 7 * 24 * time.Hour
// Get or generate JWT signing key
signingKey := cfg.JWTSigningKey
if signingKey == "" {
logging.Debug("no JWT signing key provided, generating new key")
var err error
signingKey, err = database.EnsureJWTSecret()
if err != nil {
@@ -51,20 +60,16 @@ func initAuth(cfg *Config, database db.Database) (auth.JWTManager, auth.SessionM
}
}
// Initialize JWT service
jwtManager, err := auth.NewJWTService(auth.JWTConfig{
SigningKey: signingKey,
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 7 * 24 * time.Hour,
AccessTokenExpiry: accessTokeExpiry,
RefreshTokenExpiry: refreshTokenExpiry,
})
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to initialize JWT service: %w", err)
}
// Initialize session service
sessionManager := auth.NewSessionService(database, jwtManager)
// Cookie service
cookieService := auth.NewCookieService(cfg.IsDevelopment, cfg.Domain)
return jwtManager, sessionManager, cookieService, nil
@@ -72,26 +77,26 @@ func initAuth(cfg *Config, database db.Database) (auth.JWTManager, auth.SessionM
// setupAdminUser creates the admin user if it doesn't exist
func setupAdminUser(database db.Database, storageManager storage.Manager, cfg *Config) error {
adminEmail := cfg.AdminEmail
adminPassword := cfg.AdminPassword
// Check if admin user exists
adminUser, err := database.GetUserByEmail(adminEmail)
adminUser, err := database.GetUserByEmail(cfg.AdminEmail)
if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("failed to check for existing admin user: %w", err)
}
if adminUser != nil {
return nil // Admin user already exists
} else if err != sql.ErrNoRows {
return err
logging.Debug("admin user already exists", "userId", adminUser.ID)
return nil
}
// Hash the password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(adminPassword), bcrypt.DefaultCost)
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(cfg.AdminPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
return fmt.Errorf("failed to hash admin password: %w", err)
}
// Create admin user
adminUser = &models.User{
Email: adminEmail,
Email: cfg.AdminEmail,
DisplayName: "Admin",
PasswordHash: string(hashedPassword),
Role: models.RoleAdmin,
@@ -102,13 +107,14 @@ func setupAdminUser(database db.Database, storageManager storage.Manager, cfg *C
return fmt.Errorf("failed to create admin user: %w", err)
}
// Initialize workspace directory
err = storageManager.InitializeUserWorkspace(createdUser.ID, createdUser.LastWorkspaceID)
if err != nil {
return fmt.Errorf("failed to initialize admin workspace: %w", err)
}
log.Printf("Created admin user with ID: %d and default workspace with ID: %d", createdUser.ID, createdUser.LastWorkspaceID)
logging.Info("admin user setup completed",
"userId", createdUser.ID,
"workspaceId", createdUser.LastWorkspaceID)
return nil
}

View File

@@ -3,6 +3,7 @@ package app
import (
"novamd/internal/auth"
"novamd/internal/db"
"novamd/internal/logging"
"novamd/internal/storage"
)
@@ -33,6 +34,9 @@ func DefaultOptions(cfg *Config) (*Options, error) {
// Initialize storage
storageManager := storage.NewService(cfg.WorkDir)
// Initialize logger
logging.Setup(cfg.LogLevel)
// Initialize auth services
jwtManager, sessionService, cookieService, err := initAuth(cfg, database)
if err != nil {

View File

@@ -4,6 +4,7 @@ import (
"novamd/internal/auth"
"novamd/internal/context"
"novamd/internal/handlers"
"novamd/internal/logging"
"time"
"github.com/go-chi/chi/v5"
@@ -19,6 +20,7 @@ import (
// setupRouter creates and configures the chi router with middleware and routes
func setupRouter(o Options) *chi.Mux {
logging.Debug("setting up router")
r := chi.NewRouter()
// Basic middleware

View File

@@ -1,8 +1,8 @@
package app
import (
"log"
"net/http"
"novamd/internal/logging"
"github.com/go-chi/chi/v5"
)
@@ -25,12 +25,13 @@ func NewServer(options *Options) *Server {
func (s *Server) Start() error {
// Start server
addr := ":" + s.options.Config.Port
log.Printf("Server starting on port %s", s.options.Config.Port)
logging.Info("starting server", "address", addr)
return http.ListenAndServe(addr, s.router)
}
// Close handles graceful shutdown of server dependencies
func (s *Server) Close() error {
logging.Info("shutting down server")
return s.options.Database.Close()
}

View File

@@ -3,8 +3,22 @@ package auth
import (
"net/http"
"novamd/internal/logging"
)
var logger logging.Logger
func getAuthLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("auth")
}
return logger
}
func getCookieLogger() logging.Logger {
return getAuthLogger().WithGroup("cookie")
}
// CookieManager interface defines methods for generating cookies
type CookieManager interface {
GenerateAccessTokenCookie(token string) *http.Cookie
@@ -22,6 +36,8 @@ type cookieManager struct {
// NewCookieService creates a new cookie service
func NewCookieService(isDevelopment bool, domain string) CookieManager {
log := getCookieLogger()
secure := !isDevelopment
var sameSite http.SameSite
@@ -31,6 +47,11 @@ func NewCookieService(isDevelopment bool, domain string) CookieManager {
sameSite = http.SameSiteStrictMode
}
log.Debug("creating cookie service",
"secure", secure,
"sameSite", sameSite,
"domain", domain)
return &cookieManager{
Domain: domain,
Secure: secure,
@@ -40,6 +61,12 @@ func NewCookieService(isDevelopment bool, domain string) CookieManager {
// GenerateAccessTokenCookie creates a new cookie for the access token
func (c *cookieManager) GenerateAccessTokenCookie(token string) *http.Cookie {
log := getCookieLogger()
log.Debug("generating access token cookie",
"secure", c.Secure,
"sameSite", c.SameSite,
"maxAge", 900)
return &http.Cookie{
Name: "access_token",
Value: token,
@@ -53,6 +80,12 @@ func (c *cookieManager) GenerateAccessTokenCookie(token string) *http.Cookie {
// GenerateRefreshTokenCookie creates a new cookie for the refresh token
func (c *cookieManager) GenerateRefreshTokenCookie(token string) *http.Cookie {
log := getCookieLogger()
log.Debug("generating refresh token cookie",
"secure", c.Secure,
"sameSite", c.SameSite,
"maxAge", 604800)
return &http.Cookie{
Name: "refresh_token",
Value: token,
@@ -66,6 +99,13 @@ func (c *cookieManager) GenerateRefreshTokenCookie(token string) *http.Cookie {
// GenerateCSRFCookie creates a new cookie for the CSRF token
func (c *cookieManager) GenerateCSRFCookie(token string) *http.Cookie {
log := getCookieLogger()
log.Debug("generating CSRF cookie",
"secure", c.Secure,
"sameSite", c.SameSite,
"maxAge", 900,
"httpOnly", false)
return &http.Cookie{
Name: "csrf_token",
Value: token,
@@ -79,6 +119,12 @@ func (c *cookieManager) GenerateCSRFCookie(token string) *http.Cookie {
// InvalidateCookie creates a new cookie with a MaxAge of -1 to invalidate the cookie
func (c *cookieManager) InvalidateCookie(cookieType string) *http.Cookie {
log := getCookieLogger()
log.Debug("invalidating cookie",
"type", cookieType,
"secure", c.Secure,
"sameSite", c.SameSite)
return &http.Cookie{
Name: cookieType,
Value: "",

View File

@@ -4,11 +4,16 @@ package auth
import (
"crypto/rand"
"fmt"
"novamd/internal/logging"
"time"
"github.com/golang-jwt/jwt/v5"
)
func getJWTLogger() logging.Logger {
return getAuthLogger().WithGroup("jwt")
}
// TokenType represents the type of JWT token (access or refresh)
type TokenType string
@@ -50,13 +55,15 @@ func NewJWTService(config JWTConfig) (JWTManager, error) {
if config.SigningKey == "" {
return nil, fmt.Errorf("signing key is required")
}
// Set default expiry times if not provided
if config.AccessTokenExpiry == 0 {
config.AccessTokenExpiry = 15 * time.Minute // Default to 15 minutes
config.AccessTokenExpiry = 15 * time.Minute
}
if config.RefreshTokenExpiry == 0 {
config.RefreshTokenExpiry = 7 * 24 * time.Hour // Default to 7 days
config.RefreshTokenExpiry = 7 * 24 * time.Hour
}
return &jwtService{config: config}, nil
}
@@ -93,11 +100,18 @@ func (s *jwtService) generateToken(userID int, role string, sessionID string, to
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(s.config.SigningKey))
signedToken, err := token.SignedString([]byte(s.config.SigningKey))
if err != nil {
return "", err
}
return signedToken, nil
}
// ValidateToken validates and parses a JWT token
func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) {
log := getJWTLogger()
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
// Validate the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
@@ -110,9 +124,16 @@ func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) {
return nil, fmt.Errorf("invalid token: %w", err)
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
return claims, nil
claims, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return nil, fmt.Errorf("invalid token claims")
}
return nil, fmt.Errorf("invalid token claims")
log.Debug("token validated",
"userId", claims.UserID,
"role", claims.Role,
"tokenType", claims.Type,
"expiresAt", claims.ExpiresAt)
return claims, nil
}

View File

@@ -6,10 +6,9 @@ import (
"time"
"novamd/internal/auth"
_ "novamd/internal/testenv"
)
// jwt_test.go tests
func TestNewJWTService(t *testing.T) {
testCases := []struct {
name string

View File

@@ -3,10 +3,14 @@ package auth
import (
"crypto/subtle"
"net/http"
"novamd/internal/context"
"novamd/internal/logging"
)
func getMiddlewareLogger() logging.Logger {
return getAuthLogger().WithGroup("middleware")
}
// Middleware handles JWT authentication for protected routes
type Middleware struct {
jwtManager JWTManager
@@ -26,9 +30,15 @@ func NewMiddleware(jwtManager JWTManager, sessionManager SessionManager, cookieM
// Authenticate middleware validates JWT tokens and sets user information in context
func (m *Middleware) Authenticate(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract token from Authorization header
log := getMiddlewareLogger().With(
"handler", "Authenticate",
"clientIP", r.RemoteAddr,
)
// Extract token from cookie
cookie, err := r.Cookie("access_token")
if err != nil {
log.Warn("attempt to access protected route without token")
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
@@ -36,12 +46,14 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
// Validate token
claims, err := m.jwtManager.ValidateToken(cookie.Value)
if err != nil {
log.Warn("attempt to access protected route with invalid token", "error", err.Error())
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
// Check token type
if claims.Type != AccessToken {
log.Warn("attempt to access protected route with invalid token type", "type", claims.Type)
http.Error(w, "Invalid token type", http.StatusUnauthorized)
return
}
@@ -49,6 +61,7 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
// Check if session is still valid in database
session, err := m.sessionManager.ValidateSession(claims.ID)
if err != nil || session == nil {
log.Warn("attempt to access protected route with invalid session", "error", err.Error())
m.cookieManager.InvalidateCookie("access_token")
m.cookieManager.InvalidateCookie("refresh_token")
m.cookieManager.InvalidateCookie("csrf_token")
@@ -60,17 +73,20 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions {
csrfCookie, err := r.Cookie("csrf_token")
if err != nil {
log.Warn("attempt to access protected route without CSRF token", "error", err.Error())
http.Error(w, "CSRF cookie not found", http.StatusForbidden)
return
}
csrfHeader := r.Header.Get("X-CSRF-Token")
if csrfHeader == "" {
log.Warn("attempt to access protected route without CSRF header")
http.Error(w, "CSRF token header not found", http.StatusForbidden)
return
}
if subtle.ConstantTimeCompare([]byte(csrfCookie.Value), []byte(csrfHeader)) != 1 {
log.Warn("attempt to access protected route with invalid CSRF token")
http.Error(w, "CSRF token mismatch", http.StatusForbidden)
return
}
@@ -91,12 +107,19 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := getMiddlewareLogger().With(
"handler", "RequireRole",
"requiredRole", role,
"clientIP", r.RemoteAddr,
)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
if ctx.UserRole != role && ctx.UserRole != "admin" {
log.Warn("attempt to access protected route without required role")
http.Error(w, "Insufficient permissions", http.StatusForbidden)
return
}
@@ -114,7 +137,13 @@ func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
return
}
// If no workspace in context, allow the request (might be a non-workspace endpoint)
log := getMiddlewareLogger().With(
"handler", "RequireWorkspaceAccess",
"clientIP", r.RemoteAddr,
"userId", ctx.UserID,
)
// If no workspace in context, allow the request
if ctx.Workspace == nil {
next.ServeHTTP(w, r)
return
@@ -122,6 +151,7 @@ func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
// Check if user has access (either owner or admin)
if ctx.Workspace.UserID != ctx.UserID && ctx.UserRole != "admin" {
log.Warn("attempt to access workspace without permission")
http.Error(w, "Not Found", http.StatusNotFound)
return
}

View File

@@ -11,6 +11,7 @@ import (
"novamd/internal/auth"
"novamd/internal/context"
"novamd/internal/models"
_ "novamd/internal/testenv"
)
// Mock SessionManager

View File

@@ -3,12 +3,17 @@ package auth
import (
"fmt"
"novamd/internal/db"
"novamd/internal/logging"
"novamd/internal/models"
"time"
"github.com/google/uuid"
)
func getSessionLogger() logging.Logger {
return getAuthLogger().WithGroup("session")
}
// SessionManager is an interface for managing user sessions
type SessionManager interface {
CreateSession(userID int, role string) (*models.Session, string, error)
@@ -35,6 +40,7 @@ func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManage
// CreateSession creates a new user session for a user with the given userID and role
func (s *sessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
log := getSessionLogger()
// Generate a new session ID
sessionID := uuid.New().String()
@@ -70,12 +76,18 @@ func (s *sessionManager) CreateSession(userID int, role string) (*models.Session
return nil, "", err
}
log.Debug("created new session",
"userId", userID,
"role", role,
"sessionId", sessionID,
"expiresAt", claims.ExpiresAt.Time)
return session, accessToken, nil
}
// RefreshSession creates a new access token using a refreshToken
func (s *sessionManager) RefreshSession(refreshToken string) (string, error) {
// Get session from database first
// Get session from database
session, err := s.db.GetSessionByRefreshToken(refreshToken)
if err != nil {
return "", fmt.Errorf("invalid session: %w", err)
@@ -87,17 +99,22 @@ func (s *sessionManager) RefreshSession(refreshToken string) (string, error) {
return "", fmt.Errorf("invalid refresh token: %w", err)
}
// Double check that the claims match the session
if claims.UserID != session.UserID {
return "", fmt.Errorf("token does not match session")
}
// Generate a new access token
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role, session.ID)
newToken, err := s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role, session.ID)
if err != nil {
return "", err
}
return newToken, nil
}
// ValidateSession checks if a session with the given sessionID is valid
func (s *sessionManager) ValidateSession(sessionID string) (*models.Session, error) {
log := getSessionLogger()
// Get the session from the database
session, err := s.db.GetSessionByID(sessionID)
@@ -105,21 +122,43 @@ func (s *sessionManager) ValidateSession(sessionID string) (*models.Session, err
return nil, fmt.Errorf("failed to get session: %w", err)
}
log.Debug("validated session",
"sessionId", sessionID,
"userId", session.UserID,
"expiresAt", session.ExpiresAt)
return session, nil
}
// InvalidateSession removes a session with the given sessionID from the database
func (s *sessionManager) InvalidateSession(token string) error {
log := getSessionLogger()
// Parse the JWT to get the session info
claims, err := s.jwtManager.ValidateToken(token)
if err != nil {
return fmt.Errorf("invalid token: %w", err)
}
return s.db.DeleteSession(claims.ID)
if err := s.db.DeleteSession(claims.ID); err != nil {
return err
}
log.Debug("invalidated session",
"sessionId", claims.ID,
"userId", claims.UserID)
return nil
}
// CleanExpiredSessions removes all expired sessions from the database
func (s *sessionManager) CleanExpiredSessions() error {
return s.db.CleanExpiredSessions()
log := getSessionLogger()
if err := s.db.CleanExpiredSessions(); err != nil {
return err
}
log.Info("cleaned expired sessions")
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"novamd/internal/auth"
"novamd/internal/models"
_ "novamd/internal/testenv"
)
// Mock SessionStore

View File

@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"net/http"
"novamd/internal/logging"
"novamd/internal/models"
)
@@ -28,10 +29,22 @@ type HandlerContext struct {
Workspace *models.Workspace // Optional, only set for workspace routes
}
var logger logging.Logger
func getLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("context")
}
return logger
}
// GetRequestContext retrieves the handler context from the request
func GetRequestContext(w http.ResponseWriter, r *http.Request) (*HandlerContext, bool) {
ctx := r.Context().Value(HandlerContextKey)
if ctx == nil {
getLogger().Error("missing handler context in request",
"path", r.URL.Path,
"method", r.Method)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return nil, false
}

View File

@@ -7,6 +7,7 @@ import (
"testing"
"novamd/internal/context"
_ "novamd/internal/testenv"
)
func TestGetRequestContext(t *testing.T) {

View File

@@ -10,9 +10,13 @@ import (
// WithUserContextMiddleware extracts user information from JWT claims
// and adds it to the request context
func WithUserContextMiddleware(next http.Handler) http.Handler {
log := getLogger()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, err := GetUserFromContext(r.Context())
if err != nil {
log.Error("failed to get user from context",
"error", err,
"path", r.URL.Path)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
@@ -30,6 +34,7 @@ func WithUserContextMiddleware(next http.Handler) http.Handler {
// WithWorkspaceContextMiddleware adds workspace information to the request context
func WithWorkspaceContextMiddleware(db db.WorkspaceReader) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
log := getLogger()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, ok := GetRequestContext(w, r)
if !ok {
@@ -39,11 +44,15 @@ func WithWorkspaceContextMiddleware(db db.WorkspaceReader) func(http.Handler) ht
workspaceName := chi.URLParam(r, "workspaceName")
workspace, err := db.GetWorkspaceByName(ctx.UserID, workspaceName)
if err != nil {
http.Error(w, "Workspace not found", http.StatusNotFound)
log.Error("failed to get workspace",
"error", err,
"userID", ctx.UserID,
"workspace", workspaceName,
"path", r.URL.Path)
http.Error(w, "Failed to get workspace", http.StatusNotFound)
return
}
// Update existing context with workspace
ctx.Workspace = workspace
r = WithHandlerContext(r, ctx)
next.ServeHTTP(w, r)

View File

@@ -9,6 +9,7 @@ import (
"novamd/internal/context"
"novamd/internal/models"
_ "novamd/internal/testenv"
)
// MockDB implements the minimal database interface needed for testing
@@ -89,6 +90,10 @@ func TestWithUserContextMiddleware(t *testing.T) {
}
}
type contextKey string
const workspaceNameKey contextKey = "workspaceName"
func TestWithWorkspaceContextMiddleware(t *testing.T) {
tests := []struct {
name string
@@ -158,7 +163,7 @@ func TestWithWorkspaceContextMiddleware(t *testing.T) {
}
// Add workspace name to request context via chi URL params
req = req.WithContext(stdctx.WithValue(req.Context(), "workspaceName", tt.workspaceName))
req = req.WithContext(stdctx.WithValue(req.Context(), workspaceNameKey, tt.workspaceName))
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@@ -3,7 +3,9 @@ package db
import (
"database/sql"
"fmt"
"novamd/internal/logging"
"novamd/internal/models"
"novamd/internal/secrets"
@@ -77,6 +79,7 @@ type Database interface {
Migrate() error
}
// Verify that the database implements the required interfaces
var (
// Main Database interface
_ Database = (*database)(nil)
@@ -92,6 +95,15 @@ var (
_ WorkspaceWriter = (*database)(nil)
)
var logger logging.Logger
func getLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("db")
}
return logger
}
// database represents the database connection
type database struct {
*sql.DB
@@ -100,19 +112,22 @@ type database struct {
// Init initializes the database connection
func Init(dbPath string, secretsService secrets.Service) (Database, error) {
log := getLogger()
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to open database: %w", err)
}
if err := db.Ping(); err != nil {
return nil, err
return nil, fmt.Errorf("failed to ping database: %w", err)
}
// Enable foreign keys for this connection
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
return nil, err
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
}
log.Debug("foreign keys enabled")
database := &database{
DB: db,
@@ -124,7 +139,13 @@ func Init(dbPath string, secretsService secrets.Service) (Database, error) {
// Close closes the database connection
func (db *database) Close() error {
return db.DB.Close()
log := getLogger()
log.Info("closing database connection")
if err := db.DB.Close(); err != nil {
return fmt.Errorf("failed to close database: %w", err)
}
return nil
}
// Helper methods for token encryption/decryption
@@ -132,12 +153,24 @@ func (db *database) encryptToken(token string) (string, error) {
if token == "" {
return "", nil
}
return db.secretsService.Encrypt(token)
encrypted, err := db.secretsService.Encrypt(token)
if err != nil {
return "", fmt.Errorf("failed to encrypt token: %w", err)
}
return encrypted, nil
}
func (db *database) decryptToken(token string) (string, error) {
if token == "" {
return "", nil
}
return db.secretsService.Decrypt(token)
decrypted, err := db.secretsService.Decrypt(token)
if err != nil {
return "", fmt.Errorf("failed to decrypt token: %w", err)
}
return decrypted, nil
}

View File

@@ -2,7 +2,6 @@ package db
import (
"fmt"
"log"
)
// Migration represents a database migration
@@ -79,56 +78,64 @@ var migrations = []Migration{
// Migrate applies all database migrations
func (db *database) Migrate() error {
log := getLogger().WithGroup("migrations")
log.Info("starting database migration")
// Create migrations table if it doesn't exist
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations (
version INTEGER PRIMARY KEY
)`)
if err != nil {
return err
return fmt.Errorf("failed to create migrations table: %w", err)
}
// Get current version
var currentVersion int
err = db.QueryRow("SELECT COALESCE(MAX(version), 0) FROM migrations").Scan(&currentVersion)
if err != nil {
return err
return fmt.Errorf("failed to get current migration version: %w", err)
}
// Apply new migrations
for _, migration := range migrations {
if migration.Version > currentVersion {
log.Printf("Applying migration %d", migration.Version)
log := log.With("migration_version", migration.Version)
tx, err := db.Begin()
if err != nil {
return err
return fmt.Errorf("failed to begin transaction for migration %d: %w", migration.Version, err)
}
// Execute migration SQL
_, err = tx.Exec(migration.SQL)
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("migration %d failed: %v, rollback failed: %v", migration.Version, err, rbErr)
return fmt.Errorf("migration %d failed: %v, rollback failed: %v",
migration.Version, err, rbErr)
}
return fmt.Errorf("migration %d failed: %v", migration.Version, err)
return fmt.Errorf("migration %d failed: %w", migration.Version, err)
}
// Update migrations table
_, err = tx.Exec("INSERT INTO migrations (version) VALUES (?)", migration.Version)
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("failed to update migration version: %v, rollback failed: %v", err, rbErr)
return fmt.Errorf("failed to update migration version: %v, rollback failed: %v",
err, rbErr)
}
return fmt.Errorf("failed to update migration version: %v", err)
return fmt.Errorf("failed to update migration version: %w", err)
}
// Commit transaction
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit migration %d: %v", migration.Version, err)
return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err)
}
currentVersion = migration.Version
log.Debug("migration applied", "new_version", currentVersion)
}
}
log.Printf("Database is at version %d", currentVersion)
log.Info("database migration completed", "final_version", currentVersion)
return nil
}

View File

@@ -5,6 +5,8 @@ import (
"novamd/internal/db"
_ "novamd/internal/testenv"
_ "github.com/mattn/go-sqlite3"
)

View File

@@ -18,6 +18,7 @@ func (db *database) CreateSession(session *models.Session) error {
if err != nil {
return fmt.Errorf("failed to store session: %w", err)
}
return nil
}
@@ -82,9 +83,17 @@ func (db *database) DeleteSession(sessionID string) error {
// CleanExpiredSessions removes all expired sessions from the database
func (db *database) CleanExpiredSessions() error {
_, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
log := getLogger().WithGroup("sessions")
result, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
if err != nil {
return fmt.Errorf("failed to clean expired sessions: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
log.Info("cleaned expired sessions", "sessions_removed", rowsAffected)
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"novamd/internal/db"
"novamd/internal/models"
_ "novamd/internal/testenv"
"github.com/google/uuid"
)

View File

@@ -21,6 +21,8 @@ type UserStats struct {
// EnsureJWTSecret makes sure a JWT signing secret exists in the database
// If no secret exists, it generates and stores a new one
func (db *database) EnsureJWTSecret() (string, error) {
log := getLogger().WithGroup("system")
// First, try to get existing secret
secret, err := db.GetSystemSetting(JWTSecretKey)
if err == nil {
@@ -39,6 +41,8 @@ func (db *database) EnsureJWTSecret() (string, error) {
return "", fmt.Errorf("failed to store JWT secret: %w", err)
}
log.Info("new JWT secret generated and stored")
return newSecret, nil
}
@@ -49,6 +53,7 @@ func (db *database) GetSystemSetting(key string) (string, error) {
if err != nil {
return "", err
}
return value, nil
}
@@ -59,17 +64,27 @@ func (db *database) SetSystemSetting(key, value string) error {
VALUES (?, ?)
ON CONFLICT(key) DO UPDATE SET value = ?`,
key, value, value)
return err
if err != nil {
return fmt.Errorf("failed to store system setting: %w", err)
}
return nil
}
// generateRandomSecret generates a cryptographically secure random string
func generateRandomSecret(bytes int) (string, error) {
log := getLogger().WithGroup("system")
log.Debug("generating random secret", "bytes", bytes)
b := make([]byte, bytes)
_, err := rand.Read(b)
if err != nil {
return "", err
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
return base64.StdEncoding.EncodeToString(b), nil
secret := base64.StdEncoding.EncodeToString(b)
return secret, nil
}
// GetSystemStats returns system-wide statistics
@@ -79,13 +94,13 @@ func (db *database) GetSystemStats() (*UserStats, error) {
// Get total users
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get total users count: %w", err)
}
// Get total workspaces
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get total workspaces count: %w", err)
}
// Get active users (users with activity in last 30 days)
@@ -95,8 +110,7 @@ func (db *database) GetSystemStats() (*UserStats, error) {
WHERE created_at > datetime('now', '-30 days')`).
Scan(&stats.ActiveUsers)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get active users count: %w", err)
}
return stats, nil
}

View File

@@ -9,6 +9,7 @@ import (
"novamd/internal/db"
"novamd/internal/models"
_ "novamd/internal/testenv"
"github.com/google/uuid"
)

View File

@@ -2,14 +2,18 @@ package db
import (
"database/sql"
"fmt"
"novamd/internal/models"
)
// CreateUser inserts a new user record into the database
func (db *database) CreateUser(user *models.User) (*models.User, error) {
log := getLogger().WithGroup("users")
log.Debug("creating user", "email", user.Email)
tx, err := db.Begin()
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
@@ -18,19 +22,19 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
VALUES (?, ?, ?, ?)`,
user.Email, user.DisplayName, user.PasswordHash, user.Role)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to insert user: %w", err)
}
userID, err := result.LastInsertId()
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get last insert ID: %w", err)
}
user.ID = int(userID)
// Retrieve the created_at timestamp
err = tx.QueryRow("SELECT created_at FROM users WHERE id = ?", user.ID).Scan(&user.CreatedAt)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get created timestamp: %w", err)
}
// Create default workspace with default settings
@@ -38,31 +42,34 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
UserID: user.ID,
Name: "Main",
}
defaultWorkspace.SetDefaultSettings() // Initialize default settings
defaultWorkspace.SetDefaultSettings()
// Create workspace with settings
err = db.createWorkspaceTx(tx, defaultWorkspace)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create default workspace: %w", err)
}
// Update user's last workspace ID
_, err = tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", defaultWorkspace.ID, user.ID)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to update last workspace ID: %w", err)
}
err = tx.Commit()
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}
log.Debug("created user", "user_id", user.ID)
user.LastWorkspaceID = defaultWorkspace.ID
return user, nil
}
// Helper function to create a workspace in a transaction
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
log := getLogger().WithGroup("users")
result, err := tx.Exec(`
INSERT INTO workspaces (
user_id, name,
@@ -78,17 +85,21 @@ func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) e
workspace.GitCommitName, workspace.GitCommitEmail,
)
if err != nil {
return err
return fmt.Errorf("failed to insert workspace: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return err
return fmt.Errorf("failed to get workspace ID: %w", err)
}
workspace.ID = int(id)
log.Debug("created user workspace",
"workspace_id", workspace.ID,
"user_id", workspace.UserID)
return nil
}
// GetUserByID retrieves a user by ID
func (db *database) GetUserByID(id int) (*models.User, error) {
user := &models.User{}
err := db.QueryRow(`
@@ -97,15 +108,18 @@ func (db *database) GetUserByID(id int) (*models.User, error) {
last_workspace_id
FROM users
WHERE id = ?`, id).
Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash, &user.Role, &user.CreatedAt,
&user.LastWorkspaceID)
Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash,
&user.Role, &user.CreatedAt, &user.LastWorkspaceID)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("user not found")
}
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to fetch user: %w", err)
}
return user, nil
}
// GetUserByEmail retrieves a user by email
func (db *database) GetUserByEmail(email string) (*models.User, error) {
user := &models.User{}
err := db.QueryRow(`
@@ -114,26 +128,43 @@ func (db *database) GetUserByEmail(email string) (*models.User, error) {
last_workspace_id
FROM users
WHERE email = ?`, email).
Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash, &user.Role, &user.CreatedAt,
&user.LastWorkspaceID)
Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash,
&user.Role, &user.CreatedAt, &user.LastWorkspaceID)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("user not found")
}
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to fetch user: %w", err)
}
return user, nil
}
// UpdateUser updates a user's information
func (db *database) UpdateUser(user *models.User) error {
_, err := db.Exec(`
result, err := db.Exec(`
UPDATE users
SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ?
WHERE id = ?`,
user.Email, user.DisplayName, user.PasswordHash, user.Role, user.LastWorkspaceID, user.ID)
return err
user.Email, user.DisplayName, user.PasswordHash, user.Role,
user.LastWorkspaceID, user.ID)
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("user not found")
}
return nil
}
// GetAllUsers returns a list of all users in the system
func (db *database) GetAllUsers() ([]*models.User, error) {
rows, err := db.Query(`
SELECT
@@ -142,7 +173,7 @@ func (db *database) GetAllUsers() ([]*models.User, error) {
FROM users
ORDER BY id ASC`)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to query users: %w", err)
}
defer rows.Close()
@@ -154,60 +185,74 @@ func (db *database) GetAllUsers() ([]*models.User, error) {
&user.CreatedAt, &user.LastWorkspaceID,
)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to scan user row: %w", err)
}
users = append(users, user)
}
return users, nil
}
// UpdateLastWorkspace updates the last workspace the user accessed
func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error {
tx, err := db.Begin()
if err != nil {
return err
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
var workspaceID int
err = tx.QueryRow("SELECT id FROM workspaces WHERE user_id = ? AND name = ?", userID, workspaceName).Scan(&workspaceID)
err = tx.QueryRow("SELECT id FROM workspaces WHERE user_id = ? AND name = ?",
userID, workspaceName).Scan(&workspaceID)
if err != nil {
return err
return fmt.Errorf("failed to find workspace: %w", err)
}
_, err = tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID)
_, err = tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?",
workspaceID, userID)
if err != nil {
return err
return fmt.Errorf("failed to update last workspace: %w", err)
}
return tx.Commit()
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
// DeleteUser deletes a user and all their workspaces
func (db *database) DeleteUser(id int) error {
log := getLogger().WithGroup("users")
log.Debug("deleting user", "user_id", id)
tx, err := db.Begin()
if err != nil {
return err
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
// Delete all user's workspaces first
log.Debug("deleting user workspaces", "user_id", id)
_, err = tx.Exec("DELETE FROM workspaces WHERE user_id = ?", id)
if err != nil {
return err
return fmt.Errorf("failed to delete workspaces: %w", err)
}
// Delete the user
_, err = tx.Exec("DELETE FROM users WHERE id = ?", id)
if err != nil {
return err
return fmt.Errorf("failed to delete user: %w", err)
}
return tx.Commit()
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
log.Debug("deleted user", "user_id", id)
return nil
}
// GetLastWorkspaceName returns the name of the last workspace the user accessed
func (db *database) GetLastWorkspaceName(userID int) (string, error) {
var workspaceName string
err := db.QueryRow(`
@@ -217,12 +262,24 @@ func (db *database) GetLastWorkspaceName(userID int) (string, error) {
JOIN users u ON u.last_workspace_id = w.id
WHERE u.id = ?`, userID).
Scan(&workspaceName)
return workspaceName, err
if err == sql.ErrNoRows {
return "", fmt.Errorf("no last workspace found")
}
if err != nil {
return "", fmt.Errorf("failed to fetch last workspace name: %w", err)
}
return workspaceName, nil
}
// CountAdminUsers returns the number of admin users in the system
func (db *database) CountAdminUsers() (int, error) {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count)
return count, err
if err != nil {
return 0, fmt.Errorf("failed to count admin users: %w", err)
}
return count, nil
}

View File

@@ -6,6 +6,7 @@ import (
"novamd/internal/db"
"novamd/internal/models"
_ "novamd/internal/testenv"
)
func TestUserOperations(t *testing.T) {

View File

@@ -8,6 +8,12 @@ import (
// CreateWorkspace inserts a new workspace record into the database
func (db *database) CreateWorkspace(workspace *models.Workspace) error {
log := getLogger().WithGroup("workspaces")
log.Debug("creating new workspace",
"user_id", workspace.UserID,
"name", workspace.Name,
"git_enabled", workspace.GitEnabled)
// Set default settings if not provided
if workspace.Theme == "" {
workspace.SetDefaultSettings()
@@ -31,14 +37,15 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error {
workspace.GitAutoCommit, workspace.GitCommitMsgTemplate, workspace.GitCommitName, workspace.GitCommitEmail,
)
if err != nil {
return err
return fmt.Errorf("failed to insert workspace: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return err
return fmt.Errorf("failed to get workspace ID: %w", err)
}
workspace.ID = int(id)
return nil
}
@@ -61,10 +68,15 @@ func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
&workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt,
&workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles,
&workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken,
&workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, &workspace.GitCommitName, &workspace.GitCommitEmail,
&workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate,
&workspace.GitCommitName, &workspace.GitCommitEmail,
)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("workspace not found")
}
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to fetch workspace: %w", err)
}
// Decrypt token
@@ -98,8 +110,12 @@ func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*model
&workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate,
&workspace.GitCommitName, &workspace.GitCommitEmail,
)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("workspace not found")
}
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to fetch workspace: %w", err)
}
// Decrypt token
@@ -150,7 +166,11 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
workspace.ID,
workspace.UserID,
)
return err
if err != nil {
return fmt.Errorf("failed to update workspace: %w", err)
}
return nil
}
// GetWorkspacesByUserID retrieves all workspaces for a user
@@ -167,7 +187,7 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro
userID,
)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to query workspaces: %w", err)
}
defer rows.Close()
@@ -183,7 +203,7 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro
&workspace.GitCommitName, &workspace.GitCommitEmail,
)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to scan workspace row: %w", err)
}
// Decrypt token
@@ -194,11 +214,15 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro
workspaces = append(workspaces, workspace)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating workspace rows: %w", err)
}
return workspaces, nil
}
// UpdateWorkspaceSettings updates only the settings portion of a workspace
// This is useful when you don't want to modify the name or other core workspace properties
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
_, err := db.Exec(`
UPDATE workspaces
@@ -228,43 +252,88 @@ func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
workspace.GitCommitEmail,
workspace.ID,
)
return err
if err != nil {
return fmt.Errorf("failed to update workspace settings: %w", err)
}
return nil
}
// DeleteWorkspace removes a workspace record from the database
func (db *database) DeleteWorkspace(id int) error {
log := getLogger().WithGroup("workspaces")
_, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id)
return err
if err != nil {
return fmt.Errorf("failed to delete workspace: %w", err)
}
log.Debug("workspace deleted", "workspace_id", id)
return nil
}
// DeleteWorkspaceTx removes a workspace record from the database within a transaction
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
_, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id)
return err
log := getLogger().WithGroup("workspaces")
result, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id)
if err != nil {
return fmt.Errorf("failed to delete workspace in transaction: %w", err)
}
_, err = result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected in transaction: %w", err)
}
log.Debug("workspace deleted",
"workspace_id", id)
return nil
}
// UpdateLastWorkspaceTx sets the last workspace for a user in with a transaction
// UpdateLastWorkspaceTx sets the last workspace for a user in a transaction
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
_, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID)
return err
result, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?",
workspaceID, userID)
if err != nil {
return fmt.Errorf("failed to update last workspace in transaction: %w", err)
}
_, err = result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected in transaction: %w", err)
}
return nil
}
// UpdateLastOpenedFile updates the last opened file path for a workspace
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
_, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?", filePath, workspaceID)
return err
_, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?",
filePath, workspaceID)
if err != nil {
return fmt.Errorf("failed to update last opened file: %w", err)
}
return nil
}
// GetLastOpenedFile retrieves the last opened file path for a workspace
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
var filePath sql.NullString
err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?", workspaceID).Scan(&filePath)
if err != nil {
return "", err
err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?",
workspaceID).Scan(&filePath)
if err == sql.ErrNoRows {
return "", fmt.Errorf("workspace not found")
}
if err != nil {
return "", fmt.Errorf("failed to fetch last opened file: %w", err)
}
if !filePath.Valid {
return "", nil
}
return filePath.String, nil
}
@@ -280,7 +349,7 @@ func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
FROM workspaces`,
)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to query workspaces: %w", err)
}
defer rows.Close()
@@ -296,7 +365,7 @@ func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
&workspace.GitCommitName, &workspace.GitCommitEmail,
)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to scan workspace row: %w", err)
}
// Decrypt token
@@ -307,5 +376,10 @@ func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
workspaces = append(workspaces, workspace)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating workspace rows: %w", err)
}
return workspaces, nil
}

View File

@@ -1,12 +1,12 @@
package db_test
import (
"database/sql"
"strings"
"testing"
"novamd/internal/db"
"novamd/internal/models"
_ "novamd/internal/testenv"
)
func TestWorkspaceOperations(t *testing.T) {
@@ -385,8 +385,8 @@ func TestWorkspaceOperations(t *testing.T) {
// Verify workspace is gone
_, err = database.GetWorkspaceByID(workspace.ID)
if err != sql.ErrNoRows {
t.Errorf("expected sql.ErrNoRows, got %v", err)
if !strings.Contains(err.Error(), "workspace not found") {
t.Errorf("expected workspace not found, got %v", err)
}
})
}

View File

@@ -7,6 +7,8 @@ import (
"path/filepath"
"time"
"novamd/internal/logging"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing"
"github.com/go-git/go-git/v5/plumbing/object"
@@ -46,6 +48,15 @@ type client struct {
repo *git.Repository
}
var logger logging.Logger
func getLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("git")
}
return logger
}
// New creates a new git Client instance
func New(url, username, token, workDir, commitName, commitEmail string) Client {
return &client{
@@ -62,6 +73,11 @@ func New(url, username, token, workDir, commitName, commitEmail string) Client {
// Clone clones the Git repository to the local directory
func (c *client) Clone() error {
log := getLogger()
log.Info("cloning git repository",
"url", c.URL,
"workDir", c.WorkDir)
auth := &http.BasicAuth{
Username: c.Username,
Password: c.Token,
@@ -73,7 +89,6 @@ func (c *client) Clone() error {
Auth: auth,
Progress: os.Stdout,
})
if err != nil {
return fmt.Errorf("failed to clone repository: %w", err)
}
@@ -83,6 +98,10 @@ func (c *client) Clone() error {
// Pull pulls the latest changes from the remote repository
func (c *client) Pull() error {
log := getLogger().With(
"workDir", c.WorkDir,
)
if c.repo == nil {
return fmt.Errorf("repository not initialized")
}
@@ -101,16 +120,25 @@ func (c *client) Pull() error {
Auth: auth,
Progress: os.Stdout,
})
if err != nil && err != git.NoErrAlreadyUpToDate {
return fmt.Errorf("failed to pull changes: %w", err)
}
if err == git.NoErrAlreadyUpToDate {
log.Debug("repository already up to date")
} else {
log.Debug("pulled latest changes")
}
return nil
}
// Commit commits the changes in the repository with the given message
func (c *client) Commit(message string) (CommitHash, error) {
log := getLogger().With(
"workDir", c.WorkDir,
)
if c.repo == nil {
return CommitHash(plumbing.ZeroHash), fmt.Errorf("repository not initialized")
}
@@ -136,11 +164,16 @@ func (c *client) Commit(message string) (CommitHash, error) {
return CommitHash(plumbing.ZeroHash), fmt.Errorf("failed to commit changes: %w", err)
}
log.Debug("changes committed")
return CommitHash(hash), nil
}
// Push pushes the changes to the remote repository
func (c *client) Push() error {
log := getLogger().With(
"workDir", c.WorkDir,
)
if c.repo == nil {
return fmt.Errorf("repository not initialized")
}
@@ -154,17 +187,30 @@ func (c *client) Push() error {
Auth: auth,
Progress: os.Stdout,
})
if err != nil && err != git.NoErrAlreadyUpToDate {
return fmt.Errorf("failed to push changes: %w", err)
}
if err == git.NoErrAlreadyUpToDate {
log.Debug("remote already up to date",
"workDir", c.WorkDir)
} else {
log.Debug("pushed repository changes",
"workDir", c.WorkDir)
}
return nil
}
// EnsureRepo ensures the local repository is cloned and up-to-date
func (c *client) EnsureRepo() error {
log := getLogger().With(
"workDir", c.WorkDir,
)
log.Debug("ensuring repository exists and is up to date")
if _, err := os.Stat(filepath.Join(c.WorkDir, ".git")); os.IsNotExist(err) {
log.Info("repository not found, initiating clone")
return c.Clone()
}

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"novamd/internal/context"
"novamd/internal/db"
"novamd/internal/logging"
"novamd/internal/models"
"novamd/internal/storage"
"strconv"
@@ -47,6 +48,10 @@ type SystemStats struct {
*storage.FileCountStats
}
func getAdminLogger() logging.Logger {
return getHandlersLogger().WithGroup("admin")
}
// AdminListUsers godoc
// @Summary List all users
// @Description Returns the list of all users
@@ -58,9 +63,22 @@ type SystemStats struct {
// @Failure 500 {object} ErrorResponse "Failed to list users"
// @Router /admin/users [get]
func (h *Handler) AdminListUsers() http.HandlerFunc {
return func(w http.ResponseWriter, _ *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
log := getAdminLogger().With(
"handler", "AdminListUsers",
"adminID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
users, err := h.DB.GetAllUsers()
if err != nil {
log.Error("failed to fetch users from database",
"error", err.Error(),
)
respondError(w, "Failed to list users", http.StatusInternalServerError)
return
}
@@ -89,39 +107,63 @@ func (h *Handler) AdminListUsers() http.HandlerFunc {
// @Router /admin/users [post]
func (h *Handler) AdminCreateUser() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
log := getAdminLogger().With(
"handler", "AdminCreateUser",
"adminID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
var req CreateUserRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Debug("failed to decode request body",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest)
return
}
// Validate request
// Validation logging
if req.Email == "" || req.Password == "" || req.Role == "" {
log.Debug("missing required fields",
"hasEmail", req.Email != "",
"hasPassword", req.Password != "",
"hasRole", req.Role != "",
)
respondError(w, "Email, password, and role are required", http.StatusBadRequest)
return
}
// Check if email already exists
// Email existence check
existingUser, err := h.DB.GetUserByEmail(req.Email)
if err == nil && existingUser != nil {
log.Warn("attempted to create user with existing email",
"email", req.Email,
)
respondError(w, "Email already exists", http.StatusConflict)
return
}
// Check if password is long enough
if len(req.Password) < 8 {
log.Debug("password too short",
"passwordLength", len(req.Password),
)
respondError(w, "Password must be at least 8 characters", http.StatusBadRequest)
return
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
log.Error("failed to hash password",
"error", err.Error(),
)
respondError(w, "Failed to hash password", http.StatusInternalServerError)
return
}
// Create user
user := &models.User{
Email: req.Email,
DisplayName: req.DisplayName,
@@ -131,16 +173,30 @@ func (h *Handler) AdminCreateUser() http.HandlerFunc {
insertedUser, err := h.DB.CreateUser(user)
if err != nil {
log.Error("failed to create user in database",
"error", err.Error(),
"email", req.Email,
"role", req.Role,
)
respondError(w, "Failed to create user", http.StatusInternalServerError)
return
}
// Initialize user workspace
if err := h.Storage.InitializeUserWorkspace(insertedUser.ID, insertedUser.LastWorkspaceID); err != nil {
log.Error("failed to initialize user workspace",
"error", err.Error(),
"userID", insertedUser.ID,
"workspaceID", insertedUser.LastWorkspaceID,
)
respondError(w, "Failed to initialize user workspace", http.StatusInternalServerError)
return
}
log.Info("user created",
"newUserID", insertedUser.ID,
"email", insertedUser.Email,
"role", insertedUser.Role,
)
respondJSON(w, insertedUser)
}
}
@@ -159,14 +215,32 @@ func (h *Handler) AdminCreateUser() http.HandlerFunc {
// @Router /admin/users/{userId} [get]
func (h *Handler) AdminGetUser() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
log := getAdminLogger().With(
"handler", "AdminGetUser",
"adminID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
userID, err := strconv.Atoi(chi.URLParam(r, "userId"))
if err != nil {
log.Debug("invalid user ID format",
"userIDParam", chi.URLParam(r, "userId"),
"error", err.Error(),
)
respondError(w, "Invalid user ID", http.StatusBadRequest)
return
}
user, err := h.DB.GetUserByID(userID)
if err != nil {
log.Debug("user not found",
"targetUserID", userID,
"error", err.Error(),
)
respondError(w, "User not found", http.StatusNotFound)
return
}
@@ -194,49 +268,86 @@ func (h *Handler) AdminGetUser() http.HandlerFunc {
// @Router /admin/users/{userId} [put]
func (h *Handler) AdminUpdateUser() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
log := getAdminLogger().With(
"handler", "AdminUpdateUser",
"adminID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
userID, err := strconv.Atoi(chi.URLParam(r, "userId"))
if err != nil {
log.Debug("invalid user ID format",
"userIDParam", chi.URLParam(r, "userId"),
"error", err.Error(),
)
respondError(w, "Invalid user ID", http.StatusBadRequest)
return
}
// Get existing user
user, err := h.DB.GetUserByID(userID)
if err != nil {
log.Debug("user not found",
"targetUserID", userID,
"error", err.Error(),
)
respondError(w, "User not found", http.StatusNotFound)
return
}
var req UpdateUserRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Debug("failed to decode request body",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest)
return
}
// Update fields if provided
// Track what's being updated for logging
updates := make(map[string]interface{})
if req.Email != "" {
user.Email = req.Email
updates["email"] = req.Email
}
if req.DisplayName != "" {
user.DisplayName = req.DisplayName
updates["displayName"] = req.DisplayName
}
if req.Role != "" {
user.Role = req.Role
updates["role"] = req.Role
}
if req.Password != "" {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
log.Error("failed to hash password",
"error", err.Error(),
)
respondError(w, "Failed to hash password", http.StatusInternalServerError)
return
}
user.PasswordHash = string(hashedPassword)
updates["passwordUpdated"] = true
}
if err := h.DB.UpdateUser(user); err != nil {
log.Error("failed to update user in database",
"error", err.Error(),
"targetUserID", userID,
)
respondError(w, "Failed to update user", http.StatusInternalServerError)
return
}
log.Debug("user updated",
"targetUserID", userID,
"updates", updates,
)
respondJSON(w, user)
}
}
@@ -261,37 +372,61 @@ func (h *Handler) AdminDeleteUser() http.HandlerFunc {
if !ok {
return
}
log := getAdminLogger().With(
"handler", "AdminDeleteUser",
"adminID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
userID, err := strconv.Atoi(chi.URLParam(r, "userId"))
if err != nil {
log.Debug("invalid user ID format",
"userIDParam", chi.URLParam(r, "userId"),
"error", err.Error(),
)
respondError(w, "Invalid user ID", http.StatusBadRequest)
return
}
// Prevent admin from deleting themselves
if userID == ctx.UserID {
log.Warn("admin attempted to delete own account")
respondError(w, "Cannot delete your own account", http.StatusBadRequest)
return
}
// Get user before deletion to check role
user, err := h.DB.GetUserByID(userID)
if err != nil {
log.Debug("user not found",
"targetUserID", userID,
"error", err.Error(),
)
respondError(w, "User not found", http.StatusNotFound)
return
}
// Prevent deletion of other admin users
if user.Role == models.RoleAdmin && ctx.UserID != userID {
log.Warn("attempted to delete another admin user",
"targetUserID", userID,
"targetUserEmail", user.Email,
)
respondError(w, "Cannot delete other admin users", http.StatusForbidden)
return
}
if err := h.DB.DeleteUser(userID); err != nil {
log.Error("failed to delete user from database",
"error", err.Error(),
"targetUserID", userID,
)
respondError(w, "Failed to delete user", http.StatusInternalServerError)
return
}
log.Info("user deleted",
"targetUserID", userID,
"targetUserEmail", user.Email,
"targetUserRole", user.Role,
)
w.WriteHeader(http.StatusNoContent)
}
}
@@ -309,9 +444,22 @@ func (h *Handler) AdminDeleteUser() http.HandlerFunc {
// @Failure 500 {object} ErrorResponse "Failed to get file stats"
// @Router /admin/workspaces [get]
func (h *Handler) AdminListWorkspaces() http.HandlerFunc {
return func(w http.ResponseWriter, _ *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
log := getAdminLogger().With(
"handler", "AdminListWorkspaces",
"adminID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
workspaces, err := h.DB.GetAllWorkspaces()
if err != nil {
log.Error("failed to fetch workspaces from database",
"error", err.Error(),
)
respondError(w, "Failed to list workspaces", http.StatusInternalServerError)
return
}
@@ -319,11 +467,15 @@ func (h *Handler) AdminListWorkspaces() http.HandlerFunc {
workspacesStats := make([]*WorkspaceStats, 0, len(workspaces))
for _, ws := range workspaces {
workspaceData := &WorkspaceStats{}
user, err := h.DB.GetUserByID(ws.UserID)
if err != nil {
log.Error("failed to fetch user for workspace",
"error", err.Error(),
"workspaceID", ws.ID,
"userID", ws.UserID,
)
respondError(w, "Failed to get user", http.StatusInternalServerError)
return
}
@@ -336,12 +488,16 @@ func (h *Handler) AdminListWorkspaces() http.HandlerFunc {
fileStats, err := h.Storage.GetFileStats(ws.UserID, ws.ID)
if err != nil {
log.Error("failed to fetch file stats for workspace",
"error", err.Error(),
"workspaceID", ws.ID,
"userID", ws.UserID,
)
respondError(w, "Failed to get file stats", http.StatusInternalServerError)
return
}
workspaceData.FileCountStats = fileStats
workspacesStats = append(workspacesStats, workspaceData)
}
@@ -361,15 +517,31 @@ func (h *Handler) AdminListWorkspaces() http.HandlerFunc {
// @Failure 500 {object} ErrorResponse "Failed to get file stats"
// @Router /admin/stats [get]
func (h *Handler) AdminGetSystemStats() http.HandlerFunc {
return func(w http.ResponseWriter, _ *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
log := getAdminLogger().With(
"handler", "AdminGetSystemStats",
"adminID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
userStats, err := h.DB.GetSystemStats()
if err != nil {
log.Error("failed to fetch user statistics",
"error", err.Error(),
)
respondError(w, "Failed to get user stats", http.StatusInternalServerError)
return
}
fileStats, err := h.Storage.GetTotalFileStats()
if err != nil {
log.Error("failed to fetch file statistics",
"error", err.Error(),
)
respondError(w, "Failed to get file stats", http.StatusInternalServerError)
return
}

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"novamd/internal/auth"
"novamd/internal/context"
"novamd/internal/logging"
"novamd/internal/models"
"time"
@@ -26,6 +27,10 @@ type LoginResponse struct {
ExpiresAt time.Time `json:"expiresAt,omitempty"`
}
func getAuthLogger() logging.Logger {
return getHandlersLogger().WithGroup("auth")
}
// Login godoc
// @Summary Login
// @Description Logs in a user and returns a session with access and refresh tokens
@@ -43,62 +48,88 @@ type LoginResponse struct {
// @Router /auth/login [post]
func (h *Handler) Login(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
log := getAuthLogger().With(
"handler", "Login",
"clientIP", r.RemoteAddr,
)
var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Debug("failed to decode request body",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest)
return
}
// Validate request
if req.Email == "" || req.Password == "" {
log.Debug("missing required fields",
"hasEmail", req.Email != "",
"hasPassword", req.Password != "",
)
respondError(w, "Email and password are required", http.StatusBadRequest)
return
}
// Get user from database
user, err := h.DB.GetUserByEmail(req.Email)
if err != nil {
log.Debug("user not found",
"email", req.Email,
"error", err.Error(),
)
respondError(w, "Invalid credentials", http.StatusUnauthorized)
return
}
// Verify password
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
if err != nil {
log.Warn("invalid password attempt",
"userID", user.ID,
"email", user.Email,
)
respondError(w, "Invalid credentials", http.StatusUnauthorized)
return
}
// Create session and generate tokens
session, accessToken, err := authManager.CreateSession(user.ID, string(user.Role))
if err != nil {
log.Error("failed to create session",
"error", err.Error(),
"userID", user.ID,
)
respondError(w, "Failed to create session", http.StatusInternalServerError)
return
}
// Generate CSRF token
csrfToken := make([]byte, 32)
if _, err := rand.Read(csrfToken); err != nil {
log.Error("failed to generate CSRF token",
"error", err.Error(),
"userID", user.ID,
)
respondError(w, "Failed to generate CSRF token", http.StatusInternalServerError)
return
}
csrfTokenString := hex.EncodeToString(csrfToken)
// Set cookies
http.SetCookie(w, cookieService.GenerateAccessTokenCookie(accessToken))
http.SetCookie(w, cookieService.GenerateRefreshTokenCookie(session.RefreshToken))
http.SetCookie(w, cookieService.GenerateCSRFCookie(csrfTokenString))
// Send CSRF token in header for initial setup
w.Header().Set("X-CSRF-Token", csrfTokenString)
// Only send user info in response, not tokens
response := LoginResponse{
User: user,
SessionID: session.ID,
ExpiresAt: session.ExpiresAt,
}
log.Debug("user logged in",
"userID", user.ID,
"email", user.Email,
"role", user.Role,
"sessionID", session.ID,
)
respondJSON(w, response)
}
}
@@ -114,24 +145,41 @@ func (h *Handler) Login(authManager auth.SessionManager, cookieService auth.Cook
// @Router /auth/logout [post]
func (h *Handler) Logout(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Get session ID from cookie
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
log := getAuthLogger().With(
"handler", "Logout",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
sessionCookie, err := r.Cookie("access_token")
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
log.Debug("missing access token cookie",
"error", err.Error(),
)
respondError(w, "Access token required", http.StatusBadRequest)
return
}
// Invalidate the session in the database
if err := authManager.InvalidateSession(sessionCookie.Value); err != nil {
log.Error("failed to invalidate session",
"error", err.Error(),
"sessionID", sessionCookie.Value,
)
respondError(w, "Failed to invalidate session", http.StatusInternalServerError)
return
}
// Clear cookies
http.SetCookie(w, cookieService.InvalidateCookie("access_token"))
http.SetCookie(w, cookieService.InvalidateCookie("refresh_token"))
http.SetCookie(w, cookieService.InvalidateCookie("csrf_token"))
log.Info("user logged out successfully",
"sessionID", sessionCookie.Value,
)
w.WriteHeader(http.StatusNoContent)
}
}
@@ -151,22 +199,34 @@ func (h *Handler) Logout(authManager auth.SessionManager, cookieService auth.Coo
// @Router /auth/refresh [post]
func (h *Handler) RefreshToken(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
log := getAuthLogger().With(
"handler", "RefreshToken",
"clientIP", r.RemoteAddr,
)
refreshCookie, err := r.Cookie("refresh_token")
if err != nil {
log.Debug("missing refresh token cookie",
"error", err.Error(),
)
respondError(w, "Refresh token required", http.StatusBadRequest)
return
}
// Generate new access token
accessToken, err := authManager.RefreshSession(refreshCookie.Value)
if err != nil {
log.Error("failed to refresh session",
"error", err.Error(),
)
respondError(w, "Invalid refresh token", http.StatusUnauthorized)
return
}
// Generate new CSRF token
csrfToken := make([]byte, 32)
if _, err := rand.Read(csrfToken); err != nil {
log.Error("failed to generate CSRF token",
"error", err.Error(),
)
respondError(w, "Failed to generate CSRF token", http.StatusInternalServerError)
return
}
@@ -196,10 +256,17 @@ func (h *Handler) GetCurrentUser() http.HandlerFunc {
if !ok {
return
}
log := getAuthLogger().With(
"handler", "GetCurrentUser",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
// Get user from database
user, err := h.DB.GetUserByID(ctx.UserID)
if err != nil {
log.Error("failed to fetch user",
"error", err.Error(),
)
respondError(w, "User not found", http.StatusNotFound)
return
}

View File

@@ -8,6 +8,7 @@ import (
"time"
"novamd/internal/context"
"novamd/internal/logging"
"novamd/internal/storage"
"github.com/go-chi/chi/v5"
@@ -35,6 +36,10 @@ type UpdateLastOpenedFileRequest struct {
FilePath string `json:"filePath"`
}
func getFilesLogger() logging.Logger {
return getHandlersLogger().WithGroup("files")
}
// ListFiles godoc
// @Summary List files
// @Description Lists all files in the user's workspace
@@ -52,9 +57,18 @@ func (h *Handler) ListFiles() http.HandlerFunc {
if !ok {
return
}
log := getFilesLogger().With(
"handler", "ListFiles",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
files, err := h.Storage.ListFilesRecursively(ctx.UserID, ctx.Workspace.ID)
if err != nil {
log.Error("failed to list files in workspace",
"error", err.Error(),
)
respondError(w, "Failed to list files", http.StatusInternalServerError)
return
}
@@ -82,15 +96,32 @@ func (h *Handler) LookupFileByName() http.HandlerFunc {
if !ok {
return
}
log := getFilesLogger().With(
"handler", "LookupFileByName",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
filename := r.URL.Query().Get("filename")
if filename == "" {
log.Debug("missing filename parameter")
respondError(w, "Filename is required", http.StatusBadRequest)
return
}
filePaths, err := h.Storage.FindFileByName(ctx.UserID, ctx.Workspace.ID, filename)
if err != nil {
if !os.IsNotExist(err) {
log.Error("failed to lookup file",
"filename", filename,
"error", err.Error(),
)
} else {
log.Debug("file not found",
"filename", filename,
)
}
respondError(w, "File not found", http.StatusNotFound)
return
}
@@ -120,21 +151,37 @@ func (h *Handler) GetFileContent() http.HandlerFunc {
if !ok {
return
}
log := getFilesLogger().With(
"handler", "GetFileContent",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
filePath := chi.URLParam(r, "*")
content, err := h.Storage.GetFileContent(ctx.UserID, ctx.Workspace.ID, filePath)
if err != nil {
if storage.IsPathValidationError(err) {
log.Error("invalid file path attempted",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Invalid file path", http.StatusBadRequest)
return
}
if os.IsNotExist(err) {
log.Debug("file not found",
"filePath", filePath,
)
respondError(w, "File not found", http.StatusNotFound)
return
}
log.Error("failed to read file content",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Failed to read file", http.StatusInternalServerError)
return
}
@@ -142,6 +189,10 @@ func (h *Handler) GetFileContent() http.HandlerFunc {
w.Header().Set("Content-Type", "text/plain")
_, err = w.Write(content)
if err != nil {
log.Error("failed to write response",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Failed to write response", http.StatusInternalServerError)
return
}
@@ -169,10 +220,20 @@ func (h *Handler) SaveFile() http.HandlerFunc {
if !ok {
return
}
log := getFilesLogger().With(
"handler", "SaveFile",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
filePath := chi.URLParam(r, "*")
content, err := io.ReadAll(r.Body)
if err != nil {
log.Error("failed to read request body",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Failed to read request body", http.StatusBadRequest)
return
}
@@ -180,10 +241,19 @@ func (h *Handler) SaveFile() http.HandlerFunc {
err = h.Storage.SaveFile(ctx.UserID, ctx.Workspace.ID, filePath, content)
if err != nil {
if storage.IsPathValidationError(err) {
log.Error("invalid file path attempted",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Invalid file path", http.StatusBadRequest)
return
}
log.Error("failed to save file",
"filePath", filePath,
"contentSize", len(content),
"error", err.Error(),
)
respondError(w, "Failed to save file", http.StatusInternalServerError)
return
}
@@ -194,7 +264,6 @@ func (h *Handler) SaveFile() http.HandlerFunc {
UpdatedAt: time.Now().UTC(),
}
w.WriteHeader(http.StatusOK)
respondJSON(w, response)
}
}
@@ -211,7 +280,6 @@ func (h *Handler) SaveFile() http.HandlerFunc {
// @Failure 400 {object} ErrorResponse "Invalid file path"
// @Failure 404 {object} ErrorResponse "File not found"
// @Failure 500 {object} ErrorResponse "Failed to delete file"
// @Failure 500 {object} ErrorResponse "Failed to write response"
// @Router /workspaces/{workspace_name}/files/{file_path} [delete]
func (h *Handler) DeleteFile() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
@@ -219,20 +287,37 @@ func (h *Handler) DeleteFile() http.HandlerFunc {
if !ok {
return
}
log := getFilesLogger().With(
"handler", "DeleteFile",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
filePath := chi.URLParam(r, "*")
err := h.Storage.DeleteFile(ctx.UserID, ctx.Workspace.ID, filePath)
if err != nil {
if storage.IsPathValidationError(err) {
log.Error("invalid file path attempted",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Invalid file path", http.StatusBadRequest)
return
}
if os.IsNotExist(err) {
log.Debug("file not found",
"filePath", filePath,
)
respondError(w, "File not found", http.StatusNotFound)
return
}
log.Error("failed to delete file",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Failed to delete file", http.StatusInternalServerError)
return
}
@@ -259,14 +344,27 @@ func (h *Handler) GetLastOpenedFile() http.HandlerFunc {
if !ok {
return
}
log := getFilesLogger().With(
"handler", "GetLastOpenedFile",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
filePath, err := h.DB.GetLastOpenedFile(ctx.Workspace.ID)
if err != nil {
log.Error("failed to get last opened file from database",
"error", err.Error(),
)
respondError(w, "Failed to get last opened file", http.StatusInternalServerError)
return
}
if _, err := h.Storage.ValidatePath(ctx.UserID, ctx.Workspace.ID, filePath); err != nil {
log.Error("invalid file path stored",
"filePath", filePath,
"error", err.Error(),
)
respondError(w, "Invalid file path", http.StatusBadRequest)
return
}
@@ -297,10 +395,18 @@ func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc {
if !ok {
return
}
log := getFilesLogger().With(
"handler", "UpdateLastOpenedFile",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
var requestBody UpdateLastOpenedFileRequest
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
log.Error("failed to decode request body",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest)
return
}
@@ -310,21 +416,36 @@ func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc {
_, err := h.Storage.GetFileContent(ctx.UserID, ctx.Workspace.ID, requestBody.FilePath)
if err != nil {
if storage.IsPathValidationError(err) {
log.Error("invalid file path attempted",
"filePath", requestBody.FilePath,
"error", err.Error(),
)
respondError(w, "Invalid file path", http.StatusBadRequest)
return
}
if os.IsNotExist(err) {
log.Debug("file not found",
"filePath", requestBody.FilePath,
)
respondError(w, "File not found", http.StatusNotFound)
return
}
log.Error("failed to validate file path",
"filePath", requestBody.FilePath,
"error", err.Error(),
)
respondError(w, "Failed to update last opened file", http.StatusInternalServerError)
return
}
}
if err := h.DB.UpdateLastOpenedFile(ctx.Workspace.ID, requestBody.FilePath); err != nil {
log.Error("failed to update last opened file in database",
"filePath", requestBody.FilePath,
"error", err.Error(),
)
respondError(w, "Failed to update last opened file", http.StatusInternalServerError)
return
}

View File

@@ -3,8 +3,8 @@ package handlers
import (
"encoding/json"
"net/http"
"novamd/internal/context"
"novamd/internal/logging"
)
// CommitRequest represents a request to commit changes
@@ -22,6 +22,10 @@ type PullResponse struct {
Message string `json:"message" example:"Pulled changes from remote"`
}
func getGitLogger() logging.Logger {
return getHandlersLogger().WithGroup("git")
}
// StageCommitAndPush godoc
// @Summary Stage, commit, and push changes
// @Description Stages, commits, and pushes changes to the remote repository
@@ -42,21 +46,34 @@ func (h *Handler) StageCommitAndPush() http.HandlerFunc {
if !ok {
return
}
log := getGitLogger().With(
"handler", "StageCommitAndPush",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
var requestBody CommitRequest
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
log.Error("failed to decode request body",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest)
return
}
if requestBody.Message == "" {
log.Debug("empty commit message provided")
respondError(w, "Commit message is required", http.StatusBadRequest)
return
}
hash, err := h.Storage.StageCommitAndPush(ctx.UserID, ctx.Workspace.ID, requestBody.Message)
if err != nil {
log.Error("failed to perform git operations",
"error", err.Error(),
"commitMessage", requestBody.Message,
)
respondError(w, "Failed to stage, commit, and push changes: "+err.Error(), http.StatusInternalServerError)
return
}
@@ -82,9 +99,18 @@ func (h *Handler) PullChanges() http.HandlerFunc {
if !ok {
return
}
log := getGitLogger().With(
"handler", "PullChanges",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
err := h.Storage.Pull(ctx.UserID, ctx.Workspace.ID)
if err != nil {
log.Error("failed to pull changes from remote",
"error", err.Error(),
)
respondError(w, "Failed to pull changes: "+err.Error(), http.StatusInternalServerError)
return
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"novamd/internal/db"
"novamd/internal/logging"
"novamd/internal/storage"
)
@@ -18,6 +19,15 @@ type Handler struct {
Storage storage.Manager
}
var logger logging.Logger
func getHandlersLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("handlers")
}
return logger
}
// NewHandler creates a new handler with the given dependencies
func NewHandler(db db.Database, s storage.Manager) *Handler {
return &Handler{

View File

@@ -21,6 +21,8 @@ import (
"novamd/internal/models"
"novamd/internal/secrets"
"novamd/internal/storage"
_ "novamd/internal/testenv"
)
// testHarness encapsulates all the dependencies needed for testing

View File

@@ -2,6 +2,7 @@ package handlers
import (
"net/http"
"novamd/internal/logging"
"os"
"path/filepath"
"strings"
@@ -19,8 +20,19 @@ func NewStaticHandler(staticPath string) *StaticHandler {
}
}
func getStaticLogger() logging.Logger {
return logging.WithGroup("static")
}
// ServeHTTP serves the static files
func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log := getStaticLogger().With(
"handler", "ServeHTTP",
"clientIP", r.RemoteAddr,
"method", r.Method,
"url", r.URL.Path,
)
// Get the requested path
requestedPath := r.URL.Path
fullPath := filepath.Join(h.staticPath, requestedPath)
@@ -28,6 +40,10 @@ func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Security check to prevent directory traversal
if !strings.HasPrefix(cleanPath, h.staticPath) {
log.Warn("directory traversal attempt detected",
"requestedPath", requestedPath,
"cleanPath", cleanPath,
)
respondError(w, "Invalid path", http.StatusBadRequest)
return
}
@@ -40,6 +56,21 @@ func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Check if file exists (not counting .gz files)
stat, err := os.Stat(cleanPath)
if err != nil || stat.IsDir() {
if os.IsNotExist(err) {
log.Debug("file not found, serving index.html",
"requestedPath", requestedPath,
)
} else if stat != nil && stat.IsDir() {
log.Debug("directory requested, serving index.html",
"requestedPath", requestedPath,
)
} else {
log.Error("error checking file status",
"requestedPath", requestedPath,
"error", err.Error(),
)
}
// Serve index.html for SPA routing
indexPath := filepath.Join(h.staticPath, "index.html")
http.ServeFile(w, r, indexPath)
@@ -53,15 +84,16 @@ func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Encoding", "gzip")
// Set proper content type based on original file
contentType := "application/octet-stream"
switch filepath.Ext(cleanPath) {
case ".js":
w.Header().Set("Content-Type", "application/javascript")
contentType = "application/javascript"
case ".css":
w.Header().Set("Content-Type", "text/css")
contentType = "text/css"
case ".html":
w.Header().Set("Content-Type", "text/html")
contentType = "text/html"
}
w.Header().Set("Content-Type", contentType)
http.ServeFile(w, r, gzPath)
return
}

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"novamd/internal/context"
"novamd/internal/logging"
"golang.org/x/crypto/bcrypt"
)
@@ -22,6 +23,10 @@ type DeleteAccountRequest struct {
Password string `json:"password"`
}
func getProfileLogger() logging.Logger {
return getHandlersLogger().WithGroup("profile")
}
// UpdateProfile godoc
// @Summary Update profile
// @Description Updates the user's profile
@@ -48,9 +53,17 @@ func (h *Handler) UpdateProfile() http.HandlerFunc {
if !ok {
return
}
log := getProfileLogger().With(
"handler", "UpdateProfile",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
var req UpdateProfileRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Debug("failed to decode request body",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest)
return
}
@@ -58,76 +71,94 @@ func (h *Handler) UpdateProfile() http.HandlerFunc {
// Get current user
user, err := h.DB.GetUserByID(ctx.UserID)
if err != nil {
log.Error("failed to fetch user from database",
"error", err.Error(),
)
respondError(w, "User not found", http.StatusNotFound)
return
}
// Track what's being updated for logging
updates := make(map[string]bool)
// Handle password update if requested
if req.NewPassword != "" {
// Current password must be provided to change password
if req.CurrentPassword == "" {
log.Debug("password change attempted without current password")
respondError(w, "Current password is required to change password", http.StatusBadRequest)
return
}
// Verify current password
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil {
log.Warn("incorrect password provided for password change")
respondError(w, "Current password is incorrect", http.StatusUnauthorized)
return
}
// Validate new password
if len(req.NewPassword) < 8 {
log.Debug("password change rejected - too short",
"passwordLength", len(req.NewPassword),
)
respondError(w, "New password must be at least 8 characters long", http.StatusBadRequest)
return
}
// Hash new password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil {
log.Error("failed to hash new password",
"error", err.Error(),
)
respondError(w, "Failed to process new password", http.StatusInternalServerError)
return
}
user.PasswordHash = string(hashedPassword)
updates["passwordChanged"] = true
}
// Handle email update if requested
if req.Email != "" && req.Email != user.Email {
// Check if email change requires password verification
if req.CurrentPassword == "" {
log.Warn("attempted email change without current password")
respondError(w, "Current password is required to change email", http.StatusBadRequest)
return
}
// Verify current password if not already verified for password change
if req.NewPassword == "" {
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil {
log.Warn("incorrect password provided for email change")
respondError(w, "Current password is incorrect", http.StatusUnauthorized)
return
}
}
// Check if new email is already in use
existingUser, err := h.DB.GetUserByEmail(req.Email)
if err == nil && existingUser.ID != user.ID {
log.Debug("email change rejected - already in use",
"requestedEmail", req.Email,
)
respondError(w, "Email already in use", http.StatusConflict)
return
}
user.Email = req.Email
updates["emailChanged"] = true
}
// Update display name if provided (no password required)
// Update display name if provided
if req.DisplayName != "" {
user.DisplayName = req.DisplayName
updates["displayNameChanged"] = true
}
// Update user in database
if err := h.DB.UpdateUser(user); err != nil {
log.Error("failed to update user in database",
"error", err.Error(),
"updates", updates,
)
respondError(w, "Failed to update profile", http.StatusInternalServerError)
return
}
// Return updated user data
respondJSON(w, user)
}
}
@@ -155,9 +186,17 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
if !ok {
return
}
log := getProfileLogger().With(
"handler", "DeleteAccount",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
var req DeleteAccountRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Debug("failed to decode request body",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest)
return
}
@@ -165,25 +204,32 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
// Get current user
user, err := h.DB.GetUserByID(ctx.UserID)
if err != nil {
log.Error("failed to fetch user from database",
"error", err.Error(),
)
respondError(w, "User not found", http.StatusNotFound)
return
}
// Verify password
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
respondError(w, "Password is incorrect", http.StatusUnauthorized)
log.Warn("incorrect password provided for account deletion")
respondError(w, "Incorrect password", http.StatusUnauthorized)
return
}
// Prevent admin from deleting their own account if they're the last admin
if user.Role == "admin" {
// Count number of admin users
adminCount, err := h.DB.CountAdminUsers()
if err != nil {
respondError(w, "Failed to verify admin status", http.StatusInternalServerError)
log.Error("failed to count admin users",
"error", err.Error(),
)
respondError(w, "Failed to get admin count", http.StatusInternalServerError)
return
}
if adminCount <= 1 {
log.Warn("attempted to delete last admin account")
respondError(w, "Cannot delete the last admin account", http.StatusForbidden)
return
}
@@ -192,6 +238,9 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
// Get user's workspaces for cleanup
workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID)
if err != nil {
log.Error("failed to fetch user workspaces",
"error", err.Error(),
)
respondError(w, "Failed to get user workspaces", http.StatusInternalServerError)
return
}
@@ -199,17 +248,31 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
// Delete workspace directories
for _, workspace := range workspaces {
if err := h.Storage.DeleteUserWorkspace(ctx.UserID, workspace.ID); err != nil {
log.Error("failed to delete workspace directory",
"error", err.Error(),
"workspaceID", workspace.ID,
)
respondError(w, "Failed to delete workspace files", http.StatusInternalServerError)
return
}
log.Debug("workspace deleted",
"workspaceID", workspace.ID,
)
}
// Delete user from database (this will cascade delete workspaces and sessions)
// Delete user from database
if err := h.DB.DeleteUser(ctx.UserID); err != nil {
log.Error("failed to delete user from database",
"error", err.Error(),
)
respondError(w, "Failed to delete account", http.StatusInternalServerError)
return
}
log.Info("user account deleted",
"email", user.Email,
"role", user.Role,
)
w.WriteHeader(http.StatusNoContent)
}
}

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"novamd/internal/context"
"novamd/internal/logging"
"novamd/internal/models"
)
@@ -19,6 +20,10 @@ type LastWorkspaceNameResponse struct {
LastWorkspaceName string `json:"lastWorkspaceName"`
}
func getWorkspaceLogger() logging.Logger {
return getHandlersLogger().WithGroup("workspace")
}
// ListWorkspaces godoc
// @Summary List workspaces
// @Description Lists all workspaces for the current user
@@ -35,9 +40,17 @@ func (h *Handler) ListWorkspaces() http.HandlerFunc {
if !ok {
return
}
log := getWorkspaceLogger().With(
"handler", "ListWorkspaces",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID)
if err != nil {
log.Error("failed to fetch workspaces from database",
"error", err.Error(),
)
respondError(w, "Failed to list workspaces", http.StatusInternalServerError)
return
}
@@ -68,25 +81,44 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc {
if !ok {
return
}
log := getWorkspaceLogger().With(
"handler", "CreateWorkspace",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
var workspace models.Workspace
if err := json.NewDecoder(r.Body).Decode(&workspace); err != nil {
log.Debug("invalid request body received",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := workspace.ValidateGitSettings(); err != nil {
log.Debug("invalid git settings provided",
"error", err.Error(),
)
respondError(w, "Invalid workspace", http.StatusBadRequest)
return
}
workspace.UserID = ctx.UserID
if err := h.DB.CreateWorkspace(&workspace); err != nil {
log.Error("failed to create workspace in database",
"error", err.Error(),
"workspaceName", workspace.Name,
)
respondError(w, "Failed to create workspace", http.StatusInternalServerError)
return
}
if err := h.Storage.InitializeUserWorkspace(workspace.UserID, workspace.ID); err != nil {
log.Error("failed to initialize workspace directory",
"error", err.Error(),
"workspaceID", workspace.ID,
)
respondError(w, "Failed to initialize workspace directory", http.StatusInternalServerError)
return
}
@@ -101,11 +133,20 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc {
workspace.GitCommitName,
workspace.GitCommitEmail,
); err != nil {
log.Error("failed to setup git repository",
"error", err.Error(),
"workspaceID", workspace.ID,
)
respondError(w, "Failed to setup git repo: "+err.Error(), http.StatusInternalServerError)
return
}
}
log.Info("workspace created",
"workspaceID", workspace.ID,
"workspaceName", workspace.Name,
"gitEnabled", workspace.GitEnabled,
)
respondJSON(w, workspace)
}
}
@@ -171,9 +212,18 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc {
if !ok {
return
}
log := getWorkspaceLogger().With(
"handler", "UpdateWorkspace",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
var workspace models.Workspace
if err := json.NewDecoder(r.Body).Decode(&workspace); err != nil {
log.Debug("invalid request body received",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest)
return
}
@@ -184,12 +234,23 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc {
// Validate the workspace
if err := workspace.Validate(); err != nil {
log.Debug("invalid workspace configuration",
"error", err.Error(),
)
respondError(w, err.Error(), http.StatusBadRequest)
return
}
// Track what's changed for logging
changes := map[string]bool{
"gitSettings": gitSettingsChanged(&workspace, ctx.Workspace),
"name": workspace.Name != ctx.Workspace.Name,
"theme": workspace.Theme != ctx.Workspace.Theme,
"autoSave": workspace.AutoSave != ctx.Workspace.AutoSave,
}
// Handle Git repository setup/teardown if Git settings changed
if gitSettingsChanged(&workspace, ctx.Workspace) {
if changes["gitSettings"] {
if workspace.GitEnabled {
if err := h.Storage.SetupGitRepo(
ctx.UserID,
@@ -200,16 +261,21 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc {
workspace.GitCommitName,
workspace.GitCommitEmail,
); err != nil {
log.Error("failed to setup git repository",
"error", err.Error(),
)
respondError(w, "Failed to setup git repo: "+err.Error(), http.StatusInternalServerError)
return
}
} else {
h.Storage.DisableGitRepo(ctx.UserID, ctx.Workspace.ID)
}
}
if err := h.DB.UpdateWorkspace(&workspace); err != nil {
log.Error("failed to update workspace in database",
"error", err.Error(),
)
respondError(w, "Failed to update workspace", http.StatusInternalServerError)
return
}
@@ -241,15 +307,25 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
if !ok {
return
}
log := getWorkspaceLogger().With(
"handler", "DeleteWorkspace",
"userID", ctx.UserID,
"workspaceID", ctx.Workspace.ID,
"clientIP", r.RemoteAddr,
)
// Check if this is the user's last workspace
workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID)
if err != nil {
log.Error("failed to fetch workspaces from database",
"error", err.Error(),
)
respondError(w, "Failed to get workspaces", http.StatusInternalServerError)
return
}
if len(workspaces) <= 1 {
log.Debug("attempted to delete last workspace")
respondError(w, "Cannot delete the last workspace", http.StatusBadRequest)
return
}
@@ -265,14 +341,19 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
}
}
// Start transaction
tx, err := h.DB.Begin()
if err != nil {
log.Error("failed to start database transaction",
"error", err.Error(),
)
respondError(w, "Failed to start transaction", http.StatusInternalServerError)
return
}
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
log.Error("failed to rollback transaction",
"error", err.Error(),
)
respondError(w, "Failed to rollback transaction", http.StatusInternalServerError)
}
}()
@@ -280,6 +361,10 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
// Update last workspace ID first
err = h.DB.UpdateLastWorkspaceTx(tx, ctx.UserID, nextWorkspaceID)
if err != nil {
log.Error("failed to update last workspace reference",
"error", err.Error(),
"nextWorkspaceID", nextWorkspaceID,
)
respondError(w, "Failed to update last workspace", http.StatusInternalServerError)
return
}
@@ -287,16 +372,27 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
// Delete the workspace
err = h.DB.DeleteWorkspaceTx(tx, ctx.Workspace.ID)
if err != nil {
log.Error("failed to delete workspace from database",
"error", err.Error(),
)
respondError(w, "Failed to delete workspace", http.StatusInternalServerError)
return
}
// Commit transaction
if err = tx.Commit(); err != nil {
log.Error("failed to commit transaction",
"error", err.Error(),
)
respondError(w, "Failed to commit transaction", http.StatusInternalServerError)
return
}
log.Info("workspace deleted",
"workspaceName", ctx.Workspace.Name,
"nextWorkspaceName", nextWorkspaceName,
)
// Return the next workspace ID in the response so frontend knows where to redirect
respondJSON(w, &DeleteWorkspaceResponse{NextWorkspaceName: nextWorkspaceName})
}
@@ -318,9 +414,17 @@ func (h *Handler) GetLastWorkspaceName() http.HandlerFunc {
if !ok {
return
}
log := getWorkspaceLogger().With(
"handler", "GetLastWorkspaceName",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
workspaceName, err := h.DB.GetLastWorkspaceName(ctx.UserID)
if err != nil {
log.Error("failed to fetch last workspace name",
"error", err.Error(),
)
respondError(w, "Failed to get last workspace", http.StatusInternalServerError)
return
}
@@ -347,17 +451,29 @@ func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc {
if !ok {
return
}
log := getWorkspaceLogger().With(
"handler", "UpdateLastWorkspaceName",
"userID", ctx.UserID,
"clientIP", r.RemoteAddr,
)
var requestBody struct {
WorkspaceName string `json:"workspaceName"`
}
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
log.Debug("invalid request body received",
"error", err.Error(),
)
respondError(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := h.DB.UpdateLastWorkspace(ctx.UserID, requestBody.WorkspaceName); err != nil {
log.Error("failed to update last workspace",
"error", err.Error(),
"workspaceName", requestBody.WorkspaceName,
)
respondError(w, "Failed to update last workspace", http.StatusInternalServerError)
return
}

View File

@@ -0,0 +1,116 @@
// Package logging provides a simple logging interface for the server.
package logging
import (
"log/slog"
"os"
)
// Logger represents the interface for logging operations
type Logger interface {
Debug(msg string, args ...any)
Info(msg string, args ...any)
Warn(msg string, args ...any)
Error(msg string, args ...any)
WithGroup(name string) Logger
With(args ...any) Logger
}
// Implementation of the Logger interface using slog
type logger struct {
logger *slog.Logger
}
// Logger is the global logger instance
var defaultLogger Logger
// LogLevel represents the log level
type LogLevel slog.Level
// Log levels
const (
DEBUG LogLevel = LogLevel(slog.LevelDebug)
INFO LogLevel = LogLevel(slog.LevelInfo)
WARN LogLevel = LogLevel(slog.LevelWarn)
ERROR LogLevel = LogLevel(slog.LevelError)
)
// Setup initializes the logger with the given minimum log level
func Setup(minLevel LogLevel) {
opts := &slog.HandlerOptions{
Level: slog.Level(minLevel),
}
defaultLogger = &logger{
logger: slog.New(slog.NewTextHandler(os.Stdout, opts)),
}
}
// ParseLogLevel converts a string to a LogLevel
func ParseLogLevel(level string) LogLevel {
switch level {
case "debug":
return DEBUG
case "warn":
return WARN
case "error":
return ERROR
default:
return INFO
}
}
// Implementation of Logger interface methods
func (l *logger) Debug(msg string, args ...any) {
l.logger.Debug(msg, args...)
}
func (l *logger) Info(msg string, args ...any) {
l.logger.Info(msg, args...)
}
func (l *logger) Warn(msg string, args ...any) {
l.logger.Warn(msg, args...)
}
func (l *logger) Error(msg string, args ...any) {
l.logger.Error(msg, args...)
}
func (l *logger) WithGroup(name string) Logger {
return &logger{logger: l.logger.WithGroup(name)}
}
func (l *logger) With(args ...any) Logger {
return &logger{logger: l.logger.With(args...)}
}
// Debug logs a debug message
func Debug(msg string, args ...any) {
defaultLogger.Debug(msg, args...)
}
// Info logs an info message
func Info(msg string, args ...any) {
defaultLogger.Info(msg, args...)
}
// Warn logs a warning message
func Warn(msg string, args ...any) {
defaultLogger.Warn(msg, args...)
}
// Error logs an error message
func Error(msg string, args ...any) {
defaultLogger.Error(msg, args...)
}
// WithGroup adds a group to the logger context
func WithGroup(name string) Logger {
return defaultLogger.WithGroup(name)
}
// With adds key-value pairs to the logger context
func With(args ...any) Logger {
return defaultLogger.With(args...)
}

View File

@@ -8,6 +8,8 @@ import (
"encoding/base64"
"fmt"
"io"
"novamd/internal/logging"
)
// Service is an interface for encrypting and decrypting strings
@@ -20,6 +22,15 @@ type encryptor struct {
gcm cipher.AEAD
}
var logger logging.Logger
func getLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("secrets")
}
return logger
}
// ValidateKey checks if the provided base64-encoded key is suitable for AES-256
func ValidateKey(key string) error {
_, err := decodeAndValidateKey(key)
@@ -73,7 +84,10 @@ func NewService(key string) (Service, error) {
// Encrypt encrypts the plaintext using AES-256-GCM
func (e *encryptor) Encrypt(plaintext string) (string, error) {
log := getLogger()
if plaintext == "" {
log.Debug("empty plaintext provided, returning empty string")
return "", nil
}
@@ -83,12 +97,18 @@ func (e *encryptor) Encrypt(plaintext string) (string, error) {
}
ciphertext := e.gcm.Seal(nonce, nonce, []byte(plaintext), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
encoded := base64.StdEncoding.EncodeToString(ciphertext)
log.Debug("data encrypted", "inputLength", len(plaintext), "outputLength", len(encoded))
return encoded, nil
}
// Decrypt decrypts the ciphertext using AES-256-GCM
func (e *encryptor) Decrypt(ciphertext string) (string, error) {
log := getLogger()
if ciphertext == "" {
log.Debug("empty ciphertext provided, returning empty string")
return "", nil
}
@@ -108,5 +128,6 @@ func (e *encryptor) Decrypt(ciphertext string) (string, error) {
return "", err
}
log.Debug("data decrypted", "inputLength", len(ciphertext), "outputLength", len(plaintext))
return string(plaintext), nil
}

View File

@@ -6,6 +6,7 @@ import (
"testing"
"novamd/internal/secrets"
_ "novamd/internal/testenv"
)
func TestValidateKey(t *testing.T) {

View File

@@ -1,5 +1,4 @@
// storage/errors.go
// Package storage provides functionalities to interact with the storage system (filesystem).
package storage
import (

View File

@@ -1,5 +1,3 @@
// Package storage provides functionalities to interact with the file system,
// including listing files, finding files by name, getting file content, saving files, and deleting files.
package storage
import (
@@ -33,7 +31,12 @@ type FileNode struct {
// Workspace is identified by the given userID and workspaceID.
func (s *Service) ListFilesRecursively(userID, workspaceID int) ([]FileNode, error) {
workspacePath := s.GetWorkspacePath(userID, workspaceID)
return s.walkDirectory(workspacePath, "")
nodes, err := s.walkDirectory(workspacePath, "")
if err != nil {
return nil, err
}
return nodes, nil
}
// walkDirectory recursively walks the directory and returns a list of files and directories.
@@ -147,6 +150,8 @@ func (s *Service) GetFileContent(userID, workspaceID int, filePath string) ([]by
// SaveFile writes the content to the file at the given filePath.
// Path must be a relative path within the workspace directory given by userID and workspaceID.
func (s *Service) SaveFile(userID, workspaceID int, filePath string, content []byte) error {
log := getLogger()
fullPath, err := s.ValidatePath(userID, workspaceID, filePath)
if err != nil {
return err
@@ -157,17 +162,36 @@ func (s *Service) SaveFile(userID, workspaceID int, filePath string, content []b
return err
}
return s.fs.WriteFile(fullPath, content, 0644)
if err := s.fs.WriteFile(fullPath, content, 0644); err != nil {
return err
}
log.Debug("file saved",
"userID", userID,
"workspaceID", workspaceID,
"path", filePath,
"size", len(content))
return nil
}
// DeleteFile deletes the file at the given filePath.
// Path must be a relative path within the workspace directory given by userID and workspaceID.
func (s *Service) DeleteFile(userID, workspaceID int, filePath string) error {
log := getLogger()
fullPath, err := s.ValidatePath(userID, workspaceID, filePath)
if err != nil {
return err
}
return s.fs.Remove(fullPath)
if err := s.fs.Remove(fullPath); err != nil {
return err
}
log.Debug("file deleted",
"userID", userID,
"workspaceID", workspaceID,
"path", filePath)
return nil
}
// FileCountStats holds statistics about files in a workspace
@@ -186,13 +210,22 @@ func (s *Service) GetFileStats(userID, workspaceID int) (*FileCountStats, error)
return nil, fmt.Errorf("workspace directory does not exist")
}
return s.countFilesInPath(workspacePath)
stats, err := s.countFilesInPath(workspacePath)
if err != nil {
return nil, err
}
return stats, nil
}
// GetTotalFileStats returns the total file statistics for the storage.
func (s *Service) GetTotalFileStats() (*FileCountStats, error) {
return s.countFilesInPath(s.RootDir)
stats, err := s.countFilesInPath(s.RootDir)
if err != nil {
return nil, err
}
return stats, nil
}
// countFilesInPath counts the total number of files and the total size of files in the given directory.

View File

@@ -5,6 +5,8 @@ import (
"novamd/internal/storage"
"path/filepath"
"testing"
_ "novamd/internal/testenv"
)
// TestFileNode ensures FileNode structs are created correctly

View File

@@ -2,6 +2,7 @@ package storage
import (
"io/fs"
"novamd/internal/logging"
"os"
)
@@ -17,6 +18,15 @@ type fileSystem interface {
IsNotExist(err error) bool
}
var logger logging.Logger
func getLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("storage")
}
return logger
}
// osFS implements the FileSystem interface using the real filesystem.
type osFS struct{}

View File

@@ -5,6 +5,8 @@ import (
"io/fs"
"path/filepath"
"time"
_ "novamd/internal/testenv"
)
type mockDirEntry struct {

View File

@@ -17,15 +17,23 @@ type RepositoryManager interface {
// The repository is cloned from the given gitURL using the given gitUser and gitToken.
func (s *Service) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken, commitName, commitEmail string) error {
workspacePath := s.GetWorkspacePath(userID, workspaceID)
if _, ok := s.GitRepos[userID]; !ok {
s.GitRepos[userID] = make(map[int]git.Client)
}
s.GitRepos[userID][workspaceID] = s.newGitClient(gitURL, gitUser, gitToken, workspacePath, commitName, commitEmail)
return s.GitRepos[userID][workspaceID].EnsureRepo()
}
// DisableGitRepo disables the Git repository for the given userID and workspaceID.
func (s *Service) DisableGitRepo(userID, workspaceID int) {
log := getLogger().WithGroup("git")
log.Debug("disabling git repository",
"userID", userID,
"workspaceID", workspaceID)
if userRepos, ok := s.GitRepos[userID]; ok {
delete(userRepos, workspaceID)
if len(userRepos) == 0 {
@@ -47,8 +55,11 @@ func (s *Service) StageCommitAndPush(userID, workspaceID int, message string) (g
return git.CommitHash{}, err
}
err = repo.Push()
if err = repo.Push(); err != nil {
return hash, err
}
return hash, nil
}
// Pull pulls the changes from the remote Git repository.
@@ -59,7 +70,12 @@ func (s *Service) Pull(userID, workspaceID int) error {
return fmt.Errorf("git settings not configured for this workspace")
}
return repo.Pull()
err := repo.Pull()
if err != nil {
return err
}
return nil
}
// getGitRepo returns the Git repository for the given user and workspace IDs.

View File

@@ -6,6 +6,7 @@ import (
"novamd/internal/git"
"novamd/internal/storage"
_ "novamd/internal/testenv"
)
// MockGitClient implements git.Client interface for testing

View File

@@ -43,6 +43,11 @@ func (s *Service) GetWorkspacePath(userID, workspaceID int) string {
// InitializeUserWorkspace creates the workspace directory for the given userID and workspaceID.
func (s *Service) InitializeUserWorkspace(userID, workspaceID int) error {
log := getLogger()
log.Debug("initializing workspace directory",
"userID", userID,
"workspaceID", workspaceID)
workspacePath := s.GetWorkspacePath(userID, workspaceID)
err := s.fs.MkdirAll(workspacePath, 0755)
if err != nil {
@@ -54,6 +59,11 @@ func (s *Service) InitializeUserWorkspace(userID, workspaceID int) error {
// DeleteUserWorkspace deletes the workspace directory for the given userID and workspaceID.
func (s *Service) DeleteUserWorkspace(userID, workspaceID int) error {
log := getLogger()
log.Debug("deleting workspace directory",
"userID", userID,
"workspaceID", workspaceID)
workspacePath := s.GetWorkspacePath(userID, workspaceID)
err := s.fs.RemoveAll(workspacePath)
if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"testing"
"novamd/internal/storage"
_ "novamd/internal/testenv"
)
func TestValidatePath(t *testing.T) {

View File

@@ -0,0 +1,9 @@
// Package testenv provides a setup for testing the application.
package testenv
import "novamd/internal/logging"
func init() {
// Initialize the logger
logging.Setup(logging.ERROR)
}