From 2faefb6db52bd0e6c903658c5d6e0566b90a4f63 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 21 Nov 2024 21:25:29 +0100 Subject: [PATCH] Implement JWTManager interface --- server/cmd/server/main.go | 6 ++--- server/internal/auth/jwt.go | 27 +++++++++++++++-------- server/internal/auth/middleware.go | 35 ++++++++++++++++++++++++------ server/internal/auth/session.go | 20 ++++++++--------- 4 files changed, 59 insertions(+), 29 deletions(-) diff --git a/server/cmd/server/main.go b/server/cmd/server/main.go index c6204ba..951d8b8 100644 --- a/server/cmd/server/main.go +++ b/server/cmd/server/main.go @@ -48,7 +48,7 @@ func main() { s := storage.NewService(cfg.WorkDir) // Initialize JWT service - jwtService, err := auth.NewJWTService(auth.JWTConfig{ + jwtManager, err := auth.NewJWTService(auth.JWTConfig{ SigningKey: signingKey, AccessTokenExpiry: 15 * time.Minute, RefreshTokenExpiry: 7 * 24 * time.Hour, @@ -58,10 +58,10 @@ func main() { } // Initialize auth middleware - authMiddleware := auth.NewMiddleware(jwtService) + authMiddleware := auth.NewMiddleware(jwtManager) // Initialize session service - sessionService := auth.NewSessionService(database.DB, jwtService) + sessionService := auth.NewSessionService(database.DB, jwtManager) // Set up router r := chi.NewRouter() diff --git a/server/internal/auth/jwt.go b/server/internal/auth/jwt.go index b1c0480..9b65db8 100644 --- a/server/internal/auth/jwt.go +++ b/server/internal/auth/jwt.go @@ -1,3 +1,4 @@ +// Package auth provides JWT token generation and validation package auth import ( @@ -30,14 +31,22 @@ type JWTConfig struct { RefreshTokenExpiry time.Duration // How long refresh tokens are valid } -// JWTService handles JWT token generation and validation -type JWTService struct { +// JWTManager defines the interface for managing JWT tokens +type JWTManager interface { + GenerateAccessToken(userID int, role string) (string, error) + GenerateRefreshToken(userID int, role string) (string, error) + ValidateToken(tokenString string) (*Claims, error) + RefreshAccessToken(refreshToken string) (string, error) +} + +// jwtService handles JWT token generation and validation +type jwtService struct { config JWTConfig } // NewJWTService creates a new JWT service with the provided configuration // Returns an error if the signing key is missing -func NewJWTService(config JWTConfig) (*JWTService, error) { +func NewJWTService(config JWTConfig) (JWTManager, error) { if config.SigningKey == "" { return nil, fmt.Errorf("signing key is required") } @@ -48,7 +57,7 @@ func NewJWTService(config JWTConfig) (*JWTService, error) { if config.RefreshTokenExpiry == 0 { config.RefreshTokenExpiry = 7 * 24 * time.Hour // Default to 7 days } - return &JWTService{config: config}, nil + return &jwtService{config: config}, nil } // GenerateAccessToken creates a new access token for a user @@ -56,7 +65,7 @@ func NewJWTService(config JWTConfig) (*JWTService, error) { // - userID: the ID of the user // - role: the role of the user // Returns the signed token string or an error -func (s *JWTService) GenerateAccessToken(userID int, role string) (string, error) { +func (s *jwtService) GenerateAccessToken(userID int, role string) (string, error) { return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry) } @@ -65,7 +74,7 @@ func (s *JWTService) GenerateAccessToken(userID int, role string) (string, error // - userID: the ID of the user // - role: the role of the user // Returns the signed token string or an error -func (s *JWTService) GenerateRefreshToken(userID int, role string) (string, error) { +func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, error) { return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry) } @@ -76,7 +85,7 @@ func (s *JWTService) GenerateRefreshToken(userID int, role string) (string, erro // - tokenType: the type of token (access or refresh) // - expiry: how long the token should be valid // Returns the signed token string or an error -func (s *JWTService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) { +func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) { now := time.Now() claims := Claims{ RegisteredClaims: jwt.RegisteredClaims{ @@ -97,7 +106,7 @@ func (s *JWTService) generateToken(userID int, role string, tokenType TokenType, // Parameters: // - tokenString: the token to validate // Returns the token claims if valid, or an error if invalid -func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) { +func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { // Validate the signing method if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { @@ -121,7 +130,7 @@ func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) { // Parameters: // - refreshToken: the refresh token to use // Returns a new access token if the refresh token is valid, or an error -func (s *JWTService) RefreshAccessToken(refreshToken string) (string, error) { +func (s *jwtService) RefreshAccessToken(refreshToken string) (string, error) { claims, err := s.ValidateToken(refreshToken) if err != nil { return "", fmt.Errorf("invalid refresh token: %w", err) diff --git a/server/internal/auth/middleware.go b/server/internal/auth/middleware.go index da2713d..864ebd3 100644 --- a/server/internal/auth/middleware.go +++ b/server/internal/auth/middleware.go @@ -11,9 +11,8 @@ import ( type contextKey string -const ( - UserContextKey contextKey = "user" -) +// UserContextKey is the key used to store user claims in the request context +const UserContextKey contextKey = "user" // UserClaims represents the user information stored in the request context type UserClaims struct { @@ -23,17 +22,25 @@ type UserClaims struct { // Middleware handles JWT authentication for protected routes type Middleware struct { - jwtService *JWTService + jwtManager JWTManager } // NewMiddleware creates a new authentication middleware -func NewMiddleware(jwtService *JWTService) *Middleware { +// Parameters: +// - jwtManager: the JWT manager to use for token operations +// Returns: +// - *Middleware: the new middleware instance +func NewMiddleware(jwtManager JWTManager) *Middleware { return &Middleware{ - jwtService: jwtService, + jwtManager: jwtManager, } } // Authenticate middleware validates JWT tokens and sets user information in context +// Parameters: +// - next: the next handler to call +// Returns: +// - http.Handler: the handler function func (m *Middleware) Authenticate(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Extract token from Authorization header @@ -51,7 +58,7 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler { } // Validate token - claims, err := m.jwtService.ValidateToken(parts[1]) + claims, err := m.jwtManager.ValidateToken(parts[1]) if err != nil { http.Error(w, "Invalid token", http.StatusUnauthorized) return @@ -75,6 +82,10 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler { } // RequireRole returns a middleware that ensures the user has the required role +// Parameters: +// - role: the required role +// Returns: +// - func(http.Handler) http.Handler: the middleware function func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -94,6 +105,11 @@ func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler { } } +// RequireWorkspaceAccess returns a middleware that ensures the user has access to the workspace +// Parameters: +// - next: the next handler to call +// Returns: +// - http.Handler: the handler function func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get our handler context @@ -119,6 +135,11 @@ func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler { } // GetUserFromContext retrieves user claims from the request context +// Parameters: +// - ctx: the request context +// Returns: +// - *UserClaims: the user claims +// - error: any error that occurred func GetUserFromContext(ctx context.Context) (*UserClaims, error) { claims, ok := ctx.Value(UserContextKey).(UserClaims) if !ok { diff --git a/server/internal/auth/session.go b/server/internal/auth/session.go index 8168ccc..bc0d165 100644 --- a/server/internal/auth/session.go +++ b/server/internal/auth/session.go @@ -19,18 +19,18 @@ type Session struct { // SessionService manages user sessions in the database type SessionService struct { - db *sql.DB // Database connection - jwtService *JWTService // JWT service for token operations + db *sql.DB // Database connection + jwtManager JWTManager // JWT Manager for token operations } // NewSessionService creates a new session service // Parameters: // - db: database connection -// - jwtService: JWT service for token operations -func NewSessionService(db *sql.DB, jwtService *JWTService) *SessionService { +// - jwtManager: JWT service for token operations +func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService { return &SessionService{ db: db, - jwtService: jwtService, + jwtManager: jwtManager, } } @@ -44,18 +44,18 @@ func NewSessionService(db *sql.DB, jwtService *JWTService) *SessionService { // - error: any error that occurred func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) { // Generate both access and refresh tokens - accessToken, err := s.jwtService.GenerateAccessToken(userID, role) + accessToken, err := s.jwtManager.GenerateAccessToken(userID, role) if err != nil { return nil, "", fmt.Errorf("failed to generate access token: %w", err) } - refreshToken, err := s.jwtService.GenerateRefreshToken(userID, role) + refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role) if err != nil { return nil, "", fmt.Errorf("failed to generate refresh token: %w", err) } // Validate the refresh token to get its expiry time - claims, err := s.jwtService.ValidateToken(refreshToken) + claims, err := s.jwtManager.ValidateToken(refreshToken) if err != nil { return nil, "", fmt.Errorf("failed to validate refresh token: %w", err) } @@ -90,7 +90,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin // - error: any error that occurred func (s *SessionService) RefreshSession(refreshToken string) (string, error) { // Validate the refresh token - claims, err := s.jwtService.ValidateToken(refreshToken) + claims, err := s.jwtManager.ValidateToken(refreshToken) if err != nil { return "", fmt.Errorf("invalid refresh token: %w", err) } @@ -112,7 +112,7 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) { } // Generate a new access token - return s.jwtService.GenerateAccessToken(claims.UserID, claims.Role) + return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role) } // InvalidateSession removes a session from the database