Merge pull request #20 from LordMathis/chore/backend-test

Implement backend tests
This commit is contained in:
2024-11-30 11:51:35 +01:00
committed by GitHub
72 changed files with 7476 additions and 897 deletions

34
.github/workflows/go-test.yml vendored Normal file
View File

@@ -0,0 +1,34 @@
name: Go Tests
on:
push:
branches:
- "*"
pull_request:
branches:
- main
jobs:
test:
name: Run Tests
runs-on: ubuntu-latest
defaults:
run:
working-directory: ./server
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.23"
cache: true
- name: Run Tests
run: go test -tags=test,integration ./... -v
- name: Run Tests with Race Detector
run: go test -tags=test,integration -race ./... -v

View File

@@ -14,6 +14,7 @@
"go.lintTool": "golangci-lint",
"go.lintOnSave": "package",
"go.formatTool": "goimports",
"go.testFlags": ["-tags=test,integration"],
"[go]": {
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
@@ -23,6 +24,7 @@
},
"gopls": {
"usePlaceholders": true,
"staticcheck": true
"staticcheck": true,
"buildFlags": ["-tags", "test,integration"]
}
}

View File

@@ -1,3 +1,4 @@
// Package main contains the main entry point for the application. It sets up the server, database, and other services, and starts the server.
package main
import (
@@ -17,8 +18,9 @@ import (
"novamd/internal/auth"
"novamd/internal/config"
"novamd/internal/db"
"novamd/internal/filesystem"
"novamd/internal/handlers"
"novamd/internal/secrets"
"novamd/internal/storage"
)
func main() {
@@ -28,12 +30,26 @@ func main() {
log.Fatal("Failed to load configuration:", err)
}
// Initialize secrets service
secretsService, err := secrets.NewService(cfg.EncryptionKey)
if err != nil {
log.Fatal("Failed to initialize secrets service:", err)
}
// Initialize database
database, err := db.Init(cfg.DBPath, cfg.EncryptionKey)
database, err := db.Init(cfg.DBPath, secretsService)
if err != nil {
log.Fatal(err)
}
defer database.Close()
err = database.Migrate()
if err != nil {
log.Fatal("Failed to apply database migrations:", err)
}
defer func() {
if err := database.Close(); err != nil {
log.Printf("Error closing database: %v", err)
}
}()
// Get or generate JWT signing key
signingKey := cfg.JWTSigningKey
@@ -45,10 +61,10 @@ func main() {
}
// Initialize filesystem
fs := filesystem.New(cfg.WorkDir)
s := storage.NewService(cfg.WorkDir)
// Initialize JWT service
jwtService, err := auth.NewJWTService(auth.JWTConfig{
jwtManager, err := auth.NewJWTService(auth.JWTConfig{
SigningKey: signingKey,
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 7 * 24 * time.Hour,
@@ -58,10 +74,10 @@ func main() {
}
// Initialize auth middleware
authMiddleware := auth.NewMiddleware(jwtService)
authMiddleware := auth.NewMiddleware(jwtManager)
// Initialize session service
sessionService := auth.NewSessionService(database.DB, jwtService)
sessionService := auth.NewSessionService(database, jwtManager)
// Set up router
r := chi.NewRouter()
@@ -95,7 +111,7 @@ func main() {
// Set up routes
r.Route("/api/v1", func(r chi.Router) {
r.Use(httprate.LimitByIP(cfg.RateLimitRequests, cfg.RateLimitWindow))
api.SetupRoutes(r, database, fs, authMiddleware, sessionService)
api.SetupRoutes(r, database, s, authMiddleware, sessionService)
})
// Handle all other routes with static file server

50
server/gendocs.sh Executable file
View File

@@ -0,0 +1,50 @@
#!/bin/bash
set -euo pipefail
# Function to generate anchor from package path
generate_anchor() {
echo "$1" | tr '/' '-'
}
# Create documentation file
echo "# NovaMD Package Documentation
Generated documentation for all packages in the NovaMD project.
## Table of Contents
" > documentation.md
# Find all directories containing .go files (excluding test files)
# Sort them for consistent output
PACKAGES=$(find . -type f -name "*.go" ! -name "*_test.go" -exec dirname {} \; | sort -u | grep -v "/\.")
# Generate table of contents
for PKG in $PACKAGES; do
# Strip leading ./
PKG_PATH=${PKG#./}
# Skip if empty
[ -z "$PKG_PATH" ] && continue
ANCHOR=$(generate_anchor "$PKG_PATH")
echo "- [$PKG_PATH](#$ANCHOR)" >> documentation.md
done
echo "" >> documentation.md
# Generate documentation for each package
for PKG in $PACKAGES; do
# Strip leading ./
PKG_PATH=${PKG#./}
# Skip if empty
[ -z "$PKG_PATH" ] && continue
echo "## $PKG_PATH" >> documentation.md
echo "" >> documentation.md
echo '```go' >> documentation.md
go doc -all "./$PKG_PATH" >> documentation.md
echo '```' >> documentation.md
echo "" >> documentation.md
done
echo "Documentation generated in documentation.md"

View File

@@ -11,6 +11,7 @@ require (
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0
github.com/mattn/go-sqlite3 v1.14.23
github.com/stretchr/testify v1.9.0
github.com/unrolled/secure v1.17.0
golang.org/x/crypto v0.21.0
)
@@ -22,6 +23,7 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudflare/circl v1.3.7 // indirect
github.com/cyphar/filepath-securejoin v0.2.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
@@ -33,6 +35,7 @@ require (
github.com/kevinburke/ssh_config v1.2.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/pjbgf/sha1cd v0.3.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
github.com/skeema/knownhosts v1.2.2 // indirect
github.com/xanzy/ssh-agent v0.3.3 // indirect
@@ -42,4 +45,5 @@ require (
golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.13.0 // indirect
gopkg.in/warnings.v0 v0.1.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -3,20 +3,20 @@ package api
import (
"novamd/internal/auth"
"novamd/internal/context"
"novamd/internal/db"
"novamd/internal/filesystem"
"novamd/internal/handlers"
"novamd/internal/middleware"
"novamd/internal/storage"
"github.com/go-chi/chi/v5"
)
// SetupRoutes configures the API routes
func SetupRoutes(r chi.Router, db *db.DB, fs *filesystem.FileSystem, authMiddleware *auth.Middleware, sessionService *auth.SessionService) {
func SetupRoutes(r chi.Router, db db.Database, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) {
handler := &handlers.Handler{
DB: db,
FS: fs,
DB: db,
Storage: s,
}
// Public routes (no authentication required)
@@ -29,7 +29,7 @@ func SetupRoutes(r chi.Router, db *db.DB, fs *filesystem.FileSystem, authMiddlew
r.Group(func(r chi.Router) {
// Apply authentication middleware to all routes in this group
r.Use(authMiddleware.Authenticate)
r.Use(middleware.WithUserContext)
r.Use(context.WithUserContextMiddleware)
// Auth routes
r.Post("/auth/logout", handler.Logout(sessionService))
@@ -67,7 +67,7 @@ func SetupRoutes(r chi.Router, db *db.DB, fs *filesystem.FileSystem, authMiddlew
// Single workspace routes
r.Route("/{workspaceName}", func(r chi.Router) {
r.Use(middleware.WithWorkspaceContext(db))
r.Use(context.WithWorkspaceContextMiddleware(db))
r.Use(authMiddleware.RequireWorkspaceAccess)
r.Get("/", handler.GetWorkspace())

View File

@@ -1,6 +1,9 @@
// Package auth provides JWT token generation and validation
package auth
import (
"crypto/rand"
"encoding/hex"
"fmt"
"time"
@@ -30,14 +33,22 @@ type JWTConfig struct {
RefreshTokenExpiry time.Duration // How long refresh tokens are valid
}
// JWTService handles JWT token generation and validation
type JWTService struct {
// JWTManager defines the interface for managing JWT tokens
type JWTManager interface {
GenerateAccessToken(userID int, role string) (string, error)
GenerateRefreshToken(userID int, role string) (string, error)
ValidateToken(tokenString string) (*Claims, error)
RefreshAccessToken(refreshToken string) (string, error)
}
// jwtService handles JWT token generation and validation
type jwtService struct {
config JWTConfig
}
// NewJWTService creates a new JWT service with the provided configuration
// Returns an error if the signing key is missing
func NewJWTService(config JWTConfig) (*JWTService, error) {
func NewJWTService(config JWTConfig) (JWTManager, error) {
if config.SigningKey == "" {
return nil, fmt.Errorf("signing key is required")
}
@@ -48,41 +59,35 @@ func NewJWTService(config JWTConfig) (*JWTService, error) {
if config.RefreshTokenExpiry == 0 {
config.RefreshTokenExpiry = 7 * 24 * time.Hour // Default to 7 days
}
return &JWTService{config: config}, nil
return &jwtService{config: config}, nil
}
// GenerateAccessToken creates a new access token for a user
// Parameters:
// - userID: the ID of the user
// - role: the role of the user
// Returns the signed token string or an error
func (s *JWTService) GenerateAccessToken(userID int, role string) (string, error) {
// GenerateAccessToken creates a new access token for a user with the given userID and role
func (s *jwtService) GenerateAccessToken(userID int, role string) (string, error) {
return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry)
}
// GenerateRefreshToken creates a new refresh token for a user
// Parameters:
// - userID: the ID of the user
// - role: the role of the user
// Returns the signed token string or an error
func (s *JWTService) GenerateRefreshToken(userID int, role string) (string, error) {
// GenerateRefreshToken creates a new refresh token for a user with the given userID and role
func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, error) {
return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry)
}
// generateToken is an internal helper function that creates a new JWT token
// Parameters:
// - userID: the ID of the user
// - role: the role of the user
// - tokenType: the type of token (access or refresh)
// - expiry: how long the token should be valid
// Returns the signed token string or an error
func (s *JWTService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) {
func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) {
now := time.Now()
// Add a random nonce to ensure uniqueness
nonce := make([]byte, 8)
if _, err := rand.Read(nonce); err != nil {
return "", fmt.Errorf("failed to generate nonce: %w", err)
}
claims := Claims{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(expiry)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
ID: hex.EncodeToString(nonce),
},
UserID: userID,
Role: role,
@@ -94,10 +99,7 @@ func (s *JWTService) generateToken(userID int, role string, tokenType TokenType,
}
// ValidateToken validates and parses a JWT token
// Parameters:
// - tokenString: the token to validate
// Returns the token claims if valid, or an error if invalid
func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) {
func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
// Validate the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
@@ -117,11 +119,8 @@ func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) {
return nil, fmt.Errorf("invalid token claims")
}
// RefreshAccessToken creates a new access token using a refresh token
// Parameters:
// - refreshToken: the refresh token to use
// Returns a new access token if the refresh token is valid, or an error
func (s *JWTService) RefreshAccessToken(refreshToken string) (string, error) {
// RefreshAccessToken creates a new access token using a refreshToken
func (s *jwtService) RefreshAccessToken(refreshToken string) (string, error) {
claims, err := s.ValidateToken(refreshToken)
if err != nil {
return "", fmt.Errorf("invalid refresh token: %w", err)

View File

@@ -0,0 +1,221 @@
package auth_test
import (
"testing"
"time"
"novamd/internal/auth"
"github.com/golang-jwt/jwt/v5"
)
// jwt_test.go tests
func TestNewJWTService(t *testing.T) {
testCases := []struct {
name string
config auth.JWTConfig
wantErr bool
}{
{
name: "valid configuration",
config: auth.JWTConfig{
SigningKey: "test-key",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 24 * time.Hour,
},
wantErr: false,
},
{
name: "missing signing key",
config: auth.JWTConfig{
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 24 * time.Hour,
},
wantErr: true,
},
{
name: "zero expiry times",
config: auth.JWTConfig{
SigningKey: "test-key",
},
wantErr: false, // Should use default values
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
service, err := auth.NewJWTService(tc.config)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if service == nil {
t.Error("expected service, got nil")
}
})
}
}
func TestGenerateAndValidateToken(t *testing.T) {
config := auth.JWTConfig{
SigningKey: "test-key",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 24 * time.Hour,
}
service, _ := auth.NewJWTService(config)
testCases := []struct {
name string
userID int
role string
tokenType auth.TokenType
wantErr bool
}{
{
name: "valid access token",
userID: 1,
role: "admin",
tokenType: auth.AccessToken,
wantErr: false,
},
{
name: "valid refresh token",
userID: 1,
role: "editor",
tokenType: auth.RefreshToken,
wantErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var token string
var err error
// Generate token based on type
if tc.tokenType == auth.AccessToken {
token, err = service.GenerateAccessToken(tc.userID, tc.role)
} else {
token, err = service.GenerateRefreshToken(tc.userID, tc.role)
}
if err != nil {
t.Fatalf("failed to generate token: %v", err)
}
// Validate token
claims, err := service.ValidateToken(token)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
// Verify claims
if claims.UserID != tc.userID {
t.Errorf("userID = %v, want %v", claims.UserID, tc.userID)
}
if claims.Role != tc.role {
t.Errorf("role = %v, want %v", claims.Role, tc.role)
}
if claims.Type != tc.tokenType {
t.Errorf("type = %v, want %v", claims.Type, tc.tokenType)
}
})
}
}
func TestRefreshAccessToken(t *testing.T) {
config := auth.JWTConfig{
SigningKey: "test-key",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 24 * time.Hour,
}
service, _ := auth.NewJWTService(config)
testCases := []struct {
name string
userID int
role string
wantErr bool
setupFunc func() string // Added setup function to handle custom token creation
}{
{
name: "valid refresh token",
userID: 1,
role: "admin",
wantErr: false,
setupFunc: func() string {
token, _ := service.GenerateRefreshToken(1, "admin")
return token
},
},
{
name: "expired refresh token",
userID: 1,
role: "admin",
wantErr: true,
setupFunc: func() string {
// Create a token that's already expired
claims := &auth.Claims{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)), // Expired 1 hour ago
IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
NotBefore: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
},
UserID: 1,
Role: "admin",
Type: auth.RefreshToken,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, _ := token.SignedString([]byte(config.SigningKey))
return tokenString
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
refreshToken := tc.setupFunc()
newAccessToken, err := service.RefreshAccessToken(refreshToken)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
claims, err := service.ValidateToken(newAccessToken)
if err != nil {
t.Fatalf("failed to validate new access token: %v", err)
}
if claims.UserID != tc.userID {
t.Errorf("userID = %v, want %v", claims.UserID, tc.userID)
}
if claims.Role != tc.role {
t.Errorf("role = %v, want %v", claims.Role, tc.role)
}
if claims.Type != auth.AccessToken {
t.Errorf("token type = %v, want %v", claims.Type, auth.AccessToken)
}
})
}
}

View File

@@ -1,35 +1,21 @@
package auth
import (
"context"
"fmt"
"net/http"
"strings"
"novamd/internal/httpcontext"
"novamd/internal/context"
)
type contextKey string
const (
UserContextKey contextKey = "user"
)
// UserClaims represents the user information stored in the request context
type UserClaims struct {
UserID int
Role string
}
// Middleware handles JWT authentication for protected routes
type Middleware struct {
jwtService *JWTService
jwtManager JWTManager
}
// NewMiddleware creates a new authentication middleware
func NewMiddleware(jwtService *JWTService) *Middleware {
func NewMiddleware(jwtManager JWTManager) *Middleware {
return &Middleware{
jwtService: jwtService,
jwtManager: jwtManager,
}
}
@@ -51,7 +37,7 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
}
// Validate token
claims, err := m.jwtService.ValidateToken(parts[1])
claims, err := m.jwtManager.ValidateToken(parts[1])
if err != nil {
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
@@ -63,14 +49,14 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
return
}
// Add user claims to request context
ctx := context.WithValue(r.Context(), UserContextKey, UserClaims{
UserID: claims.UserID,
Role: claims.Role,
})
// Create handler context with user information
hctx := &context.HandlerContext{
UserID: claims.UserID,
UserRole: claims.Role,
}
// Call the next handler with the updated context
next.ServeHTTP(w, r.WithContext(ctx))
// Add context to request and continue
next.ServeHTTP(w, context.WithHandlerContext(r, hctx))
})
}
@@ -78,13 +64,12 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler {
func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, ok := r.Context().Value(UserContextKey).(UserClaims)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if claims.Role != role && claims.Role != "admin" {
if ctx.UserRole != role && ctx.UserRole != "admin" {
http.Error(w, "Insufficient permissions", http.StatusForbidden)
return
}
@@ -94,10 +79,10 @@ func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler {
}
}
// RequireWorkspaceAccess returns a middleware that ensures the user has access to the workspace
func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get our handler context
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -117,12 +102,3 @@ func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
})
}
// GetUserFromContext retrieves user claims from the request context
func GetUserFromContext(ctx context.Context) (*UserClaims, error) {
claims, ok := ctx.Value(UserContextKey).(UserClaims)
if !ok {
return nil, fmt.Errorf("no user found in context")
}
return &claims, nil
}

View File

@@ -0,0 +1,294 @@
package auth_test
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"novamd/internal/auth"
"novamd/internal/context"
"novamd/internal/models"
)
// Complete mockResponseWriter implementation
type mockResponseWriter struct {
headers http.Header
statusCode int
written []byte
}
func newMockResponseWriter() *mockResponseWriter {
return &mockResponseWriter{
headers: make(http.Header),
}
}
func (m *mockResponseWriter) Header() http.Header {
return m.headers
}
func (m *mockResponseWriter) Write(b []byte) (int, error) {
m.written = b
return len(b), nil
}
func (m *mockResponseWriter) WriteHeader(statusCode int) {
m.statusCode = statusCode
}
func TestAuthenticateMiddleware(t *testing.T) {
config := auth.JWTConfig{
SigningKey: "test-key",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 24 * time.Hour,
}
jwtService, _ := auth.NewJWTService(config)
middleware := auth.NewMiddleware(jwtService)
testCases := []struct {
name string
setupAuth func() string
wantStatusCode int
}{
{
name: "valid token",
setupAuth: func() string {
token, _ := jwtService.GenerateAccessToken(1, "admin")
return token
},
wantStatusCode: http.StatusOK,
},
{
name: "missing auth header",
setupAuth: func() string {
return ""
},
wantStatusCode: http.StatusUnauthorized,
},
{
name: "invalid auth format",
setupAuth: func() string {
return "InvalidFormat token"
},
wantStatusCode: http.StatusUnauthorized,
},
{
name: "invalid token",
setupAuth: func() string {
return "Bearer invalid.token.here"
},
wantStatusCode: http.StatusUnauthorized,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create test request
req := httptest.NewRequest("GET", "/test", nil)
if token := tc.setupAuth(); token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
// Create response recorder
w := newMockResponseWriter()
// Create test handler
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})
// Execute middleware
middleware.Authenticate(next).ServeHTTP(w, req)
// Check status code
if w.statusCode != tc.wantStatusCode {
t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode)
}
// Check if next handler was called when expected
if tc.wantStatusCode == http.StatusOK && !nextCalled {
t.Error("next handler was not called")
}
if tc.wantStatusCode != http.StatusOK && nextCalled {
t.Error("next handler was called when it shouldn't have been")
}
})
}
}
func TestRequireRole(t *testing.T) {
config := auth.JWTConfig{
SigningKey: "test-key",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 24 * time.Hour,
}
jwtService, _ := auth.NewJWTService(config)
middleware := auth.NewMiddleware(jwtService)
testCases := []struct {
name string
userRole string
requiredRole string
wantStatusCode int
}{
{
name: "matching role",
userRole: "admin",
requiredRole: "admin",
wantStatusCode: http.StatusOK,
},
{
name: "admin accessing other role",
userRole: "admin",
requiredRole: "editor",
wantStatusCode: http.StatusOK,
},
{
name: "insufficient role",
userRole: "editor",
requiredRole: "admin",
wantStatusCode: http.StatusForbidden,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create handler context with user info
hctx := &context.HandlerContext{
UserID: 1,
UserRole: tc.userRole,
}
// Create request with handler context
req := httptest.NewRequest("GET", "/test", nil)
req = context.WithHandlerContext(req, hctx)
w := newMockResponseWriter()
// Create test handler
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})
// Execute middleware
middleware.RequireRole(tc.requiredRole)(next).ServeHTTP(w, req)
// Check status code
if w.statusCode != tc.wantStatusCode {
t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode)
}
// Check if next handler was called when expected
if tc.wantStatusCode == http.StatusOK && !nextCalled {
t.Error("next handler was not called")
}
if tc.wantStatusCode != http.StatusOK && nextCalled {
t.Error("next handler was called when it shouldn't have been")
}
})
}
}
func TestRequireWorkspaceAccess(t *testing.T) {
config := auth.JWTConfig{
SigningKey: "test-key",
}
jwtService, _ := auth.NewJWTService(config)
middleware := auth.NewMiddleware(jwtService)
testCases := []struct {
name string
setupContext func() *context.HandlerContext
wantStatusCode int
}{
{
name: "workspace owner access",
setupContext: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 1,
UserRole: "editor",
Workspace: &models.Workspace{
ID: 1,
UserID: 1, // Same as context UserID
},
}
},
wantStatusCode: http.StatusOK,
},
{
name: "admin access to other's workspace",
setupContext: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 2,
UserRole: "admin",
Workspace: &models.Workspace{
ID: 1,
UserID: 1, // Different from context UserID
},
}
},
wantStatusCode: http.StatusOK,
},
{
name: "unauthorized access attempt",
setupContext: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 2,
UserRole: "editor",
Workspace: &models.Workspace{
ID: 1,
UserID: 1, // Different from context UserID
},
}
},
wantStatusCode: http.StatusNotFound,
},
{
name: "no workspace in context",
setupContext: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 1,
UserRole: "editor",
Workspace: nil,
}
},
wantStatusCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create request with context
req := httptest.NewRequest("GET", "/test", nil)
req = context.WithHandlerContext(req, tc.setupContext())
w := newMockResponseWriter()
// Create test handler
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})
// Execute middleware
middleware.RequireWorkspaceAccess(next).ServeHTTP(w, req)
// Check status code
if w.statusCode != tc.wantStatusCode {
t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode)
}
// Check if next handler was called when expected
if tc.wantStatusCode == http.StatusOK && !nextCalled {
t.Error("next handler was not called")
}
if tc.wantStatusCode != http.StatusOK && nextCalled {
t.Error("next handler was called when it shouldn't have been")
}
})
}
}

View File

@@ -1,67 +1,49 @@
package auth
import (
"database/sql"
"fmt"
"novamd/internal/db"
"novamd/internal/models"
"time"
"github.com/google/uuid"
)
// Session represents a user session in the database
type Session struct {
ID string // Unique session identifier
UserID int // ID of the user this session belongs to
RefreshToken string // The refresh token associated with this session
ExpiresAt time.Time // When this session expires
CreatedAt time.Time // When this session was created
}
// SessionService manages user sessions in the database
type SessionService struct {
db *sql.DB // Database connection
jwtService *JWTService // JWT service for token operations
db db.SessionStore // Database store for sessions
jwtManager JWTManager // JWT Manager for token operations
}
// NewSessionService creates a new session service
// Parameters:
// - db: database connection
// - jwtService: JWT service for token operations
func NewSessionService(db *sql.DB, jwtService *JWTService) *SessionService {
// NewSessionService creates a new session service with the given database and JWT manager
func NewSessionService(db db.SessionStore, jwtManager JWTManager) *SessionService {
return &SessionService{
db: db,
jwtService: jwtService,
jwtManager: jwtManager,
}
}
// CreateSession creates a new user session
// Parameters:
// - userID: the ID of the user
// - role: the role of the user
// Returns:
// - session: the created session
// - accessToken: a new access token
// - error: any error that occurred
func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) {
// CreateSession creates a new user session for a user with the given userID and role
func (s *SessionService) CreateSession(userID int, role string) (*models.Session, string, error) {
// Generate both access and refresh tokens
accessToken, err := s.jwtService.GenerateAccessToken(userID, role)
accessToken, err := s.jwtManager.GenerateAccessToken(userID, role)
if err != nil {
return nil, "", fmt.Errorf("failed to generate access token: %w", err)
}
refreshToken, err := s.jwtService.GenerateRefreshToken(userID, role)
refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role)
if err != nil {
return nil, "", fmt.Errorf("failed to generate refresh token: %w", err)
}
// Validate the refresh token to get its expiry time
claims, err := s.jwtService.ValidateToken(refreshToken)
claims, err := s.jwtManager.ValidateToken(refreshToken)
if err != nil {
return nil, "", fmt.Errorf("failed to validate refresh token: %w", err)
}
// Create a new session record
session := &Session{
session := &models.Session{
ID: uuid.New().String(),
UserID: userID,
RefreshToken: refreshToken,
@@ -69,72 +51,43 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin
CreatedAt: time.Now(),
}
// Store the session in the database
_, err = s.db.Exec(`
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
VALUES (?, ?, ?, ?, ?)`,
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
)
if err != nil {
return nil, "", fmt.Errorf("failed to store session: %w", err)
// Store the session
if err := s.db.CreateSession(session); err != nil {
return nil, "", err
}
return session, accessToken, nil
}
// RefreshSession creates a new access token using a refresh token
// Parameters:
// - refreshToken: the refresh token to use
// Returns:
// - string: a new access token
// - error: any error that occurred
// RefreshSession creates a new access token using a refreshToken
func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
// Get session from database first
session, err := s.db.GetSessionByRefreshToken(refreshToken)
if err != nil {
return "", fmt.Errorf("invalid session: %w", err)
}
// Validate the refresh token
claims, err := s.jwtService.ValidateToken(refreshToken)
claims, err := s.jwtManager.ValidateToken(refreshToken)
if err != nil {
return "", fmt.Errorf("invalid refresh token: %w", err)
}
// Check if the session exists and is not expired
var session Session
err = s.db.QueryRow(`
SELECT id, user_id, refresh_token, expires_at, created_at
FROM sessions
WHERE refresh_token = ? AND expires_at > ?`,
refreshToken, time.Now(),
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
if err == sql.ErrNoRows {
return "", fmt.Errorf("session not found or expired")
}
if err != nil {
return "", fmt.Errorf("failed to fetch session: %w", err)
// Double check that the claims match the session
if claims.UserID != session.UserID {
return "", fmt.Errorf("token does not match session")
}
// Generate a new access token
return s.jwtService.GenerateAccessToken(claims.UserID, claims.Role)
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
}
// InvalidateSession removes a session from the database
// Parameters:
// - sessionID: the ID of the session to invalidate
// Returns:
// - error: any error that occurred
// InvalidateSession removes a session with the given sessionID from the database
func (s *SessionService) InvalidateSession(sessionID string) error {
_, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
if err != nil {
return fmt.Errorf("failed to invalidate session: %w", err)
}
return nil
return s.db.DeleteSession(sessionID)
}
// CleanExpiredSessions removes all expired sessions from the database
// Returns:
// - error: any error that occurred
func (s *SessionService) CleanExpiredSessions() error {
_, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
if err != nil {
return fmt.Errorf("failed to clean expired sessions: %w", err)
}
return nil
return s.db.CleanExpiredSessions()
}

View File

@@ -0,0 +1,304 @@
package auth_test
import (
"errors"
"strings"
"testing"
"time"
"novamd/internal/auth"
"novamd/internal/models"
)
// Mock SessionStore
type mockSessionStore struct {
sessions map[string]*models.Session
sessionsByToken map[string]*models.Session // Added index by refresh token
}
func newMockSessionStore() *mockSessionStore {
return &mockSessionStore{
sessions: make(map[string]*models.Session),
sessionsByToken: make(map[string]*models.Session),
}
}
func (m *mockSessionStore) CreateSession(session *models.Session) error {
m.sessions[session.ID] = session
m.sessionsByToken[session.RefreshToken] = session
return nil
}
func (m *mockSessionStore) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
session, exists := m.sessionsByToken[refreshToken]
if !exists {
return nil, errors.New("session not found")
}
if session.ExpiresAt.Before(time.Now()) {
return nil, errors.New("session expired")
}
return session, nil
}
func (m *mockSessionStore) DeleteSession(sessionID string) error {
session, exists := m.sessions[sessionID]
if !exists {
return errors.New("session not found")
}
delete(m.sessionsByToken, session.RefreshToken)
delete(m.sessions, sessionID)
return nil
}
func (m *mockSessionStore) CleanExpiredSessions() error {
for id, session := range m.sessions {
if session.ExpiresAt.Before(time.Now()) {
delete(m.sessionsByToken, session.RefreshToken)
delete(m.sessions, id)
}
}
return nil
}
func TestCreateSession(t *testing.T) {
config := auth.JWTConfig{
SigningKey: "test-key",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 24 * time.Hour,
}
jwtService, _ := auth.NewJWTService(config)
mockDB := newMockSessionStore()
sessionService := auth.NewSessionService(mockDB, jwtService)
testCases := []struct {
name string
userID int
role string
wantErr bool
}{
{
name: "successful session creation",
userID: 1,
role: "admin",
wantErr: false,
},
{
name: "another successful session",
userID: 2,
role: "editor",
wantErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, accessToken, err := sessionService.CreateSession(tc.userID, tc.role)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
// Verify session
if session.UserID != tc.userID {
t.Errorf("userID = %v, want %v", session.UserID, tc.userID)
}
// Verify the session was stored
storedSession, exists := mockDB.sessions[session.ID]
if !exists {
t.Error("session was not stored in database")
}
if storedSession.RefreshToken != session.RefreshToken {
t.Error("stored refresh token doesn't match")
}
// Verify access token
claims, err := jwtService.ValidateToken(accessToken)
if err != nil {
t.Errorf("failed to validate access token: %v", err)
return
}
if claims.UserID != tc.userID {
t.Errorf("access token userID = %v, want %v", claims.UserID, tc.userID)
}
if claims.Role != tc.role {
t.Errorf("access token role = %v, want %v", claims.Role, tc.role)
}
if claims.Type != auth.AccessToken {
t.Errorf("token type = %v, want access token", claims.Type)
}
})
}
}
func TestRefreshSession(t *testing.T) {
config := auth.JWTConfig{
SigningKey: "test-key",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 24 * time.Hour,
}
jwtService, _ := auth.NewJWTService(config)
mockDB := newMockSessionStore()
sessionService := auth.NewSessionService(mockDB, jwtService)
testCases := []struct {
name string
setupSession func() string
wantErr bool
errorContains string
}{
{
name: "valid refresh token",
setupSession: func() string {
token, _ := jwtService.GenerateRefreshToken(1, "admin")
session := &models.Session{
ID: "test-session-1",
UserID: 1,
RefreshToken: token,
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
}
if err := mockDB.CreateSession(session); err != nil {
t.Fatalf("failed to create session: %v", err)
}
return token
},
wantErr: false,
},
{
name: "expired refresh token",
setupSession: func() string {
token, _ := jwtService.GenerateRefreshToken(1, "admin")
session := &models.Session{
ID: "test-session-2",
UserID: 1,
RefreshToken: token,
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired
CreatedAt: time.Now().Add(-2 * time.Hour),
}
if err := mockDB.CreateSession(session); err != nil {
t.Fatalf("failed to create session: %v", err)
}
return token
},
wantErr: true,
errorContains: "session expired",
},
{
name: "non-existent refresh token",
setupSession: func() string {
return "non-existent-token"
},
wantErr: true,
errorContains: "session not found",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
refreshToken := tc.setupSession()
newAccessToken, err := sessionService.RefreshSession(refreshToken)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
} else if tc.errorContains != "" && !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("error = %v, want error containing %v", err, tc.errorContains)
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
// Verify new access token
claims, err := jwtService.ValidateToken(newAccessToken)
if err != nil {
t.Errorf("failed to validate new access token: %v", err)
return
}
if claims.Type != auth.AccessToken {
t.Errorf("token type = %v, want access token", claims.Type)
}
})
}
}
func TestInvalidateSession(t *testing.T) {
config := auth.JWTConfig{
SigningKey: "test-key",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 24 * time.Hour,
}
jwtService, _ := auth.NewJWTService(config)
mockDB := newMockSessionStore()
sessionService := auth.NewSessionService(mockDB, jwtService)
testCases := []struct {
name string
setupSession func() string
wantErr bool
errorContains string
}{
{
name: "valid session invalidation",
setupSession: func() string {
session := &models.Session{
ID: "test-session-1",
UserID: 1,
RefreshToken: "valid-token",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
}
if err := mockDB.CreateSession(session); err != nil {
t.Fatalf("failed to create session: %v", err)
}
return session.ID
},
wantErr: false,
},
{
name: "non-existent session",
setupSession: func() string {
return "non-existent-session-id"
},
wantErr: true,
errorContains: "session not found",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
sessionID := tc.setupSession()
err := sessionService.InvalidateSession(sessionID)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
} else if !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("error = %v, want error containing %v", err, tc.errorContains)
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
// Verify session was removed
if _, exists := mockDB.sessions[sessionID]; exists {
t.Error("session still exists after invalidation")
}
})
}
}

