mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-05 23:44:22 +00:00
Implement jwt auth backend
This commit is contained in:
135
backend/internal/auth/jwt.go
Normal file
135
backend/internal/auth/jwt.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenType represents the type of JWT token (access or refresh)
|
||||||
|
type TokenType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AccessToken TokenType = "access" // AccessToken - Short-lived token for API access
|
||||||
|
RefreshToken TokenType = "refresh" // RefreshToken - Long-lived token for obtaining new access tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
// Claims represents the custom claims we store in JWT tokens
|
||||||
|
type Claims struct {
|
||||||
|
jwt.RegisteredClaims // Embedded standard JWT claims
|
||||||
|
UserID int `json:"uid"` // User identifier
|
||||||
|
Role string `json:"role"` // User role (admin, editor, viewer)
|
||||||
|
Type TokenType `json:"type"` // Token type (access or refresh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTConfig holds the configuration for the JWT service
|
||||||
|
type JWTConfig struct {
|
||||||
|
SigningKey string // Secret key used to sign tokens
|
||||||
|
AccessTokenExpiry time.Duration // How long access tokens are valid
|
||||||
|
RefreshTokenExpiry time.Duration // How long refresh tokens are valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTService handles JWT token generation and validation
|
||||||
|
type JWTService struct {
|
||||||
|
config JWTConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJWTService creates a new JWT service with the provided configuration
|
||||||
|
// Returns an error if the signing key is missing
|
||||||
|
func NewJWTService(config JWTConfig) (*JWTService, error) {
|
||||||
|
if config.SigningKey == "" {
|
||||||
|
return nil, fmt.Errorf("signing key is required")
|
||||||
|
}
|
||||||
|
// Set default expiry times if not provided
|
||||||
|
if config.AccessTokenExpiry == 0 {
|
||||||
|
config.AccessTokenExpiry = 15 * time.Minute // Default to 15 minutes
|
||||||
|
}
|
||||||
|
if config.RefreshTokenExpiry == 0 {
|
||||||
|
config.RefreshTokenExpiry = 7 * 24 * time.Hour // Default to 7 days
|
||||||
|
}
|
||||||
|
return &JWTService{config: config}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateAccessToken creates a new access token for a user
|
||||||
|
// Parameters:
|
||||||
|
// - userID: the ID of the user
|
||||||
|
// - role: the role of the user
|
||||||
|
// Returns the signed token string or an error
|
||||||
|
func (s *JWTService) GenerateAccessToken(userID int, role string) (string, error) {
|
||||||
|
return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateRefreshToken creates a new refresh token for a user
|
||||||
|
// Parameters:
|
||||||
|
// - userID: the ID of the user
|
||||||
|
// - role: the role of the user
|
||||||
|
// Returns the signed token string or an error
|
||||||
|
func (s *JWTService) GenerateRefreshToken(userID int, role string) (string, error) {
|
||||||
|
return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateToken is an internal helper function that creates a new JWT token
|
||||||
|
// Parameters:
|
||||||
|
// - userID: the ID of the user
|
||||||
|
// - role: the role of the user
|
||||||
|
// - tokenType: the type of token (access or refresh)
|
||||||
|
// - expiry: how long the token should be valid
|
||||||
|
// Returns the signed token string or an error
|
||||||
|
func (s *JWTService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) {
|
||||||
|
now := time.Now()
|
||||||
|
claims := Claims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(now.Add(expiry)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
|
},
|
||||||
|
UserID: userID,
|
||||||
|
Role: role,
|
||||||
|
Type: tokenType,
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
return token.SignedString([]byte(s.config.SigningKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateToken validates and parses a JWT token
|
||||||
|
// Parameters:
|
||||||
|
// - tokenString: the token to validate
|
||||||
|
// Returns the token claims if valid, or an error if invalid
|
||||||
|
func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) {
|
||||||
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
// Validate the signing method
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
return []byte(s.config.SigningKey), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("invalid token claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshAccessToken creates a new access token using a refresh token
|
||||||
|
// Parameters:
|
||||||
|
// - refreshToken: the refresh token to use
|
||||||
|
// Returns a new access token if the refresh token is valid, or an error
|
||||||
|
func (s *JWTService) RefreshAccessToken(refreshToken string) (string, error) {
|
||||||
|
claims, err := s.ValidateToken(refreshToken)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.Type != RefreshToken {
|
||||||
|
return "", fmt.Errorf("invalid token type: expected refresh token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.GenerateAccessToken(claims.UserID, claims.Role)
|
||||||
|
}
|
||||||
102
backend/internal/auth/middleware.go
Normal file
102
backend/internal/auth/middleware.go
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
const (
|
||||||
|
UserContextKey contextKey = "user"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserClaims represents the user information stored in the request context
|
||||||
|
type UserClaims struct {
|
||||||
|
UserID int
|
||||||
|
Role string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware handles JWT authentication for protected routes
|
||||||
|
type Middleware struct {
|
||||||
|
jwtService *JWTService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMiddleware creates a new authentication middleware
|
||||||
|
func NewMiddleware(jwtService *JWTService) *Middleware {
|
||||||
|
return &Middleware{
|
||||||
|
jwtService: jwtService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate middleware validates JWT tokens and sets user information in context
|
||||||
|
func (m *Middleware) Authenticate(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Extract token from Authorization header
|
||||||
|
authHeader := r.Header.Get("Authorization")
|
||||||
|
if authHeader == "" {
|
||||||
|
http.Error(w, "Authorization header required", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Bearer token format
|
||||||
|
parts := strings.Split(authHeader, " ")
|
||||||
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||||
|
http.Error(w, "Invalid authorization format", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate token
|
||||||
|
claims, err := m.jwtService.ValidateToken(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid token", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check token type
|
||||||
|
if claims.Type != AccessToken {
|
||||||
|
http.Error(w, "Invalid token type", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add user claims to request context
|
||||||
|
ctx := context.WithValue(r.Context(), UserContextKey, UserClaims{
|
||||||
|
UserID: claims.UserID,
|
||||||
|
Role: claims.Role,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Call the next handler with the updated context
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireRole returns a middleware that ensures the user has the required role
|
||||||
|
func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
claims, ok := r.Context().Value(UserContextKey).(UserClaims)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.Role != role && claims.Role != "admin" {
|
||||||
|
http.Error(w, "Insufficient permissions", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserFromContext retrieves user claims from the request context
|
||||||
|
func GetUserFromContext(ctx context.Context) (*UserClaims, error) {
|
||||||
|
claims, ok := ctx.Value(UserContextKey).(UserClaims)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no user found in context")
|
||||||
|
}
|
||||||
|
return &claims, nil
|
||||||
|
}
|
||||||
140
backend/internal/auth/session.go
Normal file
140
backend/internal/auth/session.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Session represents a user session in the database
|
||||||
|
type Session struct {
|
||||||
|
ID string // Unique session identifier
|
||||||
|
UserID int // ID of the user this session belongs to
|
||||||
|
RefreshToken string // The refresh token associated with this session
|
||||||
|
ExpiresAt time.Time // When this session expires
|
||||||
|
CreatedAt time.Time // When this session was created
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionService manages user sessions in the database
|
||||||
|
type SessionService struct {
|
||||||
|
db *sql.DB // Database connection
|
||||||
|
jwtService *JWTService // JWT service for token operations
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionService creates a new session service
|
||||||
|
// Parameters:
|
||||||
|
// - db: database connection
|
||||||
|
// - jwtService: JWT service for token operations
|
||||||
|
func NewSessionService(db *sql.DB, jwtService *JWTService) *SessionService {
|
||||||
|
return &SessionService{
|
||||||
|
db: db,
|
||||||
|
jwtService: jwtService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSession creates a new user session
|
||||||
|
// Parameters:
|
||||||
|
// - userID: the ID of the user
|
||||||
|
// - role: the role of the user
|
||||||
|
// Returns:
|
||||||
|
// - session: the created session
|
||||||
|
// - accessToken: a new access token
|
||||||
|
// - error: any error that occurred
|
||||||
|
func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) {
|
||||||
|
// Generate both access and refresh tokens
|
||||||
|
accessToken, err := s.jwtService.GenerateAccessToken(userID, role)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("failed to generate access token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken, err := s.jwtService.GenerateRefreshToken(userID, role)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("failed to generate refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the refresh token to get its expiry time
|
||||||
|
claims, err := s.jwtService.ValidateToken(refreshToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("failed to validate refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new session record
|
||||||
|
session := &Session{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
UserID: userID,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
ExpiresAt: claims.ExpiresAt.Time,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the session in the database
|
||||||
|
_, err = s.db.Exec(`
|
||||||
|
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?)`,
|
||||||
|
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("failed to store session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, accessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshSession creates a new access token using a refresh token
|
||||||
|
// Parameters:
|
||||||
|
// - refreshToken: the refresh token to use
|
||||||
|
// Returns:
|
||||||
|
// - string: a new access token
|
||||||
|
// - error: any error that occurred
|
||||||
|
func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
||||||
|
// Validate the refresh token
|
||||||
|
claims, err := s.jwtService.ValidateToken(refreshToken)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the session exists and is not expired
|
||||||
|
var session Session
|
||||||
|
err = s.db.QueryRow(`
|
||||||
|
SELECT id, user_id, refresh_token, expires_at, created_at
|
||||||
|
FROM sessions
|
||||||
|
WHERE refresh_token = ? AND expires_at > ?`,
|
||||||
|
refreshToken, time.Now(),
|
||||||
|
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
|
||||||
|
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return "", fmt.Errorf("session not found or expired")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to fetch session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a new access token
|
||||||
|
return s.jwtService.GenerateAccessToken(claims.UserID, claims.Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateSession removes a session from the database
|
||||||
|
// Parameters:
|
||||||
|
// - sessionID: the ID of the session to invalidate
|
||||||
|
// Returns:
|
||||||
|
// - error: any error that occurred
|
||||||
|
func (s *SessionService) InvalidateSession(sessionID string) error {
|
||||||
|
_, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to invalidate session: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanExpiredSessions removes all expired sessions from the database
|
||||||
|
// Returns:
|
||||||
|
// - error: any error that occurred
|
||||||
|
func (s *SessionService) CleanExpiredSessions() error {
|
||||||
|
_, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to clean expired sessions: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user