Implement auth handler integration test

This commit is contained in:
2024-11-26 22:50:43 +01:00
parent e8868dde39
commit 4ddf1f570f
14 changed files with 499 additions and 15 deletions

View File

@@ -14,7 +14,7 @@
"go.lintTool": "golangci-lint", "go.lintTool": "golangci-lint",
"go.lintOnSave": "package", "go.lintOnSave": "package",
"go.formatTool": "goimports", "go.formatTool": "goimports",
"go.testFlags": ["-tags=test"], "go.testFlags": ["-tags=test,integration"],
"[go]": { "[go]": {
"editor.formatOnSave": true, "editor.formatOnSave": true,
"editor.codeActionsOnSave": { "editor.codeActionsOnSave": {
@@ -25,6 +25,6 @@
"gopls": { "gopls": {
"usePlaceholders": true, "usePlaceholders": true,
"staticcheck": true, "staticcheck": true,
"buildFlags": ["-tags", "test"] "buildFlags": ["-tags", "test,integration"]
} }
} }

View File

@@ -11,6 +11,7 @@ require (
github.com/golang-jwt/jwt/v5 v5.2.1 github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/mattn/go-sqlite3 v1.14.23 github.com/mattn/go-sqlite3 v1.14.23
github.com/stretchr/testify v1.9.0
github.com/unrolled/secure v1.17.0 github.com/unrolled/secure v1.17.0
golang.org/x/crypto v0.21.0 golang.org/x/crypto v0.21.0
) )
@@ -22,6 +23,7 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudflare/circl v1.3.7 // indirect github.com/cloudflare/circl v1.3.7 // indirect
github.com/cyphar/filepath-securejoin v0.2.4 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/emirpasic/gods v1.18.1 // indirect github.com/emirpasic/gods v1.18.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
@@ -33,6 +35,7 @@ require (
github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/kevinburke/ssh_config v1.2.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/pjbgf/sha1cd v0.3.0 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
github.com/skeema/knownhosts v1.2.2 // indirect github.com/skeema/knownhosts v1.2.2 // indirect
github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect
@@ -42,4 +45,5 @@ require (
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.13.0 // indirect golang.org/x/tools v0.13.0 // indirect
gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

View File

@@ -2,6 +2,8 @@
package auth package auth
import ( import (
"crypto/rand"
"encoding/hex"
"fmt" "fmt"
"time" "time"
@@ -87,11 +89,19 @@ func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, erro
// Returns the signed token string or an error // Returns the signed token string or an error
func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) { func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) {
now := time.Now() now := time.Now()
// Add a random nonce to ensure uniqueness
nonce := make([]byte, 8)
if _, err := rand.Read(nonce); err != nil {
return "", fmt.Errorf("failed to generate nonce: %w", err)
}
claims := Claims{ claims := Claims{
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(expiry)), ExpiresAt: jwt.NewNumericDate(now.Add(expiry)),
IssuedAt: jwt.NewNumericDate(now), IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now),
ID: hex.EncodeToString(nonce),
}, },
UserID: userID, UserID: userID,
Role: role, Role: role,

View File

@@ -76,8 +76,8 @@ func (s *SessionService) CreateSession(userID int, role string) (*models.Session
// - string: a new access token // - string: a new access token
// - error: any error that occurred // - error: any error that occurred
func (s *SessionService) RefreshSession(refreshToken string) (string, error) { func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
// Get session from database // Get session from database first
_, err := s.db.GetSessionByRefreshToken(refreshToken) session, err := s.db.GetSessionByRefreshToken(refreshToken)
if err != nil { if err != nil {
return "", fmt.Errorf("invalid session: %w", err) return "", fmt.Errorf("invalid session: %w", err)
} }
@@ -88,6 +88,11 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) {
return "", fmt.Errorf("invalid refresh token: %w", err) return "", fmt.Errorf("invalid refresh token: %w", err)
} }
// Double check that the claims match the session
if claims.UserID != session.UserID {
return "", fmt.Errorf("token does not match session")
}
// Generate a new access token // Generate a new access token
return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role) return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role)
} }

