Update auth test

This commit is contained in:
2024-12-07 21:41:37 +01:00
parent 8a4508e29f
commit ad4af2f82d
2 changed files with 269 additions and 83 deletions

View File

@@ -3,6 +3,7 @@ package auth_test
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
@@ -11,6 +12,42 @@ import (
"novamd/internal/models" "novamd/internal/models"
) )
// Mock SessionManager
type mockSessionManager struct {
sessions map[string]*models.Session
}
func newMockSessionManager() *mockSessionManager {
return &mockSessionManager{
sessions: make(map[string]*models.Session),
}
}
func (m *mockSessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
return nil, "", nil // Not needed for these tests
}
func (m *mockSessionManager) RefreshSession(refreshToken string) (string, error) {
return "", nil // Not needed for these tests
}
func (m *mockSessionManager) ValidateSession(sessionID string) (*models.Session, error) {
session, exists := m.sessions[sessionID]
if !exists {
return nil, nil
}
return session, nil
}
func (m *mockSessionManager) InvalidateSession(token string) error {
delete(m.sessions, token)
return nil
}
func (m *mockSessionManager) CleanExpiredSessions() error {
return nil
}
// Complete mockResponseWriter implementation // Complete mockResponseWriter implementation
type mockResponseWriter struct { type mockResponseWriter struct {
headers http.Header headers http.Header
@@ -44,62 +81,122 @@ func TestAuthenticateMiddleware(t *testing.T) {
RefreshTokenExpiry: 24 * time.Hour, RefreshTokenExpiry: 24 * time.Hour,
} }
jwtService, _ := auth.NewJWTService(config) jwtService, _ := auth.NewJWTService(config)
middleware := auth.NewMiddleware(jwtService) sessionManager := newMockSessionManager()
cookieManager := auth.NewCookieService(true, "localhost")
middleware := auth.NewMiddleware(jwtService, sessionManager, cookieManager)
testCases := []struct { testCases := []struct {
name string name string
setupAuth func() string setupRequest func() *http.Request
setupSession func(sessionID string)
method string
wantStatusCode int wantStatusCode int
}{ }{
{ {
name: "valid token", name: "valid token with valid session",
setupAuth: func() string { setupRequest: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin") token, _ := jwtService.GenerateAccessToken(1, "admin")
return token cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie)
return req
}, },
setupSession: func(sessionID string) {
sessionManager.sessions[sessionID] = &models.Session{
ID: sessionID,
UserID: 1,
ExpiresAt: time.Now().Add(15 * time.Minute),
}
},
method: "GET",
wantStatusCode: http.StatusOK, wantStatusCode: http.StatusOK,
}, },
{ {
name: "missing auth header", name: "valid token but invalid session",
setupAuth: func() string { setupRequest: func() *http.Request {
return "" req := httptest.NewRequest("GET", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin")
cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie)
return req
}, },
setupSession: func(sessionID string) {}, // No session setup
method: "GET",
wantStatusCode: http.StatusUnauthorized, wantStatusCode: http.StatusUnauthorized,
}, },
{ {
name: "invalid auth format", name: "missing auth cookie",
setupAuth: func() string { setupRequest: func() *http.Request {
return "InvalidFormat token" return httptest.NewRequest("GET", "/test", nil)
}, },
setupSession: func(sessionID string) {},
method: "GET",
wantStatusCode: http.StatusUnauthorized, wantStatusCode: http.StatusUnauthorized,
}, },
{ {
name: "invalid token", name: "POST request without CSRF token",
setupAuth: func() string { setupRequest: func() *http.Request {
return "Bearer invalid.token.here" req := httptest.NewRequest("POST", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin")
cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie)
return req
}, },
wantStatusCode: http.StatusUnauthorized, setupSession: func(sessionID string) {
sessionManager.sessions[sessionID] = &models.Session{
ID: sessionID,
UserID: 1,
ExpiresAt: time.Now().Add(15 * time.Minute),
}
},
method: "POST",
wantStatusCode: http.StatusForbidden,
},
{
name: "POST request with valid CSRF token",
setupRequest: func() *http.Request {
req := httptest.NewRequest("POST", "/test", nil)
token, _ := jwtService.GenerateAccessToken(1, "admin")
cookie := cookieManager.GenerateAccessTokenCookie(token)
req.AddCookie(cookie)
csrfToken := "test-csrf-token"
csrfCookie := cookieManager.GenerateCSRFCookie(csrfToken)
req.AddCookie(csrfCookie)
req.Header.Set("X-CSRF-Token", csrfToken)
return req
},
setupSession: func(sessionID string) {
sessionManager.sessions[sessionID] = &models.Session{
ID: sessionID,
UserID: 1,
ExpiresAt: time.Now().Add(15 * time.Minute),
}
},
method: "POST",
wantStatusCode: http.StatusOK,
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
// Create test request req := tc.setupRequest()
req := httptest.NewRequest("GET", "/test", nil)
if token := tc.setupAuth(); token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
// Create response recorder
w := newMockResponseWriter() w := newMockResponseWriter()
// Create test handler // Create test handler
nextCalled := false nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true nextCalled = true
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
// If we have a valid token, set up the session
if cookie, err := req.Cookie("access_token"); err == nil {
if claims, err := jwtService.ValidateToken(cookie.Value); err == nil {
tc.setupSession(claims.ID)
}
}
// Execute middleware // Execute middleware
middleware.Authenticate(next).ServeHTTP(w, req) middleware.Authenticate(next).ServeHTTP(w, req)
@@ -115,6 +212,15 @@ func TestAuthenticateMiddleware(t *testing.T) {
if tc.wantStatusCode != http.StatusOK && nextCalled { if tc.wantStatusCode != http.StatusOK && nextCalled {
t.Error("next handler was called when it shouldn't have been") t.Error("next handler was called when it shouldn't have been")
} }
// For unauthorized responses, check if cookies were invalidated
if w.statusCode == http.StatusUnauthorized {
for _, cookie := range w.Header()["Set-Cookie"] {
if strings.Contains(cookie, "Max-Age=0") {
t.Error("cookies were not properly invalidated")
}
}
}
}) })
} }
} }
@@ -126,7 +232,7 @@ func TestRequireRole(t *testing.T) {
RefreshTokenExpiry: 24 * time.Hour, RefreshTokenExpiry: 24 * time.Hour,
} }
jwtService, _ := auth.NewJWTService(config) jwtService, _ := auth.NewJWTService(config)
middleware := auth.NewMiddleware(jwtService) middleware := auth.NewMiddleware(jwtService, &mockSessionManager{}, auth.NewCookieService(true, "localhost"))
testCases := []struct { testCases := []struct {
name string name string
@@ -198,7 +304,7 @@ func TestRequireWorkspaceAccess(t *testing.T) {
SigningKey: "test-key", SigningKey: "test-key",
} }
jwtService, _ := auth.NewJWTService(config) jwtService, _ := auth.NewJWTService(config)
middleware := auth.NewMiddleware(jwtService) middleware := auth.NewMiddleware(jwtService, &mockSessionManager{}, auth.NewCookieService(true, "localhost"))
testCases := []struct { testCases := []struct {
name string name string

View File

@@ -13,7 +13,7 @@ import (
// Mock SessionStore // Mock SessionStore
type mockSessionStore struct { type mockSessionStore struct {
sessions map[string]*models.Session sessions map[string]*models.Session
sessionsByToken map[string]*models.Session // Added index by refresh token sessionsByToken map[string]*models.Session
} }
func newMockSessionStore() *mockSessionStore { func newMockSessionStore() *mockSessionStore {
@@ -29,6 +29,17 @@ func (m *mockSessionStore) CreateSession(session *models.Session) error {
return nil return nil
} }
func (m *mockSessionStore) GetSessionByID(sessionID string) (*models.Session, error) {
session, exists := m.sessions[sessionID]
if !exists {
return nil, errors.New("session not found")
}
if session.ExpiresAt.Before(time.Now()) {
return nil, errors.New("session expired")
}
return session, nil
}
func (m *mockSessionStore) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { func (m *mockSessionStore) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
session, exists := m.sessionsByToken[refreshToken] session, exists := m.sessionsByToken[refreshToken]
if !exists { if !exists {
@@ -111,9 +122,9 @@ func TestCreateSession(t *testing.T) {
} }
// Verify the session was stored // Verify the session was stored
storedSession, exists := mockDB.sessions[session.ID] storedSession, err := mockDB.GetSessionByID(session.ID)
if !exists { if err != nil {
t.Error("session was not stored in database") t.Errorf("failed to get stored session: %v", err)
} }
if storedSession.RefreshToken != session.RefreshToken { if storedSession.RefreshToken != session.RefreshToken {
t.Error("stored refresh token doesn't match") t.Error("stored refresh token doesn't match")
@@ -138,6 +149,97 @@ func TestCreateSession(t *testing.T) {
} }
} }
func TestValidateSession(t *testing.T) {
config := auth.JWTConfig{
SigningKey: "test-key",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 24 * time.Hour,
}
jwtService, _ := auth.NewJWTService(config)
mockDB := newMockSessionStore()
sessionService := auth.NewSessionService(mockDB, jwtService)
testCases := []struct {
name string
setupSession func() string
wantErr bool
errorContains string
}{
{
name: "valid session",
setupSession: func() string {
session := &models.Session{
ID: "test-session-1",
UserID: 1,
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
}
if err := mockDB.CreateSession(session); err != nil {
t.Fatalf("failed to create session: %v", err)
}
return session.ID
},
wantErr: false,
},
{
name: "expired session",
setupSession: func() string {
session := &models.Session{
ID: "test-session-2",
UserID: 1,
ExpiresAt: time.Now().Add(-1 * time.Hour),
CreatedAt: time.Now().Add(-2 * time.Hour),
}
if err := mockDB.CreateSession(session); err != nil {
t.Fatalf("failed to create session: %v", err)
}
return session.ID
},
wantErr: true,
errorContains: "session expired",
},
{
name: "non-existent session",
setupSession: func() string {
return "non-existent-session-id"
},
wantErr: true,
errorContains: "session not found",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
sessionID := tc.setupSession()
session, err := sessionService.ValidateSession(sessionID)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
} else if tc.errorContains != "" && !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("error = %v, want error containing %v", err, tc.errorContains)
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if session == nil {
t.Error("expected session, got nil")
return
}
if session.ID != sessionID {
t.Errorf("session ID = %v, want %v", session.ID, sessionID)
}
})
}
}
func TestRefreshSession(t *testing.T) { func TestRefreshSession(t *testing.T) {
config := auth.JWTConfig{ config := auth.JWTConfig{
SigningKey: "test-key", SigningKey: "test-key",
@@ -180,7 +282,7 @@ func TestRefreshSession(t *testing.T) {
ID: "test-session-2", ID: "test-session-2",
UserID: 1, UserID: 1,
RefreshToken: token, RefreshToken: token,
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired ExpiresAt: time.Now().Add(-1 * time.Hour),
CreatedAt: time.Now().Add(-2 * time.Hour), CreatedAt: time.Now().Add(-2 * time.Hour),
} }
if err := mockDB.CreateSession(session); err != nil { if err := mockDB.CreateSession(session); err != nil {
@@ -233,7 +335,7 @@ func TestRefreshSession(t *testing.T) {
} }
} }
func TestInvalidateSession(t *testing.T) { func TestCleanExpiredSessions(t *testing.T) {
config := auth.JWTConfig{ config := auth.JWTConfig{
SigningKey: "test-key", SigningKey: "test-key",
AccessTokenExpiry: 15 * time.Minute, AccessTokenExpiry: 15 * time.Minute,
@@ -243,62 +345,40 @@ func TestInvalidateSession(t *testing.T) {
mockDB := newMockSessionStore() mockDB := newMockSessionStore()
sessionService := auth.NewSessionService(mockDB, jwtService) sessionService := auth.NewSessionService(mockDB, jwtService)
testCases := []struct { // Create test sessions
name string validSession := &models.Session{
setupSession func() string ID: "valid-session",
wantErr bool UserID: 1,
errorContains string ExpiresAt: time.Now().Add(24 * time.Hour),
}{ CreatedAt: time.Now(),
{ }
name: "valid session invalidation", if err := mockDB.CreateSession(validSession); err != nil {
setupSession: func() string { t.Fatalf("failed to create valid session: %v", err)
session := &models.Session{
ID: "test-session-1",
UserID: 1,
RefreshToken: "valid-token",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
}
if err := mockDB.CreateSession(session); err != nil {
t.Fatalf("failed to create session: %v", err)
}
return session.ID
},
wantErr: false,
},
{
name: "non-existent session",
setupSession: func() string {
return "non-existent-session-id"
},
wantErr: true,
errorContains: "session not found",
},
} }
for _, tc := range testCases { expiredSession := &models.Session{
t.Run(tc.name, func(t *testing.T) { ID: "expired-session",
sessionID := tc.setupSession() UserID: 2,
err := sessionService.InvalidateSession(sessionID) ExpiresAt: time.Now().Add(-1 * time.Hour),
CreatedAt: time.Now().Add(-2 * time.Hour),
}
if err := mockDB.CreateSession(expiredSession); err != nil {
t.Fatalf("failed to create expired session: %v", err)
}
if tc.wantErr { // Clean expired sessions
if err == nil { err := sessionService.CleanExpiredSessions()
t.Error("expected error, got nil") if err != nil {
} else if !strings.Contains(err.Error(), tc.errorContains) { t.Errorf("unexpected error cleaning sessions: %v", err)
t.Errorf("error = %v, want error containing %v", err, tc.errorContains) }
}
return
}
if err != nil { // Verify valid session still exists
t.Errorf("unexpected error: %v", err) if _, err := mockDB.GetSessionByID(validSession.ID); err != nil {
return t.Error("valid session was incorrectly removed")
} }
// Verify session was removed // Verify expired session was removed
if _, exists := mockDB.sessions[sessionID]; exists { if _, err := mockDB.GetSessionByID(expiredSession.ID); err == nil {
t.Error("session still exists after invalidation") t.Error("expired session was not removed")
}
})
} }
} }