Update session and cookie managers

This commit is contained in:
2024-12-07 21:19:02 +01:00
parent de9e9102db
commit 8a4508e29f
10 changed files with 111 additions and 56 deletions

View File

@@ -40,7 +40,7 @@ func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, err
} }
// initAuth initializes JWT and session services // initAuth initializes JWT and session services
func initAuth(cfg *Config, database db.Database) (auth.JWTManager, *auth.SessionService, auth.CookieService, error) { func initAuth(cfg *Config, database db.Database) (auth.JWTManager, auth.SessionManager, auth.CookieManager, error) {
// Get or generate JWT signing key // Get or generate JWT signing key
signingKey := cfg.JWTSigningKey signingKey := cfg.JWTSigningKey
if signingKey == "" { if signingKey == "" {
@@ -62,12 +62,12 @@ func initAuth(cfg *Config, database db.Database) (auth.JWTManager, *auth.Session
} }
// Initialize session service // Initialize session service
sessionService := auth.NewSessionService(database, jwtManager) sessionManager := auth.NewSessionService(database, jwtManager)
// Cookie service // Cookie service
cookieService := auth.NewCookieService(cfg.IsDevelopment, cfg.Domain) cookieService := auth.NewCookieService(cfg.IsDevelopment, cfg.Domain)
return jwtManager, sessionService, cookieService, nil return jwtManager, sessionManager, cookieService, nil
} }
// setupAdminUser creates the admin user if it doesn't exist // setupAdminUser creates the admin user if it doesn't exist

View File

@@ -12,8 +12,8 @@ type Options struct {
Database db.Database Database db.Database
Storage storage.Manager Storage storage.Manager
JWTManager auth.JWTManager JWTManager auth.JWTManager
SessionService *auth.SessionService SessionManager auth.SessionManager
CookieService auth.CookieService CookieService auth.CookieManager
} }
// DefaultOptions creates server options with default configuration // DefaultOptions creates server options with default configuration
@@ -49,7 +49,7 @@ func DefaultOptions(cfg *Config) (*Options, error) {
Database: database, Database: database,
Storage: storageManager, Storage: storageManager,
JWTManager: jwtManager, JWTManager: jwtManager,
SessionService: sessionService, SessionManager: sessionService,
CookieService: cookieService, CookieService: cookieService,
}, nil }, nil
} }

View File