View File

@@ -25,9 +25,9 @@ func (db *database) CreateSession(session *models.Session) error {
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
session := &models.Session{} session := &models.Session{}
err := db.QueryRow(` err := db.QueryRow(`
SELECT id, user_id, refresh_token, expires_at, created_at SELECT id, user_id, refresh_token, expires_at, created_at
FROM sessions FROM sessions
WHERE refresh_token = ? AND expires_at > ?`, WHERE refresh_token = ? AND expires_at > ?`,
refreshToken, time.Now(), refreshToken, time.Now(),
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) ).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)

View File

@@ -1,3 +1,4 @@
// Package handlers contains the request handlers for the api routes.
package handlers package handlers
import ( import (

View File

@@ -10,11 +10,13 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
// LoginRequest represents a user login request
type LoginRequest struct { type LoginRequest struct {
Email string `json:"email"` Email string `json:"email"`
Password string `json:"password"` Password string `json:"password"`
} }
// LoginResponse represents a user login response
type LoginResponse struct { type LoginResponse struct {
AccessToken string `json:"accessToken"` AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"` RefreshToken string `json:"refreshToken"`
@@ -22,10 +24,12 @@ type LoginResponse struct {
Session *models.Session `json:"session"` Session *models.Session `json:"session"`
} }
// RefreshRequest represents a refresh token request
type RefreshRequest struct { type RefreshRequest struct {
RefreshToken string `json:"refreshToken"` RefreshToken string `json:"refreshToken"`
} }
// RefreshResponse represents a refresh token response
type RefreshResponse struct { type RefreshResponse struct {
AccessToken string `json:"accessToken"` AccessToken string `json:"accessToken"`
} }

View File

@@ -0,0 +1,232 @@
//go:build integration
package handlers_test
import (
"encoding/json"
"net/http"
"testing"
"novamd/internal/handlers"
"novamd/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAuthHandlers_Integration(t *testing.T) {
h := setupTestHarness(t)
defer h.teardown(t)
t.Run("login", func(t *testing.T) {
t.Run("successful login - admin user", func(t *testing.T) {
loginReq := handlers.LoginRequest{
Email: "admin@test.com",
Password: "admin123",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
require.Equal(t, http.StatusOK, rr.Code)
var resp handlers.LoginResponse
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.NotEmpty(t, resp.AccessToken)
assert.NotEmpty(t, resp.RefreshToken)
assert.NotNil(t, resp.User)
assert.Equal(t, loginReq.Email, resp.User.Email)
assert.Equal(t, models.RoleAdmin, resp.User.Role)
})
t.Run("successful login - regular user", func(t *testing.T) {
loginReq := handlers.LoginRequest{
Email: "user@test.com",
Password: "user123",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
require.Equal(t, http.StatusOK, rr.Code)
var resp handlers.LoginResponse
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.NotEmpty(t, resp.AccessToken)
assert.NotEmpty(t, resp.RefreshToken)
assert.NotNil(t, resp.User)
assert.Equal(t, loginReq.Email, resp.User.Email)
assert.Equal(t, models.RoleEditor, resp.User.Role)
})
t.Run("login failures", func(t *testing.T) {
tests := []struct {
name string
request handlers.LoginRequest
wantCode int
}{
{
name: "wrong password",
request: handlers.LoginRequest{
Email: "user@test.com",
Password: "wrongpassword",
},
wantCode: http.StatusUnauthorized,
},
{
name: "non-existent user",
request: handlers.LoginRequest{
Email: "nonexistent@test.com",
Password: "password123",
},
wantCode: http.StatusUnauthorized,
},
{
name: "empty email",
request: handlers.LoginRequest{
Email: "",
Password: "password123",
},
wantCode: http.StatusBadRequest,
},
{
name: "empty password",
request: handlers.LoginRequest{
Email: "user@test.com",
Password: "",
},
wantCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", tt.request, "", nil)
assert.Equal(t, tt.wantCode, rr.Code)
})
}
})
})
t.Run("refresh token", func(t *testing.T) {
t.Run("successful token refresh", func(t *testing.T) {
// First login to get refresh token
loginReq := handlers.LoginRequest{
Email: "user@test.com",
Password: "user123",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
require.Equal(t, http.StatusOK, rr.Code)
var loginResp handlers.LoginResponse
err := json.NewDecoder(rr.Body).Decode(&loginResp)
require.NoError(t, err)
// Now try to refresh the token
refreshReq := handlers.RefreshRequest{
RefreshToken: loginResp.RefreshToken,
}
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", refreshReq, "", nil)
require.Equal(t, http.StatusOK, rr.Code)
var refreshResp handlers.RefreshResponse
err = json.NewDecoder(rr.Body).Decode(&refreshResp)
require.NoError(t, err)
assert.NotEmpty(t, refreshResp.AccessToken)
})
t.Run("refresh failures", func(t *testing.T) {
tests := []struct {
name string
request handlers.RefreshRequest
wantCode int
}{
{
name: "invalid refresh token",
request: handlers.RefreshRequest{
RefreshToken: "invalid-token",
},
wantCode: http.StatusUnauthorized,
},
{
name: "empty refresh token",
request: handlers.RefreshRequest{
RefreshToken: "",
},
wantCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", tt.request, "", nil)
assert.Equal(t, tt.wantCode, rr.Code)
})
}
})
})
t.Run("logout", func(t *testing.T) {
t.Run("successful logout", func(t *testing.T) {
// First login to get session
loginReq := handlers.LoginRequest{
Email: "user@test.com",
Password: "user123",
}
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil)
require.Equal(t, http.StatusOK, rr.Code)
var loginResp handlers.LoginResponse
err := json.NewDecoder(rr.Body).Decode(&loginResp)
require.NoError(t, err)
// Now logout using session ID from login response
headers := map[string]string{
"X-Session-ID": loginResp.Session.ID,
}
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/logout", nil, loginResp.AccessToken, headers)
require.Equal(t, http.StatusOK, rr.Code)
// Try to use the refresh token - should fail
refreshReq := handlers.RefreshRequest{
RefreshToken: loginResp.RefreshToken,
}
rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", refreshReq, "", nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
})
t.Run("logout without session ID", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/logout", nil, h.RegularToken, nil)
assert.Equal(t, http.StatusBadRequest, rr.Code)
})
})
t.Run("get current user", func(t *testing.T) {
t.Run("successful get current user", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, h.RegularToken, nil)
require.Equal(t, http.StatusOK, rr.Code)
var user models.User
err := json.NewDecoder(rr.Body).Decode(&user)
require.NoError(t, err)
assert.Equal(t, h.RegularUser.ID, user.ID)
assert.Equal(t, h.RegularUser.Email, user.Email)
assert.Equal(t, h.RegularUser.Role, user.Role)
})
t.Run("get current user without token", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "", nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
})
t.Run("get current user with invalid token", func(t *testing.T) {
rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "invalid-token", nil)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
})
})
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
) )
// ListFiles returns a list of all files in the workspace
func (h *Handler) ListFiles() http.HandlerFunc { func (h *Handler) ListFiles() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -27,6 +28,7 @@ func (h *Handler) ListFiles() http.HandlerFunc {
} }
} }
// LookupFileByName returns the paths of files with the given name
func (h *Handler) LookupFileByName() http.HandlerFunc { func (h *Handler) LookupFileByName() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -50,6 +52,7 @@ func (h *Handler) LookupFileByName() http.HandlerFunc {
} }
} }
// GetFileContent returns the content of a file
func (h *Handler) GetFileContent() http.HandlerFunc { func (h *Handler) GetFileContent() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -65,10 +68,15 @@ func (h *Handler) GetFileContent() http.HandlerFunc {
} }
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
w.Write(content) _, err = w.Write(content)
if err != nil {
http.Error(w, "Failed to write response", http.StatusInternalServerError)
return
}
} }
} }
// SaveFile saves the content of a file
func (h *Handler) SaveFile() http.HandlerFunc { func (h *Handler) SaveFile() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -93,6 +101,7 @@ func (h *Handler) SaveFile() http.HandlerFunc {
} }
} }
// DeleteFile deletes a file
func (h *Handler) DeleteFile() http.HandlerFunc { func (h *Handler) DeleteFile() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -108,10 +117,15 @@ func (h *Handler) DeleteFile() http.HandlerFunc {
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte("File deleted successfully")) _, err = w.Write([]byte("File deleted successfully"))
if err != nil {
http.Error(w, "Failed to write response", http.StatusInternalServerError)
return
}
} }
} }
// GetLastOpenedFile returns the last opened file in the workspace
func (h *Handler) GetLastOpenedFile() http.HandlerFunc { func (h *Handler) GetLastOpenedFile() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -134,6 +148,7 @@ func (h *Handler) GetLastOpenedFile() http.HandlerFunc {
} }
} }
// UpdateLastOpenedFile updates the last opened file in the workspace
func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc { func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)

