From b3ec4e136ce4c6db3d6640964b1e43337eab34ce Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 22 Nov 2024 23:17:59 +0100 Subject: [PATCH] Implement auth tests --- server/internal/auth/jwt_test.go | 221 +++++++++++++++ server/internal/auth/middleware_test.go | 362 ++++++++++++++++++++++++ server/internal/auth/session_test.go | 298 +++++++++++++++++++ 3 files changed, 881 insertions(+) create mode 100644 server/internal/auth/jwt_test.go create mode 100644 server/internal/auth/middleware_test.go create mode 100644 server/internal/auth/session_test.go diff --git a/server/internal/auth/jwt_test.go b/server/internal/auth/jwt_test.go new file mode 100644 index 0000000..61aa3ad --- /dev/null +++ b/server/internal/auth/jwt_test.go @@ -0,0 +1,221 @@ +package auth_test + +import ( + "testing" + "time" + + "novamd/internal/auth" + + "github.com/golang-jwt/jwt/v5" +) + +// jwt_test.go tests + +func TestNewJWTService(t *testing.T) { + testCases := []struct { + name string + config auth.JWTConfig + wantErr bool + }{ + { + name: "valid configuration", + config: auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + }, + wantErr: false, + }, + { + name: "missing signing key", + config: auth.JWTConfig{ + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + }, + wantErr: true, + }, + { + name: "zero expiry times", + config: auth.JWTConfig{ + SigningKey: "test-key", + }, + wantErr: false, // Should use default values + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + service, err := auth.NewJWTService(tc.config) + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if service == nil { + t.Error("expected service, got nil") + } + }) + } +} + +func TestGenerateAndValidateToken(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + service, _ := auth.NewJWTService(config) + + testCases := []struct { + name string + userID int + role string + tokenType auth.TokenType + wantErr bool + }{ + { + name: "valid access token", + userID: 1, + role: "admin", + tokenType: auth.AccessToken, + wantErr: false, + }, + { + name: "valid refresh token", + userID: 1, + role: "editor", + tokenType: auth.RefreshToken, + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var token string + var err error + + // Generate token based on type + if tc.tokenType == auth.AccessToken { + token, err = service.GenerateAccessToken(tc.userID, tc.role) + } else { + token, err = service.GenerateRefreshToken(tc.userID, tc.role) + } + + if err != nil { + t.Fatalf("failed to generate token: %v", err) + } + + // Validate token + claims, err := service.ValidateToken(token) + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Verify claims + if claims.UserID != tc.userID { + t.Errorf("userID = %v, want %v", claims.UserID, tc.userID) + } + if claims.Role != tc.role { + t.Errorf("role = %v, want %v", claims.Role, tc.role) + } + if claims.Type != tc.tokenType { + t.Errorf("type = %v, want %v", claims.Type, tc.tokenType) + } + }) + } +} + +func TestRefreshAccessToken(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + service, _ := auth.NewJWTService(config) + + testCases := []struct { + name string + userID int + role string + wantErr bool + setupFunc func() string // Added setup function to handle custom token creation + }{ + { + name: "valid refresh token", + userID: 1, + role: "admin", + wantErr: false, + setupFunc: func() string { + token, _ := service.GenerateRefreshToken(1, "admin") + return token + }, + }, + { + name: "expired refresh token", + userID: 1, + role: "admin", + wantErr: true, + setupFunc: func() string { + // Create a token that's already expired + claims := &auth.Claims{ + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)), // Expired 1 hour ago + IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)), + NotBefore: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)), + }, + UserID: 1, + Role: "admin", + Type: auth.RefreshToken, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, _ := token.SignedString([]byte(config.SigningKey)) + return tokenString + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + refreshToken := tc.setupFunc() + newAccessToken, err := service.RefreshAccessToken(refreshToken) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + claims, err := service.ValidateToken(newAccessToken) + if err != nil { + t.Fatalf("failed to validate new access token: %v", err) + } + + if claims.UserID != tc.userID { + t.Errorf("userID = %v, want %v", claims.UserID, tc.userID) + } + if claims.Role != tc.role { + t.Errorf("role = %v, want %v", claims.Role, tc.role) + } + if claims.Type != auth.AccessToken { + t.Errorf("token type = %v, want %v", claims.Type, auth.AccessToken) + } + }) + } +} diff --git a/server/internal/auth/middleware_test.go b/server/internal/auth/middleware_test.go new file mode 100644 index 0000000..153bf33 --- /dev/null +++ b/server/internal/auth/middleware_test.go @@ -0,0 +1,362 @@ +package auth_test + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "novamd/internal/auth" + "novamd/internal/httpcontext" + "novamd/internal/models" +) + +// Complete mockResponseWriter implementation +type mockResponseWriter struct { + headers http.Header + statusCode int + written []byte +} + +func newMockResponseWriter() *mockResponseWriter { + return &mockResponseWriter{ + headers: make(http.Header), + } +} + +func (m *mockResponseWriter) Header() http.Header { + return m.headers +} + +func (m *mockResponseWriter) Write(b []byte) (int, error) { + m.written = b + return len(b), nil +} + +func (m *mockResponseWriter) WriteHeader(statusCode int) { + m.statusCode = statusCode +} + +func TestAuthenticateMiddleware(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + jwtService, _ := auth.NewJWTService(config) + middleware := auth.NewMiddleware(jwtService) + + testCases := []struct { + name string + setupAuth func() string + wantStatusCode int + }{ + { + name: "valid token", + setupAuth: func() string { + token, _ := jwtService.GenerateAccessToken(1, "admin") + return token + }, + wantStatusCode: http.StatusOK, + }, + { + name: "missing auth header", + setupAuth: func() string { + return "" + }, + wantStatusCode: http.StatusUnauthorized, + }, + { + name: "invalid auth format", + setupAuth: func() string { + return "InvalidFormat token" + }, + wantStatusCode: http.StatusUnauthorized, + }, + { + name: "invalid token", + setupAuth: func() string { + return "Bearer invalid.token.here" + }, + wantStatusCode: http.StatusUnauthorized, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create test request + req := httptest.NewRequest("GET", "/test", nil) + if token := tc.setupAuth(); token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + // Create response recorder + w := newMockResponseWriter() + + // Create test handler + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + // Execute middleware + middleware.Authenticate(next).ServeHTTP(w, req) + + // Check status code + if w.statusCode != tc.wantStatusCode { + t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode) + } + + // Check if next handler was called when expected + if tc.wantStatusCode == http.StatusOK && !nextCalled { + t.Error("next handler was not called") + } + if tc.wantStatusCode != http.StatusOK && nextCalled { + t.Error("next handler was called when it shouldn't have been") + } + }) + } +} + +func TestRequireRole(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + jwtService, _ := auth.NewJWTService(config) + middleware := auth.NewMiddleware(jwtService) + + testCases := []struct { + name string + userRole string + requiredRole string + wantStatusCode int + }{ + { + name: "matching role", + userRole: "admin", + requiredRole: "admin", + wantStatusCode: http.StatusOK, + }, + { + name: "admin accessing other role", + userRole: "admin", + requiredRole: "editor", + wantStatusCode: http.StatusOK, + }, + { + name: "insufficient role", + userRole: "editor", + requiredRole: "admin", + wantStatusCode: http.StatusForbidden, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create context with user claims + ctx := context.WithValue(context.Background(), auth.UserContextKey, auth.UserClaims{ + UserID: 1, + Role: tc.userRole, + }) + req := httptest.NewRequest("GET", "/test", nil).WithContext(ctx) + w := newMockResponseWriter() + + // Create test handler + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + // Execute middleware + middleware.RequireRole(tc.requiredRole)(next).ServeHTTP(w, req) + + // Check status code + if w.statusCode != tc.wantStatusCode { + t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode) + } + + // Check if next handler was called when expected + if tc.wantStatusCode == http.StatusOK && !nextCalled { + t.Error("next handler was not called") + } + if tc.wantStatusCode != http.StatusOK && nextCalled { + t.Error("next handler was called when it shouldn't have been") + } + }) + } +} + +func TestRequireWorkspaceAccess(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + } + jwtService, _ := auth.NewJWTService(config) + middleware := auth.NewMiddleware(jwtService) + + testCases := []struct { + name string + setupContext func() *httpcontext.HandlerContext + wantStatusCode int + }{ + { + name: "workspace owner access", + setupContext: func() *httpcontext.HandlerContext { + return &httpcontext.HandlerContext{ + UserID: 1, + UserRole: "editor", + Workspace: &models.Workspace{ + ID: 1, + UserID: 1, // Same as context UserID + }, + } + }, + wantStatusCode: http.StatusOK, + }, + { + name: "admin access to other's workspace", + setupContext: func() *httpcontext.HandlerContext { + return &httpcontext.HandlerContext{ + UserID: 2, + UserRole: "admin", + Workspace: &models.Workspace{ + ID: 1, + UserID: 1, // Different from context UserID + }, + } + }, + wantStatusCode: http.StatusOK, + }, + { + name: "unauthorized access attempt", + setupContext: func() *httpcontext.HandlerContext { + return &httpcontext.HandlerContext{ + UserID: 2, + UserRole: "editor", + Workspace: &models.Workspace{ + ID: 1, + UserID: 1, // Different from context UserID + }, + } + }, + wantStatusCode: http.StatusNotFound, + }, + { + name: "no workspace in context", + setupContext: func() *httpcontext.HandlerContext { + return &httpcontext.HandlerContext{ + UserID: 1, + UserRole: "editor", + Workspace: nil, + } + }, + wantStatusCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create request with context + ctx := context.WithValue(context.Background(), httpcontext.HandlerContextKey, tc.setupContext()) + req := httptest.NewRequest("GET", "/test", nil).WithContext(ctx) + w := newMockResponseWriter() + + // Create test handler + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + // Execute middleware + middleware.RequireWorkspaceAccess(next).ServeHTTP(w, req) + + // Check status code + if w.statusCode != tc.wantStatusCode { + t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode) + } + + // Check if next handler was called when expected + if tc.wantStatusCode == http.StatusOK && !nextCalled { + t.Error("next handler was not called") + } + if tc.wantStatusCode != http.StatusOK && nextCalled { + t.Error("next handler was called when it shouldn't have been") + } + }) + } +} + +func TestGetUserFromContext(t *testing.T) { + testCases := []struct { + name string + setupCtx func() context.Context + wantUserID int + wantRole string + wantErr bool + errContains string + }{ + { + name: "valid user context", + setupCtx: func() context.Context { + return context.WithValue(context.Background(), auth.UserContextKey, auth.UserClaims{ + UserID: 1, + Role: "admin", + }) + }, + wantUserID: 1, + wantRole: "admin", + wantErr: false, + }, + { + name: "missing user context", + setupCtx: func() context.Context { + return context.Background() + }, + wantErr: true, + errContains: "no user found in context", + }, + { + name: "invalid context value type", + setupCtx: func() context.Context { + return context.WithValue(context.Background(), auth.UserContextKey, "invalid") + }, + wantErr: true, + errContains: "no user found in context", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := tc.setupCtx() + claims, err := auth.GetUserFromContext(ctx) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if claims.UserID != tc.wantUserID { + t.Errorf("UserID = %v, want %v", claims.UserID, tc.wantUserID) + } + + if claims.Role != tc.wantRole { + t.Errorf("Role = %v, want %v", claims.Role, tc.wantRole) + } + }) + } +} diff --git a/server/internal/auth/session_test.go b/server/internal/auth/session_test.go new file mode 100644 index 0000000..00c7494 --- /dev/null +++ b/server/internal/auth/session_test.go @@ -0,0 +1,298 @@ +package auth_test + +import ( + "errors" + "strings" + "testing" + "time" + + "novamd/internal/auth" + "novamd/internal/models" +) + +// Mock SessionStore +type mockSessionStore struct { + sessions map[string]*models.Session + sessionsByToken map[string]*models.Session // Added index by refresh token +} + +func newMockSessionStore() *mockSessionStore { + return &mockSessionStore{ + sessions: make(map[string]*models.Session), + sessionsByToken: make(map[string]*models.Session), + } +} + +func (m *mockSessionStore) CreateSession(session *models.Session) error { + m.sessions[session.ID] = session + m.sessionsByToken[session.RefreshToken] = session + return nil +} + +func (m *mockSessionStore) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { + session, exists := m.sessionsByToken[refreshToken] + if !exists { + return nil, errors.New("session not found") + } + if session.ExpiresAt.Before(time.Now()) { + return nil, errors.New("session expired") + } + return session, nil +} + +func (m *mockSessionStore) DeleteSession(sessionID string) error { + session, exists := m.sessions[sessionID] + if !exists { + return errors.New("session not found") + } + delete(m.sessionsByToken, session.RefreshToken) + delete(m.sessions, sessionID) + return nil +} + +func (m *mockSessionStore) CleanExpiredSessions() error { + for id, session := range m.sessions { + if session.ExpiresAt.Before(time.Now()) { + delete(m.sessionsByToken, session.RefreshToken) + delete(m.sessions, id) + } + } + return nil +} + +func TestCreateSession(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + jwtService, _ := auth.NewJWTService(config) + mockDB := newMockSessionStore() + sessionService := auth.NewSessionService(mockDB, jwtService) + + testCases := []struct { + name string + userID int + role string + wantErr bool + }{ + { + name: "successful session creation", + userID: 1, + role: "admin", + wantErr: false, + }, + { + name: "another successful session", + userID: 2, + role: "editor", + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + session, accessToken, err := sessionService.CreateSession(tc.userID, tc.role) + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Verify session + if session.UserID != tc.userID { + t.Errorf("userID = %v, want %v", session.UserID, tc.userID) + } + + // Verify the session was stored + storedSession, exists := mockDB.sessions[session.ID] + if !exists { + t.Error("session was not stored in database") + } + if storedSession.RefreshToken != session.RefreshToken { + t.Error("stored refresh token doesn't match") + } + + // Verify access token + claims, err := jwtService.ValidateToken(accessToken) + if err != nil { + t.Errorf("failed to validate access token: %v", err) + return + } + if claims.UserID != tc.userID { + t.Errorf("access token userID = %v, want %v", claims.UserID, tc.userID) + } + if claims.Role != tc.role { + t.Errorf("access token role = %v, want %v", claims.Role, tc.role) + } + if claims.Type != auth.AccessToken { + t.Errorf("token type = %v, want access token", claims.Type) + } + }) + } +} + +func TestRefreshSession(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + jwtService, _ := auth.NewJWTService(config) + mockDB := newMockSessionStore() + sessionService := auth.NewSessionService(mockDB, jwtService) + + testCases := []struct { + name string + setupSession func() string + wantErr bool + errorContains string + }{ + { + name: "valid refresh token", + setupSession: func() string { + token, _ := jwtService.GenerateRefreshToken(1, "admin") + session := &models.Session{ + ID: "test-session-1", + UserID: 1, + RefreshToken: token, + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + } + mockDB.CreateSession(session) + return token + }, + wantErr: false, + }, + { + name: "expired refresh token", + setupSession: func() string { + token, _ := jwtService.GenerateRefreshToken(1, "admin") + session := &models.Session{ + ID: "test-session-2", + UserID: 1, + RefreshToken: token, + ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired + CreatedAt: time.Now().Add(-2 * time.Hour), + } + mockDB.CreateSession(session) + return token + }, + wantErr: true, + errorContains: "session expired", + }, + { + name: "non-existent refresh token", + setupSession: func() string { + return "non-existent-token" + }, + wantErr: true, + errorContains: "session not found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + refreshToken := tc.setupSession() + newAccessToken, err := sessionService.RefreshSession(refreshToken) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tc.errorContains != "" && !strings.Contains(err.Error(), tc.errorContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errorContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Verify new access token + claims, err := jwtService.ValidateToken(newAccessToken) + if err != nil { + t.Errorf("failed to validate new access token: %v", err) + return + } + if claims.Type != auth.AccessToken { + t.Errorf("token type = %v, want access token", claims.Type) + } + }) + } +} + +func TestInvalidateSession(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + jwtService, _ := auth.NewJWTService(config) + mockDB := newMockSessionStore() + sessionService := auth.NewSessionService(mockDB, jwtService) + + testCases := []struct { + name string + setupSession func() string + wantErr bool + errorContains string + }{ + { + name: "valid session invalidation", + setupSession: func() string { + session := &models.Session{ + ID: "test-session-1", + UserID: 1, + RefreshToken: "valid-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + } + mockDB.CreateSession(session) + return session.ID + }, + wantErr: false, + }, + { + name: "non-existent session", + setupSession: func() string { + return "non-existent-session-id" + }, + wantErr: true, + errorContains: "session not found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sessionID := tc.setupSession() + err := sessionService.InvalidateSession(sessionID) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), tc.errorContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errorContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Verify session was removed + if _, exists := mockDB.sessions[sessionID]; exists { + t.Error("session still exists after invalidation") + } + }) + } +}