Files
lemma/server/internal/auth/middleware_test.go
2024-12-07 21:41:37 +01:00

401 lines
10 KiB
Go

package auth_test
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"novamd/internal/auth"
"novamd/internal/context"
"novamd/internal/models"
)
// Mock SessionManager
type mockSessionManager struct {
sessions map[string]*models.Session
}
func newMockSessionManager() *mockSessionManager {
return &mockSessionManager{
sessions: make(map[string]*models.Session),
}
}
func (m *mockSessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
return nil, "", nil // Not needed for these tests
}
func (m *mockSessionManager) RefreshSession(refreshToken string) (string, error) {
return "", nil // Not needed for these tests
}
func (m *mockSessionManager) ValidateSession(sessionID string) (*models.Session, error) {
session, exists := m.sessions[sessionID]
if !exists {
return nil, nil
}
return session, nil
}
func (m *mockSessionManager) InvalidateSession(token string) error {
delete(m.sessions, token)
return nil
}
func (m *mockSessionManager) CleanExpiredSessions() error {
return nil
}
// Complete mockResponseWriter implementation
type mockResponseWriter struct {
headers http.Header
statusCode int
written []byte
}
func newMockResponseWriter() *mockResponseWriter {
return &mockResponseWriter{
headers: make(http.Header),
}
}
func (m *mockResponseWriter) Header() http.Header {
return m.headers
}
func (m *mockResponseWriter) Write(b []byte) (int, error) {
m.written = b
return len(b), nil
}
func (m *mockResponseWriter) WriteHeader(statusCode int) {
m.statusCode = statusCode
}
func TestAuthenticateMiddleware(t *testing.T) {
config := auth.JWTConfig{
SigningKey: "test-key",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 24 * time.Hour,
}
jwtService, _ := auth.NewJWTService(config)
sessionManager := newMockSessionManager()
cookieManager := auth.NewCookieService(true, "localhost")
middleware := auth.NewMiddleware(jwtService, sessionManager, cookieManager)
testCases := []struct {
name string
setupRequest func() *http.Request
setupSession func(sessionID string)
method string
wantStatusCode int
}{
{
name: "valid token with valid session",
setupRequest: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin")
cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie)
return req
},
setupSession: func(sessionID string) {
sessionManager.sessions[sessionID] = &models.Session{
ID: sessionID,
UserID: 1,
ExpiresAt: time.Now().Add(15 * time.Minute),
}
},
method: "GET",
wantStatusCode: http.StatusOK,
},
{
name: "valid token but invalid session",
setupRequest: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin")
cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie)
return req
},
setupSession: func(sessionID string) {}, // No session setup
method: "GET",
wantStatusCode: http.StatusUnauthorized,
},
{
name: "missing auth cookie",
setupRequest: func() *http.Request {
return httptest.NewRequest("GET", "/test", nil)
},
setupSession: func(sessionID string) {},
method: "GET",
wantStatusCode: http.StatusUnauthorized,
},
{
name: "POST request without CSRF token",
setupRequest: func() *http.Request {
req := httptest.NewRequest("POST", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin")
cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie)
return req
},
setupSession: func(sessionID string) {
sessionManager.sessions[sessionID] = &models.Session{
ID: sessionID,
UserID: 1,
ExpiresAt: time.Now().Add(15 * time.Minute),
}
},
method: "POST",
wantStatusCode: http.StatusForbidden,
},
{
name: "POST request with valid CSRF token",
setupRequest: func() *http.Request {
req := httptest.NewRequest("POST", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin")
cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie)
csrfToken := "test-csrf-token"
csrfCookie := cookieManager.GenerateCSRFCookie(csrfToken)
req.AddCookie(csrfCookie)
req.Header.Set("X-CSRF-Token", csrfToken)
return req
},
setupSession: func(sessionID string) {
sessionManager.sessions[sessionID] = &models.Session{
ID: sessionID,
UserID: 1,
ExpiresAt: time.Now().Add(15 * time.Minute),
}
},
method: "POST",
wantStatusCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := tc.setupRequest()
w := newMockResponseWriter()
// Create test handler
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})
// If we have a valid token, set up the session
if cookie, err := req.Cookie("access_token"); err == nil {
if claims, err := jwtService.ValidateToken(cookie.Value); err == nil {
tc.setupSession(claims.ID)
}
}
// Execute middleware
middleware.Authenticate(next).ServeHTTP(w, req)
// Check status code
if w.statusCode != tc.wantStatusCode {
t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode)
}
// Check if next handler was called when expected
if tc.wantStatusCode == http.StatusOK && !nextCalled {
t.Error("next handler was not called")
}
if tc.wantStatusCode != http.StatusOK && nextCalled {
t.Error("next handler was called when it shouldn't have been")
}
// For unauthorized responses, check if cookies were invalidated
if w.statusCode == http.StatusUnauthorized {
for _, cookie := range w.Header()["Set-Cookie"] {
if strings.Contains(cookie, "Max-Age=0") {
t.Error("cookies were not properly invalidated")
}
}
}
})
}
}
func TestRequireRole(t *testing.T) {
config := auth.JWTConfig{
SigningKey: "test-key",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 24 * time.Hour,
}
jwtService, _ := auth.NewJWTService(config)
middleware := auth.NewMiddleware(jwtService, &mockSessionManager{}, auth.NewCookieService(true, "localhost"))
testCases := []struct {
name string
userRole string
requiredRole string
wantStatusCode int
}{
{
name: "matching role",
userRole: "admin",
requiredRole: "admin",
wantStatusCode: http.StatusOK,
},
{
name: "admin accessing other role",
userRole: "admin",
requiredRole: "editor",
wantStatusCode: http.StatusOK,
},
{
name: "insufficient role",
userRole: "editor",
requiredRole: "admin",
wantStatusCode: http.StatusForbidden,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create handler context with user info
hctx := &context.HandlerContext{
UserID: 1,
UserRole: tc.userRole,
}
// Create request with handler context
req := httptest.NewRequest("GET", "/test", nil)
req = context.WithHandlerContext(req, hctx)
w := newMockResponseWriter()
// Create test handler
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})
// Execute middleware
middleware.RequireRole(tc.requiredRole)(next).ServeHTTP(w, req)
// Check status code
if w.statusCode != tc.wantStatusCode {
t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode)
}
// Check if next handler was called when expected
if tc.wantStatusCode == http.StatusOK && !nextCalled {
t.Error("next handler was not called")
}
if tc.wantStatusCode != http.StatusOK && nextCalled {
t.Error("next handler was called when it shouldn't have been")
}
})
}
}
func TestRequireWorkspaceAccess(t *testing.T) {
config := auth.JWTConfig{
SigningKey: "test-key",
}
jwtService, _ := auth.NewJWTService(config)
middleware := auth.NewMiddleware(jwtService, &mockSessionManager{}, auth.NewCookieService(true, "localhost"))
testCases := []struct {
name string
setupContext func() *context.HandlerContext
wantStatusCode int
}{
{
name: "workspace owner access",
setupContext: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 1,
UserRole: "editor",
Workspace: &models.Workspace{
ID: 1,
UserID: 1, // Same as context UserID
},
}
},
wantStatusCode: http.StatusOK,
},
{
name: "admin access to other's workspace",
setupContext: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 2,
UserRole: "admin",
Workspace: &models.Workspace{
ID: 1,
UserID: 1, // Different from context UserID
},
}
},
wantStatusCode: http.StatusOK,
},
{
name: "unauthorized access attempt",
setupContext: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 2,
UserRole: "editor",
Workspace: &models.Workspace{
ID: 1,
UserID: 1, // Different from context UserID
},
}
},
wantStatusCode: http.StatusNotFound,
},
{
name: "no workspace in context",
setupContext: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 1,
UserRole: "editor",
Workspace: nil,
}
},
wantStatusCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create request with context
req := httptest.NewRequest("GET", "/test", nil)
req = context.WithHandlerContext(req, tc.setupContext())
w := newMockResponseWriter()
// Create test handler
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})
// Execute middleware
middleware.RequireWorkspaceAccess(next).ServeHTTP(w, req)
// Check status code
if w.statusCode != tc.wantStatusCode {
t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode)
}
// Check if next handler was called when expected
if tc.wantStatusCode == http.StatusOK && !nextCalled {
t.Error("next handler was not called")
}
if tc.wantStatusCode != http.StatusOK && nextCalled {
t.Error("next handler was called when it shouldn't have been")
}
})
}
}