mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-06 07:54:22 +00:00
Update auth test
This commit is contained in:
@@ -3,6 +3,7 @@ package auth_test
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -11,6 +12,42 @@ import (
|
||||
"novamd/internal/models"
|
||||
)
|
||||
|
||||
// Mock SessionManager
|
||||
type mockSessionManager struct {
|
||||
sessions map[string]*models.Session
|
||||
}
|
||||
|
||||
func newMockSessionManager() *mockSessionManager {
|
||||
return &mockSessionManager{
|
||||
sessions: make(map[string]*models.Session),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
|
||||
return nil, "", nil // Not needed for these tests
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) RefreshSession(refreshToken string) (string, error) {
|
||||
return "", nil // Not needed for these tests
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) ValidateSession(sessionID string) (*models.Session, error) {
|
||||
session, exists := m.sessions[sessionID]
|
||||
if !exists {
|
||||
return nil, nil
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) InvalidateSession(token string) error {
|
||||
delete(m.sessions, token)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) CleanExpiredSessions() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Complete mockResponseWriter implementation
|
||||
type mockResponseWriter struct {
|
||||
headers http.Header
|
||||
@@ -44,62 +81,122 @@ func TestAuthenticateMiddleware(t *testing.T) {
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
}
|
||||
jwtService, _ := auth.NewJWTService(config)
|
||||
middleware := auth.NewMiddleware(jwtService)
|
||||
sessionManager := newMockSessionManager()
|
||||
cookieManager := auth.NewCookieService(true, "localhost")
|
||||
middleware := auth.NewMiddleware(jwtService, sessionManager, cookieManager)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupAuth func() string
|
||||
setupRequest func() *http.Request
|
||||
setupSession func(sessionID string)
|
||||
method string
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "valid token",
|
||||
setupAuth: func() string {
|
||||
name: "valid token with valid session",
|
||||
setupRequest: func() *http.Request {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
token, _ := jwtService.GenerateAccessToken(1, "admin")
|
||||
return token
|
||||
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
||||
req.AddCookie(cookie)
|
||||
return req
|
||||
},
|
||||
setupSession: func(sessionID string) {
|
||||
sessionManager.sessions[sessionID] = &models.Session{
|
||||
ID: sessionID,
|
||||
UserID: 1,
|
||||
ExpiresAt: time.Now().Add(15 * time.Minute),
|
||||
}
|
||||
},
|
||||
method: "GET",
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "missing auth header",
|
||||
setupAuth: func() string {
|
||||
return ""
|
||||
name: "valid token but invalid session",
|
||||
setupRequest: func() *http.Request {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
token, _ := jwtService.GenerateAccessToken(1, "admin")
|
||||
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
||||
req.AddCookie(cookie)
|
||||
return req
|
||||
},
|
||||
setupSession: func(sessionID string) {}, // No session setup
|
||||
method: "GET",
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "invalid auth format",
|
||||
setupAuth: func() string {
|
||||
return "InvalidFormat token"
|
||||
name: "missing auth cookie",
|
||||
setupRequest: func() *http.Request {
|
||||
return httptest.NewRequest("GET", "/test", nil)
|
||||
},
|
||||
setupSession: func(sessionID string) {},
|
||||
method: "GET",
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "invalid token",
|
||||
setupAuth: func() string {
|
||||
return "Bearer invalid.token.here"
|
||||
name: "POST request without CSRF token",
|
||||
setupRequest: func() *http.Request {
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
token, _ := jwtService.GenerateAccessToken(1, "admin")
|
||||
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
||||
req.AddCookie(cookie)
|
||||
return req
|
||||
},
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
setupSession: func(sessionID string) {
|
||||
sessionManager.sessions[sessionID] = &models.Session{
|
||||
ID: sessionID,
|
||||
UserID: 1,
|
||||
ExpiresAt: time.Now().Add(15 * time.Minute),
|
||||
}
|
||||
},
|
||||
method: "POST",
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "POST request with valid CSRF token",
|
||||
setupRequest: func() *http.Request {
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
token, _ := jwtService.GenerateAccessToken(1, "admin")
|
||||
cookie := cookieManager.GenerateAccessTokenCookie(token)
|
||||
req.AddCookie(cookie)
|
||||
|
||||
csrfToken := "test-csrf-token"
|
||||
csrfCookie := cookieManager.GenerateCSRFCookie(csrfToken)
|
||||
req.AddCookie(csrfCookie)
|
||||
req.Header.Set("X-CSRF-Token", csrfToken)
|
||||
return req
|
||||
},
|
||||
setupSession: func(sessionID string) {
|
||||
sessionManager.sessions[sessionID] = &models.Session{
|
||||
ID: sessionID,
|
||||
UserID: 1,
|
||||
ExpiresAt: time.Now().Add(15 * time.Minute),
|
||||
}
|
||||
},
|
||||
method: "POST",
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
if token := tc.setupAuth(); token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
// Create response recorder
|
||||
req := tc.setupRequest()
|
||||
w := newMockResponseWriter()
|
||||
|
||||
// Create test handler
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// If we have a valid token, set up the session
|
||||
if cookie, err := req.Cookie("access_token"); err == nil {
|
||||
if claims, err := jwtService.ValidateToken(cookie.Value); err == nil {
|
||||
tc.setupSession(claims.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute middleware
|
||||
middleware.Authenticate(next).ServeHTTP(w, req)
|
||||
|
||||
@@ -115,6 +212,15 @@ func TestAuthenticateMiddleware(t *testing.T) {
|
||||
if tc.wantStatusCode != http.StatusOK && nextCalled {
|
||||
t.Error("next handler was called when it shouldn't have been")
|
||||
}
|
||||
|
||||
// For unauthorized responses, check if cookies were invalidated
|
||||
if w.statusCode == http.StatusUnauthorized {
|
||||
for _, cookie := range w.Header()["Set-Cookie"] {
|
||||
if strings.Contains(cookie, "Max-Age=0") {
|
||||
t.Error("cookies were not properly invalidated")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -126,7 +232,7 @@ func TestRequireRole(t *testing.T) {
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
}
|
||||
jwtService, _ := auth.NewJWTService(config)
|
||||
middleware := auth.NewMiddleware(jwtService)
|
||||
middleware := auth.NewMiddleware(jwtService, &mockSessionManager{}, auth.NewCookieService(true, "localhost"))
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -198,7 +304,7 @@ func TestRequireWorkspaceAccess(t *testing.T) {
|
||||
SigningKey: "test-key",
|
||||
}
|
||||
jwtService, _ := auth.NewJWTService(config)
|
||||
middleware := auth.NewMiddleware(jwtService)
|
||||
middleware := auth.NewMiddleware(jwtService, &mockSessionManager{}, auth.NewCookieService(true, "localhost"))
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
// Mock SessionStore
|
||||
type mockSessionStore struct {
|
||||
sessions map[string]*models.Session
|
||||
sessionsByToken map[string]*models.Session // Added index by refresh token
|
||||
sessionsByToken map[string]*models.Session
|
||||
}
|
||||
|
||||
func newMockSessionStore() *mockSessionStore {
|
||||
@@ -29,6 +29,17 @@ func (m *mockSessionStore) CreateSession(session *models.Session) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSessionStore) GetSessionByID(sessionID string) (*models.Session, error) {
|
||||
session, exists := m.sessions[sessionID]
|
||||
if !exists {
|
||||
return nil, errors.New("session not found")
|
||||
}
|
||||
if session.ExpiresAt.Before(time.Now()) {
|
||||
return nil, errors.New("session expired")
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (m *mockSessionStore) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
|
||||
session, exists := m.sessionsByToken[refreshToken]
|
||||
if !exists {
|
||||
@@ -111,9 +122,9 @@ func TestCreateSession(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify the session was stored
|
||||
storedSession, exists := mockDB.sessions[session.ID]
|
||||
if !exists {
|
||||
t.Error("session was not stored in database")
|
||||
storedSession, err := mockDB.GetSessionByID(session.ID)
|
||||
if err != nil {
|
||||
t.Errorf("failed to get stored session: %v", err)
|
||||
}
|
||||
if storedSession.RefreshToken != session.RefreshToken {
|
||||
t.Error("stored refresh token doesn't match")
|
||||
@@ -138,6 +149,97 @@ func TestCreateSession(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSession(t *testing.T) {
|
||||
config := auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
}
|
||||
jwtService, _ := auth.NewJWTService(config)
|
||||
mockDB := newMockSessionStore()
|
||||
sessionService := auth.NewSessionService(mockDB, jwtService)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupSession func() string
|
||||
wantErr bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid session",
|
||||
setupSession: func() string {
|
||||
session := &models.Session{
|
||||
ID: "test-session-1",
|
||||
UserID: 1,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := mockDB.CreateSession(session); err != nil {
|
||||
t.Fatalf("failed to create session: %v", err)
|
||||
}
|
||||
|
||||
return session.ID
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "expired session",
|
||||
setupSession: func() string {
|
||||
session := &models.Session{
|
||||
ID: "test-session-2",
|
||||
UserID: 1,
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour),
|
||||
CreatedAt: time.Now().Add(-2 * time.Hour),
|
||||
}
|
||||
if err := mockDB.CreateSession(session); err != nil {
|
||||
t.Fatalf("failed to create session: %v", err)
|
||||
}
|
||||
return session.ID
|
||||
},
|
||||
wantErr: true,
|
||||
errorContains: "session expired",
|
||||
},
|
||||
{
|
||||
name: "non-existent session",
|
||||
setupSession: func() string {
|
||||
return "non-existent-session-id"
|
||||
},
|
||||
wantErr: true,
|
||||
errorContains: "session not found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sessionID := tc.setupSession()
|
||||
session, err := sessionService.ValidateSession(sessionID)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if tc.errorContains != "" && !strings.Contains(err.Error(), tc.errorContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errorContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if session == nil {
|
||||
t.Error("expected session, got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if session.ID != sessionID {
|
||||
t.Errorf("session ID = %v, want %v", session.ID, sessionID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshSession(t *testing.T) {
|
||||
config := auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
@@ -180,7 +282,7 @@ func TestRefreshSession(t *testing.T) {
|
||||
ID: "test-session-2",
|
||||
UserID: 1,
|
||||
RefreshToken: token,
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour),
|
||||
CreatedAt: time.Now().Add(-2 * time.Hour),
|
||||
}
|
||||
if err := mockDB.CreateSession(session); err != nil {
|
||||
@@ -233,7 +335,7 @@ func TestRefreshSession(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidateSession(t *testing.T) {
|
||||
func TestCleanExpiredSessions(t *testing.T) {
|
||||
config := auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
@@ -243,62 +345,40 @@ func TestInvalidateSession(t *testing.T) {
|
||||
mockDB := newMockSessionStore()
|
||||
sessionService := auth.NewSessionService(mockDB, jwtService)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupSession func() string
|
||||
wantErr bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid session invalidation",
|
||||
setupSession: func() string {
|
||||
session := &models.Session{
|
||||
ID: "test-session-1",
|
||||
UserID: 1,
|
||||
RefreshToken: "valid-token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := mockDB.CreateSession(session); err != nil {
|
||||
t.Fatalf("failed to create session: %v", err)
|
||||
}
|
||||
return session.ID
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent session",
|
||||
setupSession: func() string {
|
||||
return "non-existent-session-id"
|
||||
},
|
||||
wantErr: true,
|
||||
errorContains: "session not found",
|
||||
},
|
||||
// Create test sessions
|
||||
validSession := &models.Session{
|
||||
ID: "valid-session",
|
||||
UserID: 1,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := mockDB.CreateSession(validSession); err != nil {
|
||||
t.Fatalf("failed to create valid session: %v", err)
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sessionID := tc.setupSession()
|
||||
err := sessionService.InvalidateSession(sessionID)
|
||||
expiredSession := &models.Session{
|
||||
ID: "expired-session",
|
||||
UserID: 2,
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour),
|
||||
CreatedAt: time.Now().Add(-2 * time.Hour),
|
||||
}
|
||||
if err := mockDB.CreateSession(expiredSession); err != nil {
|
||||
t.Fatalf("failed to create expired session: %v", err)
|
||||
}
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if !strings.Contains(err.Error(), tc.errorContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errorContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
// Clean expired sessions
|
||||
err := sessionService.CleanExpiredSessions()
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error cleaning sessions: %v", err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
// Verify valid session still exists
|
||||
if _, err := mockDB.GetSessionByID(validSession.ID); err != nil {
|
||||
t.Error("valid session was incorrectly removed")
|
||||
}
|
||||
|
||||
// Verify session was removed
|
||||
if _, exists := mockDB.sessions[sessionID]; exists {
|
||||
t.Error("session still exists after invalidation")
|
||||
}
|
||||
})
|
||||
// Verify expired session was removed
|
||||
if _, err := mockDB.GetSessionByID(expiredSession.ID); err == nil {
|
||||
t.Error("expired session was not removed")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user