mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-06 07:54:22 +00:00
Implement JWTManager interface
This commit is contained in:
@@ -48,7 +48,7 @@ func main() {
|
|||||||
s := storage.NewService(cfg.WorkDir)
|
s := storage.NewService(cfg.WorkDir)
|
||||||
|
|
||||||
// Initialize JWT service
|
// Initialize JWT service
|
||||||
jwtService, err := auth.NewJWTService(auth.JWTConfig{
|
jwtManager, err := auth.NewJWTService(auth.JWTConfig{
|
||||||
SigningKey: signingKey,
|
SigningKey: signingKey,
|
||||||
AccessTokenExpiry: 15 * time.Minute,
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
RefreshTokenExpiry: 7 * 24 * time.Hour,
|
RefreshTokenExpiry: 7 * 24 * time.Hour,
|
||||||
@@ -58,10 +58,10 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize auth middleware
|
// Initialize auth middleware
|
||||||
authMiddleware := auth.NewMiddleware(jwtService)
|
authMiddleware := auth.NewMiddleware(jwtManager)
|
||||||
|
|
||||||
// Initialize session service
|
// Initialize session service
|
||||||
sessionService := auth.NewSessionService(database.DB, jwtService)
|
sessionService := auth.NewSessionService(database.DB, jwtManager)
|
||||||
|
|
||||||
// Set up router
|
// Set up router
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
// Package auth provides JWT token generation and validation
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -30,14 +31,22 @@ type JWTConfig struct {
|
|||||||
RefreshTokenExpiry time.Duration // How long refresh tokens are valid
|
RefreshTokenExpiry time.Duration // How long refresh tokens are valid
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTService handles JWT token generation and validation
|
// JWTManager defines the interface for managing JWT tokens
|
||||||
type JWTService struct {
|
type JWTManager interface {
|
||||||
|
GenerateAccessToken(userID int, role string) (string, error)
|
||||||
|
GenerateRefreshToken(userID int, role string) (string, error)
|
||||||
|
ValidateToken(tokenString string) (*Claims, error)
|
||||||
|
RefreshAccessToken(refreshToken string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// jwtService handles JWT token generation and validation
|
||||||
|
type jwtService struct {
|
||||||
config JWTConfig
|
config JWTConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewJWTService creates a new JWT service with the provided configuration
|
// NewJWTService creates a new JWT service with the provided configuration
|
||||||
// Returns an error if the signing key is missing
|
// Returns an error if the signing key is missing
|
||||||
func NewJWTService(config JWTConfig) (*JWTService, error) {
|
func NewJWTService(config JWTConfig) (JWTManager, error) {
|
||||||
if config.SigningKey == "" {
|
if config.SigningKey == "" {
|
||||||
return nil, fmt.Errorf("signing key is required")
|
return nil, fmt.Errorf("signing key is required")
|
||||||
}
|
}
|
||||||
@@ -48,7 +57,7 @@ func NewJWTService(config JWTConfig) (*JWTService, error) {
|
|||||||
if config.RefreshTokenExpiry == 0 {
|
if config.RefreshTokenExpiry == 0 {
|
||||||
config.RefreshTokenExpiry = 7 * 24 * time.Hour // Default to 7 days
|
config.RefreshTokenExpiry = 7 * 24 * time.Hour // Default to 7 days
|
||||||
}
|
}
|
||||||
return &JWTService{config: config}, nil
|
return &jwtService{config: config}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateAccessToken creates a new access token for a user
|
// GenerateAccessToken creates a new access token for a user
|
||||||
@@ -56,7 +65,7 @@ func NewJWTService(config JWTConfig) (*JWTService, error) {
|
|||||||
// - userID: the ID of the user
|
// - userID: the ID of the user
|
||||||
// - role: the role of the user
|
// - role: the role of the user
|
||||||
// Returns the signed token string or an error
|
// Returns the signed token string or an error
|
||||||
func (s *JWTService) GenerateAccessToken(userID int, role string) (string, error) {
|
func (s *jwtService) GenerateAccessToken(userID int, role string) (string, error) {
|
||||||
return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry)
|
return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,7 +74,7 @@ func (s *JWTService) GenerateAccessToken(userID int, role string) (string, error
|
|||||||
// - userID: the ID of the user
|
// - userID: the ID of the user
|
||||||
// - role: the role of the user
|
// - role: the role of the user
|
||||||
// Returns the signed token string or an error
|
// Returns the signed token string or an error
|
||||||
func (s *JWTService) GenerateRefreshToken(userID int, role string) (string, error) {
|
func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, error) {
|
||||||
return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry)
|
return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,7 +85,7 @@ func (s *JWTService) GenerateRefreshToken(userID int, role string) (string, erro
|
|||||||
// - tokenType: the type of token (access or refresh)
|
// - tokenType: the type of token (access or refresh)
|
||||||
// - expiry: how long the token should be valid
|
// - expiry: how long the token should be valid
|
||||||
// Returns the signed token string or an error
|
// Returns the signed token string or an error
|
||||||
func (s *JWTService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) {
|
func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
claims := Claims{
|
claims := Claims{
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
@@ -97,7 +106,7 @@ func (s *JWTService) generateToken(userID int, role string, tokenType TokenType,
|
|||||||
// Parameters:
|
// Parameters:
|
||||||
// - tokenString: the token to validate
|
// - tokenString: the token to validate
|
||||||
// Returns the token claims if valid, or an error if invalid
|
// Returns the token claims if valid, or an error if invalid
|
||||||
func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) {
|
func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) {
|
||||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
// Validate the signing method
|
// Validate the signing method
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
@@ -121,7 +130,7 @@ func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) {
|
|||||||
// Parameters:
|
// Parameters:
|
||||||
// - refreshToken: the refresh token to use
|
// - refreshToken: the refresh token to use
|
||||||
// Returns a new access token if the refresh token is valid, or an error
|
// Returns a new access token if the refresh token is valid, or an error
|
||||||
func (s *JWTService) RefreshAccessToken(refreshToken string) (string, error) {
|
func (s *jwtService) RefreshAccessToken(refreshToken string) (string, error) {
|
||||||
claims, err := s.ValidateToken(refreshToken)
|
claims, err := s.ValidateToken(refreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("invalid refresh token: %w", err)
|
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||||
|
|||||||
@@ -11,9 +11,8 @@ import (
|
|||||||
|
|
||||||
type contextKey string
|
type contextKey string
|
||||||
|
|
||||||
const (
|
// UserContextKey is the key used to store user claims in the request context
|
||||||
UserContextKey contextKey = "user"
|
const UserContextKey contextKey = "user"
|
||||||
)
|
|
||||||
|
|
||||||
// UserClaims represents the user information stored in the request context
|
// UserClaims represents the user information stored in the request context
|
||||||
type UserClaims struct {
|
type UserClaims struct {
|
||||||
@@ -23,17 +22,25 @@ type UserClaims struct {
|
|||||||
|
|
||||||
// Middleware handles JWT authentication for protected routes
|
// Middleware handles JWT authentication for protected routes
|
||||||
type Middleware struct {
|
type Middleware struct {
|
||||||
jwtService *JWTService
|
jwtManager JWTManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMiddleware creates a new authentication middleware
|
// NewMiddleware creates a new authentication middleware
|
||||||
func NewMiddleware(jwtService *JWTService) *Middleware {
|
// Parameters:
|
||||||
|
// - jwtManager: the JWT manager to use for token operations
|
||||||
|
// Returns:
|
||||||
|
// - *Middleware: the new middleware instance
|
||||||
|
func NewMiddleware(jwtManager JWTManager) *Middleware {
|
||||||
return &Middleware{
|
return &Middleware{
|
||||||
jwtService: jwtService,
|
jwtManager: jwtManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate middleware validates JWT tokens and sets user information in context
|
// Authenticate middleware validates JWT tokens and sets user information in context
|
||||||
|
// Parameters:
|
||||||
|
// - next: the next handler to call
|
||||||
|
// Returns:
|
||||||
|
// - http.Handler: the handler function
|
||||||
func (m *Middleware) Authenticate(next http.Handler) http.Handler {
|
func (m *Middleware) Authenticate(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Extract token from Authorization header
|
// Extract token from Authorization header
|
||||||
@@ -51,7 +58,7 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate token
|
// Validate token
|
||||||
claims, err := m.jwtService.ValidateToken(parts[1])
|
claims, err := m.jwtManager.ValidateToken(parts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "Invalid token", http.StatusUnauthorized)
|
http.Error(w, "Invalid token", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
@@ -75,6 +82,10 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RequireRole returns a middleware that ensures the user has the required role
|
// RequireRole returns a middleware that ensures the user has the required role
|
||||||
|
// Parameters:
|
||||||
|
// - role: the required role
|
||||||
|
// Returns:
|
||||||
|
// - func(http.Handler) http.Handler: the middleware function
|
||||||
func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
|
func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -94,6 +105,11 @@ func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RequireWorkspaceAccess returns a middleware that ensures the user has access to the workspace
|
||||||
|
// Parameters:
|
||||||
|
// - next: the next handler to call
|
||||||
|
// Returns:
|
||||||
|
// - http.Handler: the handler function
|
||||||
func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
|
func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Get our handler context
|
// Get our handler context
|
||||||
@@ -119,6 +135,11 @@ func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUserFromContext retrieves user claims from the request context
|
// GetUserFromContext retrieves user claims from the request context
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: the request context
|
||||||
|
// Returns:
|
||||||
|
// - *UserClaims: the user claims
|
||||||
|
// - error: any error that occurred
|
||||||
func GetUserFromContext(ctx context.Context) (*UserClaims, error) {
|
func GetUserFromContext(ctx context.Context) (*UserClaims, error) {
|
||||||
claims, ok := ctx.Value(UserContextKey).(UserClaims)
|
claims, ok := ctx.Value(UserContextKey).(UserClaims)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
@@ -20,17 +20,17 @@ type Session struct {
|
|||||||
// SessionService manages user sessions in the database
|
// SessionService manages user sessions in the database
|
||||||
type SessionService struct {
|
type SessionService struct {
|
||||||
db *sql.DB // Database connection
|
db *sql.DB // Database connection
|
||||||
jwtService *JWTService // JWT service for token operations
|
jwtManager JWTManager // JWT Manager for token operations
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSessionService creates a new session service
|
// NewSessionService creates a new session service
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - db: database connection
|
// - db: database connection
|
||||||
// - jwtService: JWT service for token operations
|
// - jwtManager: JWT service for token operations
|
||||||
func NewSessionService(db *sql.DB, jwtService *JWTService) *SessionService {
|
func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService {
|
||||||
return &SessionService{
|
return &SessionService{
|
||||||
db: db,
|
db: db,
|
||||||
jwtService: jwtService,
|
jwtManager: jwtManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,18 +44,18 @@ func NewSessionService(db *sql.DB, jwtService *JWTService) *SessionService {
|
|||||||
// - error: any error that occurred
|
// - error: any error that occurred
|
||||||
func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) {
|
func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) {
|
||||||
// Generate both access and refresh tokens
|
// Generate both access and refresh tokens
|
||||||
accessToken, err := s.jwtService.GenerateAccessToken(userID, role)
|
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("failed to generate access token: %w", err)
|
return nil, "", fmt.Errorf("failed to generate access token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshToken, err := s.jwtService.GenerateRefreshToken(userID, role)
|
refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("failed to generate refresh token: %w", err)
|
return nil, "", fmt.Errorf("failed to generate refresh token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate the refresh token to get its expiry time
|
// Validate the refresh token to get its expiry time
|
||||||
claims, err := s.jwtService.ValidateToken(refreshToken)
|
claims, err := s.jwtManager.ValidateToken(refreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("failed to validate refresh token: %w", err)
|
return nil, "", fmt.Errorf("failed to validate refresh token: %w", err)
|
||||||
}
|
}
|
||||||
@@ -90,7 +90,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
|
|||||||
// - error: any error that occurred
|
// - error: any error that occurred
|
||||||
func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
||||||
// Validate the refresh token
|
// Validate the refresh token
|
||||||
claims, err := s.jwtService.ValidateToken(refreshToken)
|
claims, err := s.jwtManager.ValidateToken(refreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("invalid refresh token: %w", err)
|
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||||
}
|
}
|
||||||
@@ -112,7 +112,7 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate a new access token
|
// Generate a new access token
|
||||||
return s.jwtService.GenerateAccessToken(claims.UserID, claims.Role)
|
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
// InvalidateSession removes a session from the database
|
// InvalidateSession removes a session from the database
|
||||||
|
|||||||
Reference in New Issue
Block a user