Implement JWTManager interface

This commit is contained in:
2024-11-21 21:25:29 +01:00
parent 435dce89d9
commit 2faefb6db5
4 changed files with 59 additions and 29 deletions

View File

@@ -48,7 +48,7 @@ func main() {
s := storage.NewService(cfg.WorkDir) s := storage.NewService(cfg.WorkDir)
// Initialize JWT service // Initialize JWT service
jwtService, err := auth.NewJWTService(auth.JWTConfig{ jwtManager, err := auth.NewJWTService(auth.JWTConfig{
SigningKey: signingKey, SigningKey: signingKey,
AccessTokenExpiry: 15 * time.Minute, AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 7 * 24 * time.Hour, RefreshTokenExpiry: 7 * 24 * time.Hour,
@@ -58,10 +58,10 @@ func main() {
} }
// Initialize auth middleware // Initialize auth middleware
authMiddleware := auth.NewMiddleware(jwtService) authMiddleware := auth.NewMiddleware(jwtManager)
// Initialize session service // Initialize session service
sessionService := auth.NewSessionService(database.DB, jwtService) sessionService := auth.NewSessionService(database.DB, jwtManager)
// Set up router // Set up router
r := chi.NewRouter() r := chi.NewRouter()

View File

@@ -1,3 +1,4 @@
// Package auth provides JWT token generation and validation
package auth package auth
import ( import (
@@ -30,14 +31,22 @@ type JWTConfig struct {
RefreshTokenExpiry time.Duration // How long refresh tokens are valid RefreshTokenExpiry time.Duration // How long refresh tokens are valid
} }
// JWTService handles JWT token generation and validation // JWTManager defines the interface for managing JWT tokens
type JWTService struct { type JWTManager interface {
GenerateAccessToken(userID int, role string) (string, error)
GenerateRefreshToken(userID int, role string) (string, error)
ValidateToken(tokenString string) (*Claims, error)
RefreshAccessToken(refreshToken string) (string, error)
}
// jwtService handles JWT token generation and validation
type jwtService struct {
config JWTConfig config JWTConfig
} }
// NewJWTService creates a new JWT service with the provided configuration // NewJWTService creates a new JWT service with the provided configuration
// Returns an error if the signing key is missing // Returns an error if the signing key is missing
func NewJWTService(config JWTConfig) (*JWTService, error) { func NewJWTService(config JWTConfig) (JWTManager, error) {
if config.SigningKey == "" { if config.SigningKey == "" {
return nil, fmt.Errorf("signing key is required") return nil, fmt.Errorf("signing key is required")
} }
@@ -48,7 +57,7 @@ func NewJWTService(config JWTConfig) (*JWTService, error) {
if config.RefreshTokenExpiry == 0 { if config.RefreshTokenExpiry == 0 {
config.RefreshTokenExpiry = 7 * 24 * time.Hour // Default to 7 days config.RefreshTokenExpiry = 7 * 24 * time.Hour // Default to 7 days
} }
return &JWTService{config: config}, nil return &jwtService{config: config}, nil
} }
// GenerateAccessToken creates a new access token for a user // GenerateAccessToken creates a new access token for a user
@@ -56,7 +65,7 @@ func NewJWTService(config JWTConfig) (*JWTService, error) {
// - userID: the ID of the user // - userID: the ID of the user
// - role: the role of the user // - role: the role of the user
// Returns the signed token string or an error // Returns the signed token string or an error
func (s *JWTService) GenerateAccessToken(userID int, role string) (string, error) { func (s *jwtService) GenerateAccessToken(userID int, role string) (string, error) {
return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry) return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry)
} }
@@ -65,7 +74,7 @@ func (s *JWTService) GenerateAccessToken(userID int, role string) (string, error
// - userID: the ID of the user // - userID: the ID of the user
// - role: the role of the user // - role: the role of the user
// Returns the signed token string or an error // Returns the signed token string or an error
func (s *JWTService) GenerateRefreshToken(userID int, role string) (string, error) { func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, error) {
return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry) return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry)
} }
@@ -76,7 +85,7 @@ func (s *JWTService) GenerateRefreshToken(userID int, role string) (string, erro
// - tokenType: the type of token (access or refresh) // - tokenType: the type of token (access or refresh)
// - expiry: how long the token should be valid // - expiry: how long the token should be valid
// Returns the signed token string or an error // Returns the signed token string or an error
func (s *JWTService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) { func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) {
now := time.Now() now := time.Now()
claims := Claims{ claims := Claims{
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
@@ -97,7 +106,7 @@ func (s *JWTService) generateToken(userID int, role string, tokenType TokenType,
// Parameters: // Parameters:
// - tokenString: the token to validate // - tokenString: the token to validate
// Returns the token claims if valid, or an error if invalid // Returns the token claims if valid, or an error if invalid
func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) { func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
// Validate the signing method // Validate the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
@@ -121,7 +130,7 @@ func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) {
// Parameters: // Parameters:
// - refreshToken: the refresh token to use // - refreshToken: the refresh token to use
// Returns a new access token if the refresh token is valid, or an error // Returns a new access token if the refresh token is valid, or an error
func (s *JWTService) RefreshAccessToken(refreshToken string) (string, error) { func (s *jwtService) RefreshAccessToken(refreshToken string) (string, error) {
claims, err := s.ValidateToken(refreshToken) claims, err := s.ValidateToken(refreshToken)
if err != nil { if err != nil {
return "", fmt.Errorf("invalid refresh token: %w", err) return "", fmt.Errorf("invalid refresh token: %w", err)

View File

@@ -11,9 +11,8 @@ import (
type contextKey string type contextKey string
const ( // UserContextKey is the key used to store user claims in the request context
UserContextKey contextKey = "user" const UserContextKey contextKey = "user"
)
// UserClaims represents the user information stored in the request context // UserClaims represents the user information stored in the request context
type UserClaims struct { type UserClaims struct {
@@ -23,17 +22,25 @@ type UserClaims struct {
// Middleware handles JWT authentication for protected routes // Middleware handles JWT authentication for protected routes
type Middleware struct { type Middleware struct {
jwtService *JWTService jwtManager JWTManager
} }
// NewMiddleware creates a new authentication middleware // NewMiddleware creates a new authentication middleware
func NewMiddleware(jwtService *JWTService) *Middleware { // Parameters:
// - jwtManager: the JWT manager to use for token operations
// Returns:
// - *Middleware: the new middleware instance
func NewMiddleware(jwtManager JWTManager) *Middleware {
return &Middleware{ return &Middleware{
jwtService: jwtService, jwtManager: jwtManager,
} }
} }
// Authenticate middleware validates JWT tokens and sets user information in context // Authenticate middleware validates JWT tokens and sets user information in context
// Parameters:
// - next: the next handler to call
// Returns:
// - http.Handler: the handler function
func (m *Middleware) Authenticate(next http.Handler) http.Handler { func (m *Middleware) Authenticate(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract token from Authorization header // Extract token from Authorization header
@@ -51,7 +58,7 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
} }
// Validate token // Validate token
claims, err := m.jwtService.ValidateToken(parts[1]) claims, err := m.jwtManager.ValidateToken(parts[1])
if err != nil { if err != nil {
http.Error(w, "Invalid token", http.StatusUnauthorized) http.Error(w, "Invalid token", http.StatusUnauthorized)
return return
@@ -75,6 +82,10 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
} }
// RequireRole returns a middleware that ensures the user has the required role // RequireRole returns a middleware that ensures the user has the required role
// Parameters:
// - role: the required role
// Returns:
// - func(http.Handler) http.Handler: the middleware function
func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler { func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -94,6 +105,11 @@ func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
} }
} }
// RequireWorkspaceAccess returns a middleware that ensures the user has access to the workspace
// Parameters:
// - next: the next handler to call
// Returns:
// - http.Handler: the handler function
func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler { func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get our handler context // Get our handler context
@@ -119,6 +135,11 @@ func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
} }
// GetUserFromContext retrieves user claims from the request context // GetUserFromContext retrieves user claims from the request context
// Parameters:
// - ctx: the request context
// Returns:
// - *UserClaims: the user claims
// - error: any error that occurred
func GetUserFromContext(ctx context.Context) (*UserClaims, error) { func GetUserFromContext(ctx context.Context) (*UserClaims, error) {
claims, ok := ctx.Value(UserContextKey).(UserClaims) claims, ok := ctx.Value(UserContextKey).(UserClaims)
if !ok { if !ok {

View File

@@ -19,18 +19,18 @@ type Session struct {
// SessionService manages user sessions in the database // SessionService manages user sessions in the database
type SessionService struct { type SessionService struct {
db *sql.DB // Database connection db *sql.DB // Database connection
jwtService *JWTService // JWT service for token operations jwtManager JWTManager // JWT Manager for token operations
} }
// NewSessionService creates a new session service // NewSessionService creates a new session service
// Parameters: // Parameters:
// - db: database connection // - db: database connection
// - jwtService: JWT service for token operations // - jwtManager: JWT service for token operations
func NewSessionService(db *sql.DB, jwtService *JWTService) *SessionService { func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService {
return &SessionService{ return &SessionService{
db: db, db: db,
jwtService: jwtService, jwtManager: jwtManager,
} }
} }
@@ -44,18 +44,18 @@ func NewSessionService(db *sql.DB, jwtService *JWTService) *SessionService {
// - error: any error that occurred // - error: any error that occurred
func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) { func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) {
// Generate both access and refresh tokens // Generate both access and refresh tokens
accessToken, err := s.jwtService.GenerateAccessToken(userID, role) accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("failed to generate access token: %w", err) return nil, "", fmt.Errorf("failed to generate access token: %w", err)
} }
refreshToken, err := s.jwtService.GenerateRefreshToken(userID, role) refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("failed to generate refresh token: %w", err) return nil, "", fmt.Errorf("failed to generate refresh token: %w", err)
} }
// Validate the refresh token to get its expiry time // Validate the refresh token to get its expiry time
claims, err := s.jwtService.ValidateToken(refreshToken) claims, err := s.jwtManager.ValidateToken(refreshToken)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("failed to validate refresh token: %w", err) return nil, "", fmt.Errorf("failed to validate refresh token: %w", err)
} }
@@ -90,7 +90,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
// - error: any error that occurred // - error: any error that occurred
func (s *SessionService) RefreshSession(refreshToken string) (string, error) { func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
// Validate the refresh token // Validate the refresh token
claims, err := s.jwtService.ValidateToken(refreshToken) claims, err := s.jwtManager.ValidateToken(refreshToken)
if err != nil { if err != nil {
return "", fmt.Errorf("invalid refresh token: %w", err) return "", fmt.Errorf("invalid refresh token: %w", err)
} }
@@ -112,7 +112,7 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
} }
// Generate a new access token // Generate a new access token
return s.jwtService.GenerateAccessToken(claims.UserID, claims.Role) return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
} }
// InvalidateSession removes a session from the database // InvalidateSession removes a session from the database