Implement auth package tests

This commit is contained in:
2024-11-23 00:29:26 +01:00
parent b3ec4e136c
commit ebdd7bd741
12 changed files with 136 additions and 203 deletions

View File

@@ -1,25 +1,12 @@
package auth
import (
"context"
"fmt"
"net/http"
"strings"
"novamd/internal/httpcontext"
"novamd/internal/context"
)
type contextKey string
// UserContextKey is the key used to store user claims in the request context
const UserContextKey contextKey = "user"
// UserClaims represents the user information stored in the request context
type UserClaims struct {
UserID int
Role string
}
// Middleware handles JWT authentication for protected routes
type Middleware struct {
jwtManager JWTManager
@@ -70,14 +57,14 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
return
}
// Add user claims to request context
ctx := context.WithValue(r.Context(), UserContextKey, UserClaims{
UserID: claims.UserID,
Role: claims.Role,
})
// Create handler context with user information
hctx := &context.HandlerContext{
UserID: claims.UserID,
UserRole: claims.Role,
}
// Call the next handler with the updated context
next.ServeHTTP(w, r.WithContext(ctx))
// Add context to request and continue
next.ServeHTTP(w, context.WithHandlerContext(r, hctx))
})
}
@@ -89,13 +76,12 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, ok := r.Context().Value(UserContextKey).(UserClaims)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if claims.Role != role && claims.Role != "admin" {
if ctx.UserRole != role && ctx.UserRole != "admin" {
http.Error(w, "Insufficient permissions", http.StatusForbidden)
return
}
@@ -112,8 +98,7 @@ func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
// - http.Handler: the handler function
func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get our handler context
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -133,17 +118,3 @@ func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
})
}
// GetUserFromContext retrieves user claims from the request context
// Parameters:
// - ctx: the request context
// Returns:
// - *UserClaims: the user claims
// - error: any error that occurred
func GetUserFromContext(ctx context.Context) (*UserClaims, error) {
claims, ok := ctx.Value(UserContextKey).(UserClaims)
if !ok {
return nil, fmt.Errorf("no user found in context")
}
return &claims, nil
}