mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-06 07:54:22 +00:00
Implement auth package tests
This commit is contained in:
@@ -1,25 +1,12 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"novamd/internal/httpcontext"
|
||||
"novamd/internal/context"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
// UserContextKey is the key used to store user claims in the request context
|
||||
const UserContextKey contextKey = "user"
|
||||
|
||||
// UserClaims represents the user information stored in the request context
|
||||
type UserClaims struct {
|
||||
UserID int
|
||||
Role string
|
||||
}
|
||||
|
||||
// Middleware handles JWT authentication for protected routes
|
||||
type Middleware struct {
|
||||
jwtManager JWTManager
|
||||
@@ -70,14 +57,14 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
// Add user claims to request context
|
||||
ctx := context.WithValue(r.Context(), UserContextKey, UserClaims{
|
||||
UserID: claims.UserID,
|
||||
Role: claims.Role,
|
||||
})
|
||||
// Create handler context with user information
|
||||
hctx := &context.HandlerContext{
|
||||
UserID: claims.UserID,
|
||||
UserRole: claims.Role,
|
||||
}
|
||||
|
||||
// Call the next handler with the updated context
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
// Add context to request and continue
|
||||
next.ServeHTTP(w, context.WithHandlerContext(r, hctx))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -89,13 +76,12 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
|
||||
func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims, ok := r.Context().Value(UserContextKey).(UserClaims)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if claims.Role != role && claims.Role != "admin" {
|
||||
if ctx.UserRole != role && ctx.UserRole != "admin" {
|
||||
http.Error(w, "Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
@@ -112,8 +98,7 @@ func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
|
||||
// - http.Handler: the handler function
|
||||
func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get our handler context
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -133,17 +118,3 @@ func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserFromContext retrieves user claims from the request context
|
||||
// Parameters:
|
||||
// - ctx: the request context
|
||||
// Returns:
|
||||
// - *UserClaims: the user claims
|
||||
// - error: any error that occurred
|
||||
func GetUserFromContext(ctx context.Context) (*UserClaims, error) {
|
||||
claims, ok := ctx.Value(UserContextKey).(UserClaims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no user found in context")
|
||||
}
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"novamd/internal/auth"
|
||||
"novamd/internal/httpcontext"
|
||||
"novamd/internal/context"
|
||||
"novamd/internal/models"
|
||||
)
|
||||
|
||||
@@ -97,7 +95,7 @@ func TestAuthenticateMiddleware(t *testing.T) {
|
||||
|
||||
// Create test handler
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
@@ -158,12 +156,15 @@ func TestRequireRole(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create context with user claims
|
||||
ctx := context.WithValue(context.Background(), auth.UserContextKey, auth.UserClaims{
|
||||
UserID: 1,
|
||||
Role: tc.userRole,
|
||||
})
|
||||
req := httptest.NewRequest("GET", "/test", nil).WithContext(ctx)
|
||||
// Create handler context with user info
|
||||
hctx := &context.HandlerContext{
|
||||
UserID: 1,
|
||||
UserRole: tc.userRole,
|
||||
}
|
||||
|
||||
// Create request with handler context
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req = context.WithHandlerContext(req, hctx)
|
||||
w := newMockResponseWriter()
|
||||
|
||||
// Create test handler
|
||||
@@ -201,13 +202,13 @@ func TestRequireWorkspaceAccess(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupContext func() *httpcontext.HandlerContext
|
||||
setupContext func() *context.HandlerContext
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner access",
|
||||
setupContext: func() *httpcontext.HandlerContext {
|
||||
return &httpcontext.HandlerContext{
|
||||
setupContext: func() *context.HandlerContext {
|
||||
return &context.HandlerContext{
|
||||
UserID: 1,
|
||||
UserRole: "editor",
|
||||
Workspace: &models.Workspace{
|
||||
@@ -220,8 +221,8 @@ func TestRequireWorkspaceAccess(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "admin access to other's workspace",
|
||||
setupContext: func() *httpcontext.HandlerContext {
|
||||
return &httpcontext.HandlerContext{
|
||||
setupContext: func() *context.HandlerContext {
|
||||
return &context.HandlerContext{
|
||||
UserID: 2,
|
||||
UserRole: "admin",
|
||||
Workspace: &models.Workspace{
|
||||
@@ -234,8 +235,8 @@ func TestRequireWorkspaceAccess(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "unauthorized access attempt",
|
||||
setupContext: func() *httpcontext.HandlerContext {
|
||||
return &httpcontext.HandlerContext{
|
||||
setupContext: func() *context.HandlerContext {
|
||||
return &context.HandlerContext{
|
||||
UserID: 2,
|
||||
UserRole: "editor",
|
||||
Workspace: &models.Workspace{
|
||||
@@ -248,8 +249,8 @@ func TestRequireWorkspaceAccess(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "no workspace in context",
|
||||
setupContext: func() *httpcontext.HandlerContext {
|
||||
return &httpcontext.HandlerContext{
|
||||
setupContext: func() *context.HandlerContext {
|
||||
return &context.HandlerContext{
|
||||
UserID: 1,
|
||||
UserRole: "editor",
|
||||
Workspace: nil,
|
||||
@@ -262,8 +263,8 @@ func TestRequireWorkspaceAccess(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create request with context
|
||||
ctx := context.WithValue(context.Background(), httpcontext.HandlerContextKey, tc.setupContext())
|
||||
req := httptest.NewRequest("GET", "/test", nil).WithContext(ctx)
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req = context.WithHandlerContext(req, tc.setupContext())
|
||||
w := newMockResponseWriter()
|
||||
|
||||
// Create test handler
|
||||
@@ -291,72 +292,3 @@ func TestRequireWorkspaceAccess(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserFromContext(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupCtx func() context.Context
|
||||
wantUserID int
|
||||
wantRole string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid user context",
|
||||
setupCtx: func() context.Context {
|
||||
return context.WithValue(context.Background(), auth.UserContextKey, auth.UserClaims{
|
||||
UserID: 1,
|
||||
Role: "admin",
|
||||
})
|
||||
},
|
||||
wantUserID: 1,
|
||||
wantRole: "admin",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing user context",
|
||||
setupCtx: func() context.Context {
|
||||
return context.Background()
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "no user found in context",
|
||||
},
|
||||
{
|
||||
name: "invalid context value type",
|
||||
setupCtx: func() context.Context {
|
||||
return context.WithValue(context.Background(), auth.UserContextKey, "invalid")
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "no user found in context",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx := tc.setupCtx()
|
||||
claims, err := auth.GetUserFromContext(ctx)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if claims.UserID != tc.wantUserID {
|
||||
t.Errorf("UserID = %v, want %v", claims.UserID, tc.wantUserID)
|
||||
}
|
||||
|
||||
if claims.Role != tc.wantRole {
|
||||
t.Errorf("Role = %v, want %v", claims.Role, tc.wantRole)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user