View File

@@ -1,16 +1,16 @@
// Package config provides the configuration for the application
package config
import (
"fmt"
"novamd/internal/secrets"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"novamd/internal/crypto"
)
// Config holds the configuration for the application
type Config struct {
DBPath string
WorkDir string
@@ -27,6 +27,7 @@ type Config struct {
IsDevelopment bool
}
// DefaultConfig returns a new Config instance with default values
func DefaultConfig() *Config {
return &Config{
DBPath: "./novamd.db",
@@ -39,13 +40,14 @@ func DefaultConfig() *Config {
}
}
// Validate checks if the configuration is valid
func (c *Config) Validate() error {
if c.AdminEmail == "" || c.AdminPassword == "" {
return fmt.Errorf("NOVAMD_ADMIN_EMAIL and NOVAMD_ADMIN_PASSWORD must be set")
}
// Validate encryption key
if err := crypto.ValidateKey(c.EncryptionKey); err != nil {
if err := secrets.ValidateKey(c.EncryptionKey); err != nil {
return fmt.Errorf("invalid NOVAMD_ENCRYPTION_KEY: %w", err)
}
@@ -63,16 +65,10 @@ func Load() (*Config, error) {
if dbPath := os.Getenv("NOVAMD_DB_PATH"); dbPath != "" {
config.DBPath = dbPath
}
if err := ensureDir(filepath.Dir(config.DBPath)); err != nil {
return nil, fmt.Errorf("failed to create database directory: %w", err)
}
if workDir := os.Getenv("NOVAMD_WORKDIR"); workDir != "" {
config.WorkDir = workDir
}
if err := ensureDir(config.WorkDir); err != nil {
return nil, fmt.Errorf("failed to create work directory: %w", err)
}
if staticPath := os.Getenv("NOVAMD_STATIC_PATH"); staticPath != "" {
config.StaticPath = staticPath
@@ -115,10 +111,3 @@ func Load() (*Config, error) {
return config, nil
}
func ensureDir(dir string) error {
if dir == "" {
return nil
}
return os.MkdirAll(dir, 0755)
}

View File

@@ -0,0 +1,215 @@
package config_test
import (
"os"
"testing"
"time"
"novamd/internal/config"
)
func TestDefaultConfig(t *testing.T) {
cfg := config.DefaultConfig()
tests := []struct {
name string
got interface{}
expected interface{}
}{
{"DBPath", cfg.DBPath, "./novamd.db"},
{"WorkDir", cfg.WorkDir, "./data"},
{"StaticPath", cfg.StaticPath, "../app/dist"},
{"Port", cfg.Port, "8080"},
{"RateLimitRequests", cfg.RateLimitRequests, 100},
{"RateLimitWindow", cfg.RateLimitWindow, time.Minute * 15},
{"IsDevelopment", cfg.IsDevelopment, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.expected {
t.Errorf("DefaultConfig().%s = %v, want %v", tt.name, tt.got, tt.expected)
}
})
}
}
// setEnv is a helper function to set environment variables and check for errors
func setEnv(t *testing.T, key, value string) {
if err := os.Setenv(key, value); err != nil {
t.Fatalf("Failed to set environment variable %s: %v", key, err)
}
}
func TestLoad(t *testing.T) {
// Helper function to reset environment variables
cleanup := func() {
envVars := []string{
"NOVAMD_ENV",
"NOVAMD_DB_PATH",
"NOVAMD_WORKDIR",
"NOVAMD_STATIC_PATH",
"NOVAMD_PORT",
"NOVAMD_APP_URL",
"NOVAMD_CORS_ORIGINS",
"NOVAMD_ADMIN_EMAIL",
"NOVAMD_ADMIN_PASSWORD",
"NOVAMD_ENCRYPTION_KEY",
"NOVAMD_JWT_SIGNING_KEY",
"NOVAMD_RATE_LIMIT_REQUESTS",
"NOVAMD_RATE_LIMIT_WINDOW",
}
for _, env := range envVars {
if err := os.Unsetenv(env); err != nil {
t.Fatalf("Failed to unset environment variable %s: %v", env, err)
}
}
}
t.Run("load with defaults", func(t *testing.T) {
cleanup()
defer cleanup()
// Set required env vars
setEnv(t, "NOVAMD_ADMIN_EMAIL", "admin@example.com")
setEnv(t, "NOVAMD_ADMIN_PASSWORD", "password123")
setEnv(t, "NOVAMD_ENCRYPTION_KEY", "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=") // 32 bytes base64 encoded
cfg, err := config.Load()
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.DBPath != "./novamd.db" {
t.Errorf("default DBPath = %v, want %v", cfg.DBPath, "./novamd.db")
}
})
t.Run("load with custom values", func(t *testing.T) {
cleanup()
defer cleanup()
// Set all environment variables
envs := map[string]string{
"NOVAMD_ENV": "development",
"NOVAMD_DB_PATH": "/custom/db/path.db",
"NOVAMD_WORKDIR": "/custom/work/dir",
"NOVAMD_STATIC_PATH": "/custom/static/path",
"NOVAMD_PORT": "3000",
"NOVAMD_APP_URL": "http://localhost:3000",
"NOVAMD_CORS_ORIGINS": "http://localhost:3000,http://localhost:3001",
"NOVAMD_ADMIN_EMAIL": "admin@example.com",
"NOVAMD_ADMIN_PASSWORD": "password123",
"NOVAMD_ENCRYPTION_KEY": "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=",
"NOVAMD_JWT_SIGNING_KEY": "secret-key",
"NOVAMD_RATE_LIMIT_REQUESTS": "200",
"NOVAMD_RATE_LIMIT_WINDOW": "30m",
}
for k, v := range envs {
setEnv(t, k, v)
}
cfg, err := config.Load()
if err != nil {
t.Fatalf("Load() error = %v", err)
}
tests := []struct {
name string
got interface{}
expected interface{}
}{
{"IsDevelopment", cfg.IsDevelopment, true},
{"DBPath", cfg.DBPath, "/custom/db/path.db"},
{"WorkDir", cfg.WorkDir, "/custom/work/dir"},
{"StaticPath", cfg.StaticPath, "/custom/static/path"},
{"Port", cfg.Port, "3000"},
{"AppURL", cfg.AppURL, "http://localhost:3000"},
{"AdminEmail", cfg.AdminEmail, "admin@example.com"},
{"AdminPassword", cfg.AdminPassword, "password123"},
{"JWTSigningKey", cfg.JWTSigningKey, "secret-key"},
{"RateLimitRequests", cfg.RateLimitRequests, 200},
{"RateLimitWindow", cfg.RateLimitWindow, 30 * time.Minute},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.expected {
t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.expected)
}
})
}
// Test CORS origins separately as it's a slice
expectedOrigins := []string{"http://localhost:3000", "http://localhost:3001"}
if len(cfg.CORSOrigins) != len(expectedOrigins) {
t.Errorf("CORSOrigins length = %v, want %v", len(cfg.CORSOrigins), len(expectedOrigins))
}
for i, origin := range cfg.CORSOrigins {
if origin != expectedOrigins[i] {
t.Errorf("CORSOrigins[%d] = %v, want %v", i, origin, expectedOrigins[i])
}
}
})
t.Run("validation failures", func(t *testing.T) {
testCases := []struct {
name string
setupEnv func(*testing.T)
expectedError string
}{
{
name: "missing admin email",
setupEnv: func(t *testing.T) {
cleanup()
setEnv(t, "NOVAMD_ADMIN_PASSWORD", "password123")
setEnv(t, "NOVAMD_ENCRYPTION_KEY", "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=")
},
expectedError: "NOVAMD_ADMIN_EMAIL and NOVAMD_ADMIN_PASSWORD must be set",
},
{
name: "missing admin password",
setupEnv: func(t *testing.T) {
cleanup()
setEnv(t, "NOVAMD_ADMIN_EMAIL", "admin@example.com")
setEnv(t, "NOVAMD_ENCRYPTION_KEY", "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=")
},
expectedError: "NOVAMD_ADMIN_EMAIL and NOVAMD_ADMIN_PASSWORD must be set",
},
{
name: "missing encryption key",
setupEnv: func(t *testing.T) {
cleanup()
setEnv(t, "NOVAMD_ADMIN_EMAIL", "admin@example.com")
setEnv(t, "NOVAMD_ADMIN_PASSWORD", "password123")
},
expectedError: "invalid NOVAMD_ENCRYPTION_KEY: encryption key is required",
},
{
name: "invalid encryption key",
setupEnv: func(t *testing.T) {
cleanup()
setEnv(t, "NOVAMD_ADMIN_EMAIL", "admin@example.com")
setEnv(t, "NOVAMD_ADMIN_PASSWORD", "password123")
setEnv(t, "NOVAMD_ENCRYPTION_KEY", "invalid-key")
},
expectedError: "invalid NOVAMD_ENCRYPTION_KEY: invalid base64 encoding: illegal base64 data at input byte 7",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.setupEnv(t)
_, err := config.Load()
if err == nil {
t.Error("expected error, got nil")
return
}
if err.Error() != tc.expectedError {
t.Errorf("error = %v, want error containing %v", err, tc.expectedError)
}
})
}
})
}

View File

@@ -0,0 +1,62 @@
// Package context provides functions for managing request context
package context
import (
"context"
"fmt"
"net/http"
"novamd/internal/models"
)
type contextKey string
const (
// HandlerContextKey is the key used to store handler context in the request context
HandlerContextKey contextKey = "handlerContext"
)
// UserClaims represents user information from authentication
type UserClaims struct {
UserID int
Role string
}
// HandlerContext holds the request-specific data available to all handlers
type HandlerContext struct {
UserID int
UserRole string
Workspace *models.Workspace // Optional, only set for workspace routes
}
// GetRequestContext retrieves the handler context from the request
func GetRequestContext(w http.ResponseWriter, r *http.Request) (*HandlerContext, bool) {
ctx := r.Context().Value(HandlerContextKey)
if ctx == nil {
http.Error(w, "Internal server error", http.StatusInternalServerError)
return nil, false
}
return ctx.(*HandlerContext), true
}
// WithHandlerContext adds handler context to the request
func WithHandlerContext(r *http.Request, hctx *HandlerContext) *http.Request {
return r.WithContext(context.WithValue(r.Context(), HandlerContextKey, hctx))
}
// GetUserFromContext retrieves user claims from the context
func GetUserFromContext(ctx context.Context) (*UserClaims, error) {
val := ctx.Value(HandlerContextKey)
if val == nil {
return nil, fmt.Errorf("no user found in context")
}
hctx, ok := val.(*HandlerContext)
if !ok {
return nil, fmt.Errorf("invalid context type")
}
return &UserClaims{
UserID: hctx.UserID,
Role: hctx.UserRole,
}, nil
}

View File

@@ -0,0 +1,139 @@
package context_test
import (
stdctx "context"
"net/http"
"net/http/httptest"
"testing"
"novamd/internal/context"
)
func TestGetRequestContext(t *testing.T) {
tests := []struct {
name string
setupCtx func() *context.HandlerContext
wantStatus int
wantOK bool
}{
{
name: "valid context",
setupCtx: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 1,
UserRole: "admin",
}
},
wantStatus: http.StatusOK,
wantOK: true,
},
{
name: "missing context",
setupCtx: func() *context.HandlerContext {
return nil
},
wantStatus: http.StatusInternalServerError,
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test request
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
if ctx := tt.setupCtx(); ctx != nil {
req = context.WithHandlerContext(req, ctx)
}
gotCtx, ok := context.GetRequestContext(w, req)
if ok != tt.wantOK {
t.Errorf("GetRequestContext() ok = %v, want %v", ok, tt.wantOK)
}
if !tt.wantOK {
if w.Code != tt.wantStatus {
t.Errorf("GetRequestContext() status = %v, want %v", w.Code, tt.wantStatus)
}
return
}
if gotCtx.UserID != tt.setupCtx().UserID {
t.Errorf("GetRequestContext() UserID = %v, want %v", gotCtx.UserID, tt.setupCtx().UserID)
}
if gotCtx.UserRole != tt.setupCtx().UserRole {
t.Errorf("GetRequestContext() UserRole = %v, want %v", gotCtx.UserRole, tt.setupCtx().UserRole)
}
})
}
}
func TestGetUserFromContext(t *testing.T) {
tests := []struct {
name string
setupCtx func() stdctx.Context
wantUser *context.UserClaims
wantError bool
}{
{
name: "valid user context",
setupCtx: func() stdctx.Context {
return stdctx.WithValue(stdctx.Background(), context.HandlerContextKey, &context.HandlerContext{
UserID: 1,
UserRole: "admin",
})
},
wantUser: &context.UserClaims{
UserID: 1,
Role: "admin",
},
wantError: false,
},
{
name: "missing context",
setupCtx: func() stdctx.Context {
return stdctx.Background()
},
wantUser: nil,
wantError: true,
},
{
name: "invalid context type",
setupCtx: func() stdctx.Context {
return stdctx.WithValue(stdctx.Background(), context.HandlerContextKey, "invalid")
},
wantUser: nil,
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := tt.setupCtx()
gotUser, err := context.GetUserFromContext(ctx)
if tt.wantError {
if err == nil {
t.Error("GetUserFromContext() error = nil, want error")
}
return
}
if err != nil {
t.Errorf("GetUserFromContext() unexpected error = %v", err)
return
}
if gotUser.UserID != tt.wantUser.UserID {
t.Errorf("GetUserFromContext() UserID = %v, want %v", gotUser.UserID, tt.wantUser.UserID)
}
if gotUser.Role != tt.wantUser.Role {
t.Errorf("GetUserFromContext() Role = %v, want %v", gotUser.Role, tt.wantUser.Role)
}
})
}
}

View File

@@ -1,38 +1,37 @@
package middleware
package context
import (
"net/http"
"novamd/internal/auth"
"novamd/internal/db"
"novamd/internal/httpcontext"
"github.com/go-chi/chi/v5"
)
// User ID and User Role context
func WithUserContext(next http.Handler) http.Handler {
// WithUserContextMiddleware extracts user information from JWT claims
// and adds it to the request context
func WithUserContextMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, err := auth.GetUserFromContext(r.Context())
claims, err := GetUserFromContext(r.Context())
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
hctx := &httpcontext.HandlerContext{
hctx := &HandlerContext{
UserID: claims.UserID,
UserRole: claims.Role,
}
r = httpcontext.WithHandlerContext(r, hctx)
r = WithHandlerContext(r, hctx)
next.ServeHTTP(w, r)
})
}
// Workspace context
func WithWorkspaceContext(db *db.DB) func(http.Handler) http.Handler {
// WithWorkspaceContextMiddleware adds workspace information to the request context
func WithWorkspaceContextMiddleware(db db.WorkspaceReader) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := GetRequestContext(w, r)
if !ok {
return
}
@@ -46,7 +45,7 @@ func WithWorkspaceContext(db *db.DB) func(http.Handler) http.Handler {
// Update existing context with workspace
ctx.Workspace = workspace
r = httpcontext.WithHandlerContext(r, ctx)
r = WithHandlerContext(r, ctx)
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,197 @@
package context_test
import (
stdctx "context"
"database/sql"
"net/http"
"net/http/httptest"
"testing"
"novamd/internal/context"
"novamd/internal/models"
)
// MockDB implements the minimal database interface needed for testing
type MockDB struct {
GetWorkspaceByNameFunc func(userID int, workspaceName string) (*models.Workspace, error)
}
func (m *MockDB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
return m.GetWorkspaceByNameFunc(userID, workspaceName)
}
func (m *MockDB) GetWorkspaceByID(_ int) (*models.Workspace, error) {
return nil, nil
}
func (m *MockDB) GetWorkspacesByUserID(_ int) ([]*models.Workspace, error) {
return nil, nil
}
func (m *MockDB) GetAllWorkspaces() ([]*models.Workspace, error) {
return nil, nil
}
func TestWithUserContextMiddleware(t *testing.T) {
tests := []struct {
name string
setupCtx func() *context.HandlerContext
wantStatus int
wantNext bool
}{
{
name: "valid user context",
setupCtx: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 1,
UserRole: "admin",
}
},
wantStatus: http.StatusOK,
wantNext: true,
},
{
name: "missing user context",
setupCtx: func() *context.HandlerContext {
return nil
},
wantStatus: http.StatusUnauthorized,
wantNext: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
if ctx := tt.setupCtx(); ctx != nil {
req = context.WithHandlerContext(req, ctx)
}
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})
middleware := context.WithUserContextMiddleware(next)
middleware.ServeHTTP(w, req)
if nextCalled != tt.wantNext {
t.Errorf("WithUserContextMiddleware() next called = %v, want %v", nextCalled, tt.wantNext)
}
if w.Code != tt.wantStatus {
t.Errorf("WithUserContextMiddleware() status = %v, want %v", w.Code, tt.wantStatus)
}
})
}
}
func TestWithWorkspaceContextMiddleware(t *testing.T) {
tests := []struct {
name string
setupCtx func() *context.HandlerContext
workspaceName string
mockWorkspace *models.Workspace
mockError error
wantStatus int
wantNext bool
}{
{
name: "valid workspace context",
setupCtx: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 1,
UserRole: "admin",
}
},
workspaceName: "test-workspace",
mockWorkspace: &models.Workspace{
ID: 1,
UserID: 1,
Name: "test-workspace",
},
mockError: nil,
wantStatus: http.StatusOK,
wantNext: true,
},
{
name: "workspace not found",
setupCtx: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 1,
UserRole: "admin",
}
},
workspaceName: "nonexistent",
mockWorkspace: nil,
mockError: sql.ErrNoRows,
wantStatus: http.StatusNotFound,
wantNext: false,
},
{
name: "missing user context",
setupCtx: func() *context.HandlerContext { return nil },
workspaceName: "test-workspace",
mockWorkspace: nil,
mockError: nil,
wantStatus: http.StatusInternalServerError,
wantNext: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockDB := &MockDB{
GetWorkspaceByNameFunc: func(_ int, _ string) (*models.Workspace, error) {
return tt.mockWorkspace, tt.mockError
},
}
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
if ctx := tt.setupCtx(); ctx != nil {
req = context.WithHandlerContext(req, ctx)
}
// Add workspace name to request context via chi URL params
req = req.WithContext(stdctx.WithValue(req.Context(), "workspaceName", tt.workspaceName))
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
// Verify workspace was added to context
if tt.mockWorkspace != nil {
ctx, ok := context.GetRequestContext(w, r)
if !ok {
t.Error("Failed to get request context in next handler")
return
}
if ctx.Workspace == nil {
t.Error("Workspace not set in context")
return
}
if ctx.Workspace.ID != tt.mockWorkspace.ID {
t.Errorf("Workspace ID = %v, want %v", ctx.Workspace.ID, tt.mockWorkspace.ID)
}
}
})
middleware := context.WithWorkspaceContextMiddleware(mockDB)(next)
middleware.ServeHTTP(w, req)
if nextCalled != tt.wantNext {
t.Errorf("WithWorkspaceContextMiddleware() next called = %v, want %v", nextCalled, tt.wantNext)
}
if w.Code != tt.wantStatus {
t.Errorf("WithWorkspaceContextMiddleware() status = %v, want %v", w.Code, tt.wantStatus)
}
})
}
}

View File

