Add logging to auth package

This commit is contained in:
2024-12-15 18:03:04 +01:00
parent d6680d8e03
commit 3edce8a0b9
5 changed files with 172 additions and 13 deletions

View File

@@ -3,8 +3,22 @@ package auth
import ( import (
"net/http" "net/http"
"novamd/internal/logging"
) )
var logger logging.Logger
func getAuthLogger() logging.Logger {
if logger == nil {
logger = logging.WithGroup("auth")
}
return logger
}
func getCookieLogger() logging.Logger {
return getAuthLogger().WithGroup("cookie")
}
// CookieManager interface defines methods for generating cookies // CookieManager interface defines methods for generating cookies
type CookieManager interface { type CookieManager interface {
GenerateAccessTokenCookie(token string) *http.Cookie GenerateAccessTokenCookie(token string) *http.Cookie
@@ -22,6 +36,8 @@ type cookieManager struct {
// NewCookieService creates a new cookie service // NewCookieService creates a new cookie service
func NewCookieService(isDevelopment bool, domain string) CookieManager { func NewCookieService(isDevelopment bool, domain string) CookieManager {
log := getCookieLogger()
secure := !isDevelopment secure := !isDevelopment
var sameSite http.SameSite var sameSite http.SameSite
@@ -31,6 +47,11 @@ func NewCookieService(isDevelopment bool, domain string) CookieManager {
sameSite = http.SameSiteStrictMode sameSite = http.SameSiteStrictMode
} }
log.Debug("creating cookie service",
"secure", secure,
"sameSite", sameSite,
"domain", domain)
return &cookieManager{ return &cookieManager{
Domain: domain, Domain: domain,
Secure: secure, Secure: secure,
@@ -40,6 +61,12 @@ func NewCookieService(isDevelopment bool, domain string) CookieManager {
// GenerateAccessTokenCookie creates a new cookie for the access token // GenerateAccessTokenCookie creates a new cookie for the access token
func (c *cookieManager) GenerateAccessTokenCookie(token string) *http.Cookie { func (c *cookieManager) GenerateAccessTokenCookie(token string) *http.Cookie {
log := getCookieLogger()
log.Debug("generating access token cookie",
"secure", c.Secure,
"sameSite", c.SameSite,
"maxAge", 900)
return &http.Cookie{ return &http.Cookie{
Name: "access_token", Name: "access_token",
Value: token, Value: token,
@@ -53,6 +80,12 @@ func (c *cookieManager) GenerateAccessTokenCookie(token string) *http.Cookie {
// GenerateRefreshTokenCookie creates a new cookie for the refresh token // GenerateRefreshTokenCookie creates a new cookie for the refresh token
func (c *cookieManager) GenerateRefreshTokenCookie(token string) *http.Cookie { func (c *cookieManager) GenerateRefreshTokenCookie(token string) *http.Cookie {
log := getCookieLogger()
log.Debug("generating refresh token cookie",
"secure", c.Secure,
"sameSite", c.SameSite,
"maxAge", 604800)
return &http.Cookie{ return &http.Cookie{
Name: "refresh_token", Name: "refresh_token",
Value: token, Value: token,
@@ -66,6 +99,13 @@ func (c *cookieManager) GenerateRefreshTokenCookie(token string) *http.Cookie {
// GenerateCSRFCookie creates a new cookie for the CSRF token // GenerateCSRFCookie creates a new cookie for the CSRF token
func (c *cookieManager) GenerateCSRFCookie(token string) *http.Cookie { func (c *cookieManager) GenerateCSRFCookie(token string) *http.Cookie {
log := getCookieLogger()
log.Debug("generating CSRF cookie",
"secure", c.Secure,
"sameSite", c.SameSite,
"maxAge", 900,
"httpOnly", false)
return &http.Cookie{ return &http.Cookie{
Name: "csrf_token", Name: "csrf_token",
Value: token, Value: token,
@@ -79,6 +119,12 @@ func (c *cookieManager) GenerateCSRFCookie(token string) *http.Cookie {
// InvalidateCookie creates a new cookie with a MaxAge of -1 to invalidate the cookie // InvalidateCookie creates a new cookie with a MaxAge of -1 to invalidate the cookie
func (c *cookieManager) InvalidateCookie(cookieType string) *http.Cookie { func (c *cookieManager) InvalidateCookie(cookieType string) *http.Cookie {
log := getCookieLogger()
log.Debug("invalidating cookie",
"type", cookieType,
"secure", c.Secure,
"sameSite", c.SameSite)
return &http.Cookie{ return &http.Cookie{
Name: cookieType, Name: cookieType,
Value: "", Value: "",

View File

@@ -4,11 +4,16 @@ package auth
import ( import (
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"novamd/internal/logging"
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
) )
func getJWTLogger() logging.Logger {
return getAuthLogger().WithGroup("jwt")
}
// TokenType represents the type of JWT token (access or refresh) // TokenType represents the type of JWT token (access or refresh)
type TokenType string type TokenType string
@@ -47,16 +52,26 @@ type jwtService struct {
// NewJWTService creates a new JWT service with the provided configuration // NewJWTService creates a new JWT service with the provided configuration
// Returns an error if the signing key is missing // Returns an error if the signing key is missing
func NewJWTService(config JWTConfig) (JWTManager, error) { func NewJWTService(config JWTConfig) (JWTManager, error) {
log := getJWTLogger()
if config.SigningKey == "" { if config.SigningKey == "" {
return nil, fmt.Errorf("signing key is required") return nil, fmt.Errorf("signing key is required")
} }
// Set default expiry times if not provided // Set default expiry times if not provided
if config.AccessTokenExpiry == 0 { if config.AccessTokenExpiry == 0 {
config.AccessTokenExpiry = 15 * time.Minute // Default to 15 minutes config.AccessTokenExpiry = 15 * time.Minute
log.Debug("using default access token expiry", "expiry", config.AccessTokenExpiry)
} }
if config.RefreshTokenExpiry == 0 { if config.RefreshTokenExpiry == 0 {
config.RefreshTokenExpiry = 7 * 24 * time.Hour // Default to 7 days config.RefreshTokenExpiry = 7 * 24 * time.Hour
log.Debug("using default refresh token expiry", "expiry", config.RefreshTokenExpiry)
} }
log.Info("initialized JWT service",
"accessExpiry", config.AccessTokenExpiry,
"refreshExpiry", config.RefreshTokenExpiry)
return &jwtService{config: config}, nil return &jwtService{config: config}, nil
} }
@@ -72,6 +87,7 @@ func (s *jwtService) GenerateRefreshToken(userID int, role, sessionID string) (s
// generateToken is an internal helper function that creates a new JWT token // generateToken is an internal helper function that creates a new JWT token
func (s *jwtService) generateToken(userID int, role string, sessionID string, tokenType TokenType, expiry time.Duration) (string, error) { func (s *jwtService) generateToken(userID int, role string, sessionID string, tokenType TokenType, expiry time.Duration) (string, error) {
log := getJWTLogger()
now := time.Now() now := time.Now()
// Add a random nonce to ensure uniqueness // Add a random nonce to ensure uniqueness
@@ -93,11 +109,24 @@ func (s *jwtService) generateToken(userID int, role string, sessionID string, to
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(s.config.SigningKey)) signedToken, err := token.SignedString([]byte(s.config.SigningKey))
if err != nil {
return "", err
}
log.Debug("generated JWT token",
"userId", userID,
"role", role,
"tokenType", tokenType,
"expiresAt", claims.ExpiresAt)
return signedToken, nil
} }
// ValidateToken validates and parses a JWT token // ValidateToken validates and parses a JWT token
func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) { func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) {
log := getJWTLogger()
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
// Validate the signing method // Validate the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
@@ -110,9 +139,16 @@ func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) {
return nil, fmt.Errorf("invalid token: %w", err) return nil, fmt.Errorf("invalid token: %w", err)
} }
if claims, ok := token.Claims.(*Claims); ok && token.Valid { claims, ok := token.Claims.(*Claims)
return claims, nil if !ok || !token.Valid {
return nil, fmt.Errorf("invalid token claims")
} }
return nil, fmt.Errorf("invalid token claims") log.Debug("token validated",
"userId", claims.UserID,
"role", claims.Role,
"tokenType", claims.Type,
"expiresAt", claims.ExpiresAt)
return claims, nil
} }

View File

@@ -8,8 +8,6 @@ import (
"novamd/internal/auth" "novamd/internal/auth"
) )
// jwt_test.go tests
func TestNewJWTService(t *testing.T) { func TestNewJWTService(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string

View File

@@ -3,10 +3,14 @@ package auth
import ( import (
"crypto/subtle" "crypto/subtle"
"net/http" "net/http"
"novamd/internal/context" "novamd/internal/context"
"novamd/internal/logging"
) )
func getMiddlewareLogger() logging.Logger {
return getAuthLogger().WithGroup("middleware")
}
// Middleware handles JWT authentication for protected routes // Middleware handles JWT authentication for protected routes
type Middleware struct { type Middleware struct {
jwtManager JWTManager jwtManager JWTManager
@@ -16,6 +20,9 @@ type Middleware struct {
// NewMiddleware creates a new authentication middleware // NewMiddleware creates a new authentication middleware
func NewMiddleware(jwtManager JWTManager, sessionManager SessionManager, cookieManager CookieManager) *Middleware { func NewMiddleware(jwtManager JWTManager, sessionManager SessionManager, cookieManager CookieManager) *Middleware {
log := getMiddlewareLogger()
log.Info("initialized auth middleware")
return &Middleware{ return &Middleware{
jwtManager: jwtManager, jwtManager: jwtManager,
sessionManager: sessionManager, sessionManager: sessionManager,
@@ -26,7 +33,9 @@ func NewMiddleware(jwtManager JWTManager, sessionManager SessionManager, cookieM
// Authenticate middleware validates JWT tokens and sets user information in context // Authenticate middleware validates JWT tokens and sets user information in context
func (m *Middleware) Authenticate(next http.Handler) http.Handler { func (m *Middleware) Authenticate(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract token from Authorization header log := getMiddlewareLogger()
// Extract token from cookie
cookie, err := r.Cookie("access_token") cookie, err := r.Cookie("access_token")
if err != nil { if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
@@ -82,6 +91,12 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
UserRole: claims.Role, UserRole: claims.Role,
} }
log.Debug("authentication completed",
"userId", claims.UserID,
"role", claims.Role,
"method", r.Method,
"path", r.URL.Path)
// Add context to request and continue // Add context to request and continue
next.ServeHTTP(w, context.WithHandlerContext(r, hctx)) next.ServeHTTP(w, context.WithHandlerContext(r, hctx))
}) })
@@ -91,6 +106,8 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler { func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := getMiddlewareLogger()
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
if !ok { if !ok {
return return
@@ -101,6 +118,11 @@ func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
return return
} }
log.Debug("role requirement satisfied",
"requiredRole", role,
"userRole", ctx.UserRole,
"path", r.URL.Path)
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
@@ -109,6 +131,8 @@ func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
// RequireWorkspaceAccess returns a middleware that ensures the user has access to the workspace // RequireWorkspaceAccess returns a middleware that ensures the user has access to the workspace
func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler { func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := getMiddlewareLogger()
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
if !ok { if !ok {
return return
@@ -126,6 +150,11 @@ func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
return return
} }
log.Debug("workspace access granted",
"userId", ctx.UserID,
"workspaceId", ctx.Workspace.ID,
"path", r.URL.Path)
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }

View File

@@ -3,12 +3,17 @@ package auth
import ( import (
"fmt" "fmt"
"novamd/internal/db" "novamd/internal/db"
"novamd/internal/logging"
"novamd/internal/models" "novamd/internal/models"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
) )
func getSessionLogger() logging.Logger {
return getAuthLogger().WithGroup("session")
}
// SessionManager is an interface for managing user sessions // SessionManager is an interface for managing user sessions
type SessionManager interface { type SessionManager interface {
CreateSession(userID int, role string) (*models.Session, string, error) CreateSession(userID int, role string) (*models.Session, string, error)
@@ -27,6 +32,9 @@ type sessionManager struct {
// NewSessionService creates a new session service with the given database and JWT manager // NewSessionService creates a new session service with the given database and JWT manager
// revive:disable:unexported-return // revive:disable:unexported-return
func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManager { func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManager {
log := getSessionLogger()
log.Info("initialized session manager")
return &sessionManager{ return &sessionManager{
db: db, db: db,
jwtManager: jwtManager, jwtManager: jwtManager,
@@ -35,6 +43,7 @@ func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManage
// CreateSession creates a new user session for a user with the given userID and role // CreateSession creates a new user session for a user with the given userID and role
func (s *sessionManager) CreateSession(userID int, role string) (*models.Session, string, error) { func (s *sessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
log := getSessionLogger()
// Generate a new session ID // Generate a new session ID
sessionID := uuid.New().String() sessionID := uuid.New().String()
@@ -70,11 +79,19 @@ func (s *sessionManager) CreateSession(userID int, role string) (*models.Session
return nil, "", err return nil, "", err
} }
log.Debug("created new session",
"userId", userID,
"role", role,
"sessionId", sessionID,
"expiresAt", claims.ExpiresAt.Time)
return session, accessToken, nil return session, accessToken, nil
} }
// RefreshSession creates a new access token using a refreshToken // RefreshSession creates a new access token using a refreshToken
func (s *sessionManager) RefreshSession(refreshToken string) (string, error) { func (s *sessionManager) RefreshSession(refreshToken string) (string, error) {
log := getSessionLogger()
// Get session from database first // Get session from database first
session, err := s.db.GetSessionByRefreshToken(refreshToken) session, err := s.db.GetSessionByRefreshToken(refreshToken)
if err != nil { if err != nil {
@@ -93,11 +110,22 @@ func (s *sessionManager) RefreshSession(refreshToken string) (string, error) {
} }
// Generate a new access token // Generate a new access token
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role, session.ID) newToken, err := s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role, session.ID)
if err != nil {
return "", err
}
log.Debug("refreshed session",
"userId", claims.UserID,
"role", claims.Role,
"sessionId", session.ID)
return newToken, nil
} }
// ValidateSession checks if a session with the given sessionID is valid // ValidateSession checks if a session with the given sessionID is valid
func (s *sessionManager) ValidateSession(sessionID string) (*models.Session, error) { func (s *sessionManager) ValidateSession(sessionID string) (*models.Session, error) {
log := getSessionLogger()
// Get the session from the database // Get the session from the database
session, err := s.db.GetSessionByID(sessionID) session, err := s.db.GetSessionByID(sessionID)
@@ -105,21 +133,43 @@ func (s *sessionManager) ValidateSession(sessionID string) (*models.Session, err
return nil, fmt.Errorf("failed to get session: %w", err) return nil, fmt.Errorf("failed to get session: %w", err)
} }
log.Debug("validated session",
"sessionId", sessionID,
"userId", session.UserID,
"expiresAt", session.ExpiresAt)
return session, nil return session, nil
} }
// InvalidateSession removes a session with the given sessionID from the database // InvalidateSession removes a session with the given sessionID from the database
func (s *sessionManager) InvalidateSession(token string) error { func (s *sessionManager) InvalidateSession(token string) error {
log := getSessionLogger()
// Parse the JWT to get the session info // Parse the JWT to get the session info
claims, err := s.jwtManager.ValidateToken(token) claims, err := s.jwtManager.ValidateToken(token)
if err != nil { if err != nil {
return fmt.Errorf("invalid token: %w", err) return fmt.Errorf("invalid token: %w", err)
} }
return s.db.DeleteSession(claims.ID) if err := s.db.DeleteSession(claims.ID); err != nil {
return err
}
log.Debug("invalidated session",
"sessionId", claims.ID,
"userId", claims.UserID)
return nil
} }
// CleanExpiredSessions removes all expired sessions from the database // CleanExpiredSessions removes all expired sessions from the database
func (s *sessionManager) CleanExpiredSessions() error { func (s *sessionManager) CleanExpiredSessions() error {
return s.db.CleanExpiredSessions() log := getSessionLogger()
if err := s.db.CleanExpiredSessions(); err != nil {
return err
}
log.Info("cleaned expired sessions")
return nil
} }