diff --git a/backend/internal/auth/jwt.go b/backend/internal/auth/jwt.go new file mode 100644 index 0000000..b1c0480 --- /dev/null +++ b/backend/internal/auth/jwt.go @@ -0,0 +1,135 @@ +package auth + +import ( + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// TokenType represents the type of JWT token (access or refresh) +type TokenType string + +const ( + AccessToken TokenType = "access" // AccessToken - Short-lived token for API access + RefreshToken TokenType = "refresh" // RefreshToken - Long-lived token for obtaining new access tokens +) + +// Claims represents the custom claims we store in JWT tokens +type Claims struct { + jwt.RegisteredClaims // Embedded standard JWT claims + UserID int `json:"uid"` // User identifier + Role string `json:"role"` // User role (admin, editor, viewer) + Type TokenType `json:"type"` // Token type (access or refresh) +} + +// JWTConfig holds the configuration for the JWT service +type JWTConfig struct { + SigningKey string // Secret key used to sign tokens + AccessTokenExpiry time.Duration // How long access tokens are valid + RefreshTokenExpiry time.Duration // How long refresh tokens are valid +} + +// JWTService handles JWT token generation and validation +type JWTService struct { + config JWTConfig +} + +// NewJWTService creates a new JWT service with the provided configuration +// Returns an error if the signing key is missing +func NewJWTService(config JWTConfig) (*JWTService, error) { + if config.SigningKey == "" { + return nil, fmt.Errorf("signing key is required") + } + // Set default expiry times if not provided + if config.AccessTokenExpiry == 0 { + config.AccessTokenExpiry = 15 * time.Minute // Default to 15 minutes + } + if config.RefreshTokenExpiry == 0 { + config.RefreshTokenExpiry = 7 * 24 * time.Hour // Default to 7 days + } + return &JWTService{config: config}, nil +} + +// GenerateAccessToken creates a new access token for a user +// Parameters: +// - userID: the ID of the user +// - role: the role of the user +// Returns the signed token string or an error +func (s *JWTService) GenerateAccessToken(userID int, role string) (string, error) { + return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry) +} + +// GenerateRefreshToken creates a new refresh token for a user +// Parameters: +// - userID: the ID of the user +// - role: the role of the user +// Returns the signed token string or an error +func (s *JWTService) GenerateRefreshToken(userID int, role string) (string, error) { + return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry) +} + +// generateToken is an internal helper function that creates a new JWT token +// Parameters: +// - userID: the ID of the user +// - role: the role of the user +// - tokenType: the type of token (access or refresh) +// - expiry: how long the token should be valid +// Returns the signed token string or an error +func (s *JWTService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) { + now := time.Now() + claims := Claims{ + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(expiry)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + }, + UserID: userID, + Role: role, + Type: tokenType, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(s.config.SigningKey)) +} + +// ValidateToken validates and parses a JWT token +// Parameters: +// - tokenString: the token to validate +// Returns the token claims if valid, or an error if invalid +func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) { + token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { + // Validate the signing method + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(s.config.SigningKey), nil + }) + + if err != nil { + return nil, fmt.Errorf("invalid token: %w", err) + } + + if claims, ok := token.Claims.(*Claims); ok && token.Valid { + return claims, nil + } + + return nil, fmt.Errorf("invalid token claims") +} + +// RefreshAccessToken creates a new access token using a refresh token +// Parameters: +// - refreshToken: the refresh token to use +// Returns a new access token if the refresh token is valid, or an error +func (s *JWTService) RefreshAccessToken(refreshToken string) (string, error) { + claims, err := s.ValidateToken(refreshToken) + if err != nil { + return "", fmt.Errorf("invalid refresh token: %w", err) + } + + if claims.Type != RefreshToken { + return "", fmt.Errorf("invalid token type: expected refresh token") + } + + return s.GenerateAccessToken(claims.UserID, claims.Role) +} diff --git a/backend/internal/auth/middleware.go b/backend/internal/auth/middleware.go new file mode 100644 index 0000000..3bac176 --- /dev/null +++ b/backend/internal/auth/middleware.go @@ -0,0 +1,102 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + "strings" +) + +type contextKey string + +const ( + UserContextKey contextKey = "user" +) + +// UserClaims represents the user information stored in the request context +type UserClaims struct { + UserID int + Role string +} + +// Middleware handles JWT authentication for protected routes +type Middleware struct { + jwtService *JWTService +} + +// NewMiddleware creates a new authentication middleware +func NewMiddleware(jwtService *JWTService) *Middleware { + return &Middleware{ + jwtService: jwtService, + } +} + +// Authenticate middleware validates JWT tokens and sets user information in context +func (m *Middleware) Authenticate(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Extract token from Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + http.Error(w, "Authorization header required", http.StatusUnauthorized) + return + } + + // Check Bearer token format + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + http.Error(w, "Invalid authorization format", http.StatusUnauthorized) + return + } + + // Validate token + claims, err := m.jwtService.ValidateToken(parts[1]) + if err != nil { + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + // Check token type + if claims.Type != AccessToken { + http.Error(w, "Invalid token type", http.StatusUnauthorized) + return + } + + // Add user claims to request context + ctx := context.WithValue(r.Context(), UserContextKey, UserClaims{ + UserID: claims.UserID, + Role: claims.Role, + }) + + // Call the next handler with the updated context + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// RequireRole returns a middleware that ensures the user has the required role +func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims, ok := r.Context().Value(UserContextKey).(UserClaims) + if !ok { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + if claims.Role != role && claims.Role != "admin" { + http.Error(w, "Insufficient permissions", http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// GetUserFromContext retrieves user claims from the request context +func GetUserFromContext(ctx context.Context) (*UserClaims, error) { + claims, ok := ctx.Value(UserContextKey).(UserClaims) + if !ok { + return nil, fmt.Errorf("no user found in context") + } + return &claims, nil +} diff --git a/backend/internal/auth/session.go b/backend/internal/auth/session.go new file mode 100644 index 0000000..8168ccc --- /dev/null +++ b/backend/internal/auth/session.go @@ -0,0 +1,140 @@ +package auth + +import ( + "database/sql" + "fmt" + "time" + + "github.com/google/uuid" +) + +// Session represents a user session in the database +type Session struct { + ID string // Unique session identifier + UserID int // ID of the user this session belongs to + RefreshToken string // The refresh token associated with this session + ExpiresAt time.Time // When this session expires + CreatedAt time.Time // When this session was created +} + +// SessionService manages user sessions in the database +type SessionService struct { + db *sql.DB // Database connection + jwtService *JWTService // JWT service for token operations +} + +// NewSessionService creates a new session service +// Parameters: +// - db: database connection +// - jwtService: JWT service for token operations +func NewSessionService(db *sql.DB, jwtService *JWTService) *SessionService { + return &SessionService{ + db: db, + jwtService: jwtService, + } +} + +// CreateSession creates a new user session +// Parameters: +// - userID: the ID of the user +// - role: the role of the user +// Returns: +// - session: the created session +// - accessToken: a new access token +// - error: any error that occurred +func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) { + // Generate both access and refresh tokens + accessToken, err := s.jwtService.GenerateAccessToken(userID, role) + if err != nil { + return nil, "", fmt.Errorf("failed to generate access token: %w", err) + } + + refreshToken, err := s.jwtService.GenerateRefreshToken(userID, role) + if err != nil { + return nil, "", fmt.Errorf("failed to generate refresh token: %w", err) + } + + // Validate the refresh token to get its expiry time + claims, err := s.jwtService.ValidateToken(refreshToken) + if err != nil { + return nil, "", fmt.Errorf("failed to validate refresh token: %w", err) + } + + // Create a new session record + session := &Session{ + ID: uuid.New().String(), + UserID: userID, + RefreshToken: refreshToken, + ExpiresAt: claims.ExpiresAt.Time, + CreatedAt: time.Now(), + } + + // Store the session in the database + _, err = s.db.Exec(` + INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at) + VALUES (?, ?, ?, ?, ?)`, + session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt, + ) + if err != nil { + return nil, "", fmt.Errorf("failed to store session: %w", err) + } + + return session, accessToken, nil +} + +// RefreshSession creates a new access token using a refresh token +// Parameters: +// - refreshToken: the refresh token to use +// Returns: +// - string: a new access token +// - error: any error that occurred +func (s *SessionService) RefreshSession(refreshToken string) (string, error) { + // Validate the refresh token + claims, err := s.jwtService.ValidateToken(refreshToken) + if err != nil { + return "", fmt.Errorf("invalid refresh token: %w", err) + } + + // Check if the session exists and is not expired + var session Session + err = s.db.QueryRow(` + SELECT id, user_id, refresh_token, expires_at, created_at + FROM sessions + WHERE refresh_token = ? AND expires_at > ?`, + refreshToken, time.Now(), + ).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) + + if err == sql.ErrNoRows { + return "", fmt.Errorf("session not found or expired") + } + if err != nil { + return "", fmt.Errorf("failed to fetch session: %w", err) + } + + // Generate a new access token + return s.jwtService.GenerateAccessToken(claims.UserID, claims.Role) +} + +// InvalidateSession removes a session from the database +// Parameters: +// - sessionID: the ID of the session to invalidate +// Returns: +// - error: any error that occurred +func (s *SessionService) InvalidateSession(sessionID string) error { + _, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID) + if err != nil { + return fmt.Errorf("failed to invalidate session: %w", err) + } + return nil +} + +// CleanExpiredSessions removes all expired sessions from the database +// Returns: +// - error: any error that occurred +func (s *SessionService) CleanExpiredSessions() error { + _, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now()) + if err != nil { + return fmt.Errorf("failed to clean expired sessions: %w", err) + } + return nil +}