diff --git a/.vscode/settings.json b/.vscode/settings.json index 34f5162..52c38cb 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,7 +14,7 @@ "go.lintTool": "golangci-lint", "go.lintOnSave": "package", "go.formatTool": "goimports", - "go.testFlags": ["-tags=test"], + "go.testFlags": ["-tags=test,integration"], "[go]": { "editor.formatOnSave": true, "editor.codeActionsOnSave": { @@ -25,6 +25,6 @@ "gopls": { "usePlaceholders": true, "staticcheck": true, - "buildFlags": ["-tags", "test"] + "buildFlags": ["-tags", "test,integration"] } } diff --git a/server/go.mod b/server/go.mod index 0af2cf4..90ae840 100644 --- a/server/go.mod +++ b/server/go.mod @@ -11,6 +11,7 @@ require ( github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/mattn/go-sqlite3 v1.14.23 + github.com/stretchr/testify v1.9.0 github.com/unrolled/secure v1.17.0 golang.org/x/crypto v0.21.0 ) @@ -22,6 +23,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudflare/circl v1.3.7 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect @@ -33,6 +35,7 @@ require ( github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.2.2 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect @@ -42,4 +45,5 @@ require ( golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.13.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/server/internal/auth/jwt.go b/server/internal/auth/jwt.go index 9b65db8..f9b4b23 100644 --- a/server/internal/auth/jwt.go +++ b/server/internal/auth/jwt.go @@ -2,6 +2,8 @@ package auth import ( + "crypto/rand" + "encoding/hex" "fmt" "time" @@ -87,11 +89,19 @@ func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, erro // Returns the signed token string or an error func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) { now := time.Now() + + // Add a random nonce to ensure uniqueness + nonce := make([]byte, 8) + if _, err := rand.Read(nonce); err != nil { + return "", fmt.Errorf("failed to generate nonce: %w", err) + } + claims := Claims{ RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(now.Add(expiry)), IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), + ID: hex.EncodeToString(nonce), }, UserID: userID, Role: role, diff --git a/server/internal/auth/session.go b/server/internal/auth/session.go index b6217d6..872b2c1 100644 --- a/server/internal/auth/session.go +++ b/server/internal/auth/session.go @@ -76,8 +76,8 @@ func (s *SessionService) CreateSession(userID int, role string) (*models.Session // - string: a new access token // - error: any error that occurred func (s *SessionService) RefreshSession(refreshToken string) (string, error) { - // Get session from database - _, err := s.db.GetSessionByRefreshToken(refreshToken) + // Get session from database first + session, err := s.db.GetSessionByRefreshToken(refreshToken) if err != nil { return "", fmt.Errorf("invalid session: %w", err) } @@ -88,6 +88,11 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) { return "", fmt.Errorf("invalid refresh token: %w", err) } + // Double check that the claims match the session + if claims.UserID != session.UserID { + return "", fmt.Errorf("token does not match session") + } + // Generate a new access token return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role) } diff --git a/server/internal/db/sessions.go b/server/internal/db/sessions.go index 596dc64..ecc4a72 100644 --- a/server/internal/db/sessions.go +++ b/server/internal/db/sessions.go @@ -25,9 +25,9 @@ func (db *database) CreateSession(session *models.Session) error { func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { session := &models.Session{} err := db.QueryRow(` - SELECT id, user_id, refresh_token, expires_at, created_at - FROM sessions - WHERE refresh_token = ? AND expires_at > ?`, + SELECT id, user_id, refresh_token, expires_at, created_at + FROM sessions + WHERE refresh_token = ? AND expires_at > ?`, refreshToken, time.Now(), ).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) diff --git a/server/internal/handlers/admin_handlers.go b/server/internal/handlers/admin_handlers.go index 5ed8d42..0ae33e2 100644 --- a/server/internal/handlers/admin_handlers.go +++ b/server/internal/handlers/admin_handlers.go @@ -1,3 +1,4 @@ +// Package handlers contains the request handlers for the api routes. package handlers import ( diff --git a/server/internal/handlers/auth_handlers.go b/server/internal/handlers/auth_handlers.go index d6aec83..a1bc4b7 100644 --- a/server/internal/handlers/auth_handlers.go +++ b/server/internal/handlers/auth_handlers.go @@ -10,11 +10,13 @@ import ( "golang.org/x/crypto/bcrypt" ) +// LoginRequest represents a user login request type LoginRequest struct { Email string `json:"email"` Password string `json:"password"` } +// LoginResponse represents a user login response type LoginResponse struct { AccessToken string `json:"accessToken"` RefreshToken string `json:"refreshToken"` @@ -22,10 +24,12 @@ type LoginResponse struct { Session *models.Session `json:"session"` } +// RefreshRequest represents a refresh token request type RefreshRequest struct { RefreshToken string `json:"refreshToken"` } +// RefreshResponse represents a refresh token response type RefreshResponse struct { AccessToken string `json:"accessToken"` } diff --git a/server/internal/handlers/auth_handlers_integration_test.go b/server/internal/handlers/auth_handlers_integration_test.go new file mode 100644 index 0000000..d917e32 --- /dev/null +++ b/server/internal/handlers/auth_handlers_integration_test.go @@ -0,0 +1,232 @@ +//go:build integration + +package handlers_test + +import ( + "encoding/json" + "net/http" + "testing" + + "novamd/internal/handlers" + "novamd/internal/models" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAuthHandlers_Integration(t *testing.T) { + h := setupTestHarness(t) + defer h.teardown(t) + + t.Run("login", func(t *testing.T) { + t.Run("successful login - admin user", func(t *testing.T) { + loginReq := handlers.LoginRequest{ + Email: "admin@test.com", + Password: "admin123", + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil) + require.Equal(t, http.StatusOK, rr.Code) + + var resp handlers.LoginResponse + err := json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + + assert.NotEmpty(t, resp.AccessToken) + assert.NotEmpty(t, resp.RefreshToken) + assert.NotNil(t, resp.User) + assert.Equal(t, loginReq.Email, resp.User.Email) + assert.Equal(t, models.RoleAdmin, resp.User.Role) + }) + + t.Run("successful login - regular user", func(t *testing.T) { + loginReq := handlers.LoginRequest{ + Email: "user@test.com", + Password: "user123", + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil) + require.Equal(t, http.StatusOK, rr.Code) + + var resp handlers.LoginResponse + err := json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + + assert.NotEmpty(t, resp.AccessToken) + assert.NotEmpty(t, resp.RefreshToken) + assert.NotNil(t, resp.User) + assert.Equal(t, loginReq.Email, resp.User.Email) + assert.Equal(t, models.RoleEditor, resp.User.Role) + }) + + t.Run("login failures", func(t *testing.T) { + tests := []struct { + name string + request handlers.LoginRequest + wantCode int + }{ + { + name: "wrong password", + request: handlers.LoginRequest{ + Email: "user@test.com", + Password: "wrongpassword", + }, + wantCode: http.StatusUnauthorized, + }, + { + name: "non-existent user", + request: handlers.LoginRequest{ + Email: "nonexistent@test.com", + Password: "password123", + }, + wantCode: http.StatusUnauthorized, + }, + { + name: "empty email", + request: handlers.LoginRequest{ + Email: "", + Password: "password123", + }, + wantCode: http.StatusBadRequest, + }, + { + name: "empty password", + request: handlers.LoginRequest{ + Email: "user@test.com", + Password: "", + }, + wantCode: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", tt.request, "", nil) + assert.Equal(t, tt.wantCode, rr.Code) + }) + } + }) + }) + + t.Run("refresh token", func(t *testing.T) { + t.Run("successful token refresh", func(t *testing.T) { + // First login to get refresh token + loginReq := handlers.LoginRequest{ + Email: "user@test.com", + Password: "user123", + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil) + require.Equal(t, http.StatusOK, rr.Code) + + var loginResp handlers.LoginResponse + err := json.NewDecoder(rr.Body).Decode(&loginResp) + require.NoError(t, err) + + // Now try to refresh the token + refreshReq := handlers.RefreshRequest{ + RefreshToken: loginResp.RefreshToken, + } + + rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", refreshReq, "", nil) + require.Equal(t, http.StatusOK, rr.Code) + + var refreshResp handlers.RefreshResponse + err = json.NewDecoder(rr.Body).Decode(&refreshResp) + require.NoError(t, err) + assert.NotEmpty(t, refreshResp.AccessToken) + }) + + t.Run("refresh failures", func(t *testing.T) { + tests := []struct { + name string + request handlers.RefreshRequest + wantCode int + }{ + { + name: "invalid refresh token", + request: handlers.RefreshRequest{ + RefreshToken: "invalid-token", + }, + wantCode: http.StatusUnauthorized, + }, + { + name: "empty refresh token", + request: handlers.RefreshRequest{ + RefreshToken: "", + }, + wantCode: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", tt.request, "", nil) + assert.Equal(t, tt.wantCode, rr.Code) + }) + } + }) + }) + + t.Run("logout", func(t *testing.T) { + t.Run("successful logout", func(t *testing.T) { + // First login to get session + loginReq := handlers.LoginRequest{ + Email: "user@test.com", + Password: "user123", + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil) + require.Equal(t, http.StatusOK, rr.Code) + + var loginResp handlers.LoginResponse + err := json.NewDecoder(rr.Body).Decode(&loginResp) + require.NoError(t, err) + + // Now logout using session ID from login response + headers := map[string]string{ + "X-Session-ID": loginResp.Session.ID, + } + rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/logout", nil, loginResp.AccessToken, headers) + require.Equal(t, http.StatusOK, rr.Code) + + // Try to use the refresh token - should fail + refreshReq := handlers.RefreshRequest{ + RefreshToken: loginResp.RefreshToken, + } + + rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", refreshReq, "", nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + + t.Run("logout without session ID", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/logout", nil, h.RegularToken, nil) + assert.Equal(t, http.StatusBadRequest, rr.Code) + }) + }) + + t.Run("get current user", func(t *testing.T) { + t.Run("successful get current user", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var user models.User + err := json.NewDecoder(rr.Body).Decode(&user) + require.NoError(t, err) + + assert.Equal(t, h.RegularUser.ID, user.ID) + assert.Equal(t, h.RegularUser.Email, user.Email) + assert.Equal(t, h.RegularUser.Role, user.Role) + }) + + t.Run("get current user without token", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "", nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + + t.Run("get current user with invalid token", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "invalid-token", nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + }) +} diff --git a/server/internal/handlers/file_handlers.go b/server/internal/handlers/file_handlers.go index 9f75eb2..3f53c17 100644 --- a/server/internal/handlers/file_handlers.go +++ b/server/internal/handlers/file_handlers.go @@ -10,6 +10,7 @@ import ( "github.com/go-chi/chi/v5" ) +// ListFiles returns a list of all files in the workspace func (h *Handler) ListFiles() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -27,6 +28,7 @@ func (h *Handler) ListFiles() http.HandlerFunc { } } +// LookupFileByName returns the paths of files with the given name func (h *Handler) LookupFileByName() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -50,6 +52,7 @@ func (h *Handler) LookupFileByName() http.HandlerFunc { } } +// GetFileContent returns the content of a file func (h *Handler) GetFileContent() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -65,10 +68,15 @@ func (h *Handler) GetFileContent() http.HandlerFunc { } w.Header().Set("Content-Type", "text/plain") - w.Write(content) + _, err = w.Write(content) + if err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + return + } } } +// SaveFile saves the content of a file func (h *Handler) SaveFile() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -93,6 +101,7 @@ func (h *Handler) SaveFile() http.HandlerFunc { } } +// DeleteFile deletes a file func (h *Handler) DeleteFile() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -108,10 +117,15 @@ func (h *Handler) DeleteFile() http.HandlerFunc { } w.WriteHeader(http.StatusOK) - w.Write([]byte("File deleted successfully")) + _, err = w.Write([]byte("File deleted successfully")) + if err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + return + } } } +// GetLastOpenedFile returns the last opened file in the workspace func (h *Handler) GetLastOpenedFile() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -134,6 +148,7 @@ func (h *Handler) GetLastOpenedFile() http.HandlerFunc { } } +// UpdateLastOpenedFile updates the last opened file in the workspace func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) diff --git a/server/internal/handlers/git_handlers.go b/server/internal/handlers/git_handlers.go index ea6eeff..f34b12a 100644 --- a/server/internal/handlers/git_handlers.go +++ b/server/internal/handlers/git_handlers.go @@ -7,6 +7,7 @@ import ( "novamd/internal/context" ) +// StageCommitAndPush stages, commits, and pushes changes to the remote repository func (h *Handler) StageCommitAndPush() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -38,6 +39,7 @@ func (h *Handler) StageCommitAndPush() http.HandlerFunc { } } +// PullChanges pulls changes from the remote repository func (h *Handler) PullChanges() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) diff --git a/server/internal/handlers/integration_test.go b/server/internal/handlers/integration_test.go new file mode 100644 index 0000000..4201fdd --- /dev/null +++ b/server/internal/handlers/integration_test.go @@ -0,0 +1,188 @@ +//go:build integration + +package handlers_test + +import ( + "bytes" + "encoding/json" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "golang.org/x/crypto/bcrypt" + + "novamd/internal/api" + "novamd/internal/auth" + "novamd/internal/db" + "novamd/internal/handlers" + "novamd/internal/models" + "novamd/internal/secrets" + "novamd/internal/storage" +) + +// testHarness encapsulates all the dependencies needed for testing +type testHarness struct { + DB db.TestDatabase + Storage storage.Manager + Router *chi.Mux + Handler *handlers.Handler + JWTManager auth.JWTManager + SessionSvc *auth.SessionService + AdminUser *models.User + AdminToken string + RegularUser *models.User + RegularToken string + TempDirectory string +} + +// setupTestHarness creates a new test environment +func setupTestHarness(t *testing.T) *testHarness { + t.Helper() + + // Create temporary directory for test files + tempDir, err := os.MkdirTemp("", "novamd-test-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + + // Initialize test database + secretsSvc, err := secrets.NewService("YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=") // test key + if err != nil { + t.Fatalf("Failed to initialize secrets service: %v", err) + } + + database, err := db.NewTestDB(":memory:", secretsSvc) + if err != nil { + t.Fatalf("Failed to initialize test database: %v", err) + } + + if err := database.Migrate(); err != nil { + t.Fatalf("Failed to run migrations: %v", err) + } + + // Initialize storage + storageSvc := storage.NewService(tempDir) + + // Initialize JWT service + jwtSvc, err := auth.NewJWTService(auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + }) + if err != nil { + t.Fatalf("Failed to initialize JWT service: %v", err) + } + + // Initialize session service + sessionSvc := auth.NewSessionService(database, jwtSvc) + + // Create handler + handler := &handlers.Handler{ + DB: database, + Storage: storageSvc, + } + + // Set up router with middlewares + router := chi.NewRouter() + authMiddleware := auth.NewMiddleware(jwtSvc) + router.Route("/api/v1", func(r chi.Router) { + api.SetupRoutes(r, database, storageSvc, authMiddleware, sessionSvc) + }) + + // Create test users + adminUser, adminToken := createTestUser(t, database, sessionSvc, "admin@test.com", "admin123", models.RoleAdmin) + regularUser, regularToken := createTestUser(t, database, sessionSvc, "user@test.com", "user123", models.RoleEditor) + + return &testHarness{ + DB: database, + Storage: storageSvc, + Router: router, + Handler: handler, + JWTManager: jwtSvc, + SessionSvc: sessionSvc, + AdminUser: adminUser, + AdminToken: adminToken, + RegularUser: regularUser, + RegularToken: regularToken, + TempDirectory: tempDir, + } +} + +// teardownTestHarness cleans up the test environment +func (h *testHarness) teardown(t *testing.T) { + t.Helper() + + if err := h.DB.Close(); err != nil { + t.Errorf("Failed to close database: %v", err) + } + + if err := os.RemoveAll(h.TempDirectory); err != nil { + t.Errorf("Failed to remove temp directory: %v", err) + } +} + +// createTestUser creates a test user and returns the user and access token +func createTestUser(t *testing.T, db db.Database, sessionSvc *auth.SessionService, email, password string, role models.UserRole) (*models.User, string) { + t.Helper() + + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + t.Fatalf("Failed to hash password: %v", err) + } + + user := &models.User{ + Email: email, + DisplayName: "Test User", + PasswordHash: string(hashedPassword), + Role: role, + } + + user, err = db.CreateUser(user) + if err != nil { + t.Fatalf("Failed to create user: %v", err) + } + + session, accessToken, err := sessionSvc.CreateSession(user.ID, string(user.Role)) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + if session == nil || accessToken == "" { + t.Fatal("Failed to get valid session or token") + } + + return user, accessToken +} + +// makeRequest is a helper function to make HTTP requests in tests +func (h *testHarness) makeRequest(t *testing.T, method, path string, body interface{}, token string, headers map[string]string) *httptest.ResponseRecorder { + t.Helper() + + var reqBody []byte + var err error + + if body != nil { + reqBody, err = json.Marshal(body) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + } + + req := httptest.NewRequest(method, path, bytes.NewBuffer(reqBody)) + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + req.Header.Set("Content-Type", "application/json") + + // Add any additional headers + for k, v := range headers { + req.Header.Set(k, v) + } + + rr := httptest.NewRecorder() + h.Router.ServeHTTP(rr, req) + + return rr +} diff --git a/server/internal/handlers/static_handler.go b/server/internal/handlers/static_handler.go index 8dfb710..8360b3c 100644 --- a/server/internal/handlers/static_handler.go +++ b/server/internal/handlers/static_handler.go @@ -12,12 +12,14 @@ type StaticHandler struct { staticPath string } +// NewStaticHandler creates a new StaticHandler with the given static path func NewStaticHandler(staticPath string) *StaticHandler { return &StaticHandler{ staticPath: staticPath, } } +// ServeHTTP serves the static files func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Get the requested path requestedPath := r.URL.Path diff --git a/server/internal/handlers/user_handlers.go b/server/internal/handlers/user_handlers.go index 678c8e5..0249327 100644 --- a/server/internal/handlers/user_handlers.go +++ b/server/internal/handlers/user_handlers.go @@ -1,6 +1,7 @@ package handlers import ( + "database/sql" "encoding/json" "net/http" @@ -9,6 +10,7 @@ import ( "golang.org/x/crypto/bcrypt" ) +// UpdateProfileRequest represents a user profile update request type UpdateProfileRequest struct { DisplayName string `json:"displayName"` Email string `json:"email"` @@ -16,10 +18,12 @@ type UpdateProfileRequest struct { NewPassword string `json:"newPassword"` } +// DeleteAccountRequest represents a user account deletion request type DeleteAccountRequest struct { Password string `json:"password"` } +// GetUser returns the current user's profile func (h *Handler) GetUser() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -64,7 +68,11 @@ func (h *Handler) UpdateProfile() http.HandlerFunc { http.Error(w, "Failed to start transaction", http.StatusInternalServerError) return } - defer tx.Rollback() + defer func() { + if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { + http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError) + } + }() // Handle password update if requested if req.NewPassword != "" { @@ -188,7 +196,11 @@ func (h *Handler) DeleteAccount() http.HandlerFunc { http.Error(w, "Failed to start transaction", http.StatusInternalServerError) return } - defer tx.Rollback() + defer func() { + if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { + http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError) + } + }() // Get user's workspaces for cleanup workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID) diff --git a/server/internal/handlers/workspace_handlers.go b/server/internal/handlers/workspace_handlers.go index 8dee442..10ae3cf 100644 --- a/server/internal/handlers/workspace_handlers.go +++ b/server/internal/handlers/workspace_handlers.go @@ -1,14 +1,15 @@ package handlers import ( + "database/sql" "encoding/json" - "fmt" "net/http" "novamd/internal/context" "novamd/internal/models" ) +// ListWorkspaces returns a list of all workspaces for the current user func (h *Handler) ListWorkspaces() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -26,6 +27,7 @@ func (h *Handler) ListWorkspaces() http.HandlerFunc { } } +// CreateWorkspace creates a new workspace func (h *Handler) CreateWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -54,6 +56,7 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc { } } +// GetWorkspace returns the current workspace func (h *Handler) GetWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -81,6 +84,7 @@ func gitSettingsChanged(new, old *models.Workspace) bool { return false } +// UpdateWorkspace updates the current workspace func (h *Handler) UpdateWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -132,6 +136,7 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc { } } +// DeleteWorkspace deletes the current workspace func (h *Handler) DeleteWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -168,7 +173,11 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc { http.Error(w, "Failed to start transaction", http.StatusInternalServerError) return } - defer tx.Rollback() + defer func() { + if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { + http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError) + } + }() // Update last workspace ID first err = h.DB.UpdateLastWorkspaceTx(tx, ctx.UserID, nextWorkspaceID) @@ -195,6 +204,7 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc { } } +// GetLastWorkspaceName returns the name of the last opened workspace func (h *Handler) GetLastWorkspaceName() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -212,6 +222,7 @@ func (h *Handler) GetLastWorkspaceName() http.HandlerFunc { } } +// UpdateLastWorkspaceName updates the name of the last opened workspace func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -224,13 +235,11 @@ func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc { } if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { - fmt.Println(err) http.Error(w, "Invalid request body", http.StatusBadRequest) return } if err := h.DB.UpdateLastWorkspace(ctx.UserID, requestBody.WorkspaceName); err != nil { - fmt.Println(err) http.Error(w, "Failed to update last workspace", http.StatusInternalServerError) return }