@@ -1,114 +0,0 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
)
var (
ErrKeyRequired = fmt.Errorf("encryption key is required")
ErrInvalidKeySize = fmt.Errorf("encryption key must be 32 bytes (256 bits) when decoded")
)
type Crypto struct {
key []byte
}
// ValidateKey checks if the provided key is suitable for AES-256
func ValidateKey(key string) error {
if key == "" {
return ErrKeyRequired
}
// Attempt to decode base64
keyBytes, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return fmt.Errorf("invalid base64 encoding: %w", err)
}
if len(keyBytes) != 32 {
return fmt.Errorf("%w: got %d bytes", ErrInvalidKeySize, len(keyBytes))
}
// Verify the key can be used for AES
_, err = aes.NewCipher(keyBytes)
if err != nil {
return fmt.Errorf("invalid encryption key: %w", err)
}
return nil
}
// New creates a new Crypto instance with the provided base64-encoded key
func New(key string) (*Crypto, error) {
if err := ValidateKey(key); err != nil {
return nil, err
}
keyBytes, _ := base64.StdEncoding.DecodeString(key)
return &Crypto{key: keyBytes}, nil
}
// Encrypt encrypts the plaintext using AES-256-GCM
func (c *Crypto) Encrypt(plaintext string) (string, error) {
if plaintext == "" {
return "", nil
}
block, err := aes.NewCipher(c.key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts the ciphertext using AES-256-GCM
func (c *Crypto) Decrypt(ciphertext string) (string, error) {
if ciphertext == "" {
return "", nil
}
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", err
}
block, err := aes.NewCipher(c.key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertextBytes, nil)
if err != nil {
return "", err
}
return string(plaintext), nil
}

View File

@@ -1,68 +0,0 @@
// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records.
package db
import "novamd/internal/models"
// UserStats represents system-wide statistics
type UserStats struct {
TotalUsers int `json:"totalUsers"`
TotalWorkspaces int `json:"totalWorkspaces"`
ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days
}
// GetAllUsers returns a list of all users in the system
func (db *DB) GetAllUsers() ([]*models.User, error) {
rows, err := db.Query(`
SELECT
id, email, display_name, role, created_at,
last_workspace_id
FROM users
ORDER BY id ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
var users []*models.User
for rows.Next() {
user := &models.User{}
err := rows.Scan(
&user.ID, &user.Email, &user.DisplayName, &user.Role,
&user.CreatedAt, &user.LastWorkspaceID,
)
if err != nil {
return nil, err
}
users = append(users, user)
}
return users, nil
}
// GetSystemStats returns system-wide statistics
func (db *DB) GetSystemStats() (*UserStats, error) {
stats := &UserStats{}
// Get total users
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
if err != nil {
return nil, err
}
// Get total workspaces
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
if err != nil {
return nil, err
}
// Get active users (users with activity in last 30 days)
err = db.QueryRow(`
SELECT COUNT(DISTINCT user_id)
FROM sessions
WHERE created_at > datetime('now', '-30 days')`).
Scan(&stats.ActiveUsers)
if err != nil {
return nil, err
}
return stats, nil
}

View File

@@ -1,22 +1,104 @@
// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records.
package db
import (
"database/sql"
"fmt"
"novamd/internal/crypto"
"novamd/internal/models"
"novamd/internal/secrets"
_ "github.com/mattn/go-sqlite3" // SQLite driver
)
// DB represents the database connection
type DB struct {
// UserStore defines the methods for interacting with user data in the database
type UserStore interface {
CreateUser(user *models.User) (*models.User, error)
GetUserByEmail(email string) (*models.User, error)
GetUserByID(userID int) (*models.User, error)
GetAllUsers() ([]*models.User, error)
UpdateUser(user *models.User) error
DeleteUser(userID int) error
UpdateLastWorkspace(userID int, workspaceName string) error
GetLastWorkspaceName(userID int) (string, error)
CountAdminUsers() (int, error)
}
// WorkspaceReader defines the methods for reading workspace data from the database
type WorkspaceReader interface {
GetWorkspaceByID(workspaceID int) (*models.Workspace, error)
GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error)
GetWorkspacesByUserID(userID int) ([]*models.Workspace, error)
GetAllWorkspaces() ([]*models.Workspace, error)
}
// WorkspaceWriter defines the methods for writing workspace data to the database
type WorkspaceWriter interface {
CreateWorkspace(workspace *models.Workspace) error
UpdateWorkspace(workspace *models.Workspace) error
DeleteWorkspace(workspaceID int) error
UpdateWorkspaceSettings(workspace *models.Workspace) error
DeleteWorkspaceTx(tx *sql.Tx, workspaceID int) error
UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error
UpdateLastOpenedFile(workspaceID int, filePath string) error
GetLastOpenedFile(workspaceID int) (string, error)
}
// WorkspaceStore defines the methods for interacting with workspace data in the database
type WorkspaceStore interface {
WorkspaceReader
WorkspaceWriter
}
// SessionStore defines the methods for interacting with jwt sessions in the database
type SessionStore interface {
CreateSession(session *models.Session) error
GetSessionByRefreshToken(refreshToken string) (*models.Session, error)
DeleteSession(sessionID string) error
CleanExpiredSessions() error
}
// SystemStore defines the methods for interacting with system settings and stats in the database
type SystemStore interface {
GetSystemStats() (*UserStats, error)
EnsureJWTSecret() (string, error)
GetSystemSetting(key string) (string, error)
SetSystemSetting(key, value string) error
}
// Database defines the methods for interacting with the database
type Database interface {
UserStore
WorkspaceStore
SessionStore
SystemStore
Begin() (*sql.Tx, error)
Close() error
Migrate() error
}
var (
// Main Database interface
_ Database = (*database)(nil)
// Component interfaces
_ UserStore = (*database)(nil)
_ WorkspaceStore = (*database)(nil)
_ SessionStore = (*database)(nil)
_ SystemStore = (*database)(nil)
// Sub-interfaces
_ WorkspaceReader = (*database)(nil)
_ WorkspaceWriter = (*database)(nil)
)
// database represents the database connection
type database struct {
*sql.DB
crypto *crypto.Crypto
secretsService secrets.Service
}
// Init initializes the database connection
func Init(dbPath string, encryptionKey string) (*DB, error) {
func Init(dbPath string, secretsService secrets.Service) (Database, error) {
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, err
@@ -26,40 +108,35 @@ func Init(dbPath string, encryptionKey string) (*DB, error) {
return nil, err
}
// Initialize crypto service
cryptoService, err := crypto.New(encryptionKey)
if err != nil {
return nil, fmt.Errorf("failed to initialize encryption: %w", err)
}
database := &DB{
DB: db,
crypto: cryptoService,
}
if err := database.Migrate(); err != nil {
// Enable foreign keys for this connection
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
return nil, err
}
database := &database{
DB: db,
secretsService: secretsService,
}
return database, nil
}
// Close closes the database connection
func (db *DB) Close() error {
func (db *database) Close() error {
return db.DB.Close()
}
// Helper methods for token encryption/decryption
func (db *DB) encryptToken(token string) (string, error) {
func (db *database) encryptToken(token string) (string, error) {
if token == "" {
return "", nil
}
return db.crypto.Encrypt(token)
return db.secretsService.Encrypt(token)
}
func (db *DB) decryptToken(token string) (string, error) {
func (db *database) decryptToken(token string) (string, error) {
if token == "" {
return "", nil
}
return db.crypto.Decrypt(token)
return db.secretsService.Decrypt(token)
}

View File

@@ -83,7 +83,7 @@ var migrations = []Migration{
}
// Migrate applies all database migrations
func (db *DB) Migrate() error {
func (db *database) Migrate() error {
// Create migrations table if it doesn't exist
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations (
version INTEGER PRIMARY KEY

View File

@@ -0,0 +1,151 @@
package db_test
import (
"testing"
"novamd/internal/db"
_ "github.com/mattn/go-sqlite3"
)
func TestMigrate(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{})
if err != nil {
t.Fatalf("failed to initialize database: %v", err)
}
defer database.Close()
t.Run("migrations are applied in order", func(t *testing.T) {
if err := database.Migrate(); err != nil {
t.Fatalf("failed to run initial migrations: %v", err)
}
// Check migration version
var version int
err := database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version)
if err != nil {
t.Fatalf("failed to get migration version: %v", err)
}
if version != 2 { // Current number of migrations in production code
t.Errorf("expected migration version 2, got %d", version)
}
// Verify number of migration entries matches versions applied
var count int
err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count)
if err != nil {
t.Fatalf("failed to count migrations: %v", err)
}
if count != 2 {
t.Errorf("expected 2 migration entries, got %d", count)
}
})
t.Run("migrations create expected schema", func(t *testing.T) {
// Verify tables exist
tables := []string{"users", "workspaces", "sessions", "system_settings", "migrations"}
for _, table := range tables {
if !tableExists(t, database, table) {
t.Errorf("table %q does not exist", table)
}
}
// Verify indexes
indexes := []struct {
table string
name string
}{
{"sessions", "idx_sessions_user_id"},
{"sessions", "idx_sessions_expires_at"},
{"sessions", "idx_sessions_refresh_token"},
}
for _, idx := range indexes {
if !indexExists(t, database, idx.table, idx.name) {
t.Errorf("index %q on table %q does not exist", idx.name, idx.table)
}
}
})
t.Run("migrations are idempotent", func(t *testing.T) {
// Run migrations again
if err := database.Migrate(); err != nil {
t.Fatalf("failed to re-run migrations: %v", err)
}
// Verify migration count hasn't changed
var count int
err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count)
if err != nil {
t.Fatalf("failed to count migrations: %v", err)
}
if count != 2 {
t.Errorf("expected 2 migration entries, got %d", count)
}
})
t.Run("rollback on migration failure", func(t *testing.T) {
// Create a test table that would conflict with a failing migration
_, err := database.TestDB().Exec("CREATE TABLE test_rollback (id INTEGER PRIMARY KEY)")
if err != nil {
t.Fatalf("failed to create test table: %v", err)
}
// Start transaction
tx, err := database.Begin()
if err != nil {
t.Fatalf("failed to start transaction: %v", err)
}
// Try operations that should fail and rollback
_, err = tx.Exec(`
CREATE TABLE test_rollback (id INTEGER PRIMARY KEY);
INSERT INTO nonexistent_table VALUES (1);
`)
if err == nil {
tx.Rollback()
t.Fatal("expected migration to fail")
}
tx.Rollback()
// Verify the migration version hasn't changed
var version int
err = database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version)
if err != nil {
t.Fatalf("failed to get migration version: %v", err)
}
if version != 2 {
t.Errorf("expected migration version to remain at 2, got %d", version)
}
})
}
func tableExists(t *testing.T, database db.TestDatabase, tableName string) bool {
t.Helper()
var name string
err := database.TestDB().QueryRow(`
SELECT name FROM sqlite_master
WHERE type='table' AND name=?`,
tableName,
).Scan(&name)
return err == nil
}
func indexExists(t *testing.T, database db.TestDatabase, tableName, indexName string) bool {
t.Helper()
var name string
err := database.TestDB().QueryRow(`
SELECT name FROM sqlite_master
WHERE type='index' AND tbl_name=? AND name=?`,
tableName, indexName,
).Scan(&name)
return err == nil
}

View File

@@ -0,0 +1,6 @@
package db_test
type mockSecrets struct{}
func (m *mockSecrets) Encrypt(s string) (string, error) { return s, nil }
func (m *mockSecrets) Decrypt(s string) (string, error) { return s, nil }

View File

@@ -0,0 +1,70 @@
package db
import (
"database/sql"
"fmt"
"time"
"novamd/internal/models"
)
// CreateSession inserts a new session record into the database
func (db *database) CreateSession(session *models.Session) error {
_, err := db.Exec(`
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
VALUES (?, ?, ?, ?, ?)`,
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
)
if err != nil {
return fmt.Errorf("failed to store session: %w", err)
}
return nil
}
// GetSessionByRefreshToken retrieves a session by its refresh token
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
session := &models.Session{}
err := db.QueryRow(`
SELECT id, user_id, refresh_token, expires_at, created_at
FROM sessions
WHERE refresh_token = ? AND expires_at > ?`,
refreshToken, time.Now(),
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("session not found or expired")
}
if err != nil {
return nil, fmt.Errorf("failed to fetch session: %w", err)
}
return session, nil
}
// DeleteSession removes a session from the database
func (db *database) DeleteSession(sessionID string) error {
result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
if err != nil {
return fmt.Errorf("failed to delete session: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("session not found")
}
return nil
}
// CleanExpiredSessions removes all expired sessions from the database
func (db *database) CleanExpiredSessions() error {
_, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
if err != nil {
return fmt.Errorf("failed to clean expired sessions: %w", err)
}
return nil
}

View File

@@ -0,0 +1,294 @@
package db_test
import (
"strings"
"testing"
"time"
"novamd/internal/db"
"novamd/internal/models"
"github.com/google/uuid"
)
func TestSessionOperations(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{})
if err != nil {
t.Fatalf("failed to create test database: %v", err)
}
defer database.Close()
if err := database.Migrate(); err != nil {
t.Fatalf("failed to run migrations: %v", err)
}
// Create a test user first since sessions need a valid user ID
user, err := database.CreateUser(&models.User{
Email: "test@example.com",
DisplayName: "Test User",
PasswordHash: "hash",
Role: "editor",
})
if err != nil {
t.Fatalf("failed to create test user: %v", err)
}
t.Run("CreateSession", func(t *testing.T) {
testCases := []struct {
name string
session *models.Session
wantErr bool
errContains string
}{
{
name: "valid session",
session: &models.Session{
ID: uuid.New().String(),
UserID: user.ID,
RefreshToken: "valid-token",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
},
wantErr: false,
},
{
name: "invalid user ID",
session: &models.Session{
ID: uuid.New().String(),
UserID: 99999, // Non-existent user ID
RefreshToken: "invalid-user-token",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
},
wantErr: true,
errContains: "FOREIGN KEY constraint failed",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := database.CreateSession(tc.session)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
} else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) {
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify session was stored
stored, err := database.GetSessionByRefreshToken(tc.session.RefreshToken)
if err != nil {
t.Fatalf("failed to retrieve stored session: %v", err)
}
// Compare fields
if stored.ID != tc.session.ID {
t.Errorf("ID = %v, want %v", stored.ID, tc.session.ID)
}
if stored.UserID != tc.session.UserID {
t.Errorf("UserID = %v, want %v", stored.UserID, tc.session.UserID)
}
if stored.RefreshToken != tc.session.RefreshToken {
t.Errorf("RefreshToken = %v, want %v", stored.RefreshToken, tc.session.RefreshToken)
}
// Compare times within a reasonable threshold
if diff := stored.ExpiresAt.Sub(tc.session.ExpiresAt); diff > time.Second || diff < -time.Second {
t.Errorf("ExpiresAt differs by %v, want difference less than 1s", diff)
}
})
}
})
t.Run("GetSessionByRefreshToken", func(t *testing.T) {
// Create test sessions
validSession := &models.Session{
ID: uuid.New().String(),
UserID: user.ID,
RefreshToken: "valid-get-token",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
}
expiredSession := &models.Session{
ID: uuid.New().String(),
UserID: user.ID,
RefreshToken: "expired-token",
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired
CreatedAt: time.Now().Add(-2 * time.Hour),
}
if err := database.CreateSession(validSession); err != nil {
t.Fatalf("failed to create valid session: %v", err)
}
if err := database.CreateSession(expiredSession); err != nil {
t.Fatalf("failed to create expired session: %v", err)
}
testCases := []struct {
name string
refreshToken string
wantErr bool
errContains string
}{
{
name: "valid token",
refreshToken: "valid-get-token",
wantErr: false,
},
{
name: "expired token",
refreshToken: "expired-token",
wantErr: true,
errContains: "session not found or expired",
},
{
name: "non-existent token",
refreshToken: "nonexistent-token",
wantErr: true,
errContains: "session not found",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := database.GetSessionByRefreshToken(tc.refreshToken)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
} else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) {
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if session.RefreshToken != tc.refreshToken {
t.Errorf("RefreshToken = %v, want %v", session.RefreshToken, tc.refreshToken)
}
})
}
})
t.Run("DeleteSession", func(t *testing.T) {
session := &models.Session{
ID: uuid.New().String(),
UserID: user.ID,
RefreshToken: "delete-token",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
}
if err := database.CreateSession(session); err != nil {
t.Fatalf("failed to create session: %v", err)
}
testCases := []struct {
name string
sessionID string
wantErr bool
errContains string
}{
{
name: "valid session ID",
sessionID: session.ID,
wantErr: false,
},
{
name: "non-existent session ID",
sessionID: "nonexistent-id",
wantErr: true,
errContains: "session not found",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := database.DeleteSession(tc.sessionID)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
} else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) {
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify session was deleted
_, err = database.GetSessionByRefreshToken(session.RefreshToken)
if err == nil {
t.Error("session still exists after deletion")
}
})
}
})
t.Run("CleanExpiredSessions", func(t *testing.T) {
// Create a mix of valid and expired sessions
sessions := []*models.Session{
{
ID: uuid.New().String(),
UserID: user.ID,
RefreshToken: "valid-clean-token",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
},
{
ID: uuid.New().String(),
UserID: user.ID,
RefreshToken: "expired-clean-token-1",
ExpiresAt: time.Now().Add(-1 * time.Hour),
CreatedAt: time.Now().Add(-2 * time.Hour),
},
{
ID: uuid.New().String(),
UserID: user.ID,
RefreshToken: "expired-clean-token-2",
ExpiresAt: time.Now().Add(-2 * time.Hour),
CreatedAt: time.Now().Add(-3 * time.Hour),
},
}
for _, s := range sessions {
if err := database.CreateSession(s); err != nil {
t.Fatalf("failed to create session: %v", err)
}
}
// Clean expired sessions
if err := database.CleanExpiredSessions(); err != nil {
t.Fatalf("failed to clean expired sessions: %v", err)
}
// Verify valid session still exists
validSession, err := database.GetSessionByRefreshToken("valid-clean-token")
if err != nil {
t.Errorf("valid session was unexpectedly deleted: %v", err)
}
if validSession == nil {
t.Error("valid session was unexpectedly deleted")
}
// Verify expired sessions were deleted
expiredTokens := []string{"expired-clean-token-1", "expired-clean-token-2"}
for _, token := range expiredTokens {
if _, err := database.GetSessionByRefreshToken(token); err == nil {
t.Errorf("expired session with token %s still exists", token)
}
}
})
}

View File

@@ -11,9 +11,16 @@ const (
JWTSecretKey = "jwt_secret"
)
// UserStats represents system-wide statistics
type UserStats struct {
TotalUsers int `json:"totalUsers"`
TotalWorkspaces int `json:"totalWorkspaces"`
ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days
}
// EnsureJWTSecret makes sure a JWT signing secret exists in the database
// If no secret exists, it generates and stores a new one
func (db *DB) EnsureJWTSecret() (string, error) {
func (db *database) EnsureJWTSecret() (string, error) {
// First, try to get existing secret
secret, err := db.GetSystemSetting(JWTSecretKey)
if err == nil {
@@ -36,7 +43,7 @@ func (db *DB) EnsureJWTSecret() (string, error) {
}
// GetSystemSetting retrieves a system setting by key
func (db *DB) GetSystemSetting(key string) (string, error) {
func (db *database) GetSystemSetting(key string) (string, error) {
var value string
err := db.QueryRow("SELECT value FROM system_settings WHERE key = ?", key).Scan(&value)
if err != nil {
@@ -46,7 +53,7 @@ func (db *DB) GetSystemSetting(key string) (string, error) {
}
// SetSystemSetting stores or updates a system setting
func (db *DB) SetSystemSetting(key, value string) error {
func (db *database) SetSystemSetting(key, value string) error {
_, err := db.Exec(`
INSERT INTO system_settings (key, value)
VALUES (?, ?)
@@ -64,3 +71,32 @@ func generateRandomSecret(bytes int) (string, error) {
}
return base64.StdEncoding.EncodeToString(b), nil
}
// GetSystemStats returns system-wide statistics
func (db *database) GetSystemStats() (*UserStats, error) {
stats := &UserStats{}
// Get total users
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
if err != nil {
return nil, err
}
// Get total workspaces
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
if err != nil {
return nil, err
}
// Get active users (users with activity in last 30 days)
err = db.QueryRow(`
SELECT COUNT(DISTINCT user_id)
FROM sessions
WHERE created_at > datetime('now', '-30 days')`).
Scan(&stats.ActiveUsers)
if err != nil {
return nil, err
}
return stats, nil
}

View File

@@ -0,0 +1,213 @@
package db_test
import (
"encoding/base64"
"fmt"
"strings"
"testing"
"time"
"novamd/internal/db"
"novamd/internal/models"
"github.com/google/uuid"
)
func TestSystemOperations(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{})
if err != nil {
t.Fatalf("failed to create test database: %v", err)
}
defer database.Close()
if err := database.Migrate(); err != nil {
t.Fatalf("failed to run migrations: %v", err)
}
t.Run("GetSystemSettings", func(t *testing.T) {
t.Run("non-existent setting", func(t *testing.T) {
_, err := database.GetSystemSetting("nonexistent-key")
if err == nil {
t.Error("expected error for non-existent key, got nil")
}
})
t.Run("existing setting", func(t *testing.T) {
// First set a value
err := database.SetSystemSetting("test-key", "test-value")
if err != nil {
t.Fatalf("failed to set system setting: %v", err)
}
// Then get it back
value, err := database.GetSystemSetting("test-key")
if err != nil {
t.Fatalf("failed to get system setting: %v", err)
}
if value != "test-value" {
t.Errorf("got value %q, want %q", value, "test-value")
}
})
})
t.Run("SetSystemSettings", func(t *testing.T) {
testCases := []struct {
name string
key string
value string
wantErr bool
errContains string
}{
{
name: "new setting",
key: "new-key",
value: "new-value",
wantErr: false,
},
{
name: "update existing setting",
key: "update-key",
value: "original-value",
wantErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := database.SetSystemSetting(tc.key, tc.value)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
} else if !strings.Contains(err.Error(), tc.errContains) {
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify the setting was stored
stored, err := database.GetSystemSetting(tc.key)
if err != nil {
t.Fatalf("failed to retrieve stored setting: %v", err)
}
if stored != tc.value {
t.Errorf("got value %q, want %q", stored, tc.value)
}
// For the update case, test updating the value
if tc.name == "update existing setting" {
newValue := "updated-value"
err := database.SetSystemSetting(tc.key, newValue)
if err != nil {
t.Fatalf("failed to update setting: %v", err)
}
stored, err := database.GetSystemSetting(tc.key)
if err != nil {
t.Fatalf("failed to retrieve updated setting: %v", err)
}
if stored != newValue {
t.Errorf("got updated value %q, want %q", stored, newValue)
}
}
})
}
})
t.Run("EnsureJWTSecret", func(t *testing.T) {
// First call should generate a new secret
secret1, err := database.EnsureJWTSecret()
if err != nil {
t.Fatalf("failed to ensure JWT secret: %v", err)
}
// Verify the secret is a valid base64-encoded string of sufficient length
decoded, err := base64.StdEncoding.DecodeString(secret1)
if err != nil {
t.Errorf("secret is not valid base64: %v", err)
}
if len(decoded) < 32 {
t.Errorf("secret length = %d, want >= 32", len(decoded))
}
// Second call should return the same secret
secret2, err := database.EnsureJWTSecret()
if err != nil {
t.Fatalf("failed to get existing JWT secret: %v", err)
}
if secret2 != secret1 {
t.Errorf("got different secret on second call")
}
})
t.Run("GetSystemStats", func(t *testing.T) {
// Create some test users and sessions
users := []*models.User{
{
Email: "user1@test.com",
DisplayName: "User 1",
PasswordHash: "hash1",
Role: "editor",
},
{
Email: "user2@test.com",
DisplayName: "User 2",
PasswordHash: "hash2",
Role: "viewer",
},
}
for _, u := range users {
createdUser, err := database.CreateUser(u)
if err != nil {
t.Fatalf("failed to create test user: %v", err)
}
// Create multiple workspaces per user
// Each user has one default workspace
for i := 0; i < 2; i++ {
workspace := &models.Workspace{
UserID: createdUser.ID,
Name: fmt.Sprintf("Workspace %d", i),
}
if err := database.CreateWorkspace(workspace); err != nil {
t.Fatalf("failed to create test workspace: %v", err)
}
}
// Create an active session for the first user
if createdUser.Email == "user1@test.com" {
session := &models.Session{
ID: uuid.New().String(),
UserID: createdUser.ID,
RefreshToken: "test-token",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
}
if err := database.CreateSession(session); err != nil {
t.Fatalf("failed to create test session: %v", err)
}
}
}
stats, err := database.GetSystemStats()
if err != nil {
t.Fatalf("failed to get system stats: %v", err)
}
// Verify stats
if stats.TotalUsers != 2 {
t.Errorf("TotalUsers = %d, want 2", stats.TotalUsers)
}
if stats.TotalWorkspaces != 6 { // 2 + 1 default workspace per user
t.Errorf("TotalWorkspaces = %d, want 6", stats.TotalWorkspaces)
}
if stats.ActiveUsers != 1 { // Only user1 has an active session
t.Errorf("ActiveUsers = %d, want 1", stats.ActiveUsers)
}
})
}

View File

@@ -0,0 +1,30 @@
//go:build test
package db
import (
"database/sql"
"novamd/internal/secrets"
)
type TestDatabase interface {
Database
TestDB() *sql.DB
}
func NewTestDB(dbPath string, secretsService secrets.Service) (TestDatabase, error) {
db, err := Init(dbPath, secretsService)
if err != nil {
return nil, err
}
return &testDatabase{db.(*database)}, nil
}
type testDatabase struct {
*database
}
func (td *testDatabase) TestDB() *sql.DB {
return td.DB
}

View File

