diff --git a/server/internal/app/init.go b/server/internal/app/init.go index 923dae8..773f8c3 100644 --- a/server/internal/app/init.go +++ b/server/internal/app/init.go @@ -40,7 +40,7 @@ func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, err } // initAuth initializes JWT and session services -func initAuth(cfg *Config, database db.Database) (auth.JWTManager, *auth.SessionService, auth.CookieService, error) { +func initAuth(cfg *Config, database db.Database) (auth.JWTManager, auth.SessionManager, auth.CookieManager, error) { // Get or generate JWT signing key signingKey := cfg.JWTSigningKey if signingKey == "" { @@ -62,12 +62,12 @@ func initAuth(cfg *Config, database db.Database) (auth.JWTManager, *auth.Session } // Initialize session service - sessionService := auth.NewSessionService(database, jwtManager) + sessionManager := auth.NewSessionService(database, jwtManager) // Cookie service cookieService := auth.NewCookieService(cfg.IsDevelopment, cfg.Domain) - return jwtManager, sessionService, cookieService, nil + return jwtManager, sessionManager, cookieService, nil } // setupAdminUser creates the admin user if it doesn't exist diff --git a/server/internal/app/options.go b/server/internal/app/options.go index be5e3cc..21a8fa5 100644 --- a/server/internal/app/options.go +++ b/server/internal/app/options.go @@ -12,8 +12,8 @@ type Options struct { Database db.Database Storage storage.Manager JWTManager auth.JWTManager - SessionService *auth.SessionService - CookieService auth.CookieService + SessionManager auth.SessionManager + CookieService auth.CookieManager } // DefaultOptions creates server options with default configuration @@ -49,7 +49,7 @@ func DefaultOptions(cfg *Config) (*Options, error) { Database: database, Storage: storageManager, JWTManager: jwtManager, - SessionService: sessionService, + SessionManager: sessionService, CookieService: cookieService, }, nil } diff --git a/server/internal/app/routes.go b/server/internal/app/routes.go index ecd1a35..644dee6 100644 --- a/server/internal/app/routes.go +++ b/server/internal/app/routes.go @@ -48,7 +48,7 @@ func setupRouter(o Options) *chi.Mux { } // Initialize auth middleware and handler - authMiddleware := auth.NewMiddleware(o.JWTManager) + authMiddleware := auth.NewMiddleware(o.JWTManager, o.SessionManager, o.CookieService) handler := &handlers.Handler{ DB: o.Database, Storage: o.Storage, @@ -72,8 +72,8 @@ func setupRouter(o Options) *chi.Mux { // Public routes (no authentication required) r.Group(func(r chi.Router) { - r.Post("/auth/login", handler.Login(o.SessionService, o.CookieService)) - r.Post("/auth/refresh", handler.RefreshToken(o.SessionService, o.CookieService)) + r.Post("/auth/login", handler.Login(o.SessionManager, o.CookieService)) + r.Post("/auth/refresh", handler.RefreshToken(o.SessionManager, o.CookieService)) }) // Protected routes (authentication required) @@ -82,7 +82,7 @@ func setupRouter(o Options) *chi.Mux { r.Use(context.WithUserContextMiddleware) // Auth routes - r.Post("/auth/logout", handler.Logout(o.SessionService, o.CookieService)) + r.Post("/auth/logout", handler.Logout(o.SessionManager, o.CookieService)) r.Get("/auth/me", handler.GetCurrentUser()) // User profile routes diff --git a/server/internal/auth/cookies.go b/server/internal/auth/cookies.go index 7bed834..7249bad 100644 --- a/server/internal/auth/cookies.go +++ b/server/internal/auth/cookies.go @@ -5,8 +5,8 @@ import ( "net/http" ) -// CookieService interface defines methods for generating cookies -type CookieService interface { +// CookieManager interface defines methods for generating cookies +type CookieManager interface { GenerateAccessTokenCookie(token string) *http.Cookie GenerateRefreshTokenCookie(token string) *http.Cookie GenerateCSRFCookie(token string) *http.Cookie @@ -14,14 +14,14 @@ type CookieService interface { } // CookieService -type cookieService struct { +type cookieManager struct { Domain string Secure bool SameSite http.SameSite } // NewCookieService creates a new cookie service -func NewCookieService(isDevelopment bool, domain string) CookieService { +func NewCookieService(isDevelopment bool, domain string) CookieManager { secure := !isDevelopment var sameSite http.SameSite @@ -31,7 +31,7 @@ func NewCookieService(isDevelopment bool, domain string) CookieService { sameSite = http.SameSiteStrictMode } - return &cookieService{ + return &cookieManager{ Domain: domain, Secure: secure, SameSite: sameSite, @@ -39,7 +39,7 @@ func NewCookieService(isDevelopment bool, domain string) CookieService { } // GenerateAccessTokenCookie creates a new cookie for the access token -func (c *cookieService) GenerateAccessTokenCookie(token string) *http.Cookie { +func (c *cookieManager) GenerateAccessTokenCookie(token string) *http.Cookie { return &http.Cookie{ Name: "access_token", Value: token, @@ -52,7 +52,7 @@ func (c *cookieService) GenerateAccessTokenCookie(token string) *http.Cookie { } // GenerateRefreshTokenCookie creates a new cookie for the refresh token -func (c *cookieService) GenerateRefreshTokenCookie(token string) *http.Cookie { +func (c *cookieManager) GenerateRefreshTokenCookie(token string) *http.Cookie { return &http.Cookie{ Name: "refresh_token", Value: token, @@ -65,7 +65,7 @@ func (c *cookieService) GenerateRefreshTokenCookie(token string) *http.Cookie { } // GenerateCSRFCookie creates a new cookie for the CSRF token -func (c *cookieService) GenerateCSRFCookie(token string) *http.Cookie { +func (c *cookieManager) GenerateCSRFCookie(token string) *http.Cookie { return &http.Cookie{ Name: "csrf_token", Value: token, @@ -78,7 +78,7 @@ func (c *cookieService) GenerateCSRFCookie(token string) *http.Cookie { } // InvalidateCookie creates a new cookie with a MaxAge of -1 to invalidate the cookie -func (c *cookieService) InvalidateCookie(cookieType string) *http.Cookie { +func (c *cookieManager) InvalidateCookie(cookieType string) *http.Cookie { return &http.Cookie{ Name: cookieType, Value: "", diff --git a/server/internal/auth/middleware.go b/server/internal/auth/middleware.go index 7754f7b..6748fa3 100644 --- a/server/internal/auth/middleware.go +++ b/server/internal/auth/middleware.go @@ -9,13 +9,17 @@ import ( // Middleware handles JWT authentication for protected routes type Middleware struct { - jwtManager JWTManager + jwtManager JWTManager + sessionManager SessionManager + cookieManager CookieManager } // NewMiddleware creates a new authentication middleware -func NewMiddleware(jwtManager JWTManager) *Middleware { +func NewMiddleware(jwtManager JWTManager, sessionManager SessionManager, cookieManager CookieManager) *Middleware { return &Middleware{ - jwtManager: jwtManager, + jwtManager: jwtManager, + sessionManager: sessionManager, + cookieManager: cookieManager, } } @@ -42,6 +46,16 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler { return } + // Check if session is still valid in database + session, err := m.sessionManager.ValidateSession(claims.ID) + if err != nil || session == nil { + m.cookieManager.InvalidateCookie("access_token") + m.cookieManager.InvalidateCookie("refresh_token") + m.cookieManager.InvalidateCookie("csrf_token") + http.Error(w, "Session invalid or expired", http.StatusUnauthorized) + return + } + // Add CSRF check for non-GET requests if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions { csrfCookie, err := r.Cookie("csrf_token") diff --git a/server/internal/auth/session.go b/server/internal/auth/session.go index afaaf06..21d0090 100644 --- a/server/internal/auth/session.go +++ b/server/internal/auth/session.go @@ -9,22 +9,30 @@ import ( "github.com/google/uuid" ) -// SessionService manages user sessions in the database -type SessionService struct { +type SessionManager interface { + CreateSession(userID int, role string) (*models.Session, string, error) + RefreshSession(refreshToken string) (string, error) + ValidateSession(sessionID string) (*models.Session, error) + InvalidateSession(token string) error + CleanExpiredSessions() error +} + +// sessionManager manages user sessions in the database +type sessionManager struct { db db.SessionStore // Database store for sessions jwtManager JWTManager // JWT Manager for token operations } // NewSessionService creates a new session service with the given database and JWT manager -func NewSessionService(db db.SessionStore, jwtManager JWTManager) *SessionService { - return &SessionService{ +func NewSessionService(db db.SessionStore, jwtManager JWTManager) *sessionManager { + return &sessionManager{ db: db, jwtManager: jwtManager, } } // CreateSession creates a new user session for a user with the given userID and role -func (s *SessionService) CreateSession(userID int, role string) (*models.Session, string, error) { +func (s *sessionManager) CreateSession(userID int, role string) (*models.Session, string, error) { // Generate both access and refresh tokens accessToken, err := s.jwtManager.GenerateAccessToken(userID, role) if err != nil { @@ -60,7 +68,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*models.Session } // RefreshSession creates a new access token using a refreshToken -func (s *SessionService) RefreshSession(refreshToken string) (string, error) { +func (s *sessionManager) RefreshSession(refreshToken string) (string, error) { // Get session from database first session, err := s.db.GetSessionByRefreshToken(refreshToken) if err != nil { @@ -82,8 +90,20 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) { return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role) } +// ValidateSession checks if a session with the given sessionID is valid +func (s *sessionManager) ValidateSession(sessionID string) (*models.Session, error) { + + // Get the session from the database + session, err := s.db.GetSessionByID(sessionID) + if err != nil { + return nil, fmt.Errorf("failed to get session: %w", err) + } + + return session, nil +} + // InvalidateSession removes a session with the given sessionID from the database -func (s *SessionService) InvalidateSession(token string) error { +func (s *sessionManager) InvalidateSession(token string) error { // Parse the JWT to get the session info claims, err := s.jwtManager.ValidateToken(token) if err != nil { @@ -94,6 +114,6 @@ func (s *SessionService) InvalidateSession(token string) error { } // CleanExpiredSessions removes all expired sessions from the database -func (s *SessionService) CleanExpiredSessions() error { +func (s *sessionManager) CleanExpiredSessions() error { return s.db.CleanExpiredSessions() } diff --git a/server/internal/db/db.go b/server/internal/db/db.go index a54ba4a..a0c21eb 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -53,6 +53,7 @@ type WorkspaceStore interface { type SessionStore interface { CreateSession(session *models.Session) error GetSessionByRefreshToken(refreshToken string) (*models.Session, error) + GetSessionByID(sessionID string) (*models.Session, error) DeleteSession(sessionID string) error CleanExpiredSessions() error } diff --git a/server/internal/db/sessions.go b/server/internal/db/sessions.go index ecc4a72..79b9231 100644 --- a/server/internal/db/sessions.go +++ b/server/internal/db/sessions.go @@ -41,6 +41,26 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi return session, nil } +// GetSessionByID retrieves a session by its ID +func (db *database) GetSessionByID(sessionID string) (*models.Session, error) { + session := &models.Session{} + err := db.QueryRow(` + SELECT id, user_id, refresh_token, expires_at, created_at + FROM sessions + WHERE id = ? AND expires_at > ?`, + sessionID, time.Now(), + ).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) + + if err == sql.ErrNoRows { + return nil, fmt.Errorf("session not found") + } + if err != nil { + return nil, fmt.Errorf("failed to fetch session: %w", err) + } + + return session, nil +} + // DeleteSession removes a session from the database func (db *database) DeleteSession(sessionID string) error { result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID) diff --git a/server/internal/handlers/auth_handlers.go b/server/internal/handlers/auth_handlers.go index e59a98d..f55c650 100644 --- a/server/internal/handlers/auth_handlers.go +++ b/server/internal/handlers/auth_handlers.go @@ -39,7 +39,7 @@ type LoginResponse struct { // @Failure 401 {object} ErrorResponse "Invalid credentials" // @Failure 500 {object} ErrorResponse "Failed to create session" // @Router /auth/login [post] -func (h *Handler) Login(authService *auth.SessionService, cookieService auth.CookieService) http.HandlerFunc { +func (h *Handler) Login(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req LoginRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -68,7 +68,7 @@ func (h *Handler) Login(authService *auth.SessionService, cookieService auth.Coo } // Create session and generate tokens - session, accessToken, err := authService.CreateSession(user.ID, string(user.Role)) + session, accessToken, err := authManager.CreateSession(user.ID, string(user.Role)) if err != nil { respondError(w, "Failed to create session", http.StatusInternalServerError) return @@ -110,7 +110,7 @@ func (h *Handler) Login(authService *auth.SessionService, cookieService auth.Coo // @Failure 400 {object} ErrorResponse "Session ID required" // @Failure 500 {object} ErrorResponse "Failed to logout" // @Router /auth/logout [post] -func (h *Handler) Logout(authService *auth.SessionService, cookieService auth.CookieService) http.HandlerFunc { +func (h *Handler) Logout(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Get session ID from cookie sessionCookie, err := r.Cookie("access_token") @@ -120,7 +120,7 @@ func (h *Handler) Logout(authService *auth.SessionService, cookieService auth.Co } // Invalidate the session in the database - if err := authService.InvalidateSession(sessionCookie.Value); err != nil { + if err := authManager.InvalidateSession(sessionCookie.Value); err != nil { respondError(w, "Failed to invalidate session", http.StatusInternalServerError) return } @@ -147,7 +147,7 @@ func (h *Handler) Logout(authService *auth.SessionService, cookieService auth.Co // @Failure 400 {object} ErrorResponse "Refresh token required" // @Failure 401 {object} ErrorResponse "Invalid refresh token" // @Router /auth/refresh [post] -func (h *Handler) RefreshToken(authService *auth.SessionService, cookieService auth.CookieService) http.HandlerFunc { +func (h *Handler) RefreshToken(authManager auth.SessionManager, cookieService auth.CookieManager) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { refreshCookie, err := r.Cookie("refresh_token") if err != nil { @@ -156,7 +156,7 @@ func (h *Handler) RefreshToken(authService *auth.SessionService, cookieService a } // Generate new access token - accessToken, err := authService.RefreshSession(refreshCookie.Value) + accessToken, err := authManager.RefreshSession(refreshCookie.Value) if err != nil { respondError(w, "Invalid refresh token", http.StatusUnauthorized) return diff --git a/server/internal/handlers/integration_test.go b/server/internal/handlers/integration_test.go index 047966f..4d8b6aa 100644 --- a/server/internal/handlers/integration_test.go +++ b/server/internal/handlers/integration_test.go @@ -24,17 +24,17 @@ import ( // testHarness encapsulates all the dependencies needed for testing type testHarness struct { - Server *app.Server - DB db.TestDatabase - Storage storage.Manager - JWTManager auth.JWTManager - SessionSvc *auth.SessionService - AdminUser *models.User - AdminToken string - RegularUser *models.User - RegularToken string - TempDirectory string - MockGit *MockGitClient + Server *app.Server + DB db.TestDatabase + Storage storage.Manager + JWTManager auth.JWTManager + SessionManager auth.SessionManager + AdminUser *models.User + AdminToken string + RegularUser *models.User + RegularToken string + TempDirectory string + MockGit *MockGitClient } // setupTestHarness creates a new test environment @@ -104,20 +104,20 @@ func setupTestHarness(t *testing.T) *testHarness { Database: database, Storage: storageSvc, JWTManager: jwtSvc, - SessionService: sessionSvc, + SessionManager: sessionSvc, } // Create server srv := app.NewServer(serverOpts) h := &testHarness{ - Server: srv, - DB: database, - Storage: storageSvc, - JWTManager: jwtSvc, - SessionSvc: sessionSvc, - TempDirectory: tempDir, - MockGit: mockGit, + Server: srv, + DB: database, + Storage: storageSvc, + JWTManager: jwtSvc, + SessionManager: sessionSvc, + TempDirectory: tempDir, + MockGit: mockGit, } // Create test users @@ -172,7 +172,7 @@ func (h *testHarness) createTestUser(t *testing.T, email, password string, role t.Fatalf("Failed to initialize user workspace: %v", err) } - session, accessToken, err := h.SessionSvc.CreateSession(user.ID, string(user.Role)) + session, accessToken, err := h.SessionManager.CreateSession(user.ID, string(user.Role)) if err != nil { t.Fatalf("Failed to create session: %v", err) }