mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-07 16:34:26 +00:00
Merge pull request #20 from LordMathis/chore/backend-test
Implement backend tests
This commit is contained in:
34
.github/workflows/go-test.yml
vendored
Normal file
34
.github/workflows/go-test.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Go Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "*"
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Run Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./server
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23"
|
||||
cache: true
|
||||
|
||||
- name: Run Tests
|
||||
run: go test -tags=test,integration ./... -v
|
||||
|
||||
- name: Run Tests with Race Detector
|
||||
run: go test -tags=test,integration -race ./... -v
|
||||
4
.vscode/settings.json
vendored
4
.vscode/settings.json
vendored
@@ -14,6 +14,7 @@
|
||||
"go.lintTool": "golangci-lint",
|
||||
"go.lintOnSave": "package",
|
||||
"go.formatTool": "goimports",
|
||||
"go.testFlags": ["-tags=test,integration"],
|
||||
"[go]": {
|
||||
"editor.formatOnSave": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
@@ -23,6 +24,7 @@
|
||||
},
|
||||
"gopls": {
|
||||
"usePlaceholders": true,
|
||||
"staticcheck": true
|
||||
"staticcheck": true,
|
||||
"buildFlags": ["-tags", "test,integration"]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package main contains the main entry point for the application. It sets up the server, database, and other services, and starts the server.
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -17,8 +18,9 @@ import (
|
||||
"novamd/internal/auth"
|
||||
"novamd/internal/config"
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/filesystem"
|
||||
"novamd/internal/handlers"
|
||||
"novamd/internal/secrets"
|
||||
"novamd/internal/storage"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -28,12 +30,26 @@ func main() {
|
||||
log.Fatal("Failed to load configuration:", err)
|
||||
}
|
||||
|
||||
// Initialize secrets service
|
||||
secretsService, err := secrets.NewService(cfg.EncryptionKey)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to initialize secrets service:", err)
|
||||
}
|
||||
|
||||
// Initialize database
|
||||
database, err := db.Init(cfg.DBPath, cfg.EncryptionKey)
|
||||
database, err := db.Init(cfg.DBPath, secretsService)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer database.Close()
|
||||
err = database.Migrate()
|
||||
if err != nil {
|
||||
log.Fatal("Failed to apply database migrations:", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := database.Close(); err != nil {
|
||||
log.Printf("Error closing database: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Get or generate JWT signing key
|
||||
signingKey := cfg.JWTSigningKey
|
||||
@@ -45,10 +61,10 @@ func main() {
|
||||
}
|
||||
|
||||
// Initialize filesystem
|
||||
fs := filesystem.New(cfg.WorkDir)
|
||||
s := storage.NewService(cfg.WorkDir)
|
||||
|
||||
// Initialize JWT service
|
||||
jwtService, err := auth.NewJWTService(auth.JWTConfig{
|
||||
jwtManager, err := auth.NewJWTService(auth.JWTConfig{
|
||||
SigningKey: signingKey,
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
RefreshTokenExpiry: 7 * 24 * time.Hour,
|
||||
@@ -58,10 +74,10 @@ func main() {
|
||||
}
|
||||
|
||||
// Initialize auth middleware
|
||||
authMiddleware := auth.NewMiddleware(jwtService)
|
||||
authMiddleware := auth.NewMiddleware(jwtManager)
|
||||
|
||||
// Initialize session service
|
||||
sessionService := auth.NewSessionService(database.DB, jwtService)
|
||||
sessionService := auth.NewSessionService(database, jwtManager)
|
||||
|
||||
// Set up router
|
||||
r := chi.NewRouter()
|
||||
@@ -95,7 +111,7 @@ func main() {
|
||||
// Set up routes
|
||||
r.Route("/api/v1", func(r chi.Router) {
|
||||
r.Use(httprate.LimitByIP(cfg.RateLimitRequests, cfg.RateLimitWindow))
|
||||
api.SetupRoutes(r, database, fs, authMiddleware, sessionService)
|
||||
api.SetupRoutes(r, database, s, authMiddleware, sessionService)
|
||||
})
|
||||
|
||||
// Handle all other routes with static file server
|
||||
|
||||
50
server/gendocs.sh
Executable file
50
server/gendocs.sh
Executable file
@@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Function to generate anchor from package path
|
||||
generate_anchor() {
|
||||
echo "$1" | tr '/' '-'
|
||||
}
|
||||
|
||||
# Create documentation file
|
||||
echo "# NovaMD Package Documentation
|
||||
|
||||
Generated documentation for all packages in the NovaMD project.
|
||||
|
||||
## Table of Contents
|
||||
" > documentation.md
|
||||
|
||||
# Find all directories containing .go files (excluding test files)
|
||||
# Sort them for consistent output
|
||||
PACKAGES=$(find . -type f -name "*.go" ! -name "*_test.go" -exec dirname {} \; | sort -u | grep -v "/\.")
|
||||
|
||||
# Generate table of contents
|
||||
for PKG in $PACKAGES; do
|
||||
# Strip leading ./
|
||||
PKG_PATH=${PKG#./}
|
||||
# Skip if empty
|
||||
[ -z "$PKG_PATH" ] && continue
|
||||
|
||||
ANCHOR=$(generate_anchor "$PKG_PATH")
|
||||
echo "- [$PKG_PATH](#$ANCHOR)" >> documentation.md
|
||||
done
|
||||
|
||||
echo "" >> documentation.md
|
||||
|
||||
# Generate documentation for each package
|
||||
for PKG in $PACKAGES; do
|
||||
# Strip leading ./
|
||||
PKG_PATH=${PKG#./}
|
||||
# Skip if empty
|
||||
[ -z "$PKG_PATH" ] && continue
|
||||
|
||||
echo "## $PKG_PATH" >> documentation.md
|
||||
echo "" >> documentation.md
|
||||
echo '```go' >> documentation.md
|
||||
go doc -all "./$PKG_PATH" >> documentation.md
|
||||
echo '```' >> documentation.md
|
||||
echo "" >> documentation.md
|
||||
done
|
||||
|
||||
echo "Documentation generated in documentation.md"
|
||||
@@ -11,6 +11,7 @@ require (
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/mattn/go-sqlite3 v1.14.23
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/unrolled/secure v1.17.0
|
||||
golang.org/x/crypto v0.21.0
|
||||
)
|
||||
@@ -22,6 +23,7 @@ require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudflare/circl v1.3.7 // indirect
|
||||
github.com/cyphar/filepath-securejoin v0.2.4 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
|
||||
@@ -33,6 +35,7 @@ require (
|
||||
github.com/kevinburke/ssh_config v1.2.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/pjbgf/sha1cd v0.3.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
|
||||
github.com/skeema/knownhosts v1.2.2 // indirect
|
||||
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
||||
@@ -42,4 +45,5 @@ require (
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
golang.org/x/tools v0.13.0 // indirect
|
||||
gopkg.in/warnings.v0 v0.1.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -3,20 +3,20 @@ package api
|
||||
|
||||
import (
|
||||
"novamd/internal/auth"
|
||||
"novamd/internal/context"
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/filesystem"
|
||||
"novamd/internal/handlers"
|
||||
"novamd/internal/middleware"
|
||||
"novamd/internal/storage"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// SetupRoutes configures the API routes
|
||||
func SetupRoutes(r chi.Router, db *db.DB, fs *filesystem.FileSystem, authMiddleware *auth.Middleware, sessionService *auth.SessionService) {
|
||||
func SetupRoutes(r chi.Router, db db.Database, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) {
|
||||
|
||||
handler := &handlers.Handler{
|
||||
DB: db,
|
||||
FS: fs,
|
||||
DB: db,
|
||||
Storage: s,
|
||||
}
|
||||
|
||||
// Public routes (no authentication required)
|
||||
@@ -29,7 +29,7 @@ func SetupRoutes(r chi.Router, db *db.DB, fs *filesystem.FileSystem, authMiddlew
|
||||
r.Group(func(r chi.Router) {
|
||||
// Apply authentication middleware to all routes in this group
|
||||
r.Use(authMiddleware.Authenticate)
|
||||
r.Use(middleware.WithUserContext)
|
||||
r.Use(context.WithUserContextMiddleware)
|
||||
|
||||
// Auth routes
|
||||
r.Post("/auth/logout", handler.Logout(sessionService))
|
||||
@@ -67,7 +67,7 @@ func SetupRoutes(r chi.Router, db *db.DB, fs *filesystem.FileSystem, authMiddlew
|
||||
|
||||
// Single workspace routes
|
||||
r.Route("/{workspaceName}", func(r chi.Router) {
|
||||
r.Use(middleware.WithWorkspaceContext(db))
|
||||
r.Use(context.WithWorkspaceContextMiddleware(db))
|
||||
r.Use(authMiddleware.RequireWorkspaceAccess)
|
||||
|
||||
r.Get("/", handler.GetWorkspace())
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
// Package auth provides JWT token generation and validation
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -30,14 +33,22 @@ type JWTConfig struct {
|
||||
RefreshTokenExpiry time.Duration // How long refresh tokens are valid
|
||||
}
|
||||
|
||||
// JWTService handles JWT token generation and validation
|
||||
type JWTService struct {
|
||||
// JWTManager defines the interface for managing JWT tokens
|
||||
type JWTManager interface {
|
||||
GenerateAccessToken(userID int, role string) (string, error)
|
||||
GenerateRefreshToken(userID int, role string) (string, error)
|
||||
ValidateToken(tokenString string) (*Claims, error)
|
||||
RefreshAccessToken(refreshToken string) (string, error)
|
||||
}
|
||||
|
||||
// jwtService handles JWT token generation and validation
|
||||
type jwtService struct {
|
||||
config JWTConfig
|
||||
}
|
||||
|
||||
// NewJWTService creates a new JWT service with the provided configuration
|
||||
// Returns an error if the signing key is missing
|
||||
func NewJWTService(config JWTConfig) (*JWTService, error) {
|
||||
func NewJWTService(config JWTConfig) (JWTManager, error) {
|
||||
if config.SigningKey == "" {
|
||||
return nil, fmt.Errorf("signing key is required")
|
||||
}
|
||||
@@ -48,41 +59,35 @@ func NewJWTService(config JWTConfig) (*JWTService, error) {
|
||||
if config.RefreshTokenExpiry == 0 {
|
||||
config.RefreshTokenExpiry = 7 * 24 * time.Hour // Default to 7 days
|
||||
}
|
||||
return &JWTService{config: config}, nil
|
||||
return &jwtService{config: config}, nil
|
||||
}
|
||||
|
||||
// GenerateAccessToken creates a new access token for a user
|
||||
// Parameters:
|
||||
// - userID: the ID of the user
|
||||
// - role: the role of the user
|
||||
// Returns the signed token string or an error
|
||||
func (s *JWTService) GenerateAccessToken(userID int, role string) (string, error) {
|
||||
// GenerateAccessToken creates a new access token for a user with the given userID and role
|
||||
func (s *jwtService) GenerateAccessToken(userID int, role string) (string, error) {
|
||||
return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry)
|
||||
}
|
||||
|
||||
// GenerateRefreshToken creates a new refresh token for a user
|
||||
// Parameters:
|
||||
// - userID: the ID of the user
|
||||
// - role: the role of the user
|
||||
// Returns the signed token string or an error
|
||||
func (s *JWTService) GenerateRefreshToken(userID int, role string) (string, error) {
|
||||
// GenerateRefreshToken creates a new refresh token for a user with the given userID and role
|
||||
func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, error) {
|
||||
return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry)
|
||||
}
|
||||
|
||||
// generateToken is an internal helper function that creates a new JWT token
|
||||
// Parameters:
|
||||
// - userID: the ID of the user
|
||||
// - role: the role of the user
|
||||
// - tokenType: the type of token (access or refresh)
|
||||
// - expiry: how long the token should be valid
|
||||
// Returns the signed token string or an error
|
||||
func (s *JWTService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) {
|
||||
func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
// Add a random nonce to ensure uniqueness
|
||||
nonce := make([]byte, 8)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
claims := Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(expiry)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
ID: hex.EncodeToString(nonce),
|
||||
},
|
||||
UserID: userID,
|
||||
Role: role,
|
||||
@@ -94,10 +99,7 @@ func (s *JWTService) generateToken(userID int, role string, tokenType TokenType,
|
||||
}
|
||||
|
||||
// ValidateToken validates and parses a JWT token
|
||||
// Parameters:
|
||||
// - tokenString: the token to validate
|
||||
// Returns the token claims if valid, or an error if invalid
|
||||
func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) {
|
||||
func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// Validate the signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
@@ -117,11 +119,8 @@ func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) {
|
||||
return nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
// RefreshAccessToken creates a new access token using a refresh token
|
||||
// Parameters:
|
||||
// - refreshToken: the refresh token to use
|
||||
// Returns a new access token if the refresh token is valid, or an error
|
||||
func (s *JWTService) RefreshAccessToken(refreshToken string) (string, error) {
|
||||
// RefreshAccessToken creates a new access token using a refreshToken
|
||||
func (s *jwtService) RefreshAccessToken(refreshToken string) (string, error) {
|
||||
claims, err := s.ValidateToken(refreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||
|
||||
221
server/internal/auth/jwt_test.go
Normal file
221
server/internal/auth/jwt_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"novamd/internal/auth"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// jwt_test.go tests
|
||||
|
||||
func TestNewJWTService(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config auth.JWTConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid configuration",
|
||||
config: auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing signing key",
|
||||
config: auth.JWTConfig{
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "zero expiry times",
|
||||
config: auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
},
|
||||
wantErr: false, // Should use default values
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
service, err := auth.NewJWTService(tc.config)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if service == nil {
|
||||
t.Error("expected service, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAndValidateToken(t *testing.T) {
|
||||
config := auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
}
|
||||
service, _ := auth.NewJWTService(config)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
role string
|
||||
tokenType auth.TokenType
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid access token",
|
||||
userID: 1,
|
||||
role: "admin",
|
||||
tokenType: auth.AccessToken,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid refresh token",
|
||||
userID: 1,
|
||||
role: "editor",
|
||||
tokenType: auth.RefreshToken,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var token string
|
||||
var err error
|
||||
|
||||
// Generate token based on type
|
||||
if tc.tokenType == auth.AccessToken {
|
||||
token, err = service.GenerateAccessToken(tc.userID, tc.role)
|
||||
} else {
|
||||
token, err = service.GenerateRefreshToken(tc.userID, tc.role)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate token: %v", err)
|
||||
}
|
||||
|
||||
// Validate token
|
||||
claims, err := service.ValidateToken(token)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify claims
|
||||
if claims.UserID != tc.userID {
|
||||
t.Errorf("userID = %v, want %v", claims.UserID, tc.userID)
|
||||
}
|
||||
if claims.Role != tc.role {
|
||||
t.Errorf("role = %v, want %v", claims.Role, tc.role)
|
||||
}
|
||||
if claims.Type != tc.tokenType {
|
||||
t.Errorf("type = %v, want %v", claims.Type, tc.tokenType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshAccessToken(t *testing.T) {
|
||||
config := auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
}
|
||||
service, _ := auth.NewJWTService(config)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
role string
|
||||
wantErr bool
|
||||
setupFunc func() string // Added setup function to handle custom token creation
|
||||
}{
|
||||
{
|
||||
name: "valid refresh token",
|
||||
userID: 1,
|
||||
role: "admin",
|
||||
wantErr: false,
|
||||
setupFunc: func() string {
|
||||
token, _ := service.GenerateRefreshToken(1, "admin")
|
||||
return token
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "expired refresh token",
|
||||
userID: 1,
|
||||
role: "admin",
|
||||
wantErr: true,
|
||||
setupFunc: func() string {
|
||||
// Create a token that's already expired
|
||||
claims := &auth.Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)), // Expired 1 hour ago
|
||||
IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
|
||||
},
|
||||
UserID: 1,
|
||||
Role: "admin",
|
||||
Type: auth.RefreshToken,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte(config.SigningKey))
|
||||
return tokenString
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
refreshToken := tc.setupFunc()
|
||||
newAccessToken, err := service.RefreshAccessToken(refreshToken)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
claims, err := service.ValidateToken(newAccessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to validate new access token: %v", err)
|
||||
}
|
||||
|
||||
if claims.UserID != tc.userID {
|
||||
t.Errorf("userID = %v, want %v", claims.UserID, tc.userID)
|
||||
}
|
||||
if claims.Role != tc.role {
|
||||
t.Errorf("role = %v, want %v", claims.Role, tc.role)
|
||||
}
|
||||
if claims.Type != auth.AccessToken {
|
||||
t.Errorf("token type = %v, want %v", claims.Type, auth.AccessToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,35 +1,21 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"novamd/internal/httpcontext"
|
||||
"novamd/internal/context"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
UserContextKey contextKey = "user"
|
||||
)
|
||||
|
||||
// UserClaims represents the user information stored in the request context
|
||||
type UserClaims struct {
|
||||
UserID int
|
||||
Role string
|
||||
}
|
||||
|
||||
// Middleware handles JWT authentication for protected routes
|
||||
type Middleware struct {
|
||||
jwtService *JWTService
|
||||
jwtManager JWTManager
|
||||
}
|
||||
|
||||
// NewMiddleware creates a new authentication middleware
|
||||
func NewMiddleware(jwtService *JWTService) *Middleware {
|
||||
func NewMiddleware(jwtManager JWTManager) *Middleware {
|
||||
return &Middleware{
|
||||
jwtService: jwtService,
|
||||
jwtManager: jwtManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +37,7 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// Validate token
|
||||
claims, err := m.jwtService.ValidateToken(parts[1])
|
||||
claims, err := m.jwtManager.ValidateToken(parts[1])
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid token", http.StatusUnauthorized)
|
||||
return
|
||||
@@ -63,14 +49,14 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
// Add user claims to request context
|
||||
ctx := context.WithValue(r.Context(), UserContextKey, UserClaims{
|
||||
UserID: claims.UserID,
|
||||
Role: claims.Role,
|
||||
})
|
||||
// Create handler context with user information
|
||||
hctx := &context.HandlerContext{
|
||||
UserID: claims.UserID,
|
||||
UserRole: claims.Role,
|
||||
}
|
||||
|
||||
// Call the next handler with the updated context
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
// Add context to request and continue
|
||||
next.ServeHTTP(w, context.WithHandlerContext(r, hctx))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -78,13 +64,12 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
|
||||
func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims, ok := r.Context().Value(UserContextKey).(UserClaims)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if claims.Role != role && claims.Role != "admin" {
|
||||
if ctx.UserRole != role && ctx.UserRole != "admin" {
|
||||
http.Error(w, "Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
@@ -94,10 +79,10 @@ func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// RequireWorkspaceAccess returns a middleware that ensures the user has access to the workspace
|
||||
func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get our handler context
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -117,12 +102,3 @@ func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserFromContext retrieves user claims from the request context
|
||||
func GetUserFromContext(ctx context.Context) (*UserClaims, error) {
|
||||
claims, ok := ctx.Value(UserContextKey).(UserClaims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no user found in context")
|
||||
}
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
294
server/internal/auth/middleware_test.go
Normal file
294
server/internal/auth/middleware_test.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"novamd/internal/auth"
|
||||
"novamd/internal/context"
|
||||
"novamd/internal/models"
|
||||
)
|
||||
|
||||
// Complete mockResponseWriter implementation
|
||||
type mockResponseWriter struct {
|
||||
headers http.Header
|
||||
statusCode int
|
||||
written []byte
|
||||
}
|
||||
|
||||
func newMockResponseWriter() *mockResponseWriter {
|
||||
return &mockResponseWriter{
|
||||
headers: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) Header() http.Header {
|
||||
return m.headers
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) Write(b []byte) (int, error) {
|
||||
m.written = b
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) WriteHeader(statusCode int) {
|
||||
m.statusCode = statusCode
|
||||
}
|
||||
|
||||
func TestAuthenticateMiddleware(t *testing.T) {
|
||||
config := auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
}
|
||||
jwtService, _ := auth.NewJWTService(config)
|
||||
middleware := auth.NewMiddleware(jwtService)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupAuth func() string
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "valid token",
|
||||
setupAuth: func() string {
|
||||
token, _ := jwtService.GenerateAccessToken(1, "admin")
|
||||
return token
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "missing auth header",
|
||||
setupAuth: func() string {
|
||||
return ""
|
||||
},
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "invalid auth format",
|
||||
setupAuth: func() string {
|
||||
return "InvalidFormat token"
|
||||
},
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "invalid token",
|
||||
setupAuth: func() string {
|
||||
return "Bearer invalid.token.here"
|
||||
},
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
if token := tc.setupAuth(); token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
// Create response recorder
|
||||
w := newMockResponseWriter()
|
||||
|
||||
// Create test handler
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Execute middleware
|
||||
middleware.Authenticate(next).ServeHTTP(w, req)
|
||||
|
||||
// Check status code
|
||||
if w.statusCode != tc.wantStatusCode {
|
||||
t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode)
|
||||
}
|
||||
|
||||
// Check if next handler was called when expected
|
||||
if tc.wantStatusCode == http.StatusOK && !nextCalled {
|
||||
t.Error("next handler was not called")
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK && nextCalled {
|
||||
t.Error("next handler was called when it shouldn't have been")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireRole(t *testing.T) {
|
||||
config := auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
}
|
||||
jwtService, _ := auth.NewJWTService(config)
|
||||
middleware := auth.NewMiddleware(jwtService)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userRole string
|
||||
requiredRole string
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "matching role",
|
||||
userRole: "admin",
|
||||
requiredRole: "admin",
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "admin accessing other role",
|
||||
userRole: "admin",
|
||||
requiredRole: "editor",
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "insufficient role",
|
||||
userRole: "editor",
|
||||
requiredRole: "admin",
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create handler context with user info
|
||||
hctx := &context.HandlerContext{
|
||||
UserID: 1,
|
||||
UserRole: tc.userRole,
|
||||
}
|
||||
|
||||
// Create request with handler context
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req = context.WithHandlerContext(req, hctx)
|
||||
w := newMockResponseWriter()
|
||||
|
||||
// Create test handler
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Execute middleware
|
||||
middleware.RequireRole(tc.requiredRole)(next).ServeHTTP(w, req)
|
||||
|
||||
// Check status code
|
||||
if w.statusCode != tc.wantStatusCode {
|
||||
t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode)
|
||||
}
|
||||
|
||||
// Check if next handler was called when expected
|
||||
if tc.wantStatusCode == http.StatusOK && !nextCalled {
|
||||
t.Error("next handler was not called")
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK && nextCalled {
|
||||
t.Error("next handler was called when it shouldn't have been")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireWorkspaceAccess(t *testing.T) {
|
||||
config := auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
}
|
||||
jwtService, _ := auth.NewJWTService(config)
|
||||
middleware := auth.NewMiddleware(jwtService)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupContext func() *context.HandlerContext
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner access",
|
||||
setupContext: func() *context.HandlerContext {
|
||||
return &context.HandlerContext{
|
||||
UserID: 1,
|
||||
UserRole: "editor",
|
||||
Workspace: &models.Workspace{
|
||||
ID: 1,
|
||||
UserID: 1, // Same as context UserID
|
||||
},
|
||||
}
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "admin access to other's workspace",
|
||||
setupContext: func() *context.HandlerContext {
|
||||
return &context.HandlerContext{
|
||||
UserID: 2,
|
||||
UserRole: "admin",
|
||||
Workspace: &models.Workspace{
|
||||
ID: 1,
|
||||
UserID: 1, // Different from context UserID
|
||||
},
|
||||
}
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "unauthorized access attempt",
|
||||
setupContext: func() *context.HandlerContext {
|
||||
return &context.HandlerContext{
|
||||
UserID: 2,
|
||||
UserRole: "editor",
|
||||
Workspace: &models.Workspace{
|
||||
ID: 1,
|
||||
UserID: 1, // Different from context UserID
|
||||
},
|
||||
}
|
||||
},
|
||||
wantStatusCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "no workspace in context",
|
||||
setupContext: func() *context.HandlerContext {
|
||||
return &context.HandlerContext{
|
||||
UserID: 1,
|
||||
UserRole: "editor",
|
||||
Workspace: nil,
|
||||
}
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create request with context
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req = context.WithHandlerContext(req, tc.setupContext())
|
||||
w := newMockResponseWriter()
|
||||
|
||||
// Create test handler
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Execute middleware
|
||||
middleware.RequireWorkspaceAccess(next).ServeHTTP(w, req)
|
||||
|
||||
// Check status code
|
||||
if w.statusCode != tc.wantStatusCode {
|
||||
t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode)
|
||||
}
|
||||
|
||||
// Check if next handler was called when expected
|
||||
if tc.wantStatusCode == http.StatusOK && !nextCalled {
|
||||
t.Error("next handler was not called")
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK && nextCalled {
|
||||
t.Error("next handler was called when it shouldn't have been")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,67 +1,49 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/models"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Session represents a user session in the database
|
||||
type Session struct {
|
||||
ID string // Unique session identifier
|
||||
UserID int // ID of the user this session belongs to
|
||||
RefreshToken string // The refresh token associated with this session
|
||||
ExpiresAt time.Time // When this session expires
|
||||
CreatedAt time.Time // When this session was created
|
||||
}
|
||||
|
||||
// SessionService manages user sessions in the database
|
||||
type SessionService struct {
|
||||
db *sql.DB // Database connection
|
||||
jwtService *JWTService // JWT service for token operations
|
||||
db db.SessionStore // Database store for sessions
|
||||
jwtManager JWTManager // JWT Manager for token operations
|
||||
}
|
||||
|
||||
// NewSessionService creates a new session service
|
||||
// Parameters:
|
||||
// - db: database connection
|
||||
// - jwtService: JWT service for token operations
|
||||
func NewSessionService(db *sql.DB, jwtService *JWTService) *SessionService {
|
||||
// NewSessionService creates a new session service with the given database and JWT manager
|
||||
func NewSessionService(db db.SessionStore, jwtManager JWTManager) *SessionService {
|
||||
return &SessionService{
|
||||
db: db,
|
||||
jwtService: jwtService,
|
||||
jwtManager: jwtManager,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSession creates a new user session
|
||||
// Parameters:
|
||||
// - userID: the ID of the user
|
||||
// - role: the role of the user
|
||||
// Returns:
|
||||
// - session: the created session
|
||||
// - accessToken: a new access token
|
||||
// - error: any error that occurred
|
||||
func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) {
|
||||
// CreateSession creates a new user session for a user with the given userID and role
|
||||
func (s *SessionService) CreateSession(userID int, role string) (*models.Session, string, error) {
|
||||
// Generate both access and refresh tokens
|
||||
accessToken, err := s.jwtService.GenerateAccessToken(userID, role)
|
||||
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to generate access token: %w", err)
|
||||
}
|
||||
|
||||
refreshToken, err := s.jwtService.GenerateRefreshToken(userID, role)
|
||||
refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to generate refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Validate the refresh token to get its expiry time
|
||||
claims, err := s.jwtService.ValidateToken(refreshToken)
|
||||
claims, err := s.jwtManager.ValidateToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to validate refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Create a new session record
|
||||
session := &Session{
|
||||
session := &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: userID,
|
||||
RefreshToken: refreshToken,
|
||||
@@ -69,72 +51,43 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Store the session in the database
|
||||
_, err = s.db.Exec(`
|
||||
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to store session: %w", err)
|
||||
// Store the session
|
||||
if err := s.db.CreateSession(session); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return session, accessToken, nil
|
||||
}
|
||||
|
||||
// RefreshSession creates a new access token using a refresh token
|
||||
// Parameters:
|
||||
// - refreshToken: the refresh token to use
|
||||
// Returns:
|
||||
// - string: a new access token
|
||||
// - error: any error that occurred
|
||||
// RefreshSession creates a new access token using a refreshToken
|
||||
func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
||||
// Get session from database first
|
||||
session, err := s.db.GetSessionByRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid session: %w", err)
|
||||
}
|
||||
|
||||
// Validate the refresh token
|
||||
claims, err := s.jwtService.ValidateToken(refreshToken)
|
||||
claims, err := s.jwtManager.ValidateToken(refreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Check if the session exists and is not expired
|
||||
var session Session
|
||||
err = s.db.QueryRow(`
|
||||
SELECT id, user_id, refresh_token, expires_at, created_at
|
||||
FROM sessions
|
||||
WHERE refresh_token = ? AND expires_at > ?`,
|
||||
refreshToken, time.Now(),
|
||||
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("session not found or expired")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch session: %w", err)
|
||||
// Double check that the claims match the session
|
||||
if claims.UserID != session.UserID {
|
||||
return "", fmt.Errorf("token does not match session")
|
||||
}
|
||||
|
||||
// Generate a new access token
|
||||
return s.jwtService.GenerateAccessToken(claims.UserID, claims.Role)
|
||||
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
|
||||
}
|
||||
|
||||
// InvalidateSession removes a session from the database
|
||||
// Parameters:
|
||||
// - sessionID: the ID of the session to invalidate
|
||||
// Returns:
|
||||
// - error: any error that occurred
|
||||
// InvalidateSession removes a session with the given sessionID from the database
|
||||
func (s *SessionService) InvalidateSession(sessionID string) error {
|
||||
_, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to invalidate session: %w", err)
|
||||
}
|
||||
return nil
|
||||
return s.db.DeleteSession(sessionID)
|
||||
}
|
||||
|
||||
// CleanExpiredSessions removes all expired sessions from the database
|
||||
// Returns:
|
||||
// - error: any error that occurred
|
||||
func (s *SessionService) CleanExpiredSessions() error {
|
||||
_, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clean expired sessions: %w", err)
|
||||
}
|
||||
return nil
|
||||
return s.db.CleanExpiredSessions()
|
||||
}
|
||||
|
||||
304
server/internal/auth/session_test.go
Normal file
304
server/internal/auth/session_test.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"novamd/internal/auth"
|
||||
"novamd/internal/models"
|
||||
)
|
||||
|
||||
// Mock SessionStore
|
||||
type mockSessionStore struct {
|
||||
sessions map[string]*models.Session
|
||||
sessionsByToken map[string]*models.Session // Added index by refresh token
|
||||
}
|
||||
|
||||
func newMockSessionStore() *mockSessionStore {
|
||||
return &mockSessionStore{
|
||||
sessions: make(map[string]*models.Session),
|
||||
sessionsByToken: make(map[string]*models.Session),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockSessionStore) CreateSession(session *models.Session) error {
|
||||
m.sessions[session.ID] = session
|
||||
m.sessionsByToken[session.RefreshToken] = session
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSessionStore) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
|
||||
session, exists := m.sessionsByToken[refreshToken]
|
||||
if !exists {
|
||||
return nil, errors.New("session not found")
|
||||
}
|
||||
if session.ExpiresAt.Before(time.Now()) {
|
||||
return nil, errors.New("session expired")
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (m *mockSessionStore) DeleteSession(sessionID string) error {
|
||||
session, exists := m.sessions[sessionID]
|
||||
if !exists {
|
||||
return errors.New("session not found")
|
||||
}
|
||||
delete(m.sessionsByToken, session.RefreshToken)
|
||||
delete(m.sessions, sessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSessionStore) CleanExpiredSessions() error {
|
||||
for id, session := range m.sessions {
|
||||
if session.ExpiresAt.Before(time.Now()) {
|
||||
delete(m.sessionsByToken, session.RefreshToken)
|
||||
delete(m.sessions, id)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestCreateSession(t *testing.T) {
|
||||
config := auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
}
|
||||
jwtService, _ := auth.NewJWTService(config)
|
||||
mockDB := newMockSessionStore()
|
||||
sessionService := auth.NewSessionService(mockDB, jwtService)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
role string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful session creation",
|
||||
userID: 1,
|
||||
role: "admin",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "another successful session",
|
||||
userID: 2,
|
||||
role: "editor",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, accessToken, err := sessionService.CreateSession(tc.userID, tc.role)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify session
|
||||
if session.UserID != tc.userID {
|
||||
t.Errorf("userID = %v, want %v", session.UserID, tc.userID)
|
||||
}
|
||||
|
||||
// Verify the session was stored
|
||||
storedSession, exists := mockDB.sessions[session.ID]
|
||||
if !exists {
|
||||
t.Error("session was not stored in database")
|
||||
}
|
||||
if storedSession.RefreshToken != session.RefreshToken {
|
||||
t.Error("stored refresh token doesn't match")
|
||||
}
|
||||
|
||||
// Verify access token
|
||||
claims, err := jwtService.ValidateToken(accessToken)
|
||||
if err != nil {
|
||||
t.Errorf("failed to validate access token: %v", err)
|
||||
return
|
||||
}
|
||||
if claims.UserID != tc.userID {
|
||||
t.Errorf("access token userID = %v, want %v", claims.UserID, tc.userID)
|
||||
}
|
||||
if claims.Role != tc.role {
|
||||
t.Errorf("access token role = %v, want %v", claims.Role, tc.role)
|
||||
}
|
||||
if claims.Type != auth.AccessToken {
|
||||
t.Errorf("token type = %v, want access token", claims.Type)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshSession(t *testing.T) {
|
||||
config := auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
}
|
||||
jwtService, _ := auth.NewJWTService(config)
|
||||
mockDB := newMockSessionStore()
|
||||
sessionService := auth.NewSessionService(mockDB, jwtService)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupSession func() string
|
||||
wantErr bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid refresh token",
|
||||
setupSession: func() string {
|
||||
token, _ := jwtService.GenerateRefreshToken(1, "admin")
|
||||
session := &models.Session{
|
||||
ID: "test-session-1",
|
||||
UserID: 1,
|
||||
RefreshToken: token,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := mockDB.CreateSession(session); err != nil {
|
||||
t.Fatalf("failed to create session: %v", err)
|
||||
}
|
||||
return token
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "expired refresh token",
|
||||
setupSession: func() string {
|
||||
token, _ := jwtService.GenerateRefreshToken(1, "admin")
|
||||
session := &models.Session{
|
||||
ID: "test-session-2",
|
||||
UserID: 1,
|
||||
RefreshToken: token,
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired
|
||||
CreatedAt: time.Now().Add(-2 * time.Hour),
|
||||
}
|
||||
if err := mockDB.CreateSession(session); err != nil {
|
||||
t.Fatalf("failed to create session: %v", err)
|
||||
}
|
||||
return token
|
||||
},
|
||||
wantErr: true,
|
||||
errorContains: "session expired",
|
||||
},
|
||||
{
|
||||
name: "non-existent refresh token",
|
||||
setupSession: func() string {
|
||||
return "non-existent-token"
|
||||
},
|
||||
wantErr: true,
|
||||
errorContains: "session not found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
refreshToken := tc.setupSession()
|
||||
newAccessToken, err := sessionService.RefreshSession(refreshToken)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if tc.errorContains != "" && !strings.Contains(err.Error(), tc.errorContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errorContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify new access token
|
||||
claims, err := jwtService.ValidateToken(newAccessToken)
|
||||
if err != nil {
|
||||
t.Errorf("failed to validate new access token: %v", err)
|
||||
return
|
||||
}
|
||||
if claims.Type != auth.AccessToken {
|
||||
t.Errorf("token type = %v, want access token", claims.Type)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidateSession(t *testing.T) {
|
||||
config := auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
}
|
||||
jwtService, _ := auth.NewJWTService(config)
|
||||
mockDB := newMockSessionStore()
|
||||
sessionService := auth.NewSessionService(mockDB, jwtService)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupSession func() string
|
||||
wantErr bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid session invalidation",
|
||||
setupSession: func() string {
|
||||
session := &models.Session{
|
||||
ID: "test-session-1",
|
||||
UserID: 1,
|
||||
RefreshToken: "valid-token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := mockDB.CreateSession(session); err != nil {
|
||||
t.Fatalf("failed to create session: %v", err)
|
||||
}
|
||||
return session.ID
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent session",
|
||||
setupSession: func() string {
|
||||
return "non-existent-session-id"
|
||||
},
|
||||
wantErr: true,
|
||||
errorContains: "session not found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sessionID := tc.setupSession()
|
||||
err := sessionService.InvalidateSession(sessionID)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if !strings.Contains(err.Error(), tc.errorContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errorContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify session was removed
|
||||
if _, exists := mockDB.sessions[sessionID]; exists {
|
||||
t.Error("session still exists after invalidation")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,16 @@
|
||||
// Package config provides the configuration for the application
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"novamd/internal/secrets"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"novamd/internal/crypto"
|
||||
)
|
||||
|
||||
// Config holds the configuration for the application
|
||||
type Config struct {
|
||||
DBPath string
|
||||
WorkDir string
|
||||
@@ -27,6 +27,7 @@ type Config struct {
|
||||
IsDevelopment bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns a new Config instance with default values
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
DBPath: "./novamd.db",
|
||||
@@ -39,13 +40,14 @@ func DefaultConfig() *Config {
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid
|
||||
func (c *Config) Validate() error {
|
||||
if c.AdminEmail == "" || c.AdminPassword == "" {
|
||||
return fmt.Errorf("NOVAMD_ADMIN_EMAIL and NOVAMD_ADMIN_PASSWORD must be set")
|
||||
}
|
||||
|
||||
// Validate encryption key
|
||||
if err := crypto.ValidateKey(c.EncryptionKey); err != nil {
|
||||
if err := secrets.ValidateKey(c.EncryptionKey); err != nil {
|
||||
return fmt.Errorf("invalid NOVAMD_ENCRYPTION_KEY: %w", err)
|
||||
}
|
||||
|
||||
@@ -63,16 +65,10 @@ func Load() (*Config, error) {
|
||||
if dbPath := os.Getenv("NOVAMD_DB_PATH"); dbPath != "" {
|
||||
config.DBPath = dbPath
|
||||
}
|
||||
if err := ensureDir(filepath.Dir(config.DBPath)); err != nil {
|
||||
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
||||
}
|
||||
|
||||
if workDir := os.Getenv("NOVAMD_WORKDIR"); workDir != "" {
|
||||
config.WorkDir = workDir
|
||||
}
|
||||
if err := ensureDir(config.WorkDir); err != nil {
|
||||
return nil, fmt.Errorf("failed to create work directory: %w", err)
|
||||
}
|
||||
|
||||
if staticPath := os.Getenv("NOVAMD_STATIC_PATH"); staticPath != "" {
|
||||
config.StaticPath = staticPath
|
||||
@@ -115,10 +111,3 @@ func Load() (*Config, error) {
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func ensureDir(dir string) error {
|
||||
if dir == "" {
|
||||
return nil
|
||||
}
|
||||
return os.MkdirAll(dir, 0755)
|
||||
}
|
||||
|
||||
215
server/internal/config/config_test.go
Normal file
215
server/internal/config/config_test.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"novamd/internal/config"
|
||||
)
|
||||
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{"DBPath", cfg.DBPath, "./novamd.db"},
|
||||
{"WorkDir", cfg.WorkDir, "./data"},
|
||||
{"StaticPath", cfg.StaticPath, "../app/dist"},
|
||||
{"Port", cfg.Port, "8080"},
|
||||
{"RateLimitRequests", cfg.RateLimitRequests, 100},
|
||||
{"RateLimitWindow", cfg.RateLimitWindow, time.Minute * 15},
|
||||
{"IsDevelopment", cfg.IsDevelopment, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.got != tt.expected {
|
||||
t.Errorf("DefaultConfig().%s = %v, want %v", tt.name, tt.got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setEnv is a helper function to set environment variables and check for errors
|
||||
func setEnv(t *testing.T, key, value string) {
|
||||
if err := os.Setenv(key, value); err != nil {
|
||||
t.Fatalf("Failed to set environment variable %s: %v", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
// Helper function to reset environment variables
|
||||
cleanup := func() {
|
||||
envVars := []string{
|
||||
"NOVAMD_ENV",
|
||||
"NOVAMD_DB_PATH",
|
||||
"NOVAMD_WORKDIR",
|
||||
"NOVAMD_STATIC_PATH",
|
||||
"NOVAMD_PORT",
|
||||
"NOVAMD_APP_URL",
|
||||
"NOVAMD_CORS_ORIGINS",
|
||||
"NOVAMD_ADMIN_EMAIL",
|
||||
"NOVAMD_ADMIN_PASSWORD",
|
||||
"NOVAMD_ENCRYPTION_KEY",
|
||||
"NOVAMD_JWT_SIGNING_KEY",
|
||||
"NOVAMD_RATE_LIMIT_REQUESTS",
|
||||
"NOVAMD_RATE_LIMIT_WINDOW",
|
||||
}
|
||||
for _, env := range envVars {
|
||||
if err := os.Unsetenv(env); err != nil {
|
||||
t.Fatalf("Failed to unset environment variable %s: %v", env, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("load with defaults", func(t *testing.T) {
|
||||
cleanup()
|
||||
defer cleanup()
|
||||
|
||||
// Set required env vars
|
||||
setEnv(t, "NOVAMD_ADMIN_EMAIL", "admin@example.com")
|
||||
setEnv(t, "NOVAMD_ADMIN_PASSWORD", "password123")
|
||||
setEnv(t, "NOVAMD_ENCRYPTION_KEY", "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=") // 32 bytes base64 encoded
|
||||
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.DBPath != "./novamd.db" {
|
||||
t.Errorf("default DBPath = %v, want %v", cfg.DBPath, "./novamd.db")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load with custom values", func(t *testing.T) {
|
||||
cleanup()
|
||||
defer cleanup()
|
||||
|
||||
// Set all environment variables
|
||||
envs := map[string]string{
|
||||
"NOVAMD_ENV": "development",
|
||||
"NOVAMD_DB_PATH": "/custom/db/path.db",
|
||||
"NOVAMD_WORKDIR": "/custom/work/dir",
|
||||
"NOVAMD_STATIC_PATH": "/custom/static/path",
|
||||
"NOVAMD_PORT": "3000",
|
||||
"NOVAMD_APP_URL": "http://localhost:3000",
|
||||
"NOVAMD_CORS_ORIGINS": "http://localhost:3000,http://localhost:3001",
|
||||
"NOVAMD_ADMIN_EMAIL": "admin@example.com",
|
||||
"NOVAMD_ADMIN_PASSWORD": "password123",
|
||||
"NOVAMD_ENCRYPTION_KEY": "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=",
|
||||
"NOVAMD_JWT_SIGNING_KEY": "secret-key",
|
||||
"NOVAMD_RATE_LIMIT_REQUESTS": "200",
|
||||
"NOVAMD_RATE_LIMIT_WINDOW": "30m",
|
||||
}
|
||||
|
||||
for k, v := range envs {
|
||||
setEnv(t, k, v)
|
||||
}
|
||||
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{"IsDevelopment", cfg.IsDevelopment, true},
|
||||
{"DBPath", cfg.DBPath, "/custom/db/path.db"},
|
||||
{"WorkDir", cfg.WorkDir, "/custom/work/dir"},
|
||||
{"StaticPath", cfg.StaticPath, "/custom/static/path"},
|
||||
{"Port", cfg.Port, "3000"},
|
||||
{"AppURL", cfg.AppURL, "http://localhost:3000"},
|
||||
{"AdminEmail", cfg.AdminEmail, "admin@example.com"},
|
||||
{"AdminPassword", cfg.AdminPassword, "password123"},
|
||||
{"JWTSigningKey", cfg.JWTSigningKey, "secret-key"},
|
||||
{"RateLimitRequests", cfg.RateLimitRequests, 200},
|
||||
{"RateLimitWindow", cfg.RateLimitWindow, 30 * time.Minute},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.got != tt.expected {
|
||||
t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test CORS origins separately as it's a slice
|
||||
expectedOrigins := []string{"http://localhost:3000", "http://localhost:3001"}
|
||||
if len(cfg.CORSOrigins) != len(expectedOrigins) {
|
||||
t.Errorf("CORSOrigins length = %v, want %v", len(cfg.CORSOrigins), len(expectedOrigins))
|
||||
}
|
||||
for i, origin := range cfg.CORSOrigins {
|
||||
if origin != expectedOrigins[i] {
|
||||
t.Errorf("CORSOrigins[%d] = %v, want %v", i, origin, expectedOrigins[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validation failures", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupEnv func(*testing.T)
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "missing admin email",
|
||||
setupEnv: func(t *testing.T) {
|
||||
cleanup()
|
||||
setEnv(t, "NOVAMD_ADMIN_PASSWORD", "password123")
|
||||
setEnv(t, "NOVAMD_ENCRYPTION_KEY", "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=")
|
||||
},
|
||||
expectedError: "NOVAMD_ADMIN_EMAIL and NOVAMD_ADMIN_PASSWORD must be set",
|
||||
},
|
||||
{
|
||||
name: "missing admin password",
|
||||
setupEnv: func(t *testing.T) {
|
||||
cleanup()
|
||||
setEnv(t, "NOVAMD_ADMIN_EMAIL", "admin@example.com")
|
||||
setEnv(t, "NOVAMD_ENCRYPTION_KEY", "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=")
|
||||
},
|
||||
expectedError: "NOVAMD_ADMIN_EMAIL and NOVAMD_ADMIN_PASSWORD must be set",
|
||||
},
|
||||
{
|
||||
name: "missing encryption key",
|
||||
setupEnv: func(t *testing.T) {
|
||||
cleanup()
|
||||
setEnv(t, "NOVAMD_ADMIN_EMAIL", "admin@example.com")
|
||||
setEnv(t, "NOVAMD_ADMIN_PASSWORD", "password123")
|
||||
},
|
||||
expectedError: "invalid NOVAMD_ENCRYPTION_KEY: encryption key is required",
|
||||
},
|
||||
{
|
||||
name: "invalid encryption key",
|
||||
setupEnv: func(t *testing.T) {
|
||||
cleanup()
|
||||
setEnv(t, "NOVAMD_ADMIN_EMAIL", "admin@example.com")
|
||||
setEnv(t, "NOVAMD_ADMIN_PASSWORD", "password123")
|
||||
setEnv(t, "NOVAMD_ENCRYPTION_KEY", "invalid-key")
|
||||
},
|
||||
expectedError: "invalid NOVAMD_ENCRYPTION_KEY: invalid base64 encoding: illegal base64 data at input byte 7",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.setupEnv(t)
|
||||
_, err := config.Load()
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
return
|
||||
}
|
||||
if err.Error() != tc.expectedError {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.expectedError)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
62
server/internal/context/context.go
Normal file
62
server/internal/context/context.go
Normal file
@@ -0,0 +1,62 @@
|
||||
// Package context provides functions for managing request context
|
||||
package context
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"novamd/internal/models"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
// HandlerContextKey is the key used to store handler context in the request context
|
||||
HandlerContextKey contextKey = "handlerContext"
|
||||
)
|
||||
|
||||
// UserClaims represents user information from authentication
|
||||
type UserClaims struct {
|
||||
UserID int
|
||||
Role string
|
||||
}
|
||||
|
||||
// HandlerContext holds the request-specific data available to all handlers
|
||||
type HandlerContext struct {
|
||||
UserID int
|
||||
UserRole string
|
||||
Workspace *models.Workspace // Optional, only set for workspace routes
|
||||
}
|
||||
|
||||
// GetRequestContext retrieves the handler context from the request
|
||||
func GetRequestContext(w http.ResponseWriter, r *http.Request) (*HandlerContext, bool) {
|
||||
ctx := r.Context().Value(HandlerContextKey)
|
||||
if ctx == nil {
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return nil, false
|
||||
}
|
||||
return ctx.(*HandlerContext), true
|
||||
}
|
||||
|
||||
// WithHandlerContext adds handler context to the request
|
||||
func WithHandlerContext(r *http.Request, hctx *HandlerContext) *http.Request {
|
||||
return r.WithContext(context.WithValue(r.Context(), HandlerContextKey, hctx))
|
||||
}
|
||||
|
||||
// GetUserFromContext retrieves user claims from the context
|
||||
func GetUserFromContext(ctx context.Context) (*UserClaims, error) {
|
||||
val := ctx.Value(HandlerContextKey)
|
||||
if val == nil {
|
||||
return nil, fmt.Errorf("no user found in context")
|
||||
}
|
||||
|
||||
hctx, ok := val.(*HandlerContext)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid context type")
|
||||
}
|
||||
|
||||
return &UserClaims{
|
||||
UserID: hctx.UserID,
|
||||
Role: hctx.UserRole,
|
||||
}, nil
|
||||
}
|
||||
139
server/internal/context/context_test.go
Normal file
139
server/internal/context/context_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package context_test
|
||||
|
||||
import (
|
||||
stdctx "context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/context"
|
||||
)
|
||||
|
||||
func TestGetRequestContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupCtx func() *context.HandlerContext
|
||||
wantStatus int
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
name: "valid context",
|
||||
setupCtx: func() *context.HandlerContext {
|
||||
return &context.HandlerContext{
|
||||
UserID: 1,
|
||||
UserRole: "admin",
|
||||
}
|
||||
},
|
||||
wantStatus: http.StatusOK,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "missing context",
|
||||
setupCtx: func() *context.HandlerContext {
|
||||
return nil
|
||||
},
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create test request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if ctx := tt.setupCtx(); ctx != nil {
|
||||
req = context.WithHandlerContext(req, ctx)
|
||||
}
|
||||
|
||||
gotCtx, ok := context.GetRequestContext(w, req)
|
||||
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("GetRequestContext() ok = %v, want %v", ok, tt.wantOK)
|
||||
}
|
||||
|
||||
if !tt.wantOK {
|
||||
if w.Code != tt.wantStatus {
|
||||
t.Errorf("GetRequestContext() status = %v, want %v", w.Code, tt.wantStatus)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if gotCtx.UserID != tt.setupCtx().UserID {
|
||||
t.Errorf("GetRequestContext() UserID = %v, want %v", gotCtx.UserID, tt.setupCtx().UserID)
|
||||
}
|
||||
|
||||
if gotCtx.UserRole != tt.setupCtx().UserRole {
|
||||
t.Errorf("GetRequestContext() UserRole = %v, want %v", gotCtx.UserRole, tt.setupCtx().UserRole)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserFromContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupCtx func() stdctx.Context
|
||||
wantUser *context.UserClaims
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "valid user context",
|
||||
setupCtx: func() stdctx.Context {
|
||||
return stdctx.WithValue(stdctx.Background(), context.HandlerContextKey, &context.HandlerContext{
|
||||
UserID: 1,
|
||||
UserRole: "admin",
|
||||
})
|
||||
},
|
||||
wantUser: &context.UserClaims{
|
||||
UserID: 1,
|
||||
Role: "admin",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "missing context",
|
||||
setupCtx: func() stdctx.Context {
|
||||
return stdctx.Background()
|
||||
},
|
||||
wantUser: nil,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid context type",
|
||||
setupCtx: func() stdctx.Context {
|
||||
return stdctx.WithValue(stdctx.Background(), context.HandlerContextKey, "invalid")
|
||||
},
|
||||
wantUser: nil,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := tt.setupCtx()
|
||||
gotUser, err := context.GetUserFromContext(ctx)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Error("GetUserFromContext() error = nil, want error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("GetUserFromContext() unexpected error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if gotUser.UserID != tt.wantUser.UserID {
|
||||
t.Errorf("GetUserFromContext() UserID = %v, want %v", gotUser.UserID, tt.wantUser.UserID)
|
||||
}
|
||||
|
||||
if gotUser.Role != tt.wantUser.Role {
|
||||
t.Errorf("GetUserFromContext() Role = %v, want %v", gotUser.Role, tt.wantUser.Role)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,38 +1,37 @@
|
||||
package middleware
|
||||
package context
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"novamd/internal/auth"
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/httpcontext"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// User ID and User Role context
|
||||
func WithUserContext(next http.Handler) http.Handler {
|
||||
// WithUserContextMiddleware extracts user information from JWT claims
|
||||
// and adds it to the request context
|
||||
func WithUserContextMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims, err := auth.GetUserFromContext(r.Context())
|
||||
claims, err := GetUserFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
hctx := &httpcontext.HandlerContext{
|
||||
hctx := &HandlerContext{
|
||||
UserID: claims.UserID,
|
||||
UserRole: claims.Role,
|
||||
}
|
||||
|
||||
r = httpcontext.WithHandlerContext(r, hctx)
|
||||
r = WithHandlerContext(r, hctx)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Workspace context
|
||||
func WithWorkspaceContext(db *db.DB) func(http.Handler) http.Handler {
|
||||
// WithWorkspaceContextMiddleware adds workspace information to the request context
|
||||
func WithWorkspaceContextMiddleware(db db.WorkspaceReader) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -46,7 +45,7 @@ func WithWorkspaceContext(db *db.DB) func(http.Handler) http.Handler {
|
||||
|
||||
// Update existing context with workspace
|
||||
ctx.Workspace = workspace
|
||||
r = httpcontext.WithHandlerContext(r, ctx)
|
||||
r = WithHandlerContext(r, ctx)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
197
server/internal/context/middleware_test.go
Normal file
197
server/internal/context/middleware_test.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package context_test
|
||||
|
||||
import (
|
||||
stdctx "context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/context"
|
||||
"novamd/internal/models"
|
||||
)
|
||||
|
||||
// MockDB implements the minimal database interface needed for testing
|
||||
type MockDB struct {
|
||||
GetWorkspaceByNameFunc func(userID int, workspaceName string) (*models.Workspace, error)
|
||||
}
|
||||
|
||||
func (m *MockDB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
|
||||
return m.GetWorkspaceByNameFunc(userID, workspaceName)
|
||||
}
|
||||
|
||||
func (m *MockDB) GetWorkspaceByID(_ int) (*models.Workspace, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockDB) GetWorkspacesByUserID(_ int) ([]*models.Workspace, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockDB) GetAllWorkspaces() ([]*models.Workspace, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestWithUserContextMiddleware(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupCtx func() *context.HandlerContext
|
||||
wantStatus int
|
||||
wantNext bool
|
||||
}{
|
||||
{
|
||||
name: "valid user context",
|
||||
setupCtx: func() *context.HandlerContext {
|
||||
return &context.HandlerContext{
|
||||
UserID: 1,
|
||||
UserRole: "admin",
|
||||
}
|
||||
},
|
||||
wantStatus: http.StatusOK,
|
||||
wantNext: true,
|
||||
},
|
||||
{
|
||||
name: "missing user context",
|
||||
setupCtx: func() *context.HandlerContext {
|
||||
return nil
|
||||
},
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantNext: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if ctx := tt.setupCtx(); ctx != nil {
|
||||
req = context.WithHandlerContext(req, ctx)
|
||||
}
|
||||
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := context.WithUserContextMiddleware(next)
|
||||
middleware.ServeHTTP(w, req)
|
||||
|
||||
if nextCalled != tt.wantNext {
|
||||
t.Errorf("WithUserContextMiddleware() next called = %v, want %v", nextCalled, tt.wantNext)
|
||||
}
|
||||
|
||||
if w.Code != tt.wantStatus {
|
||||
t.Errorf("WithUserContextMiddleware() status = %v, want %v", w.Code, tt.wantStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithWorkspaceContextMiddleware(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupCtx func() *context.HandlerContext
|
||||
workspaceName string
|
||||
mockWorkspace *models.Workspace
|
||||
mockError error
|
||||
wantStatus int
|
||||
wantNext bool
|
||||
}{
|
||||
{
|
||||
name: "valid workspace context",
|
||||
setupCtx: func() *context.HandlerContext {
|
||||
return &context.HandlerContext{
|
||||
UserID: 1,
|
||||
UserRole: "admin",
|
||||
}
|
||||
},
|
||||
workspaceName: "test-workspace",
|
||||
mockWorkspace: &models.Workspace{
|
||||
ID: 1,
|
||||
UserID: 1,
|
||||
Name: "test-workspace",
|
||||
},
|
||||
mockError: nil,
|
||||
wantStatus: http.StatusOK,
|
||||
wantNext: true,
|
||||
},
|
||||
{
|
||||
name: "workspace not found",
|
||||
setupCtx: func() *context.HandlerContext {
|
||||
return &context.HandlerContext{
|
||||
UserID: 1,
|
||||
UserRole: "admin",
|
||||
}
|
||||
},
|
||||
workspaceName: "nonexistent",
|
||||
mockWorkspace: nil,
|
||||
mockError: sql.ErrNoRows,
|
||||
wantStatus: http.StatusNotFound,
|
||||
wantNext: false,
|
||||
},
|
||||
{
|
||||
name: "missing user context",
|
||||
setupCtx: func() *context.HandlerContext { return nil },
|
||||
workspaceName: "test-workspace",
|
||||
mockWorkspace: nil,
|
||||
mockError: nil,
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantNext: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockDB := &MockDB{
|
||||
GetWorkspaceByNameFunc: func(_ int, _ string) (*models.Workspace, error) {
|
||||
return tt.mockWorkspace, tt.mockError
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if ctx := tt.setupCtx(); ctx != nil {
|
||||
req = context.WithHandlerContext(req, ctx)
|
||||
}
|
||||
|
||||
// Add workspace name to request context via chi URL params
|
||||
req = req.WithContext(stdctx.WithValue(req.Context(), "workspaceName", tt.workspaceName))
|
||||
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// Verify workspace was added to context
|
||||
if tt.mockWorkspace != nil {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
t.Error("Failed to get request context in next handler")
|
||||
return
|
||||
}
|
||||
if ctx.Workspace == nil {
|
||||
t.Error("Workspace not set in context")
|
||||
return
|
||||
}
|
||||
if ctx.Workspace.ID != tt.mockWorkspace.ID {
|
||||
t.Errorf("Workspace ID = %v, want %v", ctx.Workspace.ID, tt.mockWorkspace.ID)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
middleware := context.WithWorkspaceContextMiddleware(mockDB)(next)
|
||||
middleware.ServeHTTP(w, req)
|
||||
|
||||
if nextCalled != tt.wantNext {
|
||||
t.Errorf("WithWorkspaceContextMiddleware() next called = %v, want %v", nextCalled, tt.wantNext)
|
||||
}
|
||||
|
||||
if w.Code != tt.wantStatus {
|
||||
t.Errorf("WithWorkspaceContextMiddleware() status = %v, want %v", w.Code, tt.wantStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKeyRequired = fmt.Errorf("encryption key is required")
|
||||
ErrInvalidKeySize = fmt.Errorf("encryption key must be 32 bytes (256 bits) when decoded")
|
||||
)
|
||||
|
||||
type Crypto struct {
|
||||
key []byte
|
||||
}
|
||||
|
||||
// ValidateKey checks if the provided key is suitable for AES-256
|
||||
func ValidateKey(key string) error {
|
||||
if key == "" {
|
||||
return ErrKeyRequired
|
||||
}
|
||||
|
||||
// Attempt to decode base64
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid base64 encoding: %w", err)
|
||||
}
|
||||
|
||||
if len(keyBytes) != 32 {
|
||||
return fmt.Errorf("%w: got %d bytes", ErrInvalidKeySize, len(keyBytes))
|
||||
}
|
||||
|
||||
// Verify the key can be used for AES
|
||||
_, err = aes.NewCipher(keyBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid encryption key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// New creates a new Crypto instance with the provided base64-encoded key
|
||||
func New(key string) (*Crypto, error) {
|
||||
if err := ValidateKey(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyBytes, _ := base64.StdEncoding.DecodeString(key)
|
||||
return &Crypto{key: keyBytes}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts the plaintext using AES-256-GCM
|
||||
func (c *Crypto) Encrypt(plaintext string) (string, error) {
|
||||
if plaintext == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(c.key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the ciphertext using AES-256-GCM
|
||||
func (c *Crypto) Decrypt(ciphertext string) (string, error) {
|
||||
if ciphertext == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(c.key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertextBytes, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records.
|
||||
package db
|
||||
|
||||
import "novamd/internal/models"
|
||||
|
||||
// UserStats represents system-wide statistics
|
||||
type UserStats struct {
|
||||
TotalUsers int `json:"totalUsers"`
|
||||
TotalWorkspaces int `json:"totalWorkspaces"`
|
||||
ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days
|
||||
}
|
||||
|
||||
// GetAllUsers returns a list of all users in the system
|
||||
func (db *DB) GetAllUsers() ([]*models.User, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
id, email, display_name, role, created_at,
|
||||
last_workspace_id
|
||||
FROM users
|
||||
ORDER BY id ASC`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []*models.User
|
||||
for rows.Next() {
|
||||
user := &models.User{}
|
||||
err := rows.Scan(
|
||||
&user.ID, &user.Email, &user.DisplayName, &user.Role,
|
||||
&user.CreatedAt, &user.LastWorkspaceID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetSystemStats returns system-wide statistics
|
||||
func (db *DB) GetSystemStats() (*UserStats, error) {
|
||||
stats := &UserStats{}
|
||||
|
||||
// Get total users
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get total workspaces
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get active users (users with activity in last 30 days)
|
||||
err = db.QueryRow(`
|
||||
SELECT COUNT(DISTINCT user_id)
|
||||
FROM sessions
|
||||
WHERE created_at > datetime('now', '-30 days')`).
|
||||
Scan(&stats.ActiveUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
@@ -1,22 +1,104 @@
|
||||
// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records.
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"novamd/internal/crypto"
|
||||
"novamd/internal/models"
|
||||
"novamd/internal/secrets"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
// DB represents the database connection
|
||||
type DB struct {
|
||||
// UserStore defines the methods for interacting with user data in the database
|
||||
type UserStore interface {
|
||||
CreateUser(user *models.User) (*models.User, error)
|
||||
GetUserByEmail(email string) (*models.User, error)
|
||||
GetUserByID(userID int) (*models.User, error)
|
||||
GetAllUsers() ([]*models.User, error)
|
||||
UpdateUser(user *models.User) error
|
||||
DeleteUser(userID int) error
|
||||
UpdateLastWorkspace(userID int, workspaceName string) error
|
||||
GetLastWorkspaceName(userID int) (string, error)
|
||||
CountAdminUsers() (int, error)
|
||||
}
|
||||
|
||||
// WorkspaceReader defines the methods for reading workspace data from the database
|
||||
type WorkspaceReader interface {
|
||||
GetWorkspaceByID(workspaceID int) (*models.Workspace, error)
|
||||
GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error)
|
||||
GetWorkspacesByUserID(userID int) ([]*models.Workspace, error)
|
||||
GetAllWorkspaces() ([]*models.Workspace, error)
|
||||
}
|
||||
|
||||
// WorkspaceWriter defines the methods for writing workspace data to the database
|
||||
type WorkspaceWriter interface {
|
||||
CreateWorkspace(workspace *models.Workspace) error
|
||||
UpdateWorkspace(workspace *models.Workspace) error
|
||||
DeleteWorkspace(workspaceID int) error
|
||||
UpdateWorkspaceSettings(workspace *models.Workspace) error
|
||||
DeleteWorkspaceTx(tx *sql.Tx, workspaceID int) error
|
||||
UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error
|
||||
UpdateLastOpenedFile(workspaceID int, filePath string) error
|
||||
GetLastOpenedFile(workspaceID int) (string, error)
|
||||
}
|
||||
|
||||
// WorkspaceStore defines the methods for interacting with workspace data in the database
|
||||
type WorkspaceStore interface {
|
||||
WorkspaceReader
|
||||
WorkspaceWriter
|
||||
}
|
||||
|
||||
// SessionStore defines the methods for interacting with jwt sessions in the database
|
||||
type SessionStore interface {
|
||||
CreateSession(session *models.Session) error
|
||||
GetSessionByRefreshToken(refreshToken string) (*models.Session, error)
|
||||
DeleteSession(sessionID string) error
|
||||
CleanExpiredSessions() error
|
||||
}
|
||||
|
||||
// SystemStore defines the methods for interacting with system settings and stats in the database
|
||||
type SystemStore interface {
|
||||
GetSystemStats() (*UserStats, error)
|
||||
EnsureJWTSecret() (string, error)
|
||||
GetSystemSetting(key string) (string, error)
|
||||
SetSystemSetting(key, value string) error
|
||||
}
|
||||
|
||||
// Database defines the methods for interacting with the database
|
||||
type Database interface {
|
||||
UserStore
|
||||
WorkspaceStore
|
||||
SessionStore
|
||||
SystemStore
|
||||
Begin() (*sql.Tx, error)
|
||||
Close() error
|
||||
Migrate() error
|
||||
}
|
||||
|
||||
var (
|
||||
// Main Database interface
|
||||
_ Database = (*database)(nil)
|
||||
|
||||
// Component interfaces
|
||||
_ UserStore = (*database)(nil)
|
||||
_ WorkspaceStore = (*database)(nil)
|
||||
_ SessionStore = (*database)(nil)
|
||||
_ SystemStore = (*database)(nil)
|
||||
|
||||
// Sub-interfaces
|
||||
_ WorkspaceReader = (*database)(nil)
|
||||
_ WorkspaceWriter = (*database)(nil)
|
||||
)
|
||||
|
||||
// database represents the database connection
|
||||
type database struct {
|
||||
*sql.DB
|
||||
crypto *crypto.Crypto
|
||||
secretsService secrets.Service
|
||||
}
|
||||
|
||||
// Init initializes the database connection
|
||||
func Init(dbPath string, encryptionKey string) (*DB, error) {
|
||||
func Init(dbPath string, secretsService secrets.Service) (Database, error) {
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -26,40 +108,35 @@ func Init(dbPath string, encryptionKey string) (*DB, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize crypto service
|
||||
cryptoService, err := crypto.New(encryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize encryption: %w", err)
|
||||
}
|
||||
|
||||
database := &DB{
|
||||
DB: db,
|
||||
crypto: cryptoService,
|
||||
}
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
// Enable foreign keys for this connection
|
||||
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
database := &database{
|
||||
DB: db,
|
||||
secretsService: secretsService,
|
||||
}
|
||||
|
||||
return database, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (db *DB) Close() error {
|
||||
func (db *database) Close() error {
|
||||
return db.DB.Close()
|
||||
}
|
||||
|
||||
// Helper methods for token encryption/decryption
|
||||
func (db *DB) encryptToken(token string) (string, error) {
|
||||
func (db *database) encryptToken(token string) (string, error) {
|
||||
if token == "" {
|
||||
return "", nil
|
||||
}
|
||||
return db.crypto.Encrypt(token)
|
||||
return db.secretsService.Encrypt(token)
|
||||
}
|
||||
|
||||
func (db *DB) decryptToken(token string) (string, error) {
|
||||
func (db *database) decryptToken(token string) (string, error) {
|
||||
if token == "" {
|
||||
return "", nil
|
||||
}
|
||||
return db.crypto.Decrypt(token)
|
||||
return db.secretsService.Decrypt(token)
|
||||
}
|
||||
|
||||
@@ -83,7 +83,7 @@ var migrations = []Migration{
|
||||
}
|
||||
|
||||
// Migrate applies all database migrations
|
||||
func (db *DB) Migrate() error {
|
||||
func (db *database) Migrate() error {
|
||||
// Create migrations table if it doesn't exist
|
||||
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations (
|
||||
version INTEGER PRIMARY KEY
|
||||
|
||||
151
server/internal/db/migrations_test.go
Normal file
151
server/internal/db/migrations_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"novamd/internal/db"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to initialize database: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
t.Run("migrations are applied in order", func(t *testing.T) {
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("failed to run initial migrations: %v", err)
|
||||
}
|
||||
|
||||
// Check migration version
|
||||
var version int
|
||||
err := database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get migration version: %v", err)
|
||||
}
|
||||
|
||||
if version != 2 { // Current number of migrations in production code
|
||||
t.Errorf("expected migration version 2, got %d", version)
|
||||
}
|
||||
|
||||
// Verify number of migration entries matches versions applied
|
||||
var count int
|
||||
err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to count migrations: %v", err)
|
||||
}
|
||||
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 migration entries, got %d", count)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("migrations create expected schema", func(t *testing.T) {
|
||||
// Verify tables exist
|
||||
tables := []string{"users", "workspaces", "sessions", "system_settings", "migrations"}
|
||||
for _, table := range tables {
|
||||
if !tableExists(t, database, table) {
|
||||
t.Errorf("table %q does not exist", table)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify indexes
|
||||
indexes := []struct {
|
||||
table string
|
||||
name string
|
||||
}{
|
||||
{"sessions", "idx_sessions_user_id"},
|
||||
{"sessions", "idx_sessions_expires_at"},
|
||||
{"sessions", "idx_sessions_refresh_token"},
|
||||
}
|
||||
|
||||
for _, idx := range indexes {
|
||||
if !indexExists(t, database, idx.table, idx.name) {
|
||||
t.Errorf("index %q on table %q does not exist", idx.name, idx.table)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("migrations are idempotent", func(t *testing.T) {
|
||||
// Run migrations again
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("failed to re-run migrations: %v", err)
|
||||
}
|
||||
|
||||
// Verify migration count hasn't changed
|
||||
var count int
|
||||
err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to count migrations: %v", err)
|
||||
}
|
||||
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 migration entries, got %d", count)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rollback on migration failure", func(t *testing.T) {
|
||||
// Create a test table that would conflict with a failing migration
|
||||
_, err := database.TestDB().Exec("CREATE TABLE test_rollback (id INTEGER PRIMARY KEY)")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test table: %v", err)
|
||||
}
|
||||
|
||||
// Start transaction
|
||||
tx, err := database.Begin()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start transaction: %v", err)
|
||||
}
|
||||
|
||||
// Try operations that should fail and rollback
|
||||
_, err = tx.Exec(`
|
||||
CREATE TABLE test_rollback (id INTEGER PRIMARY KEY);
|
||||
INSERT INTO nonexistent_table VALUES (1);
|
||||
`)
|
||||
if err == nil {
|
||||
tx.Rollback()
|
||||
t.Fatal("expected migration to fail")
|
||||
}
|
||||
tx.Rollback()
|
||||
|
||||
// Verify the migration version hasn't changed
|
||||
var version int
|
||||
err = database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get migration version: %v", err)
|
||||
}
|
||||
|
||||
if version != 2 {
|
||||
t.Errorf("expected migration version to remain at 2, got %d", version)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func tableExists(t *testing.T, database db.TestDatabase, tableName string) bool {
|
||||
t.Helper()
|
||||
|
||||
var name string
|
||||
err := database.TestDB().QueryRow(`
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name=?`,
|
||||
tableName,
|
||||
).Scan(&name)
|
||||
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func indexExists(t *testing.T, database db.TestDatabase, tableName, indexName string) bool {
|
||||
t.Helper()
|
||||
|
||||
var name string
|
||||
err := database.TestDB().QueryRow(`
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='index' AND tbl_name=? AND name=?`,
|
||||
tableName, indexName,
|
||||
).Scan(&name)
|
||||
|
||||
return err == nil
|
||||
}
|
||||
6
server/internal/db/mock_secrets_test.go
Normal file
6
server/internal/db/mock_secrets_test.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package db_test
|
||||
|
||||
type mockSecrets struct{}
|
||||
|
||||
func (m *mockSecrets) Encrypt(s string) (string, error) { return s, nil }
|
||||
func (m *mockSecrets) Decrypt(s string) (string, error) { return s, nil }
|
||||
70
server/internal/db/sessions.go
Normal file
70
server/internal/db/sessions.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"novamd/internal/models"
|
||||
)
|
||||
|
||||
// CreateSession inserts a new session record into the database
|
||||
func (db *database) CreateSession(session *models.Session) error {
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSessionByRefreshToken retrieves a session by its refresh token
|
||||
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
|
||||
session := &models.Session{}
|
||||
err := db.QueryRow(`
|
||||
SELECT id, user_id, refresh_token, expires_at, created_at
|
||||
FROM sessions
|
||||
WHERE refresh_token = ? AND expires_at > ?`,
|
||||
refreshToken, time.Now(),
|
||||
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch session: %w", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// DeleteSession removes a session from the database
|
||||
func (db *database) DeleteSession(sessionID string) error {
|
||||
result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete session: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanExpiredSessions removes all expired sessions from the database
|
||||
func (db *database) CleanExpiredSessions() error {
|
||||
_, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clean expired sessions: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
294
server/internal/db/sessions_test.go
Normal file
294
server/internal/db/sessions_test.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestSessionOperations(t *testing.T) {
|
||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test database: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
// Create a test user first since sessions need a valid user ID
|
||||
user, err := database.CreateUser(&models.User{
|
||||
Email: "test@example.com",
|
||||
DisplayName: "Test User",
|
||||
PasswordHash: "hash",
|
||||
Role: "editor",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
t.Run("CreateSession", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
session *models.Session
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid session",
|
||||
session: &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: "valid-token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
session: &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: 99999, // Non-existent user ID
|
||||
RefreshToken: "invalid-user-token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "FOREIGN KEY constraint failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := database.CreateSession(tc.session)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify session was stored
|
||||
stored, err := database.GetSessionByRefreshToken(tc.session.RefreshToken)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to retrieve stored session: %v", err)
|
||||
}
|
||||
|
||||
// Compare fields
|
||||
if stored.ID != tc.session.ID {
|
||||
t.Errorf("ID = %v, want %v", stored.ID, tc.session.ID)
|
||||
}
|
||||
if stored.UserID != tc.session.UserID {
|
||||
t.Errorf("UserID = %v, want %v", stored.UserID, tc.session.UserID)
|
||||
}
|
||||
if stored.RefreshToken != tc.session.RefreshToken {
|
||||
t.Errorf("RefreshToken = %v, want %v", stored.RefreshToken, tc.session.RefreshToken)
|
||||
}
|
||||
// Compare times within a reasonable threshold
|
||||
if diff := stored.ExpiresAt.Sub(tc.session.ExpiresAt); diff > time.Second || diff < -time.Second {
|
||||
t.Errorf("ExpiresAt differs by %v, want difference less than 1s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetSessionByRefreshToken", func(t *testing.T) {
|
||||
// Create test sessions
|
||||
validSession := &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: "valid-get-token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
expiredSession := &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: "expired-token",
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired
|
||||
CreatedAt: time.Now().Add(-2 * time.Hour),
|
||||
}
|
||||
|
||||
if err := database.CreateSession(validSession); err != nil {
|
||||
t.Fatalf("failed to create valid session: %v", err)
|
||||
}
|
||||
if err := database.CreateSession(expiredSession); err != nil {
|
||||
t.Fatalf("failed to create expired session: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
refreshToken string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid token",
|
||||
refreshToken: "valid-get-token",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "expired token",
|
||||
refreshToken: "expired-token",
|
||||
wantErr: true,
|
||||
errContains: "session not found or expired",
|
||||
},
|
||||
{
|
||||
name: "non-existent token",
|
||||
refreshToken: "nonexistent-token",
|
||||
wantErr: true,
|
||||
errContains: "session not found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := database.GetSessionByRefreshToken(tc.refreshToken)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if session.RefreshToken != tc.refreshToken {
|
||||
t.Errorf("RefreshToken = %v, want %v", session.RefreshToken, tc.refreshToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DeleteSession", func(t *testing.T) {
|
||||
session := &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: "delete-token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := database.CreateSession(session); err != nil {
|
||||
t.Fatalf("failed to create session: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
sessionID string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid session ID",
|
||||
sessionID: session.ID,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent session ID",
|
||||
sessionID: "nonexistent-id",
|
||||
wantErr: true,
|
||||
errContains: "session not found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := database.DeleteSession(tc.sessionID)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify session was deleted
|
||||
_, err = database.GetSessionByRefreshToken(session.RefreshToken)
|
||||
if err == nil {
|
||||
t.Error("session still exists after deletion")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CleanExpiredSessions", func(t *testing.T) {
|
||||
// Create a mix of valid and expired sessions
|
||||
sessions := []*models.Session{
|
||||
{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: "valid-clean-token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: "expired-clean-token-1",
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour),
|
||||
CreatedAt: time.Now().Add(-2 * time.Hour),
|
||||
},
|
||||
{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: "expired-clean-token-2",
|
||||
ExpiresAt: time.Now().Add(-2 * time.Hour),
|
||||
CreatedAt: time.Now().Add(-3 * time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range sessions {
|
||||
if err := database.CreateSession(s); err != nil {
|
||||
t.Fatalf("failed to create session: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean expired sessions
|
||||
if err := database.CleanExpiredSessions(); err != nil {
|
||||
t.Fatalf("failed to clean expired sessions: %v", err)
|
||||
}
|
||||
|
||||
// Verify valid session still exists
|
||||
validSession, err := database.GetSessionByRefreshToken("valid-clean-token")
|
||||
if err != nil {
|
||||
t.Errorf("valid session was unexpectedly deleted: %v", err)
|
||||
}
|
||||
if validSession == nil {
|
||||
t.Error("valid session was unexpectedly deleted")
|
||||
}
|
||||
|
||||
// Verify expired sessions were deleted
|
||||
expiredTokens := []string{"expired-clean-token-1", "expired-clean-token-2"}
|
||||
for _, token := range expiredTokens {
|
||||
if _, err := database.GetSessionByRefreshToken(token); err == nil {
|
||||
t.Errorf("expired session with token %s still exists", token)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -11,9 +11,16 @@ const (
|
||||
JWTSecretKey = "jwt_secret"
|
||||
)
|
||||
|
||||
// UserStats represents system-wide statistics
|
||||
type UserStats struct {
|
||||
TotalUsers int `json:"totalUsers"`
|
||||
TotalWorkspaces int `json:"totalWorkspaces"`
|
||||
ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days
|
||||
}
|
||||
|
||||
// EnsureJWTSecret makes sure a JWT signing secret exists in the database
|
||||
// If no secret exists, it generates and stores a new one
|
||||
func (db *DB) EnsureJWTSecret() (string, error) {
|
||||
func (db *database) EnsureJWTSecret() (string, error) {
|
||||
// First, try to get existing secret
|
||||
secret, err := db.GetSystemSetting(JWTSecretKey)
|
||||
if err == nil {
|
||||
@@ -36,7 +43,7 @@ func (db *DB) EnsureJWTSecret() (string, error) {
|
||||
}
|
||||
|
||||
// GetSystemSetting retrieves a system setting by key
|
||||
func (db *DB) GetSystemSetting(key string) (string, error) {
|
||||
func (db *database) GetSystemSetting(key string) (string, error) {
|
||||
var value string
|
||||
err := db.QueryRow("SELECT value FROM system_settings WHERE key = ?", key).Scan(&value)
|
||||
if err != nil {
|
||||
@@ -46,7 +53,7 @@ func (db *DB) GetSystemSetting(key string) (string, error) {
|
||||
}
|
||||
|
||||
// SetSystemSetting stores or updates a system setting
|
||||
func (db *DB) SetSystemSetting(key, value string) error {
|
||||
func (db *database) SetSystemSetting(key, value string) error {
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO system_settings (key, value)
|
||||
VALUES (?, ?)
|
||||
@@ -64,3 +71,32 @@ func generateRandomSecret(bytes int) (string, error) {
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// GetSystemStats returns system-wide statistics
|
||||
func (db *database) GetSystemStats() (*UserStats, error) {
|
||||
stats := &UserStats{}
|
||||
|
||||
// Get total users
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get total workspaces
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get active users (users with activity in last 30 days)
|
||||
err = db.QueryRow(`
|
||||
SELECT COUNT(DISTINCT user_id)
|
||||
FROM sessions
|
||||
WHERE created_at > datetime('now', '-30 days')`).
|
||||
Scan(&stats.ActiveUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
213
server/internal/db/system_test.go
Normal file
213
server/internal/db/system_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestSystemOperations(t *testing.T) {
|
||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test database: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
t.Run("GetSystemSettings", func(t *testing.T) {
|
||||
t.Run("non-existent setting", func(t *testing.T) {
|
||||
_, err := database.GetSystemSetting("nonexistent-key")
|
||||
if err == nil {
|
||||
t.Error("expected error for non-existent key, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("existing setting", func(t *testing.T) {
|
||||
// First set a value
|
||||
err := database.SetSystemSetting("test-key", "test-value")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set system setting: %v", err)
|
||||
}
|
||||
|
||||
// Then get it back
|
||||
value, err := database.GetSystemSetting("test-key")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get system setting: %v", err)
|
||||
}
|
||||
|
||||
if value != "test-value" {
|
||||
t.Errorf("got value %q, want %q", value, "test-value")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("SetSystemSettings", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "new setting",
|
||||
key: "new-key",
|
||||
value: "new-value",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "update existing setting",
|
||||
key: "update-key",
|
||||
value: "original-value",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := database.SetSystemSetting(tc.key, tc.value)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the setting was stored
|
||||
stored, err := database.GetSystemSetting(tc.key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to retrieve stored setting: %v", err)
|
||||
}
|
||||
if stored != tc.value {
|
||||
t.Errorf("got value %q, want %q", stored, tc.value)
|
||||
}
|
||||
|
||||
// For the update case, test updating the value
|
||||
if tc.name == "update existing setting" {
|
||||
newValue := "updated-value"
|
||||
err := database.SetSystemSetting(tc.key, newValue)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to update setting: %v", err)
|
||||
}
|
||||
|
||||
stored, err := database.GetSystemSetting(tc.key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to retrieve updated setting: %v", err)
|
||||
}
|
||||
if stored != newValue {
|
||||
t.Errorf("got updated value %q, want %q", stored, newValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EnsureJWTSecret", func(t *testing.T) {
|
||||
// First call should generate a new secret
|
||||
secret1, err := database.EnsureJWTSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to ensure JWT secret: %v", err)
|
||||
}
|
||||
|
||||
// Verify the secret is a valid base64-encoded string of sufficient length
|
||||
decoded, err := base64.StdEncoding.DecodeString(secret1)
|
||||
if err != nil {
|
||||
t.Errorf("secret is not valid base64: %v", err)
|
||||
}
|
||||
if len(decoded) < 32 {
|
||||
t.Errorf("secret length = %d, want >= 32", len(decoded))
|
||||
}
|
||||
|
||||
// Second call should return the same secret
|
||||
secret2, err := database.EnsureJWTSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get existing JWT secret: %v", err)
|
||||
}
|
||||
|
||||
if secret2 != secret1 {
|
||||
t.Errorf("got different secret on second call")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetSystemStats", func(t *testing.T) {
|
||||
// Create some test users and sessions
|
||||
users := []*models.User{
|
||||
{
|
||||
Email: "user1@test.com",
|
||||
DisplayName: "User 1",
|
||||
PasswordHash: "hash1",
|
||||
Role: "editor",
|
||||
},
|
||||
{
|
||||
Email: "user2@test.com",
|
||||
DisplayName: "User 2",
|
||||
PasswordHash: "hash2",
|
||||
Role: "viewer",
|
||||
},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
createdUser, err := database.CreateUser(u)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
// Create multiple workspaces per user
|
||||
// Each user has one default workspace
|
||||
for i := 0; i < 2; i++ {
|
||||
workspace := &models.Workspace{
|
||||
UserID: createdUser.ID,
|
||||
Name: fmt.Sprintf("Workspace %d", i),
|
||||
}
|
||||
if err := database.CreateWorkspace(workspace); err != nil {
|
||||
t.Fatalf("failed to create test workspace: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create an active session for the first user
|
||||
if createdUser.Email == "user1@test.com" {
|
||||
session := &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: createdUser.ID,
|
||||
RefreshToken: "test-token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := database.CreateSession(session); err != nil {
|
||||
t.Fatalf("failed to create test session: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := database.GetSystemStats()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get system stats: %v", err)
|
||||
}
|
||||
|
||||
// Verify stats
|
||||
if stats.TotalUsers != 2 {
|
||||
t.Errorf("TotalUsers = %d, want 2", stats.TotalUsers)
|
||||
}
|
||||
if stats.TotalWorkspaces != 6 { // 2 + 1 default workspace per user
|
||||
t.Errorf("TotalWorkspaces = %d, want 6", stats.TotalWorkspaces)
|
||||
}
|
||||
if stats.ActiveUsers != 1 { // Only user1 has an active session
|
||||
t.Errorf("ActiveUsers = %d, want 1", stats.ActiveUsers)
|
||||
}
|
||||
})
|
||||
}
|
||||
30
server/internal/db/testdb.go
Normal file
30
server/internal/db/testdb.go
Normal file
@@ -0,0 +1,30 @@
|
||||
//go:build test
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"novamd/internal/secrets"
|
||||
)
|
||||
|
||||
type TestDatabase interface {
|
||||
Database
|
||||
TestDB() *sql.DB
|
||||
}
|
||||
|
||||
func NewTestDB(dbPath string, secretsService secrets.Service) (TestDatabase, error) {
|
||||
db, err := Init(dbPath, secretsService)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &testDatabase{db.(*database)}, nil
|
||||
}
|
||||
|
||||
type testDatabase struct {
|
||||
*database
|
||||
}
|
||||
|
||||
func (td *testDatabase) TestDB() *sql.DB {
|
||||
return td.DB
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
// CreateUser inserts a new user record into the database
|
||||
func (db *DB) CreateUser(user *models.User) (*models.User, error) {
|
||||
func (db *database) CreateUser(user *models.User) (*models.User, error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -38,7 +38,7 @@ func (db *DB) CreateUser(user *models.User) (*models.User, error) {
|
||||
UserID: user.ID,
|
||||
Name: "Main",
|
||||
}
|
||||
defaultWorkspace.GetDefaultSettings() // Initialize default settings
|
||||
defaultWorkspace.SetDefaultSettings() // Initialize default settings
|
||||
|
||||
// Create workspace with settings
|
||||
err = db.createWorkspaceTx(tx, defaultWorkspace)
|
||||
@@ -62,14 +62,14 @@ func (db *DB) CreateUser(user *models.User) (*models.User, error) {
|
||||
}
|
||||
|
||||
// Helper function to create a workspace in a transaction
|
||||
func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
|
||||
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
|
||||
result, err := tx.Exec(`
|
||||
INSERT INTO workspaces (
|
||||
user_id, name,
|
||||
theme, auto_save, show_hidden_files,
|
||||
git_enabled, git_url, git_user, git_token,
|
||||
git_auto_commit, git_commit_msg_template
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
workspace.UserID, workspace.Name,
|
||||
workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles,
|
||||
workspace.GitEnabled, workspace.GitURL, workspace.GitUser, workspace.GitToken,
|
||||
@@ -87,7 +87,7 @@ func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
|
||||
}
|
||||
|
||||
// GetUserByID retrieves a user by ID
|
||||
func (db *DB) GetUserByID(id int) (*models.User, error) {
|
||||
func (db *database) GetUserByID(id int) (*models.User, error) {
|
||||
user := &models.User{}
|
||||
err := db.QueryRow(`
|
||||
SELECT
|
||||
@@ -104,7 +104,7 @@ func (db *DB) GetUserByID(id int) (*models.User, error) {
|
||||
}
|
||||
|
||||
// GetUserByEmail retrieves a user by email
|
||||
func (db *DB) GetUserByEmail(email string) (*models.User, error) {
|
||||
func (db *database) GetUserByEmail(email string) (*models.User, error) {
|
||||
user := &models.User{}
|
||||
err := db.QueryRow(`
|
||||
SELECT
|
||||
@@ -122,7 +122,7 @@ func (db *DB) GetUserByEmail(email string) (*models.User, error) {
|
||||
}
|
||||
|
||||
// UpdateUser updates a user's information
|
||||
func (db *DB) UpdateUser(user *models.User) error {
|
||||
func (db *database) UpdateUser(user *models.User) error {
|
||||
_, err := db.Exec(`
|
||||
UPDATE users
|
||||
SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ?
|
||||
@@ -131,8 +131,36 @@ func (db *DB) UpdateUser(user *models.User) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAllUsers returns a list of all users in the system
|
||||
func (db *database) GetAllUsers() ([]*models.User, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
id, email, display_name, role, created_at,
|
||||
last_workspace_id
|
||||
FROM users
|
||||
ORDER BY id ASC`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []*models.User
|
||||
for rows.Next() {
|
||||
user := &models.User{}
|
||||
err := rows.Scan(
|
||||
&user.ID, &user.Email, &user.DisplayName, &user.Role,
|
||||
&user.CreatedAt, &user.LastWorkspaceID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// UpdateLastWorkspace updates the last workspace the user accessed
|
||||
func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error {
|
||||
func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -155,7 +183,7 @@ func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error {
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user and all their workspaces
|
||||
func (db *DB) DeleteUser(id int) error {
|
||||
func (db *database) DeleteUser(id int) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -178,7 +206,7 @@ func (db *DB) DeleteUser(id int) error {
|
||||
}
|
||||
|
||||
// GetLastWorkspaceName returns the name of the last workspace the user accessed
|
||||
func (db *DB) GetLastWorkspaceName(userID int) (string, error) {
|
||||
func (db *database) GetLastWorkspaceName(userID int) (string, error) {
|
||||
var workspaceName string
|
||||
err := db.QueryRow(`
|
||||
SELECT
|
||||
@@ -189,3 +217,10 @@ func (db *DB) GetLastWorkspaceName(userID int) (string, error) {
|
||||
Scan(&workspaceName)
|
||||
return workspaceName, err
|
||||
}
|
||||
|
||||
// CountAdminUsers returns the number of admin users in the system
|
||||
func (db *database) CountAdminUsers() (int, error) {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
413
server/internal/db/users_test.go
Normal file
413
server/internal/db/users_test.go
Normal file
@@ -0,0 +1,413 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/models"
|
||||
)
|
||||
|
||||
func TestUserOperations(t *testing.T) {
|
||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test database: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
t.Run("CreateUser", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
user *models.User
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid user",
|
||||
user: &models.User{
|
||||
Email: "test@example.com",
|
||||
DisplayName: "Test User",
|
||||
PasswordHash: "hashed_password",
|
||||
Role: models.RoleEditor,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "duplicate email",
|
||||
user: &models.User{
|
||||
Email: "test@example.com", // Same as above
|
||||
DisplayName: "Another User",
|
||||
PasswordHash: "different_hash",
|
||||
Role: models.RoleViewer,
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "UNIQUE constraint failed",
|
||||
},
|
||||
{
|
||||
name: "invalid role",
|
||||
user: &models.User{
|
||||
Email: "invalid@example.com",
|
||||
DisplayName: "Invalid Role User",
|
||||
PasswordHash: "hash",
|
||||
Role: "invalid_role",
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "CHECK constraint failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
user, err := database.CreateUser(tc.user)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify user was created properly
|
||||
if user.ID == 0 {
|
||||
t.Error("expected non-zero user ID")
|
||||
}
|
||||
if user.Email != tc.user.Email {
|
||||
t.Errorf("Email = %v, want %v", user.Email, tc.user.Email)
|
||||
}
|
||||
if user.DisplayName != tc.user.DisplayName {
|
||||
t.Errorf("DisplayName = %v, want %v", user.DisplayName, tc.user.DisplayName)
|
||||
}
|
||||
if user.Role != tc.user.Role {
|
||||
t.Errorf("Role = %v, want %v", user.Role, tc.user.Role)
|
||||
}
|
||||
if user.CreatedAt.IsZero() {
|
||||
t.Error("CreatedAt should not be zero")
|
||||
}
|
||||
if user.LastWorkspaceID == 0 {
|
||||
t.Error("expected non-zero LastWorkspaceID (default workspace)")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetUserByID", func(t *testing.T) {
|
||||
// Create a test user first
|
||||
createdUser, err := database.CreateUser(&models.User{
|
||||
Email: "getbyid@example.com",
|
||||
DisplayName: "Get By ID User",
|
||||
PasswordHash: "hash",
|
||||
Role: models.RoleEditor,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "existing user",
|
||||
userID: createdUser.ID,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent user",
|
||||
userID: 99999,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
user, err := database.GetUserByID(tc.userID)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if user.ID != tc.userID {
|
||||
t.Errorf("ID = %v, want %v", user.ID, tc.userID)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetUserByEmail", func(t *testing.T) {
|
||||
// Create a test user first
|
||||
createdUser, err := database.CreateUser(&models.User{
|
||||
Email: "getbyemail@example.com",
|
||||
DisplayName: "Get By Email User",
|
||||
PasswordHash: "hash",
|
||||
Role: models.RoleEditor,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
email string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "existing user",
|
||||
email: createdUser.Email,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent user",
|
||||
email: "nonexistent@example.com",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
user, err := database.GetUserByEmail(tc.email)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if user.Email != tc.email {
|
||||
t.Errorf("Email = %v, want %v", user.Email, tc.email)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpdateUser", func(t *testing.T) {
|
||||
// Create a test user first
|
||||
user, err := database.CreateUser(&models.User{
|
||||
Email: "update@example.com",
|
||||
DisplayName: "Original Name",
|
||||
PasswordHash: "original_hash",
|
||||
Role: models.RoleEditor,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
// Update user details
|
||||
user.DisplayName = "Updated Name"
|
||||
user.PasswordHash = "new_hash"
|
||||
user.Role = models.RoleAdmin
|
||||
|
||||
if err := database.UpdateUser(user); err != nil {
|
||||
t.Fatalf("failed to update user: %v", err)
|
||||
}
|
||||
|
||||
// Verify updates
|
||||
updated, err := database.GetUserByID(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get updated user: %v", err)
|
||||
}
|
||||
|
||||
if updated.DisplayName != "Updated Name" {
|
||||
t.Errorf("DisplayName = %v, want %v", updated.DisplayName, "Updated Name")
|
||||
}
|
||||
if updated.PasswordHash != "new_hash" {
|
||||
t.Errorf("PasswordHash = %v, want %v", updated.PasswordHash, "new_hash")
|
||||
}
|
||||
if updated.Role != models.RoleAdmin {
|
||||
t.Errorf("Role = %v, want %v", updated.Role, models.RoleAdmin)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetAllUsers", func(t *testing.T) {
|
||||
// Create several test users
|
||||
testUsers := []*models.User{
|
||||
{
|
||||
Email: "user1@example.com",
|
||||
DisplayName: "User One",
|
||||
PasswordHash: "hash1",
|
||||
Role: models.RoleEditor,
|
||||
},
|
||||
{
|
||||
Email: "user2@example.com",
|
||||
DisplayName: "User Two",
|
||||
PasswordHash: "hash2",
|
||||
Role: models.RoleViewer,
|
||||
},
|
||||
}
|
||||
|
||||
for _, u := range testUsers {
|
||||
_, err := database.CreateUser(u)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get all users
|
||||
users, err := database.GetAllUsers()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get all users: %v", err)
|
||||
}
|
||||
|
||||
// We should have at least as many users as we just created
|
||||
// (there might be more from previous tests)
|
||||
if len(users) < len(testUsers) {
|
||||
t.Errorf("got %d users, want at least %d", len(users), len(testUsers))
|
||||
}
|
||||
|
||||
// Verify each test user exists in the result
|
||||
for _, expected := range testUsers {
|
||||
found := false
|
||||
for _, u := range users {
|
||||
if u.Email == expected.Email {
|
||||
found = true
|
||||
if u.DisplayName != expected.DisplayName {
|
||||
t.Errorf("DisplayName = %v, want %v", u.DisplayName, expected.DisplayName)
|
||||
}
|
||||
if u.Role != expected.Role {
|
||||
t.Errorf("Role = %v, want %v", u.Role, expected.Role)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("user with email %s not found in results", expected.Email)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpdateLastWorkspace", func(t *testing.T) {
|
||||
// Create a test user with multiple workspaces
|
||||
user, err := database.CreateUser(&models.User{
|
||||
Email: "workspace@example.com",
|
||||
DisplayName: "Workspace User",
|
||||
PasswordHash: "hash",
|
||||
Role: models.RoleEditor,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
// Create additional workspace
|
||||
workspace := &models.Workspace{
|
||||
UserID: user.ID,
|
||||
Name: "Second Workspace",
|
||||
}
|
||||
if err := database.CreateWorkspace(workspace); err != nil {
|
||||
t.Fatalf("failed to create additional workspace: %v", err)
|
||||
}
|
||||
|
||||
// Update last workspace
|
||||
err = database.UpdateLastWorkspace(user.ID, workspace.Name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to update last workspace: %v", err)
|
||||
}
|
||||
|
||||
// Verify update
|
||||
lastWorkspace, err := database.GetLastWorkspaceName(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get last workspace: %v", err)
|
||||
}
|
||||
|
||||
if lastWorkspace != workspace.Name {
|
||||
t.Errorf("LastWorkspace = %v, want %v", lastWorkspace, workspace.Name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DeleteUser", func(t *testing.T) {
|
||||
// Create a test user
|
||||
user, err := database.CreateUser(&models.User{
|
||||
Email: "delete@example.com",
|
||||
DisplayName: "Delete User",
|
||||
PasswordHash: "hash",
|
||||
Role: models.RoleEditor,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
// Delete the user
|
||||
if err := database.DeleteUser(user.ID); err != nil {
|
||||
t.Fatalf("failed to delete user: %v", err)
|
||||
}
|
||||
|
||||
// Verify user is gone
|
||||
_, err = database.GetUserByID(user.ID)
|
||||
if err == nil {
|
||||
t.Error("expected error getting deleted user, got nil")
|
||||
}
|
||||
|
||||
// Verify workspaces are gone
|
||||
workspaces, err := database.GetWorkspacesByUserID(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error checking workspaces: %v", err)
|
||||
}
|
||||
if len(workspaces) > 0 {
|
||||
t.Error("expected no workspaces for deleted user")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CountAdminUsers", func(t *testing.T) {
|
||||
// Create users with different roles
|
||||
testUsers := []*models.User{
|
||||
{
|
||||
Email: "admin1@example.com",
|
||||
DisplayName: "Admin One",
|
||||
PasswordHash: "hash1",
|
||||
Role: models.RoleAdmin,
|
||||
},
|
||||
{
|
||||
Email: "admin2@example.com",
|
||||
DisplayName: "Admin Two",
|
||||
PasswordHash: "hash2",
|
||||
Role: models.RoleAdmin,
|
||||
},
|
||||
{
|
||||
Email: "editor@example.com",
|
||||
DisplayName: "Editor",
|
||||
PasswordHash: "hash3",
|
||||
Role: models.RoleEditor,
|
||||
},
|
||||
}
|
||||
|
||||
for _, u := range testUsers {
|
||||
_, err := database.CreateUser(u)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Count admin users
|
||||
count, err := database.CountAdminUsers()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to count admin users: %v", err)
|
||||
}
|
||||
|
||||
// We should have at least 2 admin users (from our test cases)
|
||||
// There might be more from previous tests
|
||||
if count < 2 {
|
||||
t.Errorf("AdminCount = %d, want at least 2", count)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
)
|
||||
|
||||
// CreateWorkspace inserts a new workspace record into the database
|
||||
func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
|
||||
func (db *database) CreateWorkspace(workspace *models.Workspace) error {
|
||||
// Set default settings if not provided
|
||||
if workspace.Theme == "" {
|
||||
workspace.GetDefaultSettings()
|
||||
workspace.SetDefaultSettings()
|
||||
}
|
||||
|
||||
// Encrypt token if present
|
||||
@@ -24,7 +24,7 @@ func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
|
||||
user_id, name, theme, auto_save, show_hidden_files,
|
||||
git_enabled, git_url, git_user, git_token,
|
||||
git_auto_commit, git_commit_msg_template
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
workspace.UserID, workspace.Name, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles,
|
||||
workspace.GitEnabled, workspace.GitURL, workspace.GitUser, encryptedToken,
|
||||
workspace.GitAutoCommit, workspace.GitCommitMsgTemplate,
|
||||
@@ -42,7 +42,7 @@ func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
|
||||
}
|
||||
|
||||
// GetWorkspaceByID retrieves a workspace by its ID
|
||||
func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) {
|
||||
func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
|
||||
workspace := &models.Workspace{}
|
||||
var encryptedToken string
|
||||
|
||||
@@ -75,7 +75,7 @@ func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) {
|
||||
}
|
||||
|
||||
// GetWorkspaceByName retrieves a workspace by its name and user ID
|
||||
func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
|
||||
func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
|
||||
workspace := &models.Workspace{}
|
||||
var encryptedToken string
|
||||
|
||||
@@ -108,7 +108,7 @@ func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Work
|
||||
}
|
||||
|
||||
// UpdateWorkspace updates a workspace record in the database
|
||||
func (db *DB) UpdateWorkspace(workspace *models.Workspace) error {
|
||||
func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
|
||||
// Encrypt token before storing
|
||||
encryptedToken, err := db.encryptToken(workspace.GitToken)
|
||||
if err != nil {
|
||||
@@ -146,7 +146,7 @@ func (db *DB) UpdateWorkspace(workspace *models.Workspace) error {
|
||||
}
|
||||
|
||||
// GetWorkspacesByUserID retrieves all workspaces for a user
|
||||
func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
|
||||
func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
id, user_id, name, created_at,
|
||||
@@ -189,7 +189,7 @@ func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
|
||||
|
||||
// UpdateWorkspaceSettings updates only the settings portion of a workspace
|
||||
// This is useful when you don't want to modify the name or other core workspace properties
|
||||
func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error {
|
||||
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
|
||||
_, err := db.Exec(`
|
||||
UPDATE workspaces
|
||||
SET
|
||||
@@ -218,31 +218,31 @@ func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error {
|
||||
}
|
||||
|
||||
// DeleteWorkspace removes a workspace record from the database
|
||||
func (db *DB) DeleteWorkspace(id int) error {
|
||||
func (db *database) DeleteWorkspace(id int) error {
|
||||
_, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteWorkspaceTx removes a workspace record from the database within a transaction
|
||||
func (db *DB) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
|
||||
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
|
||||
_, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateLastWorkspaceTx sets the last workspace for a user in with a transaction
|
||||
func (db *DB) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
|
||||
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
|
||||
_, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateLastOpenedFile updates the last opened file path for a workspace
|
||||
func (db *DB) UpdateLastOpenedFile(workspaceID int, filePath string) error {
|
||||
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
|
||||
_, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?", filePath, workspaceID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetLastOpenedFile retrieves the last opened file path for a workspace
|
||||
func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) {
|
||||
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
|
||||
var filePath sql.NullString
|
||||
err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?", workspaceID).Scan(&filePath)
|
||||
if err != nil {
|
||||
@@ -255,7 +255,7 @@ func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) {
|
||||
}
|
||||
|
||||
// GetAllWorkspaces retrieves all workspaces in the database
|
||||
func (db *DB) GetAllWorkspaces() ([]*models.Workspace, error) {
|
||||
func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
id, user_id, name, created_at,
|
||||
|
||||
430
server/internal/db/workspaces_test.go
Normal file
430
server/internal/db/workspaces_test.go
Normal file
@@ -0,0 +1,430 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/models"
|
||||
)
|
||||
|
||||
func TestWorkspaceOperations(t *testing.T) {
|
||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test database: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
// Create a test user first
|
||||
user, err := database.CreateUser(&models.User{
|
||||
Email: "test@example.com",
|
||||
DisplayName: "Test User",
|
||||
PasswordHash: "hash",
|
||||
Role: models.RoleEditor,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
t.Run("CreateWorkspace", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
workspace *models.Workspace
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid workspace",
|
||||
workspace: &models.Workspace{
|
||||
UserID: user.ID,
|
||||
Name: "Test Workspace",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent user",
|
||||
workspace: &models.Workspace{
|
||||
UserID: 99999,
|
||||
Name: "Invalid User",
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "FOREIGN KEY constraint failed",
|
||||
},
|
||||
{
|
||||
name: "with git settings",
|
||||
workspace: &models.Workspace{
|
||||
UserID: user.ID,
|
||||
Name: "Git Workspace",
|
||||
Theme: "dark",
|
||||
AutoSave: true,
|
||||
ShowHiddenFiles: true,
|
||||
GitEnabled: true,
|
||||
GitURL: "https://github.com/user/repo",
|
||||
GitUser: "username",
|
||||
GitToken: "secret-token",
|
||||
GitAutoCommit: true,
|
||||
GitCommitMsgTemplate: "${action} ${filename}",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.workspace.Theme == "" {
|
||||
tc.workspace.SetDefaultSettings()
|
||||
}
|
||||
|
||||
err := database.CreateWorkspace(tc.workspace)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify workspace was created properly
|
||||
if tc.workspace.ID == 0 {
|
||||
t.Error("expected non-zero workspace ID")
|
||||
}
|
||||
|
||||
// Retrieve and verify workspace
|
||||
stored, err := database.GetWorkspaceByID(tc.workspace.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to retrieve workspace: %v", err)
|
||||
}
|
||||
|
||||
verifyWorkspace(t, stored, tc.workspace)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetWorkspaceByID", func(t *testing.T) {
|
||||
// Create a test workspace first
|
||||
workspace := &models.Workspace{
|
||||
UserID: user.ID,
|
||||
Name: "Get By ID Workspace",
|
||||
}
|
||||
workspace.SetDefaultSettings()
|
||||
if err := database.CreateWorkspace(workspace); err != nil {
|
||||
t.Fatalf("failed to create test workspace: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
workspaceID int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "existing workspace",
|
||||
workspaceID: workspace.ID,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent workspace",
|
||||
workspaceID: 99999,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result, err := database.GetWorkspaceByID(tc.workspaceID)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.ID != tc.workspaceID {
|
||||
t.Errorf("ID = %v, want %v", result.ID, tc.workspaceID)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetWorkspaceByName", func(t *testing.T) {
|
||||
// Create a test workspace first
|
||||
workspace := &models.Workspace{
|
||||
UserID: user.ID,
|
||||
Name: "Get By Name Workspace",
|
||||
}
|
||||
workspace.SetDefaultSettings()
|
||||
if err := database.CreateWorkspace(workspace); err != nil {
|
||||
t.Fatalf("failed to create test workspace: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
workspaceName string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "existing workspace",
|
||||
userID: user.ID,
|
||||
workspaceName: workspace.Name,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "wrong user ID",
|
||||
userID: 99999,
|
||||
workspaceName: workspace.Name,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent workspace",
|
||||
userID: user.ID,
|
||||
workspaceName: "Non-existent",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result, err := database.GetWorkspaceByName(tc.userID, tc.workspaceName)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Name != tc.workspaceName {
|
||||
t.Errorf("Name = %v, want %v", result.Name, tc.workspaceName)
|
||||
}
|
||||
if result.UserID != tc.userID {
|
||||
t.Errorf("UserID = %v, want %v", result.UserID, tc.userID)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpdateWorkspace", func(t *testing.T) {
|
||||
// Create a test workspace first
|
||||
workspace := &models.Workspace{
|
||||
UserID: user.ID,
|
||||
Name: "Update Workspace",
|
||||
}
|
||||
workspace.SetDefaultSettings()
|
||||
if err := database.CreateWorkspace(workspace); err != nil {
|
||||
t.Fatalf("failed to create test workspace: %v", err)
|
||||
}
|
||||
|
||||
// Update workspace settings
|
||||
workspace.Theme = "dark"
|
||||
workspace.AutoSave = true
|
||||
workspace.ShowHiddenFiles = true
|
||||
workspace.GitEnabled = true
|
||||
workspace.GitURL = "https://github.com/user/repo"
|
||||
workspace.GitUser = "username"
|
||||
workspace.GitToken = "new-token"
|
||||
workspace.GitAutoCommit = true
|
||||
workspace.GitCommitMsgTemplate = "custom ${filename}"
|
||||
|
||||
if err := database.UpdateWorkspace(workspace); err != nil {
|
||||
t.Fatalf("failed to update workspace: %v", err)
|
||||
}
|
||||
|
||||
// Verify updates
|
||||
updated, err := database.GetWorkspaceByID(workspace.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get updated workspace: %v", err)
|
||||
}
|
||||
|
||||
verifyWorkspace(t, updated, workspace)
|
||||
})
|
||||
|
||||
t.Run("GetWorkspacesByUserID", func(t *testing.T) {
|
||||
// Create several test workspaces
|
||||
testWorkspaces := []*models.Workspace{
|
||||
{
|
||||
UserID: user.ID,
|
||||
Name: "User Workspace 1",
|
||||
},
|
||||
{
|
||||
UserID: user.ID,
|
||||
Name: "User Workspace 2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, w := range testWorkspaces {
|
||||
w.SetDefaultSettings()
|
||||
if err := database.CreateWorkspace(w); err != nil {
|
||||
t.Fatalf("failed to create test workspace: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get all workspaces for user
|
||||
workspaces, err := database.GetWorkspacesByUserID(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get workspaces: %v", err)
|
||||
}
|
||||
|
||||
// We should have at least as many workspaces as we just created
|
||||
// (there might be more from previous tests)
|
||||
if len(workspaces) < len(testWorkspaces) {
|
||||
t.Errorf("got %d workspaces, want at least %d", len(workspaces), len(testWorkspaces))
|
||||
}
|
||||
|
||||
// Verify each test workspace exists in the result
|
||||
for _, expected := range testWorkspaces {
|
||||
found := false
|
||||
for _, w := range workspaces {
|
||||
if w.Name == expected.Name {
|
||||
found = true
|
||||
if w.UserID != expected.UserID {
|
||||
t.Errorf("UserID = %v, want %v", w.UserID, expected.UserID)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("workspace %s not found in results", expected.Name)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpdateLastOpenedFile", func(t *testing.T) {
|
||||
// Create a test workspace
|
||||
workspace := &models.Workspace{
|
||||
UserID: user.ID,
|
||||
Name: "Last File Workspace",
|
||||
}
|
||||
workspace.SetDefaultSettings()
|
||||
if err := database.CreateWorkspace(workspace); err != nil {
|
||||
t.Fatalf("failed to create test workspace: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
filePath string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid file path",
|
||||
filePath: "docs/test.md",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty file path",
|
||||
filePath: "",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := database.UpdateLastOpenedFile(workspace.ID, tc.filePath)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify update
|
||||
path, err := database.GetLastOpenedFile(workspace.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get last opened file: %v", err)
|
||||
}
|
||||
|
||||
if path != tc.filePath {
|
||||
t.Errorf("LastOpenedFile = %v, want %v", path, tc.filePath)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DeleteWorkspace", func(t *testing.T) {
|
||||
// Create a test workspace
|
||||
workspace := &models.Workspace{
|
||||
UserID: user.ID,
|
||||
Name: "Delete Workspace",
|
||||
}
|
||||
workspace.SetDefaultSettings()
|
||||
if err := database.CreateWorkspace(workspace); err != nil {
|
||||
t.Fatalf("failed to create test workspace: %v", err)
|
||||
}
|
||||
|
||||
// Delete the workspace
|
||||
if err := database.DeleteWorkspace(workspace.ID); err != nil {
|
||||
t.Fatalf("failed to delete workspace: %v", err)
|
||||
}
|
||||
|
||||
// Verify workspace is gone
|
||||
_, err = database.GetWorkspaceByID(workspace.ID)
|
||||
if err != sql.ErrNoRows {
|
||||
t.Errorf("expected sql.ErrNoRows, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to verify workspace fields
|
||||
func verifyWorkspace(t *testing.T, actual, expected *models.Workspace) {
|
||||
t.Helper()
|
||||
|
||||
if actual.Name != expected.Name {
|
||||
t.Errorf("Name = %v, want %v", actual.Name, expected.Name)
|
||||
}
|
||||
if actual.UserID != expected.UserID {
|
||||
t.Errorf("UserID = %v, want %v", actual.UserID, expected.UserID)
|
||||
}
|
||||
if actual.Theme != expected.Theme {
|
||||
t.Errorf("Theme = %v, want %v", actual.Theme, expected.Theme)
|
||||
}
|
||||
if actual.AutoSave != expected.AutoSave {
|
||||
t.Errorf("AutoSave = %v, want %v", actual.AutoSave, expected.AutoSave)
|
||||
}
|
||||
if actual.ShowHiddenFiles != expected.ShowHiddenFiles {
|
||||
t.Errorf("ShowHiddenFiles = %v, want %v", actual.ShowHiddenFiles, expected.ShowHiddenFiles)
|
||||
}
|
||||
if actual.GitEnabled != expected.GitEnabled {
|
||||
t.Errorf("GitEnabled = %v, want %v", actual.GitEnabled, expected.GitEnabled)
|
||||
}
|
||||
if actual.GitURL != expected.GitURL {
|
||||
t.Errorf("GitURL = %v, want %v", actual.GitURL, expected.GitURL)
|
||||
}
|
||||
if actual.GitUser != expected.GitUser {
|
||||
t.Errorf("GitUser = %v, want %v", actual.GitUser, expected.GitUser)
|
||||
}
|
||||
if actual.GitToken != expected.GitToken {
|
||||
t.Errorf("GitToken = %v, want %v", actual.GitToken, expected.GitToken)
|
||||
}
|
||||
if actual.GitAutoCommit != expected.GitAutoCommit {
|
||||
t.Errorf("GitAutoCommit = %v, want %v", actual.GitAutoCommit, expected.GitAutoCommit)
|
||||
}
|
||||
if actual.GitCommitMsgTemplate != expected.GitCommitMsgTemplate {
|
||||
t.Errorf("GitCommitMsgTemplate = %v, want %v", actual.GitCommitMsgTemplate, expected.GitCommitMsgTemplate)
|
||||
}
|
||||
if actual.CreatedAt.IsZero() {
|
||||
t.Error("CreatedAt should not be zero")
|
||||
}
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"novamd/internal/gitutils"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FileSystem represents the file system structure.
|
||||
type FileSystem struct {
|
||||
RootDir string
|
||||
GitRepos map[int]map[int]*gitutils.GitRepo // map[userID]map[workspaceID]*gitutils.GitRepo
|
||||
}
|
||||
|
||||
// New creates a new FileSystem instance.
|
||||
func New(rootDir string) *FileSystem {
|
||||
return &FileSystem{
|
||||
RootDir: rootDir,
|
||||
GitRepos: make(map[int]map[int]*gitutils.GitRepo),
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePath validates the given path and returns the cleaned path if it is valid.
|
||||
func (fs *FileSystem) ValidatePath(userID, workspaceID int, path string) (string, error) {
|
||||
workspacePath := fs.GetWorkspacePath(userID, workspaceID)
|
||||
fullPath := filepath.Join(workspacePath, path)
|
||||
cleanPath := filepath.Clean(fullPath)
|
||||
|
||||
if !strings.HasPrefix(cleanPath, workspacePath) {
|
||||
return "", fmt.Errorf("invalid path: outside of workspace")
|
||||
}
|
||||
|
||||
return cleanPath, nil
|
||||
}
|
||||
|
||||
// GetTotalFileStats returns the total file statistics for the file system.
|
||||
func (fs *FileSystem) GetTotalFileStats() (*FileCountStats, error) {
|
||||
return fs.countFilesInPath(fs.RootDir)
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"novamd/internal/gitutils"
|
||||
)
|
||||
|
||||
// SetupGitRepo sets up a Git repository for the given user and workspace IDs.
|
||||
func (fs *FileSystem) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken string) error {
|
||||
workspacePath := fs.GetWorkspacePath(userID, workspaceID)
|
||||
if _, ok := fs.GitRepos[userID]; !ok {
|
||||
fs.GitRepos[userID] = make(map[int]*gitutils.GitRepo)
|
||||
}
|
||||
fs.GitRepos[userID][workspaceID] = gitutils.New(gitURL, gitUser, gitToken, workspacePath)
|
||||
return fs.GitRepos[userID][workspaceID].EnsureRepo()
|
||||
}
|
||||
|
||||
// DisableGitRepo disables the Git repository for the given user and workspace IDs.
|
||||
func (fs *FileSystem) DisableGitRepo(userID, workspaceID int) {
|
||||
if userRepos, ok := fs.GitRepos[userID]; ok {
|
||||
delete(userRepos, workspaceID)
|
||||
if len(userRepos) == 0 {
|
||||
delete(fs.GitRepos, userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StageCommitAndPush stages, commits, and pushes the changes to the Git repository.
|
||||
func (fs *FileSystem) StageCommitAndPush(userID, workspaceID int, message string) error {
|
||||
repo, ok := fs.getGitRepo(userID, workspaceID)
|
||||
if !ok {
|
||||
return fmt.Errorf("git settings not configured for this workspace")
|
||||
}
|
||||
|
||||
if err := repo.Commit(message); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return repo.Push()
|
||||
}
|
||||
|
||||
// Pull pulls the changes from the remote Git repository.
|
||||
func (fs *FileSystem) Pull(userID, workspaceID int) error {
|
||||
repo, ok := fs.getGitRepo(userID, workspaceID)
|
||||
if !ok {
|
||||
return fmt.Errorf("git settings not configured for this workspace")
|
||||
}
|
||||
|
||||
return repo.Pull()
|
||||
}
|
||||
|
||||
// getGitRepo returns the Git repository for the given user and workspace IDs.
|
||||
func (fs *FileSystem) getGitRepo(userID, workspaceID int) (*gitutils.GitRepo, bool) {
|
||||
userRepos, ok := fs.GitRepos[userID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
repo, ok := userRepos[workspaceID]
|
||||
return repo, ok
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// GetWorkspacePath returns the path to the workspace directory for the given user and workspace IDs.
|
||||
func (fs *FileSystem) GetWorkspacePath(userID, workspaceID int) string {
|
||||
return filepath.Join(fs.RootDir, fmt.Sprintf("%d", userID), fmt.Sprintf("%d", workspaceID))
|
||||
}
|
||||
|
||||
// InitializeUserWorkspace creates the workspace directory for the given user and workspace IDs.
|
||||
func (fs *FileSystem) InitializeUserWorkspace(userID, workspaceID int) error {
|
||||
workspacePath := fs.GetWorkspacePath(userID, workspaceID)
|
||||
err := os.MkdirAll(workspacePath, 0755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create workspace directory: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUserWorkspace deletes the workspace directory for the given user and workspace IDs.
|
||||
func (fs *FileSystem) DeleteUserWorkspace(userID, workspaceID int) error {
|
||||
workspacePath := fs.GetWorkspacePath(userID, workspaceID)
|
||||
err := os.RemoveAll(workspacePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete workspace directory: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateWorkspaceDirectory creates the workspace directory for the given user and workspace IDs.
|
||||
func (fs *FileSystem) CreateWorkspaceDirectory(userID, workspaceID int) error {
|
||||
dir := fs.GetWorkspacePath(userID, workspaceID)
|
||||
return os.MkdirAll(dir, 0755)
|
||||
}
|
||||
157
server/internal/git/client.go
Normal file
157
server/internal/git/client.go
Normal file
@@ -0,0 +1,157 @@
|
||||
// Package git provides functionalities to interact with Git repositories, including cloning, pulling, committing, and pushing changes.
|
||||
package git
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-git/go-git/v5"
|
||||
"github.com/go-git/go-git/v5/plumbing/transport/http"
|
||||
)
|
||||
|
||||
// Config holds the configuration for a Git client
|
||||
type Config struct {
|
||||
URL string
|
||||
Username string
|
||||
Token string
|
||||
WorkDir string
|
||||
}
|
||||
|
||||
// Client defines the interface for Git operations
|
||||
type Client interface {
|
||||
Clone() error
|
||||
Pull() error
|
||||
Commit(message string) error
|
||||
Push() error
|
||||
EnsureRepo() error
|
||||
}
|
||||
|
||||
// client implements the Client interface
|
||||
type client struct {
|
||||
Config
|
||||
repo *git.Repository
|
||||
}
|
||||
|
||||
// New creates a new git Client instance
|
||||
func New(url, username, token, workDir string) Client {
|
||||
return &client{
|
||||
Config: Config{
|
||||
URL: url,
|
||||
Username: username,
|
||||
Token: token,
|
||||
WorkDir: workDir,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Clone clones the Git repository to the local directory
|
||||
func (c *client) Clone() error {
|
||||
auth := &http.BasicAuth{
|
||||
Username: c.Username,
|
||||
Password: c.Token,
|
||||
}
|
||||
|
||||
var err error
|
||||
c.repo, err = git.PlainClone(c.WorkDir, false, &git.CloneOptions{
|
||||
URL: c.URL,
|
||||
Auth: auth,
|
||||
Progress: os.Stdout,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clone repository: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pull pulls the latest changes from the remote repository
|
||||
func (c *client) Pull() error {
|
||||
if c.repo == nil {
|
||||
return fmt.Errorf("repository not initialized")
|
||||
}
|
||||
|
||||
w, err := c.repo.Worktree()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get worktree: %w", err)
|
||||
}
|
||||
|
||||
auth := &http.BasicAuth{
|
||||
Username: c.Username,
|
||||
Password: c.Token,
|
||||
}
|
||||
|
||||
err = w.Pull(&git.PullOptions{
|
||||
Auth: auth,
|
||||
Progress: os.Stdout,
|
||||
})
|
||||
|
||||
if err != nil && err != git.NoErrAlreadyUpToDate {
|
||||
return fmt.Errorf("failed to pull changes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Commit commits the changes in the repository with the given message
|
||||
func (c *client) Commit(message string) error {
|
||||
if c.repo == nil {
|
||||
return fmt.Errorf("repository not initialized")
|
||||
}
|
||||
|
||||
w, err := c.repo.Worktree()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get worktree: %w", err)
|
||||
}
|
||||
|
||||
_, err = w.Add(".")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add changes: %w", err)
|
||||
}
|
||||
|
||||
_, err = w.Commit(message, &git.CommitOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit changes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Push pushes the changes to the remote repository
|
||||
func (c *client) Push() error {
|
||||
if c.repo == nil {
|
||||
return fmt.Errorf("repository not initialized")
|
||||
}
|
||||
|
||||
auth := &http.BasicAuth{
|
||||
Username: c.Username,
|
||||
Password: c.Token,
|
||||
}
|
||||
|
||||
err := c.repo.Push(&git.PushOptions{
|
||||
Auth: auth,
|
||||
Progress: os.Stdout,
|
||||
})
|
||||
|
||||
if err != nil && err != git.NoErrAlreadyUpToDate {
|
||||
return fmt.Errorf("failed to push changes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnsureRepo ensures the local repository is cloned and up-to-date
|
||||
func (c *client) EnsureRepo() error {
|
||||
if _, err := os.Stat(filepath.Join(c.WorkDir, ".git")); os.IsNotExist(err) {
|
||||
return c.Clone()
|
||||
}
|
||||
|
||||
var err error
|
||||
c.repo, err = git.PlainOpen(c.WorkDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open existing repository: %w", err)
|
||||
}
|
||||
|
||||
return c.Pull()
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
package gitutils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-git/go-git/v5"
|
||||
"github.com/go-git/go-git/v5/plumbing/transport/http"
|
||||
)
|
||||
|
||||
type GitRepo struct {
|
||||
URL string
|
||||
Username string
|
||||
Token string
|
||||
WorkDir string
|
||||
repo *git.Repository
|
||||
}
|
||||
|
||||
func New(url, username, token, workDir string) *GitRepo {
|
||||
return &GitRepo{
|
||||
URL: url,
|
||||
Username: username,
|
||||
Token: token,
|
||||
WorkDir: workDir,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *GitRepo) Clone() error {
|
||||
auth := &http.BasicAuth{
|
||||
Username: g.Username,
|
||||
Password: g.Token,
|
||||
}
|
||||
|
||||
var err error
|
||||
g.repo, err = git.PlainClone(g.WorkDir, false, &git.CloneOptions{
|
||||
URL: g.URL,
|
||||
Auth: auth,
|
||||
Progress: os.Stdout,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clone repository: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GitRepo) Pull() error {
|
||||
if g.repo == nil {
|
||||
return fmt.Errorf("repository not initialized")
|
||||
}
|
||||
|
||||
w, err := g.repo.Worktree()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get worktree: %w", err)
|
||||
}
|
||||
|
||||
auth := &http.BasicAuth{
|
||||
Username: g.Username,
|
||||
Password: g.Token,
|
||||
}
|
||||
|
||||
err = w.Pull(&git.PullOptions{
|
||||
Auth: auth,
|
||||
Progress: os.Stdout,
|
||||
})
|
||||
|
||||
if err != nil && err != git.NoErrAlreadyUpToDate {
|
||||
return fmt.Errorf("failed to pull changes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GitRepo) Commit(message string) error {
|
||||
if g.repo == nil {
|
||||
return fmt.Errorf("repository not initialized")
|
||||
}
|
||||
|
||||
w, err := g.repo.Worktree()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get worktree: %w", err)
|
||||
}
|
||||
|
||||
_, err = w.Add(".")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add changes: %w", err)
|
||||
}
|
||||
|
||||
_, err = w.Commit(message, &git.CommitOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit changes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GitRepo) Push() error {
|
||||
if g.repo == nil {
|
||||
return fmt.Errorf("repository not initialized")
|
||||
}
|
||||
|
||||
auth := &http.BasicAuth{
|
||||
Username: g.Username,
|
||||
Password: g.Token,
|
||||
}
|
||||
|
||||
err := g.repo.Push(&git.PushOptions{
|
||||
Auth: auth,
|
||||
Progress: os.Stdout,
|
||||
})
|
||||
|
||||
if err != nil && err != git.NoErrAlreadyUpToDate {
|
||||
return fmt.Errorf("failed to push changes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GitRepo) EnsureRepo() error {
|
||||
if _, err := os.Stat(filepath.Join(g.WorkDir, ".git")); os.IsNotExist(err) {
|
||||
return g.Clone()
|
||||
}
|
||||
|
||||
var err error
|
||||
g.repo, err = git.PlainOpen(g.WorkDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open existing repository: %w", err)
|
||||
}
|
||||
|
||||
return g.Pull()
|
||||
}
|
||||
@@ -1,12 +1,13 @@
|
||||
// Package handlers contains the request handlers for the api routes.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"novamd/internal/context"
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/filesystem"
|
||||
"novamd/internal/httpcontext"
|
||||
"novamd/internal/models"
|
||||
"novamd/internal/storage"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -14,14 +15,16 @@ import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type createUserRequest struct {
|
||||
// CreateUserRequest holds the request fields for creating a new user
|
||||
type CreateUserRequest struct {
|
||||
Email string `json:"email"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Password string `json:"password"`
|
||||
Role models.UserRole `json:"role"`
|
||||
}
|
||||
|
||||
type updateUserRequest struct {
|
||||
// UpdateUserRequest holds the request fields for updating a user
|
||||
type UpdateUserRequest struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
@@ -44,7 +47,7 @@ func (h *Handler) AdminListUsers() http.HandlerFunc {
|
||||
// AdminCreateUser creates a new user
|
||||
func (h *Handler) AdminCreateUser() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req createUserRequest
|
||||
var req CreateUserRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
@@ -91,7 +94,7 @@ func (h *Handler) AdminCreateUser() http.HandlerFunc {
|
||||
}
|
||||
|
||||
// Initialize user workspace
|
||||
if err := h.FS.InitializeUserWorkspace(insertedUser.ID, insertedUser.LastWorkspaceID); err != nil {
|
||||
if err := h.Storage.InitializeUserWorkspace(insertedUser.ID, insertedUser.LastWorkspaceID); err != nil {
|
||||
http.Error(w, "Failed to initialize user workspace", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -135,7 +138,7 @@ func (h *Handler) AdminUpdateUser() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
var req updateUserRequest
|
||||
var req UpdateUserRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
@@ -172,7 +175,7 @@ func (h *Handler) AdminUpdateUser() http.HandlerFunc {
|
||||
// AdminDeleteUser deletes a specific user
|
||||
func (h *Handler) AdminDeleteUser() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -218,7 +221,7 @@ type WorkspaceStats struct {
|
||||
WorkspaceID int `json:"workspaceID"`
|
||||
WorkspaceName string `json:"workspaceName"`
|
||||
WorkspaceCreatedAt time.Time `json:"workspaceCreatedAt"`
|
||||
*filesystem.FileCountStats
|
||||
*storage.FileCountStats
|
||||
}
|
||||
|
||||
// AdminListWorkspaces returns a list of all workspaces and their stats
|
||||
@@ -248,7 +251,7 @@ func (h *Handler) AdminListWorkspaces() http.HandlerFunc {
|
||||
workspaceData.WorkspaceName = ws.Name
|
||||
workspaceData.WorkspaceCreatedAt = ws.CreatedAt
|
||||
|
||||
fileStats, err := h.FS.GetFileStats(ws.UserID, ws.ID)
|
||||
fileStats, err := h.Storage.GetFileStats(ws.UserID, ws.ID)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get file stats", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -266,7 +269,7 @@ func (h *Handler) AdminListWorkspaces() http.HandlerFunc {
|
||||
// SystemStats holds system-wide statistics
|
||||
type SystemStats struct {
|
||||
*db.UserStats
|
||||
*filesystem.FileCountStats
|
||||
*storage.FileCountStats
|
||||
}
|
||||
|
||||
// AdminGetSystemStats returns system-wide statistics for admins
|
||||
@@ -278,7 +281,7 @@ func (h *Handler) AdminGetSystemStats() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
fileStats, err := h.FS.GetTotalFileStats()
|
||||
fileStats, err := h.Storage.GetTotalFileStats()
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get file stats", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
243
server/internal/handlers/admin_handlers_integration_test.go
Normal file
243
server/internal/handlers/admin_handlers_integration_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
//go:build integration
|
||||
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/handlers"
|
||||
"novamd/internal/models"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Helper function to check if a user exists in a slice of users
|
||||
func containsUser(users []*models.User, searchUser *models.User) bool {
|
||||
for _, u := range users {
|
||||
if u.ID == searchUser.ID &&
|
||||
u.Email == searchUser.Email &&
|
||||
u.DisplayName == searchUser.DisplayName &&
|
||||
u.Role == searchUser.Role {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestAdminHandlers_Integration(t *testing.T) {
|
||||
h := setupTestHarness(t)
|
||||
defer h.teardown(t)
|
||||
|
||||
t.Run("user management", func(t *testing.T) {
|
||||
t.Run("list users", func(t *testing.T) {
|
||||
// Test with admin token
|
||||
rr := h.makeRequest(t, http.MethodGet, "/api/v1/admin/users", nil, h.AdminToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var users []*models.User
|
||||
err := json.NewDecoder(rr.Body).Decode(&users)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have at least our admin and regular test users
|
||||
assert.GreaterOrEqual(t, len(users), 2)
|
||||
assert.True(t, containsUser(users, h.AdminUser), "Admin user not found in users list")
|
||||
assert.True(t, containsUser(users, h.RegularUser), "Regular user not found in users list")
|
||||
|
||||
// Test with non-admin token
|
||||
rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/users", nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code)
|
||||
|
||||
// Test without token
|
||||
rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/users", nil, "", nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("create user", func(t *testing.T) {
|
||||
createReq := handlers.CreateUserRequest{
|
||||
Email: "newuser@test.com",
|
||||
DisplayName: "New User",
|
||||
Password: "password123",
|
||||
Role: models.RoleEditor,
|
||||
}
|
||||
|
||||
// Test with admin token
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.AdminToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var createdUser models.User
|
||||
err := json.NewDecoder(rr.Body).Decode(&createdUser)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, createReq.Email, createdUser.Email)
|
||||
assert.Equal(t, createReq.DisplayName, createdUser.DisplayName)
|
||||
assert.Equal(t, createReq.Role, createdUser.Role)
|
||||
assert.NotZero(t, createdUser.LastWorkspaceID)
|
||||
|
||||
// Test duplicate email
|
||||
rr = h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.AdminToken, nil)
|
||||
assert.Equal(t, http.StatusConflict, rr.Code)
|
||||
|
||||
// Test invalid request (missing required fields)
|
||||
invalidReq := handlers.CreateUserRequest{
|
||||
Email: "invalid@test.com",
|
||||
// Missing password and role
|
||||
}
|
||||
rr = h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", invalidReq, h.AdminToken, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
|
||||
// Test with non-admin token
|
||||
rr = h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("get user", func(t *testing.T) {
|
||||
path := fmt.Sprintf("/api/v1/admin/users/%d", h.RegularUser.ID)
|
||||
|
||||
// Test with admin token
|
||||
rr := h.makeRequest(t, http.MethodGet, path, nil, h.AdminToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var user models.User
|
||||
err := json.NewDecoder(rr.Body).Decode(&user)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, h.RegularUser.ID, user.ID)
|
||||
|
||||
// Test non-existent user
|
||||
rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/users/999999", nil, h.AdminToken, nil)
|
||||
assert.Equal(t, http.StatusNotFound, rr.Code)
|
||||
|
||||
// Test with non-admin token
|
||||
rr = h.makeRequest(t, http.MethodGet, path, nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("update user", func(t *testing.T) {
|
||||
path := fmt.Sprintf("/api/v1/admin/users/%d", h.RegularUser.ID)
|
||||
updateReq := handlers.UpdateUserRequest{
|
||||
DisplayName: "Updated Name",
|
||||
Role: models.RoleViewer,
|
||||
}
|
||||
|
||||
// Test with admin token
|
||||
rr := h.makeRequest(t, http.MethodPut, path, updateReq, h.AdminToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var updatedUser models.User
|
||||
err := json.NewDecoder(rr.Body).Decode(&updatedUser)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, updateReq.DisplayName, updatedUser.DisplayName)
|
||||
assert.Equal(t, updateReq.Role, updatedUser.Role)
|
||||
|
||||
// Test with non-admin token
|
||||
rr = h.makeRequest(t, http.MethodPut, path, updateReq, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("delete user", func(t *testing.T) {
|
||||
// Create a user to delete
|
||||
createReq := handlers.CreateUserRequest{
|
||||
Email: "todelete@test.com",
|
||||
DisplayName: "To Delete",
|
||||
Password: "password123",
|
||||
Role: models.RoleEditor,
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.AdminToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var createdUser models.User
|
||||
err := json.NewDecoder(rr.Body).Decode(&createdUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
path := fmt.Sprintf("/api/v1/admin/users/%d", createdUser.ID)
|
||||
|
||||
// Test deleting own account (should fail)
|
||||
adminPath := fmt.Sprintf("/api/v1/admin/users/%d", h.AdminUser.ID)
|
||||
rr = h.makeRequest(t, http.MethodDelete, adminPath, nil, h.AdminToken, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
|
||||
// Test with admin token
|
||||
rr = h.makeRequest(t, http.MethodDelete, path, nil, h.AdminToken, nil)
|
||||
assert.Equal(t, http.StatusNoContent, rr.Code)
|
||||
|
||||
// Verify user is deleted
|
||||
rr = h.makeRequest(t, http.MethodGet, path, nil, h.AdminToken, nil)
|
||||
assert.Equal(t, http.StatusNotFound, rr.Code)
|
||||
|
||||
// Test with non-admin token
|
||||
rr = h.makeRequest(t, http.MethodDelete, path, nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("workspace management", func(t *testing.T) {
|
||||
t.Run("list workspaces", func(t *testing.T) {
|
||||
// Create a test workspace first
|
||||
workspace := &models.Workspace{
|
||||
UserID: h.RegularUser.ID,
|
||||
Name: "Test Workspace",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Test with admin token
|
||||
rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/workspaces", nil, h.AdminToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var workspaces []*handlers.WorkspaceStats
|
||||
err := json.NewDecoder(rr.Body).Decode(&workspaces)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have at least the default workspaces for admin and regular users
|
||||
assert.NotEmpty(t, workspaces)
|
||||
|
||||
// Verify workspace stats fields
|
||||
for _, ws := range workspaces {
|
||||
assert.NotZero(t, ws.UserID)
|
||||
assert.NotEmpty(t, ws.UserEmail)
|
||||
assert.NotZero(t, ws.WorkspaceID)
|
||||
assert.NotEmpty(t, ws.WorkspaceName)
|
||||
assert.NotZero(t, ws.WorkspaceCreatedAt)
|
||||
assert.GreaterOrEqual(t, ws.TotalFiles, 0)
|
||||
assert.GreaterOrEqual(t, ws.TotalSize, int64(0))
|
||||
}
|
||||
|
||||
// Test with non-admin token
|
||||
rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/workspaces", nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("system stats", func(t *testing.T) {
|
||||
// Create some test data
|
||||
workspace := &models.Workspace{
|
||||
UserID: h.RegularUser.ID,
|
||||
Name: "Stats Test Workspace",
|
||||
}
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Test with admin token
|
||||
rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/stats", nil, h.AdminToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var stats handlers.SystemStats
|
||||
err := json.NewDecoder(rr.Body).Decode(&stats)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify stats fields
|
||||
assert.GreaterOrEqual(t, stats.TotalUsers, 2) // At least admin and regular user
|
||||
assert.GreaterOrEqual(t, stats.TotalWorkspaces, 2) // At least default workspaces
|
||||
assert.GreaterOrEqual(t, stats.ActiveUsers, 2) // Our test users should be active
|
||||
assert.GreaterOrEqual(t, stats.TotalFiles, 0)
|
||||
assert.GreaterOrEqual(t, stats.TotalSize, int64(0))
|
||||
|
||||
// Test with non-admin token
|
||||
rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/stats", nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code)
|
||||
})
|
||||
}
|
||||
@@ -4,28 +4,32 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"novamd/internal/auth"
|
||||
"novamd/internal/httpcontext"
|
||||
"novamd/internal/context"
|
||||
"novamd/internal/models"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// LoginRequest represents a user login request
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// LoginResponse represents a user login response
|
||||
type LoginResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
User *models.User `json:"user"`
|
||||
Session *auth.Session `json:"session"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
User *models.User `json:"user"`
|
||||
Session *models.Session `json:"session"`
|
||||
}
|
||||
|
||||
// RefreshRequest represents a refresh token request
|
||||
type RefreshRequest struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// RefreshResponse represents a refresh token response
|
||||
type RefreshResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
}
|
||||
@@ -129,7 +133,7 @@ func (h *Handler) RefreshToken(authService *auth.SessionService) http.HandlerFun
|
||||
// GetCurrentUser returns the currently authenticated user
|
||||
func (h *Handler) GetCurrentUser() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
232
server/internal/handlers/auth_handlers_integration_test.go
Normal file
232
server/internal/handlers/auth_handlers_integration_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
//go:build integration
|
||||
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/handlers"
|
||||
"novamd/internal/models"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAuthHandlers_Integration(t *testing.T) {
|
||||
h := setupTestHarness(t)
|
||||
defer h.teardown(t)
|
||||
|
||||
t.Run("login", func(t *testing.T) {
|
||||
t.Run("successful login - admin user", func(t *testing.T) {
|
||||
loginReq := handlers.LoginRequest{
|
||||
Email: "admin@test.com",
|
||||
Password: "admin123",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var resp handlers.LoginResponse
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, resp.AccessToken)
|
||||
assert.NotEmpty(t, resp.RefreshToken)
|
||||
assert.NotNil(t, resp.User)
|
||||
assert.Equal(t, loginReq.Email, resp.User.Email)
|
||||
assert.Equal(t, models.RoleAdmin, resp.User.Role)
|
||||
})
|
||||
|
||||
t.Run("successful login - regular user", func(t *testing.T) {
|
||||
loginReq := handlers.LoginRequest{
|
||||
Email: "user@test.com",
|
||||
Password: "user123",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var resp handlers.LoginResponse
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, resp.AccessToken)
|
||||
assert.NotEmpty(t, resp.RefreshToken)
|
||||
assert.NotNil(t, resp.User)
|
||||
assert.Equal(t, loginReq.Email, resp.User.Email)
|
||||
assert.Equal(t, models.RoleEditor, resp.User.Role)
|
||||
})
|
||||
|
||||
t.Run("login failures", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request handlers.LoginRequest
|
||||
wantCode int
|
||||
}{
|
||||
{
|
||||
name: "wrong password",
|
||||
request: handlers.LoginRequest{
|
||||
Email: "user@test.com",
|
||||
Password: "wrongpassword",
|
||||
},
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "non-existent user",
|
||||
request: handlers.LoginRequest{
|
||||
Email: "nonexistent@test.com",
|
||||
Password: "password123",
|
||||
},
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "empty email",
|
||||
request: handlers.LoginRequest{
|
||||
Email: "",
|
||||
Password: "password123",
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "empty password",
|
||||
request: handlers.LoginRequest{
|
||||
Email: "user@test.com",
|
||||
Password: "",
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", tt.request, "", nil)
|
||||
assert.Equal(t, tt.wantCode, rr.Code)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("refresh token", func(t *testing.T) {
|
||||
t.Run("successful token refresh", func(t *testing.T) {
|
||||
// First login to get refresh token
|
||||
loginReq := handlers.LoginRequest{
|
||||
Email: "user@test.com",
|
||||
Password: "user123",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var loginResp handlers.LoginResponse
|
||||
err := json.NewDecoder(rr.Body).Decode(&loginResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now try to refresh the token
|
||||
refreshReq := handlers.RefreshRequest{
|
||||
RefreshToken: loginResp.RefreshToken,
|
||||
}
|
||||
|
||||
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", refreshReq, "", nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var refreshResp handlers.RefreshResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&refreshResp)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, refreshResp.AccessToken)
|
||||
})
|
||||
|
||||
t.Run("refresh failures", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request handlers.RefreshRequest
|
||||
wantCode int
|
||||
}{
|
||||
{
|
||||
name: "invalid refresh token",
|
||||
request: handlers.RefreshRequest{
|
||||
RefreshToken: "invalid-token",
|
||||
},
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "empty refresh token",
|
||||
request: handlers.RefreshRequest{
|
||||
RefreshToken: "",
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", tt.request, "", nil)
|
||||
assert.Equal(t, tt.wantCode, rr.Code)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("logout", func(t *testing.T) {
|
||||
t.Run("successful logout", func(t *testing.T) {
|
||||
// First login to get session
|
||||
loginReq := handlers.LoginRequest{
|
||||
Email: "user@test.com",
|
||||
Password: "user123",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var loginResp handlers.LoginResponse
|
||||
err := json.NewDecoder(rr.Body).Decode(&loginResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now logout using session ID from login response
|
||||
headers := map[string]string{
|
||||
"X-Session-ID": loginResp.Session.ID,
|
||||
}
|
||||
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/logout", nil, loginResp.AccessToken, headers)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Try to use the refresh token - should fail
|
||||
refreshReq := handlers.RefreshRequest{
|
||||
RefreshToken: loginResp.RefreshToken,
|
||||
}
|
||||
|
||||
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", refreshReq, "", nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("logout without session ID", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/logout", nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("get current user", func(t *testing.T) {
|
||||
t.Run("successful get current user", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var user models.User
|
||||
err := json.NewDecoder(rr.Body).Decode(&user)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, h.RegularUser.ID, user.ID)
|
||||
assert.Equal(t, h.RegularUser.Email, user.Email)
|
||||
assert.Equal(t, h.RegularUser.Role, user.Role)
|
||||
})
|
||||
|
||||
t.Run("get current user without token", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "", nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("get current user with invalid token", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "invalid-token", nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -4,20 +4,23 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"novamd/internal/httpcontext"
|
||||
"novamd/internal/context"
|
||||
"novamd/internal/storage"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// ListFiles returns a list of all files in the workspace
|
||||
func (h *Handler) ListFiles() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
files, err := h.FS.ListFilesRecursively(ctx.UserID, ctx.Workspace.ID)
|
||||
files, err := h.Storage.ListFilesRecursively(ctx.UserID, ctx.Workspace.ID)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to list files", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -27,9 +30,10 @@ func (h *Handler) ListFiles() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// LookupFileByName returns the paths of files with the given name
|
||||
func (h *Handler) LookupFileByName() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -40,7 +44,7 @@ func (h *Handler) LookupFileByName() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
filePaths, err := h.FS.FindFileByName(ctx.UserID, ctx.Workspace.ID, filename)
|
||||
filePaths, err := h.Storage.FindFileByName(ctx.UserID, ctx.Workspace.ID, filename)
|
||||
if err != nil {
|
||||
http.Error(w, "File not found", http.StatusNotFound)
|
||||
return
|
||||
@@ -50,28 +54,45 @@ func (h *Handler) LookupFileByName() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// GetFileContent returns the content of a file
|
||||
func (h *Handler) GetFileContent() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
filePath := chi.URLParam(r, "*")
|
||||
content, err := h.FS.GetFileContent(ctx.UserID, ctx.Workspace.ID, filePath)
|
||||
content, err := h.Storage.GetFileContent(ctx.UserID, ctx.Workspace.ID, filePath)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read file", http.StatusNotFound)
|
||||
|
||||
if storage.IsPathValidationError(err) {
|
||||
http.Error(w, "Invalid file path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
http.Error(w, "Failed to read file", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to read file", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.Write(content)
|
||||
_, err = w.Write(content)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SaveFile saves the content of a file
|
||||
func (h *Handler) SaveFile() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -83,8 +104,13 @@ func (h *Handler) SaveFile() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.FS.SaveFile(ctx.UserID, ctx.Workspace.ID, filePath, content)
|
||||
err = h.Storage.SaveFile(ctx.UserID, ctx.Workspace.ID, filePath, content)
|
||||
if err != nil {
|
||||
if storage.IsPathValidationError(err) {
|
||||
http.Error(w, "Invalid file path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to save file", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -93,28 +119,44 @@ func (h *Handler) SaveFile() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteFile deletes a file
|
||||
func (h *Handler) DeleteFile() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
filePath := chi.URLParam(r, "*")
|
||||
err := h.FS.DeleteFile(ctx.UserID, ctx.Workspace.ID, filePath)
|
||||
err := h.Storage.DeleteFile(ctx.UserID, ctx.Workspace.ID, filePath)
|
||||
if err != nil {
|
||||
if storage.IsPathValidationError(err) {
|
||||
http.Error(w, "Invalid file path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
http.Error(w, "File not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to delete file", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("File deleted successfully"))
|
||||
_, err = w.Write([]byte("File deleted successfully"))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetLastOpenedFile returns the last opened file in the workspace
|
||||
func (h *Handler) GetLastOpenedFile() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -125,7 +167,7 @@ func (h *Handler) GetLastOpenedFile() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.FS.ValidatePath(ctx.UserID, ctx.Workspace.ID, filePath); err != nil {
|
||||
if _, err := h.Storage.ValidatePath(ctx.UserID, ctx.Workspace.ID, filePath); err != nil {
|
||||
http.Error(w, "Invalid file path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@@ -134,9 +176,10 @@ func (h *Handler) GetLastOpenedFile() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateLastOpenedFile updates the last opened file in the workspace
|
||||
func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -150,10 +193,21 @@ func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the file path exists in the workspace
|
||||
// Validate the file path in the workspace
|
||||
if requestBody.FilePath != "" {
|
||||
if _, err := h.FS.ValidatePath(ctx.UserID, ctx.Workspace.ID, requestBody.FilePath); err != nil {
|
||||
http.Error(w, "Invalid file path", http.StatusBadRequest)
|
||||
_, err := h.Storage.GetFileContent(ctx.UserID, ctx.Workspace.ID, requestBody.FilePath)
|
||||
if err != nil {
|
||||
if storage.IsPathValidationError(err) {
|
||||
http.Error(w, "Invalid file path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
http.Error(w, "File not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to update file", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
239
server/internal/handlers/file_handlers_integration_test.go
Normal file
239
server/internal/handlers/file_handlers_integration_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
//go:build integration
|
||||
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/models"
|
||||
"novamd/internal/storage"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFileHandlers_Integration(t *testing.T) {
|
||||
h := setupTestHarness(t)
|
||||
defer h.teardown(t)
|
||||
|
||||
t.Run("file operations", func(t *testing.T) {
|
||||
// Setup: Create a workspace first
|
||||
workspace := &models.Workspace{
|
||||
UserID: h.RegularUser.ID,
|
||||
Name: "File Test Workspace",
|
||||
}
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
err := json.NewDecoder(rr.Body).Decode(workspace)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Construct base URL for file operations
|
||||
baseURL := fmt.Sprintf("/api/v1/workspaces/%s/files", url.PathEscape(workspace.Name))
|
||||
|
||||
t.Run("list empty directory", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodGet, baseURL, nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var files []storage.FileNode
|
||||
err := json.NewDecoder(rr.Body).Decode(&files)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, files, "Expected empty directory")
|
||||
})
|
||||
|
||||
t.Run("save and get file", func(t *testing.T) {
|
||||
content := "Test content for file operations"
|
||||
filePath := "test.md"
|
||||
|
||||
// Save file
|
||||
rr := h.makeRequestRaw(t, http.MethodPost, baseURL+"/"+filePath, strings.NewReader(content), h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Get file content
|
||||
rr = h.makeRequest(t, http.MethodGet, baseURL+"/"+filePath, nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, content, rr.Body.String())
|
||||
|
||||
// List directory should now show the file
|
||||
rr = h.makeRequest(t, http.MethodGet, baseURL, nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var files []storage.FileNode
|
||||
err := json.NewDecoder(rr.Body).Decode(&files)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, files, 1)
|
||||
assert.Equal(t, filePath, files[0].Name)
|
||||
})
|
||||
|
||||
t.Run("save and list nested files", func(t *testing.T) {
|
||||
files := map[string]string{
|
||||
"docs/readme.md": "README content",
|
||||
"docs/api/endpoints.md": "API documentation",
|
||||
"notes/meeting-notes.md": "Meeting notes content",
|
||||
"notes/todo.md": "TODO list",
|
||||
}
|
||||
|
||||
// Create all files
|
||||
for path, content := range files {
|
||||
rr := h.makeRequest(t, http.MethodPost, baseURL+"/"+path, content, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
// List all files
|
||||
rr := h.makeRequest(t, http.MethodGet, baseURL, nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var fileNodes []storage.FileNode
|
||||
err := json.NewDecoder(rr.Body).Decode(&fileNodes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We should have 3 root items: docs/, notes/, and test.md
|
||||
assert.Len(t, fileNodes, 3)
|
||||
|
||||
// Verify directory structure
|
||||
var docsDir, notesDir *storage.FileNode
|
||||
for i := range fileNodes {
|
||||
switch fileNodes[i].Name {
|
||||
case "docs":
|
||||
docsDir = &fileNodes[i]
|
||||
case "notes":
|
||||
notesDir = &fileNodes[i]
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, docsDir)
|
||||
require.NotNil(t, notesDir)
|
||||
assert.Len(t, docsDir.Children, 2) // readme.md and api/
|
||||
assert.Len(t, notesDir.Children, 2) // meeting-notes.md and todo.md
|
||||
})
|
||||
|
||||
t.Run("lookup file by name", func(t *testing.T) {
|
||||
// Look up a file that exists in multiple locations
|
||||
filename := "readme.md"
|
||||
dupContent := "Another readme"
|
||||
rr := h.makeRequest(t, http.MethodPost, baseURL+"/projects/"+filename, dupContent, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Search for the file
|
||||
rr = h.makeRequest(t, http.MethodGet, baseURL+"/lookup?filename="+filename, nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var response struct {
|
||||
Paths []string `json:"paths"`
|
||||
}
|
||||
err := json.NewDecoder(rr.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, response.Paths, 2)
|
||||
|
||||
// Search for non-existent file
|
||||
rr = h.makeRequest(t, http.MethodGet, baseURL+"/lookup?filename=nonexistent.md", nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusNotFound, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("delete file", func(t *testing.T) {
|
||||
filePath := "to-delete.md"
|
||||
content := "This file will be deleted"
|
||||
|
||||
// Create file
|
||||
rr := h.makeRequest(t, http.MethodPost, baseURL+"/"+filePath, content, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Delete file
|
||||
rr = h.makeRequest(t, http.MethodDelete, baseURL+"/"+filePath, nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify file is gone
|
||||
rr = h.makeRequest(t, http.MethodGet, baseURL+"/"+filePath, nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusNotFound, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("last opened file", func(t *testing.T) {
|
||||
// Initially should be empty
|
||||
rr := h.makeRequest(t, http.MethodGet, baseURL+"/last", nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var response struct {
|
||||
LastOpenedFilePath string `json:"lastOpenedFilePath"`
|
||||
}
|
||||
err := json.NewDecoder(rr.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, response.LastOpenedFilePath)
|
||||
|
||||
// Update last opened file
|
||||
updateReq := struct {
|
||||
FilePath string `json:"filePath"`
|
||||
}{
|
||||
FilePath: "docs/readme.md",
|
||||
}
|
||||
rr = h.makeRequest(t, http.MethodPut, baseURL+"/last", updateReq, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify update
|
||||
rr = h.makeRequest(t, http.MethodGet, baseURL+"/last", nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
err = json.NewDecoder(rr.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, updateReq.FilePath, response.LastOpenedFilePath)
|
||||
|
||||
// Test invalid file path
|
||||
updateReq.FilePath = "nonexistent.md"
|
||||
rr = h.makeRequest(t, http.MethodPut, baseURL+"/last", updateReq, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusNotFound, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("unauthorized access", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
body interface{}
|
||||
}{
|
||||
{"list files", http.MethodGet, baseURL, nil},
|
||||
{"get file", http.MethodGet, baseURL + "/test.md", nil},
|
||||
{"save file", http.MethodPost, baseURL + "/test.md", "content"},
|
||||
{"delete file", http.MethodDelete, baseURL + "/test.md", nil},
|
||||
{"get last file", http.MethodGet, baseURL + "/last", nil},
|
||||
{"update last file", http.MethodPut, baseURL + "/last", struct{ FilePath string }{"test.md"}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Test without token
|
||||
rr := h.makeRequest(t, tc.method, tc.path, tc.body, "", nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
|
||||
// Test with wrong user's token
|
||||
rr = h.makeRequest(t, tc.method, tc.path, tc.body, h.AdminToken, nil)
|
||||
assert.Equal(t, http.StatusNotFound, rr.Code)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("path traversal attempts", func(t *testing.T) {
|
||||
maliciousPaths := []string{
|
||||
"../../../etc/passwd",
|
||||
"./../../secret.txt",
|
||||
"/etc/shadow",
|
||||
"test/../../../etc/passwd",
|
||||
}
|
||||
|
||||
for _, path := range maliciousPaths {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
// Try to read
|
||||
rr := h.makeRequest(t, http.MethodGet, baseURL+"/"+path, nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
|
||||
// Try to write
|
||||
rr = h.makeRequest(t, http.MethodPost, baseURL+"/"+path, "malicious content", h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"novamd/internal/httpcontext"
|
||||
"novamd/internal/context"
|
||||
)
|
||||
|
||||
// StageCommitAndPush stages, commits, and pushes changes to the remote repository
|
||||
func (h *Handler) StageCommitAndPush() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -28,7 +29,7 @@ func (h *Handler) StageCommitAndPush() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.FS.StageCommitAndPush(ctx.UserID, ctx.Workspace.ID, requestBody.Message)
|
||||
err := h.Storage.StageCommitAndPush(ctx.UserID, ctx.Workspace.ID, requestBody.Message)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to stage, commit, and push changes: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
@@ -38,14 +39,15 @@ func (h *Handler) StageCommitAndPush() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// PullChanges pulls changes from the remote repository
|
||||
func (h *Handler) PullChanges() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.FS.Pull(ctx.UserID, ctx.Workspace.ID)
|
||||
err := h.Storage.Pull(ctx.UserID, ctx.Workspace.ID)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to pull changes: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
181
server/internal/handlers/git_handlers_integration_test.go
Normal file
181
server/internal/handlers/git_handlers_integration_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
//go:build integration
|
||||
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/models"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGitHandlers_Integration(t *testing.T) {
|
||||
h := setupTestHarness(t)
|
||||
defer h.teardown(t)
|
||||
|
||||
t.Run("git operations", func(t *testing.T) {
|
||||
// Setup: Create a workspace with Git enabled
|
||||
workspace := &models.Workspace{
|
||||
UserID: h.RegularUser.ID,
|
||||
Name: "Git Test Workspace",
|
||||
GitEnabled: true,
|
||||
GitURL: "https://github.com/test/repo.git",
|
||||
GitUser: "testuser",
|
||||
GitToken: "testtoken",
|
||||
GitAutoCommit: true,
|
||||
GitCommitMsgTemplate: "Update: {{message}}",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
err := json.NewDecoder(rr.Body).Decode(workspace)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Construct base URL for Git operations
|
||||
baseURL := "/api/v1/workspaces/" + url.PathEscape(workspace.Name) + "/git"
|
||||
|
||||
t.Run("stage, commit and push", func(t *testing.T) {
|
||||
h.MockGit.Reset()
|
||||
|
||||
t.Run("successful commit", func(t *testing.T) {
|
||||
commitMsg := "Test commit message"
|
||||
requestBody := map[string]string{
|
||||
"message": commitMsg,
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, baseURL+"/commit", requestBody, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rr.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response["message"], "successfully")
|
||||
|
||||
// Verify mock was called correctly
|
||||
assert.Equal(t, 1, h.MockGit.GetCommitCount(), "Commit should be called once")
|
||||
assert.Equal(t, 1, h.MockGit.GetPushCount(), "Push should be called once")
|
||||
assert.Equal(t, commitMsg, h.MockGit.GetLastCommitMessage(), "Commit message should match")
|
||||
})
|
||||
|
||||
t.Run("empty commit message", func(t *testing.T) {
|
||||
h.MockGit.Reset()
|
||||
requestBody := map[string]string{
|
||||
"message": "",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, baseURL+"/commit", requestBody, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
assert.Equal(t, 0, h.MockGit.GetCommitCount(), "Commit should not be called")
|
||||
})
|
||||
|
||||
t.Run("git error", func(t *testing.T) {
|
||||
h.MockGit.Reset()
|
||||
h.MockGit.SetError(fmt.Errorf("mock git error"))
|
||||
|
||||
requestBody := map[string]string{
|
||||
"message": "Test message",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, baseURL+"/commit", requestBody, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
|
||||
h.MockGit.SetError(nil) // Reset error state
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("pull changes", func(t *testing.T) {
|
||||
h.MockGit.Reset()
|
||||
|
||||
t.Run("successful pull", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodPost, baseURL+"/pull", nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rr.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response["message"], "Pulled changes")
|
||||
|
||||
assert.Equal(t, 1, h.MockGit.GetPullCount(), "Pull should be called once")
|
||||
})
|
||||
|
||||
t.Run("git error", func(t *testing.T) {
|
||||
h.MockGit.Reset()
|
||||
h.MockGit.SetError(fmt.Errorf("mock git error"))
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, baseURL+"/pull", nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
|
||||
h.MockGit.SetError(nil) // Reset error state
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("unauthorized access", func(t *testing.T) {
|
||||
h.MockGit.Reset()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
body interface{}
|
||||
}{
|
||||
{
|
||||
name: "commit without token",
|
||||
method: http.MethodPost,
|
||||
path: baseURL + "/commit",
|
||||
body: map[string]string{"message": "test"},
|
||||
},
|
||||
{
|
||||
name: "pull without token",
|
||||
method: http.MethodPost,
|
||||
path: baseURL + "/pull",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Test without token
|
||||
rr := h.makeRequest(t, tc.method, tc.path, tc.body, "", nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
|
||||
// Test with wrong user's token
|
||||
rr = h.makeRequest(t, tc.method, tc.path, tc.body, h.AdminToken, nil)
|
||||
assert.Equal(t, http.StatusNotFound, rr.Code)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("workspace without git", func(t *testing.T) {
|
||||
h.MockGit.Reset()
|
||||
|
||||
// Create a workspace without Git enabled
|
||||
nonGitWorkspace := &models.Workspace{
|
||||
UserID: h.RegularUser.ID,
|
||||
Name: "Non-Git Workspace",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", nonGitWorkspace, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
err := json.NewDecoder(rr.Body).Decode(nonGitWorkspace)
|
||||
require.NoError(t, err)
|
||||
|
||||
nonGitBaseURL := "/api/v1/workspaces/" + url.PathEscape(nonGitWorkspace.Name) + "/git"
|
||||
|
||||
// Try to commit
|
||||
commitMsg := map[string]string{"message": "test"}
|
||||
rr = h.makeRequest(t, http.MethodPost, nonGitBaseURL+"/commit", commitMsg, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
|
||||
// Try to pull
|
||||
rr = h.makeRequest(t, http.MethodPost, nonGitBaseURL+"/pull", nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -4,20 +4,20 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/filesystem"
|
||||
"novamd/internal/storage"
|
||||
)
|
||||
|
||||
// Handler provides common functionality for all handlers
|
||||
type Handler struct {
|
||||
DB *db.DB
|
||||
FS *filesystem.FileSystem
|
||||
DB db.Database
|
||||
Storage storage.Manager
|
||||
}
|
||||
|
||||
// NewHandler creates a new handler with the given dependencies
|
||||
func NewHandler(db *db.DB, fs *filesystem.FileSystem) *Handler {
|
||||
func NewHandler(db db.Database, s storage.Manager) *Handler {
|
||||
return &Handler{
|
||||
DB: db,
|
||||
FS: fs,
|
||||
DB: db,
|
||||
Storage: s,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
229
server/internal/handlers/integration_test.go
Normal file
229
server/internal/handlers/integration_test.go
Normal file
@@ -0,0 +1,229 @@
|
||||
//go:build integration
|
||||
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"novamd/internal/api"
|
||||
"novamd/internal/auth"
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/git"
|
||||
"novamd/internal/handlers"
|
||||
"novamd/internal/models"
|
||||
"novamd/internal/secrets"
|
||||
"novamd/internal/storage"
|
||||
)
|
||||
|
||||
// testHarness encapsulates all the dependencies needed for testing
|
||||
type testHarness struct {
|
||||
DB db.TestDatabase
|
||||
Storage storage.Manager
|
||||
Router *chi.Mux
|
||||
Handler *handlers.Handler
|
||||
JWTManager auth.JWTManager
|
||||
SessionSvc *auth.SessionService
|
||||
AdminUser *models.User
|
||||
AdminToken string
|
||||
RegularUser *models.User
|
||||
RegularToken string
|
||||
TempDirectory string
|
||||
MockGit *MockGitClient
|
||||
}
|
||||
|
||||
// setupTestHarness creates a new test environment
|
||||
func setupTestHarness(t *testing.T) *testHarness {
|
||||
t.Helper()
|
||||
|
||||
// Create temporary directory for test files
|
||||
tempDir, err := os.MkdirTemp("", "novamd-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
|
||||
// Initialize test database
|
||||
secretsSvc, err := secrets.NewService("YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=") // test key
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize secrets service: %v", err)
|
||||
}
|
||||
|
||||
database, err := db.NewTestDB(":memory:", secretsSvc)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize test database: %v", err)
|
||||
}
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("Failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
// Create mock git client
|
||||
mockGit := NewMockGitClient(false)
|
||||
|
||||
// Create storage with mock git client
|
||||
storageOpts := storage.Options{
|
||||
NewGitClient: func(url, user, token, path string) git.Client {
|
||||
return mockGit
|
||||
},
|
||||
}
|
||||
storageSvc := storage.NewServiceWithOptions(tempDir, storageOpts)
|
||||
|
||||
// Initialize JWT service
|
||||
jwtSvc, err := auth.NewJWTService(auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize JWT service: %v", err)
|
||||
}
|
||||
|
||||
// Initialize session service
|
||||
sessionSvc := auth.NewSessionService(database, jwtSvc)
|
||||
|
||||
// Create handler
|
||||
handler := &handlers.Handler{
|
||||
DB: database,
|
||||
Storage: storageSvc,
|
||||
}
|
||||
|
||||
// Set up router with middlewares
|
||||
router := chi.NewRouter()
|
||||
authMiddleware := auth.NewMiddleware(jwtSvc)
|
||||
router.Route("/api/v1", func(r chi.Router) {
|
||||
api.SetupRoutes(r, database, storageSvc, authMiddleware, sessionSvc)
|
||||
})
|
||||
|
||||
h := &testHarness{
|
||||
DB: database,
|
||||
Storage: storageSvc,
|
||||
Router: router,
|
||||
Handler: handler,
|
||||
JWTManager: jwtSvc,
|
||||
SessionSvc: sessionSvc,
|
||||
TempDirectory: tempDir,
|
||||
MockGit: mockGit,
|
||||
}
|
||||
|
||||
// Create test users
|
||||
adminUser, adminToken := h.createTestUser(t, database, sessionSvc, "admin@test.com", "admin123", models.RoleAdmin)
|
||||
regularUser, regularToken := h.createTestUser(t, database, sessionSvc, "user@test.com", "user123", models.RoleEditor)
|
||||
|
||||
h.AdminUser = adminUser
|
||||
h.AdminToken = adminToken
|
||||
h.RegularUser = regularUser
|
||||
h.RegularToken = regularToken
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// teardownTestHarness cleans up the test environment
|
||||
func (h *testHarness) teardown(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
if err := h.DB.Close(); err != nil {
|
||||
t.Errorf("Failed to close database: %v", err)
|
||||
}
|
||||
|
||||
if err := os.RemoveAll(h.TempDirectory); err != nil {
|
||||
t.Errorf("Failed to remove temp directory: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// createTestUser creates a test user and returns the user and access token
|
||||
func (h *testHarness) createTestUser(t *testing.T, db db.Database, sessionSvc *auth.SessionService, email, password string, role models.UserRole) (*models.User, string) {
|
||||
t.Helper()
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to hash password: %v", err)
|
||||
}
|
||||
|
||||
user := &models.User{
|
||||
Email: email,
|
||||
DisplayName: "Test User",
|
||||
PasswordHash: string(hashedPassword),
|
||||
Role: role,
|
||||
}
|
||||
|
||||
user, err = db.CreateUser(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
// Initialize the default workspace directory in storage
|
||||
err = h.Storage.InitializeUserWorkspace(user.ID, user.LastWorkspaceID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize user workspace: %v", err)
|
||||
}
|
||||
|
||||
session, accessToken, err := sessionSvc.CreateSession(user.ID, string(user.Role))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
if session == nil || accessToken == "" {
|
||||
t.Fatal("Failed to get valid session or token")
|
||||
}
|
||||
|
||||
return user, accessToken
|
||||
}
|
||||
|
||||
// makeRequest is a helper function to make HTTP requests in tests
|
||||
func (h *testHarness) makeRequest(t *testing.T, method, path string, body interface{}, token string, headers map[string]string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
var reqBody []byte
|
||||
var err error
|
||||
|
||||
if body != nil {
|
||||
reqBody, err = json.Marshal(body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(method, path, bytes.NewBuffer(reqBody))
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Add any additional headers
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
h.Router.ServeHTTP(rr, req)
|
||||
|
||||
return rr
|
||||
}
|
||||
|
||||
// makeRequestRaw is a helper function to make HTTP requests with raw body content
|
||||
func (h *testHarness) makeRequestRaw(t *testing.T, method, path string, body io.Reader, token string, headers map[string]string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
req := httptest.NewRequest(method, path, body)
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
// Add any additional headers
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
h.Router.ServeHTTP(rr, req)
|
||||
|
||||
return rr
|
||||
}
|
||||
123
server/internal/handlers/mock_git_test.go
Normal file
123
server/internal/handlers/mock_git_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
//go:build integration
|
||||
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// MockGitClient implements the git.Client interface for testing
|
||||
type MockGitClient struct {
|
||||
initialized bool
|
||||
cloned bool
|
||||
lastCommitMsg string
|
||||
error error
|
||||
|
||||
pullCount int
|
||||
commitCount int
|
||||
pushCount int
|
||||
cloneCount int
|
||||
ensureCount int
|
||||
}
|
||||
|
||||
// NewMockGitClient creates a new mock git client
|
||||
func NewMockGitClient(shouldError bool) *MockGitClient {
|
||||
var err error
|
||||
if shouldError {
|
||||
err = fmt.Errorf("mock git error")
|
||||
}
|
||||
return &MockGitClient{
|
||||
error: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Clone implements git.Client
|
||||
func (m *MockGitClient) Clone() error {
|
||||
if m.error != nil {
|
||||
return m.error
|
||||
}
|
||||
m.cloneCount++
|
||||
m.cloned = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pull implements git.Client
|
||||
func (m *MockGitClient) Pull() error {
|
||||
if m.error != nil {
|
||||
return m.error
|
||||
}
|
||||
m.pullCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
// Commit implements git.Client
|
||||
func (m *MockGitClient) Commit(message string) error {
|
||||
if m.error != nil {
|
||||
return m.error
|
||||
}
|
||||
m.commitCount++
|
||||
m.lastCommitMsg = message
|
||||
return nil
|
||||
}
|
||||
|
||||
// Push implements git.Client
|
||||
func (m *MockGitClient) Push() error {
|
||||
if m.error != nil {
|
||||
return m.error
|
||||
}
|
||||
m.pushCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnsureRepo implements git.Client
|
||||
func (m *MockGitClient) EnsureRepo() error {
|
||||
if m.error != nil {
|
||||
return m.error
|
||||
}
|
||||
m.ensureCount++
|
||||
m.initialized = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper methods for tests
|
||||
|
||||
func (m *MockGitClient) GetCommitCount() int {
|
||||
return m.commitCount
|
||||
}
|
||||
|
||||
func (m *MockGitClient) GetPushCount() int {
|
||||
return m.pushCount
|
||||
}
|
||||
|
||||
func (m *MockGitClient) GetPullCount() int {
|
||||
return m.pullCount
|
||||
}
|
||||
|
||||
func (m *MockGitClient) GetLastCommitMessage() string {
|
||||
return m.lastCommitMsg
|
||||
}
|
||||
|
||||
func (m *MockGitClient) IsInitialized() bool {
|
||||
return m.initialized
|
||||
}
|
||||
|
||||
func (m *MockGitClient) IsCloned() bool {
|
||||
return m.cloned
|
||||
}
|
||||
|
||||
// Reset resets all counters and states
|
||||
func (m *MockGitClient) Reset() {
|
||||
m.initialized = false
|
||||
m.cloned = false
|
||||
m.lastCommitMsg = ""
|
||||
m.pullCount = 0
|
||||
m.commitCount = 0
|
||||
m.pushCount = 0
|
||||
m.cloneCount = 0
|
||||
m.ensureCount = 0
|
||||
}
|
||||
|
||||
// SetError sets the error state
|
||||
func (m *MockGitClient) SetError(err error) {
|
||||
m.error = err
|
||||
}
|
||||
@@ -12,12 +12,14 @@ type StaticHandler struct {
|
||||
staticPath string
|
||||
}
|
||||
|
||||
// NewStaticHandler creates a new StaticHandler with the given static path
|
||||
func NewStaticHandler(staticPath string) *StaticHandler {
|
||||
return &StaticHandler{
|
||||
staticPath: staticPath,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP serves the static files
|
||||
func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Get the requested path
|
||||
requestedPath := r.URL.Path
|
||||
|
||||
145
server/internal/handlers/static_handler_integration_test.go
Normal file
145
server/internal/handlers/static_handler_integration_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
//go:build integration
|
||||
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/handlers"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStaticHandler_Integration(t *testing.T) {
|
||||
// Create temporary directory for test static files
|
||||
tempDir, err := os.MkdirTemp("", "novamd-static-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create test files
|
||||
files := map[string][]byte{
|
||||
"index.html": []byte("<html><body>Index</body></html>"),
|
||||
"assets/style.css": []byte("body { color: blue; }"),
|
||||
"assets/style.css.gz": []byte("gzipped css content"),
|
||||
"assets/script.js": []byte("console.log('test');"),
|
||||
"assets/script.js.gz": []byte("gzipped js content"),
|
||||
"subdir/page.html": []byte("<html><body>Page</body></html>"),
|
||||
"subdir/page.html.gz": []byte("gzipped html content"),
|
||||
}
|
||||
|
||||
for path, content := range files {
|
||||
fullPath := filepath.Join(tempDir, path)
|
||||
err := os.MkdirAll(filepath.Dir(fullPath), 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(fullPath, content, 0644)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create static handler
|
||||
handler := handlers.NewStaticHandler(tempDir)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
acceptEncoding string
|
||||
wantStatus int
|
||||
wantBody []byte
|
||||
wantType string
|
||||
wantEncoding string
|
||||
wantCacheHeader string
|
||||
}{
|
||||
{
|
||||
name: "serve index.html",
|
||||
path: "/",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: []byte("<html><body>Index</body></html>"),
|
||||
wantType: "text/html; charset=utf-8",
|
||||
},
|
||||
{
|
||||
name: "serve CSS with gzip support",
|
||||
path: "/assets/style.css",
|
||||
acceptEncoding: "gzip",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: []byte("gzipped css content"),
|
||||
wantType: "text/css",
|
||||
wantEncoding: "gzip",
|
||||
wantCacheHeader: "public, max-age=31536000",
|
||||
},
|
||||
{
|
||||
name: "serve JS with gzip support",
|
||||
path: "/assets/script.js",
|
||||
acceptEncoding: "gzip",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: []byte("gzipped js content"),
|
||||
wantType: "application/javascript",
|
||||
wantEncoding: "gzip",
|
||||
wantCacheHeader: "public, max-age=31536000",
|
||||
},
|
||||
{
|
||||
name: "serve CSS without gzip",
|
||||
path: "/assets/style.css",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: []byte("body { color: blue; }"),
|
||||
wantType: "text/css; charset=utf-8",
|
||||
wantCacheHeader: "public, max-age=31536000",
|
||||
},
|
||||
{
|
||||
name: "SPA routing - nonexistent path",
|
||||
path: "/nonexistent",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: []byte("<html><body>Index</body></html>"),
|
||||
wantType: "text/html; charset=utf-8",
|
||||
},
|
||||
{
|
||||
name: "SPA routing - deep path",
|
||||
path: "/some/deep/path",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: []byte("<html><body>Index</body></html>"),
|
||||
wantType: "text/html; charset=utf-8",
|
||||
},
|
||||
{
|
||||
name: "block directory traversal",
|
||||
path: "/../../../etc/passwd",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "nonexistent file in assets",
|
||||
path: "/assets/nonexistent.js",
|
||||
wantStatus: http.StatusOK, // Should serve index.html
|
||||
wantBody: []byte("<html><body>Index</body></html>"),
|
||||
wantType: "text/html; charset=utf-8",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", tc.path, nil)
|
||||
if tc.acceptEncoding != "" {
|
||||
req.Header.Set("Accept-Encoding", tc.acceptEncoding)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tc.wantStatus, w.Code)
|
||||
|
||||
if tc.wantStatus == http.StatusOK {
|
||||
assert.Equal(t, tc.wantBody, w.Body.Bytes())
|
||||
assert.Equal(t, tc.wantType, w.Header().Get("Content-Type"))
|
||||
|
||||
if tc.wantEncoding != "" {
|
||||
assert.Equal(t, tc.wantEncoding, w.Header().Get("Content-Encoding"))
|
||||
}
|
||||
|
||||
if tc.wantCacheHeader != "" {
|
||||
assert.Equal(t, tc.wantCacheHeader, w.Header().Get("Cache-Control"))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,11 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"novamd/internal/httpcontext"
|
||||
"novamd/internal/context"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// UpdateProfileRequest represents a user profile update request
|
||||
type UpdateProfileRequest struct {
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
@@ -16,31 +17,15 @@ type UpdateProfileRequest struct {
|
||||
NewPassword string `json:"newPassword"`
|
||||
}
|
||||
|
||||
// DeleteAccountRequest represents a user account deletion request
|
||||
type DeleteAccountRequest struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func (h *Handler) GetUser() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.DB.GetUserByID(ctx.UserID)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, user)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateProfile updates the current user's profile
|
||||
func (h *Handler) UpdateProfile() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -58,14 +43,6 @@ func (h *Handler) UpdateProfile() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Start transaction for atomic updates
|
||||
tx, err := h.DB.Begin()
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to start transaction", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Handle password update if requested
|
||||
if req.NewPassword != "" {
|
||||
// Current password must be provided to change password
|
||||
@@ -131,11 +108,6 @@ func (h *Handler) UpdateProfile() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
http.Error(w, "Failed to commit changes", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Return updated user data
|
||||
respondJSON(w, user)
|
||||
}
|
||||
@@ -144,7 +116,7 @@ func (h *Handler) UpdateProfile() http.HandlerFunc {
|
||||
// DeleteAccount handles user account deletion
|
||||
func (h *Handler) DeleteAccount() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -171,8 +143,7 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
|
||||
// Prevent admin from deleting their own account if they're the last admin
|
||||
if user.Role == "admin" {
|
||||
// Count number of admin users
|
||||
adminCount := 0
|
||||
err := h.DB.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&adminCount)
|
||||
adminCount, err := h.DB.CountAdminUsers()
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to verify admin status", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -183,14 +154,6 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// Start transaction for consistent deletion
|
||||
tx, err := h.DB.Begin()
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to start transaction", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Get user's workspaces for cleanup
|
||||
workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID)
|
||||
if err != nil {
|
||||
@@ -200,7 +163,7 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
|
||||
|
||||
// Delete workspace directories
|
||||
for _, workspace := range workspaces {
|
||||
if err := h.FS.DeleteUserWorkspace(ctx.UserID, workspace.ID); err != nil {
|
||||
if err := h.Storage.DeleteUserWorkspace(ctx.UserID, workspace.ID); err != nil {
|
||||
http.Error(w, "Failed to delete workspace files", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -212,11 +175,6 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
http.Error(w, "Failed to commit transaction", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
respondJSON(w, map[string]string{"message": "Account deleted successfully"})
|
||||
}
|
||||
}
|
||||
|
||||
219
server/internal/handlers/user_handlers_integration_test.go
Normal file
219
server/internal/handlers/user_handlers_integration_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
//go:build integration
|
||||
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/handlers"
|
||||
"novamd/internal/models"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUserHandlers_Integration(t *testing.T) {
|
||||
h := setupTestHarness(t)
|
||||
defer h.teardown(t)
|
||||
|
||||
currentEmail := h.RegularUser.Email
|
||||
currentPassword := "user123"
|
||||
|
||||
t.Run("get current user", func(t *testing.T) {
|
||||
t.Run("successful get", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var user models.User
|
||||
err := json.NewDecoder(rr.Body).Decode(&user)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, h.RegularUser.ID, user.ID)
|
||||
assert.Equal(t, h.RegularUser.Email, user.Email)
|
||||
assert.Equal(t, h.RegularUser.DisplayName, user.DisplayName)
|
||||
assert.Equal(t, h.RegularUser.Role, user.Role)
|
||||
assert.Empty(t, user.PasswordHash, "Password hash should not be included in response")
|
||||
})
|
||||
|
||||
t.Run("unauthorized", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "", nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("update profile", func(t *testing.T) {
|
||||
t.Run("update display name only", func(t *testing.T) {
|
||||
updateReq := handlers.UpdateProfileRequest{
|
||||
DisplayName: "Updated Name",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var user models.User
|
||||
err := json.NewDecoder(rr.Body).Decode(&user)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, updateReq.DisplayName, user.DisplayName)
|
||||
})
|
||||
|
||||
t.Run("update email", func(t *testing.T) {
|
||||
updateReq := handlers.UpdateProfileRequest{
|
||||
Email: "newemail@test.com",
|
||||
CurrentPassword: currentPassword,
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var user models.User
|
||||
err := json.NewDecoder(rr.Body).Decode(&user)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, updateReq.Email, user.Email)
|
||||
|
||||
currentEmail = updateReq.Email
|
||||
})
|
||||
|
||||
t.Run("update email without password", func(t *testing.T) {
|
||||
updateReq := handlers.UpdateProfileRequest{
|
||||
Email: "anotheremail@test.com",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("update email with wrong password", func(t *testing.T) {
|
||||
updateReq := handlers.UpdateProfileRequest{
|
||||
Email: "wrongpass@test.com",
|
||||
CurrentPassword: "wrongpassword",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("update password", func(t *testing.T) {
|
||||
updateReq := handlers.UpdateProfileRequest{
|
||||
CurrentPassword: currentPassword,
|
||||
NewPassword: "newpassword123",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify can login with new password
|
||||
loginReq := handlers.LoginRequest{
|
||||
Email: currentEmail,
|
||||
Password: "newpassword123",
|
||||
}
|
||||
|
||||
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
currentPassword = updateReq.NewPassword
|
||||
})
|
||||
|
||||
t.Run("update password without current password", func(t *testing.T) {
|
||||
updateReq := handlers.UpdateProfileRequest{
|
||||
NewPassword: "newpass123",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("update password with wrong current password", func(t *testing.T) {
|
||||
updateReq := handlers.UpdateProfileRequest{
|
||||
CurrentPassword: "wrongpassword",
|
||||
NewPassword: "newpass123",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("update with short password", func(t *testing.T) {
|
||||
updateReq := handlers.UpdateProfileRequest{
|
||||
CurrentPassword: currentPassword,
|
||||
NewPassword: "short",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("duplicate email", func(t *testing.T) {
|
||||
updateReq := handlers.UpdateProfileRequest{
|
||||
Email: h.AdminUser.Email,
|
||||
CurrentPassword: currentPassword,
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusConflict, rr.Code)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("delete account", func(t *testing.T) {
|
||||
// Create a new user that we can delete
|
||||
createReq := handlers.CreateUserRequest{
|
||||
Email: "todelete@test.com",
|
||||
DisplayName: "To Delete",
|
||||
Password: "password123",
|
||||
Role: models.RoleEditor,
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.AdminToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var newUser models.User
|
||||
err := json.NewDecoder(rr.Body).Decode(&newUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get token for new user
|
||||
loginReq := handlers.LoginRequest{
|
||||
Email: createReq.Email,
|
||||
Password: createReq.Password,
|
||||
}
|
||||
|
||||
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var loginResp handlers.LoginResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&loginResp)
|
||||
require.NoError(t, err)
|
||||
userToken := loginResp.AccessToken
|
||||
|
||||
t.Run("successful delete", func(t *testing.T) {
|
||||
deleteReq := handlers.DeleteAccountRequest{
|
||||
Password: createReq.Password,
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodDelete, "/api/v1/profile", deleteReq, userToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify user is deleted
|
||||
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("delete with wrong password", func(t *testing.T) {
|
||||
deleteReq := handlers.DeleteAccountRequest{
|
||||
Password: "wrongpassword",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodDelete, "/api/v1/profile", deleteReq, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("prevent last admin deletion", func(t *testing.T) {
|
||||
deleteReq := handlers.DeleteAccountRequest{
|
||||
Password: "admin123", // Admin password from test harness
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodDelete, "/api/v1/profile", deleteReq, h.AdminToken, nil)
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -1,17 +1,18 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"novamd/internal/httpcontext"
|
||||
"novamd/internal/context"
|
||||
"novamd/internal/models"
|
||||
)
|
||||
|
||||
// ListWorkspaces returns a list of all workspaces for the current user
|
||||
func (h *Handler) ListWorkspaces() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -26,9 +27,10 @@ func (h *Handler) ListWorkspaces() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// CreateWorkspace creates a new workspace
|
||||
func (h *Handler) CreateWorkspace() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -39,24 +41,43 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if err := workspace.ValidateGitSettings(); err != nil {
|
||||
http.Error(w, "Invalid workspace", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
workspace.UserID = ctx.UserID
|
||||
if err := h.DB.CreateWorkspace(&workspace); err != nil {
|
||||
http.Error(w, "Failed to create workspace", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.FS.InitializeUserWorkspace(workspace.UserID, workspace.ID); err != nil {
|
||||
if err := h.Storage.InitializeUserWorkspace(workspace.UserID, workspace.ID); err != nil {
|
||||
http.Error(w, "Failed to initialize workspace directory", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if workspace.GitEnabled {
|
||||
if err := h.Storage.SetupGitRepo(
|
||||
ctx.UserID,
|
||||
workspace.ID,
|
||||
workspace.GitURL,
|
||||
workspace.GitUser,
|
||||
workspace.GitToken,
|
||||
); err != nil {
|
||||
http.Error(w, "Failed to setup git repo: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
respondJSON(w, workspace)
|
||||
}
|
||||
}
|
||||
|
||||
// GetWorkspace returns the current workspace
|
||||
func (h *Handler) GetWorkspace() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -81,9 +102,10 @@ func gitSettingsChanged(new, old *models.Workspace) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateWorkspace updates the current workspace
|
||||
func (h *Handler) UpdateWorkspace() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -107,7 +129,7 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc {
|
||||
// Handle Git repository setup/teardown if Git settings changed
|
||||
if gitSettingsChanged(&workspace, ctx.Workspace) {
|
||||
if workspace.GitEnabled {
|
||||
if err := h.FS.SetupGitRepo(
|
||||
if err := h.Storage.SetupGitRepo(
|
||||
ctx.UserID,
|
||||
ctx.Workspace.ID,
|
||||
workspace.GitURL,
|
||||
@@ -119,7 +141,7 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc {
|
||||
}
|
||||
|
||||
} else {
|
||||
h.FS.DisableGitRepo(ctx.UserID, ctx.Workspace.ID)
|
||||
h.Storage.DisableGitRepo(ctx.UserID, ctx.Workspace.ID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,9 +154,10 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteWorkspace deletes the current workspace
|
||||
func (h *Handler) DeleteWorkspace() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -168,7 +191,11 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
|
||||
http.Error(w, "Failed to start transaction", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
defer func() {
|
||||
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
||||
http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
// Update last workspace ID first
|
||||
err = h.DB.UpdateLastWorkspaceTx(tx, ctx.UserID, nextWorkspaceID)
|
||||
@@ -195,9 +222,10 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// GetLastWorkspaceName returns the name of the last opened workspace
|
||||
func (h *Handler) GetLastWorkspaceName() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -212,9 +240,10 @@ func (h *Handler) GetLastWorkspaceName() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateLastWorkspaceName updates the name of the last opened workspace
|
||||
func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := httpcontext.GetRequestContext(w, r)
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -224,13 +253,11 @@ func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc {
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
fmt.Println(err)
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.DB.UpdateLastWorkspace(ctx.UserID, requestBody.WorkspaceName); err != nil {
|
||||
fmt.Println(err)
|
||||
http.Error(w, "Failed to update last workspace", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
297
server/internal/handlers/workspace_handlers_integration_test.go
Normal file
297
server/internal/handlers/workspace_handlers_integration_test.go
Normal file
@@ -0,0 +1,297 @@
|
||||
//go:build integration
|
||||
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/models"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWorkspaceHandlers_Integration(t *testing.T) {
|
||||
h := setupTestHarness(t)
|
||||
defer h.teardown(t)
|
||||
|
||||
t.Run("list workspaces", func(t *testing.T) {
|
||||
t.Run("successful list", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces", nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var workspaces []*models.Workspace
|
||||
err := json.NewDecoder(rr.Body).Decode(&workspaces)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, workspaces, "User should have at least one default workspace")
|
||||
})
|
||||
|
||||
t.Run("unauthorized", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces", nil, "", nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("create workspace", func(t *testing.T) {
|
||||
t.Run("successful create", func(t *testing.T) {
|
||||
workspace := &models.Workspace{
|
||||
Name: "Test Workspace",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var created models.Workspace
|
||||
err := json.NewDecoder(rr.Body).Decode(&created)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, workspace.Name, created.Name)
|
||||
assert.Equal(t, h.RegularUser.ID, created.UserID)
|
||||
assert.NotZero(t, created.ID)
|
||||
})
|
||||
|
||||
t.Run("create with git settings", func(t *testing.T) {
|
||||
workspace := &models.Workspace{
|
||||
Name: "Git Workspace",
|
||||
GitEnabled: true,
|
||||
GitURL: "https://github.com/test/repo.git",
|
||||
GitUser: "testuser",
|
||||
GitToken: "testtoken",
|
||||
GitAutoCommit: true,
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var created models.Workspace
|
||||
err := json.NewDecoder(rr.Body).Decode(&created)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, workspace.GitEnabled, created.GitEnabled)
|
||||
assert.Equal(t, workspace.GitURL, created.GitURL)
|
||||
assert.Equal(t, workspace.GitUser, created.GitUser)
|
||||
assert.Equal(t, workspace.GitToken, created.GitToken)
|
||||
assert.Equal(t, workspace.GitAutoCommit, created.GitAutoCommit)
|
||||
})
|
||||
|
||||
t.Run("invalid workspace", func(t *testing.T) {
|
||||
workspace := &models.Workspace{
|
||||
Name: "", // Empty name
|
||||
GitEnabled: true,
|
||||
// Missing required Git settings
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
})
|
||||
})
|
||||
|
||||
// Create a workspace for the remaining tests
|
||||
workspace := &models.Workspace{
|
||||
Name: "Test Workspace Operations",
|
||||
}
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
err := json.NewDecoder(rr.Body).Decode(workspace)
|
||||
require.NoError(t, err)
|
||||
|
||||
escapedName := url.PathEscape(workspace.Name)
|
||||
baseURL := "/api/v1/workspaces/" + escapedName
|
||||
|
||||
t.Run("get workspace", func(t *testing.T) {
|
||||
t.Run("successful get", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodGet, baseURL, nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var got models.Workspace
|
||||
err := json.NewDecoder(rr.Body).Decode(&got)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, workspace.ID, got.ID)
|
||||
assert.Equal(t, workspace.Name, got.Name)
|
||||
})
|
||||
|
||||
t.Run("nonexistent workspace", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces/nonexistent", nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusNotFound, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("unauthorized access", func(t *testing.T) {
|
||||
// Try accessing with another user's token
|
||||
rr := h.makeRequest(t, http.MethodGet, baseURL, nil, h.AdminToken, nil)
|
||||
assert.Equal(t, http.StatusNotFound, rr.Code)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("update workspace", func(t *testing.T) {
|
||||
t.Run("update name", func(t *testing.T) {
|
||||
workspace.Name = "Updated Workspace"
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPut, baseURL, workspace, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var updated models.Workspace
|
||||
err := json.NewDecoder(rr.Body).Decode(&updated)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, workspace.Name, updated.Name)
|
||||
|
||||
// Update baseURL for remaining tests
|
||||
escapedName = url.PathEscape(workspace.Name)
|
||||
baseURL = "/api/v1/workspaces/" + escapedName
|
||||
})
|
||||
|
||||
t.Run("update settings", func(t *testing.T) {
|
||||
update := &models.Workspace{
|
||||
Name: workspace.Name,
|
||||
Theme: "dark",
|
||||
AutoSave: true,
|
||||
ShowHiddenFiles: true,
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPut, baseURL, update, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var updated models.Workspace
|
||||
err := json.NewDecoder(rr.Body).Decode(&updated)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, update.Theme, updated.Theme)
|
||||
assert.Equal(t, update.AutoSave, updated.AutoSave)
|
||||
assert.Equal(t, update.ShowHiddenFiles, updated.ShowHiddenFiles)
|
||||
})
|
||||
|
||||
t.Run("enable git", func(t *testing.T) {
|
||||
update := &models.Workspace{
|
||||
Name: workspace.Name,
|
||||
Theme: "dark",
|
||||
GitEnabled: true,
|
||||
GitURL: "https://github.com/test/repo.git",
|
||||
GitUser: "testuser",
|
||||
GitToken: "testtoken",
|
||||
GitAutoCommit: true,
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPut, baseURL, update, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var updated models.Workspace
|
||||
err := json.NewDecoder(rr.Body).Decode(&updated)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, update.GitEnabled, updated.GitEnabled)
|
||||
assert.Equal(t, update.GitURL, updated.GitURL)
|
||||
assert.Equal(t, update.GitUser, updated.GitUser)
|
||||
assert.Equal(t, update.GitToken, updated.GitToken)
|
||||
|
||||
// Mock should have been called to setup git
|
||||
assert.True(t, h.MockGit.IsInitialized())
|
||||
})
|
||||
|
||||
t.Run("invalid git settings", func(t *testing.T) {
|
||||
update := &models.Workspace{
|
||||
Name: workspace.Name,
|
||||
GitEnabled: true,
|
||||
// Missing required Git settings
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPut, baseURL, update, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("last workspace", func(t *testing.T) {
|
||||
t.Run("get last workspace", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces/last", nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var response struct {
|
||||
LastWorkspaceName string `json:"lastWorkspaceName"`
|
||||
}
|
||||
err := json.NewDecoder(rr.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, response.LastWorkspaceName)
|
||||
})
|
||||
|
||||
t.Run("update last workspace", func(t *testing.T) {
|
||||
req := struct {
|
||||
WorkspaceName string `json:"workspaceName"`
|
||||
}{
|
||||
WorkspaceName: workspace.Name,
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPut, "/api/v1/workspaces/last", req, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify the update
|
||||
rr = h.makeRequest(t, http.MethodGet, "/api/v1/workspaces/last", nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var response struct {
|
||||
LastWorkspaceName string `json:"lastWorkspaceName"`
|
||||
}
|
||||
err := json.NewDecoder(rr.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, workspace.Name, response.LastWorkspaceName)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("delete workspace", func(t *testing.T) {
|
||||
// Get current workspaces to know how many we have
|
||||
rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces", nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var existingWorkspaces []*models.Workspace
|
||||
err := json.NewDecoder(rr.Body).Decode(&existingWorkspaces)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a new workspace we can safely delete
|
||||
newWorkspace := &models.Workspace{
|
||||
Name: "Workspace To Delete",
|
||||
}
|
||||
rr = h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", newWorkspace, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
err = json.NewDecoder(rr.Body).Decode(newWorkspace)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("successful delete", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodDelete, "/api/v1/workspaces/"+url.PathEscape(newWorkspace.Name), nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var response struct {
|
||||
NextWorkspaceName string `json:"nextWorkspaceName"`
|
||||
}
|
||||
err := json.NewDecoder(rr.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, response.NextWorkspaceName)
|
||||
|
||||
// Verify workspace is deleted
|
||||
rr = h.makeRequest(t, http.MethodGet, "/api/v1/workspaces/"+url.PathEscape(newWorkspace.Name), nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusNotFound, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("prevent deleting last workspace", func(t *testing.T) {
|
||||
// Delete all but one workspace
|
||||
for i := 0; i < len(existingWorkspaces)-1; i++ {
|
||||
ws := existingWorkspaces[i]
|
||||
rr := h.makeRequest(t, http.MethodDelete, "/api/v1/workspaces/"+url.PathEscape(ws.Name), nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
// Try to delete the last remaining workspace
|
||||
lastWs := existingWorkspaces[len(existingWorkspaces)-1]
|
||||
rr := h.makeRequest(t, http.MethodDelete, "/api/v1/workspaces/"+url.PathEscape(lastWs.Name), nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("unauthorized deletion", func(t *testing.T) {
|
||||
// Create a workspace to attempt unauthorized deletion
|
||||
workspace := &models.Workspace{
|
||||
Name: "Unauthorized Delete Test",
|
||||
}
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Try to delete with wrong user's token
|
||||
rr = h.makeRequest(t, http.MethodDelete, "/api/v1/workspaces/"+url.PathEscape(workspace.Name), nil, h.AdminToken, nil)
|
||||
assert.Equal(t, http.StatusNotFound, rr.Code)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
package httpcontext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"novamd/internal/models"
|
||||
)
|
||||
|
||||
// HandlerContext holds the request-specific data available to all handlers
|
||||
type HandlerContext struct {
|
||||
UserID int
|
||||
UserRole string
|
||||
Workspace *models.Workspace
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const HandlerContextKey contextKey = "handlerContext"
|
||||
|
||||
func GetRequestContext(w http.ResponseWriter, r *http.Request) (*HandlerContext, bool) {
|
||||
ctx := r.Context().Value(HandlerContextKey)
|
||||
if ctx == nil {
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return nil, false
|
||||
}
|
||||
return ctx.(*HandlerContext), true
|
||||
}
|
||||
|
||||
func WithHandlerContext(r *http.Request, hctx *HandlerContext) *http.Request {
|
||||
return r.WithContext(context.WithValue(r.Context(), HandlerContextKey, hctx))
|
||||
}
|
||||
13
server/internal/models/session.go
Normal file
13
server/internal/models/session.go
Normal file
@@ -0,0 +1,13 @@
|
||||
// Package models contains the data models used throughout the application. These models are used to represent data in the database, as well as to validate and serialize data in the application.
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// Session represents a user session in the database
|
||||
type Session struct {
|
||||
ID string // Unique session identifier
|
||||
UserID int // ID of the user this session belongs to
|
||||
RefreshToken string // The refresh token associated with this session
|
||||
ExpiresAt time.Time // When this session expires
|
||||
CreatedAt time.Time // When this session was created
|
||||
}
|
||||
@@ -8,14 +8,17 @@ import (
|
||||
|
||||
var validate = validator.New()
|
||||
|
||||
// UserRole represents the role of a user in the system
|
||||
type UserRole string
|
||||
|
||||
// User roles
|
||||
const (
|
||||
RoleAdmin UserRole = "admin"
|
||||
RoleEditor UserRole = "editor"
|
||||
RoleViewer UserRole = "viewer"
|
||||
)
|
||||
|
||||
// User represents a user in the system
|
||||
type User struct {
|
||||
ID int `json:"id" validate:"required,min=1"`
|
||||
Email string `json:"email" validate:"required,email"`
|
||||
@@ -26,6 +29,7 @@ type User struct {
|
||||
LastWorkspaceID int `json:"lastWorkspaceId"`
|
||||
}
|
||||
|
||||
// Validate validates the user struct
|
||||
func (u *User) Validate() error {
|
||||
return validate.Struct(u)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Workspace represents a user's workspace in the system
|
||||
type Workspace struct {
|
||||
ID int `json:"id" validate:"required,min=1"`
|
||||
UserID int `json:"userId" validate:"required,min=1"`
|
||||
@@ -23,18 +24,30 @@ type Workspace struct {
|
||||
GitCommitMsgTemplate string `json:"gitCommitMsgTemplate"`
|
||||
}
|
||||
|
||||
// Validate validates the workspace struct
|
||||
func (w *Workspace) Validate() error {
|
||||
return validate.Struct(w)
|
||||
}
|
||||
|
||||
func (w *Workspace) GetDefaultSettings() {
|
||||
w.Theme = "light"
|
||||
w.AutoSave = false
|
||||
w.ShowHiddenFiles = false
|
||||
w.GitEnabled = false
|
||||
w.GitURL = ""
|
||||
w.GitUser = ""
|
||||
w.GitToken = ""
|
||||
w.GitAutoCommit = false
|
||||
w.GitCommitMsgTemplate = "${action} ${filename}"
|
||||
// ValidateGitSettings validates the git settings if git is enabled
|
||||
func (w *Workspace) ValidateGitSettings() error {
|
||||
return validate.StructExcept(w, "ID", "UserID", "Theme")
|
||||
}
|
||||
|
||||
// SetDefaultSettings sets the default settings for the workspace
|
||||
func (w *Workspace) SetDefaultSettings() {
|
||||
|
||||
if w.Theme == "" {
|
||||
w.Theme = "light"
|
||||
}
|
||||
|
||||
w.AutoSave = w.AutoSave || false
|
||||
w.ShowHiddenFiles = w.ShowHiddenFiles || false
|
||||
w.GitEnabled = w.GitEnabled || false
|
||||
|
||||
w.GitAutoCommit = w.GitEnabled && (w.GitAutoCommit || false)
|
||||
|
||||
if w.GitCommitMsgTemplate == "" {
|
||||
w.GitCommitMsgTemplate = "${action} ${filename}"
|
||||
}
|
||||
}
|
||||
|
||||
112
server/internal/secrets/secrets.go
Normal file
112
server/internal/secrets/secrets.go
Normal file
@@ -0,0 +1,112 @@
|
||||
// Package secrets provides an Encryptor interface for encrypting and decrypting strings using AES-256-GCM.
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Service is an interface for encrypting and decrypting strings
|
||||
type Service interface {
|
||||
Encrypt(plaintext string) (string, error)
|
||||
Decrypt(ciphertext string) (string, error)
|
||||
}
|
||||
|
||||
type encryptor struct {
|
||||
gcm cipher.AEAD
|
||||
}
|
||||
|
||||
// ValidateKey checks if the provided base64-encoded key is suitable for AES-256
|
||||
func ValidateKey(key string) error {
|
||||
_, err := decodeAndValidateKey(key)
|
||||
return err
|
||||
}
|
||||
|
||||
// decodeAndValidateKey validates and decodes the base64-encoded key
|
||||
// Returns the decoded key bytes if valid
|
||||
func decodeAndValidateKey(key string) ([]byte, error) {
|
||||
if key == "" {
|
||||
return nil, fmt.Errorf("encryption key is required")
|
||||
}
|
||||
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 encoding: %w", err)
|
||||
}
|
||||
|
||||
if len(keyBytes) != 32 {
|
||||
return nil, fmt.Errorf("encryption key must be 32 bytes (256 bits): got %d bytes", len(keyBytes))
|
||||
}
|
||||
|
||||
// Verify the key can be used for AES
|
||||
_, err = aes.NewCipher(keyBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid encryption key: %w", err)
|
||||
}
|
||||
|
||||
return keyBytes, nil
|
||||
}
|
||||
|
||||
// NewService creates a new Encryptor instance with the provided base64-encoded key
|
||||
func NewService(key string) (Service, error) {
|
||||
keyBytes, err := decodeAndValidateKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(keyBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
return &encryptor{gcm: gcm}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts the plaintext using AES-256-GCM
|
||||
func (e *encryptor) Encrypt(plaintext string) (string, error) {
|
||||
if plaintext == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
nonce := make([]byte, e.gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := e.gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the ciphertext using AES-256-GCM
|
||||
func (e *encryptor) Decrypt(ciphertext string) (string, error) {
|
||||
if ciphertext == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base64 encoding: %w", err)
|
||||
}
|
||||
|
||||
nonceSize := e.gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", fmt.Errorf("invalid ciphertext: too short")
|
||||
}
|
||||
|
||||
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
|
||||
plaintext, err := e.gcm.Open(nil, nonce, ciphertextBytes, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
257
server/internal/secrets/secrets_test.go
Normal file
257
server/internal/secrets/secrets_test.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package secrets_test
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/secrets"
|
||||
)
|
||||
|
||||
func TestValidateKey(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
key string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid 32-byte base64 key",
|
||||
key: base64.StdEncoding.EncodeToString(make([]byte, 32)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty key",
|
||||
key: "",
|
||||
wantErr: true,
|
||||
errContains: "encryption key is required",
|
||||
},
|
||||
{
|
||||
name: "invalid base64",
|
||||
key: "not-base64!@#$",
|
||||
wantErr: true,
|
||||
errContains: "invalid base64 encoding",
|
||||
},
|
||||
{
|
||||
name: "wrong key size (16 bytes)",
|
||||
key: base64.StdEncoding.EncodeToString(make([]byte, 16)),
|
||||
wantErr: true,
|
||||
errContains: "encryption key must be 32 bytes",
|
||||
},
|
||||
{
|
||||
name: "wrong key size (64 bytes)",
|
||||
key: base64.StdEncoding.EncodeToString(make([]byte, 64)),
|
||||
wantErr: true,
|
||||
errContains: "encryption key must be 32 bytes",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := secrets.ValidateKey(tc.key)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
key string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid key",
|
||||
key: base64.StdEncoding.EncodeToString(make([]byte, 32)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty key",
|
||||
key: "",
|
||||
wantErr: true,
|
||||
errContains: "encryption key is required",
|
||||
},
|
||||
{
|
||||
name: "invalid key",
|
||||
key: "invalid",
|
||||
wantErr: true,
|
||||
errContains: "invalid base64 encoding",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e, err := secrets.NewService(tc.key)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if e == nil {
|
||||
t.Error("expected Encryptor instance, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
// Generate a valid key for testing
|
||||
key := base64.StdEncoding.EncodeToString(make([]byte, 32))
|
||||
e, err := secrets.NewService(key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Encryptor instance: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
plaintext string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normal text",
|
||||
plaintext: "Hello, World!",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
plaintext: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "long text",
|
||||
plaintext: strings.Repeat("Long text with lots of content. ", 100),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
plaintext: "!@#$%^&*()_+-=[]{}|;:,.<>?",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
plaintext: "Hello, 世界! नमस्ते",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Test encryption
|
||||
ciphertext, err := e.Encrypt(tc.plaintext)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected encryption error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected encryption error: %v", err)
|
||||
}
|
||||
|
||||
// Verify ciphertext is different from plaintext
|
||||
if tc.plaintext != "" && ciphertext == tc.plaintext {
|
||||
t.Error("ciphertext matches plaintext")
|
||||
}
|
||||
|
||||
// Test decryption
|
||||
decrypted, err := e.Decrypt(ciphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected decryption error: %v", err)
|
||||
}
|
||||
|
||||
// Verify decrypted text matches original
|
||||
if decrypted != tc.plaintext {
|
||||
t.Errorf("decrypted text = %q, want %q", decrypted, tc.plaintext)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptInvalidCiphertext(t *testing.T) {
|
||||
key := base64.StdEncoding.EncodeToString(make([]byte, 32))
|
||||
e, err := secrets.NewService(key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Encryptor instance: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
ciphertext string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "empty ciphertext",
|
||||
ciphertext: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid base64",
|
||||
ciphertext: "not-base64!@#$",
|
||||
wantErr: true,
|
||||
errContains: "invalid base64 encoding",
|
||||
},
|
||||
{
|
||||
name: "invalid ciphertext (too short)",
|
||||
ciphertext: base64.StdEncoding.EncodeToString(make([]byte, 10)),
|
||||
wantErr: true,
|
||||
errContains: "invalid ciphertext: too short",
|
||||
},
|
||||
{
|
||||
name: "tampered ciphertext",
|
||||
ciphertext: base64.StdEncoding.EncodeToString(make([]byte, 50)),
|
||||
wantErr: true,
|
||||
errContains: "message authentication failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
decrypted, err := e.Decrypt(tc.ciphertext)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if decrypted != "" {
|
||||
t.Errorf("expected empty string, got %q", decrypted)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
24
server/internal/storage/errors.go
Normal file
24
server/internal/storage/errors.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// storage/errors.go
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// PathValidationError represents a path validation error (e.g., path traversal attempt)
|
||||
type PathValidationError struct {
|
||||
Path string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *PathValidationError) Error() string {
|
||||
return fmt.Sprintf("%s: %s", e.Message, e.Path)
|
||||
}
|
||||
|
||||
// IsPathValidationError checks if the error is a PathValidationError
|
||||
func IsPathValidationError(err error) bool {
|
||||
var pathErr *PathValidationError
|
||||
return err != nil && errors.As(err, &pathErr)
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
// Package filesystem provides functionalities to interact with the file system,
|
||||
// Package storage provides functionalities to interact with the file system,
|
||||
// including listing files, finding files by name, getting file content, saving files, and deleting files.
|
||||
package filesystem
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -10,7 +10,18 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FileNode represents a file or directory in the file system.
|
||||
// FileManager provides functionalities to interact with files in the storage.
|
||||
type FileManager interface {
|
||||
ListFilesRecursively(userID, workspaceID int) ([]FileNode, error)
|
||||
FindFileByName(userID, workspaceID int, filename string) ([]string, error)
|
||||
GetFileContent(userID, workspaceID int, filePath string) ([]byte, error)
|
||||
SaveFile(userID, workspaceID int, filePath string, content []byte) error
|
||||
DeleteFile(userID, workspaceID int, filePath string) error
|
||||
GetFileStats(userID, workspaceID int) (*FileCountStats, error)
|
||||
GetTotalFileStats() (*FileCountStats, error)
|
||||
}
|
||||
|
||||
// FileNode represents a file or directory in the storage.
|
||||
type FileNode struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
@@ -19,13 +30,15 @@ type FileNode struct {
|
||||
}
|
||||
|
||||
// ListFilesRecursively returns a list of all files in the workspace directory and its subdirectories.
|
||||
func (fs *FileSystem) ListFilesRecursively(userID, workspaceID int) ([]FileNode, error) {
|
||||
workspacePath := fs.GetWorkspacePath(userID, workspaceID)
|
||||
return fs.walkDirectory(workspacePath, "")
|
||||
// Workspace is identified by the given userID and workspaceID.
|
||||
func (s *Service) ListFilesRecursively(userID, workspaceID int) ([]FileNode, error) {
|
||||
workspacePath := s.GetWorkspacePath(userID, workspaceID)
|
||||
return s.walkDirectory(workspacePath, "")
|
||||
}
|
||||
|
||||
func (fs *FileSystem) walkDirectory(dir, prefix string) ([]FileNode, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
// walkDirectory recursively walks the directory and returns a list of files and directories.
|
||||
func (s *Service) walkDirectory(dir, prefix string) ([]FileNode, error) {
|
||||
entries, err := s.fs.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -57,7 +70,7 @@ func (fs *FileSystem) walkDirectory(dir, prefix string) ([]FileNode, error) {
|
||||
path := filepath.Join(prefix, name)
|
||||
fullPath := filepath.Join(dir, name)
|
||||
|
||||
children, err := fs.walkDirectory(fullPath, path)
|
||||
children, err := s.walkDirectory(fullPath, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -88,9 +101,11 @@ func (fs *FileSystem) walkDirectory(dir, prefix string) ([]FileNode, error) {
|
||||
}
|
||||
|
||||
// FindFileByName returns a list of file paths that match the given filename.
|
||||
func (fs *FileSystem) FindFileByName(userID, workspaceID int, filename string) ([]string, error) {
|
||||
// Files are searched recursively in the workspace directory and its subdirectories.
|
||||
// Workspace is identified by the given userID and workspaceID.
|
||||
func (s *Service) FindFileByName(userID, workspaceID int, filename string) ([]string, error) {
|
||||
var foundPaths []string
|
||||
workspacePath := fs.GetWorkspacePath(userID, workspaceID)
|
||||
workspacePath := s.GetWorkspacePath(userID, workspaceID)
|
||||
|
||||
err := filepath.Walk(workspacePath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
@@ -119,37 +134,40 @@ func (fs *FileSystem) FindFileByName(userID, workspaceID int, filename string) (
|
||||
return foundPaths, nil
|
||||
}
|
||||
|
||||
// GetFileContent returns the content of the file at the given path.
|
||||
func (fs *FileSystem) GetFileContent(userID, workspaceID int, filePath string) ([]byte, error) {
|
||||
fullPath, err := fs.ValidatePath(userID, workspaceID, filePath)
|
||||
// GetFileContent returns the content of the file at the given filePath.
|
||||
// Path must be a relative path within the workspace directory given by userID and workspaceID.
|
||||
func (s *Service) GetFileContent(userID, workspaceID int, filePath string) ([]byte, error) {
|
||||
fullPath, err := s.ValidatePath(userID, workspaceID, filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return os.ReadFile(fullPath)
|
||||
return s.fs.ReadFile(fullPath)
|
||||
}
|
||||
|
||||
// SaveFile writes the content to the file at the given path.
|
||||
func (fs *FileSystem) SaveFile(userID, workspaceID int, filePath string, content []byte) error {
|
||||
fullPath, err := fs.ValidatePath(userID, workspaceID, filePath)
|
||||
// SaveFile writes the content to the file at the given filePath.
|
||||
// Path must be a relative path within the workspace directory given by userID and workspaceID.
|
||||
func (s *Service) SaveFile(userID, workspaceID int, filePath string, content []byte) error {
|
||||
fullPath, err := s.ValidatePath(userID, workspaceID, filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dir := filepath.Dir(fullPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
if err := s.fs.MkdirAll(dir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(fullPath, content, 0644)
|
||||
return s.fs.WriteFile(fullPath, content, 0644)
|
||||
}
|
||||
|
||||
// DeleteFile deletes the file at the given path.
|
||||
func (fs *FileSystem) DeleteFile(userID, workspaceID int, filePath string) error {
|
||||
fullPath, err := fs.ValidatePath(userID, workspaceID, filePath)
|
||||
// DeleteFile deletes the file at the given filePath.
|
||||
// Path must be a relative path within the workspace directory given by userID and workspaceID.
|
||||
func (s *Service) DeleteFile(userID, workspaceID int, filePath string) error {
|
||||
fullPath, err := s.ValidatePath(userID, workspaceID, filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Remove(fullPath)
|
||||
return s.fs.Remove(fullPath)
|
||||
}
|
||||
|
||||
// FileCountStats holds statistics about files in a workspace
|
||||
@@ -159,25 +177,26 @@ type FileCountStats struct {
|
||||
}
|
||||
|
||||
// GetFileStats returns the total number of files and related statistics in a workspace
|
||||
// Parameters:
|
||||
// - userID: the ID of the user who owns the workspace
|
||||
// - workspaceID: the ID of the workspace to count files in
|
||||
// Returns:
|
||||
// - result: statistics about the files in the workspace
|
||||
// - error: any error that occurred during counting
|
||||
func (fs *FileSystem) GetFileStats(userID, workspaceID int) (*FileCountStats, error) {
|
||||
workspacePath := fs.GetWorkspacePath(userID, workspaceID)
|
||||
// Workspace is identified by the given userID and workspaceID
|
||||
func (s *Service) GetFileStats(userID, workspaceID int) (*FileCountStats, error) {
|
||||
workspacePath := s.GetWorkspacePath(userID, workspaceID)
|
||||
|
||||
// Check if workspace exists
|
||||
if _, err := os.Stat(workspacePath); os.IsNotExist(err) {
|
||||
if _, err := s.fs.Stat(workspacePath); s.fs.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("workspace directory does not exist")
|
||||
}
|
||||
|
||||
return fs.countFilesInPath(workspacePath)
|
||||
return s.countFilesInPath(workspacePath)
|
||||
|
||||
}
|
||||
|
||||
func (fs *FileSystem) countFilesInPath(directoryPath string) (*FileCountStats, error) {
|
||||
// GetTotalFileStats returns the total file statistics for the storage.
|
||||
func (s *Service) GetTotalFileStats() (*FileCountStats, error) {
|
||||
return s.countFilesInPath(s.RootDir)
|
||||
}
|
||||
|
||||
// countFilesInPath counts the total number of files and the total size of files in the given directory.
|
||||
func (s *Service) countFilesInPath(directoryPath string) (*FileCountStats, error) {
|
||||
result := &FileCountStats{}
|
||||
|
||||
err := filepath.WalkDir(directoryPath, func(path string, d os.DirEntry, err error) error {
|
||||
407
server/internal/storage/files_test.go
Normal file
407
server/internal/storage/files_test.go
Normal file
@@ -0,0 +1,407 @@
|
||||
package storage_test
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"novamd/internal/storage"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestFileNode ensures FileNode structs are created correctly
|
||||
func TestFileNode(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string // name of the test case
|
||||
node storage.FileNode
|
||||
want storage.FileNode
|
||||
}{
|
||||
{
|
||||
name: "file without children",
|
||||
node: storage.FileNode{
|
||||
ID: "test.md",
|
||||
Name: "test.md",
|
||||
Path: "test.md",
|
||||
},
|
||||
want: storage.FileNode{
|
||||
ID: "test.md",
|
||||
Name: "test.md",
|
||||
Path: "test.md",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "directory with children",
|
||||
node: storage.FileNode{
|
||||
ID: "dir",
|
||||
Name: "dir",
|
||||
Path: "dir",
|
||||
Children: []storage.FileNode{
|
||||
{
|
||||
ID: "dir/file1.md",
|
||||
Name: "file1.md",
|
||||
Path: "dir/file1.md",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: storage.FileNode{
|
||||
ID: "dir",
|
||||
Name: "dir",
|
||||
Path: "dir",
|
||||
Children: []storage.FileNode{
|
||||
{
|
||||
ID: "dir/file1.md",
|
||||
Name: "file1.md",
|
||||
Path: "dir/file1.md",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tc.node // Now we're testing the actual node structure
|
||||
|
||||
if got.ID != tc.want.ID {
|
||||
t.Errorf("ID = %v, want %v", got.ID, tc.want.ID)
|
||||
}
|
||||
if got.Name != tc.want.Name {
|
||||
t.Errorf("Name = %v, want %v", got.Name, tc.want.Name)
|
||||
}
|
||||
if got.Path != tc.want.Path {
|
||||
t.Errorf("Path = %v, want %v", got.Path, tc.want.Path)
|
||||
}
|
||||
if len(got.Children) != len(tc.want.Children) {
|
||||
t.Errorf("len(Children) = %v, want %v", len(got.Children), len(tc.want.Children))
|
||||
}
|
||||
// Add deep comparison of children if they exist
|
||||
if len(got.Children) > 0 {
|
||||
for i := range got.Children {
|
||||
if got.Children[i].ID != tc.want.Children[i].ID {
|
||||
t.Errorf("Children[%d].ID = %v, want %v", i, got.Children[i].ID, tc.want.Children[i].ID)
|
||||
}
|
||||
if got.Children[i].Name != tc.want.Children[i].Name {
|
||||
t.Errorf("Children[%d].Name = %v, want %v", i, got.Children[i].Name, tc.want.Children[i].Name)
|
||||
}
|
||||
if got.Children[i].Path != tc.want.Children[i].Path {
|
||||
t.Errorf("Children[%d].Path = %v, want %v", i, got.Children[i].Path, tc.want.Children[i].Path)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListFilesRecursively(t *testing.T) {
|
||||
mockFS := NewMockFS()
|
||||
s := storage.NewServiceWithOptions("test-root", storage.Options{
|
||||
Fs: mockFS,
|
||||
NewGitClient: nil,
|
||||
})
|
||||
|
||||
t.Run("empty directory", func(t *testing.T) {
|
||||
mockFS.ReadDirReturns = map[string]struct {
|
||||
entries []fs.DirEntry
|
||||
err error
|
||||
}{
|
||||
"test-root/1/1": {
|
||||
entries: []fs.DirEntry{},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
files, err := s.ListFilesRecursively(1, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(files) != 0 {
|
||||
t.Errorf("expected empty file list, got %v", files)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("directory with files", func(t *testing.T) {
|
||||
mockFS.ReadDirReturns = map[string]struct {
|
||||
entries []fs.DirEntry
|
||||
err error
|
||||
}{
|
||||
"test-root/1/1": {
|
||||
entries: []fs.DirEntry{
|
||||
NewMockDirEntry("file1.md", false),
|
||||
NewMockDirEntry("file2.md", false),
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
files, err := s.ListFilesRecursively(1, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(files) != 2 {
|
||||
t.Errorf("expected 2 files, got %d", len(files))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nested directories", func(t *testing.T) {
|
||||
mockFS.ReadDirReturns = map[string]struct {
|
||||
entries []fs.DirEntry
|
||||
err error
|
||||
}{
|
||||
"test-root/1/1": {
|
||||
entries: []fs.DirEntry{
|
||||
NewMockDirEntry("dir1", true),
|
||||
NewMockDirEntry("file1.md", false),
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
"test-root/1/1/dir1": {
|
||||
entries: []fs.DirEntry{
|
||||
NewMockDirEntry("file2.md", false),
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
files, err := s.ListFilesRecursively(1, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(files) != 2 { // dir1 and file1.md
|
||||
t.Errorf("expected 2 entries at root, got %d", len(files))
|
||||
}
|
||||
|
||||
// Find directory and check its children
|
||||
var dirFound bool
|
||||
for _, f := range files {
|
||||
if f.Name == "dir1" {
|
||||
dirFound = true
|
||||
if len(f.Children) != 1 {
|
||||
t.Errorf("expected 1 child in dir1, got %d", len(f.Children))
|
||||
}
|
||||
}
|
||||
}
|
||||
if !dirFound {
|
||||
t.Error("directory 'dir1' not found in results")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFileContent(t *testing.T) {
|
||||
mockFS := NewMockFS()
|
||||
s := storage.NewServiceWithOptions("test-root", storage.Options{
|
||||
Fs: mockFS,
|
||||
NewGitClient: nil,
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
workspaceID int
|
||||
filePath string
|
||||
mockData []byte
|
||||
mockErr error
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful read",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
filePath: "test.md",
|
||||
mockData: []byte("test content"),
|
||||
mockErr: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "file not found",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
filePath: "nonexistent.md",
|
||||
mockData: nil,
|
||||
mockErr: fs.ErrNotExist,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid path",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
filePath: "../../../etc/passwd",
|
||||
mockData: nil,
|
||||
mockErr: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
expectedPath := filepath.Join("test-root", "1", "1", tc.filePath)
|
||||
mockFS.ReadFileReturns[expectedPath] = struct {
|
||||
data []byte
|
||||
err error
|
||||
}{tc.mockData, tc.mockErr}
|
||||
|
||||
content, err := s.GetFileContent(tc.userID, tc.workspaceID, tc.filePath)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if string(content) != string(tc.mockData) {
|
||||
t.Errorf("content = %q, want %q", content, tc.mockData)
|
||||
}
|
||||
|
||||
if mockFS.ReadCalls[expectedPath] != 1 {
|
||||
t.Errorf("expected 1 read call for %s, got %d", expectedPath, mockFS.ReadCalls[expectedPath])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveFile(t *testing.T) {
|
||||
mockFS := NewMockFS()
|
||||
s := storage.NewServiceWithOptions("test-root", storage.Options{
|
||||
Fs: mockFS,
|
||||
NewGitClient: nil,
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
workspaceID int
|
||||
filePath string
|
||||
content []byte
|
||||
mockErr error
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful save",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
filePath: "test.md",
|
||||
content: []byte("test content"),
|
||||
mockErr: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid path",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
filePath: "../../../etc/passwd",
|
||||
content: []byte("test content"),
|
||||
mockErr: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "write error",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
filePath: "test.md",
|
||||
content: []byte("test content"),
|
||||
mockErr: fs.ErrPermission,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockFS.WriteFileError = tc.mockErr
|
||||
err := s.SaveFile(tc.userID, tc.workspaceID, tc.filePath, tc.content)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expectedPath := filepath.Join("test-root", "1", "1", tc.filePath)
|
||||
if content, ok := mockFS.WriteCalls[expectedPath]; ok {
|
||||
if string(content) != string(tc.content) {
|
||||
t.Errorf("written content = %q, want %q", content, tc.content)
|
||||
}
|
||||
} else {
|
||||
t.Error("expected write call not made")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteFile(t *testing.T) {
|
||||
mockFS := NewMockFS()
|
||||
s := storage.NewServiceWithOptions("test-root", storage.Options{
|
||||
Fs: mockFS,
|
||||
NewGitClient: nil,
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
workspaceID int
|
||||
filePath string
|
||||
mockErr error
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful delete",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
filePath: "test.md",
|
||||
mockErr: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid path",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
filePath: "../../../etc/passwd",
|
||||
mockErr: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "file not found",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
filePath: "nonexistent.md",
|
||||
mockErr: fs.ErrNotExist,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockFS.RemoveError = tc.mockErr
|
||||
err := s.DeleteFile(tc.userID, tc.workspaceID, tc.filePath)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expectedPath := filepath.Join("test-root", "1", "1", tc.filePath)
|
||||
found := false
|
||||
for _, p := range mockFS.RemoveCalls {
|
||||
if p == expectedPath {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected delete call not made")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
47
server/internal/storage/filesystem.go
Normal file
47
server/internal/storage/filesystem.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
)
|
||||
|
||||
// fileSystem defines the interface for filesystem operations
|
||||
type fileSystem interface {
|
||||
ReadFile(path string) ([]byte, error)
|
||||
WriteFile(path string, data []byte, perm fs.FileMode) error
|
||||
Remove(path string) error
|
||||
MkdirAll(path string, perm fs.FileMode) error
|
||||
RemoveAll(path string) error
|
||||
ReadDir(path string) ([]fs.DirEntry, error)
|
||||
Stat(path string) (fs.FileInfo, error)
|
||||
IsNotExist(err error) bool
|
||||
}
|
||||
|
||||
// osFS implements the FileSystem interface using the real filesystem.
|
||||
type osFS struct{}
|
||||
|
||||
// ReadFile reads the file at the given path.
|
||||
func (f *osFS) ReadFile(path string) ([]byte, error) { return os.ReadFile(path) }
|
||||
|
||||
// WriteFile writes the given data to the file at the given path.
|
||||
func (f *osFS) WriteFile(path string, data []byte, perm fs.FileMode) error {
|
||||
return os.WriteFile(path, data, perm)
|
||||
}
|
||||
|
||||
// Remove deletes the file at the given path.
|
||||
func (f *osFS) Remove(path string) error { return os.Remove(path) }
|
||||
|
||||
// MkdirAll creates the directory at the given path and any necessary parents.
|
||||
func (f *osFS) MkdirAll(path string, perm fs.FileMode) error { return os.MkdirAll(path, perm) }
|
||||
|
||||
// RemoveAll removes the file or directory at the given path.
|
||||
func (f *osFS) RemoveAll(path string) error { return os.RemoveAll(path) }
|
||||
|
||||
// ReadDir reads the directory at the given path.
|
||||
func (f *osFS) ReadDir(path string) ([]fs.DirEntry, error) { return os.ReadDir(path) }
|
||||
|
||||
// Stat returns the FileInfo for the file at the given path.
|
||||
func (f *osFS) Stat(path string) (fs.FileInfo, error) { return os.Stat(path) }
|
||||
|
||||
// IsNotExist returns true if the error is a "file does not exist" error.
|
||||
func (f *osFS) IsNotExist(err error) bool { return os.IsNotExist(err) }
|
||||
125
server/internal/storage/filesystem_test.go
Normal file
125
server/internal/storage/filesystem_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package storage_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mockDirEntry struct {
|
||||
name string
|
||||
isDir bool
|
||||
}
|
||||
|
||||
func (m *mockDirEntry) Name() string { return m.name }
|
||||
func (m *mockDirEntry) IsDir() bool { return m.isDir }
|
||||
func (m *mockDirEntry) Type() fs.FileMode { return fs.ModeDir }
|
||||
func (m *mockDirEntry) Info() (fs.FileInfo, error) { return nil, nil }
|
||||
|
||||
func NewMockDirEntry(name string, isDir bool) fs.DirEntry {
|
||||
return &mockDirEntry{name: name, isDir: isDir}
|
||||
}
|
||||
|
||||
// Extend mockFS to support directory operations
|
||||
type MockDirInfo struct {
|
||||
name string
|
||||
size int64
|
||||
mode fs.FileMode
|
||||
modTime time.Time
|
||||
isDir bool
|
||||
}
|
||||
|
||||
func (m MockDirInfo) Name() string { return m.name }
|
||||
func (m MockDirInfo) Size() int64 { return m.size }
|
||||
func (m MockDirInfo) Mode() fs.FileMode { return m.mode }
|
||||
func (m MockDirInfo) ModTime() time.Time { return m.modTime }
|
||||
func (m MockDirInfo) IsDir() bool { return m.isDir }
|
||||
func (m MockDirInfo) Sys() interface{} { return nil }
|
||||
|
||||
type mockFS struct {
|
||||
// Record operations for verification
|
||||
ReadCalls map[string]int
|
||||
WriteCalls map[string][]byte
|
||||
RemoveCalls []string
|
||||
MkdirCalls []string
|
||||
|
||||
// Configure test behavior
|
||||
ReadFileReturns map[string]struct {
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
ReadDirReturns map[string]struct {
|
||||
entries []fs.DirEntry
|
||||
err error
|
||||
}
|
||||
WriteFileError error
|
||||
RemoveError error
|
||||
MkdirError error
|
||||
StatError error
|
||||
}
|
||||
|
||||
func NewMockFS() *mockFS {
|
||||
return &mockFS{
|
||||
ReadCalls: make(map[string]int),
|
||||
WriteCalls: make(map[string][]byte),
|
||||
RemoveCalls: make([]string, 0),
|
||||
MkdirCalls: make([]string, 0),
|
||||
ReadFileReturns: make(map[string]struct {
|
||||
data []byte
|
||||
err error
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockFS) ReadFile(path string) ([]byte, error) {
|
||||
m.ReadCalls[path]++
|
||||
if ret, ok := m.ReadFileReturns[path]; ok {
|
||||
return ret.data, ret.err
|
||||
}
|
||||
return nil, errors.New("file not found")
|
||||
}
|
||||
|
||||
func (m *mockFS) WriteFile(path string, data []byte, _ fs.FileMode) error {
|
||||
m.WriteCalls[path] = data
|
||||
return m.WriteFileError
|
||||
}
|
||||
|
||||
func (m *mockFS) Remove(path string) error {
|
||||
m.RemoveCalls = append(m.RemoveCalls, path)
|
||||
return m.RemoveError
|
||||
}
|
||||
|
||||
func (m *mockFS) MkdirAll(path string, _ fs.FileMode) error {
|
||||
m.MkdirCalls = append(m.MkdirCalls, path)
|
||||
return m.MkdirError
|
||||
}
|
||||
|
||||
func (m *mockFS) Stat(path string) (fs.FileInfo, error) {
|
||||
if m.StatError != nil {
|
||||
return nil, m.StatError
|
||||
}
|
||||
return MockDirInfo{
|
||||
name: filepath.Base(path),
|
||||
size: 1024,
|
||||
mode: 0644,
|
||||
modTime: time.Now(),
|
||||
isDir: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockFS) ReadDir(path string) ([]fs.DirEntry, error) {
|
||||
if ret, ok := m.ReadDirReturns[path]; ok {
|
||||
return ret.entries, ret.err
|
||||
}
|
||||
return nil, fs.ErrNotExist
|
||||
}
|
||||
|
||||
func (m *mockFS) RemoveAll(path string) error {
|
||||
m.RemoveCalls = append(m.RemoveCalls, path)
|
||||
return m.RemoveError
|
||||
}
|
||||
|
||||
func (m *mockFS) IsNotExist(err error) bool {
|
||||
return err == fs.ErrNotExist
|
||||
}
|
||||
71
server/internal/storage/git.go
Normal file
71
server/internal/storage/git.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"novamd/internal/git"
|
||||
)
|
||||
|
||||
// RepositoryManager defines the interface for managing Git repositories.
|
||||
type RepositoryManager interface {
|
||||
SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken string) error
|
||||
DisableGitRepo(userID, workspaceID int)
|
||||
StageCommitAndPush(userID, workspaceID int, message string) error
|
||||
Pull(userID, workspaceID int) error
|
||||
}
|
||||
|
||||
// SetupGitRepo sets up a Git repository for the given userID and workspaceID.
|
||||
// The repository is cloned from the given gitURL using the given gitUser and gitToken.
|
||||
func (s *Service) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken string) error {
|
||||
workspacePath := s.GetWorkspacePath(userID, workspaceID)
|
||||
if _, ok := s.GitRepos[userID]; !ok {
|
||||
s.GitRepos[userID] = make(map[int]git.Client)
|
||||
}
|
||||
s.GitRepos[userID][workspaceID] = s.newGitClient(gitURL, gitUser, gitToken, workspacePath)
|
||||
return s.GitRepos[userID][workspaceID].EnsureRepo()
|
||||
}
|
||||
|
||||
// DisableGitRepo disables the Git repository for the given userID and workspaceID.
|
||||
func (s *Service) DisableGitRepo(userID, workspaceID int) {
|
||||
if userRepos, ok := s.GitRepos[userID]; ok {
|
||||
delete(userRepos, workspaceID)
|
||||
if len(userRepos) == 0 {
|
||||
delete(s.GitRepos, userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StageCommitAndPush stages, commit with the message, and pushes the changes to the Git repository.
|
||||
// The git repository belongs to the given userID and is associated with the given workspaceID.
|
||||
func (s *Service) StageCommitAndPush(userID, workspaceID int, message string) error {
|
||||
repo, ok := s.getGitRepo(userID, workspaceID)
|
||||
if !ok {
|
||||
return fmt.Errorf("git settings not configured for this workspace")
|
||||
}
|
||||
|
||||
if err := repo.Commit(message); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return repo.Push()
|
||||
}
|
||||
|
||||
// Pull pulls the changes from the remote Git repository.
|
||||
// The git repository belongs to the given userID and is associated with the given workspaceID.
|
||||
func (s *Service) Pull(userID, workspaceID int) error {
|
||||
repo, ok := s.getGitRepo(userID, workspaceID)
|
||||
if !ok {
|
||||
return fmt.Errorf("git settings not configured for this workspace")
|
||||
}
|
||||
|
||||
return repo.Pull()
|
||||
}
|
||||
|
||||
// getGitRepo returns the Git repository for the given user and workspace IDs.
|
||||
func (s *Service) getGitRepo(userID, workspaceID int) (git.Client, bool) {
|
||||
userRepos, ok := s.GitRepos[userID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
repo, ok := userRepos[workspaceID]
|
||||
return repo, ok
|
||||
}
|
||||
258
server/internal/storage/git_test.go
Normal file
258
server/internal/storage/git_test.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package storage_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/git"
|
||||
"novamd/internal/storage"
|
||||
)
|
||||
|
||||
// MockGitClient implements git.Client interface for testing
|
||||
type MockGitClient struct {
|
||||
CloneCalled bool
|
||||
PullCalled bool
|
||||
CommitCalled bool
|
||||
PushCalled bool
|
||||
EnsureCalled bool
|
||||
CommitMessage string
|
||||
ReturnError error
|
||||
}
|
||||
|
||||
func (m *MockGitClient) Clone() error {
|
||||
m.CloneCalled = true
|
||||
return m.ReturnError
|
||||
}
|
||||
|
||||
func (m *MockGitClient) Pull() error {
|
||||
m.PullCalled = true
|
||||
return m.ReturnError
|
||||
}
|
||||
|
||||
func (m *MockGitClient) Commit(message string) error {
|
||||
m.CommitCalled = true
|
||||
m.CommitMessage = message
|
||||
return m.ReturnError
|
||||
}
|
||||
|
||||
func (m *MockGitClient) Push() error {
|
||||
m.PushCalled = true
|
||||
return m.ReturnError
|
||||
}
|
||||
|
||||
func (m *MockGitClient) EnsureRepo() error {
|
||||
m.EnsureCalled = true
|
||||
return m.ReturnError
|
||||
}
|
||||
|
||||
func TestSetupGitRepo(t *testing.T) {
|
||||
mockFS := NewMockFS()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
workspaceID int
|
||||
gitURL string
|
||||
gitUser string
|
||||
gitToken string
|
||||
mockErr error
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful setup",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
gitURL: "https://github.com/user/repo",
|
||||
gitUser: "user",
|
||||
gitToken: "token",
|
||||
mockErr: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "git initialization error",
|
||||
userID: 1,
|
||||
workspaceID: 2,
|
||||
gitURL: "https://github.com/user/repo",
|
||||
gitUser: "user",
|
||||
gitToken: "token",
|
||||
mockErr: errors.New("git initialization failed"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a mock client with the desired error behavior
|
||||
mockClient := &MockGitClient{ReturnError: tc.mockErr}
|
||||
|
||||
// Create a client factory that returns our configured mock
|
||||
mockClientFactory := func(_, _, _, _ string) git.Client {
|
||||
return mockClient
|
||||
}
|
||||
|
||||
s := storage.NewServiceWithOptions("test-root", storage.Options{
|
||||
Fs: mockFS,
|
||||
NewGitClient: mockClientFactory,
|
||||
})
|
||||
|
||||
// Setup the git repo
|
||||
err := s.SetupGitRepo(tc.userID, tc.workspaceID, tc.gitURL, tc.gitUser, tc.gitToken)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check if client was stored correctly
|
||||
client, ok := s.GitRepos[tc.userID][tc.workspaceID]
|
||||
if !ok {
|
||||
t.Fatal("git client was not stored in service")
|
||||
}
|
||||
|
||||
if !mockClient.EnsureCalled {
|
||||
t.Error("EnsureRepo was not called")
|
||||
}
|
||||
|
||||
// Verify it's our mock client
|
||||
if client != mockClient {
|
||||
t.Error("stored client is not our mock client")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitOperations(t *testing.T) {
|
||||
mockFS := NewMockFS()
|
||||
s := storage.NewServiceWithOptions("test-root", storage.Options{
|
||||
Fs: mockFS,
|
||||
NewGitClient: func(_, _, _, _ string) git.Client { return &MockGitClient{} },
|
||||
})
|
||||
|
||||
t.Run("operations on non-configured workspace", func(t *testing.T) {
|
||||
err := s.StageCommitAndPush(1, 1, "test commit")
|
||||
if err == nil {
|
||||
t.Error("expected error for non-configured workspace, got nil")
|
||||
}
|
||||
|
||||
err = s.Pull(1, 1)
|
||||
if err == nil {
|
||||
t.Error("expected error for non-configured workspace, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("successful operations", func(t *testing.T) {
|
||||
// Initialize GitRepos map
|
||||
s.GitRepos = make(map[int]map[int]git.Client)
|
||||
s.GitRepos[1] = make(map[int]git.Client)
|
||||
mockClient := &MockGitClient{}
|
||||
s.GitRepos[1][1] = mockClient
|
||||
|
||||
// Test commit and push
|
||||
err := s.StageCommitAndPush(1, 1, "test commit")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if !mockClient.CommitCalled {
|
||||
t.Error("Commit was not called")
|
||||
}
|
||||
if mockClient.CommitMessage != "test commit" {
|
||||
t.Errorf("Commit message = %q, want %q", mockClient.CommitMessage, "test commit")
|
||||
}
|
||||
if !mockClient.PushCalled {
|
||||
t.Error("Push was not called")
|
||||
}
|
||||
|
||||
// Test pull
|
||||
err = s.Pull(1, 1)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if !mockClient.PullCalled {
|
||||
t.Error("Pull was not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("operation errors", func(t *testing.T) {
|
||||
// Initialize GitRepos map with error-returning client
|
||||
s.GitRepos = make(map[int]map[int]git.Client)
|
||||
s.GitRepos[1] = make(map[int]git.Client)
|
||||
mockClient := &MockGitClient{ReturnError: errors.New("git operation failed")}
|
||||
s.GitRepos[1][1] = mockClient
|
||||
|
||||
// Test commit error
|
||||
err := s.StageCommitAndPush(1, 1, "test commit")
|
||||
if err == nil {
|
||||
t.Error("expected error for commit, got nil")
|
||||
}
|
||||
|
||||
// Test pull error
|
||||
err = s.Pull(1, 1)
|
||||
if err == nil {
|
||||
t.Error("expected error for pull, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDisableGitRepo(t *testing.T) {
|
||||
mockFS := NewMockFS()
|
||||
s := storage.NewServiceWithOptions("test-root", storage.Options{
|
||||
Fs: mockFS,
|
||||
NewGitClient: func(_, _, _, _ string) git.Client { return &MockGitClient{} },
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
workspaceID int
|
||||
setupRepo bool
|
||||
}{
|
||||
{
|
||||
name: "disable existing repo",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
setupRepo: true,
|
||||
},
|
||||
{
|
||||
name: "disable non-existent repo",
|
||||
userID: 2,
|
||||
workspaceID: 1,
|
||||
setupRepo: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Reset GitRepos for each test
|
||||
s.GitRepos = make(map[int]map[int]git.Client)
|
||||
|
||||
if tc.setupRepo {
|
||||
// Setup initial repo
|
||||
s.GitRepos[tc.userID] = make(map[int]git.Client)
|
||||
s.GitRepos[tc.userID][tc.workspaceID] = &MockGitClient{}
|
||||
}
|
||||
|
||||
// Disable the repo
|
||||
s.DisableGitRepo(tc.userID, tc.workspaceID)
|
||||
|
||||
// Verify repo was removed
|
||||
if userRepos, exists := s.GitRepos[tc.userID]; exists {
|
||||
if _, repoExists := userRepos[tc.workspaceID]; repoExists {
|
||||
t.Error("git repo still exists after disable")
|
||||
}
|
||||
}
|
||||
|
||||
// If this was the user's last repo, verify user entry was cleaned up
|
||||
if tc.setupRepo {
|
||||
if len(s.GitRepos[tc.userID]) > 0 {
|
||||
t.Error("user's git repos map not cleaned up when last repo removed")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
52
server/internal/storage/service.go
Normal file
52
server/internal/storage/service.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"novamd/internal/git"
|
||||
)
|
||||
|
||||
// Manager interface combines all storage interfaces.
|
||||
type Manager interface {
|
||||
FileManager
|
||||
WorkspaceManager
|
||||
RepositoryManager
|
||||
}
|
||||
|
||||
// Service represents the file system structure.
|
||||
type Service struct {
|
||||
fs fileSystem
|
||||
newGitClient func(url, user, token, path string) git.Client
|
||||
RootDir string
|
||||
GitRepos map[int]map[int]git.Client // map[userID]map[workspaceID]*git.Client
|
||||
}
|
||||
|
||||
// Options represents the options for the storage service.
|
||||
type Options struct {
|
||||
Fs fileSystem
|
||||
NewGitClient func(url, user, token, path string) git.Client
|
||||
}
|
||||
|
||||
// NewService creates a new Storage instance with the default options and the given rootDir root directory.
|
||||
func NewService(rootDir string) *Service {
|
||||
return NewServiceWithOptions(rootDir, Options{
|
||||
Fs: &osFS{},
|
||||
NewGitClient: git.New,
|
||||
})
|
||||
}
|
||||
|
||||
// NewServiceWithOptions creates a new Storage instance with the given options and the given rootDir root directory.
|
||||
func NewServiceWithOptions(rootDir string, options Options) *Service {
|
||||
if options.Fs == nil {
|
||||
options.Fs = &osFS{}
|
||||
}
|
||||
|
||||
if options.NewGitClient == nil {
|
||||
options.NewGitClient = git.New
|
||||
}
|
||||
|
||||
return &Service{
|
||||
fs: options.Fs,
|
||||
newGitClient: options.NewGitClient,
|
||||
RootDir: rootDir,
|
||||
GitRepos: make(map[int]map[int]git.Client),
|
||||
}
|
||||
}
|
||||
64
server/internal/storage/workspace.go
Normal file
64
server/internal/storage/workspace.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// WorkspaceManager provides functionalities to interact with workspaces in the storage.
|
||||
type WorkspaceManager interface {
|
||||
ValidatePath(userID, workspaceID int, path string) (string, error)
|
||||
GetWorkspacePath(userID, workspaceID int) string
|
||||
InitializeUserWorkspace(userID, workspaceID int) error
|
||||
DeleteUserWorkspace(userID, workspaceID int) error
|
||||
}
|
||||
|
||||
// ValidatePath validates the if the given path is valid within the workspace directory.
|
||||
// Workspace directory is defined as the directory for the given userID and workspaceID.
|
||||
func (s *Service) ValidatePath(userID, workspaceID int, path string) (string, error) {
|
||||
workspacePath := s.GetWorkspacePath(userID, workspaceID)
|
||||
|
||||
// First check if the path is absolute
|
||||
if filepath.IsAbs(path) {
|
||||
return "", &PathValidationError{Path: path, Message: "absolute paths not allowed"}
|
||||
}
|
||||
|
||||
// Join and clean the path
|
||||
fullPath := filepath.Join(workspacePath, path)
|
||||
cleanPath := filepath.Clean(fullPath)
|
||||
|
||||
// Verify the path is still within the workspace
|
||||
if !strings.HasPrefix(cleanPath, workspacePath) {
|
||||
return "", &PathValidationError{Path: path, Message: "path traversal attempt"}
|
||||
}
|
||||
|
||||
return cleanPath, nil
|
||||
}
|
||||
|
||||
// GetWorkspacePath returns the path to the workspace directory for the given userID and workspaceID.
|
||||
func (s *Service) GetWorkspacePath(userID, workspaceID int) string {
|
||||
return filepath.Join(s.RootDir, fmt.Sprintf("%d", userID), fmt.Sprintf("%d", workspaceID))
|
||||
}
|
||||
|
||||
// InitializeUserWorkspace creates the workspace directory for the given userID and workspaceID.
|
||||
func (s *Service) InitializeUserWorkspace(userID, workspaceID int) error {
|
||||
workspacePath := s.GetWorkspacePath(userID, workspaceID)
|
||||
err := s.fs.MkdirAll(workspacePath, 0755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create workspace directory: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUserWorkspace deletes the workspace directory for the given userID and workspaceID.
|
||||
func (s *Service) DeleteUserWorkspace(userID, workspaceID int) error {
|
||||
workspacePath := s.GetWorkspacePath(userID, workspaceID)
|
||||
err := s.fs.RemoveAll(workspacePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete workspace directory: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
275
server/internal/storage/workspace_test.go
Normal file
275
server/internal/storage/workspace_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package storage_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/storage"
|
||||
)
|
||||
|
||||
func TestValidatePath(t *testing.T) {
|
||||
mockFS := NewMockFS()
|
||||
s := storage.NewServiceWithOptions("test-root", storage.Options{
|
||||
Fs: mockFS,
|
||||
NewGitClient: nil,
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
workspaceID int
|
||||
path string
|
||||
want string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid path",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
path: "notes/test.md",
|
||||
want: filepath.Join("test-root", "1", "1", "notes", "test.md"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid path with dot",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
path: "./notes/test.md",
|
||||
want: filepath.Join("test-root", "1", "1", "notes", "test.md"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "path with parent directory traversal",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
path: "../../../etc/passwd",
|
||||
want: "",
|
||||
wantErr: true,
|
||||
errContains: "path traversal attempt",
|
||||
},
|
||||
{
|
||||
name: "absolute path attempt",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
path: "/etc/passwd",
|
||||
want: "",
|
||||
wantErr: true,
|
||||
errContains: "absolute paths not allowed",
|
||||
},
|
||||
{
|
||||
name: "empty path",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
path: "",
|
||||
want: filepath.Join("test-root", "1", "1"),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := s.ValidatePath(tc.userID, tc.workspaceID, tc.path)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("ValidatePath() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetWorkspacePath(t *testing.T) {
|
||||
mockFS := NewMockFS()
|
||||
s := storage.NewServiceWithOptions("test-root", storage.Options{
|
||||
Fs: mockFS,
|
||||
NewGitClient: nil,
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
workspaceID int
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "standard workspace path",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
want: filepath.Join("test-root", "1", "1"),
|
||||
},
|
||||
{
|
||||
name: "different user and workspace IDs",
|
||||
userID: 2,
|
||||
workspaceID: 3,
|
||||
want: filepath.Join("test-root", "2", "3"),
|
||||
},
|
||||
{
|
||||
name: "zero IDs",
|
||||
userID: 0,
|
||||
workspaceID: 0,
|
||||
want: filepath.Join("test-root", "0", "0"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := s.GetWorkspacePath(tc.userID, tc.workspaceID)
|
||||
if got != tc.want {
|
||||
t.Errorf("GetWorkspacePath() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitializeUserWorkspace(t *testing.T) {
|
||||
mockFS := NewMockFS()
|
||||
s := storage.NewServiceWithOptions("test-root", storage.Options{
|
||||
Fs: mockFS,
|
||||
NewGitClient: nil,
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
workspaceID int
|
||||
mockErr error
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "successful initialization",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
mockErr: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mkdir error",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
mockErr: errors.New("permission denied"),
|
||||
wantErr: true,
|
||||
errContains: "failed to create workspace directory",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockFS.MkdirError = tc.mockErr
|
||||
err := s.InitializeUserWorkspace(tc.userID, tc.workspaceID)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the correct directory was created
|
||||
expectedPath := filepath.Join("test-root", "1", "1")
|
||||
dirCreated := false
|
||||
for _, path := range mockFS.MkdirCalls {
|
||||
if path == expectedPath {
|
||||
dirCreated = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !dirCreated {
|
||||
t.Errorf("directory %s was not created", expectedPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUserWorkspace(t *testing.T) {
|
||||
mockFS := NewMockFS()
|
||||
s := storage.NewServiceWithOptions("test-root", storage.Options{
|
||||
Fs: mockFS,
|
||||
NewGitClient: nil,
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID int
|
||||
workspaceID int
|
||||
mockErr error
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "successful deletion",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
mockErr: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "removal error",
|
||||
userID: 1,
|
||||
workspaceID: 1,
|
||||
mockErr: errors.New("permission denied"),
|
||||
wantErr: true,
|
||||
errContains: "failed to delete workspace directory",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockFS.RemoveError = tc.mockErr
|
||||
err := s.DeleteUserWorkspace(tc.userID, tc.workspaceID)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the correct directory was deleted
|
||||
expectedPath := filepath.Join("test-root", "1", "1")
|
||||
dirDeleted := false
|
||||
for _, path := range mockFS.RemoveCalls {
|
||||
if path == expectedPath {
|
||||
dirDeleted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !dirDeleted {
|
||||
t.Errorf("directory %s was not deleted", expectedPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -8,19 +8,19 @@ import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/filesystem"
|
||||
"novamd/internal/models"
|
||||
"novamd/internal/storage"
|
||||
)
|
||||
|
||||
type UserService struct {
|
||||
DB *db.DB
|
||||
FS *filesystem.FileSystem
|
||||
DB db.Database
|
||||
Storage storage.Manager
|
||||
}
|
||||
|
||||
func NewUserService(database *db.DB, fs *filesystem.FileSystem) *UserService {
|
||||
func NewUserService(database db.Database, s storage.Manager) *UserService {
|
||||
return &UserService{
|
||||
DB: database,
|
||||
FS: fs,
|
||||
DB: database,
|
||||
Storage: s,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func (s *UserService) SetupAdminUser(adminEmail, adminPassword string) (*models.
|
||||
}
|
||||
|
||||
// Initialize workspace directory
|
||||
err = s.FS.InitializeUserWorkspace(createdUser.ID, createdUser.LastWorkspaceID)
|
||||
err = s.Storage.InitializeUserWorkspace(createdUser.ID, createdUser.LastWorkspaceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize admin workspace: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user