@@ -6,7 +6,7 @@ import (
)
// CreateUser inserts a new user record into the database
func (db *DB) CreateUser(user *models.User) (*models.User, error) {
func (db *database) CreateUser(user *models.User) (*models.User, error) {
tx, err := db.Begin()
if err != nil {
return nil, err
@@ -38,7 +38,7 @@ func (db *DB) CreateUser(user *models.User) (*models.User, error) {
UserID: user.ID,
Name: "Main",
}
defaultWorkspace.GetDefaultSettings() // Initialize default settings
defaultWorkspace.SetDefaultSettings() // Initialize default settings
// Create workspace with settings
err = db.createWorkspaceTx(tx, defaultWorkspace)
@@ -62,14 +62,14 @@ func (db *DB) CreateUser(user *models.User) (*models.User, error) {
}
// Helper function to create a workspace in a transaction
func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
result, err := tx.Exec(`
INSERT INTO workspaces (
user_id, name,
theme, auto_save, show_hidden_files,
git_enabled, git_url, git_user, git_token,
git_auto_commit, git_commit_msg_template
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
workspace.UserID, workspace.Name,
workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles,
workspace.GitEnabled, workspace.GitURL, workspace.GitUser, workspace.GitToken,
@@ -87,7 +87,7 @@ func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
}
// GetUserByID retrieves a user by ID
func (db *DB) GetUserByID(id int) (*models.User, error) {
func (db *database) GetUserByID(id int) (*models.User, error) {
user := &models.User{}
err := db.QueryRow(`
SELECT
@@ -104,7 +104,7 @@ func (db *DB) GetUserByID(id int) (*models.User, error) {
}
// GetUserByEmail retrieves a user by email
func (db *DB) GetUserByEmail(email string) (*models.User, error) {
func (db *database) GetUserByEmail(email string) (*models.User, error) {
user := &models.User{}
err := db.QueryRow(`
SELECT
@@ -122,7 +122,7 @@ func (db *DB) GetUserByEmail(email string) (*models.User, error) {
}
// UpdateUser updates a user's information
func (db *DB) UpdateUser(user *models.User) error {
func (db *database) UpdateUser(user *models.User) error {
_, err := db.Exec(`
UPDATE users
SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ?
@@ -131,8 +131,36 @@ func (db *DB) UpdateUser(user *models.User) error {
return err
}
// GetAllUsers returns a list of all users in the system
func (db *database) GetAllUsers() ([]*models.User, error) {
rows, err := db.Query(`
SELECT
id, email, display_name, role, created_at,
last_workspace_id
FROM users
ORDER BY id ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
var users []*models.User
for rows.Next() {
user := &models.User{}
err := rows.Scan(
&user.ID, &user.Email, &user.DisplayName, &user.Role,
&user.CreatedAt, &user.LastWorkspaceID,
)
if err != nil {
return nil, err
}
users = append(users, user)
}
return users, nil
}
// UpdateLastWorkspace updates the last workspace the user accessed
func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error {
func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error {
tx, err := db.Begin()
if err != nil {
return err
@@ -155,7 +183,7 @@ func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error {
}
// DeleteUser deletes a user and all their workspaces
func (db *DB) DeleteUser(id int) error {
func (db *database) DeleteUser(id int) error {
tx, err := db.Begin()
if err != nil {
return err
@@ -178,7 +206,7 @@ func (db *DB) DeleteUser(id int) error {
}
// GetLastWorkspaceName returns the name of the last workspace the user accessed
func (db *DB) GetLastWorkspaceName(userID int) (string, error) {
func (db *database) GetLastWorkspaceName(userID int) (string, error) {
var workspaceName string
err := db.QueryRow(`
SELECT
@@ -189,3 +217,10 @@ func (db *DB) GetLastWorkspaceName(userID int) (string, error) {
Scan(&workspaceName)
return workspaceName, err
}
// CountAdminUsers returns the number of admin users in the system
func (db *database) CountAdminUsers() (int, error) {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count)
return count, err
}

View File

@@ -0,0 +1,413 @@
package db_test
import (
"strings"
"testing"
"novamd/internal/db"
"novamd/internal/models"
)
func TestUserOperations(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{})
if err != nil {
t.Fatalf("failed to create test database: %v", err)
}
defer database.Close()
if err := database.Migrate(); err != nil {
t.Fatalf("failed to run migrations: %v", err)
}
t.Run("CreateUser", func(t *testing.T) {
testCases := []struct {
name string
user *models.User
wantErr bool
errContains string
}{
{
name: "valid user",
user: &models.User{
Email: "test@example.com",
DisplayName: "Test User",
PasswordHash: "hashed_password",
Role: models.RoleEditor,
},
wantErr: false,
},
{
name: "duplicate email",
user: &models.User{
Email: "test@example.com", // Same as above
DisplayName: "Another User",
PasswordHash: "different_hash",
Role: models.RoleViewer,
},
wantErr: true,
errContains: "UNIQUE constraint failed",
},
{
name: "invalid role",
user: &models.User{
Email: "invalid@example.com",
DisplayName: "Invalid Role User",
PasswordHash: "hash",
Role: "invalid_role",
},
wantErr: true,
errContains: "CHECK constraint failed",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
user, err := database.CreateUser(tc.user)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
} else if !strings.Contains(err.Error(), tc.errContains) {
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify user was created properly
if user.ID == 0 {
t.Error("expected non-zero user ID")
}
if user.Email != tc.user.Email {
t.Errorf("Email = %v, want %v", user.Email, tc.user.Email)
}
if user.DisplayName != tc.user.DisplayName {
t.Errorf("DisplayName = %v, want %v", user.DisplayName, tc.user.DisplayName)
}
if user.Role != tc.user.Role {
t.Errorf("Role = %v, want %v", user.Role, tc.user.Role)
}
if user.CreatedAt.IsZero() {
t.Error("CreatedAt should not be zero")
}
if user.LastWorkspaceID == 0 {
t.Error("expected non-zero LastWorkspaceID (default workspace)")
}
})
}
})
t.Run("GetUserByID", func(t *testing.T) {
// Create a test user first
createdUser, err := database.CreateUser(&models.User{
Email: "getbyid@example.com",
DisplayName: "Get By ID User",
PasswordHash: "hash",
Role: models.RoleEditor,
})
if err != nil {
t.Fatalf("failed to create test user: %v", err)
}
testCases := []struct {
name string
userID int
wantErr bool
}{
{
name: "existing user",
userID: createdUser.ID,
wantErr: false,
},
{
name: "non-existent user",
userID: 99999,
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
user, err := database.GetUserByID(tc.userID)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if user.ID != tc.userID {
t.Errorf("ID = %v, want %v", user.ID, tc.userID)
}
})
}
})
t.Run("GetUserByEmail", func(t *testing.T) {
// Create a test user first
createdUser, err := database.CreateUser(&models.User{
Email: "getbyemail@example.com",
DisplayName: "Get By Email User",
PasswordHash: "hash",
Role: models.RoleEditor,
})
if err != nil {
t.Fatalf("failed to create test user: %v", err)
}
testCases := []struct {
name string
email string
wantErr bool
}{
{
name: "existing user",
email: createdUser.Email,
wantErr: false,
},
{
name: "non-existent user",
email: "nonexistent@example.com",
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
user, err := database.GetUserByEmail(tc.email)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if user.Email != tc.email {
t.Errorf("Email = %v, want %v", user.Email, tc.email)
}
})
}
})
t.Run("UpdateUser", func(t *testing.T) {
// Create a test user first
user, err := database.CreateUser(&models.User{
Email: "update@example.com",
DisplayName: "Original Name",
PasswordHash: "original_hash",
Role: models.RoleEditor,
})
if err != nil {
t.Fatalf("failed to create test user: %v", err)
}
// Update user details
user.DisplayName = "Updated Name"
user.PasswordHash = "new_hash"
user.Role = models.RoleAdmin
if err := database.UpdateUser(user); err != nil {
t.Fatalf("failed to update user: %v", err)
}
// Verify updates
updated, err := database.GetUserByID(user.ID)
if err != nil {
t.Fatalf("failed to get updated user: %v", err)
}
if updated.DisplayName != "Updated Name" {
t.Errorf("DisplayName = %v, want %v", updated.DisplayName, "Updated Name")
}
if updated.PasswordHash != "new_hash" {
t.Errorf("PasswordHash = %v, want %v", updated.PasswordHash, "new_hash")
}
if updated.Role != models.RoleAdmin {
t.Errorf("Role = %v, want %v", updated.Role, models.RoleAdmin)
}
})
t.Run("GetAllUsers", func(t *testing.T) {
// Create several test users
testUsers := []*models.User{
{
Email: "user1@example.com",
DisplayName: "User One",
PasswordHash: "hash1",
Role: models.RoleEditor,
},
{
Email: "user2@example.com",
DisplayName: "User Two",
PasswordHash: "hash2",
Role: models.RoleViewer,
},
}
for _, u := range testUsers {
_, err := database.CreateUser(u)
if err != nil {
t.Fatalf("failed to create test user: %v", err)
}
}
// Get all users
users, err := database.GetAllUsers()
if err != nil {
t.Fatalf("failed to get all users: %v", err)
}
// We should have at least as many users as we just created
// (there might be more from previous tests)
if len(users) < len(testUsers) {
t.Errorf("got %d users, want at least %d", len(users), len(testUsers))
}
// Verify each test user exists in the result
for _, expected := range testUsers {
found := false
for _, u := range users {
if u.Email == expected.Email {
found = true
if u.DisplayName != expected.DisplayName {
t.Errorf("DisplayName = %v, want %v", u.DisplayName, expected.DisplayName)
}
if u.Role != expected.Role {
t.Errorf("Role = %v, want %v", u.Role, expected.Role)
}
break
}
}
if !found {
t.Errorf("user with email %s not found in results", expected.Email)
}
}
})
t.Run("UpdateLastWorkspace", func(t *testing.T) {
// Create a test user with multiple workspaces
user, err := database.CreateUser(&models.User{
Email: "workspace@example.com",
DisplayName: "Workspace User",
PasswordHash: "hash",
Role: models.RoleEditor,
})
if err != nil {
t.Fatalf("failed to create test user: %v", err)
}
// Create additional workspace
workspace := &models.Workspace{
UserID: user.ID,
Name: "Second Workspace",
}
if err := database.CreateWorkspace(workspace); err != nil {
t.Fatalf("failed to create additional workspace: %v", err)
}
// Update last workspace
err = database.UpdateLastWorkspace(user.ID, workspace.Name)
if err != nil {
t.Fatalf("failed to update last workspace: %v", err)
}
// Verify update
lastWorkspace, err := database.GetLastWorkspaceName(user.ID)
if err != nil {
t.Fatalf("failed to get last workspace: %v", err)
}
if lastWorkspace != workspace.Name {
t.Errorf("LastWorkspace = %v, want %v", lastWorkspace, workspace.Name)
}
})
t.Run("DeleteUser", func(t *testing.T) {
// Create a test user
user, err := database.CreateUser(&models.User{
Email: "delete@example.com",
DisplayName: "Delete User",
PasswordHash: "hash",
Role: models.RoleEditor,
})
if err != nil {
t.Fatalf("failed to create test user: %v", err)
}
// Delete the user
if err := database.DeleteUser(user.ID); err != nil {
t.Fatalf("failed to delete user: %v", err)
}
// Verify user is gone
_, err = database.GetUserByID(user.ID)
if err == nil {
t.Error("expected error getting deleted user, got nil")
}
// Verify workspaces are gone
workspaces, err := database.GetWorkspacesByUserID(user.ID)
if err != nil {
t.Fatalf("unexpected error checking workspaces: %v", err)
}
if len(workspaces) > 0 {
t.Error("expected no workspaces for deleted user")
}
})
t.Run("CountAdminUsers", func(t *testing.T) {
// Create users with different roles
testUsers := []*models.User{
{
Email: "admin1@example.com",
DisplayName: "Admin One",
PasswordHash: "hash1",
Role: models.RoleAdmin,
},
{
Email: "admin2@example.com",
DisplayName: "Admin Two",
PasswordHash: "hash2",
Role: models.RoleAdmin,
},
{
Email: "editor@example.com",
DisplayName: "Editor",
PasswordHash: "hash3",
Role: models.RoleEditor,
},
}
for _, u := range testUsers {
_, err := database.CreateUser(u)
if err != nil {
t.Fatalf("failed to create test user: %v", err)
}
}
// Count admin users
count, err := database.CountAdminUsers()
if err != nil {
t.Fatalf("failed to count admin users: %v", err)
}
// We should have at least 2 admin users (from our test cases)
// There might be more from previous tests
if count < 2 {
t.Errorf("AdminCount = %d, want at least 2", count)
}
})
}

View File

@@ -7,10 +7,10 @@ import (
)
// CreateWorkspace inserts a new workspace record into the database
func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
func (db *database) CreateWorkspace(workspace *models.Workspace) error {
// Set default settings if not provided
if workspace.Theme == "" {
workspace.GetDefaultSettings()
workspace.SetDefaultSettings()
}
// Encrypt token if present
@@ -24,7 +24,7 @@ func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
user_id, name, theme, auto_save, show_hidden_files,
git_enabled, git_url, git_user, git_token,
git_auto_commit, git_commit_msg_template
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
workspace.UserID, workspace.Name, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles,
workspace.GitEnabled, workspace.GitURL, workspace.GitUser, encryptedToken,
workspace.GitAutoCommit, workspace.GitCommitMsgTemplate,
@@ -42,7 +42,7 @@ func (db *DB) CreateWorkspace(workspace *models.Workspace) error {
}
// GetWorkspaceByID retrieves a workspace by its ID
func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) {
func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
workspace := &models.Workspace{}
var encryptedToken string
@@ -75,7 +75,7 @@ func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) {
}
// GetWorkspaceByName retrieves a workspace by its name and user ID
func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
workspace := &models.Workspace{}
var encryptedToken string
@@ -108,7 +108,7 @@ func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Work
}
// UpdateWorkspace updates a workspace record in the database
func (db *DB) UpdateWorkspace(workspace *models.Workspace) error {
func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
// Encrypt token before storing
encryptedToken, err := db.encryptToken(workspace.GitToken)
if err != nil {
@@ -146,7 +146,7 @@ func (db *DB) UpdateWorkspace(workspace *models.Workspace) error {
}
// GetWorkspacesByUserID retrieves all workspaces for a user
func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
rows, err := db.Query(`
SELECT
id, user_id, name, created_at,
@@ -189,7 +189,7 @@ func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
// UpdateWorkspaceSettings updates only the settings portion of a workspace
// This is useful when you don't want to modify the name or other core workspace properties
func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error {
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
_, err := db.Exec(`
UPDATE workspaces
SET
@@ -218,31 +218,31 @@ func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error {
}
// DeleteWorkspace removes a workspace record from the database
func (db *DB) DeleteWorkspace(id int) error {
func (db *database) DeleteWorkspace(id int) error {
_, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id)
return err
}
// DeleteWorkspaceTx removes a workspace record from the database within a transaction
func (db *DB) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
_, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id)
return err
}
// UpdateLastWorkspaceTx sets the last workspace for a user in with a transaction
func (db *DB) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
_, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID)
return err
}
// UpdateLastOpenedFile updates the last opened file path for a workspace
func (db *DB) UpdateLastOpenedFile(workspaceID int, filePath string) error {
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
_, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?", filePath, workspaceID)
return err
}
// GetLastOpenedFile retrieves the last opened file path for a workspace
func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) {
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
var filePath sql.NullString
err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?", workspaceID).Scan(&filePath)
if err != nil {
@@ -255,7 +255,7 @@ func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) {
}
// GetAllWorkspaces retrieves all workspaces in the database
func (db *DB) GetAllWorkspaces() ([]*models.Workspace, error) {
func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
rows, err := db.Query(`
SELECT
id, user_id, name, created_at,

View File

@@ -0,0 +1,430 @@
package db_test
import (
"database/sql"
"strings"
"testing"
"novamd/internal/db"
"novamd/internal/models"
)
func TestWorkspaceOperations(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{})
if err != nil {
t.Fatalf("failed to create test database: %v", err)
}
defer database.Close()
if err := database.Migrate(); err != nil {
t.Fatalf("failed to run migrations: %v", err)
}
// Create a test user first
user, err := database.CreateUser(&models.User{
Email: "test@example.com",
DisplayName: "Test User",
PasswordHash: "hash",
Role: models.RoleEditor,
})
if err != nil {
t.Fatalf("failed to create test user: %v", err)
}
t.Run("CreateWorkspace", func(t *testing.T) {
testCases := []struct {
name string
workspace *models.Workspace
wantErr bool
errContains string
}{
{
name: "valid workspace",
workspace: &models.Workspace{
UserID: user.ID,
Name: "Test Workspace",
},
wantErr: false,
},
{
name: "non-existent user",
workspace: &models.Workspace{
UserID: 99999,
Name: "Invalid User",
},
wantErr: true,
errContains: "FOREIGN KEY constraint failed",
},
{
name: "with git settings",
workspace: &models.Workspace{
UserID: user.ID,
Name: "Git Workspace",
Theme: "dark",
AutoSave: true,
ShowHiddenFiles: true,
GitEnabled: true,
GitURL: "https://github.com/user/repo",
GitUser: "username",
GitToken: "secret-token",
GitAutoCommit: true,
GitCommitMsgTemplate: "${action} ${filename}",
},
wantErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.workspace.Theme == "" {
tc.workspace.SetDefaultSettings()
}
err := database.CreateWorkspace(tc.workspace)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
} else if !strings.Contains(err.Error(), tc.errContains) {
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify workspace was created properly
if tc.workspace.ID == 0 {
t.Error("expected non-zero workspace ID")
}
// Retrieve and verify workspace
stored, err := database.GetWorkspaceByID(tc.workspace.ID)
if err != nil {
t.Fatalf("failed to retrieve workspace: %v", err)
}
verifyWorkspace(t, stored, tc.workspace)
})
}
})
t.Run("GetWorkspaceByID", func(t *testing.T) {
// Create a test workspace first
workspace := &models.Workspace{
UserID: user.ID,
Name: "Get By ID Workspace",
}
workspace.SetDefaultSettings()
if err := database.CreateWorkspace(workspace); err != nil {
t.Fatalf("failed to create test workspace: %v", err)
}
testCases := []struct {
name string
workspaceID int
wantErr bool
}{
{
name: "existing workspace",
workspaceID: workspace.ID,
wantErr: false,
},
{
name: "non-existent workspace",
workspaceID: 99999,
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result, err := database.GetWorkspaceByID(tc.workspaceID)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.ID != tc.workspaceID {
t.Errorf("ID = %v, want %v", result.ID, tc.workspaceID)
}
})
}
})
t.Run("GetWorkspaceByName", func(t *testing.T) {
// Create a test workspace first
workspace := &models.Workspace{
UserID: user.ID,
Name: "Get By Name Workspace",
}
workspace.SetDefaultSettings()
if err := database.CreateWorkspace(workspace); err != nil {
t.Fatalf("failed to create test workspace: %v", err)
}
testCases := []struct {
name string
userID int
workspaceName string
wantErr bool
}{
{
name: "existing workspace",
userID: user.ID,
workspaceName: workspace.Name,
wantErr: false,
},
{
name: "wrong user ID",
userID: 99999,
workspaceName: workspace.Name,
wantErr: true,
},
{
name: "non-existent workspace",
userID: user.ID,
workspaceName: "Non-existent",
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result, err := database.GetWorkspaceByName(tc.userID, tc.workspaceName)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Name != tc.workspaceName {
t.Errorf("Name = %v, want %v", result.Name, tc.workspaceName)
}
if result.UserID != tc.userID {
t.Errorf("UserID = %v, want %v", result.UserID, tc.userID)
}
})
}
})
t.Run("UpdateWorkspace", func(t *testing.T) {
// Create a test workspace first
workspace := &models.Workspace{
UserID: user.ID,
Name: "Update Workspace",
}
workspace.SetDefaultSettings()
if err := database.CreateWorkspace(workspace); err != nil {
t.Fatalf("failed to create test workspace: %v", err)
}
// Update workspace settings
workspace.Theme = "dark"
workspace.AutoSave = true
workspace.ShowHiddenFiles = true
workspace.GitEnabled = true
workspace.GitURL = "https://github.com/user/repo"
workspace.GitUser = "username"
workspace.GitToken = "new-token"
workspace.GitAutoCommit = true
workspace.GitCommitMsgTemplate = "custom ${filename}"
if err := database.UpdateWorkspace(workspace); err != nil {
t.Fatalf("failed to update workspace: %v", err)
}
// Verify updates
updated, err := database.GetWorkspaceByID(workspace.ID)
if err != nil {
t.Fatalf("failed to get updated workspace: %v", err)
}
verifyWorkspace(t, updated, workspace)
})
t.Run("GetWorkspacesByUserID", func(t *testing.T) {
// Create several test workspaces
testWorkspaces := []*models.Workspace{
{
UserID: user.ID,
Name: "User Workspace 1",
},
{
UserID: user.ID,
Name: "User Workspace 2",
},
}
for _, w := range testWorkspaces {
w.SetDefaultSettings()
if err := database.CreateWorkspace(w); err != nil {
t.Fatalf("failed to create test workspace: %v", err)
}
}
// Get all workspaces for user
workspaces, err := database.GetWorkspacesByUserID(user.ID)
if err != nil {
t.Fatalf("failed to get workspaces: %v", err)
}
// We should have at least as many workspaces as we just created
// (there might be more from previous tests)
if len(workspaces) < len(testWorkspaces) {
t.Errorf("got %d workspaces, want at least %d", len(workspaces), len(testWorkspaces))
}
// Verify each test workspace exists in the result
for _, expected := range testWorkspaces {
found := false
for _, w := range workspaces {
if w.Name == expected.Name {
found = true
if w.UserID != expected.UserID {
t.Errorf("UserID = %v, want %v", w.UserID, expected.UserID)
}
break
}
}
if !found {
t.Errorf("workspace %s not found in results", expected.Name)
}
}
})
t.Run("UpdateLastOpenedFile", func(t *testing.T) {
// Create a test workspace
workspace := &models.Workspace{
UserID: user.ID,
Name: "Last File Workspace",
}
workspace.SetDefaultSettings()
if err := database.CreateWorkspace(workspace); err != nil {
t.Fatalf("failed to create test workspace: %v", err)
}
testCases := []struct {
name string
filePath string
wantErr bool
}{
{
name: "valid file path",
filePath: "docs/test.md",
wantErr: false,
},
{
name: "empty file path",
filePath: "",
wantErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := database.UpdateLastOpenedFile(workspace.ID, tc.filePath)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify update
path, err := database.GetLastOpenedFile(workspace.ID)
if err != nil {
t.Fatalf("failed to get last opened file: %v", err)
}
if path != tc.filePath {
t.Errorf("LastOpenedFile = %v, want %v", path, tc.filePath)
}
})
}
})
t.Run("DeleteWorkspace", func(t *testing.T) {
// Create a test workspace
workspace := &models.Workspace{
UserID: user.ID,
Name: "Delete Workspace",
}
workspace.SetDefaultSettings()
if err := database.CreateWorkspace(workspace); err != nil {
t.Fatalf("failed to create test workspace: %v", err)
}
// Delete the workspace
if err := database.DeleteWorkspace(workspace.ID); err != nil {
t.Fatalf("failed to delete workspace: %v", err)
}
// Verify workspace is gone
_, err = database.GetWorkspaceByID(workspace.ID)
if err != sql.ErrNoRows {
t.Errorf("expected sql.ErrNoRows, got %v", err)
}
})
}
// Helper function to verify workspace fields
func verifyWorkspace(t *testing.T, actual, expected *models.Workspace) {
t.Helper()
if actual.Name != expected.Name {
t.Errorf("Name = %v, want %v", actual.Name, expected.Name)
}
if actual.UserID != expected.UserID {
t.Errorf("UserID = %v, want %v", actual.UserID, expected.UserID)
}
if actual.Theme != expected.Theme {
t.Errorf("Theme = %v, want %v", actual.Theme, expected.Theme)
}
if actual.AutoSave != expected.AutoSave {
t.Errorf("AutoSave = %v, want %v", actual.AutoSave, expected.AutoSave)
}
if actual.ShowHiddenFiles != expected.ShowHiddenFiles {
t.Errorf("ShowHiddenFiles = %v, want %v", actual.ShowHiddenFiles, expected.ShowHiddenFiles)
}
if actual.GitEnabled != expected.GitEnabled {
t.Errorf("GitEnabled = %v, want %v", actual.GitEnabled, expected.GitEnabled)
}
if actual.GitURL != expected.GitURL {
t.Errorf("GitURL = %v, want %v", actual.GitURL, expected.GitURL)
}
if actual.GitUser != expected.GitUser {
t.Errorf("GitUser = %v, want %v", actual.GitUser, expected.GitUser)
}
if actual.GitToken != expected.GitToken {
t.Errorf("GitToken = %v, want %v", actual.GitToken, expected.GitToken)
}
if actual.GitAutoCommit != expected.GitAutoCommit {
t.Errorf("GitAutoCommit = %v, want %v", actual.GitAutoCommit, expected.GitAutoCommit)
}
if actual.GitCommitMsgTemplate != expected.GitCommitMsgTemplate {
t.Errorf("GitCommitMsgTemplate = %v, want %v", actual.GitCommitMsgTemplate, expected.GitCommitMsgTemplate)
}
if actual.CreatedAt.IsZero() {
t.Error("CreatedAt should not be zero")
}
}

View File

@@ -1,40 +0,0 @@
package filesystem
import (
"fmt"
"novamd/internal/gitutils"
"path/filepath"
"strings"
)
// FileSystem represents the file system structure.
type FileSystem struct {
RootDir string
GitRepos map[int]map[int]*gitutils.GitRepo // map[userID]map[workspaceID]*gitutils.GitRepo
}
// New creates a new FileSystem instance.
func New(rootDir string) *FileSystem {
return &FileSystem{
RootDir: rootDir,
GitRepos: make(map[int]map[int]*gitutils.GitRepo),
}
}
// ValidatePath validates the given path and returns the cleaned path if it is valid.
func (fs *FileSystem) ValidatePath(userID, workspaceID int, path string) (string, error) {
workspacePath := fs.GetWorkspacePath(userID, workspaceID)
fullPath := filepath.Join(workspacePath, path)
cleanPath := filepath.Clean(fullPath)
if !strings.HasPrefix(cleanPath, workspacePath) {
return "", fmt.Errorf("invalid path: outside of workspace")
}
return cleanPath, nil
}
// GetTotalFileStats returns the total file statistics for the file system.
func (fs *FileSystem) GetTotalFileStats() (*FileCountStats, error) {
return fs.countFilesInPath(fs.RootDir)
}

View File

@@ -1,60 +0,0 @@
package filesystem
import (
"fmt"
"novamd/internal/gitutils"
)
// SetupGitRepo sets up a Git repository for the given user and workspace IDs.
func (fs *FileSystem) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken string) error {
workspacePath := fs.GetWorkspacePath(userID, workspaceID)
if _, ok := fs.GitRepos[userID]; !ok {
fs.GitRepos[userID] = make(map[int]*gitutils.GitRepo)
}
fs.GitRepos[userID][workspaceID] = gitutils.New(gitURL, gitUser, gitToken, workspacePath)
return fs.GitRepos[userID][workspaceID].EnsureRepo()
}
// DisableGitRepo disables the Git repository for the given user and workspace IDs.
func (fs *FileSystem) DisableGitRepo(userID, workspaceID int) {
if userRepos, ok := fs.GitRepos[userID]; ok {
delete(userRepos, workspaceID)
if len(userRepos) == 0 {
delete(fs.GitRepos, userID)
}
}
}
// StageCommitAndPush stages, commits, and pushes the changes to the Git repository.
func (fs *FileSystem) StageCommitAndPush(userID, workspaceID int, message string) error {
repo, ok := fs.getGitRepo(userID, workspaceID)
if !ok {
return fmt.Errorf("git settings not configured for this workspace")
}
if err := repo.Commit(message); err != nil {
return err
}
return repo.Push()
}
// Pull pulls the changes from the remote Git repository.
func (fs *FileSystem) Pull(userID, workspaceID int) error {
repo, ok := fs.getGitRepo(userID, workspaceID)
if !ok {
return fmt.Errorf("git settings not configured for this workspace")
}
return repo.Pull()
}
// getGitRepo returns the Git repository for the given user and workspace IDs.
func (fs *FileSystem) getGitRepo(userID, workspaceID int) (*gitutils.GitRepo, bool) {
userRepos, ok := fs.GitRepos[userID]
if !ok {
return nil, false
}
repo, ok := userRepos[workspaceID]
return repo, ok
}

View File

@@ -1,40 +0,0 @@
package filesystem
import (
"fmt"
"os"
"path/filepath"
)
// GetWorkspacePath returns the path to the workspace directory for the given user and workspace IDs.
func (fs *FileSystem) GetWorkspacePath(userID, workspaceID int) string {
return filepath.Join(fs.RootDir, fmt.Sprintf("%d", userID), fmt.Sprintf("%d", workspaceID))
}
// InitializeUserWorkspace creates the workspace directory for the given user and workspace IDs.
func (fs *FileSystem) InitializeUserWorkspace(userID, workspaceID int) error {
workspacePath := fs.GetWorkspacePath(userID, workspaceID)
err := os.MkdirAll(workspacePath, 0755)
if err != nil {
return fmt.Errorf("failed to create workspace directory: %w", err)
}
return nil
}
// DeleteUserWorkspace deletes the workspace directory for the given user and workspace IDs.
func (fs *FileSystem) DeleteUserWorkspace(userID, workspaceID int) error {
workspacePath := fs.GetWorkspacePath(userID, workspaceID)
err := os.RemoveAll(workspacePath)
if err != nil {
return fmt.Errorf("failed to delete workspace directory: %w", err)
}
return nil
}
// CreateWorkspaceDirectory creates the workspace directory for the given user and workspace IDs.
func (fs *FileSystem) CreateWorkspaceDirectory(userID, workspaceID int) error {
dir := fs.GetWorkspacePath(userID, workspaceID)
return os.MkdirAll(dir, 0755)
}

View File

@@ -0,0 +1,157 @@
// Package git provides functionalities to interact with Git repositories, including cloning, pulling, committing, and pushing changes.
package git
import (
"fmt"
"os"
"path/filepath"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing/transport/http"
)
// Config holds the configuration for a Git client
type Config struct {
URL string
Username string
Token string
WorkDir string
}
// Client defines the interface for Git operations
type Client interface {
Clone() error
Pull() error
Commit(message string) error
Push() error
EnsureRepo() error
}
// client implements the Client interface
type client struct {
Config
repo *git.Repository
}
// New creates a new git Client instance
func New(url, username, token, workDir string) Client {
return &client{
Config: Config{
URL: url,
Username: username,
Token: token,
WorkDir: workDir,
},
}
}
// Clone clones the Git repository to the local directory
func (c *client) Clone() error {
auth := &http.BasicAuth{
Username: c.Username,
Password: c.Token,
}
var err error
c.repo, err = git.PlainClone(c.WorkDir, false, &git.CloneOptions{
URL: c.URL,
Auth: auth,
Progress: os.Stdout,
})
if err != nil {
return fmt.Errorf("failed to clone repository: %w", err)
}
return nil
}
// Pull pulls the latest changes from the remote repository
func (c *client) Pull() error {
if c.repo == nil {
return fmt.Errorf("repository not initialized")
}
w, err := c.repo.Worktree()
if err != nil {
return fmt.Errorf("failed to get worktree: %w", err)
}
auth := &http.BasicAuth{
Username: c.Username,
Password: c.Token,
}
err = w.Pull(&git.PullOptions{
Auth: auth,
Progress: os.Stdout,
})
if err != nil && err != git.NoErrAlreadyUpToDate {
return fmt.Errorf("failed to pull changes: %w", err)
}
return nil
}
// Commit commits the changes in the repository with the given message
func (c *client) Commit(message string) error {
if c.repo == nil {
return fmt.Errorf("repository not initialized")
}
w, err := c.repo.Worktree()
if err != nil {
return fmt.Errorf("failed to get worktree: %w", err)
}
_, err = w.Add(".")
if err != nil {
return fmt.Errorf("failed to add changes: %w", err)
}
_, err = w.Commit(message, &git.CommitOptions{})
if err != nil {
return fmt.Errorf("failed to commit changes: %w", err)
}
return nil
}
// Push pushes the changes to the remote repository
func (c *client) Push() error {
if c.repo == nil {
return fmt.Errorf("repository not initialized")
}
auth := &http.BasicAuth{
Username: c.Username,
Password: c.Token,
}
err := c.repo.Push(&git.PushOptions{
Auth: auth,
Progress: os.Stdout,
})
if err != nil && err != git.NoErrAlreadyUpToDate {
return fmt.Errorf("failed to push changes: %w", err)
}
return nil
}
// EnsureRepo ensures the local repository is cloned and up-to-date
func (c *client) EnsureRepo() error {
if _, err := os.Stat(filepath.Join(c.WorkDir, ".git")); os.IsNotExist(err) {
return c.Clone()
}
var err error
c.repo, err = git.PlainOpen(c.WorkDir)
if err != nil {
return fmt.Errorf("failed to open existing repository: %w", err)
}
return c.Pull()
}

View File

@@ -1,133 +0,0 @@
package gitutils
import (
"fmt"
"os"
"path/filepath"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing/transport/http"
)
type GitRepo struct {
URL string
Username string
Token string
WorkDir string
repo *git.Repository
}
func New(url, username, token, workDir string) *GitRepo {
return &GitRepo{
URL: url,
Username: username,
Token: token,
WorkDir: workDir,
}
}
func (g *GitRepo) Clone() error {
auth := &http.BasicAuth{
Username: g.Username,
Password: g.Token,
}
var err error
g.repo, err = git.PlainClone(g.WorkDir, false, &git.CloneOptions{
URL: g.URL,
Auth: auth,
Progress: os.Stdout,
})
if err != nil {
return fmt.Errorf("failed to clone repository: %w", err)
}
return nil
}
func (g *GitRepo) Pull() error {
if g.repo == nil {
return fmt.Errorf("repository not initialized")
}
w, err := g.repo.Worktree()
if err != nil {
return fmt.Errorf("failed to get worktree: %w", err)
}
auth := &http.BasicAuth{
Username: g.Username,
Password: g.Token,
}
err = w.Pull(&git.PullOptions{
Auth: auth,
Progress: os.Stdout,
})
if err != nil && err != git.NoErrAlreadyUpToDate {
return fmt.Errorf("failed to pull changes: %w", err)
}
return nil
}
func (g *GitRepo) Commit(message string) error {
if g.repo == nil {
return fmt.Errorf("repository not initialized")
}
w, err := g.repo.Worktree()
if err != nil {
return fmt.Errorf("failed to get worktree: %w", err)
}
_, err = w.Add(".")
if err != nil {
return fmt.Errorf("failed to add changes: %w", err)
}
_, err = w.Commit(message, &git.CommitOptions{})
if err != nil {
return fmt.Errorf("failed to commit changes: %w", err)
}
return nil
}
func (g *GitRepo) Push() error {
if g.repo == nil {
return fmt.Errorf("repository not initialized")
}
auth := &http.BasicAuth{
Username: g.Username,
Password: g.Token,
}
err := g.repo.Push(&git.PushOptions{
Auth: auth,
Progress: os.Stdout,
})
if err != nil && err != git.NoErrAlreadyUpToDate {
return fmt.Errorf("failed to push changes: %w", err)
}
return nil
}
func (g *GitRepo) EnsureRepo() error {
if _, err := os.Stat(filepath.Join(g.WorkDir, ".git")); os.IsNotExist(err) {
return g.Clone()
}
var err error
g.repo, err = git.PlainOpen(g.WorkDir)
if err != nil {
return fmt.Errorf("failed to open existing repository: %w", err)
}
return g.Pull()
}

View File

@@ -1,12 +1,13 @@
// Package handlers contains the request handlers for the api routes.
package handlers
import (
"encoding/json"
"net/http"
"novamd/internal/context"
"novamd/internal/db"
"novamd/internal/filesystem"
"novamd/internal/httpcontext"
"novamd/internal/models"
"novamd/internal/storage"
"strconv"
"time"
@@ -14,14 +15,16 @@ import (
"golang.org/x/crypto/bcrypt"
)
type createUserRequest struct {
// CreateUserRequest holds the request fields for creating a new user
type CreateUserRequest struct {
Email string `json:"email"`
DisplayName string `json:"displayName"`
Password string `json:"password"`
Role models.UserRole `json:"role"`
}
type updateUserRequest struct {
// UpdateUserRequest holds the request fields for updating a user
type UpdateUserRequest struct {
Email string `json:"email,omitempty"`
DisplayName string `json:"displayName,omitempty"`
Password string `json:"password,omitempty"`
@@ -44,7 +47,7 @@ func (h *Handler) AdminListUsers() http.HandlerFunc {
// AdminCreateUser creates a new user
func (h *Handler) AdminCreateUser() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req createUserRequest
var req CreateUserRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
@@ -91,7 +94,7 @@ func (h *Handler) AdminCreateUser() http.HandlerFunc {
}
// Initialize user workspace
if err := h.FS.InitializeUserWorkspace(insertedUser.ID, insertedUser.LastWorkspaceID); err != nil {
if err := h.Storage.InitializeUserWorkspace(insertedUser.ID, insertedUser.LastWorkspaceID); err != nil {
http.Error(w, "Failed to initialize user workspace", http.StatusInternalServerError)
return
}
@@ -135,7 +138,7 @@ func (h *Handler) AdminUpdateUser() http.HandlerFunc {
return
}
var req updateUserRequest
var req UpdateUserRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
@@ -172,7 +175,7 @@ func (h *Handler) AdminUpdateUser() http.HandlerFunc {
// AdminDeleteUser deletes a specific user
func (h *Handler) AdminDeleteUser() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -218,7 +221,7 @@ type WorkspaceStats struct {
WorkspaceID int `json:"workspaceID"`
WorkspaceName string `json:"workspaceName"`
WorkspaceCreatedAt time.Time `json:"workspaceCreatedAt"`
*filesystem.FileCountStats
*storage.FileCountStats
}
// AdminListWorkspaces returns a list of all workspaces and their stats
@@ -248,7 +251,7 @@ func (h *Handler) AdminListWorkspaces() http.HandlerFunc {
workspaceData.WorkspaceName = ws.Name
workspaceData.WorkspaceCreatedAt = ws.CreatedAt
fileStats, err := h.FS.GetFileStats(ws.UserID, ws.ID)
fileStats, err := h.Storage.GetFileStats(ws.UserID, ws.ID)
if err != nil {
http.Error(w, "Failed to get file stats", http.StatusInternalServerError)
return
@@ -266,7 +269,7 @@ func (h *Handler) AdminListWorkspaces() http.HandlerFunc {
// SystemStats holds system-wide statistics
type SystemStats struct {
*db.UserStats
*filesystem.FileCountStats
*storage.FileCountStats
}
// AdminGetSystemStats returns system-wide statistics for admins
@@ -278,7 +281,7 @@ func (h *Handler) AdminGetSystemStats() http.HandlerFunc {
return
}
fileStats, err := h.FS.GetTotalFileStats()
fileStats, err := h.Storage.GetTotalFileStats()
if err != nil {
http.Error(w, "Failed to get file stats", http.StatusInternalServerError)
return

View File

@@ -0,0 +1,243 @@
//go:build integration
package handlers_test
import (
"encoding/json"
"fmt"
"net/http"
"testing"
"novamd/internal/handlers"
"novamd/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Helper function to check if a user exists in a slice of users
func containsUser(users []*models.User, searchUser *models.User) bool {
for _, u := range users {
if u.ID == searchUser.ID &&
u.Email == searchUser.Email &&
u.DisplayName == searchUser.DisplayName &&
u.Role == searchUser.Role {
return true
}
}
return false
}
func TestAdminHandlers_Integration(t *testing.T) {
h := setupTestHarness(t)
defer h.teardown(t)
t.Run("user management", func(t *testing.T) {
t.Run("list users", func(t *testing.T) {
// Test with admin token
rr := h.makeRequest(t, http.MethodGet, "/api/v1/admin/users", nil, h.AdminToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var users []*models.User
err := json.NewDecoder(rr.Body).Decode(&users)
require.NoError(t, err)
// Should have at least our admin and regular test users
assert.GreaterOrEqual(t, len(users), 2)
assert.True(t, containsUser(users, h.AdminUser), "Admin user not found in users list")
assert.True(t, containsUser(users, h.RegularUser), "Regular user not found in users list")
// Test with non-admin token
rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/users", nil, h.RegularToken, nil)
assert.Equal(t, http.StatusForbidden, rr.Code)
// Test without token
rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/users", nil, "", nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
})
t.Run("create user", func(t *testing.T) {
createReq := handlers.CreateUserRequest{
Email: "newuser@test.com",
DisplayName: "New User",
Password: "password123",
Role: models.RoleEditor,
}
// Test with admin token
rr := h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.AdminToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var createdUser models.User
err := json.NewDecoder(rr.Body).Decode(&createdUser)
require.NoError(t, err)
assert.Equal(t, createReq.Email, createdUser.Email)
assert.Equal(t, createReq.DisplayName, createdUser.DisplayName)
assert.Equal(t, createReq.Role, createdUser.Role)
assert.NotZero(t, createdUser.LastWorkspaceID)
// Test duplicate email
rr = h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.AdminToken, nil)
assert.Equal(t, http.StatusConflict, rr.Code)
// Test invalid request (missing required fields)
invalidReq := handlers.CreateUserRequest{
Email: "invalid@test.com",
// Missing password and role
}
rr = h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", invalidReq, h.AdminToken, nil)
assert.Equal(t, http.StatusBadRequest, rr.Code)
// Test with non-admin token
rr = h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.RegularToken, nil)
assert.Equal(t, http.StatusForbidden, rr.Code)
})
t.Run("get user", func(t *testing.T) {
path := fmt.Sprintf("/api/v1/admin/users/%d", h.RegularUser.ID)
// Test with admin token
rr := h.makeRequest(t, http.MethodGet, path, nil, h.AdminToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var user models.User
err := json.NewDecoder(rr.Body).Decode(&user)
require.NoError(t, err)
assert.Equal(t, h.RegularUser.ID, user.ID)
// Test non-existent user
rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/users/999999", nil, h.AdminToken, nil)
assert.Equal(t, http.StatusNotFound, rr.Code)
// Test with non-admin token
rr = h.makeRequest(t, http.MethodGet, path, nil, h.RegularToken, nil)
assert.Equal(t, http.StatusForbidden, rr.Code)
})
t.Run("update user", func(t *testing.T) {
path := fmt.Sprintf("/api/v1/admin/users/%d", h.RegularUser.ID)
updateReq := handlers.UpdateUserRequest{
DisplayName: "Updated Name",
Role: models.RoleViewer,
}
// Test with admin token
rr := h.makeRequest(t, http.MethodPut, path, updateReq, h.AdminToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var updatedUser models.User
err := json.NewDecoder(rr.Body).Decode(&updatedUser)
require.NoError(t, err)
assert.Equal(t, updateReq.DisplayName, updatedUser.DisplayName)
assert.Equal(t, updateReq.Role, updatedUser.Role)
// Test with non-admin token
rr = h.makeRequest(t, http.MethodPut, path, updateReq, h.RegularToken, nil)
assert.Equal(t, http.StatusForbidden, rr.Code)
})
t.Run("delete user", func(t *testing.T) {
// Create a user to delete
createReq := handlers.CreateUserRequest{
Email: "todelete@test.com",
DisplayName: "To Delete",
Password: "password123",
Role: models.RoleEditor,
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.AdminToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var createdUser models.User
err := json.NewDecoder(rr.Body).Decode(&createdUser)
require.NoError(t, err)
path := fmt.Sprintf("/api/v1/admin/users/%d", createdUser.ID)
// Test deleting own account (should fail)
adminPath := fmt.Sprintf("/api/v1/admin/users/%d", h.AdminUser.ID)
rr = h.makeRequest(t, http.MethodDelete, adminPath, nil, h.AdminToken, nil)
assert.Equal(t, http.StatusBadRequest, rr.Code)
// Test with admin token
rr = h.makeRequest(t, http.MethodDelete, path, nil, h.AdminToken, nil)
assert.Equal(t, http.StatusNoContent, rr.Code)
// Verify user is deleted
rr = h.makeRequest(t, http.MethodGet, path, nil, h.AdminToken, nil)
assert.Equal(t, http.StatusNotFound, rr.Code)
// Test with non-admin token
rr = h.makeRequest(t, http.MethodDelete, path, nil, h.RegularToken, nil)
assert.Equal(t, http.StatusForbidden, rr.Code)
})
})
t.Run("workspace management", func(t *testing.T) {
t.Run("list workspaces", func(t *testing.T) {
// Create a test workspace first
workspace := &models.Workspace{
UserID: h.RegularUser.ID,
Name: "Test Workspace",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
// Test with admin token
rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/workspaces", nil, h.AdminToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var workspaces []*handlers.WorkspaceStats
err := json.NewDecoder(rr.Body).Decode(&workspaces)
require.NoError(t, err)
// Should have at least the default workspaces for admin and regular users
assert.NotEmpty(t, workspaces)
// Verify workspace stats fields
for _, ws := range workspaces {
assert.NotZero(t, ws.UserID)
assert.NotEmpty(t, ws.UserEmail)
assert.NotZero(t, ws.WorkspaceID)
assert.NotEmpty(t, ws.WorkspaceName)
assert.NotZero(t, ws.WorkspaceCreatedAt)
assert.GreaterOrEqual(t, ws.TotalFiles, 0)
assert.GreaterOrEqual(t, ws.TotalSize, int64(0))
}
// Test with non-admin token
rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/workspaces", nil, h.RegularToken, nil)
assert.Equal(t, http.StatusForbidden, rr.Code)
})
})
t.Run("system stats", func(t *testing.T) {
// Create some test data
workspace := &models.Workspace{
UserID: h.RegularUser.ID,
Name: "Stats Test Workspace",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
// Test with admin token
rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/stats", nil, h.AdminToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var stats handlers.SystemStats
err := json.NewDecoder(rr.Body).Decode(&stats)
require.NoError(t, err)
// Verify stats fields
assert.GreaterOrEqual(t, stats.TotalUsers, 2) // At least admin and regular user
assert.GreaterOrEqual(t, stats.TotalWorkspaces, 2) // At least default workspaces
assert.GreaterOrEqual(t, stats.ActiveUsers, 2) // Our test users should be active
assert.GreaterOrEqual(t, stats.TotalFiles, 0)
assert.GreaterOrEqual(t, stats.TotalSize, int64(0))
// Test with non-admin token
rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/stats", nil, h.RegularToken, nil)
assert.Equal(t, http.StatusForbidden, rr.Code)
})
}

View File

@@ -4,28 +4,32 @@ import (
"encoding/json"
"net/http"
"novamd/internal/auth"
"novamd/internal/httpcontext"
"novamd/internal/context"
"novamd/internal/models"
"golang.org/x/crypto/bcrypt"
)
// LoginRequest represents a user login request
type LoginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
// LoginResponse represents a user login response
type LoginResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
User *models.User `json:"user"`
Session *auth.Session `json:"session"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
User *models.User `json:"user"`
Session *models.Session `json:"session"`
}
// RefreshRequest represents a refresh token request
type RefreshRequest struct {
RefreshToken string `json:"refreshToken"`
}
// RefreshResponse represents a refresh token response
type RefreshResponse struct {
AccessToken string `json:"accessToken"`
}
@@ -129,7 +133,7 @@ func (h *Handler) RefreshToken(authService *auth.SessionService) http.HandlerFun
// GetCurrentUser returns the currently authenticated user
func (h *Handler) GetCurrentUser() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}

View File

@@ -0,0 +1,232 @@
//go:build integration
package handlers_test
import (
"encoding/json"
"net/http"
"testing"
"novamd/internal/handlers"
"novamd/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAuthHandlers_Integration(t *testing.T) {
h := setupTestHarness(t)
defer h.teardown(t)
t.Run("login", func(t *testing.T) {
t.Run("successful login - admin user", func(t *testing.T) {
loginReq := handlers.LoginRequest{
Email: "admin@test.com",
Password: "admin123",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
require.Equal(t, http.StatusOK, rr.Code)
var resp handlers.LoginResponse
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.NotEmpty(t, resp.AccessToken)
assert.NotEmpty(t, resp.RefreshToken)
assert.NotNil(t, resp.User)
assert.Equal(t, loginReq.Email, resp.User.Email)
assert.Equal(t, models.RoleAdmin, resp.User.Role)
})
t.Run("successful login - regular user", func(t *testing.T) {
loginReq := handlers.LoginRequest{
Email: "user@test.com",
Password: "user123",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
require.Equal(t, http.StatusOK, rr.Code)
var resp handlers.LoginResponse
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.NotEmpty(t, resp.AccessToken)
assert.NotEmpty(t, resp.RefreshToken)
assert.NotNil(t, resp.User)
assert.Equal(t, loginReq.Email, resp.User.Email)
assert.Equal(t, models.RoleEditor, resp.User.Role)
})
t.Run("login failures", func(t *testing.T) {
tests := []struct {
name string
request handlers.LoginRequest
wantCode int
}{
{
name: "wrong password",
request: handlers.LoginRequest{
Email: "user@test.com",
Password: "wrongpassword",
},
wantCode: http.StatusUnauthorized,
},
{
name: "non-existent user",
request: handlers.LoginRequest{
Email: "nonexistent@test.com",
Password: "password123",
},
wantCode: http.StatusUnauthorized,
},
{
name: "empty email",
request: handlers.LoginRequest{
Email: "",
Password: "password123",
},
wantCode: http.StatusBadRequest,
},
{
name: "empty password",
request: handlers.LoginRequest{
Email: "user@test.com",
Password: "",
},
wantCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", tt.request, "", nil)
assert.Equal(t, tt.wantCode, rr.Code)
})
}
})
})
t.Run("refresh token", func(t *testing.T) {
t.Run("successful token refresh", func(t *testing.T) {
// First login to get refresh token
loginReq := handlers.LoginRequest{
Email: "user@test.com",
Password: "user123",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
require.Equal(t, http.StatusOK, rr.Code)
var loginResp handlers.LoginResponse
err := json.NewDecoder(rr.Body).Decode(&loginResp)
require.NoError(t, err)
// Now try to refresh the token
refreshReq := handlers.RefreshRequest{
RefreshToken: loginResp.RefreshToken,
}
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", refreshReq, "", nil)
require.Equal(t, http.StatusOK, rr.Code)
var refreshResp handlers.RefreshResponse
err = json.NewDecoder(rr.Body).Decode(&refreshResp)
require.NoError(t, err)
assert.NotEmpty(t, refreshResp.AccessToken)
})
t.Run("refresh failures", func(t *testing.T) {
tests := []struct {
name string
request handlers.RefreshRequest
wantCode int
}{
{
name: "invalid refresh token",
request: handlers.RefreshRequest{
RefreshToken: "invalid-token",
},
wantCode: http.StatusUnauthorized,
},
{
name: "empty refresh token",
request: handlers.RefreshRequest{
RefreshToken: "",
},
wantCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", tt.request, "", nil)
assert.Equal(t, tt.wantCode, rr.Code)
})
}
})
})
t.Run("logout", func(t *testing.T) {
t.Run("successful logout", func(t *testing.T) {
// First login to get session
loginReq := handlers.LoginRequest{
Email: "user@test.com",
Password: "user123",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
require.Equal(t, http.StatusOK, rr.Code)
var loginResp handlers.LoginResponse
err := json.NewDecoder(rr.Body).Decode(&loginResp)
require.NoError(t, err)
// Now logout using session ID from login response
headers := map[string]string{
"X-Session-ID": loginResp.Session.ID,
}
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/logout", nil, loginResp.AccessToken, headers)
require.Equal(t, http.StatusOK, rr.Code)
// Try to use the refresh token - should fail
refreshReq := handlers.RefreshRequest{
RefreshToken: loginResp.RefreshToken,
}
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", refreshReq, "", nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
})
t.Run("logout without session ID", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/logout", nil, h.RegularToken, nil)
assert.Equal(t, http.StatusBadRequest, rr.Code)
})
})
t.Run("get current user", func(t *testing.T) {
t.Run("successful get current user", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var user models.User
err := json.NewDecoder(rr.Body).Decode(&user)
require.NoError(t, err)
assert.Equal(t, h.RegularUser.ID, user.ID)
assert.Equal(t, h.RegularUser.Email, user.Email)
assert.Equal(t, h.RegularUser.Role, user.Role)
})
t.Run("get current user without token", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "", nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
})
t.Run("get current user with invalid token", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "invalid-token", nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
})
})
}