View File

@@ -7,6 +7,7 @@ import (
"novamd/internal/context" "novamd/internal/context"
) )
// StageCommitAndPush stages, commits, and pushes changes to the remote repository
func (h *Handler) StageCommitAndPush() http.HandlerFunc { func (h *Handler) StageCommitAndPush() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -38,6 +39,7 @@ func (h *Handler) StageCommitAndPush() http.HandlerFunc {
} }
} }
// PullChanges pulls changes from the remote repository
func (h *Handler) PullChanges() http.HandlerFunc { func (h *Handler) PullChanges() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)

View File

@@ -0,0 +1,188 @@
//go:build integration
package handlers_test
import (
"bytes"
"encoding/json"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/go-chi/chi/v5"
"golang.org/x/crypto/bcrypt"
"novamd/internal/api"
"novamd/internal/auth"
"novamd/internal/db"
"novamd/internal/handlers"
"novamd/internal/models"
"novamd/internal/secrets"
"novamd/internal/storage"
)
// testHarness encapsulates all the dependencies needed for testing
type testHarness struct {
DB db.TestDatabase
Storage storage.Manager
Router *chi.Mux
Handler *handlers.Handler
JWTManager auth.JWTManager
SessionSvc *auth.SessionService
AdminUser *models.User
AdminToken string
RegularUser *models.User
RegularToken string
TempDirectory string
}
// setupTestHarness creates a new test environment
func setupTestHarness(t *testing.T) *testHarness {
t.Helper()
// Create temporary directory for test files
tempDir, err := os.MkdirTemp("", "novamd-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
// Initialize test database
secretsSvc, err := secrets.NewService("YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=") // test key
if err != nil {
t.Fatalf("Failed to initialize secrets service: %v", err)
}
database, err := db.NewTestDB(":memory:", secretsSvc)
if err != nil {
t.Fatalf("Failed to initialize test database: %v", err)
}
if err := database.Migrate(); err != nil {
t.Fatalf("Failed to run migrations: %v", err)
}
// Initialize storage
storageSvc := storage.NewService(tempDir)
// Initialize JWT service
jwtSvc, err := auth.NewJWTService(auth.JWTConfig{
SigningKey: "test-key",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 24 * time.Hour,
})
if err != nil {
t.Fatalf("Failed to initialize JWT service: %v", err)
}
// Initialize session service
sessionSvc := auth.NewSessionService(database, jwtSvc)
// Create handler
handler := &handlers.Handler{
DB: database,
Storage: storageSvc,
}
// Set up router with middlewares
router := chi.NewRouter()
authMiddleware := auth.NewMiddleware(jwtSvc)
router.Route("/api/v1", func(r chi.Router) {
api.SetupRoutes(r, database, storageSvc, authMiddleware, sessionSvc)
})
// Create test users
adminUser, adminToken := createTestUser(t, database, sessionSvc, "admin@test.com", "admin123", models.RoleAdmin)
regularUser, regularToken := createTestUser(t, database, sessionSvc, "user@test.com", "user123", models.RoleEditor)
return &testHarness{
DB: database,
Storage: storageSvc,
Router: router,
Handler: handler,
JWTManager: jwtSvc,
SessionSvc: sessionSvc,
AdminUser: adminUser,
AdminToken: adminToken,
RegularUser: regularUser,
RegularToken: regularToken,
TempDirectory: tempDir,
}
}
// teardownTestHarness cleans up the test environment
func (h *testHarness) teardown(t *testing.T) {
t.Helper()
if err := h.DB.Close(); err != nil {
t.Errorf("Failed to close database: %v", err)
}
if err := os.RemoveAll(h.TempDirectory); err != nil {
t.Errorf("Failed to remove temp directory: %v", err)
}
}
// createTestUser creates a test user and returns the user and access token
func createTestUser(t *testing.T, db db.Database, sessionSvc *auth.SessionService, email, password string, role models.UserRole) (*models.User, string) {
t.Helper()
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
t.Fatalf("Failed to hash password: %v", err)
}
user := &models.User{
Email: email,
DisplayName: "Test User",
PasswordHash: string(hashedPassword),
Role: role,
}
user, err = db.CreateUser(user)
if err != nil {
t.Fatalf("Failed to create user: %v", err)
}
session, accessToken, err := sessionSvc.CreateSession(user.ID, string(user.Role))
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
if session == nil || accessToken == "" {
t.Fatal("Failed to get valid session or token")
}
return user, accessToken
}
// makeRequest is a helper function to make HTTP requests in tests
func (h *testHarness) makeRequest(t *testing.T, method, path string, body interface{}, token string, headers map[string]string) *httptest.ResponseRecorder {
t.Helper()
var reqBody []byte
var err error
if body != nil {
reqBody, err = json.Marshal(body)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
}
req := httptest.NewRequest(method, path, bytes.NewBuffer(reqBody))
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
req.Header.Set("Content-Type", "application/json")
// Add any additional headers
for k, v := range headers {
req.Header.Set(k, v)
}
rr := httptest.NewRecorder()
h.Router.ServeHTTP(rr, req)
return rr
}

