mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-05 23:44:22 +00:00
Implement auth tests
This commit is contained in:
221
server/internal/auth/jwt_test.go
Normal file
221
server/internal/auth/jwt_test.go
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
package auth_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"novamd/internal/auth"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// jwt_test.go tests
|
||||||
|
|
||||||
|
func TestNewJWTService(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
config auth.JWTConfig
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid configuration",
|
||||||
|
config: auth.JWTConfig{
|
||||||
|
SigningKey: "test-key",
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 24 * time.Hour,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing signing key",
|
||||||
|
config: auth.JWTConfig{
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 24 * time.Hour,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero expiry times",
|
||||||
|
config: auth.JWTConfig{
|
||||||
|
SigningKey: "test-key",
|
||||||
|
},
|
||||||
|
wantErr: false, // Should use default values
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
service, err := auth.NewJWTService(tc.config)
|
||||||
|
if tc.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if service == nil {
|
||||||
|
t.Error("expected service, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateAndValidateToken(t *testing.T) {
|
||||||
|
config := auth.JWTConfig{
|
||||||
|
SigningKey: "test-key",
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 24 * time.Hour,
|
||||||
|
}
|
||||||
|
service, _ := auth.NewJWTService(config)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
userID int
|
||||||
|
role string
|
||||||
|
tokenType auth.TokenType
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid access token",
|
||||||
|
userID: 1,
|
||||||
|
role: "admin",
|
||||||
|
tokenType: auth.AccessToken,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid refresh token",
|
||||||
|
userID: 1,
|
||||||
|
role: "editor",
|
||||||
|
tokenType: auth.RefreshToken,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var token string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Generate token based on type
|
||||||
|
if tc.tokenType == auth.AccessToken {
|
||||||
|
token, err = service.GenerateAccessToken(tc.userID, tc.role)
|
||||||
|
} else {
|
||||||
|
token, err = service.GenerateRefreshToken(tc.userID, tc.role)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate token
|
||||||
|
claims, err := service.ValidateToken(token)
|
||||||
|
if tc.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify claims
|
||||||
|
if claims.UserID != tc.userID {
|
||||||
|
t.Errorf("userID = %v, want %v", claims.UserID, tc.userID)
|
||||||
|
}
|
||||||
|
if claims.Role != tc.role {
|
||||||
|
t.Errorf("role = %v, want %v", claims.Role, tc.role)
|
||||||
|
}
|
||||||
|
if claims.Type != tc.tokenType {
|
||||||
|
t.Errorf("type = %v, want %v", claims.Type, tc.tokenType)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshAccessToken(t *testing.T) {
|
||||||
|
config := auth.JWTConfig{
|
||||||
|
SigningKey: "test-key",
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 24 * time.Hour,
|
||||||
|
}
|
||||||
|
service, _ := auth.NewJWTService(config)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
userID int
|
||||||
|
role string
|
||||||
|
wantErr bool
|
||||||
|
setupFunc func() string // Added setup function to handle custom token creation
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid refresh token",
|
||||||
|
userID: 1,
|
||||||
|
role: "admin",
|
||||||
|
wantErr: false,
|
||||||
|
setupFunc: func() string {
|
||||||
|
token, _ := service.GenerateRefreshToken(1, "admin")
|
||||||
|
return token
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired refresh token",
|
||||||
|
userID: 1,
|
||||||
|
role: "admin",
|
||||||
|
wantErr: true,
|
||||||
|
setupFunc: func() string {
|
||||||
|
// Create a token that's already expired
|
||||||
|
claims := &auth.Claims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)), // Expired 1 hour ago
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
|
||||||
|
NotBefore: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
|
||||||
|
},
|
||||||
|
UserID: 1,
|
||||||
|
Role: "admin",
|
||||||
|
Type: auth.RefreshToken,
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
tokenString, _ := token.SignedString([]byte(config.SigningKey))
|
||||||
|
return tokenString
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
refreshToken := tc.setupFunc()
|
||||||
|
newAccessToken, err := service.RefreshAccessToken(refreshToken)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := service.ValidateToken(newAccessToken)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to validate new access token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.UserID != tc.userID {
|
||||||
|
t.Errorf("userID = %v, want %v", claims.UserID, tc.userID)
|
||||||
|
}
|
||||||
|
if claims.Role != tc.role {
|
||||||
|
t.Errorf("role = %v, want %v", claims.Role, tc.role)
|
||||||
|
}
|
||||||
|
if claims.Type != auth.AccessToken {
|
||||||
|
t.Errorf("token type = %v, want %v", claims.Type, auth.AccessToken)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
362
server/internal/auth/middleware_test.go
Normal file
362
server/internal/auth/middleware_test.go
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
package auth_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"novamd/internal/auth"
|
||||||
|
"novamd/internal/httpcontext"
|
||||||
|
"novamd/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Complete mockResponseWriter implementation
|
||||||
|
type mockResponseWriter struct {
|
||||||
|
headers http.Header
|
||||||
|
statusCode int
|
||||||
|
written []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockResponseWriter() *mockResponseWriter {
|
||||||
|
return &mockResponseWriter{
|
||||||
|
headers: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockResponseWriter) Header() http.Header {
|
||||||
|
return m.headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockResponseWriter) Write(b []byte) (int, error) {
|
||||||
|
m.written = b
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
m.statusCode = statusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticateMiddleware(t *testing.T) {
|
||||||
|
config := auth.JWTConfig{
|
||||||
|
SigningKey: "test-key",
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 24 * time.Hour,
|
||||||
|
}
|
||||||
|
jwtService, _ := auth.NewJWTService(config)
|
||||||
|
middleware := auth.NewMiddleware(jwtService)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
setupAuth func() string
|
||||||
|
wantStatusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid token",
|
||||||
|
setupAuth: func() string {
|
||||||
|
token, _ := jwtService.GenerateAccessToken(1, "admin")
|
||||||
|
return token
|
||||||
|
},
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing auth header",
|
||||||
|
setupAuth: func() string {
|
||||||
|
return ""
|
||||||
|
},
|
||||||
|
wantStatusCode: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid auth format",
|
||||||
|
setupAuth: func() string {
|
||||||
|
return "InvalidFormat token"
|
||||||
|
},
|
||||||
|
wantStatusCode: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid token",
|
||||||
|
setupAuth: func() string {
|
||||||
|
return "Bearer invalid.token.here"
|
||||||
|
},
|
||||||
|
wantStatusCode: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Create test request
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
if token := tc.setupAuth(); token != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create response recorder
|
||||||
|
w := newMockResponseWriter()
|
||||||
|
|
||||||
|
// Create test handler
|
||||||
|
nextCalled := false
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
nextCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute middleware
|
||||||
|
middleware.Authenticate(next).ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Check status code
|
||||||
|
if w.statusCode != tc.wantStatusCode {
|
||||||
|
t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if next handler was called when expected
|
||||||
|
if tc.wantStatusCode == http.StatusOK && !nextCalled {
|
||||||
|
t.Error("next handler was not called")
|
||||||
|
}
|
||||||
|
if tc.wantStatusCode != http.StatusOK && nextCalled {
|
||||||
|
t.Error("next handler was called when it shouldn't have been")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireRole(t *testing.T) {
|
||||||
|
config := auth.JWTConfig{
|
||||||
|
SigningKey: "test-key",
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 24 * time.Hour,
|
||||||
|
}
|
||||||
|
jwtService, _ := auth.NewJWTService(config)
|
||||||
|
middleware := auth.NewMiddleware(jwtService)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
userRole string
|
||||||
|
requiredRole string
|
||||||
|
wantStatusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "matching role",
|
||||||
|
userRole: "admin",
|
||||||
|
requiredRole: "admin",
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "admin accessing other role",
|
||||||
|
userRole: "admin",
|
||||||
|
requiredRole: "editor",
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "insufficient role",
|
||||||
|
userRole: "editor",
|
||||||
|
requiredRole: "admin",
|
||||||
|
wantStatusCode: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Create context with user claims
|
||||||
|
ctx := context.WithValue(context.Background(), auth.UserContextKey, auth.UserClaims{
|
||||||
|
UserID: 1,
|
||||||
|
Role: tc.userRole,
|
||||||
|
})
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil).WithContext(ctx)
|
||||||
|
w := newMockResponseWriter()
|
||||||
|
|
||||||
|
// Create test handler
|
||||||
|
nextCalled := false
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
nextCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute middleware
|
||||||
|
middleware.RequireRole(tc.requiredRole)(next).ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Check status code
|
||||||
|
if w.statusCode != tc.wantStatusCode {
|
||||||
|
t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if next handler was called when expected
|
||||||
|
if tc.wantStatusCode == http.StatusOK && !nextCalled {
|
||||||
|
t.Error("next handler was not called")
|
||||||
|
}
|
||||||
|
if tc.wantStatusCode != http.StatusOK && nextCalled {
|
||||||
|
t.Error("next handler was called when it shouldn't have been")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireWorkspaceAccess(t *testing.T) {
|
||||||
|
config := auth.JWTConfig{
|
||||||
|
SigningKey: "test-key",
|
||||||
|
}
|
||||||
|
jwtService, _ := auth.NewJWTService(config)
|
||||||
|
middleware := auth.NewMiddleware(jwtService)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
setupContext func() *httpcontext.HandlerContext
|
||||||
|
wantStatusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "workspace owner access",
|
||||||
|
setupContext: func() *httpcontext.HandlerContext {
|
||||||
|
return &httpcontext.HandlerContext{
|
||||||
|
UserID: 1,
|
||||||
|
UserRole: "editor",
|
||||||
|
Workspace: &models.Workspace{
|
||||||
|
ID: 1,
|
||||||
|
UserID: 1, // Same as context UserID
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "admin access to other's workspace",
|
||||||
|
setupContext: func() *httpcontext.HandlerContext {
|
||||||
|
return &httpcontext.HandlerContext{
|
||||||
|
UserID: 2,
|
||||||
|
UserRole: "admin",
|
||||||
|
Workspace: &models.Workspace{
|
||||||
|
ID: 1,
|
||||||
|
UserID: 1, // Different from context UserID
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized access attempt",
|
||||||
|
setupContext: func() *httpcontext.HandlerContext {
|
||||||
|
return &httpcontext.HandlerContext{
|
||||||
|
UserID: 2,
|
||||||
|
UserRole: "editor",
|
||||||
|
Workspace: &models.Workspace{
|
||||||
|
ID: 1,
|
||||||
|
UserID: 1, // Different from context UserID
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantStatusCode: http.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no workspace in context",
|
||||||
|
setupContext: func() *httpcontext.HandlerContext {
|
||||||
|
return &httpcontext.HandlerContext{
|
||||||
|
UserID: 1,
|
||||||
|
UserRole: "editor",
|
||||||
|
Workspace: nil,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantStatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Create request with context
|
||||||
|
ctx := context.WithValue(context.Background(), httpcontext.HandlerContextKey, tc.setupContext())
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil).WithContext(ctx)
|
||||||
|
w := newMockResponseWriter()
|
||||||
|
|
||||||
|
// Create test handler
|
||||||
|
nextCalled := false
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
nextCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute middleware
|
||||||
|
middleware.RequireWorkspaceAccess(next).ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Check status code
|
||||||
|
if w.statusCode != tc.wantStatusCode {
|
||||||
|
t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if next handler was called when expected
|
||||||
|
if tc.wantStatusCode == http.StatusOK && !nextCalled {
|
||||||
|
t.Error("next handler was not called")
|
||||||
|
}
|
||||||
|
if tc.wantStatusCode != http.StatusOK && nextCalled {
|
||||||
|
t.Error("next handler was called when it shouldn't have been")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserFromContext(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
setupCtx func() context.Context
|
||||||
|
wantUserID int
|
||||||
|
wantRole string
|
||||||
|
wantErr bool
|
||||||
|
errContains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid user context",
|
||||||
|
setupCtx: func() context.Context {
|
||||||
|
return context.WithValue(context.Background(), auth.UserContextKey, auth.UserClaims{
|
||||||
|
UserID: 1,
|
||||||
|
Role: "admin",
|
||||||
|
})
|
||||||
|
},
|
||||||
|
wantUserID: 1,
|
||||||
|
wantRole: "admin",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing user context",
|
||||||
|
setupCtx: func() context.Context {
|
||||||
|
return context.Background()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "no user found in context",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid context value type",
|
||||||
|
setupCtx: func() context.Context {
|
||||||
|
return context.WithValue(context.Background(), auth.UserContextKey, "invalid")
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "no user found in context",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
ctx := tc.setupCtx()
|
||||||
|
claims, err := auth.GetUserFromContext(ctx)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
} else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) {
|
||||||
|
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.UserID != tc.wantUserID {
|
||||||
|
t.Errorf("UserID = %v, want %v", claims.UserID, tc.wantUserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.Role != tc.wantRole {
|
||||||
|
t.Errorf("Role = %v, want %v", claims.Role, tc.wantRole)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
298
server/internal/auth/session_test.go
Normal file
298
server/internal/auth/session_test.go
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
package auth_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"novamd/internal/auth"
|
||||||
|
"novamd/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Mock SessionStore
|
||||||
|
type mockSessionStore struct {
|
||||||
|
sessions map[string]*models.Session
|
||||||
|
sessionsByToken map[string]*models.Session // Added index by refresh token
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockSessionStore() *mockSessionStore {
|
||||||
|
return &mockSessionStore{
|
||||||
|
sessions: make(map[string]*models.Session),
|
||||||
|
sessionsByToken: make(map[string]*models.Session),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSessionStore) CreateSession(session *models.Session) error {
|
||||||
|
m.sessions[session.ID] = session
|
||||||
|
m.sessionsByToken[session.RefreshToken] = session
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSessionStore) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
|
||||||
|
session, exists := m.sessionsByToken[refreshToken]
|
||||||
|
if !exists {
|
||||||
|
return nil, errors.New("session not found")
|
||||||
|
}
|
||||||
|
if session.ExpiresAt.Before(time.Now()) {
|
||||||
|
return nil, errors.New("session expired")
|
||||||
|
}
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSessionStore) DeleteSession(sessionID string) error {
|
||||||
|
session, exists := m.sessions[sessionID]
|
||||||
|
if !exists {
|
||||||
|
return errors.New("session not found")
|
||||||
|
}
|
||||||
|
delete(m.sessionsByToken, session.RefreshToken)
|
||||||
|
delete(m.sessions, sessionID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSessionStore) CleanExpiredSessions() error {
|
||||||
|
for id, session := range m.sessions {
|
||||||
|
if session.ExpiresAt.Before(time.Now()) {
|
||||||
|
delete(m.sessionsByToken, session.RefreshToken)
|
||||||
|
delete(m.sessions, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateSession(t *testing.T) {
|
||||||
|
config := auth.JWTConfig{
|
||||||
|
SigningKey: "test-key",
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 24 * time.Hour,
|
||||||
|
}
|
||||||
|
jwtService, _ := auth.NewJWTService(config)
|
||||||
|
mockDB := newMockSessionStore()
|
||||||
|
sessionService := auth.NewSessionService(mockDB, jwtService)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
userID int
|
||||||
|
role string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "successful session creation",
|
||||||
|
userID: 1,
|
||||||
|
role: "admin",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "another successful session",
|
||||||
|
userID: 2,
|
||||||
|
role: "editor",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
session, accessToken, err := sessionService.CreateSession(tc.userID, tc.role)
|
||||||
|
if tc.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify session
|
||||||
|
if session.UserID != tc.userID {
|
||||||
|
t.Errorf("userID = %v, want %v", session.UserID, tc.userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the session was stored
|
||||||
|
storedSession, exists := mockDB.sessions[session.ID]
|
||||||
|
if !exists {
|
||||||
|
t.Error("session was not stored in database")
|
||||||
|
}
|
||||||
|
if storedSession.RefreshToken != session.RefreshToken {
|
||||||
|
t.Error("stored refresh token doesn't match")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify access token
|
||||||
|
claims, err := jwtService.ValidateToken(accessToken)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to validate access token: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if claims.UserID != tc.userID {
|
||||||
|
t.Errorf("access token userID = %v, want %v", claims.UserID, tc.userID)
|
||||||
|
}
|
||||||
|
if claims.Role != tc.role {
|
||||||
|
t.Errorf("access token role = %v, want %v", claims.Role, tc.role)
|
||||||
|
}
|
||||||
|
if claims.Type != auth.AccessToken {
|
||||||
|
t.Errorf("token type = %v, want access token", claims.Type)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshSession(t *testing.T) {
|
||||||
|
config := auth.JWTConfig{
|
||||||
|
SigningKey: "test-key",
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 24 * time.Hour,
|
||||||
|
}
|
||||||
|
jwtService, _ := auth.NewJWTService(config)
|
||||||
|
mockDB := newMockSessionStore()
|
||||||
|
sessionService := auth.NewSessionService(mockDB, jwtService)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
setupSession func() string
|
||||||
|
wantErr bool
|
||||||
|
errorContains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid refresh token",
|
||||||
|
setupSession: func() string {
|
||||||
|
token, _ := jwtService.GenerateRefreshToken(1, "admin")
|
||||||
|
session := &models.Session{
|
||||||
|
ID: "test-session-1",
|
||||||
|
UserID: 1,
|
||||||
|
RefreshToken: token,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
mockDB.CreateSession(session)
|
||||||
|
return token
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired refresh token",
|
||||||
|
setupSession: func() string {
|
||||||
|
token, _ := jwtService.GenerateRefreshToken(1, "admin")
|
||||||
|
session := &models.Session{
|
||||||
|
ID: "test-session-2",
|
||||||
|
UserID: 1,
|
||||||
|
RefreshToken: token,
|
||||||
|
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired
|
||||||
|
CreatedAt: time.Now().Add(-2 * time.Hour),
|
||||||
|
}
|
||||||
|
mockDB.CreateSession(session)
|
||||||
|
return token
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errorContains: "session expired",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existent refresh token",
|
||||||
|
setupSession: func() string {
|
||||||
|
return "non-existent-token"
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errorContains: "session not found",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
refreshToken := tc.setupSession()
|
||||||
|
newAccessToken, err := sessionService.RefreshSession(refreshToken)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
} else if tc.errorContains != "" && !strings.Contains(err.Error(), tc.errorContains) {
|
||||||
|
t.Errorf("error = %v, want error containing %v", err, tc.errorContains)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify new access token
|
||||||
|
claims, err := jwtService.ValidateToken(newAccessToken)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to validate new access token: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if claims.Type != auth.AccessToken {
|
||||||
|
t.Errorf("token type = %v, want access token", claims.Type)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidateSession(t *testing.T) {
|
||||||
|
config := auth.JWTConfig{
|
||||||
|
SigningKey: "test-key",
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 24 * time.Hour,
|
||||||
|
}
|
||||||
|
jwtService, _ := auth.NewJWTService(config)
|
||||||
|
mockDB := newMockSessionStore()
|
||||||
|
sessionService := auth.NewSessionService(mockDB, jwtService)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
setupSession func() string
|
||||||
|
wantErr bool
|
||||||
|
errorContains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid session invalidation",
|
||||||
|
setupSession: func() string {
|
||||||
|
session := &models.Session{
|
||||||
|
ID: "test-session-1",
|
||||||
|
UserID: 1,
|
||||||
|
RefreshToken: "valid-token",
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
mockDB.CreateSession(session)
|
||||||
|
return session.ID
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existent session",
|
||||||
|
setupSession: func() string {
|
||||||
|
return "non-existent-session-id"
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errorContains: "session not found",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
sessionID := tc.setupSession()
|
||||||
|
err := sessionService.InvalidateSession(sessionID)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
} else if !strings.Contains(err.Error(), tc.errorContains) {
|
||||||
|
t.Errorf("error = %v, want error containing %v", err, tc.errorContains)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify session was removed
|
||||||
|
if _, exists := mockDB.sessions[sessionID]; exists {
|
||||||
|
t.Error("session still exists after invalidation")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user