View File

@@ -4,20 +4,23 @@ import (
"encoding/json"
"io"
"net/http"
"os"
"novamd/internal/httpcontext"
"novamd/internal/context"
"novamd/internal/storage"
"github.com/go-chi/chi/v5"
)
// ListFiles returns a list of all files in the workspace
func (h *Handler) ListFiles() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
files, err := h.FS.ListFilesRecursively(ctx.UserID, ctx.Workspace.ID)
files, err := h.Storage.ListFilesRecursively(ctx.UserID, ctx.Workspace.ID)
if err != nil {
http.Error(w, "Failed to list files", http.StatusInternalServerError)
return
@@ -27,9 +30,10 @@ func (h *Handler) ListFiles() http.HandlerFunc {
}
}
// LookupFileByName returns the paths of files with the given name
func (h *Handler) LookupFileByName() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -40,7 +44,7 @@ func (h *Handler) LookupFileByName() http.HandlerFunc {
return
}
filePaths, err := h.FS.FindFileByName(ctx.UserID, ctx.Workspace.ID, filename)
filePaths, err := h.Storage.FindFileByName(ctx.UserID, ctx.Workspace.ID, filename)
if err != nil {
http.Error(w, "File not found", http.StatusNotFound)
return
@@ -50,28 +54,45 @@ func (h *Handler) LookupFileByName() http.HandlerFunc {
}
}
// GetFileContent returns the content of a file
func (h *Handler) GetFileContent() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
filePath := chi.URLParam(r, "*")
content, err := h.FS.GetFileContent(ctx.UserID, ctx.Workspace.ID, filePath)
content, err := h.Storage.GetFileContent(ctx.UserID, ctx.Workspace.ID, filePath)
if err != nil {
http.Error(w, "Failed to read file", http.StatusNotFound)
if storage.IsPathValidationError(err) {
http.Error(w, "Invalid file path", http.StatusBadRequest)
return
}
if os.IsNotExist(err) {
http.Error(w, "Failed to read file", http.StatusNotFound)
return
}
http.Error(w, "Failed to read file", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/plain")
w.Write(content)
_, err = w.Write(content)
if err != nil {
http.Error(w, "Failed to write response", http.StatusInternalServerError)
return
}
}
}
// SaveFile saves the content of a file
func (h *Handler) SaveFile() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -83,8 +104,13 @@ func (h *Handler) SaveFile() http.HandlerFunc {
return
}
err = h.FS.SaveFile(ctx.UserID, ctx.Workspace.ID, filePath, content)
err = h.Storage.SaveFile(ctx.UserID, ctx.Workspace.ID, filePath, content)
if err != nil {
if storage.IsPathValidationError(err) {
http.Error(w, "Invalid file path", http.StatusBadRequest)
return
}
http.Error(w, "Failed to save file", http.StatusInternalServerError)
return
}
@@ -93,28 +119,44 @@ func (h *Handler) SaveFile() http.HandlerFunc {
}
}
// DeleteFile deletes a file
func (h *Handler) DeleteFile() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
filePath := chi.URLParam(r, "*")
err := h.FS.DeleteFile(ctx.UserID, ctx.Workspace.ID, filePath)
err := h.Storage.DeleteFile(ctx.UserID, ctx.Workspace.ID, filePath)
if err != nil {
if storage.IsPathValidationError(err) {
http.Error(w, "Invalid file path", http.StatusBadRequest)
return
}
if os.IsNotExist(err) {
http.Error(w, "File not found", http.StatusNotFound)
return
}
http.Error(w, "Failed to delete file", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("File deleted successfully"))
_, err = w.Write([]byte("File deleted successfully"))
if err != nil {
http.Error(w, "Failed to write response", http.StatusInternalServerError)
return
}
}
}
// GetLastOpenedFile returns the last opened file in the workspace
func (h *Handler) GetLastOpenedFile() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -125,7 +167,7 @@ func (h *Handler) GetLastOpenedFile() http.HandlerFunc {
return
}
if _, err := h.FS.ValidatePath(ctx.UserID, ctx.Workspace.ID, filePath); err != nil {
if _, err := h.Storage.ValidatePath(ctx.UserID, ctx.Workspace.ID, filePath); err != nil {
http.Error(w, "Invalid file path", http.StatusBadRequest)
return
}
@@ -134,9 +176,10 @@ func (h *Handler) GetLastOpenedFile() http.HandlerFunc {
}
}
// UpdateLastOpenedFile updates the last opened file in the workspace
func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -150,10 +193,21 @@ func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc {
return
}
// Validate the file path exists in the workspace
// Validate the file path in the workspace
if requestBody.FilePath != "" {
if _, err := h.FS.ValidatePath(ctx.UserID, ctx.Workspace.ID, requestBody.FilePath); err != nil {
http.Error(w, "Invalid file path", http.StatusBadRequest)
_, err := h.Storage.GetFileContent(ctx.UserID, ctx.Workspace.ID, requestBody.FilePath)
if err != nil {
if storage.IsPathValidationError(err) {
http.Error(w, "Invalid file path", http.StatusBadRequest)
return
}
if os.IsNotExist(err) {
http.Error(w, "File not found", http.StatusNotFound)
return
}
http.Error(w, "Failed to update file", http.StatusInternalServerError)
return
}
}

View File

@@ -0,0 +1,239 @@
//go:build integration
package handlers_test
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"testing"
"novamd/internal/models"
"novamd/internal/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFileHandlers_Integration(t *testing.T) {
h := setupTestHarness(t)
defer h.teardown(t)
t.Run("file operations", func(t *testing.T) {
// Setup: Create a workspace first
workspace := &models.Workspace{
UserID: h.RegularUser.ID,
Name: "File Test Workspace",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
err := json.NewDecoder(rr.Body).Decode(workspace)
require.NoError(t, err)
// Construct base URL for file operations
baseURL := fmt.Sprintf("/api/v1/workspaces/%s/files", url.PathEscape(workspace.Name))
t.Run("list empty directory", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodGet, baseURL, nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var files []storage.FileNode
err := json.NewDecoder(rr.Body).Decode(&files)
require.NoError(t, err)
assert.Empty(t, files, "Expected empty directory")
})
t.Run("save and get file", func(t *testing.T) {
content := "Test content for file operations"
filePath := "test.md"
// Save file
rr := h.makeRequestRaw(t, http.MethodPost, baseURL+"/"+filePath, strings.NewReader(content), h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
// Get file content
rr = h.makeRequest(t, http.MethodGet, baseURL+"/"+filePath, nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, content, rr.Body.String())
// List directory should now show the file
rr = h.makeRequest(t, http.MethodGet, baseURL, nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var files []storage.FileNode
err := json.NewDecoder(rr.Body).Decode(&files)
require.NoError(t, err)
assert.Len(t, files, 1)
assert.Equal(t, filePath, files[0].Name)
})
t.Run("save and list nested files", func(t *testing.T) {
files := map[string]string{
"docs/readme.md": "README content",
"docs/api/endpoints.md": "API documentation",
"notes/meeting-notes.md": "Meeting notes content",
"notes/todo.md": "TODO list",
}
// Create all files
for path, content := range files {
rr := h.makeRequest(t, http.MethodPost, baseURL+"/"+path, content, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
}
// List all files
rr := h.makeRequest(t, http.MethodGet, baseURL, nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var fileNodes []storage.FileNode
err := json.NewDecoder(rr.Body).Decode(&fileNodes)
require.NoError(t, err)
// We should have 3 root items: docs/, notes/, and test.md
assert.Len(t, fileNodes, 3)
// Verify directory structure
var docsDir, notesDir *storage.FileNode
for i := range fileNodes {
switch fileNodes[i].Name {
case "docs":
docsDir = &fileNodes[i]
case "notes":
notesDir = &fileNodes[i]
}
}
require.NotNil(t, docsDir)
require.NotNil(t, notesDir)
assert.Len(t, docsDir.Children, 2) // readme.md and api/
assert.Len(t, notesDir.Children, 2) // meeting-notes.md and todo.md
})
t.Run("lookup file by name", func(t *testing.T) {
// Look up a file that exists in multiple locations
filename := "readme.md"
dupContent := "Another readme"
rr := h.makeRequest(t, http.MethodPost, baseURL+"/projects/"+filename, dupContent, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
// Search for the file
rr = h.makeRequest(t, http.MethodGet, baseURL+"/lookup?filename="+filename, nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var response struct {
Paths []string `json:"paths"`
}
err := json.NewDecoder(rr.Body).Decode(&response)
require.NoError(t, err)
assert.Len(t, response.Paths, 2)
// Search for non-existent file
rr = h.makeRequest(t, http.MethodGet, baseURL+"/lookup?filename=nonexistent.md", nil, h.RegularToken, nil)
assert.Equal(t, http.StatusNotFound, rr.Code)
})
t.Run("delete file", func(t *testing.T) {
filePath := "to-delete.md"
content := "This file will be deleted"
// Create file
rr := h.makeRequest(t, http.MethodPost, baseURL+"/"+filePath, content, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
// Delete file
rr = h.makeRequest(t, http.MethodDelete, baseURL+"/"+filePath, nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
// Verify file is gone
rr = h.makeRequest(t, http.MethodGet, baseURL+"/"+filePath, nil, h.RegularToken, nil)
assert.Equal(t, http.StatusNotFound, rr.Code)
})
t.Run("last opened file", func(t *testing.T) {
// Initially should be empty
rr := h.makeRequest(t, http.MethodGet, baseURL+"/last", nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var response struct {
LastOpenedFilePath string `json:"lastOpenedFilePath"`
}
err := json.NewDecoder(rr.Body).Decode(&response)
require.NoError(t, err)
assert.Empty(t, response.LastOpenedFilePath)
// Update last opened file
updateReq := struct {
FilePath string `json:"filePath"`
}{
FilePath: "docs/readme.md",
}
rr = h.makeRequest(t, http.MethodPut, baseURL+"/last", updateReq, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
// Verify update
rr = h.makeRequest(t, http.MethodGet, baseURL+"/last", nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
err = json.NewDecoder(rr.Body).Decode(&response)
require.NoError(t, err)
assert.Equal(t, updateReq.FilePath, response.LastOpenedFilePath)
// Test invalid file path
updateReq.FilePath = "nonexistent.md"
rr = h.makeRequest(t, http.MethodPut, baseURL+"/last", updateReq, h.RegularToken, nil)
assert.Equal(t, http.StatusNotFound, rr.Code)
})
t.Run("unauthorized access", func(t *testing.T) {
tests := []struct {
name string
method string
path string
body interface{}
}{
{"list files", http.MethodGet, baseURL, nil},
{"get file", http.MethodGet, baseURL + "/test.md", nil},
{"save file", http.MethodPost, baseURL + "/test.md", "content"},
{"delete file", http.MethodDelete, baseURL + "/test.md", nil},
{"get last file", http.MethodGet, baseURL + "/last", nil},
{"update last file", http.MethodPut, baseURL + "/last", struct{ FilePath string }{"test.md"}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Test without token
rr := h.makeRequest(t, tc.method, tc.path, tc.body, "", nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
// Test with wrong user's token
rr = h.makeRequest(t, tc.method, tc.path, tc.body, h.AdminToken, nil)
assert.Equal(t, http.StatusNotFound, rr.Code)
})
}
})
t.Run("path traversal attempts", func(t *testing.T) {
maliciousPaths := []string{
"../../../etc/passwd",
"./../../secret.txt",
"/etc/shadow",
"test/../../../etc/passwd",
}
for _, path := range maliciousPaths {
t.Run(path, func(t *testing.T) {
// Try to read
rr := h.makeRequest(t, http.MethodGet, baseURL+"/"+path, nil, h.RegularToken, nil)
assert.Equal(t, http.StatusBadRequest, rr.Code)
// Try to write
rr = h.makeRequest(t, http.MethodPost, baseURL+"/"+path, "malicious content", h.RegularToken, nil)
assert.Equal(t, http.StatusBadRequest, rr.Code)
})
}
})
})
}

View File

@@ -4,12 +4,13 @@ import (
"encoding/json"
"net/http"
"novamd/internal/httpcontext"
"novamd/internal/context"
)
// StageCommitAndPush stages, commits, and pushes changes to the remote repository
func (h *Handler) StageCommitAndPush() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -28,7 +29,7 @@ func (h *Handler) StageCommitAndPush() http.HandlerFunc {
return
}
err := h.FS.StageCommitAndPush(ctx.UserID, ctx.Workspace.ID, requestBody.Message)
err := h.Storage.StageCommitAndPush(ctx.UserID, ctx.Workspace.ID, requestBody.Message)
if err != nil {
http.Error(w, "Failed to stage, commit, and push changes: "+err.Error(), http.StatusInternalServerError)
return
@@ -38,14 +39,15 @@ func (h *Handler) StageCommitAndPush() http.HandlerFunc {
}
}
// PullChanges pulls changes from the remote repository
func (h *Handler) PullChanges() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
err := h.FS.Pull(ctx.UserID, ctx.Workspace.ID)
err := h.Storage.Pull(ctx.UserID, ctx.Workspace.ID)
if err != nil {
http.Error(w, "Failed to pull changes: "+err.Error(), http.StatusInternalServerError)
return

View File

@@ -0,0 +1,181 @@
//go:build integration
package handlers_test
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"testing"
"novamd/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGitHandlers_Integration(t *testing.T) {
h := setupTestHarness(t)
defer h.teardown(t)
t.Run("git operations", func(t *testing.T) {
// Setup: Create a workspace with Git enabled
workspace := &models.Workspace{
UserID: h.RegularUser.ID,
Name: "Git Test Workspace",
GitEnabled: true,
GitURL: "https://github.com/test/repo.git",
GitUser: "testuser",
GitToken: "testtoken",
GitAutoCommit: true,
GitCommitMsgTemplate: "Update: {{message}}",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
err := json.NewDecoder(rr.Body).Decode(workspace)
require.NoError(t, err)
// Construct base URL for Git operations
baseURL := "/api/v1/workspaces/" + url.PathEscape(workspace.Name) + "/git"
t.Run("stage, commit and push", func(t *testing.T) {
h.MockGit.Reset()
t.Run("successful commit", func(t *testing.T) {
commitMsg := "Test commit message"
requestBody := map[string]string{
"message": commitMsg,
}
rr := h.makeRequest(t, http.MethodPost, baseURL+"/commit", requestBody, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var response map[string]string
err := json.NewDecoder(rr.Body).Decode(&response)
require.NoError(t, err)
assert.Contains(t, response["message"], "successfully")
// Verify mock was called correctly
assert.Equal(t, 1, h.MockGit.GetCommitCount(), "Commit should be called once")
assert.Equal(t, 1, h.MockGit.GetPushCount(), "Push should be called once")
assert.Equal(t, commitMsg, h.MockGit.GetLastCommitMessage(), "Commit message should match")
})
t.Run("empty commit message", func(t *testing.T) {
h.MockGit.Reset()
requestBody := map[string]string{
"message": "",
}
rr := h.makeRequest(t, http.MethodPost, baseURL+"/commit", requestBody, h.RegularToken, nil)
assert.Equal(t, http.StatusBadRequest, rr.Code)
assert.Equal(t, 0, h.MockGit.GetCommitCount(), "Commit should not be called")
})
t.Run("git error", func(t *testing.T) {
h.MockGit.Reset()
h.MockGit.SetError(fmt.Errorf("mock git error"))
requestBody := map[string]string{
"message": "Test message",
}
rr := h.makeRequest(t, http.MethodPost, baseURL+"/commit", requestBody, h.RegularToken, nil)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
h.MockGit.SetError(nil) // Reset error state
})
})
t.Run("pull changes", func(t *testing.T) {
h.MockGit.Reset()
t.Run("successful pull", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodPost, baseURL+"/pull", nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var response map[string]string
err := json.NewDecoder(rr.Body).Decode(&response)
require.NoError(t, err)
assert.Contains(t, response["message"], "Pulled changes")
assert.Equal(t, 1, h.MockGit.GetPullCount(), "Pull should be called once")
})
t.Run("git error", func(t *testing.T) {
h.MockGit.Reset()
h.MockGit.SetError(fmt.Errorf("mock git error"))
rr := h.makeRequest(t, http.MethodPost, baseURL+"/pull", nil, h.RegularToken, nil)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
h.MockGit.SetError(nil) // Reset error state
})
})
t.Run("unauthorized access", func(t *testing.T) {
h.MockGit.Reset()
tests := []struct {
name string
method string
path string
body interface{}
}{
{
name: "commit without token",
method: http.MethodPost,
path: baseURL + "/commit",
body: map[string]string{"message": "test"},
},
{
name: "pull without token",
method: http.MethodPost,
path: baseURL + "/pull",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Test without token
rr := h.makeRequest(t, tc.method, tc.path, tc.body, "", nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
// Test with wrong user's token
rr = h.makeRequest(t, tc.method, tc.path, tc.body, h.AdminToken, nil)
assert.Equal(t, http.StatusNotFound, rr.Code)
})
}
})
t.Run("workspace without git", func(t *testing.T) {
h.MockGit.Reset()
// Create a workspace without Git enabled
nonGitWorkspace := &models.Workspace{
UserID: h.RegularUser.ID,
Name: "Non-Git Workspace",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", nonGitWorkspace, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
err := json.NewDecoder(rr.Body).Decode(nonGitWorkspace)
require.NoError(t, err)
nonGitBaseURL := "/api/v1/workspaces/" + url.PathEscape(nonGitWorkspace.Name) + "/git"
// Try to commit
commitMsg := map[string]string{"message": "test"}
rr = h.makeRequest(t, http.MethodPost, nonGitBaseURL+"/commit", commitMsg, h.RegularToken, nil)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
// Try to pull
rr = h.makeRequest(t, http.MethodPost, nonGitBaseURL+"/pull", nil, h.RegularToken, nil)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
})
})
}

View File

@@ -4,20 +4,20 @@ import (
"encoding/json"
"net/http"
"novamd/internal/db"
"novamd/internal/filesystem"
"novamd/internal/storage"
)
// Handler provides common functionality for all handlers
type Handler struct {
DB *db.DB
FS *filesystem.FileSystem
DB db.Database
Storage storage.Manager
}
// NewHandler creates a new handler with the given dependencies
func NewHandler(db *db.DB, fs *filesystem.FileSystem) *Handler {
func NewHandler(db db.Database, s storage.Manager) *Handler {
return &Handler{
DB: db,
FS: fs,
DB: db,
Storage: s,
}
}

View File

@@ -0,0 +1,229 @@
//go:build integration
package handlers_test
import (
"bytes"
"encoding/json"
"io"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/go-chi/chi/v5"
"golang.org/x/crypto/bcrypt"
"novamd/internal/api"
"novamd/internal/auth"
"novamd/internal/db"
"novamd/internal/git"
"novamd/internal/handlers"
"novamd/internal/models"
"novamd/internal/secrets"
"novamd/internal/storage"
)
// testHarness encapsulates all the dependencies needed for testing
type testHarness struct {
DB db.TestDatabase
Storage storage.Manager
Router *chi.Mux
Handler *handlers.Handler
JWTManager auth.JWTManager
SessionSvc *auth.SessionService
AdminUser *models.User
AdminToken string
RegularUser *models.User
RegularToken string
TempDirectory string
MockGit *MockGitClient
}
// setupTestHarness creates a new test environment
func setupTestHarness(t *testing.T) *testHarness {
t.Helper()
// Create temporary directory for test files
tempDir, err := os.MkdirTemp("", "novamd-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
// Initialize test database
secretsSvc, err := secrets.NewService("YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=") // test key
if err != nil {
t.Fatalf("Failed to initialize secrets service: %v", err)
}
database, err := db.NewTestDB(":memory:", secretsSvc)
if err != nil {
t.Fatalf("Failed to initialize test database: %v", err)
}
if err := database.Migrate(); err != nil {
t.Fatalf("Failed to run migrations: %v", err)
}
// Create mock git client
mockGit := NewMockGitClient(false)
// Create storage with mock git client
storageOpts := storage.Options{
NewGitClient: func(url, user, token, path string) git.Client {
return mockGit
},
}
storageSvc := storage.NewServiceWithOptions(tempDir, storageOpts)
// Initialize JWT service
jwtSvc, err := auth.NewJWTService(auth.JWTConfig{
SigningKey: "test-key",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 24 * time.Hour,
})
if err != nil {
t.Fatalf("Failed to initialize JWT service: %v", err)
}
// Initialize session service
sessionSvc := auth.NewSessionService(database, jwtSvc)
// Create handler
handler := &handlers.Handler{
DB: database,
Storage: storageSvc,
}
// Set up router with middlewares
router := chi.NewRouter()
authMiddleware := auth.NewMiddleware(jwtSvc)
router.Route("/api/v1", func(r chi.Router) {
api.SetupRoutes(r, database, storageSvc, authMiddleware, sessionSvc)
})
h := &testHarness{
DB: database,
Storage: storageSvc,
Router: router,
Handler: handler,
JWTManager: jwtSvc,
SessionSvc: sessionSvc,
TempDirectory: tempDir,
MockGit: mockGit,
}
// Create test users
adminUser, adminToken := h.createTestUser(t, database, sessionSvc, "admin@test.com", "admin123", models.RoleAdmin)
regularUser, regularToken := h.createTestUser(t, database, sessionSvc, "user@test.com", "user123", models.RoleEditor)
h.AdminUser = adminUser
h.AdminToken = adminToken
h.RegularUser = regularUser
h.RegularToken = regularToken
return h
}
// teardownTestHarness cleans up the test environment
func (h *testHarness) teardown(t *testing.T) {
t.Helper()
if err := h.DB.Close(); err != nil {
t.Errorf("Failed to close database: %v", err)
}
if err := os.RemoveAll(h.TempDirectory); err != nil {
t.Errorf("Failed to remove temp directory: %v", err)
}
}
// createTestUser creates a test user and returns the user and access token
func (h *testHarness) createTestUser(t *testing.T, db db.Database, sessionSvc *auth.SessionService, email, password string, role models.UserRole) (*models.User, string) {
t.Helper()
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
t.Fatalf("Failed to hash password: %v", err)
}
user := &models.User{
Email: email,
DisplayName: "Test User",
PasswordHash: string(hashedPassword),
Role: role,
}
user, err = db.CreateUser(user)
if err != nil {
t.Fatalf("Failed to create user: %v", err)
}
// Initialize the default workspace directory in storage
err = h.Storage.InitializeUserWorkspace(user.ID, user.LastWorkspaceID)
if err != nil {
t.Fatalf("Failed to initialize user workspace: %v", err)
}
session, accessToken, err := sessionSvc.CreateSession(user.ID, string(user.Role))
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
if session == nil || accessToken == "" {
t.Fatal("Failed to get valid session or token")
}
return user, accessToken
}
// makeRequest is a helper function to make HTTP requests in tests
func (h *testHarness) makeRequest(t *testing.T, method, path string, body interface{}, token string, headers map[string]string) *httptest.ResponseRecorder {
t.Helper()
var reqBody []byte
var err error
if body != nil {
reqBody, err = json.Marshal(body)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
}
req := httptest.NewRequest(method, path, bytes.NewBuffer(reqBody))
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
req.Header.Set("Content-Type", "application/json")
// Add any additional headers
for k, v := range headers {
req.Header.Set(k, v)
}
rr := httptest.NewRecorder()
h.Router.ServeHTTP(rr, req)
return rr
}
// makeRequestRaw is a helper function to make HTTP requests with raw body content
func (h *testHarness) makeRequestRaw(t *testing.T, method, path string, body io.Reader, token string, headers map[string]string) *httptest.ResponseRecorder {
t.Helper()
req := httptest.NewRequest(method, path, body)
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
// Add any additional headers
for k, v := range headers {
req.Header.Set(k, v)
}
rr := httptest.NewRecorder()
h.Router.ServeHTTP(rr, req)
return rr
}

View File

@@ -0,0 +1,123 @@
//go:build integration
package handlers_test
import (
"fmt"
)
// MockGitClient implements the git.Client interface for testing
type MockGitClient struct {
initialized bool
cloned bool
lastCommitMsg string
error error
pullCount int
commitCount int
pushCount int
cloneCount int
ensureCount int
}
// NewMockGitClient creates a new mock git client
func NewMockGitClient(shouldError bool) *MockGitClient {
var err error
if shouldError {
err = fmt.Errorf("mock git error")
}
return &MockGitClient{
error: err,
}
}
// Clone implements git.Client
func (m *MockGitClient) Clone() error {
if m.error != nil {
return m.error
}
m.cloneCount++
m.cloned = true
return nil
}
// Pull implements git.Client
func (m *MockGitClient) Pull() error {
if m.error != nil {
return m.error
}
m.pullCount++
return nil
}
// Commit implements git.Client
func (m *MockGitClient) Commit(message string) error {
if m.error != nil {
return m.error
}
m.commitCount++
m.lastCommitMsg = message
return nil
}
// Push implements git.Client
func (m *MockGitClient) Push() error {
if m.error != nil {
return m.error
}
m.pushCount++
return nil
}
// EnsureRepo implements git.Client
func (m *MockGitClient) EnsureRepo() error {
if m.error != nil {
return m.error
}
m.ensureCount++
m.initialized = true
return nil
}
// Helper methods for tests
func (m *MockGitClient) GetCommitCount() int {
return m.commitCount
}
func (m *MockGitClient) GetPushCount() int {
return m.pushCount
}
func (m *MockGitClient) GetPullCount() int {
return m.pullCount
}
func (m *MockGitClient) GetLastCommitMessage() string {
return m.lastCommitMsg
}
func (m *MockGitClient) IsInitialized() bool {
return m.initialized
}
func (m *MockGitClient) IsCloned() bool {
return m.cloned
}
// Reset resets all counters and states
func (m *MockGitClient) Reset() {
m.initialized = false
m.cloned = false
m.lastCommitMsg = ""
m.pullCount = 0
m.commitCount = 0
m.pushCount = 0
m.cloneCount = 0
m.ensureCount = 0
}
// SetError sets the error state
func (m *MockGitClient) SetError(err error) {
m.error = err
}

View File

@@ -12,12 +12,14 @@ type StaticHandler struct {
staticPath string
}
// NewStaticHandler creates a new StaticHandler with the given static path
func NewStaticHandler(staticPath string) *StaticHandler {
return &StaticHandler{
staticPath: staticPath,
}
}
// ServeHTTP serves the static files
func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Get the requested path
requestedPath := r.URL.Path

View File

@@ -0,0 +1,145 @@
//go:build integration
package handlers_test
import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"novamd/internal/handlers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStaticHandler_Integration(t *testing.T) {
// Create temporary directory for test static files
tempDir, err := os.MkdirTemp("", "novamd-static-test-*")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
// Create test files
files := map[string][]byte{
"index.html": []byte("<html><body>Index</body></html>"),
"assets/style.css": []byte("body { color: blue; }"),
"assets/style.css.gz": []byte("gzipped css content"),
"assets/script.js": []byte("console.log('test');"),
"assets/script.js.gz": []byte("gzipped js content"),
"subdir/page.html": []byte("<html><body>Page</body></html>"),
"subdir/page.html.gz": []byte("gzipped html content"),
}
for path, content := range files {
fullPath := filepath.Join(tempDir, path)
err := os.MkdirAll(filepath.Dir(fullPath), 0755)
require.NoError(t, err)
err = os.WriteFile(fullPath, content, 0644)
require.NoError(t, err)
}
// Create static handler
handler := handlers.NewStaticHandler(tempDir)
tests := []struct {
name string
path string
acceptEncoding string
wantStatus int
wantBody []byte
wantType string
wantEncoding string
wantCacheHeader string
}{
{
name: "serve index.html",
path: "/",
wantStatus: http.StatusOK,
wantBody: []byte("<html><body>Index</body></html>"),
wantType: "text/html; charset=utf-8",
},
{
name: "serve CSS with gzip support",
path: "/assets/style.css",
acceptEncoding: "gzip",
wantStatus: http.StatusOK,
wantBody: []byte("gzipped css content"),
wantType: "text/css",
wantEncoding: "gzip",
wantCacheHeader: "public, max-age=31536000",
},
{
name: "serve JS with gzip support",
path: "/assets/script.js",
acceptEncoding: "gzip",
wantStatus: http.StatusOK,
wantBody: []byte("gzipped js content"),
wantType: "application/javascript",
wantEncoding: "gzip",
wantCacheHeader: "public, max-age=31536000",
},
{
name: "serve CSS without gzip",
path: "/assets/style.css",
wantStatus: http.StatusOK,
wantBody: []byte("body { color: blue; }"),
wantType: "text/css; charset=utf-8",
wantCacheHeader: "public, max-age=31536000",
},
{
name: "SPA routing - nonexistent path",
path: "/nonexistent",
wantStatus: http.StatusOK,
wantBody: []byte("<html><body>Index</body></html>"),
wantType: "text/html; charset=utf-8",
},
{
name: "SPA routing - deep path",
path: "/some/deep/path",
wantStatus: http.StatusOK,
wantBody: []byte("<html><body>Index</body></html>"),
wantType: "text/html; charset=utf-8",
},
{
name: "block directory traversal",
path: "/../../../etc/passwd",
wantStatus: http.StatusBadRequest,
},
{
name: "nonexistent file in assets",
path: "/assets/nonexistent.js",
wantStatus: http.StatusOK, // Should serve index.html
wantBody: []byte("<html><body>Index</body></html>"),
wantType: "text/html; charset=utf-8",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", tc.path, nil)
if tc.acceptEncoding != "" {
req.Header.Set("Accept-Encoding", tc.acceptEncoding)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, tc.wantStatus, w.Code)
if tc.wantStatus == http.StatusOK {
assert.Equal(t, tc.wantBody, w.Body.Bytes())
assert.Equal(t, tc.wantType, w.Header().Get("Content-Type"))
if tc.wantEncoding != "" {
assert.Equal(t, tc.wantEncoding, w.Header().Get("Content-Encoding"))
}
if tc.wantCacheHeader != "" {
assert.Equal(t, tc.wantCacheHeader, w.Header().Get("Cache-Control"))
}
}
})
}
}

View File

@@ -4,11 +4,12 @@ import (
"encoding/json"
"net/http"
"novamd/internal/httpcontext"
"novamd/internal/context"
"golang.org/x/crypto/bcrypt"
)
// UpdateProfileRequest represents a user profile update request
type UpdateProfileRequest struct {
DisplayName string `json:"displayName"`
Email string `json:"email"`
@@ -16,31 +17,15 @@ type UpdateProfileRequest struct {
NewPassword string `json:"newPassword"`
}
// DeleteAccountRequest represents a user account deletion request
type DeleteAccountRequest struct {
Password string `json:"password"`
}
func (h *Handler) GetUser() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
if !ok {
return
}
user, err := h.DB.GetUserByID(ctx.UserID)
if err != nil {
http.Error(w, "Failed to get user", http.StatusInternalServerError)
return
}
respondJSON(w, user)
}
}
// UpdateProfile updates the current user's profile
func (h *Handler) UpdateProfile() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -58,14 +43,6 @@ func (h *Handler) UpdateProfile() http.HandlerFunc {
return
}
// Start transaction for atomic updates
tx, err := h.DB.Begin()
if err != nil {
http.Error(w, "Failed to start transaction", http.StatusInternalServerError)
return
}
defer tx.Rollback()
// Handle password update if requested
if req.NewPassword != "" {
// Current password must be provided to change password
@@ -131,11 +108,6 @@ func (h *Handler) UpdateProfile() http.HandlerFunc {
return
}
if err := tx.Commit(); err != nil {
http.Error(w, "Failed to commit changes", http.StatusInternalServerError)
return
}
// Return updated user data
respondJSON(w, user)
}
@@ -144,7 +116,7 @@ func (h *Handler) UpdateProfile() http.HandlerFunc {
// DeleteAccount handles user account deletion
func (h *Handler) DeleteAccount() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -171,8 +143,7 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
// Prevent admin from deleting their own account if they're the last admin
if user.Role == "admin" {
// Count number of admin users
adminCount := 0
err := h.DB.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&adminCount)
adminCount, err := h.DB.CountAdminUsers()
if err != nil {
http.Error(w, "Failed to verify admin status", http.StatusInternalServerError)
return
@@ -183,14 +154,6 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
}
}
// Start transaction for consistent deletion
tx, err := h.DB.Begin()
if err != nil {
http.Error(w, "Failed to start transaction", http.StatusInternalServerError)
return
}
defer tx.Rollback()
// Get user's workspaces for cleanup
workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID)
if err != nil {
@@ -200,7 +163,7 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
// Delete workspace directories
for _, workspace := range workspaces {
if err := h.FS.DeleteUserWorkspace(ctx.UserID, workspace.ID); err != nil {
if err := h.Storage.DeleteUserWorkspace(ctx.UserID, workspace.ID); err != nil {
http.Error(w, "Failed to delete workspace files", http.StatusInternalServerError)
return
}
@@ -212,11 +175,6 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
return
}
if err := tx.Commit(); err != nil {
http.Error(w, "Failed to commit transaction", http.StatusInternalServerError)
return
}
respondJSON(w, map[string]string{"message": "Account deleted successfully"})
}
}

View File

@@ -0,0 +1,219 @@
//go:build integration
package handlers_test
import (
"encoding/json"
"net/http"
"testing"
"novamd/internal/handlers"
"novamd/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUserHandlers_Integration(t *testing.T) {
h := setupTestHarness(t)
defer h.teardown(t)
currentEmail := h.RegularUser.Email
currentPassword := "user123"
t.Run("get current user", func(t *testing.T) {
t.Run("successful get", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var user models.User
err := json.NewDecoder(rr.Body).Decode(&user)
require.NoError(t, err)
assert.Equal(t, h.RegularUser.ID, user.ID)
assert.Equal(t, h.RegularUser.Email, user.Email)
assert.Equal(t, h.RegularUser.DisplayName, user.DisplayName)
assert.Equal(t, h.RegularUser.Role, user.Role)
assert.Empty(t, user.PasswordHash, "Password hash should not be included in response")
})
t.Run("unauthorized", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "", nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
})
})
t.Run("update profile", func(t *testing.T) {
t.Run("update display name only", func(t *testing.T) {
updateReq := handlers.UpdateProfileRequest{
DisplayName: "Updated Name",
}
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var user models.User
err := json.NewDecoder(rr.Body).Decode(&user)
require.NoError(t, err)
assert.Equal(t, updateReq.DisplayName, user.DisplayName)
})
t.Run("update email", func(t *testing.T) {
updateReq := handlers.UpdateProfileRequest{
Email: "newemail@test.com",
CurrentPassword: currentPassword,
}
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var user models.User
err := json.NewDecoder(rr.Body).Decode(&user)
require.NoError(t, err)
assert.Equal(t, updateReq.Email, user.Email)
currentEmail = updateReq.Email
})
t.Run("update email without password", func(t *testing.T) {
updateReq := handlers.UpdateProfileRequest{
Email: "anotheremail@test.com",
}
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
assert.Equal(t, http.StatusBadRequest, rr.Code)
})
t.Run("update email with wrong password", func(t *testing.T) {
updateReq := handlers.UpdateProfileRequest{
Email: "wrongpass@test.com",
CurrentPassword: "wrongpassword",
}
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
})
t.Run("update password", func(t *testing.T) {
updateReq := handlers.UpdateProfileRequest{
CurrentPassword: currentPassword,
NewPassword: "newpassword123",
}
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
// Verify can login with new password
loginReq := handlers.LoginRequest{
Email: currentEmail,
Password: "newpassword123",
}
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
assert.Equal(t, http.StatusOK, rr.Code)
currentPassword = updateReq.NewPassword
})
t.Run("update password without current password", func(t *testing.T) {
updateReq := handlers.UpdateProfileRequest{
NewPassword: "newpass123",
}
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
assert.Equal(t, http.StatusBadRequest, rr.Code)
})
t.Run("update password with wrong current password", func(t *testing.T) {
updateReq := handlers.UpdateProfileRequest{
CurrentPassword: "wrongpassword",
NewPassword: "newpass123",
}
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
})
t.Run("update with short password", func(t *testing.T) {
updateReq := handlers.UpdateProfileRequest{
CurrentPassword: currentPassword,
NewPassword: "short",
}
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
assert.Equal(t, http.StatusBadRequest, rr.Code)
})
t.Run("duplicate email", func(t *testing.T) {
updateReq := handlers.UpdateProfileRequest{
Email: h.AdminUser.Email,
CurrentPassword: currentPassword,
}
rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil)
assert.Equal(t, http.StatusConflict, rr.Code)
})
})
t.Run("delete account", func(t *testing.T) {
// Create a new user that we can delete
createReq := handlers.CreateUserRequest{
Email: "todelete@test.com",
DisplayName: "To Delete",
Password: "password123",
Role: models.RoleEditor,
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.AdminToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var newUser models.User
err := json.NewDecoder(rr.Body).Decode(&newUser)
require.NoError(t, err)
// Get token for new user
loginReq := handlers.LoginRequest{
Email: createReq.Email,
Password: createReq.Password,
}
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
require.Equal(t, http.StatusOK, rr.Code)
var loginResp handlers.LoginResponse
err = json.NewDecoder(rr.Body).Decode(&loginResp)
require.NoError(t, err)
userToken := loginResp.AccessToken
t.Run("successful delete", func(t *testing.T) {
deleteReq := handlers.DeleteAccountRequest{
Password: createReq.Password,
}
rr := h.makeRequest(t, http.MethodDelete, "/api/v1/profile", deleteReq, userToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
// Verify user is deleted
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
})
t.Run("delete with wrong password", func(t *testing.T) {
deleteReq := handlers.DeleteAccountRequest{
Password: "wrongpassword",
}
rr := h.makeRequest(t, http.MethodDelete, "/api/v1/profile", deleteReq, h.RegularToken, nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
})
t.Run("prevent last admin deletion", func(t *testing.T) {
deleteReq := handlers.DeleteAccountRequest{
Password: "admin123", // Admin password from test harness
}
rr := h.makeRequest(t, http.MethodDelete, "/api/v1/profile", deleteReq, h.AdminToken, nil)
assert.Equal(t, http.StatusForbidden, rr.Code)
})
})
}

View File

@@ -1,17 +1,18 @@
package handlers
import (
"database/sql"
"encoding/json"
"fmt"
"net/http"
"novamd/internal/httpcontext"
"novamd/internal/context"
"novamd/internal/models"
)
// ListWorkspaces returns a list of all workspaces for the current user
func (h *Handler) ListWorkspaces() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -26,9 +27,10 @@ func (h *Handler) ListWorkspaces() http.HandlerFunc {
}
}
// CreateWorkspace creates a new workspace
func (h *Handler) CreateWorkspace() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -39,24 +41,43 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc {
return
}
if err := workspace.ValidateGitSettings(); err != nil {
http.Error(w, "Invalid workspace", http.StatusBadRequest)
return
}
workspace.UserID = ctx.UserID
if err := h.DB.CreateWorkspace(&workspace); err != nil {
http.Error(w, "Failed to create workspace", http.StatusInternalServerError)
return
}
if err := h.FS.InitializeUserWorkspace(workspace.UserID, workspace.ID); err != nil {
if err := h.Storage.InitializeUserWorkspace(workspace.UserID, workspace.ID); err != nil {
http.Error(w, "Failed to initialize workspace directory", http.StatusInternalServerError)
return
}
if workspace.GitEnabled {
if err := h.Storage.SetupGitRepo(
ctx.UserID,
workspace.ID,
workspace.GitURL,
workspace.GitUser,
workspace.GitToken,
); err != nil {
http.Error(w, "Failed to setup git repo: "+err.Error(), http.StatusInternalServerError)
return
}
}
respondJSON(w, workspace)
}
}
// GetWorkspace returns the current workspace
func (h *Handler) GetWorkspace() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -81,9 +102,10 @@ func gitSettingsChanged(new, old *models.Workspace) bool {
return false
}
// UpdateWorkspace updates the current workspace
func (h *Handler) UpdateWorkspace() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -107,7 +129,7 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc {
// Handle Git repository setup/teardown if Git settings changed
if gitSettingsChanged(&workspace, ctx.Workspace) {
if workspace.GitEnabled {
if err := h.FS.SetupGitRepo(
if err := h.Storage.SetupGitRepo(
ctx.UserID,
ctx.Workspace.ID,
workspace.GitURL,
@@ -119,7 +141,7 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc {
}
} else {
h.FS.DisableGitRepo(ctx.UserID, ctx.Workspace.ID)
h.Storage.DisableGitRepo(ctx.UserID, ctx.Workspace.ID)
}
}
@@ -132,9 +154,10 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc {
}
}
// DeleteWorkspace deletes the current workspace
func (h *Handler) DeleteWorkspace() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -168,7 +191,11 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
http.Error(w, "Failed to start transaction", http.StatusInternalServerError)
return
}
defer tx.Rollback()
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError)
}
}()
// Update last workspace ID first
err = h.DB.UpdateLastWorkspaceTx(tx, ctx.UserID, nextWorkspaceID)
@@ -195,9 +222,10 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
}
}
// GetLastWorkspaceName returns the name of the last opened workspace
func (h *Handler) GetLastWorkspaceName() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -212,9 +240,10 @@ func (h *Handler) GetLastWorkspaceName() http.HandlerFunc {
}
}
// UpdateLastWorkspaceName updates the name of the last opened workspace
func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := httpcontext.GetRequestContext(w, r)
ctx, ok := context.GetRequestContext(w, r)
if !ok {
return
}
@@ -224,13 +253,11 @@ func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc {
}
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
fmt.Println(err)
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := h.DB.UpdateLastWorkspace(ctx.UserID, requestBody.WorkspaceName); err != nil {
fmt.Println(err)
http.Error(w, "Failed to update last workspace", http.StatusInternalServerError)
return
}

View File

@@ -0,0 +1,297 @@
//go:build integration
package handlers_test
import (
"encoding/json"
"net/http"
"net/url"
"testing"
"novamd/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWorkspaceHandlers_Integration(t *testing.T) {
h := setupTestHarness(t)
defer h.teardown(t)
t.Run("list workspaces", func(t *testing.T) {
t.Run("successful list", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces", nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var workspaces []*models.Workspace
err := json.NewDecoder(rr.Body).Decode(&workspaces)
require.NoError(t, err)
assert.NotEmpty(t, workspaces, "User should have at least one default workspace")
})
t.Run("unauthorized", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces", nil, "", nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
})
})
t.Run("create workspace", func(t *testing.T) {
t.Run("successful create", func(t *testing.T) {
workspace := &models.Workspace{
Name: "Test Workspace",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var created models.Workspace
err := json.NewDecoder(rr.Body).Decode(&created)
require.NoError(t, err)
assert.Equal(t, workspace.Name, created.Name)
assert.Equal(t, h.RegularUser.ID, created.UserID)
assert.NotZero(t, created.ID)
})
t.Run("create with git settings", func(t *testing.T) {
workspace := &models.Workspace{
Name: "Git Workspace",
GitEnabled: true,
GitURL: "https://github.com/test/repo.git",
GitUser: "testuser",
GitToken: "testtoken",
GitAutoCommit: true,
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var created models.Workspace
err := json.NewDecoder(rr.Body).Decode(&created)
require.NoError(t, err)
assert.Equal(t, workspace.GitEnabled, created.GitEnabled)
assert.Equal(t, workspace.GitURL, created.GitURL)
assert.Equal(t, workspace.GitUser, created.GitUser)
assert.Equal(t, workspace.GitToken, created.GitToken)
assert.Equal(t, workspace.GitAutoCommit, created.GitAutoCommit)
})
t.Run("invalid workspace", func(t *testing.T) {
workspace := &models.Workspace{
Name: "", // Empty name
GitEnabled: true,
// Missing required Git settings
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
assert.Equal(t, http.StatusBadRequest, rr.Code)
})
})
// Create a workspace for the remaining tests
workspace := &models.Workspace{
Name: "Test Workspace Operations",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
err := json.NewDecoder(rr.Body).Decode(workspace)
require.NoError(t, err)
escapedName := url.PathEscape(workspace.Name)
baseURL := "/api/v1/workspaces/" + escapedName
t.Run("get workspace", func(t *testing.T) {
t.Run("successful get", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodGet, baseURL, nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var got models.Workspace
err := json.NewDecoder(rr.Body).Decode(&got)
require.NoError(t, err)
assert.Equal(t, workspace.ID, got.ID)
assert.Equal(t, workspace.Name, got.Name)
})
t.Run("nonexistent workspace", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces/nonexistent", nil, h.RegularToken, nil)
assert.Equal(t, http.StatusNotFound, rr.Code)
})
t.Run("unauthorized access", func(t *testing.T) {
// Try accessing with another user's token
rr := h.makeRequest(t, http.MethodGet, baseURL, nil, h.AdminToken, nil)
assert.Equal(t, http.StatusNotFound, rr.Code)
})
})
t.Run("update workspace", func(t *testing.T) {
t.Run("update name", func(t *testing.T) {
workspace.Name = "Updated Workspace"
rr := h.makeRequest(t, http.MethodPut, baseURL, workspace, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var updated models.Workspace
err := json.NewDecoder(rr.Body).Decode(&updated)
require.NoError(t, err)
assert.Equal(t, workspace.Name, updated.Name)
// Update baseURL for remaining tests
escapedName = url.PathEscape(workspace.Name)
baseURL = "/api/v1/workspaces/" + escapedName
})
t.Run("update settings", func(t *testing.T) {
update := &models.Workspace{
Name: workspace.Name,
Theme: "dark",
AutoSave: true,
ShowHiddenFiles: true,
}
rr := h.makeRequest(t, http.MethodPut, baseURL, update, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var updated models.Workspace
err := json.NewDecoder(rr.Body).Decode(&updated)
require.NoError(t, err)
assert.Equal(t, update.Theme, updated.Theme)
assert.Equal(t, update.AutoSave, updated.AutoSave)
assert.Equal(t, update.ShowHiddenFiles, updated.ShowHiddenFiles)
})
t.Run("enable git", func(t *testing.T) {
update := &models.Workspace{
Name: workspace.Name,
Theme: "dark",
GitEnabled: true,
GitURL: "https://github.com/test/repo.git",
GitUser: "testuser",
GitToken: "testtoken",
GitAutoCommit: true,
}
rr := h.makeRequest(t, http.MethodPut, baseURL, update, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var updated models.Workspace
err := json.NewDecoder(rr.Body).Decode(&updated)
require.NoError(t, err)
assert.Equal(t, update.GitEnabled, updated.GitEnabled)
assert.Equal(t, update.GitURL, updated.GitURL)
assert.Equal(t, update.GitUser, updated.GitUser)
assert.Equal(t, update.GitToken, updated.GitToken)
// Mock should have been called to setup git
assert.True(t, h.MockGit.IsInitialized())
})
t.Run("invalid git settings", func(t *testing.T) {
update := &models.Workspace{
Name: workspace.Name,
GitEnabled: true,
// Missing required Git settings
}
rr := h.makeRequest(t, http.MethodPut, baseURL, update, h.RegularToken, nil)
assert.Equal(t, http.StatusBadRequest, rr.Code)
})
})
t.Run("last workspace", func(t *testing.T) {
t.Run("get last workspace", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces/last", nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var response struct {
LastWorkspaceName string `json:"lastWorkspaceName"`
}
err := json.NewDecoder(rr.Body).Decode(&response)
require.NoError(t, err)
assert.NotEmpty(t, response.LastWorkspaceName)
})
t.Run("update last workspace", func(t *testing.T) {
req := struct {
WorkspaceName string `json:"workspaceName"`
}{
WorkspaceName: workspace.Name,
}
rr := h.makeRequest(t, http.MethodPut, "/api/v1/workspaces/last", req, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
// Verify the update
rr = h.makeRequest(t, http.MethodGet, "/api/v1/workspaces/last", nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var response struct {
LastWorkspaceName string `json:"lastWorkspaceName"`
}
err := json.NewDecoder(rr.Body).Decode(&response)
require.NoError(t, err)
assert.Equal(t, workspace.Name, response.LastWorkspaceName)
})
})
t.Run("delete workspace", func(t *testing.T) {
// Get current workspaces to know how many we have
rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces", nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var existingWorkspaces []*models.Workspace
err := json.NewDecoder(rr.Body).Decode(&existingWorkspaces)
require.NoError(t, err)
// Create a new workspace we can safely delete
newWorkspace := &models.Workspace{
Name: "Workspace To Delete",
}
rr = h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", newWorkspace, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
err = json.NewDecoder(rr.Body).Decode(newWorkspace)
require.NoError(t, err)
t.Run("successful delete", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodDelete, "/api/v1/workspaces/"+url.PathEscape(newWorkspace.Name), nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var response struct {
NextWorkspaceName string `json:"nextWorkspaceName"`
}
err := json.NewDecoder(rr.Body).Decode(&response)
require.NoError(t, err)
assert.NotEmpty(t, response.NextWorkspaceName)
// Verify workspace is deleted
rr = h.makeRequest(t, http.MethodGet, "/api/v1/workspaces/"+url.PathEscape(newWorkspace.Name), nil, h.RegularToken, nil)
assert.Equal(t, http.StatusNotFound, rr.Code)
})
t.Run("prevent deleting last workspace", func(t *testing.T) {
// Delete all but one workspace
for i := 0; i < len(existingWorkspaces)-1; i++ {
ws := existingWorkspaces[i]
rr := h.makeRequest(t, http.MethodDelete, "/api/v1/workspaces/"+url.PathEscape(ws.Name), nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
}
// Try to delete the last remaining workspace
lastWs := existingWorkspaces[len(existingWorkspaces)-1]
rr := h.makeRequest(t, http.MethodDelete, "/api/v1/workspaces/"+url.PathEscape(lastWs.Name), nil, h.RegularToken, nil)
assert.Equal(t, http.StatusBadRequest, rr.Code)
})
t.Run("unauthorized deletion", func(t *testing.T) {
// Create a workspace to attempt unauthorized deletion
workspace := &models.Workspace{
Name: "Unauthorized Delete Test",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
// Try to delete with wrong user's token
rr = h.makeRequest(t, http.MethodDelete, "/api/v1/workspaces/"+url.PathEscape(workspace.Name), nil, h.AdminToken, nil)
assert.Equal(t, http.StatusNotFound, rr.Code)
})
})
}

View File

@@ -1,31 +0,0 @@
package httpcontext
import (
"context"
"net/http"
"novamd/internal/models"
)
// HandlerContext holds the request-specific data available to all handlers
type HandlerContext struct {
UserID int
UserRole string
Workspace *models.Workspace
}
type contextKey string
const HandlerContextKey contextKey = "handlerContext"
func GetRequestContext(w http.ResponseWriter, r *http.Request) (*HandlerContext, bool) {
ctx := r.Context().Value(HandlerContextKey)
if ctx == nil {
http.Error(w, "Internal server error", http.StatusInternalServerError)
return nil, false
}
return ctx.(*HandlerContext), true
}
func WithHandlerContext(r *http.Request, hctx *HandlerContext) *http.Request {
return r.WithContext(context.WithValue(r.Context(), HandlerContextKey, hctx))
}

View File

@@ -0,0 +1,13 @@
// Package models contains the data models used throughout the application. These models are used to represent data in the database, as well as to validate and serialize data in the application.
package models
import "time"
// Session represents a user session in the database
type Session struct {
ID string // Unique session identifier
UserID int // ID of the user this session belongs to
RefreshToken string // The refresh token associated with this session
ExpiresAt time.Time // When this session expires
CreatedAt time.Time // When this session was created
}

View File

@@ -8,14 +8,17 @@ import (
var validate = validator.New()
// UserRole represents the role of a user in the system
type UserRole string
// User roles
const (
RoleAdmin UserRole = "admin"
RoleEditor UserRole = "editor"
RoleViewer UserRole = "viewer"
)
// User represents a user in the system
type User struct {
ID int `json:"id" validate:"required,min=1"`
Email string `json:"email" validate:"required,email"`
@@ -26,6 +29,7 @@ type User struct {
LastWorkspaceID int `json:"lastWorkspaceId"`
}
// Validate validates the user struct
func (u *User) Validate() error {
return validate.Struct(u)
}

View File

@@ -4,6 +4,7 @@ import (
"time"
)
// Workspace represents a user's workspace in the system
type Workspace struct {
ID int `json:"id" validate:"required,min=1"`
UserID int `json:"userId" validate:"required,min=1"`
@@ -23,18 +24,30 @@ type Workspace struct {
GitCommitMsgTemplate string `json:"gitCommitMsgTemplate"`
}
// Validate validates the workspace struct
func (w *Workspace) Validate() error {
return validate.Struct(w)
}
func (w *Workspace) GetDefaultSettings() {
w.Theme = "light"
w.AutoSave = false
w.ShowHiddenFiles = false
w.GitEnabled = false
w.GitURL = ""
w.GitUser = ""
w.GitToken = ""
w.GitAutoCommit = false
w.GitCommitMsgTemplate = "${action} ${filename}"
// ValidateGitSettings validates the git settings if git is enabled
func (w *Workspace) ValidateGitSettings() error {
return validate.StructExcept(w, "ID", "UserID", "Theme")
}
// SetDefaultSettings sets the default settings for the workspace
func (w *Workspace) SetDefaultSettings() {
if w.Theme == "" {
w.Theme = "light"
}
w.AutoSave = w.AutoSave || false
w.ShowHiddenFiles = w.ShowHiddenFiles || false
w.GitEnabled = w.GitEnabled || false
w.GitAutoCommit = w.GitEnabled && (w.GitAutoCommit || false)
if w.GitCommitMsgTemplate == "" {
w.GitCommitMsgTemplate = "${action} ${filename}"
}
}

View File

@@ -0,0 +1,112 @@
// Package secrets provides an Encryptor interface for encrypting and decrypting strings using AES-256-GCM.
package secrets
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
)
// Service is an interface for encrypting and decrypting strings
type Service interface {
Encrypt(plaintext string) (string, error)
Decrypt(ciphertext string) (string, error)
}
type encryptor struct {
gcm cipher.AEAD
}
// ValidateKey checks if the provided base64-encoded key is suitable for AES-256
func ValidateKey(key string) error {
_, err := decodeAndValidateKey(key)
return err
}
// decodeAndValidateKey validates and decodes the base64-encoded key
// Returns the decoded key bytes if valid
func decodeAndValidateKey(key string) ([]byte, error) {
if key == "" {
return nil, fmt.Errorf("encryption key is required")
}
keyBytes, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, fmt.Errorf("invalid base64 encoding: %w", err)
}
if len(keyBytes) != 32 {
return nil, fmt.Errorf("encryption key must be 32 bytes (256 bits): got %d bytes", len(keyBytes))
}
// Verify the key can be used for AES
_, err = aes.NewCipher(keyBytes)
if err != nil {
return nil, fmt.Errorf("invalid encryption key: %w", err)
}
return keyBytes, nil
}
// NewService creates a new Encryptor instance with the provided base64-encoded key
func NewService(key string) (Service, error) {
keyBytes, err := decodeAndValidateKey(key)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(keyBytes)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
return &encryptor{gcm: gcm}, nil
}
// Encrypt encrypts the plaintext using AES-256-GCM
func (e *encryptor) Encrypt(plaintext string) (string, error) {
if plaintext == "" {
return "", nil
}
nonce := make([]byte, e.gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
ciphertext := e.gcm.Seal(nonce, nonce, []byte(plaintext), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts the ciphertext using AES-256-GCM
func (e *encryptor) Decrypt(ciphertext string) (string, error) {
if ciphertext == "" {
return "", nil
}
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", fmt.Errorf("invalid base64 encoding: %w", err)
}
nonceSize := e.gcm.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("invalid ciphertext: too short")
}
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
plaintext, err := e.gcm.Open(nil, nonce, ciphertextBytes, nil)
if err != nil {
return "", err
}
return string(plaintext), nil
}

View File

@@ -0,0 +1,257 @@
package secrets_test
import (
"encoding/base64"
"strings"
"testing"
"novamd/internal/secrets"
)
func TestValidateKey(t *testing.T) {
testCases := []struct {
name string
key string
wantErr bool
errContains string
}{
{
name: "valid 32-byte base64 key",
key: base64.StdEncoding.EncodeToString(make([]byte, 32)),
wantErr: false,
},
{
name: "empty key",
key: "",
wantErr: true,
errContains: "encryption key is required",
},
{
name: "invalid base64",
key: "not-base64!@#$",
wantErr: true,
errContains: "invalid base64 encoding",
},
{
name: "wrong key size (16 bytes)",
key: base64.StdEncoding.EncodeToString(make([]byte, 16)),
wantErr: true,
errContains: "encryption key must be 32 bytes",
},
{
name: "wrong key size (64 bytes)",
key: base64.StdEncoding.EncodeToString(make([]byte, 64)),
wantErr: true,
errContains: "encryption key must be 32 bytes",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := secrets.ValidateKey(tc.key)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
return
}
if !strings.Contains(err.Error(), tc.errContains) {
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestNew(t *testing.T) {
testCases := []struct {
name string
key string
wantErr bool
errContains string
}{
{
name: "valid key",
key: base64.StdEncoding.EncodeToString(make([]byte, 32)),
wantErr: false,
},
{
name: "empty key",
key: "",
wantErr: true,
errContains: "encryption key is required",
},
{
name: "invalid key",
key: "invalid",
wantErr: true,
errContains: "invalid base64 encoding",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e, err := secrets.NewService(tc.key)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
return
}
if !strings.Contains(err.Error(), tc.errContains) {
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if e == nil {
t.Error("expected Encryptor instance, got nil")
}
})
}
}
func TestEncryptDecrypt(t *testing.T) {
// Generate a valid key for testing
key := base64.StdEncoding.EncodeToString(make([]byte, 32))
e, err := secrets.NewService(key)
if err != nil {
t.Fatalf("failed to create Encryptor instance: %v", err)
}
testCases := []struct {
name string
plaintext string
wantErr bool
}{
{
name: "normal text",
plaintext: "Hello, World!",
wantErr: false,
},
{
name: "empty string",
plaintext: "",
wantErr: false,
},
{
name: "long text",
plaintext: strings.Repeat("Long text with lots of content. ", 100),
wantErr: false,
},
{
name: "special characters",
plaintext: "!@#$%^&*()_+-=[]{}|;:,.<>?",
wantErr: false,
},
{
name: "unicode characters",
plaintext: "Hello, 世界! नमस्ते",
wantErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Test encryption
ciphertext, err := e.Encrypt(tc.plaintext)
if tc.wantErr {
if err == nil {
t.Error("expected encryption error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected encryption error: %v", err)
}
// Verify ciphertext is different from plaintext
if tc.plaintext != "" && ciphertext == tc.plaintext {
t.Error("ciphertext matches plaintext")
}
// Test decryption
decrypted, err := e.Decrypt(ciphertext)
if err != nil {
t.Fatalf("unexpected decryption error: %v", err)
}
// Verify decrypted text matches original
if decrypted != tc.plaintext {
t.Errorf("decrypted text = %q, want %q", decrypted, tc.plaintext)
}
})
}
}
func TestDecryptInvalidCiphertext(t *testing.T) {
key := base64.StdEncoding.EncodeToString(make([]byte, 32))
e, err := secrets.NewService(key)
if err != nil {
t.Fatalf("failed to create Encryptor instance: %v", err)
}
testCases := []struct {
name string
ciphertext string
wantErr bool
errContains string
}{
{
name: "empty ciphertext",
ciphertext: "",
wantErr: false,
},
{
name: "invalid base64",
ciphertext: "not-base64!@#$",
wantErr: true,
errContains: "invalid base64 encoding",
},
{
name: "invalid ciphertext (too short)",
ciphertext: base64.StdEncoding.EncodeToString(make([]byte, 10)),
wantErr: true,
errContains: "invalid ciphertext: too short",
},
{
name: "tampered ciphertext",
ciphertext: base64.StdEncoding.EncodeToString(make([]byte, 50)),
wantErr: true,
errContains: "message authentication failed",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
decrypted, err := e.Decrypt(tc.ciphertext)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
return
}
if !strings.Contains(err.Error(), tc.errContains) {
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if decrypted != "" {
t.Errorf("expected empty string, got %q", decrypted)
}
})
}
}

View File

@@ -0,0 +1,24 @@
// storage/errors.go
package storage
import (
"errors"
"fmt"
)
// PathValidationError represents a path validation error (e.g., path traversal attempt)
type PathValidationError struct {
Path string
Message string
}
func (e *PathValidationError) Error() string {
return fmt.Sprintf("%s: %s", e.Message, e.Path)
}
// IsPathValidationError checks if the error is a PathValidationError
func IsPathValidationError(err error) bool {
var pathErr *PathValidationError
return err != nil && errors.As(err, &pathErr)
}

View File

@@ -1,6 +1,6 @@
// Package filesystem provides functionalities to interact with the file system,
// Package storage provides functionalities to interact with the file system,
// including listing files, finding files by name, getting file content, saving files, and deleting files.
package filesystem
package storage
import (
"fmt"
@@ -10,7 +10,18 @@ import (
"strings"
)
// FileNode represents a file or directory in the file system.
// FileManager provides functionalities to interact with files in the storage.
type FileManager interface {
ListFilesRecursively(userID, workspaceID int) ([]FileNode, error)
FindFileByName(userID, workspaceID int, filename string) ([]string, error)
GetFileContent(userID, workspaceID int, filePath string) ([]byte, error)
SaveFile(userID, workspaceID int, filePath string, content []byte) error
DeleteFile(userID, workspaceID int, filePath string) error
GetFileStats(userID, workspaceID int) (*FileCountStats, error)
GetTotalFileStats() (*FileCountStats, error)
}
// FileNode represents a file or directory in the storage.
type FileNode struct {
ID string `json:"id"`
Name string `json:"name"`
@@ -19,13 +30,15 @@ type FileNode struct {
}
// ListFilesRecursively returns a list of all files in the workspace directory and its subdirectories.
func (fs *FileSystem) ListFilesRecursively(userID, workspaceID int) ([]FileNode, error) {
workspacePath := fs.GetWorkspacePath(userID, workspaceID)
return fs.walkDirectory(workspacePath, "")
// Workspace is identified by the given userID and workspaceID.
func (s *Service) ListFilesRecursively(userID, workspaceID int) ([]FileNode, error) {
workspacePath := s.GetWorkspacePath(userID, workspaceID)
return s.walkDirectory(workspacePath, "")
}
func (fs *FileSystem) walkDirectory(dir, prefix string) ([]FileNode, error) {
entries, err := os.ReadDir(dir)
// walkDirectory recursively walks the directory and returns a list of files and directories.
func (s *Service) walkDirectory(dir, prefix string) ([]FileNode, error) {
entries, err := s.fs.ReadDir(dir)
if err != nil {
return nil, err
}
@@ -57,7 +70,7 @@ func (fs *FileSystem) walkDirectory(dir, prefix string) ([]FileNode, error) {
path := filepath.Join(prefix, name)
fullPath := filepath.Join(dir, name)
children, err := fs.walkDirectory(fullPath, path)
children, err := s.walkDirectory(fullPath, path)
if err != nil {
return nil, err
}
@@ -88,9 +101,11 @@ func (fs *FileSystem) walkDirectory(dir, prefix string) ([]FileNode, error) {
}
// FindFileByName returns a list of file paths that match the given filename.
func (fs *FileSystem) FindFileByName(userID, workspaceID int, filename string) ([]string, error) {
// Files are searched recursively in the workspace directory and its subdirectories.
// Workspace is identified by the given userID and workspaceID.
func (s *Service) FindFileByName(userID, workspaceID int, filename string) ([]string, error) {
var foundPaths []string
workspacePath := fs.GetWorkspacePath(userID, workspaceID)
workspacePath := s.GetWorkspacePath(userID, workspaceID)
err := filepath.Walk(workspacePath, func(path string, info os.FileInfo, err error) error {
if err != nil {
@@ -119,37 +134,40 @@ func (fs *FileSystem) FindFileByName(userID, workspaceID int, filename string) (
return foundPaths, nil
}
// GetFileContent returns the content of the file at the given path.
func (fs *FileSystem) GetFileContent(userID, workspaceID int, filePath string) ([]byte, error) {
fullPath, err := fs.ValidatePath(userID, workspaceID, filePath)
// GetFileContent returns the content of the file at the given filePath.
// Path must be a relative path within the workspace directory given by userID and workspaceID.
func (s *Service) GetFileContent(userID, workspaceID int, filePath string) ([]byte, error) {
fullPath, err := s.ValidatePath(userID, workspaceID, filePath)
if err != nil {
return nil, err
}
return os.ReadFile(fullPath)
return s.fs.ReadFile(fullPath)
}
// SaveFile writes the content to the file at the given path.
func (fs *FileSystem) SaveFile(userID, workspaceID int, filePath string, content []byte) error {
fullPath, err := fs.ValidatePath(userID, workspaceID, filePath)
// SaveFile writes the content to the file at the given filePath.
// Path must be a relative path within the workspace directory given by userID and workspaceID.
func (s *Service) SaveFile(userID, workspaceID int, filePath string, content []byte) error {
fullPath, err := s.ValidatePath(userID, workspaceID, filePath)
if err != nil {
return err
}
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
if err := s.fs.MkdirAll(dir, 0755); err != nil {
return err
}
return os.WriteFile(fullPath, content, 0644)
return s.fs.WriteFile(fullPath, content, 0644)
}
// DeleteFile deletes the file at the given path.
func (fs *FileSystem) DeleteFile(userID, workspaceID int, filePath string) error {
fullPath, err := fs.ValidatePath(userID, workspaceID, filePath)
// DeleteFile deletes the file at the given filePath.
// Path must be a relative path within the workspace directory given by userID and workspaceID.
func (s *Service) DeleteFile(userID, workspaceID int, filePath string) error {
fullPath, err := s.ValidatePath(userID, workspaceID, filePath)
if err != nil {
return err
}
return os.Remove(fullPath)
return s.fs.Remove(fullPath)
}
// FileCountStats holds statistics about files in a workspace
@@ -159,25 +177,26 @@ type FileCountStats struct {
}
// GetFileStats returns the total number of files and related statistics in a workspace
// Parameters:
// - userID: the ID of the user who owns the workspace
// - workspaceID: the ID of the workspace to count files in
// Returns:
// - result: statistics about the files in the workspace
// - error: any error that occurred during counting
func (fs *FileSystem) GetFileStats(userID, workspaceID int) (*FileCountStats, error) {
workspacePath := fs.GetWorkspacePath(userID, workspaceID)
// Workspace is identified by the given userID and workspaceID
func (s *Service) GetFileStats(userID, workspaceID int) (*FileCountStats, error) {
workspacePath := s.GetWorkspacePath(userID, workspaceID)
// Check if workspace exists
if _, err := os.Stat(workspacePath); os.IsNotExist(err) {
if _, err := s.fs.Stat(workspacePath); s.fs.IsNotExist(err) {
return nil, fmt.Errorf("workspace directory does not exist")
}
return fs.countFilesInPath(workspacePath)
return s.countFilesInPath(workspacePath)
}
func (fs *FileSystem) countFilesInPath(directoryPath string) (*FileCountStats, error) {
// GetTotalFileStats returns the total file statistics for the storage.
func (s *Service) GetTotalFileStats() (*FileCountStats, error) {
return s.countFilesInPath(s.RootDir)
}
// countFilesInPath counts the total number of files and the total size of files in the given directory.
func (s *Service) countFilesInPath(directoryPath string) (*FileCountStats, error) {
result := &FileCountStats{}
err := filepath.WalkDir(directoryPath, func(path string, d os.DirEntry, err error) error {

View File

@@ -0,0 +1,407 @@
package storage_test
import (
"io/fs"
"novamd/internal/storage"
"path/filepath"
"testing"
)
// TestFileNode ensures FileNode structs are created correctly
func TestFileNode(t *testing.T) {
testCases := []struct {
name string // name of the test case
node storage.FileNode
want storage.FileNode
}{
{
name: "file without children",
node: storage.FileNode{
ID: "test.md",
Name: "test.md",
Path: "test.md",
},
want: storage.FileNode{
ID: "test.md",
Name: "test.md",
Path: "test.md",
},
},
{
name: "directory with children",
node: storage.FileNode{
ID: "dir",
Name: "dir",
Path: "dir",
Children: []storage.FileNode{
{
ID: "dir/file1.md",
Name: "file1.md",
Path: "dir/file1.md",
},
},
},
want: storage.FileNode{
ID: "dir",
Name: "dir",
Path: "dir",
Children: []storage.FileNode{
{
ID: "dir/file1.md",
Name: "file1.md",
Path: "dir/file1.md",
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := tc.node // Now we're testing the actual node structure
if got.ID != tc.want.ID {
t.Errorf("ID = %v, want %v", got.ID, tc.want.ID)
}
if got.Name != tc.want.Name {
t.Errorf("Name = %v, want %v", got.Name, tc.want.Name)
}
if got.Path != tc.want.Path {
t.Errorf("Path = %v, want %v", got.Path, tc.want.Path)
}
if len(got.Children) != len(tc.want.Children) {
t.Errorf("len(Children) = %v, want %v", len(got.Children), len(tc.want.Children))
}
// Add deep comparison of children if they exist
if len(got.Children) > 0 {
for i := range got.Children {
if got.Children[i].ID != tc.want.Children[i].ID {
t.Errorf("Children[%d].ID = %v, want %v", i, got.Children[i].ID, tc.want.Children[i].ID)
}
if got.Children[i].Name != tc.want.Children[i].Name {
t.Errorf("Children[%d].Name = %v, want %v", i, got.Children[i].Name, tc.want.Children[i].Name)
}
if got.Children[i].Path != tc.want.Children[i].Path {
t.Errorf("Children[%d].Path = %v, want %v", i, got.Children[i].Path, tc.want.Children[i].Path)
}
}
}
})
}
}
func TestListFilesRecursively(t *testing.T) {
mockFS := NewMockFS()
s := storage.NewServiceWithOptions("test-root", storage.Options{
Fs: mockFS,
NewGitClient: nil,
})
t.Run("empty directory", func(t *testing.T) {
mockFS.ReadDirReturns = map[string]struct {
entries []fs.DirEntry
err error
}{
"test-root/1/1": {
entries: []fs.DirEntry{},
err: nil,
},
}
files, err := s.ListFilesRecursively(1, 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(files) != 0 {
t.Errorf("expected empty file list, got %v", files)
}
})
t.Run("directory with files", func(t *testing.T) {
mockFS.ReadDirReturns = map[string]struct {
entries []fs.DirEntry
err error
}{
"test-root/1/1": {
entries: []fs.DirEntry{
NewMockDirEntry("file1.md", false),
NewMockDirEntry("file2.md", false),
},
err: nil,
},
}
files, err := s.ListFilesRecursively(1, 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(files) != 2 {
t.Errorf("expected 2 files, got %d", len(files))
}
})
t.Run("nested directories", func(t *testing.T) {
mockFS.ReadDirReturns = map[string]struct {
entries []fs.DirEntry
err error
}{
"test-root/1/1": {
entries: []fs.DirEntry{
NewMockDirEntry("dir1", true),
NewMockDirEntry("file1.md", false),
},
err: nil,
},
"test-root/1/1/dir1": {
entries: []fs.DirEntry{
NewMockDirEntry("file2.md", false),
},
err: nil,
},
}
files, err := s.ListFilesRecursively(1, 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(files) != 2 { // dir1 and file1.md
t.Errorf("expected 2 entries at root, got %d", len(files))
}
// Find directory and check its children
var dirFound bool
for _, f := range files {
if f.Name == "dir1" {
dirFound = true
if len(f.Children) != 1 {
t.Errorf("expected 1 child in dir1, got %d", len(f.Children))
}
}
}
if !dirFound {
t.Error("directory 'dir1' not found in results")
}
})
}
func TestGetFileContent(t *testing.T) {
mockFS := NewMockFS()
s := storage.NewServiceWithOptions("test-root", storage.Options{
Fs: mockFS,
NewGitClient: nil,
})
testCases := []struct {
name string
userID int
workspaceID int
filePath string
mockData []byte
mockErr error
wantErr bool
}{
{
name: "successful read",
userID: 1,
workspaceID: 1,
filePath: "test.md",
mockData: []byte("test content"),
mockErr: nil,
wantErr: false,
},
{
name: "file not found",
userID: 1,
workspaceID: 1,
filePath: "nonexistent.md",
mockData: nil,
mockErr: fs.ErrNotExist,
wantErr: true,
},
{
name: "invalid path",
userID: 1,
workspaceID: 1,
filePath: "../../../etc/passwd",
mockData: nil,
mockErr: nil,
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
expectedPath := filepath.Join("test-root", "1", "1", tc.filePath)
mockFS.ReadFileReturns[expectedPath] = struct {
data []byte
err error
}{tc.mockData, tc.mockErr}
content, err := s.GetFileContent(tc.userID, tc.workspaceID, tc.filePath)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(content) != string(tc.mockData) {
t.Errorf("content = %q, want %q", content, tc.mockData)
}
if mockFS.ReadCalls[expectedPath] != 1 {
t.Errorf("expected 1 read call for %s, got %d", expectedPath, mockFS.ReadCalls[expectedPath])
}
})
}
}
func TestSaveFile(t *testing.T) {
mockFS := NewMockFS()
s := storage.NewServiceWithOptions("test-root", storage.Options{
Fs: mockFS,
NewGitClient: nil,
})
testCases := []struct {
name string
userID int
workspaceID int
filePath string
content []byte
mockErr error
wantErr bool
}{
{
name: "successful save",
userID: 1,
workspaceID: 1,
filePath: "test.md",
content: []byte("test content"),
mockErr: nil,
wantErr: false,
},
{
name: "invalid path",
userID: 1,
workspaceID: 1,
filePath: "../../../etc/passwd",
content: []byte("test content"),
mockErr: nil,
wantErr: true,
},
{
name: "write error",
userID: 1,
workspaceID: 1,
filePath: "test.md",
content: []byte("test content"),
mockErr: fs.ErrPermission,
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mockFS.WriteFileError = tc.mockErr
err := s.SaveFile(tc.userID, tc.workspaceID, tc.filePath, tc.content)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expectedPath := filepath.Join("test-root", "1", "1", tc.filePath)
if content, ok := mockFS.WriteCalls[expectedPath]; ok {
if string(content) != string(tc.content) {
t.Errorf("written content = %q, want %q", content, tc.content)
}
} else {
t.Error("expected write call not made")
}
})
}
}
func TestDeleteFile(t *testing.T) {
mockFS := NewMockFS()
s := storage.NewServiceWithOptions("test-root", storage.Options{
Fs: mockFS,
NewGitClient: nil,
})
testCases := []struct {
name string
userID int
workspaceID int
filePath string
mockErr error
wantErr bool
}{
{
name: "successful delete",
userID: 1,
workspaceID: 1,
filePath: "test.md",
mockErr: nil,
wantErr: false,
},
{
name: "invalid path",
userID: 1,
workspaceID: 1,
filePath: "../../../etc/passwd",
mockErr: nil,
wantErr: true,
},
{
name: "file not found",
userID: 1,
workspaceID: 1,
filePath: "nonexistent.md",
mockErr: fs.ErrNotExist,
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mockFS.RemoveError = tc.mockErr
err := s.DeleteFile(tc.userID, tc.workspaceID, tc.filePath)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expectedPath := filepath.Join("test-root", "1", "1", tc.filePath)
found := false
for _, p := range mockFS.RemoveCalls {
if p == expectedPath {
found = true
break
}
}
if !found {
t.Error("expected delete call not made")
}
})
}
}

View File

@@ -0,0 +1,47 @@
package storage
import (
"io/fs"
"os"
)
// fileSystem defines the interface for filesystem operations
type fileSystem interface {
ReadFile(path string) ([]byte, error)
WriteFile(path string, data []byte, perm fs.FileMode) error
Remove(path string) error
MkdirAll(path string, perm fs.FileMode) error
RemoveAll(path string) error
ReadDir(path string) ([]fs.DirEntry, error)
Stat(path string) (fs.FileInfo, error)
IsNotExist(err error) bool
}
// osFS implements the FileSystem interface using the real filesystem.
type osFS struct{}
// ReadFile reads the file at the given path.
func (f *osFS) ReadFile(path string) ([]byte, error) { return os.ReadFile(path) }
// WriteFile writes the given data to the file at the given path.
func (f *osFS) WriteFile(path string, data []byte, perm fs.FileMode) error {
return os.WriteFile(path, data, perm)
}
// Remove deletes the file at the given path.
func (f *osFS) Remove(path string) error { return os.Remove(path) }
// MkdirAll creates the directory at the given path and any necessary parents.
func (f *osFS) MkdirAll(path string, perm fs.FileMode) error { return os.MkdirAll(path, perm) }
// RemoveAll removes the file or directory at the given path.
func (f *osFS) RemoveAll(path string) error { return os.RemoveAll(path) }
// ReadDir reads the directory at the given path.
func (f *osFS) ReadDir(path string) ([]fs.DirEntry, error) { return os.ReadDir(path) }
// Stat returns the FileInfo for the file at the given path.
func (f *osFS) Stat(path string) (fs.FileInfo, error) { return os.Stat(path) }
// IsNotExist returns true if the error is a "file does not exist" error.
func (f *osFS) IsNotExist(err error) bool { return os.IsNotExist(err) }

View File

@@ -0,0 +1,125 @@
package storage_test
import (
"errors"
"io/fs"
"path/filepath"
"time"
)
type mockDirEntry struct {
name string
isDir bool
}
func (m *mockDirEntry) Name() string { return m.name }
func (m *mockDirEntry) IsDir() bool { return m.isDir }
func (m *mockDirEntry) Type() fs.FileMode { return fs.ModeDir }
func (m *mockDirEntry) Info() (fs.FileInfo, error) { return nil, nil }
func NewMockDirEntry(name string, isDir bool) fs.DirEntry {
return &mockDirEntry{name: name, isDir: isDir}
}
// Extend mockFS to support directory operations
type MockDirInfo struct {
name string
size int64
mode fs.FileMode
modTime time.Time
isDir bool
}
func (m MockDirInfo) Name() string { return m.name }
func (m MockDirInfo) Size() int64 { return m.size }
func (m MockDirInfo) Mode() fs.FileMode { return m.mode }
func (m MockDirInfo) ModTime() time.Time { return m.modTime }
func (m MockDirInfo) IsDir() bool { return m.isDir }
func (m MockDirInfo) Sys() interface{} { return nil }
type mockFS struct {
// Record operations for verification
ReadCalls map[string]int
WriteCalls map[string][]byte
RemoveCalls []string
MkdirCalls []string
// Configure test behavior
ReadFileReturns map[string]struct {
data []byte
err error
}
ReadDirReturns map[string]struct {
entries []fs.DirEntry
err error
}
WriteFileError error
RemoveError error
MkdirError error
StatError error
}
func NewMockFS() *mockFS {
return &mockFS{
ReadCalls: make(map[string]int),
WriteCalls: make(map[string][]byte),
RemoveCalls: make([]string, 0),
MkdirCalls: make([]string, 0),
ReadFileReturns: make(map[string]struct {
data []byte
err error
}),
}
}
func (m *mockFS) ReadFile(path string) ([]byte, error) {
m.ReadCalls[path]++
if ret, ok := m.ReadFileReturns[path]; ok {
return ret.data, ret.err
}
return nil, errors.New("file not found")
}
func (m *mockFS) WriteFile(path string, data []byte, _ fs.FileMode) error {
m.WriteCalls[path] = data
return m.WriteFileError
}
func (m *mockFS) Remove(path string) error {
m.RemoveCalls = append(m.RemoveCalls, path)
return m.RemoveError
}
func (m *mockFS) MkdirAll(path string, _ fs.FileMode) error {
m.MkdirCalls = append(m.MkdirCalls, path)
return m.MkdirError
}
func (m *mockFS) Stat(path string) (fs.FileInfo, error) {
if m.StatError != nil {
return nil, m.StatError
}
return MockDirInfo{
name: filepath.Base(path),
size: 1024,
mode: 0644,
modTime: time.Now(),
isDir: false,
}, nil
}
func (m *mockFS) ReadDir(path string) ([]fs.DirEntry, error) {
if ret, ok := m.ReadDirReturns[path]; ok {
return ret.entries, ret.err
}
return nil, fs.ErrNotExist
}
func (m *mockFS) RemoveAll(path string) error {
m.RemoveCalls = append(m.RemoveCalls, path)
return m.RemoveError
}
func (m *mockFS) IsNotExist(err error) bool {
return err == fs.ErrNotExist
}

View File

@@ -0,0 +1,71 @@
package storage
import (
"fmt"
"novamd/internal/git"
)
// RepositoryManager defines the interface for managing Git repositories.
type RepositoryManager interface {
SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken string) error
DisableGitRepo(userID, workspaceID int)
StageCommitAndPush(userID, workspaceID int, message string) error
Pull(userID, workspaceID int) error
}
// SetupGitRepo sets up a Git repository for the given userID and workspaceID.
// The repository is cloned from the given gitURL using the given gitUser and gitToken.
func (s *Service) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken string) error {
workspacePath := s.GetWorkspacePath(userID, workspaceID)
if _, ok := s.GitRepos[userID]; !ok {
s.GitRepos[userID] = make(map[int]git.Client)
}
s.GitRepos[userID][workspaceID] = s.newGitClient(gitURL, gitUser, gitToken, workspacePath)
return s.GitRepos[userID][workspaceID].EnsureRepo()
}
// DisableGitRepo disables the Git repository for the given userID and workspaceID.
func (s *Service) DisableGitRepo(userID, workspaceID int) {
if userRepos, ok := s.GitRepos[userID]; ok {
delete(userRepos, workspaceID)
if len(userRepos) == 0 {
delete(s.GitRepos, userID)
}
}
}
// StageCommitAndPush stages, commit with the message, and pushes the changes to the Git repository.
// The git repository belongs to the given userID and is associated with the given workspaceID.
func (s *Service) StageCommitAndPush(userID, workspaceID int, message string) error {
repo, ok := s.getGitRepo(userID, workspaceID)
if !ok {
return fmt.Errorf("git settings not configured for this workspace")
}
if err := repo.Commit(message); err != nil {
return err
}
return repo.Push()
}
// Pull pulls the changes from the remote Git repository.
// The git repository belongs to the given userID and is associated with the given workspaceID.
func (s *Service) Pull(userID, workspaceID int) error {
repo, ok := s.getGitRepo(userID, workspaceID)
if !ok {
return fmt.Errorf("git settings not configured for this workspace")
}
return repo.Pull()
}
// getGitRepo returns the Git repository for the given user and workspace IDs.
func (s *Service) getGitRepo(userID, workspaceID int) (git.Client, bool) {
userRepos, ok := s.GitRepos[userID]
if !ok {
return nil, false
}
repo, ok := userRepos[workspaceID]
return repo, ok
}

View File

@@ -0,0 +1,258 @@
package storage_test
import (
"errors"
"testing"
"novamd/internal/git"
"novamd/internal/storage"
)
// MockGitClient implements git.Client interface for testing
type MockGitClient struct {
CloneCalled bool
PullCalled bool
CommitCalled bool
PushCalled bool
EnsureCalled bool
CommitMessage string
ReturnError error
}
func (m *MockGitClient) Clone() error {
m.CloneCalled = true
return m.ReturnError
}
func (m *MockGitClient) Pull() error {
m.PullCalled = true
return m.ReturnError
}
func (m *MockGitClient) Commit(message string) error {
m.CommitCalled = true
m.CommitMessage = message
return m.ReturnError
}
func (m *MockGitClient) Push() error {
m.PushCalled = true
return m.ReturnError
}
func (m *MockGitClient) EnsureRepo() error {
m.EnsureCalled = true
return m.ReturnError
}
func TestSetupGitRepo(t *testing.T) {
mockFS := NewMockFS()
testCases := []struct {
name string
userID int
workspaceID int
gitURL string
gitUser string
gitToken string
mockErr error
wantErr bool
}{
{
name: "successful setup",
userID: 1,
workspaceID: 1,
gitURL: "https://github.com/user/repo",
gitUser: "user",
gitToken: "token",
mockErr: nil,
wantErr: false,
},
{
name: "git initialization error",
userID: 1,
workspaceID: 2,
gitURL: "https://github.com/user/repo",
gitUser: "user",
gitToken: "token",
mockErr: errors.New("git initialization failed"),
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a mock client with the desired error behavior
mockClient := &MockGitClient{ReturnError: tc.mockErr}
// Create a client factory that returns our configured mock
mockClientFactory := func(_, _, _, _ string) git.Client {
return mockClient
}
s := storage.NewServiceWithOptions("test-root", storage.Options{
Fs: mockFS,
NewGitClient: mockClientFactory,
})
// Setup the git repo
err := s.SetupGitRepo(tc.userID, tc.workspaceID, tc.gitURL, tc.gitUser, tc.gitToken)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check if client was stored correctly
client, ok := s.GitRepos[tc.userID][tc.workspaceID]
if !ok {
t.Fatal("git client was not stored in service")
}
if !mockClient.EnsureCalled {
t.Error("EnsureRepo was not called")
}
// Verify it's our mock client
if client != mockClient {
t.Error("stored client is not our mock client")
}
})
}
}
func TestGitOperations(t *testing.T) {
mockFS := NewMockFS()
s := storage.NewServiceWithOptions("test-root", storage.Options{
Fs: mockFS,
NewGitClient: func(_, _, _, _ string) git.Client { return &MockGitClient{} },
})
t.Run("operations on non-configured workspace", func(t *testing.T) {
err := s.StageCommitAndPush(1, 1, "test commit")
if err == nil {
t.Error("expected error for non-configured workspace, got nil")
}
err = s.Pull(1, 1)
if err == nil {
t.Error("expected error for non-configured workspace, got nil")
}
})
t.Run("successful operations", func(t *testing.T) {
// Initialize GitRepos map
s.GitRepos = make(map[int]map[int]git.Client)
s.GitRepos[1] = make(map[int]git.Client)
mockClient := &MockGitClient{}
s.GitRepos[1][1] = mockClient
// Test commit and push
err := s.StageCommitAndPush(1, 1, "test commit")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if !mockClient.CommitCalled {
t.Error("Commit was not called")
}
if mockClient.CommitMessage != "test commit" {
t.Errorf("Commit message = %q, want %q", mockClient.CommitMessage, "test commit")
}
if !mockClient.PushCalled {
t.Error("Push was not called")
}
// Test pull
err = s.Pull(1, 1)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if !mockClient.PullCalled {
t.Error("Pull was not called")
}
})
t.Run("operation errors", func(t *testing.T) {
// Initialize GitRepos map with error-returning client
s.GitRepos = make(map[int]map[int]git.Client)
s.GitRepos[1] = make(map[int]git.Client)
mockClient := &MockGitClient{ReturnError: errors.New("git operation failed")}
s.GitRepos[1][1] = mockClient
// Test commit error
err := s.StageCommitAndPush(1, 1, "test commit")
if err == nil {
t.Error("expected error for commit, got nil")
}
// Test pull error
err = s.Pull(1, 1)
if err == nil {
t.Error("expected error for pull, got nil")
}
})
}
func TestDisableGitRepo(t *testing.T) {
mockFS := NewMockFS()
s := storage.NewServiceWithOptions("test-root", storage.Options{
Fs: mockFS,
NewGitClient: func(_, _, _, _ string) git.Client { return &MockGitClient{} },
})
testCases := []struct {
name string
userID int
workspaceID int
setupRepo bool
}{
{
name: "disable existing repo",
userID: 1,
workspaceID: 1,
setupRepo: true,
},
{
name: "disable non-existent repo",
userID: 2,
workspaceID: 1,
setupRepo: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Reset GitRepos for each test
s.GitRepos = make(map[int]map[int]git.Client)
if tc.setupRepo {
// Setup initial repo
s.GitRepos[tc.userID] = make(map[int]git.Client)
s.GitRepos[tc.userID][tc.workspaceID] = &MockGitClient{}
}
// Disable the repo
s.DisableGitRepo(tc.userID, tc.workspaceID)
// Verify repo was removed
if userRepos, exists := s.GitRepos[tc.userID]; exists {
if _, repoExists := userRepos[tc.workspaceID]; repoExists {
t.Error("git repo still exists after disable")
}
}
// If this was the user's last repo, verify user entry was cleaned up
if tc.setupRepo {
if len(s.GitRepos[tc.userID]) > 0 {
t.Error("user's git repos map not cleaned up when last repo removed")
}
}
})
}
}

View File

@@ -0,0 +1,52 @@
package storage
import (
"novamd/internal/git"
)
// Manager interface combines all storage interfaces.
type Manager interface {
FileManager
WorkspaceManager
RepositoryManager
}
// Service represents the file system structure.
type Service struct {
fs fileSystem
newGitClient func(url, user, token, path string) git.Client
RootDir string
GitRepos map[int]map[int]git.Client // map[userID]map[workspaceID]*git.Client
}
// Options represents the options for the storage service.
type Options struct {
Fs fileSystem
NewGitClient func(url, user, token, path string) git.Client
}
// NewService creates a new Storage instance with the default options and the given rootDir root directory.
func NewService(rootDir string) *Service {
return NewServiceWithOptions(rootDir, Options{
Fs: &osFS{},
NewGitClient: git.New,
})
}
// NewServiceWithOptions creates a new Storage instance with the given options and the given rootDir root directory.
func NewServiceWithOptions(rootDir string, options Options) *Service {
if options.Fs == nil {
options.Fs = &osFS{}
}
if options.NewGitClient == nil {
options.NewGitClient = git.New
}
return &Service{
fs: options.Fs,
newGitClient: options.NewGitClient,
RootDir: rootDir,
GitRepos: make(map[int]map[int]git.Client),
}
}

View File

@@ -0,0 +1,64 @@
package storage
import (
"fmt"
"path/filepath"
"strings"
)
// WorkspaceManager provides functionalities to interact with workspaces in the storage.
type WorkspaceManager interface {
ValidatePath(userID, workspaceID int, path string) (string, error)
GetWorkspacePath(userID, workspaceID int) string
InitializeUserWorkspace(userID, workspaceID int) error
DeleteUserWorkspace(userID, workspaceID int) error
}
// ValidatePath validates the if the given path is valid within the workspace directory.
// Workspace directory is defined as the directory for the given userID and workspaceID.
func (s *Service) ValidatePath(userID, workspaceID int, path string) (string, error) {
workspacePath := s.GetWorkspacePath(userID, workspaceID)
// First check if the path is absolute
if filepath.IsAbs(path) {
return "", &PathValidationError{Path: path, Message: "absolute paths not allowed"}
}
// Join and clean the path
fullPath := filepath.Join(workspacePath, path)
cleanPath := filepath.Clean(fullPath)
// Verify the path is still within the workspace
if !strings.HasPrefix(cleanPath, workspacePath) {
return "", &PathValidationError{Path: path, Message: "path traversal attempt"}
}
return cleanPath, nil
}
// GetWorkspacePath returns the path to the workspace directory for the given userID and workspaceID.
func (s *Service) GetWorkspacePath(userID, workspaceID int) string {
return filepath.Join(s.RootDir, fmt.Sprintf("%d", userID), fmt.Sprintf("%d", workspaceID))
}
// InitializeUserWorkspace creates the workspace directory for the given userID and workspaceID.
func (s *Service) InitializeUserWorkspace(userID, workspaceID int) error {
workspacePath := s.GetWorkspacePath(userID, workspaceID)
err := s.fs.MkdirAll(workspacePath, 0755)
if err != nil {
return fmt.Errorf("failed to create workspace directory: %w", err)
}
return nil
}
// DeleteUserWorkspace deletes the workspace directory for the given userID and workspaceID.
func (s *Service) DeleteUserWorkspace(userID, workspaceID int) error {
workspacePath := s.GetWorkspacePath(userID, workspaceID)
err := s.fs.RemoveAll(workspacePath)
if err != nil {
return fmt.Errorf("failed to delete workspace directory: %w", err)
}
return nil
}

View File

@@ -0,0 +1,275 @@
package storage_test
import (
"errors"
"path/filepath"
"strings"
"testing"
"novamd/internal/storage"
)
func TestValidatePath(t *testing.T) {
mockFS := NewMockFS()
s := storage.NewServiceWithOptions("test-root", storage.Options{
Fs: mockFS,
NewGitClient: nil,
})
testCases := []struct {
name string
userID int
workspaceID int
path string
want string
wantErr bool
errContains string
}{
{
name: "valid path",
userID: 1,
workspaceID: 1,
path: "notes/test.md",
want: filepath.Join("test-root", "1", "1", "notes", "test.md"),
wantErr: false,
},
{
name: "valid path with dot",
userID: 1,
workspaceID: 1,
path: "./notes/test.md",
want: filepath.Join("test-root", "1", "1", "notes", "test.md"),
wantErr: false,
},
{
name: "path with parent directory traversal",
userID: 1,
workspaceID: 1,
path: "../../../etc/passwd",
want: "",
wantErr: true,
errContains: "path traversal attempt",
},
{
name: "absolute path attempt",
userID: 1,
workspaceID: 1,
path: "/etc/passwd",
want: "",
wantErr: true,
errContains: "absolute paths not allowed",
},
{
name: "empty path",
userID: 1,
workspaceID: 1,
path: "",
want: filepath.Join("test-root", "1", "1"),
wantErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := s.ValidatePath(tc.userID, tc.workspaceID, tc.path)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
return
}
if !strings.Contains(err.Error(), tc.errContains) {
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tc.want {
t.Errorf("ValidatePath() = %v, want %v", got, tc.want)
}
})
}
}
func TestGetWorkspacePath(t *testing.T) {
mockFS := NewMockFS()
s := storage.NewServiceWithOptions("test-root", storage.Options{
Fs: mockFS,
NewGitClient: nil,
})
testCases := []struct {
name string
userID int
workspaceID int
want string
}{
{
name: "standard workspace path",
userID: 1,
workspaceID: 1,
want: filepath.Join("test-root", "1", "1"),
},
{
name: "different user and workspace IDs",
userID: 2,
workspaceID: 3,
want: filepath.Join("test-root", "2", "3"),
},
{
name: "zero IDs",
userID: 0,
workspaceID: 0,
want: filepath.Join("test-root", "0", "0"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := s.GetWorkspacePath(tc.userID, tc.workspaceID)
if got != tc.want {
t.Errorf("GetWorkspacePath() = %v, want %v", got, tc.want)
}
})
}
}
func TestInitializeUserWorkspace(t *testing.T) {
mockFS := NewMockFS()
s := storage.NewServiceWithOptions("test-root", storage.Options{
Fs: mockFS,
NewGitClient: nil,
})
testCases := []struct {
name string
userID int
workspaceID int
mockErr error
wantErr bool
errContains string
}{
{
name: "successful initialization",
userID: 1,
workspaceID: 1,
mockErr: nil,
wantErr: false,
},
{
name: "mkdir error",
userID: 1,
workspaceID: 1,
mockErr: errors.New("permission denied"),
wantErr: true,
errContains: "failed to create workspace directory",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mockFS.MkdirError = tc.mockErr
err := s.InitializeUserWorkspace(tc.userID, tc.workspaceID)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
return
}
if !strings.Contains(err.Error(), tc.errContains) {
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify the correct directory was created
expectedPath := filepath.Join("test-root", "1", "1")
dirCreated := false
for _, path := range mockFS.MkdirCalls {
if path == expectedPath {
dirCreated = true
break
}
}
if !dirCreated {
t.Errorf("directory %s was not created", expectedPath)
}
})
}
}
func TestDeleteUserWorkspace(t *testing.T) {
mockFS := NewMockFS()
s := storage.NewServiceWithOptions("test-root", storage.Options{
Fs: mockFS,
NewGitClient: nil,
})
testCases := []struct {
name string
userID int
workspaceID int
mockErr error
wantErr bool
errContains string
}{
{
name: "successful deletion",
userID: 1,
workspaceID: 1,
mockErr: nil,
wantErr: false,
},
{
name: "removal error",
userID: 1,
workspaceID: 1,
mockErr: errors.New("permission denied"),
wantErr: true,
errContains: "failed to delete workspace directory",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mockFS.RemoveError = tc.mockErr
err := s.DeleteUserWorkspace(tc.userID, tc.workspaceID)
if tc.wantErr {
if err == nil {
t.Error("expected error, got nil")
return
}
if !strings.Contains(err.Error(), tc.errContains) {
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify the correct directory was deleted
expectedPath := filepath.Join("test-root", "1", "1")
dirDeleted := false
for _, path := range mockFS.RemoveCalls {
if path == expectedPath {
dirDeleted = true
break
}
}
if !dirDeleted {
t.Errorf("directory %s was not deleted", expectedPath)
}
})
}
}

View File

@@ -8,19 +8,19 @@ import (
"golang.org/x/crypto/bcrypt"
"novamd/internal/db"
"novamd/internal/filesystem"
"novamd/internal/models"
"novamd/internal/storage"
)
type UserService struct {
DB *db.DB
FS *filesystem.FileSystem
DB db.Database
Storage storage.Manager
}
func NewUserService(database *db.DB, fs *filesystem.FileSystem) *UserService {
func NewUserService(database db.Database, s storage.Manager) *UserService {
return &UserService{
DB: database,
FS: fs,
DB: database,
Storage: s,
}
}
@@ -53,7 +53,7 @@ func (s *UserService) SetupAdminUser(adminEmail, adminPassword string) (*models.
}
// Initialize workspace directory
err = s.FS.InitializeUserWorkspace(createdUser.ID, createdUser.LastWorkspaceID)
err = s.Storage.InitializeUserWorkspace(createdUser.ID, createdUser.LastWorkspaceID)
if err != nil {
return nil, fmt.Errorf("failed to initialize admin workspace: %w", err)
}