mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-06 07:54:22 +00:00
Migrate backend auth to cookies
This commit is contained in:
91
server/internal/auth/cookies.go
Normal file
91
server/internal/auth/cookies.go
Normal file
@@ -0,0 +1,91 @@
|
||||
// Package auth provides JWT token generation and validation
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// CookieService interface defines methods for generating cookies
|
||||
type CookieService interface {
|
||||
GenerateAccessTokenCookie(token string) *http.Cookie
|
||||
GenerateRefreshTokenCookie(token string) *http.Cookie
|
||||
GenerateCSRFCookie(token string) *http.Cookie
|
||||
InvalidateCookie(cookieType string) *http.Cookie
|
||||
}
|
||||
|
||||
// CookieService
|
||||
type cookieService struct {
|
||||
Domain string
|
||||
Secure bool
|
||||
SameSite http.SameSite
|
||||
}
|
||||
|
||||
// NewCookieService creates a new cookie service
|
||||
func NewCookieService(isDevelopment bool, domain string) CookieService {
|
||||
secure := !isDevelopment
|
||||
var sameSite http.SameSite
|
||||
|
||||
if isDevelopment {
|
||||
sameSite = http.SameSiteLaxMode
|
||||
} else {
|
||||
sameSite = http.SameSiteStrictMode
|
||||
}
|
||||
|
||||
return &cookieService{
|
||||
Domain: domain,
|
||||
Secure: secure,
|
||||
SameSite: sameSite,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAccessTokenCookie creates a new cookie for the access token
|
||||
func (c *cookieService) GenerateAccessTokenCookie(token string) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: "access_token",
|
||||
Value: token,
|
||||
HttpOnly: true,
|
||||
Secure: c.Secure,
|
||||
SameSite: c.SameSite,
|
||||
Path: "/",
|
||||
MaxAge: 900, // 15 minutes
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRefreshTokenCookie creates a new cookie for the refresh token
|
||||
func (c *cookieService) GenerateRefreshTokenCookie(token string) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: token,
|
||||
HttpOnly: true,
|
||||
Secure: c.Secure,
|
||||
SameSite: c.SameSite,
|
||||
Path: "/",
|
||||
MaxAge: 604800, // 7 days
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateCSRFCookie creates a new cookie for the CSRF token
|
||||
func (c *cookieService) GenerateCSRFCookie(token string) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: "csrf_token",
|
||||
Value: token,
|
||||
HttpOnly: false, // Frontend needs to read this
|
||||
Secure: c.Secure,
|
||||
SameSite: c.SameSite,
|
||||
Path: "/",
|
||||
MaxAge: 900,
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateCookie creates a new cookie with a MaxAge of -1 to invalidate the cookie
|
||||
func (c *cookieService) InvalidateCookie(cookieType string) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: cookieType,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: c.Secure,
|
||||
SameSite: c.SameSite,
|
||||
}
|
||||
}
|
||||
@@ -38,7 +38,6 @@ type JWTManager interface {
|
||||
GenerateAccessToken(userID int, role string) (string, error)
|
||||
GenerateRefreshToken(userID int, role string) (string, error)
|
||||
ValidateToken(tokenString string) (*Claims, error)
|
||||
RefreshAccessToken(refreshToken string) (string, error)
|
||||
}
|
||||
|
||||
// jwtService handles JWT token generation and validation
|
||||
@@ -118,17 +117,3 @@ func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) {
|
||||
|
||||
return nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
// RefreshAccessToken creates a new access token using a refreshToken
|
||||
func (s *jwtService) RefreshAccessToken(refreshToken string) (string, error) {
|
||||
claims, err := s.ValidateToken(refreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||
}
|
||||
|
||||
if claims.Type != RefreshToken {
|
||||
return "", fmt.Errorf("invalid token type: expected refresh token")
|
||||
}
|
||||
|
||||
return s.GenerateAccessToken(claims.UserID, claims.Role)
|
||||
}
|
||||
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"time"
|
||||
|
||||
"novamd/internal/auth"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// jwt_test.go tests
|
||||
@@ -136,86 +134,3 @@ func TestGenerateAndValidateToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshAccessToken(t *testing.T) {
|
||||
config := auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
}
|
||||
service, _ := auth.NewJWTService(config)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
role string
|
||||
wantErr bool
|
||||
setupFunc func() string // Added setup function to handle custom token creation
|
||||
}{
|
||||
{
|
||||
name: "valid refresh token",
|
||||
userID: 1,
|
||||
role: "admin",
|
||||
wantErr: false,
|
||||
setupFunc: func() string {
|
||||
token, _ := service.GenerateRefreshToken(1, "admin")
|
||||
return token
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "expired refresh token",
|
||||
userID: 1,
|
||||
role: "admin",
|
||||
wantErr: true,
|
||||
setupFunc: func() string {
|
||||
// Create a token that's already expired
|
||||
claims := &auth.Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)), // Expired 1 hour ago
|
||||
IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
|
||||
},
|
||||
UserID: 1,
|
||||
Role: "admin",
|
||||
Type: auth.RefreshToken,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte(config.SigningKey))
|
||||
return tokenString
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
refreshToken := tc.setupFunc()
|
||||
newAccessToken, err := service.RefreshAccessToken(refreshToken)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
claims, err := service.ValidateToken(newAccessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to validate new access token: %v", err)
|
||||
}
|
||||
|
||||
if claims.UserID != tc.userID {
|
||||
t.Errorf("userID = %v, want %v", claims.UserID, tc.userID)
|
||||
}
|
||||
if claims.Role != tc.role {
|
||||
t.Errorf("role = %v, want %v", claims.Role, tc.role)
|
||||
}
|
||||
if claims.Type != auth.AccessToken {
|
||||
t.Errorf("token type = %v, want %v", claims.Type, auth.AccessToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"novamd/internal/context"
|
||||
)
|
||||
@@ -23,21 +23,14 @@ func NewMiddleware(jwtManager JWTManager) *Middleware {
|
||||
func (m *Middleware) Authenticate(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Extract token from Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
http.Error(w, "Authorization header required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Check Bearer token format
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
http.Error(w, "Invalid authorization format", http.StatusUnauthorized)
|
||||
cookie, err := r.Cookie("access_token")
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate token
|
||||
claims, err := m.jwtManager.ValidateToken(parts[1])
|
||||
claims, err := m.jwtManager.ValidateToken(cookie.Value)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid token", http.StatusUnauthorized)
|
||||
return
|
||||
@@ -49,6 +42,26 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
// Add CSRF check for non-GET requests
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions {
|
||||
csrfCookie, err := r.Cookie("csrf_token")
|
||||
if err != nil {
|
||||
http.Error(w, "CSRF cookie not found", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
csrfHeader := r.Header.Get("X-CSRF-Token")
|
||||
if csrfHeader == "" {
|
||||
http.Error(w, "CSRF token header not found", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if subtle.ConstantTimeCompare([]byte(csrfCookie.Value), []byte(csrfHeader)) != 1 {
|
||||
http.Error(w, "CSRF token mismatch", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Create handler context with user information
|
||||
hctx := &context.HandlerContext{
|
||||
UserID: claims.UserID,
|
||||
|
||||
@@ -83,8 +83,14 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
||||
}
|
||||
|
||||
// InvalidateSession removes a session with the given sessionID from the database
|
||||
func (s *SessionService) InvalidateSession(sessionID string) error {
|
||||
return s.db.DeleteSession(sessionID)
|
||||
func (s *SessionService) InvalidateSession(token string) error {
|
||||
// Parse the JWT to get the session info
|
||||
claims, err := s.jwtManager.ValidateToken(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
|
||||
return s.db.DeleteSession(claims.ID)
|
||||
}
|
||||
|
||||
// CleanExpiredSessions removes all expired sessions from the database
|
||||
|
||||
Reference in New Issue
Block a user