mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-05 23:44:22 +00:00
Implement auth handler integration test
This commit is contained in:
4
.vscode/settings.json
vendored
4
.vscode/settings.json
vendored
@@ -14,7 +14,7 @@
|
||||
"go.lintTool": "golangci-lint",
|
||||
"go.lintOnSave": "package",
|
||||
"go.formatTool": "goimports",
|
||||
"go.testFlags": ["-tags=test"],
|
||||
"go.testFlags": ["-tags=test,integration"],
|
||||
"[go]": {
|
||||
"editor.formatOnSave": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
@@ -25,6 +25,6 @@
|
||||
"gopls": {
|
||||
"usePlaceholders": true,
|
||||
"staticcheck": true,
|
||||
"buildFlags": ["-tags", "test"]
|
||||
"buildFlags": ["-tags", "test,integration"]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ require (
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/mattn/go-sqlite3 v1.14.23
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/unrolled/secure v1.17.0
|
||||
golang.org/x/crypto v0.21.0
|
||||
)
|
||||
@@ -22,6 +23,7 @@ require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudflare/circl v1.3.7 // indirect
|
||||
github.com/cyphar/filepath-securejoin v0.2.4 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
|
||||
@@ -33,6 +35,7 @@ require (
|
||||
github.com/kevinburke/ssh_config v1.2.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/pjbgf/sha1cd v0.3.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
|
||||
github.com/skeema/knownhosts v1.2.2 // indirect
|
||||
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
||||
@@ -42,4 +45,5 @@ require (
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
golang.org/x/tools v0.13.0 // indirect
|
||||
gopkg.in/warnings.v0 v0.1.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -87,11 +89,19 @@ func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, erro
|
||||
// Returns the signed token string or an error
|
||||
func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
// Add a random nonce to ensure uniqueness
|
||||
nonce := make([]byte, 8)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
claims := Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(expiry)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
ID: hex.EncodeToString(nonce),
|
||||
},
|
||||
UserID: userID,
|
||||
Role: role,
|
||||
|
||||
@@ -76,8 +76,8 @@ func (s *SessionService) CreateSession(userID int, role string) (*models.Session
|
||||
// - string: a new access token
|
||||
// - error: any error that occurred
|
||||
func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
||||
// Get session from database
|
||||
_, err := s.db.GetSessionByRefreshToken(refreshToken)
|
||||
// Get session from database first
|
||||
session, err := s.db.GetSessionByRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid session: %w", err)
|
||||
}
|
||||
@@ -88,6 +88,11 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
|
||||
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Double check that the claims match the session
|
||||
if claims.UserID != session.UserID {
|
||||
return "", fmt.Errorf("token does not match session")
|
||||
}
|
||||
|
||||
// Generate a new access token
|
||||
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package handlers contains the request handlers for the api routes.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
|
||||
@@ -10,11 +10,13 @@ import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// LoginRequest represents a user login request
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// LoginResponse represents a user login response
|
||||
type LoginResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
@@ -22,10 +24,12 @@ type LoginResponse struct {
|
||||
Session *models.Session `json:"session"`
|
||||
}
|
||||
|
||||
// RefreshRequest represents a refresh token request
|
||||
type RefreshRequest struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// RefreshResponse represents a refresh token response
|
||||
type RefreshResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
}
|
||||
|
||||
232
server/internal/handlers/auth_handlers_integration_test.go
Normal file
232
server/internal/handlers/auth_handlers_integration_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
//go:build integration
|
||||
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/handlers"
|
||||
"novamd/internal/models"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAuthHandlers_Integration(t *testing.T) {
|
||||
h := setupTestHarness(t)
|
||||
defer h.teardown(t)
|
||||
|
||||
t.Run("login", func(t *testing.T) {
|
||||
t.Run("successful login - admin user", func(t *testing.T) {
|
||||
loginReq := handlers.LoginRequest{
|
||||
Email: "admin@test.com",
|
||||
Password: "admin123",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var resp handlers.LoginResponse
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, resp.AccessToken)
|
||||
assert.NotEmpty(t, resp.RefreshToken)
|
||||
assert.NotNil(t, resp.User)
|
||||
assert.Equal(t, loginReq.Email, resp.User.Email)
|
||||
assert.Equal(t, models.RoleAdmin, resp.User.Role)
|
||||
})
|
||||
|
||||
t.Run("successful login - regular user", func(t *testing.T) {
|
||||
loginReq := handlers.LoginRequest{
|
||||
Email: "user@test.com",
|
||||
Password: "user123",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var resp handlers.LoginResponse
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, resp.AccessToken)
|
||||
assert.NotEmpty(t, resp.RefreshToken)
|
||||
assert.NotNil(t, resp.User)
|
||||
assert.Equal(t, loginReq.Email, resp.User.Email)
|
||||
assert.Equal(t, models.RoleEditor, resp.User.Role)
|
||||
})
|
||||
|
||||
t.Run("login failures", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request handlers.LoginRequest
|
||||
wantCode int
|
||||
}{
|
||||
{
|
||||
name: "wrong password",
|
||||
request: handlers.LoginRequest{
|
||||
Email: "user@test.com",
|
||||
Password: "wrongpassword",
|
||||
},
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "non-existent user",
|
||||
request: handlers.LoginRequest{
|
||||
Email: "nonexistent@test.com",
|
||||
Password: "password123",
|
||||
},
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "empty email",
|
||||
request: handlers.LoginRequest{
|
||||
Email: "",
|
||||
Password: "password123",
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "empty password",
|
||||
request: handlers.LoginRequest{
|
||||
Email: "user@test.com",
|
||||
Password: "",
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", tt.request, "", nil)
|
||||
assert.Equal(t, tt.wantCode, rr.Code)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("refresh token", func(t *testing.T) {
|
||||
t.Run("successful token refresh", func(t *testing.T) {
|
||||
// First login to get refresh token
|
||||
loginReq := handlers.LoginRequest{
|
||||
Email: "user@test.com",
|
||||
Password: "user123",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var loginResp handlers.LoginResponse
|
||||
err := json.NewDecoder(rr.Body).Decode(&loginResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now try to refresh the token
|
||||
refreshReq := handlers.RefreshRequest{
|
||||
RefreshToken: loginResp.RefreshToken,
|
||||
}
|
||||
|
||||
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", refreshReq, "", nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var refreshResp handlers.RefreshResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&refreshResp)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, refreshResp.AccessToken)
|
||||
})
|
||||
|
||||
t.Run("refresh failures", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request handlers.RefreshRequest
|
||||
wantCode int
|
||||
}{
|
||||
{
|
||||
name: "invalid refresh token",
|
||||
request: handlers.RefreshRequest{
|
||||
RefreshToken: "invalid-token",
|
||||
},
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "empty refresh token",
|
||||
request: handlers.RefreshRequest{
|
||||
RefreshToken: "",
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", tt.request, "", nil)
|
||||
assert.Equal(t, tt.wantCode, rr.Code)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("logout", func(t *testing.T) {
|
||||
t.Run("successful logout", func(t *testing.T) {
|
||||
// First login to get session
|
||||
loginReq := handlers.LoginRequest{
|
||||
Email: "user@test.com",
|
||||
Password: "user123",
|
||||
}
|
||||
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var loginResp handlers.LoginResponse
|
||||
err := json.NewDecoder(rr.Body).Decode(&loginResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now logout using session ID from login response
|
||||
headers := map[string]string{
|
||||
"X-Session-ID": loginResp.Session.ID,
|
||||
}
|
||||
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/logout", nil, loginResp.AccessToken, headers)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Try to use the refresh token - should fail
|
||||
refreshReq := handlers.RefreshRequest{
|
||||
RefreshToken: loginResp.RefreshToken,
|
||||
}
|
||||
|
||||
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", refreshReq, "", nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("logout without session ID", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/logout", nil, h.RegularToken, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("get current user", func(t *testing.T) {
|
||||
t.Run("successful get current user", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, h.RegularToken, nil)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var user models.User
|
||||
err := json.NewDecoder(rr.Body).Decode(&user)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, h.RegularUser.ID, user.ID)
|
||||
assert.Equal(t, h.RegularUser.Email, user.Email)
|
||||
assert.Equal(t, h.RegularUser.Role, user.Role)
|
||||
})
|
||||
|
||||
t.Run("get current user without token", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "", nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("get current user with invalid token", func(t *testing.T) {
|
||||
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "invalid-token", nil)
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// ListFiles returns a list of all files in the workspace
|
||||
func (h *Handler) ListFiles() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -27,6 +28,7 @@ func (h *Handler) ListFiles() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// LookupFileByName returns the paths of files with the given name
|
||||
func (h *Handler) LookupFileByName() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -50,6 +52,7 @@ func (h *Handler) LookupFileByName() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// GetFileContent returns the content of a file
|
||||
func (h *Handler) GetFileContent() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -65,10 +68,15 @@ func (h *Handler) GetFileContent() http.HandlerFunc {
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.Write(content)
|
||||
_, err = w.Write(content)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SaveFile saves the content of a file
|
||||
func (h *Handler) SaveFile() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -93,6 +101,7 @@ func (h *Handler) SaveFile() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteFile deletes a file
|
||||
func (h *Handler) DeleteFile() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -108,10 +117,15 @@ func (h *Handler) DeleteFile() http.HandlerFunc {
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("File deleted successfully"))
|
||||
_, err = w.Write([]byte("File deleted successfully"))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetLastOpenedFile returns the last opened file in the workspace
|
||||
func (h *Handler) GetLastOpenedFile() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -134,6 +148,7 @@ func (h *Handler) GetLastOpenedFile() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateLastOpenedFile updates the last opened file in the workspace
|
||||
func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"novamd/internal/context"
|
||||
)
|
||||
|
||||
// StageCommitAndPush stages, commits, and pushes changes to the remote repository
|
||||
func (h *Handler) StageCommitAndPush() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -38,6 +39,7 @@ func (h *Handler) StageCommitAndPush() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// PullChanges pulls changes from the remote repository
|
||||
func (h *Handler) PullChanges() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
|
||||
188
server/internal/handlers/integration_test.go
Normal file
188
server/internal/handlers/integration_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
//go:build integration
|
||||
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"novamd/internal/api"
|
||||
"novamd/internal/auth"
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/handlers"
|
||||
"novamd/internal/models"
|
||||
"novamd/internal/secrets"
|
||||
"novamd/internal/storage"
|
||||
)
|
||||
|
||||
// testHarness encapsulates all the dependencies needed for testing
|
||||
type testHarness struct {
|
||||
DB db.TestDatabase
|
||||
Storage storage.Manager
|
||||
Router *chi.Mux
|
||||
Handler *handlers.Handler
|
||||
JWTManager auth.JWTManager
|
||||
SessionSvc *auth.SessionService
|
||||
AdminUser *models.User
|
||||
AdminToken string
|
||||
RegularUser *models.User
|
||||
RegularToken string
|
||||
TempDirectory string
|
||||
}
|
||||
|
||||
// setupTestHarness creates a new test environment
|
||||
func setupTestHarness(t *testing.T) *testHarness {
|
||||
t.Helper()
|
||||
|
||||
// Create temporary directory for test files
|
||||
tempDir, err := os.MkdirTemp("", "novamd-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
|
||||
// Initialize test database
|
||||
secretsSvc, err := secrets.NewService("YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=") // test key
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize secrets service: %v", err)
|
||||
}
|
||||
|
||||
database, err := db.NewTestDB(":memory:", secretsSvc)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize test database: %v", err)
|
||||
}
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("Failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
// Initialize storage
|
||||
storageSvc := storage.NewService(tempDir)
|
||||
|
||||
// Initialize JWT service
|
||||
jwtSvc, err := auth.NewJWTService(auth.JWTConfig{
|
||||
SigningKey: "test-key",
|
||||
AccessTokenExpiry: 15 * time.Minute,
|
||||
RefreshTokenExpiry: 24 * time.Hour,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize JWT service: %v", err)
|
||||
}
|
||||
|
||||
// Initialize session service
|
||||
sessionSvc := auth.NewSessionService(database, jwtSvc)
|
||||
|
||||
// Create handler
|
||||
handler := &handlers.Handler{
|
||||
DB: database,
|
||||
Storage: storageSvc,
|
||||
}
|
||||
|
||||
// Set up router with middlewares
|
||||
router := chi.NewRouter()
|
||||
authMiddleware := auth.NewMiddleware(jwtSvc)
|
||||
router.Route("/api/v1", func(r chi.Router) {
|
||||
api.SetupRoutes(r, database, storageSvc, authMiddleware, sessionSvc)
|
||||
})
|
||||
|
||||
// Create test users
|
||||
adminUser, adminToken := createTestUser(t, database, sessionSvc, "admin@test.com", "admin123", models.RoleAdmin)
|
||||
regularUser, regularToken := createTestUser(t, database, sessionSvc, "user@test.com", "user123", models.RoleEditor)
|
||||
|
||||
return &testHarness{
|
||||
DB: database,
|
||||
Storage: storageSvc,
|
||||
Router: router,
|
||||
Handler: handler,
|
||||
JWTManager: jwtSvc,
|
||||
SessionSvc: sessionSvc,
|
||||
AdminUser: adminUser,
|
||||
AdminToken: adminToken,
|
||||
RegularUser: regularUser,
|
||||
RegularToken: regularToken,
|
||||
TempDirectory: tempDir,
|
||||
}
|
||||
}
|
||||
|
||||
// teardownTestHarness cleans up the test environment
|
||||
func (h *testHarness) teardown(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
if err := h.DB.Close(); err != nil {
|
||||
t.Errorf("Failed to close database: %v", err)
|
||||
}
|
||||
|
||||
if err := os.RemoveAll(h.TempDirectory); err != nil {
|
||||
t.Errorf("Failed to remove temp directory: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// createTestUser creates a test user and returns the user and access token
|
||||
func createTestUser(t *testing.T, db db.Database, sessionSvc *auth.SessionService, email, password string, role models.UserRole) (*models.User, string) {
|
||||
t.Helper()
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to hash password: %v", err)
|
||||
}
|
||||
|
||||
user := &models.User{
|
||||
Email: email,
|
||||
DisplayName: "Test User",
|
||||
PasswordHash: string(hashedPassword),
|
||||
Role: role,
|
||||
}
|
||||
|
||||
user, err = db.CreateUser(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
session, accessToken, err := sessionSvc.CreateSession(user.ID, string(user.Role))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
if session == nil || accessToken == "" {
|
||||
t.Fatal("Failed to get valid session or token")
|
||||
}
|
||||
|
||||
return user, accessToken
|
||||
}
|
||||
|
||||
// makeRequest is a helper function to make HTTP requests in tests
|
||||
func (h *testHarness) makeRequest(t *testing.T, method, path string, body interface{}, token string, headers map[string]string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
var reqBody []byte
|
||||
var err error
|
||||
|
||||
if body != nil {
|
||||
reqBody, err = json.Marshal(body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(method, path, bytes.NewBuffer(reqBody))
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Add any additional headers
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
h.Router.ServeHTTP(rr, req)
|
||||
|
||||
return rr
|
||||
}
|
||||
@@ -12,12 +12,14 @@ type StaticHandler struct {
|
||||
staticPath string
|
||||
}
|
||||
|
||||
// NewStaticHandler creates a new StaticHandler with the given static path
|
||||
func NewStaticHandler(staticPath string) *StaticHandler {
|
||||
return &StaticHandler{
|
||||
staticPath: staticPath,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP serves the static files
|
||||
func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Get the requested path
|
||||
requestedPath := r.URL.Path
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// UpdateProfileRequest represents a user profile update request
|
||||
type UpdateProfileRequest struct {
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
@@ -16,10 +18,12 @@ type UpdateProfileRequest struct {
|
||||
NewPassword string `json:"newPassword"`
|
||||
}
|
||||
|
||||
// DeleteAccountRequest represents a user account deletion request
|
||||
type DeleteAccountRequest struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// GetUser returns the current user's profile
|
||||
func (h *Handler) GetUser() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -64,7 +68,11 @@ func (h *Handler) UpdateProfile() http.HandlerFunc {
|
||||
http.Error(w, "Failed to start transaction", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
defer func() {
|
||||
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
||||
http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
// Handle password update if requested
|
||||
if req.NewPassword != "" {
|
||||
@@ -188,7 +196,11 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
|
||||
http.Error(w, "Failed to start transaction", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
defer func() {
|
||||
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
||||
http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
// Get user's workspaces for cleanup
|
||||
workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID)
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"novamd/internal/context"
|
||||
"novamd/internal/models"
|
||||
)
|
||||
|
||||
// ListWorkspaces returns a list of all workspaces for the current user
|
||||
func (h *Handler) ListWorkspaces() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -26,6 +27,7 @@ func (h *Handler) ListWorkspaces() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// CreateWorkspace creates a new workspace
|
||||
func (h *Handler) CreateWorkspace() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -54,6 +56,7 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// GetWorkspace returns the current workspace
|
||||
func (h *Handler) GetWorkspace() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -81,6 +84,7 @@ func gitSettingsChanged(new, old *models.Workspace) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateWorkspace updates the current workspace
|
||||
func (h *Handler) UpdateWorkspace() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -132,6 +136,7 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteWorkspace deletes the current workspace
|
||||
func (h *Handler) DeleteWorkspace() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -168,7 +173,11 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
|
||||
http.Error(w, "Failed to start transaction", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
defer func() {
|
||||
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
||||
http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
// Update last workspace ID first
|
||||
err = h.DB.UpdateLastWorkspaceTx(tx, ctx.UserID, nextWorkspaceID)
|
||||
@@ -195,6 +204,7 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// GetLastWorkspaceName returns the name of the last opened workspace
|
||||
func (h *Handler) GetLastWorkspaceName() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -212,6 +222,7 @@ func (h *Handler) GetLastWorkspaceName() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateLastWorkspaceName updates the name of the last opened workspace
|
||||
func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, ok := context.GetRequestContext(w, r)
|
||||
@@ -224,13 +235,11 @@ func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc {
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
fmt.Println(err)
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.DB.UpdateLastWorkspace(ctx.UserID, requestBody.WorkspaceName); err != nil {
|
||||
fmt.Println(err)
|
||||
http.Error(w, "Failed to update last workspace", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user