View File

@@ -12,12 +12,14 @@ type StaticHandler struct {
staticPath string staticPath string
} }
// NewStaticHandler creates a new StaticHandler with the given static path
func NewStaticHandler(staticPath string) *StaticHandler { func NewStaticHandler(staticPath string) *StaticHandler {
return &StaticHandler{ return &StaticHandler{
staticPath: staticPath, staticPath: staticPath,
} }
} }
// ServeHTTP serves the static files
func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Get the requested path // Get the requested path
requestedPath := r.URL.Path requestedPath := r.URL.Path

View File

@@ -1,6 +1,7 @@
package handlers package handlers
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"net/http" "net/http"
@@ -9,6 +10,7 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
// UpdateProfileRequest represents a user profile update request
type UpdateProfileRequest struct { type UpdateProfileRequest struct {
DisplayName string `json:"displayName"` DisplayName string `json:"displayName"`
Email string `json:"email"` Email string `json:"email"`
@@ -16,10 +18,12 @@ type UpdateProfileRequest struct {
NewPassword string `json:"newPassword"` NewPassword string `json:"newPassword"`
} }
// DeleteAccountRequest represents a user account deletion request
type DeleteAccountRequest struct { type DeleteAccountRequest struct {
Password string `json:"password"` Password string `json:"password"`
} }
// GetUser returns the current user's profile
func (h *Handler) GetUser() http.HandlerFunc { func (h *Handler) GetUser() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -64,7 +68,11 @@ func (h *Handler) UpdateProfile() http.HandlerFunc {
http.Error(w, "Failed to start transaction", http.StatusInternalServerError) http.Error(w, "Failed to start transaction", http.StatusInternalServerError)
return return
} }
defer tx.Rollback() defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError)
}
}()
// Handle password update if requested // Handle password update if requested
if req.NewPassword != "" { if req.NewPassword != "" {
@@ -188,7 +196,11 @@ func (h *Handler) DeleteAccount() http.HandlerFunc {
http.Error(w, "Failed to start transaction", http.StatusInternalServerError) http.Error(w, "Failed to start transaction", http.StatusInternalServerError)
return return
} }
defer tx.Rollback() defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError)
}
}()
// Get user's workspaces for cleanup // Get user's workspaces for cleanup
workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID) workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID)

