mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-06 16:04:23 +00:00
181 lines
4.8 KiB
Go
181 lines
4.8 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"novamd/internal/db"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
UserContextKey contextKey = "user"
|
|
)
|
|
|
|
// UserClaims represents the user information stored in the request context
|
|
type UserClaims struct {
|
|
UserID int
|
|
Role string
|
|
}
|
|
|
|
// Middleware handles JWT authentication for protected routes
|
|
type Middleware struct {
|
|
jwtService *JWTService
|
|
}
|
|
|
|
// NewMiddleware creates a new authentication middleware
|
|
func NewMiddleware(jwtService *JWTService) *Middleware {
|
|
return &Middleware{
|
|
jwtService: jwtService,
|
|
}
|
|
}
|
|
|
|
// Authenticate middleware validates JWT tokens and sets user information in context
|
|
func (m *Middleware) Authenticate(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Extract token from Authorization header
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
http.Error(w, "Authorization header required", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Check Bearer token format
|
|
parts := strings.Split(authHeader, " ")
|
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
|
http.Error(w, "Invalid authorization format", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Validate token
|
|
claims, err := m.jwtService.ValidateToken(parts[1])
|
|
if err != nil {
|
|
http.Error(w, "Invalid token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Check token type
|
|
if claims.Type != AccessToken {
|
|
http.Error(w, "Invalid token type", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Add user claims to request context
|
|
ctx := context.WithValue(r.Context(), UserContextKey, UserClaims{
|
|
UserID: claims.UserID,
|
|
Role: claims.Role,
|
|
})
|
|
|
|
// Call the next handler with the updated context
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// RequireRole returns a middleware that ensures the user has the required role
|
|
func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
claims, ok := r.Context().Value(UserContextKey).(UserClaims)
|
|
if !ok {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
if claims.Role != role && claims.Role != "admin" {
|
|
http.Error(w, "Insufficient permissions", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// RequireResourceOwnership ensures users can only access their own resources
|
|
func (m *Middleware) RequireResourceOwnership(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Get requesting user from context (set by auth middleware)
|
|
claims, err := GetUserFromContext(r.Context())
|
|
if err != nil {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Get requested user ID from URL
|
|
userIDStr := chi.URLParam(r, "userId")
|
|
requestedUserID, err := strconv.Atoi(userIDStr)
|
|
if err != nil {
|
|
http.Error(w, "Invalid user ID", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Allow if user is accessing their own resources or is an admin
|
|
if claims.UserID != requestedUserID && claims.Role != "admin" {
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// RequireWorkspaceOwnership ensures users can only access workspaces they own
|
|
type WorkspaceGetter interface {
|
|
GetWorkspaceByID(id int) (*Workspace, error)
|
|
}
|
|
|
|
type Workspace struct {
|
|
ID int
|
|
UserID int
|
|
}
|
|
|
|
// RequireWorkspaceOwnership ensures users can only access workspaces they own
|
|
func (m *Middleware) RequireWorkspaceOwnership(db *db.DB) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Get requesting user from context
|
|
claims, err := GetUserFromContext(r.Context())
|
|
if err != nil {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Get workspace ID from URL
|
|
workspaceID, err := strconv.Atoi(chi.URLParam(r, "workspaceId"))
|
|
if err != nil {
|
|
http.Error(w, "Invalid workspace ID", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Get workspace from database
|
|
workspace, err := db.GetWorkspaceByID(workspaceID)
|
|
if err != nil {
|
|
http.Error(w, "Workspace not found", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
// Check if user owns the workspace or is admin
|
|
if workspace.UserID != claims.UserID && claims.Role != "admin" {
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// GetUserFromContext retrieves user claims from the request context
|
|
func GetUserFromContext(ctx context.Context) (*UserClaims, error) {
|
|
claims, ok := ctx.Value(UserContextKey).(UserClaims)
|
|
if !ok {
|
|
return nil, fmt.Errorf("no user found in context")
|
|
}
|
|
return &claims, nil
|
|
}
|