@@ -48,7 +48,7 @@ func setupRouter(o Options) *chi.Mux {
} }
// Initialize auth middleware and handler // Initialize auth middleware and handler
authMiddleware := auth.NewMiddleware(o.JWTManager) authMiddleware := auth.NewMiddleware(o.JWTManager, o.SessionManager, o.CookieService)
handler := &handlers.Handler{ handler := &handlers.Handler{
DB: o.Database, DB: o.Database,
Storage: o.Storage, Storage: o.Storage,
@@ -72,8 +72,8 @@ func setupRouter(o Options) *chi.Mux {
// Public routes (no authentication required) // Public routes (no authentication required)
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Post("/auth/login", handler.Login(o.SessionService, o.CookieService)) r.Post("/auth/login", handler.Login(o.SessionManager, o.CookieService))
r.Post("/auth/refresh", handler.RefreshToken(o.SessionService, o.CookieService)) r.Post("/auth/refresh", handler.RefreshToken(o.SessionManager, o.CookieService))
}) })
// Protected routes (authentication required) // Protected routes (authentication required)
@@ -82,7 +82,7 @@ func setupRouter(o Options) *chi.Mux {
r.Use(context.WithUserContextMiddleware) r.Use(context.WithUserContextMiddleware)
// Auth routes // Auth routes
r.Post("/auth/logout", handler.Logout(o.SessionService, o.CookieService)) r.Post("/auth/logout", handler.Logout(o.SessionManager, o.CookieService))
r.Get("/auth/me", handler.GetCurrentUser()) r.Get("/auth/me", handler.GetCurrentUser())
// User profile routes // User profile routes

View File

@@ -5,8 +5,8 @@ import (
"net/http" "net/http"
) )
// CookieService interface defines methods for generating cookies // CookieManager interface defines methods for generating cookies
type CookieService interface { type CookieManager interface {
GenerateAccessTokenCookie(token string) *http.Cookie GenerateAccessTokenCookie(token string) *http.Cookie
GenerateRefreshTokenCookie(token string) *http.Cookie GenerateRefreshTokenCookie(token string) *http.Cookie
GenerateCSRFCookie(token string) *http.Cookie GenerateCSRFCookie(token string) *http.Cookie
@@ -14,14 +14,14 @@ type CookieService interface {
} }
// CookieService // CookieService
type cookieService struct { type cookieManager struct {
Domain string Domain string
Secure bool Secure bool
SameSite http.SameSite SameSite http.SameSite
} }
// NewCookieService creates a new cookie service // NewCookieService creates a new cookie service
func NewCookieService(isDevelopment bool, domain string) CookieService { func NewCookieService(isDevelopment bool, domain string) CookieManager {
secure := !isDevelopment secure := !isDevelopment
var sameSite http.SameSite var sameSite http.SameSite
@@ -31,7 +31,7 @@ func NewCookieService(isDevelopment bool, domain string) CookieService {
sameSite = http.SameSiteStrictMode sameSite = http.SameSiteStrictMode
} }
return &cookieService{ return &cookieManager{
Domain: domain, Domain: domain,
Secure: secure, Secure: secure,
SameSite: sameSite, SameSite: sameSite,
@@ -39,7 +39,7 @@ func NewCookieService(isDevelopment bool, domain string) CookieService {
} }
// GenerateAccessTokenCookie creates a new cookie for the access token // GenerateAccessTokenCookie creates a new cookie for the access token
func (c *cookieService) GenerateAccessTokenCookie(token string) *http.Cookie { func (c *cookieManager) GenerateAccessTokenCookie(token string) *http.Cookie {
return &http.Cookie{ return &http.Cookie{
Name: "access_token", Name: "access_token",
Value: token, Value: token,
@@ -52,7 +52,7 @@ func (c *cookieService) GenerateAccessTokenCookie(token string) *http.Cookie {
} }
// GenerateRefreshTokenCookie creates a new cookie for the refresh token // GenerateRefreshTokenCookie creates a new cookie for the refresh token
func (c *cookieService) GenerateRefreshTokenCookie(token string) *http.Cookie { func (c *cookieManager) GenerateRefreshTokenCookie(token string) *http.Cookie {
return &http.Cookie{ return &http.Cookie{
Name: "refresh_token", Name: "refresh_token",
Value: token, Value: token,
@@ -65,7 +65,7 @@ func (c *cookieService) GenerateRefreshTokenCookie(token string) *http.Cookie {
} }
// GenerateCSRFCookie creates a new cookie for the CSRF token // GenerateCSRFCookie creates a new cookie for the CSRF token
func (c *cookieService) GenerateCSRFCookie(token string) *http.Cookie { func (c *cookieManager) GenerateCSRFCookie(token string) *http.Cookie {
return &http.Cookie{ return &http.Cookie{
Name: "csrf_token", Name: "csrf_token",
Value: token, Value: token,
@@ -78,7 +78,7 @@ func (c *cookieService) GenerateCSRFCookie(token string) *http.Cookie {
} }
// InvalidateCookie creates a new cookie with a MaxAge of -1 to invalidate the cookie // InvalidateCookie creates a new cookie with a MaxAge of -1 to invalidate the cookie
func (c *cookieService) InvalidateCookie(cookieType string) *http.Cookie { func (c *cookieManager) InvalidateCookie(cookieType string) *http.Cookie {
return &http.Cookie{ return &http.Cookie{
Name: cookieType, Name: cookieType,
Value: "", Value: "",

View File

@@ -10,12 +10,16 @@ import (
// Middleware handles JWT authentication for protected routes // Middleware handles JWT authentication for protected routes
type Middleware struct { type Middleware struct {
jwtManager JWTManager jwtManager JWTManager
sessionManager SessionManager
cookieManager CookieManager
} }
// NewMiddleware creates a new authentication middleware // NewMiddleware creates a new authentication middleware
func NewMiddleware(jwtManager JWTManager) *Middleware { func NewMiddleware(jwtManager JWTManager, sessionManager SessionManager, cookieManager CookieManager) *Middleware {
return &Middleware{ return &Middleware{
jwtManager: jwtManager, jwtManager: jwtManager,
sessionManager: sessionManager,
cookieManager: cookieManager,
} }
} }
@@ -42,6 +46,16 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
return return
} }
// Check if session is still valid in database
session, err := m.sessionManager.ValidateSession(claims.ID)
if err != nil || session == nil {
m.cookieManager.InvalidateCookie("access_token")
m.cookieManager.InvalidateCookie("refresh_token")
m.cookieManager.InvalidateCookie("csrf_token")
http.Error(w, "Session invalid or expired", http.StatusUnauthorized)
return
}
// Add CSRF check for non-GET requests // Add CSRF check for non-GET requests
if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions { if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions {
csrfCookie, err := r.Cookie("csrf_token") csrfCookie, err := r.Cookie("csrf_token")

View File

@@ -9,22 +9,30 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
// SessionService manages user sessions in the database type SessionManager interface {
type SessionService struct { CreateSession(userID int, role string) (*models.Session, string, error)
RefreshSession(refreshToken string) (string, error)
ValidateSession(sessionID string) (*models.Session, error)
InvalidateSession(token string) error
CleanExpiredSessions() error
}
// sessionManager manages user sessions in the database
type sessionManager struct {
db db.SessionStore // Database store for sessions db db.SessionStore // Database store for sessions
jwtManager JWTManager // JWT Manager for token operations jwtManager JWTManager // JWT Manager for token operations
} }
// NewSessionService creates a new session service with the given database and JWT manager // NewSessionService creates a new session service with the given database and JWT manager
func NewSessionService(db db.SessionStore, jwtManager JWTManager) *SessionService { func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManager {
return &SessionService{ return &sessionManager{
db: db, db: db,
jwtManager: jwtManager, jwtManager: jwtManager,
} }
} }
// CreateSession creates a new user session for a user with the given userID and role // CreateSession creates a new user session for a user with the given userID and role
func (s *SessionService) CreateSession(userID int, role string) (*models.Session, string, error) { func (s *sessionManager) CreateSession(userID int, role string) (*models.Session, string, error) {
// Generate both access and refresh tokens // Generate both access and refresh tokens
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role) accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
if err != nil { if err != nil {
@@ -60,7 +68,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*models.Session
} }
// RefreshSession creates a new access token using a refreshToken // RefreshSession creates a new access token using a refreshToken
func (s *SessionService) RefreshSession(refreshToken string) (string, error) { func (s *sessionManager) RefreshSession(refreshToken string) (string, error) {
// Get session from database first // Get session from database first
session, err := s.db.GetSessionByRefreshToken(refreshToken) session, err := s.db.GetSessionByRefreshToken(refreshToken)
if err != nil { if err != nil {
@@ -82,8 +90,20 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role) return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
} }
// ValidateSession checks if a session with the given sessionID is valid
func (s *sessionManager) ValidateSession(sessionID string) (*models.Session, error) {
// Get the session from the database
session, err := s.db.GetSessionByID(sessionID)
if err != nil {
return nil, fmt.Errorf("failed to get session: %w", err)
}
return session, nil
}
// InvalidateSession removes a session with the given sessionID from the database // InvalidateSession removes a session with the given sessionID from the database
func (s *SessionService) InvalidateSession(token string) error { func (s *sessionManager) InvalidateSession(token string) error {
// Parse the JWT to get the session info // Parse the JWT to get the session info
claims, err := s.jwtManager.ValidateToken(token) claims, err := s.jwtManager.ValidateToken(token)
if err != nil { if err != nil {
@@ -94,6 +114,6 @@ func (s *SessionService) InvalidateSession(token string) error {
} }
// CleanExpiredSessions removes all expired sessions from the database // CleanExpiredSessions removes all expired sessions from the database
func (s *SessionService) CleanExpiredSessions() error { func (s *sessionManager) CleanExpiredSessions() error {
return s.db.CleanExpiredSessions() return s.db.CleanExpiredSessions()
} }

View File

@@ -53,6 +53,7 @@ type WorkspaceStore interface {
type SessionStore interface { type SessionStore interface {
CreateSession(session *models.Session) error CreateSession(session *models.Session) error
GetSessionByRefreshToken(refreshToken string) (*models.Session, error) GetSessionByRefreshToken(refreshToken string) (*models.Session, error)
GetSessionByID(sessionID string) (*models.Session, error)
DeleteSession(sessionID string) error DeleteSession(sessionID string) error
CleanExpiredSessions() error CleanExpiredSessions() error
} }

View File

@@ -41,6 +41,26 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi
return session, nil return session, nil
} }
// GetSessionByID retrieves a session by its ID
func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
session := &models.Session{}
err := db.QueryRow(`
SELECT id, user_id, refresh_token, expires_at, created_at
FROM sessions
WHERE id = ? AND expires_at > ?`,
sessionID, time.Now(),
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("session not found")
}
if err != nil {
return nil, fmt.Errorf("failed to fetch session: %w", err)
}
return session, nil
}
// DeleteSession removes a session from the database // DeleteSession removes a session from the database
func (db *database) DeleteSession(sessionID string) error { func (db *database) DeleteSession(sessionID string) error {
result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID) result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)

View File

@@ -39,7 +39,7 @@ type LoginResponse struct {
// @Failure 401 {object} ErrorResponse "Invalid credentials" // @Failure 401 {object} ErrorResponse "Invalid credentials"
// @Failure 500 {object} ErrorResponse "Failed to create session" // @Failure 500 {object} ErrorResponse "Failed to create session"
// @Router /auth/login [post] // @Router /auth/login [post]
func (h *Handler) Login(authService *auth.SessionService, cookieService auth.CookieService) http.HandlerFunc { func (h *Handler) Login(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var req LoginRequest var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@@ -68,7 +68,7 @@ func (h *Handler) Login(authService *auth.SessionService, cookieService auth.Coo
} }
// Create session and generate tokens // Create session and generate tokens
session, accessToken, err := authService.CreateSession(user.ID, string(user.Role)) session, accessToken, err := authManager.CreateSession(user.ID, string(user.Role))
if err != nil { if err != nil {
respondError(w, "Failed to create session", http.StatusInternalServerError) respondError(w, "Failed to create session", http.StatusInternalServerError)
return return
@@ -110,7 +110,7 @@ func (h *Handler) Login(authService *auth.SessionService, cookieService auth.Coo
// @Failure 400 {object} ErrorResponse "Session ID required" // @Failure 400 {object} ErrorResponse "Session ID required"
// @Failure 500 {object} ErrorResponse "Failed to logout" // @Failure 500 {object} ErrorResponse "Failed to logout"
// @Router /auth/logout [post] // @Router /auth/logout [post]
func (h *Handler) Logout(authService *auth.SessionService, cookieService auth.CookieService) http.HandlerFunc { func (h *Handler) Logout(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
// Get session ID from cookie // Get session ID from cookie
sessionCookie, err := r.Cookie("access_token") sessionCookie, err := r.Cookie("access_token")
@@ -120,7 +120,7 @@ func (h *Handler) Logout(authService *auth.SessionService, cookieService auth.Co
} }
// Invalidate the session in the database // Invalidate the session in the database
if err := authService.InvalidateSession(sessionCookie.Value); err != nil { if err := authManager.InvalidateSession(sessionCookie.Value); err != nil {
respondError(w, "Failed to invalidate session", http.StatusInternalServerError) respondError(w, "Failed to invalidate session", http.StatusInternalServerError)
return return
} }
@@ -147,7 +147,7 @@ func (h *Handler) Logout(authService *auth.SessionService, cookieService auth.Co
// @Failure 400 {object} ErrorResponse "Refresh token required" // @Failure 400 {object} ErrorResponse "Refresh token required"
// @Failure 401 {object} ErrorResponse "Invalid refresh token" // @Failure 401 {object} ErrorResponse "Invalid refresh token"
// @Router /auth/refresh [post] // @Router /auth/refresh [post]
func (h *Handler) RefreshToken(authService *auth.SessionService, cookieService auth.CookieService) http.HandlerFunc { func (h *Handler) RefreshToken(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
refreshCookie, err := r.Cookie("refresh_token") refreshCookie, err := r.Cookie("refresh_token")
if err != nil { if err != nil {
@@ -156,7 +156,7 @@ func (h *Handler) RefreshToken(authService *auth.SessionService, cookieService a
} }
// Generate new access token // Generate new access token
accessToken, err := authService.RefreshSession(refreshCookie.Value) accessToken, err := authManager.RefreshSession(refreshCookie.Value)
if err != nil { if err != nil {
respondError(w, "Invalid refresh token", http.StatusUnauthorized) respondError(w, "Invalid refresh token", http.StatusUnauthorized)
return return

View File

@@ -28,7 +28,7 @@ type testHarness struct {
DB db.TestDatabase DB db.TestDatabase
Storage storage.Manager Storage storage.Manager
JWTManager auth.JWTManager JWTManager auth.JWTManager
SessionSvc *auth.SessionService SessionManager auth.SessionManager
AdminUser *models.User AdminUser *models.User
AdminToken string AdminToken string
RegularUser *models.User RegularUser *models.User
@@ -104,7 +104,7 @@ func setupTestHarness(t *testing.T) *testHarness {
Database: database, Database: database,
Storage: storageSvc, Storage: storageSvc,
JWTManager: jwtSvc, JWTManager: jwtSvc,
SessionService: sessionSvc, SessionManager: sessionSvc,
} }
// Create server // Create server
@@ -115,7 +115,7 @@ func setupTestHarness(t *testing.T) *testHarness {
DB: database, DB: database,
Storage: storageSvc, Storage: storageSvc,
JWTManager: jwtSvc, JWTManager: jwtSvc,
SessionSvc: sessionSvc, SessionManager: sessionSvc,
TempDirectory: tempDir, TempDirectory: tempDir,
MockGit: mockGit, MockGit: mockGit,
} }
@@ -172,7 +172,7 @@ func (h *testHarness) createTestUser(t *testing.T, email, password string, role
t.Fatalf("Failed to initialize user workspace: %v", err) t.Fatalf("Failed to initialize user workspace: %v", err)
} }
session, accessToken, err := h.SessionSvc.CreateSession(user.ID, string(user.Role)) session, accessToken, err := h.SessionManager.CreateSession(user.ID, string(user.Role))
if err != nil { if err != nil {
t.Fatalf("Failed to create session: %v", err) t.Fatalf("Failed to create session: %v", err)
} }