View File

@@ -1,14 +1,15 @@
package handlers package handlers
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"novamd/internal/context" "novamd/internal/context"
"novamd/internal/models" "novamd/internal/models"
) )
// ListWorkspaces returns a list of all workspaces for the current user
func (h *Handler) ListWorkspaces() http.HandlerFunc { func (h *Handler) ListWorkspaces() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -26,6 +27,7 @@ func (h *Handler) ListWorkspaces() http.HandlerFunc {
} }
} }
// CreateWorkspace creates a new workspace
func (h *Handler) CreateWorkspace() http.HandlerFunc { func (h *Handler) CreateWorkspace() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -54,6 +56,7 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc {
} }
} }
// GetWorkspace returns the current workspace
func (h *Handler) GetWorkspace() http.HandlerFunc { func (h *Handler) GetWorkspace() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -81,6 +84,7 @@ func gitSettingsChanged(new, old *models.Workspace) bool {
return false return false
} }
// UpdateWorkspace updates the current workspace
func (h *Handler) UpdateWorkspace() http.HandlerFunc { func (h *Handler) UpdateWorkspace() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -132,6 +136,7 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc {
} }
} }
// DeleteWorkspace deletes the current workspace
func (h *Handler) DeleteWorkspace() http.HandlerFunc { func (h *Handler) DeleteWorkspace() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -168,7 +173,11 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
http.Error(w, "Failed to start transaction", http.StatusInternalServerError) http.Error(w, "Failed to start transaction", http.StatusInternalServerError)
return return
} }
defer tx.Rollback() defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError)
}
}()
// Update last workspace ID first // Update last workspace ID first
err = h.DB.UpdateLastWorkspaceTx(tx, ctx.UserID, nextWorkspaceID) err = h.DB.UpdateLastWorkspaceTx(tx, ctx.UserID, nextWorkspaceID)
@@ -195,6 +204,7 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc {
} }
} }
// GetLastWorkspaceName returns the name of the last opened workspace
func (h *Handler) GetLastWorkspaceName() http.HandlerFunc { func (h *Handler) GetLastWorkspaceName() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -212,6 +222,7 @@ func (h *Handler) GetLastWorkspaceName() http.HandlerFunc {
} }
} }
// UpdateLastWorkspaceName updates the name of the last opened workspace
func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc { func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx, ok := context.GetRequestContext(w, r) ctx, ok := context.GetRequestContext(w, r)
@@ -224,13 +235,11 @@ func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc {
} }
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
fmt.Println(err)
http.Error(w, "Invalid request body", http.StatusBadRequest) http.Error(w, "Invalid request body", http.StatusBadRequest)
return return
} }
if err := h.DB.UpdateLastWorkspace(ctx.UserID, requestBody.WorkspaceName); err != nil { if err := h.DB.UpdateLastWorkspace(ctx.UserID, requestBody.WorkspaceName); err != nil {
fmt.Println(err)
http.Error(w, "Failed to update last workspace", http.StatusInternalServerError) http.Error(w, "Failed to update last workspace", http.StatusInternalServerError)
return return
} }