From 93963b1867c561f31158eaccdf0e3034e7cbe434 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 13 Nov 2024 22:31:04 +0100 Subject: [PATCH 01/38] Refactor filesystem to make it testable --- server/internal/api/routes.go | 2 +- server/internal/filesystem/files.go | 106 ++++++++++++++++------- server/internal/filesystem/filesystem.go | 81 ++++++++++++++--- server/internal/filesystem/git.go | 52 +++++++---- server/internal/filesystem/workspace.go | 38 ++++---- server/internal/handlers/handlers.go | 4 +- server/internal/user/user.go | 4 +- 7 files changed, 210 insertions(+), 77 deletions(-) diff --git a/server/internal/api/routes.go b/server/internal/api/routes.go index 9b37986..af4df5b 100644 --- a/server/internal/api/routes.go +++ b/server/internal/api/routes.go @@ -12,7 +12,7 @@ import ( ) // SetupRoutes configures the API routes -func SetupRoutes(r chi.Router, db *db.DB, fs *filesystem.FileSystem, authMiddleware *auth.Middleware, sessionService *auth.SessionService) { +func SetupRoutes(r chi.Router, db *db.DB, fs *filesystem.Storage, authMiddleware *auth.Middleware, sessionService *auth.SessionService) { handler := &handlers.Handler{ DB: db, diff --git a/server/internal/filesystem/files.go b/server/internal/filesystem/files.go index 9b876d8..eb5ab16 100644 --- a/server/internal/filesystem/files.go +++ b/server/internal/filesystem/files.go @@ -10,22 +10,29 @@ import ( "strings" ) -// FileNode represents a file or directory in the file system. -type FileNode struct { - ID string `json:"id"` - Name string `json:"name"` - Path string `json:"path"` - Children []FileNode `json:"children,omitempty"` +// StorageNode represents a file or directory in the storage. +type StorageNode struct { + ID string `json:"id"` + Name string `json:"name"` + Path string `json:"path"` + Children []StorageNode `json:"children,omitempty"` } // ListFilesRecursively returns a list of all files in the workspace directory and its subdirectories. -func (fs *FileSystem) ListFilesRecursively(userID, workspaceID int) ([]FileNode, error) { - workspacePath := fs.GetWorkspacePath(userID, workspaceID) - return fs.walkDirectory(workspacePath, "") +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace to list files in +// Returns: +// - nodes: a list of files and directories in the workspace +// - error: any error that occurred during listing +func (s *Storage) ListFilesRecursively(userID, workspaceID int) ([]StorageNode, error) { + workspacePath := s.GetWorkspacePath(userID, workspaceID) + return s.walkDirectory(workspacePath, "") } -func (fs *FileSystem) walkDirectory(dir, prefix string) ([]FileNode, error) { - entries, err := os.ReadDir(dir) +// walkDirectory recursively walks the directory and returns a list of files and directories. +func (s *Storage) walkDirectory(dir, prefix string) ([]StorageNode, error) { + entries, err := s.fs.ReadDir(dir) if err != nil { return nil, err } @@ -49,7 +56,7 @@ func (fs *FileSystem) walkDirectory(dir, prefix string) ([]FileNode, error) { }) // Create combined slice with directories first, then files - nodes := make([]FileNode, 0, len(entries)) + nodes := make([]StorageNode, 0, len(entries)) // Add directories first for _, entry := range dirs { @@ -57,12 +64,12 @@ func (fs *FileSystem) walkDirectory(dir, prefix string) ([]FileNode, error) { path := filepath.Join(prefix, name) fullPath := filepath.Join(dir, name) - children, err := fs.walkDirectory(fullPath, path) + children, err := s.walkDirectory(fullPath, path) if err != nil { return nil, err } - node := FileNode{ + node := StorageNode{ ID: path, Name: name, Path: path, @@ -76,7 +83,7 @@ func (fs *FileSystem) walkDirectory(dir, prefix string) ([]FileNode, error) { name := entry.Name() path := filepath.Join(prefix, name) - node := FileNode{ + node := StorageNode{ ID: path, Name: name, Path: path, @@ -88,9 +95,16 @@ func (fs *FileSystem) walkDirectory(dir, prefix string) ([]FileNode, error) { } // FindFileByName returns a list of file paths that match the given filename. -func (fs *FileSystem) FindFileByName(userID, workspaceID int, filename string) ([]string, error) { +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace to search for the file +// - filename: the name of the file to search for +// Returns: +// - foundPaths: a list of file paths that match the filename +// - error: any error that occurred during the search +func (s *Storage) FindFileByName(userID, workspaceID int, filename string) ([]string, error) { var foundPaths []string - workspacePath := fs.GetWorkspacePath(userID, workspaceID) + workspacePath := s.GetWorkspacePath(userID, workspaceID) err := filepath.Walk(workspacePath, func(path string, info os.FileInfo, err error) error { if err != nil { @@ -120,36 +134,56 @@ func (fs *FileSystem) FindFileByName(userID, workspaceID int, filename string) ( } // GetFileContent returns the content of the file at the given path. -func (fs *FileSystem) GetFileContent(userID, workspaceID int, filePath string) ([]byte, error) { - fullPath, err := fs.ValidatePath(userID, workspaceID, filePath) +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace to get the file from +// - filePath: the path of the file to get +// Returns: +// - content: the content of the file +// - error: any error that occurred during reading +func (s *Storage) GetFileContent(userID, workspaceID int, filePath string) ([]byte, error) { + fullPath, err := s.ValidatePath(userID, workspaceID, filePath) if err != nil { return nil, err } - return os.ReadFile(fullPath) + return s.fs.ReadFile(fullPath) } // SaveFile writes the content to the file at the given path. -func (fs *FileSystem) SaveFile(userID, workspaceID int, filePath string, content []byte) error { - fullPath, err := fs.ValidatePath(userID, workspaceID, filePath) +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace to save the file to +// - filePath: the path of the file to save +// - content: the content to write to the file +// Returns: +// - error: any error that occurred during saving +func (s *Storage) SaveFile(userID, workspaceID int, filePath string, content []byte) error { + fullPath, err := s.ValidatePath(userID, workspaceID, filePath) if err != nil { return err } dir := filepath.Dir(fullPath) - if err := os.MkdirAll(dir, 0755); err != nil { + if err := s.fs.MkdirAll(dir, 0755); err != nil { return err } - return os.WriteFile(fullPath, content, 0644) + return s.fs.WriteFile(fullPath, content, 0644) } // DeleteFile deletes the file at the given path. -func (fs *FileSystem) DeleteFile(userID, workspaceID int, filePath string) error { - fullPath, err := fs.ValidatePath(userID, workspaceID, filePath) +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace to delete the file from +// - filePath: the path of the file to delete +// Returns: +// - error: any error that occurred during deletion +func (s *Storage) DeleteFile(userID, workspaceID int, filePath string) error { + fullPath, err := s.ValidatePath(userID, workspaceID, filePath) if err != nil { return err } - return os.Remove(fullPath) + return s.fs.Remove(fullPath) } // FileCountStats holds statistics about files in a workspace @@ -165,19 +199,27 @@ type FileCountStats struct { // Returns: // - result: statistics about the files in the workspace // - error: any error that occurred during counting -func (fs *FileSystem) GetFileStats(userID, workspaceID int) (*FileCountStats, error) { - workspacePath := fs.GetWorkspacePath(userID, workspaceID) +func (s *Storage) GetFileStats(userID, workspaceID int) (*FileCountStats, error) { + workspacePath := s.GetWorkspacePath(userID, workspaceID) // Check if workspace exists - if _, err := os.Stat(workspacePath); os.IsNotExist(err) { + if _, err := s.fs.Stat(workspacePath); s.fs.IsNotExist(err) { return nil, fmt.Errorf("workspace directory does not exist") } - return fs.countFilesInPath(workspacePath) + return s.countFilesInPath(workspacePath) } -func (fs *FileSystem) countFilesInPath(directoryPath string) (*FileCountStats, error) { +// GetTotalFileStats returns the total file statistics for the storage. +// Returns: +// - result: statistics about the files in the storage +func (s *Storage) GetTotalFileStats() (*FileCountStats, error) { + return s.countFilesInPath(s.RootDir) +} + +// countFilesInPath counts the total number of files and the total size of files in the given directory. +func (s *Storage) countFilesInPath(directoryPath string) (*FileCountStats, error) { result := &FileCountStats{} err := filepath.WalkDir(directoryPath, func(path string, d os.DirEntry, err error) error { diff --git a/server/internal/filesystem/filesystem.go b/server/internal/filesystem/filesystem.go index 35f5637..a08a678 100644 --- a/server/internal/filesystem/filesystem.go +++ b/server/internal/filesystem/filesystem.go @@ -2,28 +2,65 @@ package filesystem import ( "fmt" + "io/fs" "novamd/internal/gitutils" + "os" "path/filepath" "strings" ) -// FileSystem represents the file system structure. -type FileSystem struct { +// fileSystem defines the interface for filesystem operations +type fileSystem interface { + ReadFile(path string) ([]byte, error) + WriteFile(path string, data []byte, perm fs.FileMode) error + Remove(path string) error + MkdirAll(path string, perm fs.FileMode) error + RemoveAll(path string) error + ReadDir(path string) ([]fs.DirEntry, error) + Stat(path string) (fs.FileInfo, error) + IsNotExist(err error) bool +} + +// Storage represents the file system structure. +type Storage struct { + fs fileSystem RootDir string GitRepos map[int]map[int]*gitutils.GitRepo // map[userID]map[workspaceID]*gitutils.GitRepo } -// New creates a new FileSystem instance. -func New(rootDir string) *FileSystem { - return &FileSystem{ +// New creates a new Storage instance. +// Parameters: +// - rootDir: the root directory for the storage +// Returns: +// - result: the new Storage instance +func New(rootDir string) *Storage { + return NewWithFS(rootDir, &osFS{}) +} + +// NewWithFS creates a new Storage instance with the given filesystem. +// Parameters: +// - rootDir: the root directory for the storage +// - fs: the filesystem implementation to use +// Returns: +// - result: the new Storage instance +func NewWithFS(rootDir string, fs fileSystem) *Storage { + return &Storage{ + fs: fs, RootDir: rootDir, GitRepos: make(map[int]map[int]*gitutils.GitRepo), } } // ValidatePath validates the given path and returns the cleaned path if it is valid. -func (fs *FileSystem) ValidatePath(userID, workspaceID int, path string) (string, error) { - workspacePath := fs.GetWorkspacePath(userID, workspaceID) +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace to validate the path for +// - path: the path to validate +// Returns: +// - result: the cleaned path if it is valid +// - error: any error that occurred during validation +func (s *Storage) ValidatePath(userID, workspaceID int, path string) (string, error) { + workspacePath := s.GetWorkspacePath(userID, workspaceID) fullPath := filepath.Join(workspacePath, path) cleanPath := filepath.Clean(fullPath) @@ -34,7 +71,31 @@ func (fs *FileSystem) ValidatePath(userID, workspaceID int, path string) (string return cleanPath, nil } -// GetTotalFileStats returns the total file statistics for the file system. -func (fs *FileSystem) GetTotalFileStats() (*FileCountStats, error) { - return fs.countFilesInPath(fs.RootDir) +// osFS implements the FileSystem interface using the real filesystem. +type osFS struct{} + +// ReadFile reads the file at the given path. +func (f *osFS) ReadFile(path string) ([]byte, error) { return os.ReadFile(path) } + +// WriteFile writes the given data to the file at the given path. +func (f *osFS) WriteFile(path string, data []byte, perm fs.FileMode) error { + return os.WriteFile(path, data, perm) } + +// Remove deletes the file at the given path. +func (f *osFS) Remove(path string) error { return os.Remove(path) } + +// MkdirAll creates the directory at the given path and any necessary parents. +func (f *osFS) MkdirAll(path string, perm fs.FileMode) error { return os.MkdirAll(path, perm) } + +// RemoveAll removes the file or directory at the given path. +func (f *osFS) RemoveAll(path string) error { return os.RemoveAll(path) } + +// ReadDir reads the directory at the given path. +func (f *osFS) ReadDir(path string) ([]fs.DirEntry, error) { return os.ReadDir(path) } + +// Stat returns the FileInfo for the file at the given path. +func (f *osFS) Stat(path string) (fs.FileInfo, error) { return os.Stat(path) } + +// IsNotExist returns true if the error is a "file does not exist" error. +func (f *osFS) IsNotExist(err error) bool { return os.IsNotExist(err) } diff --git a/server/internal/filesystem/git.go b/server/internal/filesystem/git.go index de83ad9..f70d7ad 100644 --- a/server/internal/filesystem/git.go +++ b/server/internal/filesystem/git.go @@ -6,28 +6,45 @@ import ( ) // SetupGitRepo sets up a Git repository for the given user and workspace IDs. -func (fs *FileSystem) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken string) error { - workspacePath := fs.GetWorkspacePath(userID, workspaceID) - if _, ok := fs.GitRepos[userID]; !ok { - fs.GitRepos[userID] = make(map[int]*gitutils.GitRepo) +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace to set up the Git repository for +// - gitURL: the URL of the Git repository +// - gitUser: the username for the Git repository +// - gitToken: the access token for the Git repository +// Returns: +// - error: any error that occurred during setup +func (s *Storage) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken string) error { + workspacePath := s.GetWorkspacePath(userID, workspaceID) + if _, ok := s.GitRepos[userID]; !ok { + s.GitRepos[userID] = make(map[int]*gitutils.GitRepo) } - fs.GitRepos[userID][workspaceID] = gitutils.New(gitURL, gitUser, gitToken, workspacePath) - return fs.GitRepos[userID][workspaceID].EnsureRepo() + s.GitRepos[userID][workspaceID] = gitutils.New(gitURL, gitUser, gitToken, workspacePath) + return s.GitRepos[userID][workspaceID].EnsureRepo() } // DisableGitRepo disables the Git repository for the given user and workspace IDs. -func (fs *FileSystem) DisableGitRepo(userID, workspaceID int) { - if userRepos, ok := fs.GitRepos[userID]; ok { +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace to disable the Git repository for +func (s *Storage) DisableGitRepo(userID, workspaceID int) { + if userRepos, ok := s.GitRepos[userID]; ok { delete(userRepos, workspaceID) if len(userRepos) == 0 { - delete(fs.GitRepos, userID) + delete(s.GitRepos, userID) } } } // StageCommitAndPush stages, commits, and pushes the changes to the Git repository. -func (fs *FileSystem) StageCommitAndPush(userID, workspaceID int, message string) error { - repo, ok := fs.getGitRepo(userID, workspaceID) +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace to commit and push +// - message: the commit message +// Returns: +// - error: any error that occurred during the operation +func (s *Storage) StageCommitAndPush(userID, workspaceID int, message string) error { + repo, ok := s.getGitRepo(userID, workspaceID) if !ok { return fmt.Errorf("git settings not configured for this workspace") } @@ -40,8 +57,13 @@ func (fs *FileSystem) StageCommitAndPush(userID, workspaceID int, message string } // Pull pulls the changes from the remote Git repository. -func (fs *FileSystem) Pull(userID, workspaceID int) error { - repo, ok := fs.getGitRepo(userID, workspaceID) +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace to pull changes for +// Returns: +// - error: any error that occurred during the operation +func (s *Storage) Pull(userID, workspaceID int) error { + repo, ok := s.getGitRepo(userID, workspaceID) if !ok { return fmt.Errorf("git settings not configured for this workspace") } @@ -50,8 +72,8 @@ func (fs *FileSystem) Pull(userID, workspaceID int) error { } // getGitRepo returns the Git repository for the given user and workspace IDs. -func (fs *FileSystem) getGitRepo(userID, workspaceID int) (*gitutils.GitRepo, bool) { - userRepos, ok := fs.GitRepos[userID] +func (s *Storage) getGitRepo(userID, workspaceID int) (*gitutils.GitRepo, bool) { + userRepos, ok := s.GitRepos[userID] if !ok { return nil, false } diff --git a/server/internal/filesystem/workspace.go b/server/internal/filesystem/workspace.go index 122c816..891eaae 100644 --- a/server/internal/filesystem/workspace.go +++ b/server/internal/filesystem/workspace.go @@ -2,19 +2,28 @@ package filesystem import ( "fmt" - "os" "path/filepath" ) // GetWorkspacePath returns the path to the workspace directory for the given user and workspace IDs. -func (fs *FileSystem) GetWorkspacePath(userID, workspaceID int) string { - return filepath.Join(fs.RootDir, fmt.Sprintf("%d", userID), fmt.Sprintf("%d", workspaceID)) +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace +// Returns: +// - result: the path to the workspace directory +func (s *Storage) GetWorkspacePath(userID, workspaceID int) string { + return filepath.Join(s.RootDir, fmt.Sprintf("%d", userID), fmt.Sprintf("%d", workspaceID)) } // InitializeUserWorkspace creates the workspace directory for the given user and workspace IDs. -func (fs *FileSystem) InitializeUserWorkspace(userID, workspaceID int) error { - workspacePath := fs.GetWorkspacePath(userID, workspaceID) - err := os.MkdirAll(workspacePath, 0755) +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace to initialize +// Returns: +// - error: any error that occurred during the operation +func (s *Storage) InitializeUserWorkspace(userID, workspaceID int) error { + workspacePath := s.GetWorkspacePath(userID, workspaceID) + err := s.fs.MkdirAll(workspacePath, 0755) if err != nil { return fmt.Errorf("failed to create workspace directory: %w", err) } @@ -23,18 +32,17 @@ func (fs *FileSystem) InitializeUserWorkspace(userID, workspaceID int) error { } // DeleteUserWorkspace deletes the workspace directory for the given user and workspace IDs. -func (fs *FileSystem) DeleteUserWorkspace(userID, workspaceID int) error { - workspacePath := fs.GetWorkspacePath(userID, workspaceID) - err := os.RemoveAll(workspacePath) +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace to delete +// Returns: +// - error: any error that occurred during the operation +func (s *Storage) DeleteUserWorkspace(userID, workspaceID int) error { + workspacePath := s.GetWorkspacePath(userID, workspaceID) + err := s.fs.RemoveAll(workspacePath) if err != nil { return fmt.Errorf("failed to delete workspace directory: %w", err) } return nil } - -// CreateWorkspaceDirectory creates the workspace directory for the given user and workspace IDs. -func (fs *FileSystem) CreateWorkspaceDirectory(userID, workspaceID int) error { - dir := fs.GetWorkspacePath(userID, workspaceID) - return os.MkdirAll(dir, 0755) -} diff --git a/server/internal/handlers/handlers.go b/server/internal/handlers/handlers.go index a9d4e75..c96ccac 100644 --- a/server/internal/handlers/handlers.go +++ b/server/internal/handlers/handlers.go @@ -10,11 +10,11 @@ import ( // Handler provides common functionality for all handlers type Handler struct { DB *db.DB - FS *filesystem.FileSystem + FS *filesystem.Storage } // NewHandler creates a new handler with the given dependencies -func NewHandler(db *db.DB, fs *filesystem.FileSystem) *Handler { +func NewHandler(db *db.DB, fs *filesystem.Storage) *Handler { return &Handler{ DB: db, FS: fs, diff --git a/server/internal/user/user.go b/server/internal/user/user.go index fb83cc1..530833e 100644 --- a/server/internal/user/user.go +++ b/server/internal/user/user.go @@ -14,10 +14,10 @@ import ( type UserService struct { DB *db.DB - FS *filesystem.FileSystem + FS *filesystem.Storage } -func NewUserService(database *db.DB, fs *filesystem.FileSystem) *UserService { +func NewUserService(database *db.DB, fs *filesystem.Storage) *UserService { return &UserService{ DB: database, FS: fs, From 6a9461d9280a6d7ff9c62a5cf275d133ab71b748 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 13 Nov 2024 22:32:43 +0100 Subject: [PATCH 02/38] Rename fs variable --- server/internal/api/routes.go | 4 ++-- server/internal/handlers/admin_handlers.go | 6 +++--- server/internal/handlers/file_handlers.go | 14 +++++++------- server/internal/handlers/git_handlers.go | 4 ++-- server/internal/handlers/handlers.go | 4 ++-- server/internal/handlers/user_handlers.go | 2 +- server/internal/handlers/workspace_handlers.go | 6 +++--- 7 files changed, 20 insertions(+), 20 deletions(-) diff --git a/server/internal/api/routes.go b/server/internal/api/routes.go index af4df5b..ce8d4c3 100644 --- a/server/internal/api/routes.go +++ b/server/internal/api/routes.go @@ -12,11 +12,11 @@ import ( ) // SetupRoutes configures the API routes -func SetupRoutes(r chi.Router, db *db.DB, fs *filesystem.Storage, authMiddleware *auth.Middleware, sessionService *auth.SessionService) { +func SetupRoutes(r chi.Router, db *db.DB, s *filesystem.Storage, authMiddleware *auth.Middleware, sessionService *auth.SessionService) { handler := &handlers.Handler{ DB: db, - FS: fs, + S: s, } // Public routes (no authentication required) diff --git a/server/internal/handlers/admin_handlers.go b/server/internal/handlers/admin_handlers.go index 3a212bc..f861e5a 100644 --- a/server/internal/handlers/admin_handlers.go +++ b/server/internal/handlers/admin_handlers.go @@ -91,7 +91,7 @@ func (h *Handler) AdminCreateUser() http.HandlerFunc { } // Initialize user workspace - if err := h.FS.InitializeUserWorkspace(insertedUser.ID, insertedUser.LastWorkspaceID); err != nil { + if err := h.S.InitializeUserWorkspace(insertedUser.ID, insertedUser.LastWorkspaceID); err != nil { http.Error(w, "Failed to initialize user workspace", http.StatusInternalServerError) return } @@ -248,7 +248,7 @@ func (h *Handler) AdminListWorkspaces() http.HandlerFunc { workspaceData.WorkspaceName = ws.Name workspaceData.WorkspaceCreatedAt = ws.CreatedAt - fileStats, err := h.FS.GetFileStats(ws.UserID, ws.ID) + fileStats, err := h.S.GetFileStats(ws.UserID, ws.ID) if err != nil { http.Error(w, "Failed to get file stats", http.StatusInternalServerError) return @@ -278,7 +278,7 @@ func (h *Handler) AdminGetSystemStats() http.HandlerFunc { return } - fileStats, err := h.FS.GetTotalFileStats() + fileStats, err := h.S.GetTotalFileStats() if err != nil { http.Error(w, "Failed to get file stats", http.StatusInternalServerError) return diff --git a/server/internal/handlers/file_handlers.go b/server/internal/handlers/file_handlers.go index 4af815a..3e23772 100644 --- a/server/internal/handlers/file_handlers.go +++ b/server/internal/handlers/file_handlers.go @@ -17,7 +17,7 @@ func (h *Handler) ListFiles() http.HandlerFunc { return } - files, err := h.FS.ListFilesRecursively(ctx.UserID, ctx.Workspace.ID) + files, err := h.S.ListFilesRecursively(ctx.UserID, ctx.Workspace.ID) if err != nil { http.Error(w, "Failed to list files", http.StatusInternalServerError) return @@ -40,7 +40,7 @@ func (h *Handler) LookupFileByName() http.HandlerFunc { return } - filePaths, err := h.FS.FindFileByName(ctx.UserID, ctx.Workspace.ID, filename) + filePaths, err := h.S.FindFileByName(ctx.UserID, ctx.Workspace.ID, filename) if err != nil { http.Error(w, "File not found", http.StatusNotFound) return @@ -58,7 +58,7 @@ func (h *Handler) GetFileContent() http.HandlerFunc { } filePath := chi.URLParam(r, "*") - content, err := h.FS.GetFileContent(ctx.UserID, ctx.Workspace.ID, filePath) + content, err := h.S.GetFileContent(ctx.UserID, ctx.Workspace.ID, filePath) if err != nil { http.Error(w, "Failed to read file", http.StatusNotFound) return @@ -83,7 +83,7 @@ func (h *Handler) SaveFile() http.HandlerFunc { return } - err = h.FS.SaveFile(ctx.UserID, ctx.Workspace.ID, filePath, content) + err = h.S.SaveFile(ctx.UserID, ctx.Workspace.ID, filePath, content) if err != nil { http.Error(w, "Failed to save file", http.StatusInternalServerError) return @@ -101,7 +101,7 @@ func (h *Handler) DeleteFile() http.HandlerFunc { } filePath := chi.URLParam(r, "*") - err := h.FS.DeleteFile(ctx.UserID, ctx.Workspace.ID, filePath) + err := h.S.DeleteFile(ctx.UserID, ctx.Workspace.ID, filePath) if err != nil { http.Error(w, "Failed to delete file", http.StatusInternalServerError) return @@ -125,7 +125,7 @@ func (h *Handler) GetLastOpenedFile() http.HandlerFunc { return } - if _, err := h.FS.ValidatePath(ctx.UserID, ctx.Workspace.ID, filePath); err != nil { + if _, err := h.S.ValidatePath(ctx.UserID, ctx.Workspace.ID, filePath); err != nil { http.Error(w, "Invalid file path", http.StatusBadRequest) return } @@ -152,7 +152,7 @@ func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc { // Validate the file path exists in the workspace if requestBody.FilePath != "" { - if _, err := h.FS.ValidatePath(ctx.UserID, ctx.Workspace.ID, requestBody.FilePath); err != nil { + if _, err := h.S.ValidatePath(ctx.UserID, ctx.Workspace.ID, requestBody.FilePath); err != nil { http.Error(w, "Invalid file path", http.StatusBadRequest) return } diff --git a/server/internal/handlers/git_handlers.go b/server/internal/handlers/git_handlers.go index 61f7ba4..2836d35 100644 --- a/server/internal/handlers/git_handlers.go +++ b/server/internal/handlers/git_handlers.go @@ -28,7 +28,7 @@ func (h *Handler) StageCommitAndPush() http.HandlerFunc { return } - err := h.FS.StageCommitAndPush(ctx.UserID, ctx.Workspace.ID, requestBody.Message) + err := h.S.StageCommitAndPush(ctx.UserID, ctx.Workspace.ID, requestBody.Message) if err != nil { http.Error(w, "Failed to stage, commit, and push changes: "+err.Error(), http.StatusInternalServerError) return @@ -45,7 +45,7 @@ func (h *Handler) PullChanges() http.HandlerFunc { return } - err := h.FS.Pull(ctx.UserID, ctx.Workspace.ID) + err := h.S.Pull(ctx.UserID, ctx.Workspace.ID) if err != nil { http.Error(w, "Failed to pull changes: "+err.Error(), http.StatusInternalServerError) return diff --git a/server/internal/handlers/handlers.go b/server/internal/handlers/handlers.go index c96ccac..d45f721 100644 --- a/server/internal/handlers/handlers.go +++ b/server/internal/handlers/handlers.go @@ -10,14 +10,14 @@ import ( // Handler provides common functionality for all handlers type Handler struct { DB *db.DB - FS *filesystem.Storage + S *filesystem.Storage } // NewHandler creates a new handler with the given dependencies func NewHandler(db *db.DB, fs *filesystem.Storage) *Handler { return &Handler{ DB: db, - FS: fs, + S: fs, } } diff --git a/server/internal/handlers/user_handlers.go b/server/internal/handlers/user_handlers.go index 0a3148f..a949f11 100644 --- a/server/internal/handlers/user_handlers.go +++ b/server/internal/handlers/user_handlers.go @@ -200,7 +200,7 @@ func (h *Handler) DeleteAccount() http.HandlerFunc { // Delete workspace directories for _, workspace := range workspaces { - if err := h.FS.DeleteUserWorkspace(ctx.UserID, workspace.ID); err != nil { + if err := h.S.DeleteUserWorkspace(ctx.UserID, workspace.ID); err != nil { http.Error(w, "Failed to delete workspace files", http.StatusInternalServerError) return } diff --git a/server/internal/handlers/workspace_handlers.go b/server/internal/handlers/workspace_handlers.go index bbb3c59..7c0fea4 100644 --- a/server/internal/handlers/workspace_handlers.go +++ b/server/internal/handlers/workspace_handlers.go @@ -45,7 +45,7 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc { return } - if err := h.FS.InitializeUserWorkspace(workspace.UserID, workspace.ID); err != nil { + if err := h.S.InitializeUserWorkspace(workspace.UserID, workspace.ID); err != nil { http.Error(w, "Failed to initialize workspace directory", http.StatusInternalServerError) return } @@ -107,7 +107,7 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc { // Handle Git repository setup/teardown if Git settings changed if gitSettingsChanged(&workspace, ctx.Workspace) { if workspace.GitEnabled { - if err := h.FS.SetupGitRepo( + if err := h.S.SetupGitRepo( ctx.UserID, ctx.Workspace.ID, workspace.GitURL, @@ -119,7 +119,7 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc { } } else { - h.FS.DisableGitRepo(ctx.UserID, ctx.Workspace.ID) + h.S.DisableGitRepo(ctx.UserID, ctx.Workspace.ID) } } From 5311d2e14424eb8e5c23d3f06a533f4f14e34130 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 13 Nov 2024 22:34:11 +0100 Subject: [PATCH 03/38] Move storage to separate file --- server/internal/filesystem/filesystem.go | 54 ---------------------- server/internal/filesystem/storage.go | 58 ++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 54 deletions(-) create mode 100644 server/internal/filesystem/storage.go diff --git a/server/internal/filesystem/filesystem.go b/server/internal/filesystem/filesystem.go index a08a678..5cafc06 100644 --- a/server/internal/filesystem/filesystem.go +++ b/server/internal/filesystem/filesystem.go @@ -1,12 +1,8 @@ package filesystem import ( - "fmt" "io/fs" - "novamd/internal/gitutils" "os" - "path/filepath" - "strings" ) // fileSystem defines the interface for filesystem operations @@ -21,56 +17,6 @@ type fileSystem interface { IsNotExist(err error) bool } -// Storage represents the file system structure. -type Storage struct { - fs fileSystem - RootDir string - GitRepos map[int]map[int]*gitutils.GitRepo // map[userID]map[workspaceID]*gitutils.GitRepo -} - -// New creates a new Storage instance. -// Parameters: -// - rootDir: the root directory for the storage -// Returns: -// - result: the new Storage instance -func New(rootDir string) *Storage { - return NewWithFS(rootDir, &osFS{}) -} - -// NewWithFS creates a new Storage instance with the given filesystem. -// Parameters: -// - rootDir: the root directory for the storage -// - fs: the filesystem implementation to use -// Returns: -// - result: the new Storage instance -func NewWithFS(rootDir string, fs fileSystem) *Storage { - return &Storage{ - fs: fs, - RootDir: rootDir, - GitRepos: make(map[int]map[int]*gitutils.GitRepo), - } -} - -// ValidatePath validates the given path and returns the cleaned path if it is valid. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to validate the path for -// - path: the path to validate -// Returns: -// - result: the cleaned path if it is valid -// - error: any error that occurred during validation -func (s *Storage) ValidatePath(userID, workspaceID int, path string) (string, error) { - workspacePath := s.GetWorkspacePath(userID, workspaceID) - fullPath := filepath.Join(workspacePath, path) - cleanPath := filepath.Clean(fullPath) - - if !strings.HasPrefix(cleanPath, workspacePath) { - return "", fmt.Errorf("invalid path: outside of workspace") - } - - return cleanPath, nil -} - // osFS implements the FileSystem interface using the real filesystem. type osFS struct{} diff --git a/server/internal/filesystem/storage.go b/server/internal/filesystem/storage.go new file mode 100644 index 0000000..6c8d098 --- /dev/null +++ b/server/internal/filesystem/storage.go @@ -0,0 +1,58 @@ +package filesystem + +import ( + "fmt" + "novamd/internal/gitutils" + "path/filepath" + "strings" +) + +// Storage represents the file system structure. +type Storage struct { + fs fileSystem + RootDir string + GitRepos map[int]map[int]*gitutils.GitRepo // map[userID]map[workspaceID]*gitutils.GitRepo +} + +// New creates a new Storage instance. +// Parameters: +// - rootDir: the root directory for the storage +// Returns: +// - result: the new Storage instance +func New(rootDir string) *Storage { + return NewWithFS(rootDir, &osFS{}) +} + +// NewWithFS creates a new Storage instance with the given filesystem. +// Parameters: +// - rootDir: the root directory for the storage +// - fs: the filesystem implementation to use +// Returns: +// - result: the new Storage instance +func NewWithFS(rootDir string, fs fileSystem) *Storage { + return &Storage{ + fs: fs, + RootDir: rootDir, + GitRepos: make(map[int]map[int]*gitutils.GitRepo), + } +} + +// ValidatePath validates the given path and returns the cleaned path if it is valid. +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace to validate the path for +// - path: the path to validate +// Returns: +// - result: the cleaned path if it is valid +// - error: any error that occurred during validation +func (s *Storage) ValidatePath(userID, workspaceID int, path string) (string, error) { + workspacePath := s.GetWorkspacePath(userID, workspaceID) + fullPath := filepath.Join(workspacePath, path) + cleanPath := filepath.Clean(fullPath) + + if !strings.HasPrefix(cleanPath, workspacePath) { + return "", fmt.Errorf("invalid path: outside of workspace") + } + + return cleanPath, nil +} From e4510298ed697f8d1cc44f7c05cc3bbe09ccbb80 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 14 Nov 2024 21:13:45 +0100 Subject: [PATCH 04/38] Rename filesystem interfaces and structs --- server/cmd/server/main.go | 6 +- server/internal/api/routes.go | 8 +-- server/internal/filesystem/storage.go | 58 ------------------- server/internal/handlers/admin_handlers.go | 12 ++-- server/internal/handlers/file_handlers.go | 14 ++--- server/internal/handlers/git_handlers.go | 4 +- server/internal/handlers/handlers.go | 12 ++-- server/internal/handlers/user_handlers.go | 2 +- .../internal/handlers/workspace_handlers.go | 6 +- .../internal/{filesystem => storage}/files.go | 51 +++++++++------- .../{filesystem => storage}/filesystem.go | 2 +- server/internal/storage/filesystem_test.go | 46 +++++++++++++++ .../internal/{filesystem => storage}/git.go | 20 +++++-- server/internal/storage/service.go | 42 ++++++++++++++ .../{filesystem => storage}/workspace.go | 37 ++++++++++-- server/internal/user/user.go | 14 ++--- 16 files changed, 206 insertions(+), 128 deletions(-) delete mode 100644 server/internal/filesystem/storage.go rename server/internal/{filesystem => storage}/files.go (80%) rename server/internal/{filesystem => storage}/filesystem.go (98%) create mode 100644 server/internal/storage/filesystem_test.go rename server/internal/{filesystem => storage}/git.go (78%) create mode 100644 server/internal/storage/service.go rename server/internal/{filesystem => storage}/workspace.go (51%) diff --git a/server/cmd/server/main.go b/server/cmd/server/main.go index 96cdda5..c6204ba 100644 --- a/server/cmd/server/main.go +++ b/server/cmd/server/main.go @@ -17,8 +17,8 @@ import ( "novamd/internal/auth" "novamd/internal/config" "novamd/internal/db" - "novamd/internal/filesystem" "novamd/internal/handlers" + "novamd/internal/storage" ) func main() { @@ -45,7 +45,7 @@ func main() { } // Initialize filesystem - fs := filesystem.New(cfg.WorkDir) + s := storage.NewService(cfg.WorkDir) // Initialize JWT service jwtService, err := auth.NewJWTService(auth.JWTConfig{ @@ -95,7 +95,7 @@ func main() { // Set up routes r.Route("/api/v1", func(r chi.Router) { r.Use(httprate.LimitByIP(cfg.RateLimitRequests, cfg.RateLimitWindow)) - api.SetupRoutes(r, database, fs, authMiddleware, sessionService) + api.SetupRoutes(r, database, s, authMiddleware, sessionService) }) // Handle all other routes with static file server diff --git a/server/internal/api/routes.go b/server/internal/api/routes.go index ce8d4c3..71a873b 100644 --- a/server/internal/api/routes.go +++ b/server/internal/api/routes.go @@ -4,19 +4,19 @@ package api import ( "novamd/internal/auth" "novamd/internal/db" - "novamd/internal/filesystem" "novamd/internal/handlers" "novamd/internal/middleware" + "novamd/internal/storage" "github.com/go-chi/chi/v5" ) // SetupRoutes configures the API routes -func SetupRoutes(r chi.Router, db *db.DB, s *filesystem.Storage, authMiddleware *auth.Middleware, sessionService *auth.SessionService) { +func SetupRoutes(r chi.Router, db *db.DB, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) { handler := &handlers.Handler{ - DB: db, - S: s, + DB: db, + Storage: s, } // Public routes (no authentication required) diff --git a/server/internal/filesystem/storage.go b/server/internal/filesystem/storage.go deleted file mode 100644 index 6c8d098..0000000 --- a/server/internal/filesystem/storage.go +++ /dev/null @@ -1,58 +0,0 @@ -package filesystem - -import ( - "fmt" - "novamd/internal/gitutils" - "path/filepath" - "strings" -) - -// Storage represents the file system structure. -type Storage struct { - fs fileSystem - RootDir string - GitRepos map[int]map[int]*gitutils.GitRepo // map[userID]map[workspaceID]*gitutils.GitRepo -} - -// New creates a new Storage instance. -// Parameters: -// - rootDir: the root directory for the storage -// Returns: -// - result: the new Storage instance -func New(rootDir string) *Storage { - return NewWithFS(rootDir, &osFS{}) -} - -// NewWithFS creates a new Storage instance with the given filesystem. -// Parameters: -// - rootDir: the root directory for the storage -// - fs: the filesystem implementation to use -// Returns: -// - result: the new Storage instance -func NewWithFS(rootDir string, fs fileSystem) *Storage { - return &Storage{ - fs: fs, - RootDir: rootDir, - GitRepos: make(map[int]map[int]*gitutils.GitRepo), - } -} - -// ValidatePath validates the given path and returns the cleaned path if it is valid. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to validate the path for -// - path: the path to validate -// Returns: -// - result: the cleaned path if it is valid -// - error: any error that occurred during validation -func (s *Storage) ValidatePath(userID, workspaceID int, path string) (string, error) { - workspacePath := s.GetWorkspacePath(userID, workspaceID) - fullPath := filepath.Join(workspacePath, path) - cleanPath := filepath.Clean(fullPath) - - if !strings.HasPrefix(cleanPath, workspacePath) { - return "", fmt.Errorf("invalid path: outside of workspace") - } - - return cleanPath, nil -} diff --git a/server/internal/handlers/admin_handlers.go b/server/internal/handlers/admin_handlers.go index f861e5a..059c8cb 100644 --- a/server/internal/handlers/admin_handlers.go +++ b/server/internal/handlers/admin_handlers.go @@ -4,9 +4,9 @@ import ( "encoding/json" "net/http" "novamd/internal/db" - "novamd/internal/filesystem" "novamd/internal/httpcontext" "novamd/internal/models" + "novamd/internal/storage" "strconv" "time" @@ -91,7 +91,7 @@ func (h *Handler) AdminCreateUser() http.HandlerFunc { } // Initialize user workspace - if err := h.S.InitializeUserWorkspace(insertedUser.ID, insertedUser.LastWorkspaceID); err != nil { + if err := h.Storage.InitializeUserWorkspace(insertedUser.ID, insertedUser.LastWorkspaceID); err != nil { http.Error(w, "Failed to initialize user workspace", http.StatusInternalServerError) return } @@ -218,7 +218,7 @@ type WorkspaceStats struct { WorkspaceID int `json:"workspaceID"` WorkspaceName string `json:"workspaceName"` WorkspaceCreatedAt time.Time `json:"workspaceCreatedAt"` - *filesystem.FileCountStats + *storage.FileCountStats } // AdminListWorkspaces returns a list of all workspaces and their stats @@ -248,7 +248,7 @@ func (h *Handler) AdminListWorkspaces() http.HandlerFunc { workspaceData.WorkspaceName = ws.Name workspaceData.WorkspaceCreatedAt = ws.CreatedAt - fileStats, err := h.S.GetFileStats(ws.UserID, ws.ID) + fileStats, err := h.Storage.GetFileStats(ws.UserID, ws.ID) if err != nil { http.Error(w, "Failed to get file stats", http.StatusInternalServerError) return @@ -266,7 +266,7 @@ func (h *Handler) AdminListWorkspaces() http.HandlerFunc { // SystemStats holds system-wide statistics type SystemStats struct { *db.UserStats - *filesystem.FileCountStats + *storage.FileCountStats } // AdminGetSystemStats returns system-wide statistics for admins @@ -278,7 +278,7 @@ func (h *Handler) AdminGetSystemStats() http.HandlerFunc { return } - fileStats, err := h.S.GetTotalFileStats() + fileStats, err := h.Storage.GetTotalFileStats() if err != nil { http.Error(w, "Failed to get file stats", http.StatusInternalServerError) return diff --git a/server/internal/handlers/file_handlers.go b/server/internal/handlers/file_handlers.go index 3e23772..d970fa1 100644 --- a/server/internal/handlers/file_handlers.go +++ b/server/internal/handlers/file_handlers.go @@ -17,7 +17,7 @@ func (h *Handler) ListFiles() http.HandlerFunc { return } - files, err := h.S.ListFilesRecursively(ctx.UserID, ctx.Workspace.ID) + files, err := h.Storage.ListFilesRecursively(ctx.UserID, ctx.Workspace.ID) if err != nil { http.Error(w, "Failed to list files", http.StatusInternalServerError) return @@ -40,7 +40,7 @@ func (h *Handler) LookupFileByName() http.HandlerFunc { return } - filePaths, err := h.S.FindFileByName(ctx.UserID, ctx.Workspace.ID, filename) + filePaths, err := h.Storage.FindFileByName(ctx.UserID, ctx.Workspace.ID, filename) if err != nil { http.Error(w, "File not found", http.StatusNotFound) return @@ -58,7 +58,7 @@ func (h *Handler) GetFileContent() http.HandlerFunc { } filePath := chi.URLParam(r, "*") - content, err := h.S.GetFileContent(ctx.UserID, ctx.Workspace.ID, filePath) + content, err := h.Storage.GetFileContent(ctx.UserID, ctx.Workspace.ID, filePath) if err != nil { http.Error(w, "Failed to read file", http.StatusNotFound) return @@ -83,7 +83,7 @@ func (h *Handler) SaveFile() http.HandlerFunc { return } - err = h.S.SaveFile(ctx.UserID, ctx.Workspace.ID, filePath, content) + err = h.Storage.SaveFile(ctx.UserID, ctx.Workspace.ID, filePath, content) if err != nil { http.Error(w, "Failed to save file", http.StatusInternalServerError) return @@ -101,7 +101,7 @@ func (h *Handler) DeleteFile() http.HandlerFunc { } filePath := chi.URLParam(r, "*") - err := h.S.DeleteFile(ctx.UserID, ctx.Workspace.ID, filePath) + err := h.Storage.DeleteFile(ctx.UserID, ctx.Workspace.ID, filePath) if err != nil { http.Error(w, "Failed to delete file", http.StatusInternalServerError) return @@ -125,7 +125,7 @@ func (h *Handler) GetLastOpenedFile() http.HandlerFunc { return } - if _, err := h.S.ValidatePath(ctx.UserID, ctx.Workspace.ID, filePath); err != nil { + if _, err := h.Storage.ValidatePath(ctx.UserID, ctx.Workspace.ID, filePath); err != nil { http.Error(w, "Invalid file path", http.StatusBadRequest) return } @@ -152,7 +152,7 @@ func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc { // Validate the file path exists in the workspace if requestBody.FilePath != "" { - if _, err := h.S.ValidatePath(ctx.UserID, ctx.Workspace.ID, requestBody.FilePath); err != nil { + if _, err := h.Storage.ValidatePath(ctx.UserID, ctx.Workspace.ID, requestBody.FilePath); err != nil { http.Error(w, "Invalid file path", http.StatusBadRequest) return } diff --git a/server/internal/handlers/git_handlers.go b/server/internal/handlers/git_handlers.go index 2836d35..f8ee589 100644 --- a/server/internal/handlers/git_handlers.go +++ b/server/internal/handlers/git_handlers.go @@ -28,7 +28,7 @@ func (h *Handler) StageCommitAndPush() http.HandlerFunc { return } - err := h.S.StageCommitAndPush(ctx.UserID, ctx.Workspace.ID, requestBody.Message) + err := h.Storage.StageCommitAndPush(ctx.UserID, ctx.Workspace.ID, requestBody.Message) if err != nil { http.Error(w, "Failed to stage, commit, and push changes: "+err.Error(), http.StatusInternalServerError) return @@ -45,7 +45,7 @@ func (h *Handler) PullChanges() http.HandlerFunc { return } - err := h.S.Pull(ctx.UserID, ctx.Workspace.ID) + err := h.Storage.Pull(ctx.UserID, ctx.Workspace.ID) if err != nil { http.Error(w, "Failed to pull changes: "+err.Error(), http.StatusInternalServerError) return diff --git a/server/internal/handlers/handlers.go b/server/internal/handlers/handlers.go index d45f721..743e40d 100644 --- a/server/internal/handlers/handlers.go +++ b/server/internal/handlers/handlers.go @@ -4,20 +4,20 @@ import ( "encoding/json" "net/http" "novamd/internal/db" - "novamd/internal/filesystem" + "novamd/internal/storage" ) // Handler provides common functionality for all handlers type Handler struct { - DB *db.DB - S *filesystem.Storage + DB *db.DB + Storage storage.Manager } // NewHandler creates a new handler with the given dependencies -func NewHandler(db *db.DB, fs *filesystem.Storage) *Handler { +func NewHandler(db *db.DB, s storage.Manager) *Handler { return &Handler{ - DB: db, - S: fs, + DB: db, + Storage: s, } } diff --git a/server/internal/handlers/user_handlers.go b/server/internal/handlers/user_handlers.go index a949f11..fa57020 100644 --- a/server/internal/handlers/user_handlers.go +++ b/server/internal/handlers/user_handlers.go @@ -200,7 +200,7 @@ func (h *Handler) DeleteAccount() http.HandlerFunc { // Delete workspace directories for _, workspace := range workspaces { - if err := h.S.DeleteUserWorkspace(ctx.UserID, workspace.ID); err != nil { + if err := h.Storage.DeleteUserWorkspace(ctx.UserID, workspace.ID); err != nil { http.Error(w, "Failed to delete workspace files", http.StatusInternalServerError) return } diff --git a/server/internal/handlers/workspace_handlers.go b/server/internal/handlers/workspace_handlers.go index 7c0fea4..0f2a012 100644 --- a/server/internal/handlers/workspace_handlers.go +++ b/server/internal/handlers/workspace_handlers.go @@ -45,7 +45,7 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc { return } - if err := h.S.InitializeUserWorkspace(workspace.UserID, workspace.ID); err != nil { + if err := h.Storage.InitializeUserWorkspace(workspace.UserID, workspace.ID); err != nil { http.Error(w, "Failed to initialize workspace directory", http.StatusInternalServerError) return } @@ -107,7 +107,7 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc { // Handle Git repository setup/teardown if Git settings changed if gitSettingsChanged(&workspace, ctx.Workspace) { if workspace.GitEnabled { - if err := h.S.SetupGitRepo( + if err := h.Storage.SetupGitRepo( ctx.UserID, ctx.Workspace.ID, workspace.GitURL, @@ -119,7 +119,7 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc { } } else { - h.S.DisableGitRepo(ctx.UserID, ctx.Workspace.ID) + h.Storage.DisableGitRepo(ctx.UserID, ctx.Workspace.ID) } } diff --git a/server/internal/filesystem/files.go b/server/internal/storage/files.go similarity index 80% rename from server/internal/filesystem/files.go rename to server/internal/storage/files.go index eb5ab16..136cc84 100644 --- a/server/internal/filesystem/files.go +++ b/server/internal/storage/files.go @@ -1,6 +1,6 @@ -// Package filesystem provides functionalities to interact with the file system, +// Package storage provides functionalities to interact with the file system, // including listing files, finding files by name, getting file content, saving files, and deleting files. -package filesystem +package storage import ( "fmt" @@ -10,12 +10,23 @@ import ( "strings" ) -// StorageNode represents a file or directory in the storage. -type StorageNode struct { - ID string `json:"id"` - Name string `json:"name"` - Path string `json:"path"` - Children []StorageNode `json:"children,omitempty"` +// FileManager provides functionalities to interact with files in the storage. +type FileManager interface { + ListFilesRecursively(userID, workspaceID int) ([]FileNode, error) + FindFileByName(userID, workspaceID int, filename string) ([]string, error) + GetFileContent(userID, workspaceID int, filePath string) ([]byte, error) + SaveFile(userID, workspaceID int, filePath string, content []byte) error + DeleteFile(userID, workspaceID int, filePath string) error + GetFileStats(userID, workspaceID int) (*FileCountStats, error) + GetTotalFileStats() (*FileCountStats, error) +} + +// FileNode represents a file or directory in the storage. +type FileNode struct { + ID string `json:"id"` + Name string `json:"name"` + Path string `json:"path"` + Children []FileNode `json:"children,omitempty"` } // ListFilesRecursively returns a list of all files in the workspace directory and its subdirectories. @@ -25,13 +36,13 @@ type StorageNode struct { // Returns: // - nodes: a list of files and directories in the workspace // - error: any error that occurred during listing -func (s *Storage) ListFilesRecursively(userID, workspaceID int) ([]StorageNode, error) { +func (s *Service) ListFilesRecursively(userID, workspaceID int) ([]FileNode, error) { workspacePath := s.GetWorkspacePath(userID, workspaceID) return s.walkDirectory(workspacePath, "") } // walkDirectory recursively walks the directory and returns a list of files and directories. -func (s *Storage) walkDirectory(dir, prefix string) ([]StorageNode, error) { +func (s *Service) walkDirectory(dir, prefix string) ([]FileNode, error) { entries, err := s.fs.ReadDir(dir) if err != nil { return nil, err @@ -56,7 +67,7 @@ func (s *Storage) walkDirectory(dir, prefix string) ([]StorageNode, error) { }) // Create combined slice with directories first, then files - nodes := make([]StorageNode, 0, len(entries)) + nodes := make([]FileNode, 0, len(entries)) // Add directories first for _, entry := range dirs { @@ -69,7 +80,7 @@ func (s *Storage) walkDirectory(dir, prefix string) ([]StorageNode, error) { return nil, err } - node := StorageNode{ + node := FileNode{ ID: path, Name: name, Path: path, @@ -83,7 +94,7 @@ func (s *Storage) walkDirectory(dir, prefix string) ([]StorageNode, error) { name := entry.Name() path := filepath.Join(prefix, name) - node := StorageNode{ + node := FileNode{ ID: path, Name: name, Path: path, @@ -102,7 +113,7 @@ func (s *Storage) walkDirectory(dir, prefix string) ([]StorageNode, error) { // Returns: // - foundPaths: a list of file paths that match the filename // - error: any error that occurred during the search -func (s *Storage) FindFileByName(userID, workspaceID int, filename string) ([]string, error) { +func (s *Service) FindFileByName(userID, workspaceID int, filename string) ([]string, error) { var foundPaths []string workspacePath := s.GetWorkspacePath(userID, workspaceID) @@ -141,7 +152,7 @@ func (s *Storage) FindFileByName(userID, workspaceID int, filename string) ([]st // Returns: // - content: the content of the file // - error: any error that occurred during reading -func (s *Storage) GetFileContent(userID, workspaceID int, filePath string) ([]byte, error) { +func (s *Service) GetFileContent(userID, workspaceID int, filePath string) ([]byte, error) { fullPath, err := s.ValidatePath(userID, workspaceID, filePath) if err != nil { return nil, err @@ -157,7 +168,7 @@ func (s *Storage) GetFileContent(userID, workspaceID int, filePath string) ([]by // - content: the content to write to the file // Returns: // - error: any error that occurred during saving -func (s *Storage) SaveFile(userID, workspaceID int, filePath string, content []byte) error { +func (s *Service) SaveFile(userID, workspaceID int, filePath string, content []byte) error { fullPath, err := s.ValidatePath(userID, workspaceID, filePath) if err != nil { return err @@ -178,7 +189,7 @@ func (s *Storage) SaveFile(userID, workspaceID int, filePath string, content []b // - filePath: the path of the file to delete // Returns: // - error: any error that occurred during deletion -func (s *Storage) DeleteFile(userID, workspaceID int, filePath string) error { +func (s *Service) DeleteFile(userID, workspaceID int, filePath string) error { fullPath, err := s.ValidatePath(userID, workspaceID, filePath) if err != nil { return err @@ -199,7 +210,7 @@ type FileCountStats struct { // Returns: // - result: statistics about the files in the workspace // - error: any error that occurred during counting -func (s *Storage) GetFileStats(userID, workspaceID int) (*FileCountStats, error) { +func (s *Service) GetFileStats(userID, workspaceID int) (*FileCountStats, error) { workspacePath := s.GetWorkspacePath(userID, workspaceID) // Check if workspace exists @@ -214,12 +225,12 @@ func (s *Storage) GetFileStats(userID, workspaceID int) (*FileCountStats, error) // GetTotalFileStats returns the total file statistics for the storage. // Returns: // - result: statistics about the files in the storage -func (s *Storage) GetTotalFileStats() (*FileCountStats, error) { +func (s *Service) GetTotalFileStats() (*FileCountStats, error) { return s.countFilesInPath(s.RootDir) } // countFilesInPath counts the total number of files and the total size of files in the given directory. -func (s *Storage) countFilesInPath(directoryPath string) (*FileCountStats, error) { +func (s *Service) countFilesInPath(directoryPath string) (*FileCountStats, error) { result := &FileCountStats{} err := filepath.WalkDir(directoryPath, func(path string, d os.DirEntry, err error) error { diff --git a/server/internal/filesystem/filesystem.go b/server/internal/storage/filesystem.go similarity index 98% rename from server/internal/filesystem/filesystem.go rename to server/internal/storage/filesystem.go index 5cafc06..f5ca0b9 100644 --- a/server/internal/filesystem/filesystem.go +++ b/server/internal/storage/filesystem.go @@ -1,4 +1,4 @@ -package filesystem +package storage import ( "io/fs" diff --git a/server/internal/storage/filesystem_test.go b/server/internal/storage/filesystem_test.go new file mode 100644 index 0000000..956599b --- /dev/null +++ b/server/internal/storage/filesystem_test.go @@ -0,0 +1,46 @@ +package storage_test + +import ( + "io/fs" + "os" + "testing/fstest" +) + +// mapFS adapts testing.MapFS to implement our fileSystem interface +type mapFS struct { + fstest.MapFS +} + +func NewMapFS() *mapFS { + return &mapFS{ + MapFS: make(fstest.MapFS), + } +} + +// Only implement the methods that MapFS doesn't already provide +func (m *mapFS) WriteFile(path string, data []byte, perm fs.FileMode) error { + m.MapFS[path] = &fstest.MapFile{ + Data: data, + Mode: perm, + } + return nil +} + +func (m *mapFS) Remove(path string) error { + delete(m.MapFS, path) + return nil +} + +func (m *mapFS) MkdirAll(_ string, _ fs.FileMode) error { + // For MapFS, we don't actually need to create directories + return nil +} + +func (m *mapFS) RemoveAll(path string) error { + delete(m.MapFS, path) + return nil +} + +func (m *mapFS) IsNotExist(err error) bool { + return os.IsNotExist(err) +} diff --git a/server/internal/filesystem/git.go b/server/internal/storage/git.go similarity index 78% rename from server/internal/filesystem/git.go rename to server/internal/storage/git.go index f70d7ad..cebd1b0 100644 --- a/server/internal/filesystem/git.go +++ b/server/internal/storage/git.go @@ -1,10 +1,18 @@ -package filesystem +package storage import ( "fmt" "novamd/internal/gitutils" ) +// RepositoryManager defines the interface for managing Git repositories. +type RepositoryManager interface { + SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken string) error + DisableGitRepo(userID, workspaceID int) + StageCommitAndPush(userID, workspaceID int, message string) error + Pull(userID, workspaceID int) error +} + // SetupGitRepo sets up a Git repository for the given user and workspace IDs. // Parameters: // - userID: the ID of the user who owns the workspace @@ -14,7 +22,7 @@ import ( // - gitToken: the access token for the Git repository // Returns: // - error: any error that occurred during setup -func (s *Storage) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken string) error { +func (s *Service) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken string) error { workspacePath := s.GetWorkspacePath(userID, workspaceID) if _, ok := s.GitRepos[userID]; !ok { s.GitRepos[userID] = make(map[int]*gitutils.GitRepo) @@ -27,7 +35,7 @@ func (s *Storage) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToke // Parameters: // - userID: the ID of the user who owns the workspace // - workspaceID: the ID of the workspace to disable the Git repository for -func (s *Storage) DisableGitRepo(userID, workspaceID int) { +func (s *Service) DisableGitRepo(userID, workspaceID int) { if userRepos, ok := s.GitRepos[userID]; ok { delete(userRepos, workspaceID) if len(userRepos) == 0 { @@ -43,7 +51,7 @@ func (s *Storage) DisableGitRepo(userID, workspaceID int) { // - message: the commit message // Returns: // - error: any error that occurred during the operation -func (s *Storage) StageCommitAndPush(userID, workspaceID int, message string) error { +func (s *Service) StageCommitAndPush(userID, workspaceID int, message string) error { repo, ok := s.getGitRepo(userID, workspaceID) if !ok { return fmt.Errorf("git settings not configured for this workspace") @@ -62,7 +70,7 @@ func (s *Storage) StageCommitAndPush(userID, workspaceID int, message string) er // - workspaceID: the ID of the workspace to pull changes for // Returns: // - error: any error that occurred during the operation -func (s *Storage) Pull(userID, workspaceID int) error { +func (s *Service) Pull(userID, workspaceID int) error { repo, ok := s.getGitRepo(userID, workspaceID) if !ok { return fmt.Errorf("git settings not configured for this workspace") @@ -72,7 +80,7 @@ func (s *Storage) Pull(userID, workspaceID int) error { } // getGitRepo returns the Git repository for the given user and workspace IDs. -func (s *Storage) getGitRepo(userID, workspaceID int) (*gitutils.GitRepo, bool) { +func (s *Service) getGitRepo(userID, workspaceID int) (*gitutils.GitRepo, bool) { userRepos, ok := s.GitRepos[userID] if !ok { return nil, false diff --git a/server/internal/storage/service.go b/server/internal/storage/service.go new file mode 100644 index 0000000..5b57b4b --- /dev/null +++ b/server/internal/storage/service.go @@ -0,0 +1,42 @@ +package storage + +import ( + "novamd/internal/gitutils" +) + +// Manager interface combines all storage interfaces. +type Manager interface { + FileManager + WorkspaceManager + RepositoryManager +} + +// Service represents the file system structure. +type Service struct { + fs fileSystem + RootDir string + GitRepos map[int]map[int]*gitutils.GitRepo // map[userID]map[workspaceID]*gitutils.GitRepo +} + +// NewService creates a new Storage instance. +// Parameters: +// - rootDir: the root directory for the storage +// Returns: +// - result: the new Storage instance +func NewService(rootDir string) *Service { + return NewServiceWithFS(rootDir, &osFS{}) +} + +// NewServiceWithFS creates a new Storage instance with the given filesystem. +// Parameters: +// - rootDir: the root directory for the storage +// - fs: the filesystem implementation to use +// Returns: +// - result: the new Storage instance +func NewServiceWithFS(rootDir string, fs fileSystem) *Service { + return &Service{ + fs: fs, + RootDir: rootDir, + GitRepos: make(map[int]map[int]*gitutils.GitRepo), + } +} diff --git a/server/internal/filesystem/workspace.go b/server/internal/storage/workspace.go similarity index 51% rename from server/internal/filesystem/workspace.go rename to server/internal/storage/workspace.go index 891eaae..b72155f 100644 --- a/server/internal/filesystem/workspace.go +++ b/server/internal/storage/workspace.go @@ -1,17 +1,46 @@ -package filesystem +package storage import ( "fmt" "path/filepath" + "strings" ) +// WorkspaceManager provides functionalities to interact with workspaces in the storage. +type WorkspaceManager interface { + ValidatePath(userID, workspaceID int, path string) (string, error) + GetWorkspacePath(userID, workspaceID int) string + InitializeUserWorkspace(userID, workspaceID int) error + DeleteUserWorkspace(userID, workspaceID int) error +} + +// ValidatePath validates the given path and returns the cleaned path if it is valid. +// Parameters: +// - userID: the ID of the user who owns the workspace +// - workspaceID: the ID of the workspace to validate the path for +// - path: the path to validate +// Returns: +// - result: the cleaned path if it is valid +// - error: any error that occurred during validation +func (s *Service) ValidatePath(userID, workspaceID int, path string) (string, error) { + workspacePath := s.GetWorkspacePath(userID, workspaceID) + fullPath := filepath.Join(workspacePath, path) + cleanPath := filepath.Clean(fullPath) + + if !strings.HasPrefix(cleanPath, workspacePath) { + return "", fmt.Errorf("invalid path: outside of workspace") + } + + return cleanPath, nil +} + // GetWorkspacePath returns the path to the workspace directory for the given user and workspace IDs. // Parameters: // - userID: the ID of the user who owns the workspace // - workspaceID: the ID of the workspace // Returns: // - result: the path to the workspace directory -func (s *Storage) GetWorkspacePath(userID, workspaceID int) string { +func (s *Service) GetWorkspacePath(userID, workspaceID int) string { return filepath.Join(s.RootDir, fmt.Sprintf("%d", userID), fmt.Sprintf("%d", workspaceID)) } @@ -21,7 +50,7 @@ func (s *Storage) GetWorkspacePath(userID, workspaceID int) string { // - workspaceID: the ID of the workspace to initialize // Returns: // - error: any error that occurred during the operation -func (s *Storage) InitializeUserWorkspace(userID, workspaceID int) error { +func (s *Service) InitializeUserWorkspace(userID, workspaceID int) error { workspacePath := s.GetWorkspacePath(userID, workspaceID) err := s.fs.MkdirAll(workspacePath, 0755) if err != nil { @@ -37,7 +66,7 @@ func (s *Storage) InitializeUserWorkspace(userID, workspaceID int) error { // - workspaceID: the ID of the workspace to delete // Returns: // - error: any error that occurred during the operation -func (s *Storage) DeleteUserWorkspace(userID, workspaceID int) error { +func (s *Service) DeleteUserWorkspace(userID, workspaceID int) error { workspacePath := s.GetWorkspacePath(userID, workspaceID) err := s.fs.RemoveAll(workspacePath) if err != nil { diff --git a/server/internal/user/user.go b/server/internal/user/user.go index 530833e..638ec39 100644 --- a/server/internal/user/user.go +++ b/server/internal/user/user.go @@ -8,19 +8,19 @@ import ( "golang.org/x/crypto/bcrypt" "novamd/internal/db" - "novamd/internal/filesystem" "novamd/internal/models" + "novamd/internal/storage" ) type UserService struct { - DB *db.DB - FS *filesystem.Storage + DB *db.DB + Storage storage.Manager } -func NewUserService(database *db.DB, fs *filesystem.Storage) *UserService { +func NewUserService(database *db.DB, s storage.Manager) *UserService { return &UserService{ - DB: database, - FS: fs, + DB: database, + Storage: s, } } @@ -53,7 +53,7 @@ func (s *UserService) SetupAdminUser(adminEmail, adminPassword string) (*models. } // Initialize workspace directory - err = s.FS.InitializeUserWorkspace(createdUser.ID, createdUser.LastWorkspaceID) + err = s.Storage.InitializeUserWorkspace(createdUser.ID, createdUser.LastWorkspaceID) if err != nil { return nil, fmt.Errorf("failed to initialize admin workspace: %w", err) } From 408746187e15c30d4c2db25ecd440ab015637448 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 14 Nov 2024 22:11:40 +0100 Subject: [PATCH 05/38] Implement test list files --- server/go.mod | 4 ++ server/internal/storage/files_test.go | 60 ++++++++++++++++++++++ server/internal/storage/filesystem_test.go | 18 +++---- server/internal/storage/service_test.go | 12 +++++ server/internal/testutils/assertions.go | 13 +++++ 5 files changed, 98 insertions(+), 9 deletions(-) create mode 100644 server/internal/storage/files_test.go create mode 100644 server/internal/storage/service_test.go create mode 100644 server/internal/testutils/assertions.go diff --git a/server/go.mod b/server/go.mod index 0af2cf4..90ae840 100644 --- a/server/go.mod +++ b/server/go.mod @@ -11,6 +11,7 @@ require ( github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/mattn/go-sqlite3 v1.14.23 + github.com/stretchr/testify v1.9.0 github.com/unrolled/secure v1.17.0 golang.org/x/crypto v0.21.0 ) @@ -22,6 +23,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudflare/circl v1.3.7 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect @@ -33,6 +35,7 @@ require ( github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.2.2 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect @@ -42,4 +45,5 @@ require ( golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.13.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/server/internal/storage/files_test.go b/server/internal/storage/files_test.go new file mode 100644 index 0000000..39e0c6f --- /dev/null +++ b/server/internal/storage/files_test.go @@ -0,0 +1,60 @@ +// Package storage_test provides tests for the storage package. +package storage_test + +import ( + "novamd/internal/storage" + "novamd/internal/testutils" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestListFilesRecursively(t *testing.T) { + tests := []testutils.TestCase{ + { + Name: "empty workspace returns empty list", + Setup: func(t *testing.T, fixtures any) { + fs := fixtures.(*MapFS) + require.NoError(t, fs.MkdirAll("/test/root/1/1", 0755)) + }, + Fixtures: NewMapFS(), + Validate: func(t *testing.T, result any, err error) { + require.NoError(t, err) + files := result.([]storage.FileNode) + assert.Empty(t, files) + }, + }, + { + Name: "lists files and directories correctly", + Setup: func(t *testing.T, fixtures any) { + fs := fixtures.(*MapFS) + err := fs.WriteFile("/test/root/1/1/file1.md", []byte("content1"), 0644) + require.NoError(t, err, "Failed to write file1.md") + + err = fs.WriteFile("/test/root/1/1/dir/file2.md", []byte("content2"), 0644) + require.NoError(t, err, "Failed to write file2.md") + }, + Fixtures: NewMapFS(), + Validate: func(t *testing.T, result any, err error) { + require.NoError(t, err) + files := result.([]storage.FileNode) + require.Len(t, files, 2) + assert.Equal(t, "dir", files[0].Name) + assert.Equal(t, "file1.md", files[1].Name) + assert.Len(t, files[0].Children, 1) + assert.Equal(t, "file2.md", files[0].Children[0].Name) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.Name, func(t *testing.T) { + fs := tc.Fixtures.(*MapFS) + srv := storage.NewServiceWithFS("/test/root", fs) + tc.Setup(t, tc.Fixtures) + files, err := srv.ListFilesRecursively(1, 1) + tc.Validate(t, files, err) + }) + } +} diff --git a/server/internal/storage/filesystem_test.go b/server/internal/storage/filesystem_test.go index 956599b..3e51dca 100644 --- a/server/internal/storage/filesystem_test.go +++ b/server/internal/storage/filesystem_test.go @@ -6,19 +6,19 @@ import ( "testing/fstest" ) -// mapFS adapts testing.MapFS to implement our fileSystem interface -type mapFS struct { +// MapFS adapts testing.MapFS to implement our fileSystem interface +type MapFS struct { fstest.MapFS } -func NewMapFS() *mapFS { - return &mapFS{ +func NewMapFS() *MapFS { + return &MapFS{ MapFS: make(fstest.MapFS), } } // Only implement the methods that MapFS doesn't already provide -func (m *mapFS) WriteFile(path string, data []byte, perm fs.FileMode) error { +func (m *MapFS) WriteFile(path string, data []byte, perm fs.FileMode) error { m.MapFS[path] = &fstest.MapFile{ Data: data, Mode: perm, @@ -26,21 +26,21 @@ func (m *mapFS) WriteFile(path string, data []byte, perm fs.FileMode) error { return nil } -func (m *mapFS) Remove(path string) error { +func (m *MapFS) Remove(path string) error { delete(m.MapFS, path) return nil } -func (m *mapFS) MkdirAll(_ string, _ fs.FileMode) error { +func (m *MapFS) MkdirAll(_ string, _ fs.FileMode) error { // For MapFS, we don't actually need to create directories return nil } -func (m *mapFS) RemoveAll(path string) error { +func (m *MapFS) RemoveAll(path string) error { delete(m.MapFS, path) return nil } -func (m *mapFS) IsNotExist(err error) bool { +func (m *MapFS) IsNotExist(err error) bool { return os.IsNotExist(err) } diff --git a/server/internal/storage/service_test.go b/server/internal/storage/service_test.go new file mode 100644 index 0000000..696f90d --- /dev/null +++ b/server/internal/storage/service_test.go @@ -0,0 +1,12 @@ +package storage_test + +import ( + "novamd/internal/storage" + "testing" +) + +func SetupTestService(t *testing.T) (*storage.Service, *MapFS) { + fs := &MapFS{} + srv := storage.NewServiceWithFS("/test/root", fs) + return srv, fs +} diff --git a/server/internal/testutils/assertions.go b/server/internal/testutils/assertions.go new file mode 100644 index 0000000..5bbfa23 --- /dev/null +++ b/server/internal/testutils/assertions.go @@ -0,0 +1,13 @@ +package testutils + +import ( + "testing" +) + +// TestCase defines a generic test case structure that can be used across packages +type TestCase struct { + Name string + Setup func(t *testing.T, fixtures any) + Fixtures any + Validate func(t *testing.T, result any, err error) +} From 2fe642ac619bf8f327593715f24f3399489cd85a Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 19 Nov 2024 21:43:52 +0100 Subject: [PATCH 06/38] Rework mock filesystem --- server/go.mod | 4 - server/internal/storage/filesystem_test.go | 129 +++++++++++++++++---- server/internal/storage/service_test.go | 12 -- 3 files changed, 104 insertions(+), 41 deletions(-) delete mode 100644 server/internal/storage/service_test.go diff --git a/server/go.mod b/server/go.mod index 90ae840..0af2cf4 100644 --- a/server/go.mod +++ b/server/go.mod @@ -11,7 +11,6 @@ require ( github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/mattn/go-sqlite3 v1.14.23 - github.com/stretchr/testify v1.9.0 github.com/unrolled/secure v1.17.0 golang.org/x/crypto v0.21.0 ) @@ -23,7 +22,6 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudflare/circl v1.3.7 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect @@ -35,7 +33,6 @@ require ( github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.2.2 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect @@ -45,5 +42,4 @@ require ( golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.13.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/server/internal/storage/filesystem_test.go b/server/internal/storage/filesystem_test.go index 3e51dca..238fc52 100644 --- a/server/internal/storage/filesystem_test.go +++ b/server/internal/storage/filesystem_test.go @@ -1,46 +1,125 @@ package storage_test import ( + "errors" "io/fs" - "os" - "testing/fstest" + "path/filepath" + "time" ) -// MapFS adapts testing.MapFS to implement our fileSystem interface -type MapFS struct { - fstest.MapFS +type mockDirEntry struct { + name string + isDir bool } -func NewMapFS() *MapFS { - return &MapFS{ - MapFS: make(fstest.MapFS), +func (m *mockDirEntry) Name() string { return m.name } +func (m *mockDirEntry) IsDir() bool { return m.isDir } +func (m *mockDirEntry) Type() fs.FileMode { return fs.ModeDir } +func (m *mockDirEntry) Info() (fs.FileInfo, error) { return nil, nil } + +func NewMockDirEntry(name string, isDir bool) fs.DirEntry { + return &mockDirEntry{name: name, isDir: isDir} +} + +// Extend mockFS to support directory operations +type MockDirInfo struct { + name string + size int64 + mode fs.FileMode + modTime time.Time + isDir bool +} + +func (m MockDirInfo) Name() string { return m.name } +func (m MockDirInfo) Size() int64 { return m.size } +func (m MockDirInfo) Mode() fs.FileMode { return m.mode } +func (m MockDirInfo) ModTime() time.Time { return m.modTime } +func (m MockDirInfo) IsDir() bool { return m.isDir } +func (m MockDirInfo) Sys() interface{} { return nil } + +type mockFS struct { + // Record operations for verification + ReadCalls map[string]int + WriteCalls map[string][]byte + RemoveCalls []string + MkdirCalls []string + + // Configure test behavior + ReadFileReturns map[string]struct { + data []byte + err error + } + ReadDirReturns map[string]struct { + entries []fs.DirEntry + err error + } + WriteFileError error + RemoveError error + MkdirError error + StatError error +} + +func NewMockFS() *mockFS { + return &mockFS{ + ReadCalls: make(map[string]int), + WriteCalls: make(map[string][]byte), + RemoveCalls: make([]string, 0), + MkdirCalls: make([]string, 0), + ReadFileReturns: make(map[string]struct { + data []byte + err error + }), } } -// Only implement the methods that MapFS doesn't already provide -func (m *MapFS) WriteFile(path string, data []byte, perm fs.FileMode) error { - m.MapFS[path] = &fstest.MapFile{ - Data: data, - Mode: perm, +func (m *mockFS) ReadFile(path string) ([]byte, error) { + m.ReadCalls[path]++ + if ret, ok := m.ReadFileReturns[path]; ok { + return ret.data, ret.err } - return nil + return nil, errors.New("file not found") } -func (m *MapFS) Remove(path string) error { - delete(m.MapFS, path) - return nil +func (m *mockFS) WriteFile(path string, data []byte, _ fs.FileMode) error { + m.WriteCalls[path] = data + return m.WriteFileError } -func (m *MapFS) MkdirAll(_ string, _ fs.FileMode) error { - // For MapFS, we don't actually need to create directories - return nil +func (m *mockFS) Remove(path string) error { + m.RemoveCalls = append(m.RemoveCalls, path) + return m.RemoveError } -func (m *MapFS) RemoveAll(path string) error { - delete(m.MapFS, path) - return nil +func (m *mockFS) MkdirAll(path string, _ fs.FileMode) error { + m.MkdirCalls = append(m.MkdirCalls, path) + return m.MkdirError } -func (m *MapFS) IsNotExist(err error) bool { - return os.IsNotExist(err) +func (m *mockFS) Stat(path string) (fs.FileInfo, error) { + if m.StatError != nil { + return nil, m.StatError + } + return MockDirInfo{ + name: filepath.Base(path), + size: 1024, + mode: 0644, + modTime: time.Now(), + isDir: false, + }, nil +} + +func (m *mockFS) ReadDir(path string) ([]fs.DirEntry, error) { + if ret, ok := m.ReadDirReturns[path]; ok { + return ret.entries, ret.err + } + return nil, fs.ErrNotExist +} + +func (m *mockFS) RemoveAll(path string) error { + m.RemoveCalls = append(m.RemoveCalls, path) + return m.RemoveError +} + +func (m *mockFS) IsNotExist(err error) bool { + return err == fs.ErrNotExist } diff --git a/server/internal/storage/service_test.go b/server/internal/storage/service_test.go deleted file mode 100644 index 696f90d..0000000 --- a/server/internal/storage/service_test.go +++ /dev/null @@ -1,12 +0,0 @@ -package storage_test - -import ( - "novamd/internal/storage" - "testing" -) - -func SetupTestService(t *testing.T) (*storage.Service, *MapFS) { - fs := &MapFS{} - srv := storage.NewServiceWithFS("/test/root", fs) - return srv, fs -} From de2c9a6d0ce486d806a54503cec38b57df8001ad Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 19 Nov 2024 21:44:06 +0100 Subject: [PATCH 07/38] Implement files test --- server/internal/storage/files_test.go | 415 +++++++++++++++++++++++--- 1 file changed, 375 insertions(+), 40 deletions(-) diff --git a/server/internal/storage/files_test.go b/server/internal/storage/files_test.go index 39e0c6f..55e1f44 100644 --- a/server/internal/storage/files_test.go +++ b/server/internal/storage/files_test.go @@ -1,60 +1,395 @@ -// Package storage_test provides tests for the storage package. package storage_test import ( + "io/fs" "novamd/internal/storage" - "novamd/internal/testutils" + "path/filepath" "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestListFilesRecursively(t *testing.T) { - tests := []testutils.TestCase{ +// TestFileNode ensures FileNode structs are created correctly +func TestFileNode(t *testing.T) { + testCases := []struct { + name string // name of the test case + node storage.FileNode + want storage.FileNode + }{ { - Name: "empty workspace returns empty list", - Setup: func(t *testing.T, fixtures any) { - fs := fixtures.(*MapFS) - require.NoError(t, fs.MkdirAll("/test/root/1/1", 0755)) + name: "file without children", + node: storage.FileNode{ + ID: "test.md", + Name: "test.md", + Path: "test.md", }, - Fixtures: NewMapFS(), - Validate: func(t *testing.T, result any, err error) { - require.NoError(t, err) - files := result.([]storage.FileNode) - assert.Empty(t, files) + want: storage.FileNode{ + ID: "test.md", + Name: "test.md", + Path: "test.md", }, }, { - Name: "lists files and directories correctly", - Setup: func(t *testing.T, fixtures any) { - fs := fixtures.(*MapFS) - err := fs.WriteFile("/test/root/1/1/file1.md", []byte("content1"), 0644) - require.NoError(t, err, "Failed to write file1.md") - - err = fs.WriteFile("/test/root/1/1/dir/file2.md", []byte("content2"), 0644) - require.NoError(t, err, "Failed to write file2.md") + name: "directory with children", + node: storage.FileNode{ + ID: "dir", + Name: "dir", + Path: "dir", + Children: []storage.FileNode{ + { + ID: "dir/file1.md", + Name: "file1.md", + Path: "dir/file1.md", + }, + }, }, - Fixtures: NewMapFS(), - Validate: func(t *testing.T, result any, err error) { - require.NoError(t, err) - files := result.([]storage.FileNode) - require.Len(t, files, 2) - assert.Equal(t, "dir", files[0].Name) - assert.Equal(t, "file1.md", files[1].Name) - assert.Len(t, files[0].Children, 1) - assert.Equal(t, "file2.md", files[0].Children[0].Name) + want: storage.FileNode{ + ID: "dir", + Name: "dir", + Path: "dir", + Children: []storage.FileNode{ + { + ID: "dir/file1.md", + Name: "file1.md", + Path: "dir/file1.md", + }, + }, }, }, } - for _, tc := range tests { - t.Run(tc.Name, func(t *testing.T) { - fs := tc.Fixtures.(*MapFS) - srv := storage.NewServiceWithFS("/test/root", fs) - tc.Setup(t, tc.Fixtures) - files, err := srv.ListFilesRecursively(1, 1) - tc.Validate(t, files, err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := tc.node // Now we're testing the actual node structure + + if got.ID != tc.want.ID { + t.Errorf("ID = %v, want %v", got.ID, tc.want.ID) + } + if got.Name != tc.want.Name { + t.Errorf("Name = %v, want %v", got.Name, tc.want.Name) + } + if got.Path != tc.want.Path { + t.Errorf("Path = %v, want %v", got.Path, tc.want.Path) + } + if len(got.Children) != len(tc.want.Children) { + t.Errorf("len(Children) = %v, want %v", len(got.Children), len(tc.want.Children)) + } + // Add deep comparison of children if they exist + if len(got.Children) > 0 { + for i := range got.Children { + if got.Children[i].ID != tc.want.Children[i].ID { + t.Errorf("Children[%d].ID = %v, want %v", i, got.Children[i].ID, tc.want.Children[i].ID) + } + if got.Children[i].Name != tc.want.Children[i].Name { + t.Errorf("Children[%d].Name = %v, want %v", i, got.Children[i].Name, tc.want.Children[i].Name) + } + if got.Children[i].Path != tc.want.Children[i].Path { + t.Errorf("Children[%d].Path = %v, want %v", i, got.Children[i].Path, tc.want.Children[i].Path) + } + } + } + }) + } +} + +func TestListFilesRecursively(t *testing.T) { + mockFS := NewMockFS() + s := storage.NewServiceWithFS("test-root", mockFS) + + t.Run("empty directory", func(t *testing.T) { + mockFS.ReadDirReturns = map[string]struct { + entries []fs.DirEntry + err error + }{ + "test-root/1/1": { + entries: []fs.DirEntry{}, + err: nil, + }, + } + + files, err := s.ListFilesRecursively(1, 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 0 { + t.Errorf("expected empty file list, got %v", files) + } + }) + + t.Run("directory with files", func(t *testing.T) { + mockFS.ReadDirReturns = map[string]struct { + entries []fs.DirEntry + err error + }{ + "test-root/1/1": { + entries: []fs.DirEntry{ + NewMockDirEntry("file1.md", false), + NewMockDirEntry("file2.md", false), + }, + err: nil, + }, + } + + files, err := s.ListFilesRecursively(1, 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 2 { + t.Errorf("expected 2 files, got %d", len(files)) + } + }) + + t.Run("nested directories", func(t *testing.T) { + mockFS.ReadDirReturns = map[string]struct { + entries []fs.DirEntry + err error + }{ + "test-root/1/1": { + entries: []fs.DirEntry{ + NewMockDirEntry("dir1", true), + NewMockDirEntry("file1.md", false), + }, + err: nil, + }, + "test-root/1/1/dir1": { + entries: []fs.DirEntry{ + NewMockDirEntry("file2.md", false), + }, + err: nil, + }, + } + + files, err := s.ListFilesRecursively(1, 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 2 { // dir1 and file1.md + t.Errorf("expected 2 entries at root, got %d", len(files)) + } + + // Find directory and check its children + var dirFound bool + for _, f := range files { + if f.Name == "dir1" { + dirFound = true + if len(f.Children) != 1 { + t.Errorf("expected 1 child in dir1, got %d", len(f.Children)) + } + } + } + if !dirFound { + t.Error("directory 'dir1' not found in results") + } + }) +} + +func TestGetFileContent(t *testing.T) { + mockFS := NewMockFS() + s := storage.NewServiceWithFS("test-root", mockFS) + + testCases := []struct { + name string + userID int + workspaceID int + filePath string + mockData []byte + mockErr error + wantErr bool + }{ + { + name: "successful read", + userID: 1, + workspaceID: 1, + filePath: "test.md", + mockData: []byte("test content"), + mockErr: nil, + wantErr: false, + }, + { + name: "file not found", + userID: 1, + workspaceID: 1, + filePath: "nonexistent.md", + mockData: nil, + mockErr: fs.ErrNotExist, + wantErr: true, + }, + { + name: "invalid path", + userID: 1, + workspaceID: 1, + filePath: "../../../etc/passwd", + mockData: nil, + mockErr: nil, + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + expectedPath := filepath.Join("test-root", "1", "1", tc.filePath) + mockFS.ReadFileReturns[expectedPath] = struct { + data []byte + err error + }{tc.mockData, tc.mockErr} + + content, err := s.GetFileContent(tc.userID, tc.workspaceID, tc.filePath) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if string(content) != string(tc.mockData) { + t.Errorf("content = %q, want %q", content, tc.mockData) + } + + if mockFS.ReadCalls[expectedPath] != 1 { + t.Errorf("expected 1 read call for %s, got %d", expectedPath, mockFS.ReadCalls[expectedPath]) + } + }) + } +} + +func TestSaveFile(t *testing.T) { + mockFS := NewMockFS() + s := storage.NewServiceWithFS("test-root", mockFS) + + testCases := []struct { + name string + userID int + workspaceID int + filePath string + content []byte + mockErr error + wantErr bool + }{ + { + name: "successful save", + userID: 1, + workspaceID: 1, + filePath: "test.md", + content: []byte("test content"), + mockErr: nil, + wantErr: false, + }, + { + name: "invalid path", + userID: 1, + workspaceID: 1, + filePath: "../../../etc/passwd", + content: []byte("test content"), + mockErr: nil, + wantErr: true, + }, + { + name: "write error", + userID: 1, + workspaceID: 1, + filePath: "test.md", + content: []byte("test content"), + mockErr: fs.ErrPermission, + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockFS.WriteFileError = tc.mockErr + err := s.SaveFile(tc.userID, tc.workspaceID, tc.filePath, tc.content) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expectedPath := filepath.Join("test-root", "1", "1", tc.filePath) + if content, ok := mockFS.WriteCalls[expectedPath]; ok { + if string(content) != string(tc.content) { + t.Errorf("written content = %q, want %q", content, tc.content) + } + } else { + t.Error("expected write call not made") + } + }) + } +} + +func TestDeleteFile(t *testing.T) { + mockFS := NewMockFS() + s := storage.NewServiceWithFS("test-root", mockFS) + + testCases := []struct { + name string + userID int + workspaceID int + filePath string + mockErr error + wantErr bool + }{ + { + name: "successful delete", + userID: 1, + workspaceID: 1, + filePath: "test.md", + mockErr: nil, + wantErr: false, + }, + { + name: "invalid path", + userID: 1, + workspaceID: 1, + filePath: "../../../etc/passwd", + mockErr: nil, + wantErr: true, + }, + { + name: "file not found", + userID: 1, + workspaceID: 1, + filePath: "nonexistent.md", + mockErr: fs.ErrNotExist, + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockFS.RemoveError = tc.mockErr + err := s.DeleteFile(tc.userID, tc.workspaceID, tc.filePath) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expectedPath := filepath.Join("test-root", "1", "1", tc.filePath) + found := false + for _, p := range mockFS.RemoveCalls { + if p == expectedPath { + found = true + break + } + } + if !found { + t.Error("expected delete call not made") + } }) } } From 53e52bfdb5165535be9c8d7a9349ed490a10ca29 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 19 Nov 2024 22:17:00 +0100 Subject: [PATCH 08/38] Test workspace --- server/internal/storage/workspace.go | 8 + server/internal/storage/workspace_test.go | 263 ++++++++++++++++++++++ 2 files changed, 271 insertions(+) create mode 100644 server/internal/storage/workspace_test.go diff --git a/server/internal/storage/workspace.go b/server/internal/storage/workspace.go index b72155f..7a264ad 100644 --- a/server/internal/storage/workspace.go +++ b/server/internal/storage/workspace.go @@ -24,9 +24,17 @@ type WorkspaceManager interface { // - error: any error that occurred during validation func (s *Service) ValidatePath(userID, workspaceID int, path string) (string, error) { workspacePath := s.GetWorkspacePath(userID, workspaceID) + + // First check if the path is absolute + if filepath.IsAbs(path) { + return "", fmt.Errorf("invalid path: absolute paths not allowed") + } + + // Join and clean the path fullPath := filepath.Join(workspacePath, path) cleanPath := filepath.Clean(fullPath) + // Verify the path is still within the workspace if !strings.HasPrefix(cleanPath, workspacePath) { return "", fmt.Errorf("invalid path: outside of workspace") } diff --git a/server/internal/storage/workspace_test.go b/server/internal/storage/workspace_test.go new file mode 100644 index 0000000..752ef18 --- /dev/null +++ b/server/internal/storage/workspace_test.go @@ -0,0 +1,263 @@ +package storage_test + +import ( + "errors" + "path/filepath" + "strings" + "testing" + + "novamd/internal/storage" +) + +func TestValidatePath(t *testing.T) { + mockFS := NewMockFS() + s := storage.NewServiceWithFS("test-root", mockFS) + + testCases := []struct { + name string + userID int + workspaceID int + path string + want string + wantErr bool + errContains string + }{ + { + name: "valid path", + userID: 1, + workspaceID: 1, + path: "notes/test.md", + want: filepath.Join("test-root", "1", "1", "notes", "test.md"), + wantErr: false, + }, + { + name: "valid path with dot", + userID: 1, + workspaceID: 1, + path: "./notes/test.md", + want: filepath.Join("test-root", "1", "1", "notes", "test.md"), + wantErr: false, + }, + { + name: "path with parent directory traversal", + userID: 1, + workspaceID: 1, + path: "../../../etc/passwd", + want: "", + wantErr: true, + errContains: "outside of workspace", + }, + { + name: "absolute path attempt", + userID: 1, + workspaceID: 1, + path: "/etc/passwd", + want: "", + wantErr: true, + errContains: "absolute paths not allowed", + }, + { + name: "empty path", + userID: 1, + workspaceID: 1, + path: "", + want: filepath.Join("test-root", "1", "1"), + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := s.ValidatePath(tc.userID, tc.workspaceID, tc.path) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + return + } + if !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %q", err, tc.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got != tc.want { + t.Errorf("ValidatePath() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestGetWorkspacePath(t *testing.T) { + mockFS := NewMockFS() + s := storage.NewServiceWithFS("test-root", mockFS) + + testCases := []struct { + name string + userID int + workspaceID int + want string + }{ + { + name: "standard workspace path", + userID: 1, + workspaceID: 1, + want: filepath.Join("test-root", "1", "1"), + }, + { + name: "different user and workspace IDs", + userID: 2, + workspaceID: 3, + want: filepath.Join("test-root", "2", "3"), + }, + { + name: "zero IDs", + userID: 0, + workspaceID: 0, + want: filepath.Join("test-root", "0", "0"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := s.GetWorkspacePath(tc.userID, tc.workspaceID) + if got != tc.want { + t.Errorf("GetWorkspacePath() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestInitializeUserWorkspace(t *testing.T) { + mockFS := NewMockFS() + s := storage.NewServiceWithFS("test-root", mockFS) + + testCases := []struct { + name string + userID int + workspaceID int + mockErr error + wantErr bool + errContains string + }{ + { + name: "successful initialization", + userID: 1, + workspaceID: 1, + mockErr: nil, + wantErr: false, + }, + { + name: "mkdir error", + userID: 1, + workspaceID: 1, + mockErr: errors.New("permission denied"), + wantErr: true, + errContains: "failed to create workspace directory", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockFS.MkdirError = tc.mockErr + err := s.InitializeUserWorkspace(tc.userID, tc.workspaceID) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + return + } + if !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %q", err, tc.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify the correct directory was created + expectedPath := filepath.Join("test-root", "1", "1") + dirCreated := false + for _, path := range mockFS.MkdirCalls { + if path == expectedPath { + dirCreated = true + break + } + } + if !dirCreated { + t.Errorf("directory %s was not created", expectedPath) + } + }) + } +} + +func TestDeleteUserWorkspace(t *testing.T) { + mockFS := NewMockFS() + s := storage.NewServiceWithFS("test-root", mockFS) + + testCases := []struct { + name string + userID int + workspaceID int + mockErr error + wantErr bool + errContains string + }{ + { + name: "successful deletion", + userID: 1, + workspaceID: 1, + mockErr: nil, + wantErr: false, + }, + { + name: "removal error", + userID: 1, + workspaceID: 1, + mockErr: errors.New("permission denied"), + wantErr: true, + errContains: "failed to delete workspace directory", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockFS.RemoveError = tc.mockErr + err := s.DeleteUserWorkspace(tc.userID, tc.workspaceID) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + return + } + if !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %q", err, tc.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify the correct directory was deleted + expectedPath := filepath.Join("test-root", "1", "1") + dirDeleted := false + for _, path := range mockFS.RemoveCalls { + if path == expectedPath { + dirDeleted = true + break + } + } + if !dirDeleted { + t.Errorf("directory %s was not deleted", expectedPath) + } + }) + } +} From 7396b57a5df06f6982673e99dda11d5046398a71 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 19 Nov 2024 22:43:24 +0100 Subject: [PATCH 09/38] Rework gitutils package to make it testable --- server/internal/git/client.go | 176 +++++++++++++++++++++++++++++ server/internal/gitutils/git.go | 133 ---------------------- server/internal/storage/git.go | 8 +- server/internal/storage/service.go | 6 +- 4 files changed, 183 insertions(+), 140 deletions(-) create mode 100644 server/internal/git/client.go delete mode 100644 server/internal/gitutils/git.go diff --git a/server/internal/git/client.go b/server/internal/git/client.go new file mode 100644 index 0000000..82790d5 --- /dev/null +++ b/server/internal/git/client.go @@ -0,0 +1,176 @@ +// Package git provides functionalities to interact with Git repositories, including cloning, pulling, committing, and pushing changes. +package git + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/plumbing/transport/http" +) + +// Config holds the configuration for a Git client +type Config struct { + URL string + Username string + Token string + WorkDir string +} + +// Client defines the interface for Git operations +type Client interface { + Clone() error + Pull() error + Commit(message string) error + Push() error + EnsureRepo() error +} + +// client implements the Client interface +type client struct { + Config + repo *git.Repository +} + +// New creates a new Client instance +// Parameters: +// - url: the URL of the Git repository +// - username: the username for the Git repository +// - token: the access token for the Git repository +// - workDir: the local directory to clone the repository to +// Returns: +// - Client: the Git client +func New(url, username, token, workDir string) Client { + return &client{ + Config: Config{ + URL: url, + Username: username, + Token: token, + WorkDir: workDir, + }, + } +} + +// Clone clones the Git repository to the local directory +// Returns: +// - error: any error that occurred during cloning +func (c *client) Clone() error { + auth := &http.BasicAuth{ + Username: c.Username, + Password: c.Token, + } + + var err error + c.repo, err = git.PlainClone(c.WorkDir, false, &git.CloneOptions{ + URL: c.URL, + Auth: auth, + Progress: os.Stdout, + }) + + if err != nil { + return fmt.Errorf("failed to clone repository: %w", err) + } + + return nil +} + +// Pull pulls the latest changes from the remote repository +// Returns: +// - error: any error that occurred during pulling +func (c *client) Pull() error { + if c.repo == nil { + return fmt.Errorf("repository not initialized") + } + + w, err := c.repo.Worktree() + if err != nil { + return fmt.Errorf("failed to get worktree: %w", err) + } + + auth := &http.BasicAuth{ + Username: c.Username, + Password: c.Token, + } + + err = w.Pull(&git.PullOptions{ + Auth: auth, + Progress: os.Stdout, + }) + + if err != nil && err != git.NoErrAlreadyUpToDate { + return fmt.Errorf("failed to pull changes: %w", err) + } + + return nil +} + +// Commit commits the changes in the repository +// Parameters: +// - message: the commit message +// Returns: +// - error: any error that occurred during committing +func (c *client) Commit(message string) error { + if c.repo == nil { + return fmt.Errorf("repository not initialized") + } + + w, err := c.repo.Worktree() + if err != nil { + return fmt.Errorf("failed to get worktree: %w", err) + } + + _, err = w.Add(".") + if err != nil { + return fmt.Errorf("failed to add changes: %w", err) + } + + _, err = w.Commit(message, &git.CommitOptions{}) + if err != nil { + return fmt.Errorf("failed to commit changes: %w", err) + } + + return nil +} + +// Push pushes the changes to the remote repository +// Returns: +// - error: any error that occurred during pushing +func (c *client) Push() error { + if c.repo == nil { + return fmt.Errorf("repository not initialized") + } + + auth := &http.BasicAuth{ + Username: c.Username, + Password: c.Token, + } + + err := c.repo.Push(&git.PushOptions{ + Auth: auth, + Progress: os.Stdout, + }) + + if err != nil && err != git.NoErrAlreadyUpToDate { + return fmt.Errorf("failed to push changes: %w", err) + } + + return nil +} + +// EnsureRepo ensures the local repository is up-to-date +// Returns: +// - error: any error that occurred during the operation +func (c *client) EnsureRepo() error { + if _, err := os.Stat(filepath.Join(c.WorkDir, ".git")); os.IsNotExist(err) { + return c.Clone() + } + + var err error + c.repo, err = git.PlainOpen(c.WorkDir) + if err != nil { + return fmt.Errorf("failed to open existing repository: %w", err) + } + + return c.Pull() +} diff --git a/server/internal/gitutils/git.go b/server/internal/gitutils/git.go deleted file mode 100644 index 1ee3468..0000000 --- a/server/internal/gitutils/git.go +++ /dev/null @@ -1,133 +0,0 @@ -package gitutils - -import ( - "fmt" - "os" - "path/filepath" - - "github.com/go-git/go-git/v5" - "github.com/go-git/go-git/v5/plumbing/transport/http" -) - -type GitRepo struct { - URL string - Username string - Token string - WorkDir string - repo *git.Repository -} - -func New(url, username, token, workDir string) *GitRepo { - return &GitRepo{ - URL: url, - Username: username, - Token: token, - WorkDir: workDir, - } -} - -func (g *GitRepo) Clone() error { - auth := &http.BasicAuth{ - Username: g.Username, - Password: g.Token, - } - - var err error - g.repo, err = git.PlainClone(g.WorkDir, false, &git.CloneOptions{ - URL: g.URL, - Auth: auth, - Progress: os.Stdout, - }) - - if err != nil { - return fmt.Errorf("failed to clone repository: %w", err) - } - - return nil -} - -func (g *GitRepo) Pull() error { - if g.repo == nil { - return fmt.Errorf("repository not initialized") - } - - w, err := g.repo.Worktree() - if err != nil { - return fmt.Errorf("failed to get worktree: %w", err) - } - - auth := &http.BasicAuth{ - Username: g.Username, - Password: g.Token, - } - - err = w.Pull(&git.PullOptions{ - Auth: auth, - Progress: os.Stdout, - }) - - if err != nil && err != git.NoErrAlreadyUpToDate { - return fmt.Errorf("failed to pull changes: %w", err) - } - - return nil -} - -func (g *GitRepo) Commit(message string) error { - if g.repo == nil { - return fmt.Errorf("repository not initialized") - } - - w, err := g.repo.Worktree() - if err != nil { - return fmt.Errorf("failed to get worktree: %w", err) - } - - _, err = w.Add(".") - if err != nil { - return fmt.Errorf("failed to add changes: %w", err) - } - - _, err = w.Commit(message, &git.CommitOptions{}) - if err != nil { - return fmt.Errorf("failed to commit changes: %w", err) - } - - return nil -} - -func (g *GitRepo) Push() error { - if g.repo == nil { - return fmt.Errorf("repository not initialized") - } - - auth := &http.BasicAuth{ - Username: g.Username, - Password: g.Token, - } - - err := g.repo.Push(&git.PushOptions{ - Auth: auth, - Progress: os.Stdout, - }) - - if err != nil && err != git.NoErrAlreadyUpToDate { - return fmt.Errorf("failed to push changes: %w", err) - } - - return nil -} - -func (g *GitRepo) EnsureRepo() error { - if _, err := os.Stat(filepath.Join(g.WorkDir, ".git")); os.IsNotExist(err) { - return g.Clone() - } - - var err error - g.repo, err = git.PlainOpen(g.WorkDir) - if err != nil { - return fmt.Errorf("failed to open existing repository: %w", err) - } - - return g.Pull() -} diff --git a/server/internal/storage/git.go b/server/internal/storage/git.go index cebd1b0..a626c92 100644 --- a/server/internal/storage/git.go +++ b/server/internal/storage/git.go @@ -2,7 +2,7 @@ package storage import ( "fmt" - "novamd/internal/gitutils" + "novamd/internal/git" ) // RepositoryManager defines the interface for managing Git repositories. @@ -25,9 +25,9 @@ type RepositoryManager interface { func (s *Service) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken string) error { workspacePath := s.GetWorkspacePath(userID, workspaceID) if _, ok := s.GitRepos[userID]; !ok { - s.GitRepos[userID] = make(map[int]*gitutils.GitRepo) + s.GitRepos[userID] = make(map[int]git.Client) } - s.GitRepos[userID][workspaceID] = gitutils.New(gitURL, gitUser, gitToken, workspacePath) + s.GitRepos[userID][workspaceID] = git.New(gitURL, gitUser, gitToken, workspacePath) return s.GitRepos[userID][workspaceID].EnsureRepo() } @@ -80,7 +80,7 @@ func (s *Service) Pull(userID, workspaceID int) error { } // getGitRepo returns the Git repository for the given user and workspace IDs. -func (s *Service) getGitRepo(userID, workspaceID int) (*gitutils.GitRepo, bool) { +func (s *Service) getGitRepo(userID, workspaceID int) (git.Client, bool) { userRepos, ok := s.GitRepos[userID] if !ok { return nil, false diff --git a/server/internal/storage/service.go b/server/internal/storage/service.go index 5b57b4b..985b66f 100644 --- a/server/internal/storage/service.go +++ b/server/internal/storage/service.go @@ -1,7 +1,7 @@ package storage import ( - "novamd/internal/gitutils" + "novamd/internal/git" ) // Manager interface combines all storage interfaces. @@ -15,7 +15,7 @@ type Manager interface { type Service struct { fs fileSystem RootDir string - GitRepos map[int]map[int]*gitutils.GitRepo // map[userID]map[workspaceID]*gitutils.GitRepo + GitRepos map[int]map[int]git.Client // map[userID]map[workspaceID]*git.Client } // NewService creates a new Storage instance. @@ -37,6 +37,6 @@ func NewServiceWithFS(rootDir string, fs fileSystem) *Service { return &Service{ fs: fs, RootDir: rootDir, - GitRepos: make(map[int]map[int]*gitutils.GitRepo), + GitRepos: make(map[int]map[int]git.Client), } } From 6cb5aec3726d58bd56b3f14dabbe5c82a41925cd Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 20 Nov 2024 22:06:38 +0100 Subject: [PATCH 10/38] Implement storage git tests --- server/internal/storage/files_test.go | 20 +- server/internal/storage/git.go | 2 +- server/internal/storage/git_test.go | 258 ++++++++++++++++++++++ server/internal/storage/service.go | 31 ++- server/internal/storage/workspace_test.go | 20 +- 5 files changed, 312 insertions(+), 19 deletions(-) create mode 100644 server/internal/storage/git_test.go diff --git a/server/internal/storage/files_test.go b/server/internal/storage/files_test.go index 55e1f44..d5bf20b 100644 --- a/server/internal/storage/files_test.go +++ b/server/internal/storage/files_test.go @@ -92,7 +92,10 @@ func TestFileNode(t *testing.T) { func TestListFilesRecursively(t *testing.T) { mockFS := NewMockFS() - s := storage.NewServiceWithFS("test-root", mockFS) + s := storage.NewServiceWithOptions("test-root", storage.Options{ + Fs: mockFS, + NewGitClient: nil, + }) t.Run("empty directory", func(t *testing.T) { mockFS.ReadDirReturns = map[string]struct { @@ -183,7 +186,10 @@ func TestListFilesRecursively(t *testing.T) { func TestGetFileContent(t *testing.T) { mockFS := NewMockFS() - s := storage.NewServiceWithFS("test-root", mockFS) + s := storage.NewServiceWithOptions("test-root", storage.Options{ + Fs: mockFS, + NewGitClient: nil, + }) testCases := []struct { name string @@ -257,7 +263,10 @@ func TestGetFileContent(t *testing.T) { func TestSaveFile(t *testing.T) { mockFS := NewMockFS() - s := storage.NewServiceWithFS("test-root", mockFS) + s := storage.NewServiceWithOptions("test-root", storage.Options{ + Fs: mockFS, + NewGitClient: nil, + }) testCases := []struct { name string @@ -327,7 +336,10 @@ func TestSaveFile(t *testing.T) { func TestDeleteFile(t *testing.T) { mockFS := NewMockFS() - s := storage.NewServiceWithFS("test-root", mockFS) + s := storage.NewServiceWithOptions("test-root", storage.Options{ + Fs: mockFS, + NewGitClient: nil, + }) testCases := []struct { name string diff --git a/server/internal/storage/git.go b/server/internal/storage/git.go index a626c92..09d3b0f 100644 --- a/server/internal/storage/git.go +++ b/server/internal/storage/git.go @@ -27,7 +27,7 @@ func (s *Service) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToke if _, ok := s.GitRepos[userID]; !ok { s.GitRepos[userID] = make(map[int]git.Client) } - s.GitRepos[userID][workspaceID] = git.New(gitURL, gitUser, gitToken, workspacePath) + s.GitRepos[userID][workspaceID] = s.newGitClient(gitURL, gitUser, gitToken, workspacePath) return s.GitRepos[userID][workspaceID].EnsureRepo() } diff --git a/server/internal/storage/git_test.go b/server/internal/storage/git_test.go new file mode 100644 index 0000000..b18ce35 --- /dev/null +++ b/server/internal/storage/git_test.go @@ -0,0 +1,258 @@ +package storage_test + +import ( + "errors" + "testing" + + "novamd/internal/git" + "novamd/internal/storage" +) + +// MockGitClient implements git.Client interface for testing +type MockGitClient struct { + CloneCalled bool + PullCalled bool + CommitCalled bool + PushCalled bool + EnsureCalled bool + CommitMessage string + ReturnError error +} + +func (m *MockGitClient) Clone() error { + m.CloneCalled = true + return m.ReturnError +} + +func (m *MockGitClient) Pull() error { + m.PullCalled = true + return m.ReturnError +} + +func (m *MockGitClient) Commit(message string) error { + m.CommitCalled = true + m.CommitMessage = message + return m.ReturnError +} + +func (m *MockGitClient) Push() error { + m.PushCalled = true + return m.ReturnError +} + +func (m *MockGitClient) EnsureRepo() error { + m.EnsureCalled = true + return m.ReturnError +} + +func TestSetupGitRepo(t *testing.T) { + mockFS := NewMockFS() + + testCases := []struct { + name string + userID int + workspaceID int + gitURL string + gitUser string + gitToken string + mockErr error + wantErr bool + }{ + { + name: "successful setup", + userID: 1, + workspaceID: 1, + gitURL: "https://github.com/user/repo", + gitUser: "user", + gitToken: "token", + mockErr: nil, + wantErr: false, + }, + { + name: "git initialization error", + userID: 1, + workspaceID: 2, + gitURL: "https://github.com/user/repo", + gitUser: "user", + gitToken: "token", + mockErr: errors.New("git initialization failed"), + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a mock client with the desired error behavior + mockClient := &MockGitClient{ReturnError: tc.mockErr} + + // Create a client factory that returns our configured mock + mockClientFactory := func(_, _, _, _ string) git.Client { + return mockClient + } + + s := storage.NewServiceWithOptions("test-root", storage.Options{ + Fs: mockFS, + NewGitClient: mockClientFactory, + }) + + // Setup the git repo + err := s.SetupGitRepo(tc.userID, tc.workspaceID, tc.gitURL, tc.gitUser, tc.gitToken) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Check if client was stored correctly + client, ok := s.GitRepos[tc.userID][tc.workspaceID] + if !ok { + t.Fatal("git client was not stored in service") + } + + if !mockClient.EnsureCalled { + t.Error("EnsureRepo was not called") + } + + // Verify it's our mock client + if client != mockClient { + t.Error("stored client is not our mock client") + } + }) + } +} + +func TestGitOperations(t *testing.T) { + mockFS := NewMockFS() + s := storage.NewServiceWithOptions("test-root", storage.Options{ + Fs: mockFS, + NewGitClient: func(_, _, _, _ string) git.Client { return &MockGitClient{} }, + }) + + t.Run("operations on non-configured workspace", func(t *testing.T) { + err := s.StageCommitAndPush(1, 1, "test commit") + if err == nil { + t.Error("expected error for non-configured workspace, got nil") + } + + err = s.Pull(1, 1) + if err == nil { + t.Error("expected error for non-configured workspace, got nil") + } + }) + + t.Run("successful operations", func(t *testing.T) { + // Initialize GitRepos map + s.GitRepos = make(map[int]map[int]git.Client) + s.GitRepos[1] = make(map[int]git.Client) + mockClient := &MockGitClient{} + s.GitRepos[1][1] = mockClient + + // Test commit and push + err := s.StageCommitAndPush(1, 1, "test commit") + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if !mockClient.CommitCalled { + t.Error("Commit was not called") + } + if mockClient.CommitMessage != "test commit" { + t.Errorf("Commit message = %q, want %q", mockClient.CommitMessage, "test commit") + } + if !mockClient.PushCalled { + t.Error("Push was not called") + } + + // Test pull + err = s.Pull(1, 1) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if !mockClient.PullCalled { + t.Error("Pull was not called") + } + }) + + t.Run("operation errors", func(t *testing.T) { + // Initialize GitRepos map with error-returning client + s.GitRepos = make(map[int]map[int]git.Client) + s.GitRepos[1] = make(map[int]git.Client) + mockClient := &MockGitClient{ReturnError: errors.New("git operation failed")} + s.GitRepos[1][1] = mockClient + + // Test commit error + err := s.StageCommitAndPush(1, 1, "test commit") + if err == nil { + t.Error("expected error for commit, got nil") + } + + // Test pull error + err = s.Pull(1, 1) + if err == nil { + t.Error("expected error for pull, got nil") + } + }) +} + +func TestDisableGitRepo(t *testing.T) { + mockFS := NewMockFS() + s := storage.NewServiceWithOptions("test-root", storage.Options{ + Fs: mockFS, + NewGitClient: func(_, _, _, _ string) git.Client { return &MockGitClient{} }, + }) + + testCases := []struct { + name string + userID int + workspaceID int + setupRepo bool + }{ + { + name: "disable existing repo", + userID: 1, + workspaceID: 1, + setupRepo: true, + }, + { + name: "disable non-existent repo", + userID: 2, + workspaceID: 1, + setupRepo: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Reset GitRepos for each test + s.GitRepos = make(map[int]map[int]git.Client) + + if tc.setupRepo { + // Setup initial repo + s.GitRepos[tc.userID] = make(map[int]git.Client) + s.GitRepos[tc.userID][tc.workspaceID] = &MockGitClient{} + } + + // Disable the repo + s.DisableGitRepo(tc.userID, tc.workspaceID) + + // Verify repo was removed + if userRepos, exists := s.GitRepos[tc.userID]; exists { + if _, repoExists := userRepos[tc.workspaceID]; repoExists { + t.Error("git repo still exists after disable") + } + } + + // If this was the user's last repo, verify user entry was cleaned up + if tc.setupRepo { + if len(s.GitRepos[tc.userID]) > 0 { + t.Error("user's git repos map not cleaned up when last repo removed") + } + } + }) + } +} diff --git a/server/internal/storage/service.go b/server/internal/storage/service.go index 985b66f..758e4d2 100644 --- a/server/internal/storage/service.go +++ b/server/internal/storage/service.go @@ -13,9 +13,16 @@ type Manager interface { // Service represents the file system structure. type Service struct { - fs fileSystem - RootDir string - GitRepos map[int]map[int]git.Client // map[userID]map[workspaceID]*git.Client + fs fileSystem + newGitClient func(url, user, token, path string) git.Client + RootDir string + GitRepos map[int]map[int]git.Client // map[userID]map[workspaceID]*git.Client +} + +// Options represents the options for the storage service. +type Options struct { + Fs fileSystem + NewGitClient func(url, user, token, path string) git.Client } // NewService creates a new Storage instance. @@ -24,19 +31,23 @@ type Service struct { // Returns: // - result: the new Storage instance func NewService(rootDir string) *Service { - return NewServiceWithFS(rootDir, &osFS{}) + return NewServiceWithOptions(rootDir, Options{ + Fs: &osFS{}, + NewGitClient: git.New, + }) } -// NewServiceWithFS creates a new Storage instance with the given filesystem. +// NewServiceWithOptions creates a new Storage instance with the given options. // Parameters: // - rootDir: the root directory for the storage -// - fs: the filesystem implementation to use +// - opts: the options for the storage service // Returns: // - result: the new Storage instance -func NewServiceWithFS(rootDir string, fs fileSystem) *Service { +func NewServiceWithOptions(rootDir string, opts Options) *Service { return &Service{ - fs: fs, - RootDir: rootDir, - GitRepos: make(map[int]map[int]git.Client), + fs: opts.Fs, + newGitClient: opts.NewGitClient, + RootDir: rootDir, + GitRepos: make(map[int]map[int]git.Client), } } diff --git a/server/internal/storage/workspace_test.go b/server/internal/storage/workspace_test.go index 752ef18..f6b1607 100644 --- a/server/internal/storage/workspace_test.go +++ b/server/internal/storage/workspace_test.go @@ -11,7 +11,10 @@ import ( func TestValidatePath(t *testing.T) { mockFS := NewMockFS() - s := storage.NewServiceWithFS("test-root", mockFS) + s := storage.NewServiceWithOptions("test-root", storage.Options{ + Fs: mockFS, + NewGitClient: nil, + }) testCases := []struct { name string @@ -94,7 +97,10 @@ func TestValidatePath(t *testing.T) { func TestGetWorkspacePath(t *testing.T) { mockFS := NewMockFS() - s := storage.NewServiceWithFS("test-root", mockFS) + s := storage.NewServiceWithOptions("test-root", storage.Options{ + Fs: mockFS, + NewGitClient: nil, + }) testCases := []struct { name string @@ -134,7 +140,10 @@ func TestGetWorkspacePath(t *testing.T) { func TestInitializeUserWorkspace(t *testing.T) { mockFS := NewMockFS() - s := storage.NewServiceWithFS("test-root", mockFS) + s := storage.NewServiceWithOptions("test-root", storage.Options{ + Fs: mockFS, + NewGitClient: nil, + }) testCases := []struct { name string @@ -199,7 +208,10 @@ func TestInitializeUserWorkspace(t *testing.T) { func TestDeleteUserWorkspace(t *testing.T) { mockFS := NewMockFS() - s := storage.NewServiceWithFS("test-root", mockFS) + s := storage.NewServiceWithOptions("test-root", storage.Options{ + Fs: mockFS, + NewGitClient: nil, + }) testCases := []struct { name string From 435dce89d99dd0b3119b643a4b765840283387c5 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 21 Nov 2024 19:42:50 +0100 Subject: [PATCH 11/38] Add go test workflow --- .github/workflows/go-test.yml | 39 +++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 .github/workflows/go-test.yml diff --git a/.github/workflows/go-test.yml b/.github/workflows/go-test.yml new file mode 100644 index 0000000..108f629 --- /dev/null +++ b/.github/workflows/go-test.yml @@ -0,0 +1,39 @@ +name: Go Tests + +on: + push: + branches: + - "*" + pull_request: + branches: + - main + +jobs: + test: + name: Run Tests + runs-on: ubuntu-latest + + defaults: + run: + working-directory: ./server + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.23" + cache: true + + # - name: Install dependencies + # run: | + # sudo apt-get update + # sudo apt-get install -y gcc + + - name: Run Tests + run: go test ./... -v + + - name: Run Tests with Race Detector + run: go test -race ./... -v From 2faefb6db52bd0e6c903658c5d6e0566b90a4f63 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 21 Nov 2024 21:25:29 +0100 Subject: [PATCH 12/38] Implement JWTManager interface --- server/cmd/server/main.go | 6 ++--- server/internal/auth/jwt.go | 27 +++++++++++++++-------- server/internal/auth/middleware.go | 35 ++++++++++++++++++++++++------ server/internal/auth/session.go | 20 ++++++++--------- 4 files changed, 59 insertions(+), 29 deletions(-) diff --git a/server/cmd/server/main.go b/server/cmd/server/main.go index c6204ba..951d8b8 100644 --- a/server/cmd/server/main.go +++ b/server/cmd/server/main.go @@ -48,7 +48,7 @@ func main() { s := storage.NewService(cfg.WorkDir) // Initialize JWT service - jwtService, err := auth.NewJWTService(auth.JWTConfig{ + jwtManager, err := auth.NewJWTService(auth.JWTConfig{ SigningKey: signingKey, AccessTokenExpiry: 15 * time.Minute, RefreshTokenExpiry: 7 * 24 * time.Hour, @@ -58,10 +58,10 @@ func main() { } // Initialize auth middleware - authMiddleware := auth.NewMiddleware(jwtService) + authMiddleware := auth.NewMiddleware(jwtManager) // Initialize session service - sessionService := auth.NewSessionService(database.DB, jwtService) + sessionService := auth.NewSessionService(database.DB, jwtManager) // Set up router r := chi.NewRouter() diff --git a/server/internal/auth/jwt.go b/server/internal/auth/jwt.go index b1c0480..9b65db8 100644 --- a/server/internal/auth/jwt.go +++ b/server/internal/auth/jwt.go @@ -1,3 +1,4 @@ +// Package auth provides JWT token generation and validation package auth import ( @@ -30,14 +31,22 @@ type JWTConfig struct { RefreshTokenExpiry time.Duration // How long refresh tokens are valid } -// JWTService handles JWT token generation and validation -type JWTService struct { +// JWTManager defines the interface for managing JWT tokens +type JWTManager interface { + GenerateAccessToken(userID int, role string) (string, error) + GenerateRefreshToken(userID int, role string) (string, error) + ValidateToken(tokenString string) (*Claims, error) + RefreshAccessToken(refreshToken string) (string, error) +} + +// jwtService handles JWT token generation and validation +type jwtService struct { config JWTConfig } // NewJWTService creates a new JWT service with the provided configuration // Returns an error if the signing key is missing -func NewJWTService(config JWTConfig) (*JWTService, error) { +func NewJWTService(config JWTConfig) (JWTManager, error) { if config.SigningKey == "" { return nil, fmt.Errorf("signing key is required") } @@ -48,7 +57,7 @@ func NewJWTService(config JWTConfig) (*JWTService, error) { if config.RefreshTokenExpiry == 0 { config.RefreshTokenExpiry = 7 * 24 * time.Hour // Default to 7 days } - return &JWTService{config: config}, nil + return &jwtService{config: config}, nil } // GenerateAccessToken creates a new access token for a user @@ -56,7 +65,7 @@ func NewJWTService(config JWTConfig) (*JWTService, error) { // - userID: the ID of the user // - role: the role of the user // Returns the signed token string or an error -func (s *JWTService) GenerateAccessToken(userID int, role string) (string, error) { +func (s *jwtService) GenerateAccessToken(userID int, role string) (string, error) { return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry) } @@ -65,7 +74,7 @@ func (s *JWTService) GenerateAccessToken(userID int, role string) (string, error // - userID: the ID of the user // - role: the role of the user // Returns the signed token string or an error -func (s *JWTService) GenerateRefreshToken(userID int, role string) (string, error) { +func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, error) { return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry) } @@ -76,7 +85,7 @@ func (s *JWTService) GenerateRefreshToken(userID int, role string) (string, erro // - tokenType: the type of token (access or refresh) // - expiry: how long the token should be valid // Returns the signed token string or an error -func (s *JWTService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) { +func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) { now := time.Now() claims := Claims{ RegisteredClaims: jwt.RegisteredClaims{ @@ -97,7 +106,7 @@ func (s *JWTService) generateToken(userID int, role string, tokenType TokenType, // Parameters: // - tokenString: the token to validate // Returns the token claims if valid, or an error if invalid -func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) { +func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { // Validate the signing method if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { @@ -121,7 +130,7 @@ func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) { // Parameters: // - refreshToken: the refresh token to use // Returns a new access token if the refresh token is valid, or an error -func (s *JWTService) RefreshAccessToken(refreshToken string) (string, error) { +func (s *jwtService) RefreshAccessToken(refreshToken string) (string, error) { claims, err := s.ValidateToken(refreshToken) if err != nil { return "", fmt.Errorf("invalid refresh token: %w", err) diff --git a/server/internal/auth/middleware.go b/server/internal/auth/middleware.go index da2713d..864ebd3 100644 --- a/server/internal/auth/middleware.go +++ b/server/internal/auth/middleware.go @@ -11,9 +11,8 @@ import ( type contextKey string -const ( - UserContextKey contextKey = "user" -) +// UserContextKey is the key used to store user claims in the request context +const UserContextKey contextKey = "user" // UserClaims represents the user information stored in the request context type UserClaims struct { @@ -23,17 +22,25 @@ type UserClaims struct { // Middleware handles JWT authentication for protected routes type Middleware struct { - jwtService *JWTService + jwtManager JWTManager } // NewMiddleware creates a new authentication middleware -func NewMiddleware(jwtService *JWTService) *Middleware { +// Parameters: +// - jwtManager: the JWT manager to use for token operations +// Returns: +// - *Middleware: the new middleware instance +func NewMiddleware(jwtManager JWTManager) *Middleware { return &Middleware{ - jwtService: jwtService, + jwtManager: jwtManager, } } // Authenticate middleware validates JWT tokens and sets user information in context +// Parameters: +// - next: the next handler to call +// Returns: +// - http.Handler: the handler function func (m *Middleware) Authenticate(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Extract token from Authorization header @@ -51,7 +58,7 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler { } // Validate token - claims, err := m.jwtService.ValidateToken(parts[1]) + claims, err := m.jwtManager.ValidateToken(parts[1]) if err != nil { http.Error(w, "Invalid token", http.StatusUnauthorized) return @@ -75,6 +82,10 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler { } // RequireRole returns a middleware that ensures the user has the required role +// Parameters: +// - role: the required role +// Returns: +// - func(http.Handler) http.Handler: the middleware function func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -94,6 +105,11 @@ func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler { } } +// RequireWorkspaceAccess returns a middleware that ensures the user has access to the workspace +// Parameters: +// - next: the next handler to call +// Returns: +// - http.Handler: the handler function func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get our handler context @@ -119,6 +135,11 @@ func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler { } // GetUserFromContext retrieves user claims from the request context +// Parameters: +// - ctx: the request context +// Returns: +// - *UserClaims: the user claims +// - error: any error that occurred func GetUserFromContext(ctx context.Context) (*UserClaims, error) { claims, ok := ctx.Value(UserContextKey).(UserClaims) if !ok { diff --git a/server/internal/auth/session.go b/server/internal/auth/session.go index 8168ccc..bc0d165 100644 --- a/server/internal/auth/session.go +++ b/server/internal/auth/session.go @@ -19,18 +19,18 @@ type Session struct { // SessionService manages user sessions in the database type SessionService struct { - db *sql.DB // Database connection - jwtService *JWTService // JWT service for token operations + db *sql.DB // Database connection + jwtManager JWTManager // JWT Manager for token operations } // NewSessionService creates a new session service // Parameters: // - db: database connection -// - jwtService: JWT service for token operations -func NewSessionService(db *sql.DB, jwtService *JWTService) *SessionService { +// - jwtManager: JWT service for token operations +func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService { return &SessionService{ db: db, - jwtService: jwtService, + jwtManager: jwtManager, } } @@ -44,18 +44,18 @@ func NewSessionService(db *sql.DB, jwtService *JWTService) *SessionService { // - error: any error that occurred func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) { // Generate both access and refresh tokens - accessToken, err := s.jwtService.GenerateAccessToken(userID, role) + accessToken, err := s.jwtManager.GenerateAccessToken(userID, role) if err != nil { return nil, "", fmt.Errorf("failed to generate access token: %w", err) } - refreshToken, err := s.jwtService.GenerateRefreshToken(userID, role) + refreshToken, err := s.jwtManager.GenerateRefreshToken(userID, role) if err != nil { return nil, "", fmt.Errorf("failed to generate refresh token: %w", err) } // Validate the refresh token to get its expiry time - claims, err := s.jwtService.ValidateToken(refreshToken) + claims, err := s.jwtManager.ValidateToken(refreshToken) if err != nil { return nil, "", fmt.Errorf("failed to validate refresh token: %w", err) } @@ -90,7 +90,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin // - error: any error that occurred func (s *SessionService) RefreshSession(refreshToken string) (string, error) { // Validate the refresh token - claims, err := s.jwtService.ValidateToken(refreshToken) + claims, err := s.jwtManager.ValidateToken(refreshToken) if err != nil { return "", fmt.Errorf("invalid refresh token: %w", err) } @@ -112,7 +112,7 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) { } // Generate a new access token - return s.jwtService.GenerateAccessToken(claims.UserID, claims.Role) + return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role) } // InvalidateSession removes a session from the database From 807e96a76c885a7717ae3e89245b20edc3236939 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 21 Nov 2024 22:36:12 +0100 Subject: [PATCH 13/38] Rework db package to make it testable --- server/cmd/server/main.go | 3 +- server/internal/api/routes.go | 2 +- server/internal/auth/session.go | 67 +++++------------ server/internal/db/admin.go | 68 ------------------ server/internal/db/db.go | 72 +++++++++++++++++-- server/internal/db/migrations.go | 2 +- server/internal/db/sessions.go | 70 ++++++++++++++++++ .../db/{system_settings.go => system.go} | 42 ++++++++++- server/internal/db/users.go | 51 ++++++++++--- server/internal/db/workspaces.go | 24 +++---- server/internal/handlers/auth_handlers.go | 8 +-- server/internal/handlers/handlers.go | 4 +- server/internal/handlers/user_handlers.go | 3 +- server/internal/middleware/context.go | 2 +- server/internal/models/session.go | 13 ++++ server/internal/user/user.go | 4 +- 16 files changed, 274 insertions(+), 161 deletions(-) delete mode 100644 server/internal/db/admin.go create mode 100644 server/internal/db/sessions.go rename server/internal/db/{system_settings.go => system.go} (57%) create mode 100644 server/internal/models/session.go diff --git a/server/cmd/server/main.go b/server/cmd/server/main.go index 951d8b8..7fbf006 100644 --- a/server/cmd/server/main.go +++ b/server/cmd/server/main.go @@ -1,3 +1,4 @@ +// Package main contains the main entry point for the application. It sets up the server, database, and other services, and starts the server. package main import ( @@ -61,7 +62,7 @@ func main() { authMiddleware := auth.NewMiddleware(jwtManager) // Initialize session service - sessionService := auth.NewSessionService(database.DB, jwtManager) + sessionService := auth.NewSessionService(database, jwtManager) // Set up router r := chi.NewRouter() diff --git a/server/internal/api/routes.go b/server/internal/api/routes.go index 71a873b..35751aa 100644 --- a/server/internal/api/routes.go +++ b/server/internal/api/routes.go @@ -12,7 +12,7 @@ import ( ) // SetupRoutes configures the API routes -func SetupRoutes(r chi.Router, db *db.DB, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) { +func SetupRoutes(r chi.Router, db db.Database, s storage.Manager, authMiddleware *auth.Middleware, sessionService *auth.SessionService) { handler := &handlers.Handler{ DB: db, diff --git a/server/internal/auth/session.go b/server/internal/auth/session.go index bc0d165..b6217d6 100644 --- a/server/internal/auth/session.go +++ b/server/internal/auth/session.go @@ -1,33 +1,25 @@ package auth import ( - "database/sql" "fmt" + "novamd/internal/db" + "novamd/internal/models" "time" "github.com/google/uuid" ) -// Session represents a user session in the database -type Session struct { - ID string // Unique session identifier - UserID int // ID of the user this session belongs to - RefreshToken string // The refresh token associated with this session - ExpiresAt time.Time // When this session expires - CreatedAt time.Time // When this session was created -} - // SessionService manages user sessions in the database type SessionService struct { - db *sql.DB // Database connection - jwtManager JWTManager // JWT Manager for token operations + db db.SessionStore // Database store for sessions + jwtManager JWTManager // JWT Manager for token operations } // NewSessionService creates a new session service // Parameters: // - db: database connection // - jwtManager: JWT service for token operations -func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService { +func NewSessionService(db db.SessionStore, jwtManager JWTManager) *SessionService { return &SessionService{ db: db, jwtManager: jwtManager, @@ -42,7 +34,7 @@ func NewSessionService(db *sql.DB, jwtManager JWTManager) *SessionService { // - session: the created session // - accessToken: a new access token // - error: any error that occurred -func (s *SessionService) CreateSession(userID int, role string) (*Session, string, error) { +func (s *SessionService) CreateSession(userID int, role string) (*models.Session, string, error) { // Generate both access and refresh tokens accessToken, err := s.jwtManager.GenerateAccessToken(userID, role) if err != nil { @@ -61,7 +53,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin } // Create a new session record - session := &Session{ + session := &models.Session{ ID: uuid.New().String(), UserID: userID, RefreshToken: refreshToken, @@ -69,14 +61,9 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin CreatedAt: time.Now(), } - // Store the session in the database - _, err = s.db.Exec(` - INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at) - VALUES (?, ?, ?, ?, ?)`, - session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt, - ) - if err != nil { - return nil, "", fmt.Errorf("failed to store session: %w", err) + // Store the session + if err := s.db.CreateSession(session); err != nil { + return nil, "", err } return session, accessToken, nil @@ -89,28 +76,18 @@ func (s *SessionService) CreateSession(userID int, role string) (*Session, strin // - string: a new access token // - error: any error that occurred func (s *SessionService) RefreshSession(refreshToken string) (string, error) { + // Get session from database + _, err := s.db.GetSessionByRefreshToken(refreshToken) + if err != nil { + return "", fmt.Errorf("invalid session: %w", err) + } + // Validate the refresh token claims, err := s.jwtManager.ValidateToken(refreshToken) if err != nil { return "", fmt.Errorf("invalid refresh token: %w", err) } - // Check if the session exists and is not expired - var session Session - err = s.db.QueryRow(` - SELECT id, user_id, refresh_token, expires_at, created_at - FROM sessions - WHERE refresh_token = ? AND expires_at > ?`, - refreshToken, time.Now(), - ).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) - - if err == sql.ErrNoRows { - return "", fmt.Errorf("session not found or expired") - } - if err != nil { - return "", fmt.Errorf("failed to fetch session: %w", err) - } - // Generate a new access token return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role) } @@ -121,20 +98,12 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) { // Returns: // - error: any error that occurred func (s *SessionService) InvalidateSession(sessionID string) error { - _, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID) - if err != nil { - return fmt.Errorf("failed to invalidate session: %w", err) - } - return nil + return s.db.DeleteSession(sessionID) } // CleanExpiredSessions removes all expired sessions from the database // Returns: // - error: any error that occurred func (s *SessionService) CleanExpiredSessions() error { - _, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now()) - if err != nil { - return fmt.Errorf("failed to clean expired sessions: %w", err) - } - return nil + return s.db.CleanExpiredSessions() } diff --git a/server/internal/db/admin.go b/server/internal/db/admin.go deleted file mode 100644 index 15c8a84..0000000 --- a/server/internal/db/admin.go +++ /dev/null @@ -1,68 +0,0 @@ -// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records. -package db - -import "novamd/internal/models" - -// UserStats represents system-wide statistics -type UserStats struct { - TotalUsers int `json:"totalUsers"` - TotalWorkspaces int `json:"totalWorkspaces"` - ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days -} - -// GetAllUsers returns a list of all users in the system -func (db *DB) GetAllUsers() ([]*models.User, error) { - rows, err := db.Query(` - SELECT - id, email, display_name, role, created_at, - last_workspace_id - FROM users - ORDER BY id ASC`) - if err != nil { - return nil, err - } - defer rows.Close() - - var users []*models.User - for rows.Next() { - user := &models.User{} - err := rows.Scan( - &user.ID, &user.Email, &user.DisplayName, &user.Role, - &user.CreatedAt, &user.LastWorkspaceID, - ) - if err != nil { - return nil, err - } - users = append(users, user) - } - return users, nil -} - -// GetSystemStats returns system-wide statistics -func (db *DB) GetSystemStats() (*UserStats, error) { - stats := &UserStats{} - - // Get total users - err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers) - if err != nil { - return nil, err - } - - // Get total workspaces - err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces) - if err != nil { - return nil, err - } - - // Get active users (users with activity in last 30 days) - err = db.QueryRow(` - SELECT COUNT(DISTINCT user_id) - FROM sessions - WHERE created_at > datetime('now', '-30 days')`). - Scan(&stats.ActiveUsers) - if err != nil { - return nil, err - } - - return stats, nil -} diff --git a/server/internal/db/db.go b/server/internal/db/db.go index 8ca51c3..8cd04e9 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -1,3 +1,4 @@ +// Package db provides the database access layer for the application. It contains methods for interacting with the database, such as creating, updating, and deleting records. package db import ( @@ -5,18 +6,75 @@ import ( "fmt" "novamd/internal/crypto" + "novamd/internal/models" _ "github.com/mattn/go-sqlite3" // SQLite driver ) -// DB represents the database connection -type DB struct { +// UserStore defines the methods for interacting with user data in the database +type UserStore interface { + CreateUser(user *models.User) (*models.User, error) + GetUserByEmail(email string) (*models.User, error) + GetUserByID(userID int) (*models.User, error) + GetAllUsers() ([]*models.User, error) + UpdateUser(user *models.User) error + DeleteUser(userID int) error + UpdateLastWorkspace(userID int, workspaceName string) error + GetLastWorkspaceName(userID int) (string, error) + CountAdminUsers() (int, error) +} + +// WorkspaceStore defines the methods for interacting with workspace data in the database +type WorkspaceStore interface { + CreateWorkspace(workspace *models.Workspace) error + GetWorkspaceByID(workspaceID int) (*models.Workspace, error) + GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) + GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) + UpdateWorkspace(workspace *models.Workspace) error + DeleteWorkspace(workspaceID int) error + UpdateWorkspaceSettings(workspace *models.Workspace) error + DeleteWorkspaceTx(tx *sql.Tx, workspaceID int) error + UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error + UpdateLastOpenedFile(workspaceID int, filePath string) error + GetLastOpenedFile(workspaceID int) (string, error) + GetAllWorkspaces() ([]*models.Workspace, error) +} + +// SessionStore defines the methods for interacting with jwt sessions in the database +type SessionStore interface { + CreateSession(session *models.Session) error + GetSessionByRefreshToken(refreshToken string) (*models.Session, error) + DeleteSession(sessionID string) error + CleanExpiredSessions() error +} + +// SystemStore defines the methods for interacting with system settings and stats in the database +type SystemStore interface { + GetSystemStats() (*UserStats, error) + EnsureJWTSecret() (string, error) + GetSystemSetting(key string) (string, error) + SetSystemSetting(key, value string) error +} + +// Database defines the methods for interacting with the database +type Database interface { + UserStore + WorkspaceStore + SessionStore + SystemStore + Begin() (*sql.Tx, error) + Close() error + Migrate() error +} + +// database represents the database connection +type database struct { *sql.DB crypto *crypto.Crypto } // Init initializes the database connection -func Init(dbPath string, encryptionKey string) (*DB, error) { +func Init(dbPath string, encryptionKey string) (Database, error) { db, err := sql.Open("sqlite3", dbPath) if err != nil { return nil, err @@ -32,7 +90,7 @@ func Init(dbPath string, encryptionKey string) (*DB, error) { return nil, fmt.Errorf("failed to initialize encryption: %w", err) } - database := &DB{ + database := &database{ DB: db, crypto: cryptoService, } @@ -45,19 +103,19 @@ func Init(dbPath string, encryptionKey string) (*DB, error) { } // Close closes the database connection -func (db *DB) Close() error { +func (db *database) Close() error { return db.DB.Close() } // Helper methods for token encryption/decryption -func (db *DB) encryptToken(token string) (string, error) { +func (db *database) encryptToken(token string) (string, error) { if token == "" { return "", nil } return db.crypto.Encrypt(token) } -func (db *DB) decryptToken(token string) (string, error) { +func (db *database) decryptToken(token string) (string, error) { if token == "" { return "", nil } diff --git a/server/internal/db/migrations.go b/server/internal/db/migrations.go index 11cfc74..f59b20f 100644 --- a/server/internal/db/migrations.go +++ b/server/internal/db/migrations.go @@ -83,7 +83,7 @@ var migrations = []Migration{ } // Migrate applies all database migrations -func (db *DB) Migrate() error { +func (db *database) Migrate() error { // Create migrations table if it doesn't exist _, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations ( version INTEGER PRIMARY KEY diff --git a/server/internal/db/sessions.go b/server/internal/db/sessions.go new file mode 100644 index 0000000..596dc64 --- /dev/null +++ b/server/internal/db/sessions.go @@ -0,0 +1,70 @@ +package db + +import ( + "database/sql" + "fmt" + "time" + + "novamd/internal/models" +) + +// CreateSession inserts a new session record into the database +func (db *database) CreateSession(session *models.Session) error { + _, err := db.Exec(` + INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at) + VALUES (?, ?, ?, ?, ?)`, + session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt, + ) + if err != nil { + return fmt.Errorf("failed to store session: %w", err) + } + return nil +} + +// GetSessionByRefreshToken retrieves a session by its refresh token +func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { + session := &models.Session{} + err := db.QueryRow(` + SELECT id, user_id, refresh_token, expires_at, created_at + FROM sessions + WHERE refresh_token = ? AND expires_at > ?`, + refreshToken, time.Now(), + ).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) + + if err == sql.ErrNoRows { + return nil, fmt.Errorf("session not found or expired") + } + if err != nil { + return nil, fmt.Errorf("failed to fetch session: %w", err) + } + + return session, nil +} + +// DeleteSession removes a session from the database +func (db *database) DeleteSession(sessionID string) error { + result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID) + if err != nil { + return fmt.Errorf("failed to delete session: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rowsAffected == 0 { + return fmt.Errorf("session not found") + } + + return nil +} + +// CleanExpiredSessions removes all expired sessions from the database +func (db *database) CleanExpiredSessions() error { + _, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now()) + if err != nil { + return fmt.Errorf("failed to clean expired sessions: %w", err) + } + return nil +} diff --git a/server/internal/db/system_settings.go b/server/internal/db/system.go similarity index 57% rename from server/internal/db/system_settings.go rename to server/internal/db/system.go index 76fdfa5..f954b34 100644 --- a/server/internal/db/system_settings.go +++ b/server/internal/db/system.go @@ -11,9 +11,16 @@ const ( JWTSecretKey = "jwt_secret" ) +// UserStats represents system-wide statistics +type UserStats struct { + TotalUsers int `json:"totalUsers"` + TotalWorkspaces int `json:"totalWorkspaces"` + ActiveUsers int `json:"activeUsers"` // Users with activity in last 30 days +} + // EnsureJWTSecret makes sure a JWT signing secret exists in the database // If no secret exists, it generates and stores a new one -func (db *DB) EnsureJWTSecret() (string, error) { +func (db *database) EnsureJWTSecret() (string, error) { // First, try to get existing secret secret, err := db.GetSystemSetting(JWTSecretKey) if err == nil { @@ -36,7 +43,7 @@ func (db *DB) EnsureJWTSecret() (string, error) { } // GetSystemSetting retrieves a system setting by key -func (db *DB) GetSystemSetting(key string) (string, error) { +func (db *database) GetSystemSetting(key string) (string, error) { var value string err := db.QueryRow("SELECT value FROM system_settings WHERE key = ?", key).Scan(&value) if err != nil { @@ -46,7 +53,7 @@ func (db *DB) GetSystemSetting(key string) (string, error) { } // SetSystemSetting stores or updates a system setting -func (db *DB) SetSystemSetting(key, value string) error { +func (db *database) SetSystemSetting(key, value string) error { _, err := db.Exec(` INSERT INTO system_settings (key, value) VALUES (?, ?) @@ -64,3 +71,32 @@ func generateRandomSecret(bytes int) (string, error) { } return base64.StdEncoding.EncodeToString(b), nil } + +// GetSystemStats returns system-wide statistics +func (db *database) GetSystemStats() (*UserStats, error) { + stats := &UserStats{} + + // Get total users + err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers) + if err != nil { + return nil, err + } + + // Get total workspaces + err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces) + if err != nil { + return nil, err + } + + // Get active users (users with activity in last 30 days) + err = db.QueryRow(` + SELECT COUNT(DISTINCT user_id) + FROM sessions + WHERE created_at > datetime('now', '-30 days')`). + Scan(&stats.ActiveUsers) + if err != nil { + return nil, err + } + + return stats, nil +} diff --git a/server/internal/db/users.go b/server/internal/db/users.go index 6a040fb..17cc374 100644 --- a/server/internal/db/users.go +++ b/server/internal/db/users.go @@ -6,7 +6,7 @@ import ( ) // CreateUser inserts a new user record into the database -func (db *DB) CreateUser(user *models.User) (*models.User, error) { +func (db *database) CreateUser(user *models.User) (*models.User, error) { tx, err := db.Begin() if err != nil { return nil, err @@ -62,7 +62,7 @@ func (db *DB) CreateUser(user *models.User) (*models.User, error) { } // Helper function to create a workspace in a transaction -func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error { +func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error { result, err := tx.Exec(` INSERT INTO workspaces ( user_id, name, @@ -87,7 +87,7 @@ func (db *DB) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error { } // GetUserByID retrieves a user by ID -func (db *DB) GetUserByID(id int) (*models.User, error) { +func (db *database) GetUserByID(id int) (*models.User, error) { user := &models.User{} err := db.QueryRow(` SELECT @@ -104,7 +104,7 @@ func (db *DB) GetUserByID(id int) (*models.User, error) { } // GetUserByEmail retrieves a user by email -func (db *DB) GetUserByEmail(email string) (*models.User, error) { +func (db *database) GetUserByEmail(email string) (*models.User, error) { user := &models.User{} err := db.QueryRow(` SELECT @@ -122,7 +122,7 @@ func (db *DB) GetUserByEmail(email string) (*models.User, error) { } // UpdateUser updates a user's information -func (db *DB) UpdateUser(user *models.User) error { +func (db *database) UpdateUser(user *models.User) error { _, err := db.Exec(` UPDATE users SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ? @@ -131,8 +131,36 @@ func (db *DB) UpdateUser(user *models.User) error { return err } +// GetAllUsers returns a list of all users in the system +func (db *database) GetAllUsers() ([]*models.User, error) { + rows, err := db.Query(` + SELECT + id, email, display_name, role, created_at, + last_workspace_id + FROM users + ORDER BY id ASC`) + if err != nil { + return nil, err + } + defer rows.Close() + + var users []*models.User + for rows.Next() { + user := &models.User{} + err := rows.Scan( + &user.ID, &user.Email, &user.DisplayName, &user.Role, + &user.CreatedAt, &user.LastWorkspaceID, + ) + if err != nil { + return nil, err + } + users = append(users, user) + } + return users, nil +} + // UpdateLastWorkspace updates the last workspace the user accessed -func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error { +func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error { tx, err := db.Begin() if err != nil { return err @@ -155,7 +183,7 @@ func (db *DB) UpdateLastWorkspace(userID int, workspaceName string) error { } // DeleteUser deletes a user and all their workspaces -func (db *DB) DeleteUser(id int) error { +func (db *database) DeleteUser(id int) error { tx, err := db.Begin() if err != nil { return err @@ -178,7 +206,7 @@ func (db *DB) DeleteUser(id int) error { } // GetLastWorkspaceName returns the name of the last workspace the user accessed -func (db *DB) GetLastWorkspaceName(userID int) (string, error) { +func (db *database) GetLastWorkspaceName(userID int) (string, error) { var workspaceName string err := db.QueryRow(` SELECT @@ -189,3 +217,10 @@ func (db *DB) GetLastWorkspaceName(userID int) (string, error) { Scan(&workspaceName) return workspaceName, err } + +// CountAdminUsers returns the number of admin users in the system +func (db *database) CountAdminUsers() (int, error) { + var count int + err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count) + return count, err +} diff --git a/server/internal/db/workspaces.go b/server/internal/db/workspaces.go index 68be97a..ce39ce5 100644 --- a/server/internal/db/workspaces.go +++ b/server/internal/db/workspaces.go @@ -7,7 +7,7 @@ import ( ) // CreateWorkspace inserts a new workspace record into the database -func (db *DB) CreateWorkspace(workspace *models.Workspace) error { +func (db *database) CreateWorkspace(workspace *models.Workspace) error { // Set default settings if not provided if workspace.Theme == "" { workspace.GetDefaultSettings() @@ -42,7 +42,7 @@ func (db *DB) CreateWorkspace(workspace *models.Workspace) error { } // GetWorkspaceByID retrieves a workspace by its ID -func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) { +func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { workspace := &models.Workspace{} var encryptedToken string @@ -75,7 +75,7 @@ func (db *DB) GetWorkspaceByID(id int) (*models.Workspace, error) { } // GetWorkspaceByName retrieves a workspace by its name and user ID -func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) { +func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) { workspace := &models.Workspace{} var encryptedToken string @@ -108,7 +108,7 @@ func (db *DB) GetWorkspaceByName(userID int, workspaceName string) (*models.Work } // UpdateWorkspace updates a workspace record in the database -func (db *DB) UpdateWorkspace(workspace *models.Workspace) error { +func (db *database) UpdateWorkspace(workspace *models.Workspace) error { // Encrypt token before storing encryptedToken, err := db.encryptToken(workspace.GitToken) if err != nil { @@ -146,7 +146,7 @@ func (db *DB) UpdateWorkspace(workspace *models.Workspace) error { } // GetWorkspacesByUserID retrieves all workspaces for a user -func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { +func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { rows, err := db.Query(` SELECT id, user_id, name, created_at, @@ -189,7 +189,7 @@ func (db *DB) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { // UpdateWorkspaceSettings updates only the settings portion of a workspace // This is useful when you don't want to modify the name or other core workspace properties -func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error { +func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { _, err := db.Exec(` UPDATE workspaces SET @@ -218,31 +218,31 @@ func (db *DB) UpdateWorkspaceSettings(workspace *models.Workspace) error { } // DeleteWorkspace removes a workspace record from the database -func (db *DB) DeleteWorkspace(id int) error { +func (db *database) DeleteWorkspace(id int) error { _, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id) return err } // DeleteWorkspaceTx removes a workspace record from the database within a transaction -func (db *DB) DeleteWorkspaceTx(tx *sql.Tx, id int) error { +func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error { _, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id) return err } // UpdateLastWorkspaceTx sets the last workspace for a user in with a transaction -func (db *DB) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error { +func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error { _, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", workspaceID, userID) return err } // UpdateLastOpenedFile updates the last opened file path for a workspace -func (db *DB) UpdateLastOpenedFile(workspaceID int, filePath string) error { +func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error { _, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?", filePath, workspaceID) return err } // GetLastOpenedFile retrieves the last opened file path for a workspace -func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) { +func (db *database) GetLastOpenedFile(workspaceID int) (string, error) { var filePath sql.NullString err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?", workspaceID).Scan(&filePath) if err != nil { @@ -255,7 +255,7 @@ func (db *DB) GetLastOpenedFile(workspaceID int) (string, error) { } // GetAllWorkspaces retrieves all workspaces in the database -func (db *DB) GetAllWorkspaces() ([]*models.Workspace, error) { +func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { rows, err := db.Query(` SELECT id, user_id, name, created_at, diff --git a/server/internal/handlers/auth_handlers.go b/server/internal/handlers/auth_handlers.go index 319b6d9..cce4c30 100644 --- a/server/internal/handlers/auth_handlers.go +++ b/server/internal/handlers/auth_handlers.go @@ -16,10 +16,10 @@ type LoginRequest struct { } type LoginResponse struct { - AccessToken string `json:"accessToken"` - RefreshToken string `json:"refreshToken"` - User *models.User `json:"user"` - Session *auth.Session `json:"session"` + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + User *models.User `json:"user"` + Session *models.Session `json:"session"` } type RefreshRequest struct { diff --git a/server/internal/handlers/handlers.go b/server/internal/handlers/handlers.go index 743e40d..7af3611 100644 --- a/server/internal/handlers/handlers.go +++ b/server/internal/handlers/handlers.go @@ -9,12 +9,12 @@ import ( // Handler provides common functionality for all handlers type Handler struct { - DB *db.DB + DB db.Database Storage storage.Manager } // NewHandler creates a new handler with the given dependencies -func NewHandler(db *db.DB, s storage.Manager) *Handler { +func NewHandler(db db.Database, s storage.Manager) *Handler { return &Handler{ DB: db, Storage: s, diff --git a/server/internal/handlers/user_handlers.go b/server/internal/handlers/user_handlers.go index fa57020..013baf0 100644 --- a/server/internal/handlers/user_handlers.go +++ b/server/internal/handlers/user_handlers.go @@ -171,8 +171,7 @@ func (h *Handler) DeleteAccount() http.HandlerFunc { // Prevent admin from deleting their own account if they're the last admin if user.Role == "admin" { // Count number of admin users - adminCount := 0 - err := h.DB.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&adminCount) + adminCount, err := h.DB.CountAdminUsers() if err != nil { http.Error(w, "Failed to verify admin status", http.StatusInternalServerError) return diff --git a/server/internal/middleware/context.go b/server/internal/middleware/context.go index 95a72d4..288ed24 100644 --- a/server/internal/middleware/context.go +++ b/server/internal/middleware/context.go @@ -29,7 +29,7 @@ func WithUserContext(next http.Handler) http.Handler { } // Workspace context -func WithWorkspaceContext(db *db.DB) func(http.Handler) http.Handler { +func WithWorkspaceContext(db db.Database) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, ok := httpcontext.GetRequestContext(w, r) diff --git a/server/internal/models/session.go b/server/internal/models/session.go new file mode 100644 index 0000000..d0b8119 --- /dev/null +++ b/server/internal/models/session.go @@ -0,0 +1,13 @@ +// Package models contains the data models used throughout the application. These models are used to represent data in the database, as well as to validate and serialize data in the application. +package models + +import "time" + +// Session represents a user session in the database +type Session struct { + ID string // Unique session identifier + UserID int // ID of the user this session belongs to + RefreshToken string // The refresh token associated with this session + ExpiresAt time.Time // When this session expires + CreatedAt time.Time // When this session was created +} diff --git a/server/internal/user/user.go b/server/internal/user/user.go index 638ec39..20383ca 100644 --- a/server/internal/user/user.go +++ b/server/internal/user/user.go @@ -13,11 +13,11 @@ import ( ) type UserService struct { - DB *db.DB + DB db.Database Storage storage.Manager } -func NewUserService(database *db.DB, s storage.Manager) *UserService { +func NewUserService(database db.Database, s storage.Manager) *UserService { return &UserService{ DB: database, Storage: s, From b3ec4e136ce4c6db3d6640964b1e43337eab34ce Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 22 Nov 2024 23:17:59 +0100 Subject: [PATCH 14/38] Implement auth tests --- server/internal/auth/jwt_test.go | 221 +++++++++++++++ server/internal/auth/middleware_test.go | 362 ++++++++++++++++++++++++ server/internal/auth/session_test.go | 298 +++++++++++++++++++ 3 files changed, 881 insertions(+) create mode 100644 server/internal/auth/jwt_test.go create mode 100644 server/internal/auth/middleware_test.go create mode 100644 server/internal/auth/session_test.go diff --git a/server/internal/auth/jwt_test.go b/server/internal/auth/jwt_test.go new file mode 100644 index 0000000..61aa3ad --- /dev/null +++ b/server/internal/auth/jwt_test.go @@ -0,0 +1,221 @@ +package auth_test + +import ( + "testing" + "time" + + "novamd/internal/auth" + + "github.com/golang-jwt/jwt/v5" +) + +// jwt_test.go tests + +func TestNewJWTService(t *testing.T) { + testCases := []struct { + name string + config auth.JWTConfig + wantErr bool + }{ + { + name: "valid configuration", + config: auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + }, + wantErr: false, + }, + { + name: "missing signing key", + config: auth.JWTConfig{ + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + }, + wantErr: true, + }, + { + name: "zero expiry times", + config: auth.JWTConfig{ + SigningKey: "test-key", + }, + wantErr: false, // Should use default values + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + service, err := auth.NewJWTService(tc.config) + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if service == nil { + t.Error("expected service, got nil") + } + }) + } +} + +func TestGenerateAndValidateToken(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + service, _ := auth.NewJWTService(config) + + testCases := []struct { + name string + userID int + role string + tokenType auth.TokenType + wantErr bool + }{ + { + name: "valid access token", + userID: 1, + role: "admin", + tokenType: auth.AccessToken, + wantErr: false, + }, + { + name: "valid refresh token", + userID: 1, + role: "editor", + tokenType: auth.RefreshToken, + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var token string + var err error + + // Generate token based on type + if tc.tokenType == auth.AccessToken { + token, err = service.GenerateAccessToken(tc.userID, tc.role) + } else { + token, err = service.GenerateRefreshToken(tc.userID, tc.role) + } + + if err != nil { + t.Fatalf("failed to generate token: %v", err) + } + + // Validate token + claims, err := service.ValidateToken(token) + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Verify claims + if claims.UserID != tc.userID { + t.Errorf("userID = %v, want %v", claims.UserID, tc.userID) + } + if claims.Role != tc.role { + t.Errorf("role = %v, want %v", claims.Role, tc.role) + } + if claims.Type != tc.tokenType { + t.Errorf("type = %v, want %v", claims.Type, tc.tokenType) + } + }) + } +} + +func TestRefreshAccessToken(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + service, _ := auth.NewJWTService(config) + + testCases := []struct { + name string + userID int + role string + wantErr bool + setupFunc func() string // Added setup function to handle custom token creation + }{ + { + name: "valid refresh token", + userID: 1, + role: "admin", + wantErr: false, + setupFunc: func() string { + token, _ := service.GenerateRefreshToken(1, "admin") + return token + }, + }, + { + name: "expired refresh token", + userID: 1, + role: "admin", + wantErr: true, + setupFunc: func() string { + // Create a token that's already expired + claims := &auth.Claims{ + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)), // Expired 1 hour ago + IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)), + NotBefore: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)), + }, + UserID: 1, + Role: "admin", + Type: auth.RefreshToken, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, _ := token.SignedString([]byte(config.SigningKey)) + return tokenString + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + refreshToken := tc.setupFunc() + newAccessToken, err := service.RefreshAccessToken(refreshToken) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + claims, err := service.ValidateToken(newAccessToken) + if err != nil { + t.Fatalf("failed to validate new access token: %v", err) + } + + if claims.UserID != tc.userID { + t.Errorf("userID = %v, want %v", claims.UserID, tc.userID) + } + if claims.Role != tc.role { + t.Errorf("role = %v, want %v", claims.Role, tc.role) + } + if claims.Type != auth.AccessToken { + t.Errorf("token type = %v, want %v", claims.Type, auth.AccessToken) + } + }) + } +} diff --git a/server/internal/auth/middleware_test.go b/server/internal/auth/middleware_test.go new file mode 100644 index 0000000..153bf33 --- /dev/null +++ b/server/internal/auth/middleware_test.go @@ -0,0 +1,362 @@ +package auth_test + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "novamd/internal/auth" + "novamd/internal/httpcontext" + "novamd/internal/models" +) + +// Complete mockResponseWriter implementation +type mockResponseWriter struct { + headers http.Header + statusCode int + written []byte +} + +func newMockResponseWriter() *mockResponseWriter { + return &mockResponseWriter{ + headers: make(http.Header), + } +} + +func (m *mockResponseWriter) Header() http.Header { + return m.headers +} + +func (m *mockResponseWriter) Write(b []byte) (int, error) { + m.written = b + return len(b), nil +} + +func (m *mockResponseWriter) WriteHeader(statusCode int) { + m.statusCode = statusCode +} + +func TestAuthenticateMiddleware(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + jwtService, _ := auth.NewJWTService(config) + middleware := auth.NewMiddleware(jwtService) + + testCases := []struct { + name string + setupAuth func() string + wantStatusCode int + }{ + { + name: "valid token", + setupAuth: func() string { + token, _ := jwtService.GenerateAccessToken(1, "admin") + return token + }, + wantStatusCode: http.StatusOK, + }, + { + name: "missing auth header", + setupAuth: func() string { + return "" + }, + wantStatusCode: http.StatusUnauthorized, + }, + { + name: "invalid auth format", + setupAuth: func() string { + return "InvalidFormat token" + }, + wantStatusCode: http.StatusUnauthorized, + }, + { + name: "invalid token", + setupAuth: func() string { + return "Bearer invalid.token.here" + }, + wantStatusCode: http.StatusUnauthorized, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create test request + req := httptest.NewRequest("GET", "/test", nil) + if token := tc.setupAuth(); token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + // Create response recorder + w := newMockResponseWriter() + + // Create test handler + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + // Execute middleware + middleware.Authenticate(next).ServeHTTP(w, req) + + // Check status code + if w.statusCode != tc.wantStatusCode { + t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode) + } + + // Check if next handler was called when expected + if tc.wantStatusCode == http.StatusOK && !nextCalled { + t.Error("next handler was not called") + } + if tc.wantStatusCode != http.StatusOK && nextCalled { + t.Error("next handler was called when it shouldn't have been") + } + }) + } +} + +func TestRequireRole(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + jwtService, _ := auth.NewJWTService(config) + middleware := auth.NewMiddleware(jwtService) + + testCases := []struct { + name string + userRole string + requiredRole string + wantStatusCode int + }{ + { + name: "matching role", + userRole: "admin", + requiredRole: "admin", + wantStatusCode: http.StatusOK, + }, + { + name: "admin accessing other role", + userRole: "admin", + requiredRole: "editor", + wantStatusCode: http.StatusOK, + }, + { + name: "insufficient role", + userRole: "editor", + requiredRole: "admin", + wantStatusCode: http.StatusForbidden, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create context with user claims + ctx := context.WithValue(context.Background(), auth.UserContextKey, auth.UserClaims{ + UserID: 1, + Role: tc.userRole, + }) + req := httptest.NewRequest("GET", "/test", nil).WithContext(ctx) + w := newMockResponseWriter() + + // Create test handler + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + // Execute middleware + middleware.RequireRole(tc.requiredRole)(next).ServeHTTP(w, req) + + // Check status code + if w.statusCode != tc.wantStatusCode { + t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode) + } + + // Check if next handler was called when expected + if tc.wantStatusCode == http.StatusOK && !nextCalled { + t.Error("next handler was not called") + } + if tc.wantStatusCode != http.StatusOK && nextCalled { + t.Error("next handler was called when it shouldn't have been") + } + }) + } +} + +func TestRequireWorkspaceAccess(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + } + jwtService, _ := auth.NewJWTService(config) + middleware := auth.NewMiddleware(jwtService) + + testCases := []struct { + name string + setupContext func() *httpcontext.HandlerContext + wantStatusCode int + }{ + { + name: "workspace owner access", + setupContext: func() *httpcontext.HandlerContext { + return &httpcontext.HandlerContext{ + UserID: 1, + UserRole: "editor", + Workspace: &models.Workspace{ + ID: 1, + UserID: 1, // Same as context UserID + }, + } + }, + wantStatusCode: http.StatusOK, + }, + { + name: "admin access to other's workspace", + setupContext: func() *httpcontext.HandlerContext { + return &httpcontext.HandlerContext{ + UserID: 2, + UserRole: "admin", + Workspace: &models.Workspace{ + ID: 1, + UserID: 1, // Different from context UserID + }, + } + }, + wantStatusCode: http.StatusOK, + }, + { + name: "unauthorized access attempt", + setupContext: func() *httpcontext.HandlerContext { + return &httpcontext.HandlerContext{ + UserID: 2, + UserRole: "editor", + Workspace: &models.Workspace{ + ID: 1, + UserID: 1, // Different from context UserID + }, + } + }, + wantStatusCode: http.StatusNotFound, + }, + { + name: "no workspace in context", + setupContext: func() *httpcontext.HandlerContext { + return &httpcontext.HandlerContext{ + UserID: 1, + UserRole: "editor", + Workspace: nil, + } + }, + wantStatusCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create request with context + ctx := context.WithValue(context.Background(), httpcontext.HandlerContextKey, tc.setupContext()) + req := httptest.NewRequest("GET", "/test", nil).WithContext(ctx) + w := newMockResponseWriter() + + // Create test handler + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + // Execute middleware + middleware.RequireWorkspaceAccess(next).ServeHTTP(w, req) + + // Check status code + if w.statusCode != tc.wantStatusCode { + t.Errorf("status code = %v, want %v", w.statusCode, tc.wantStatusCode) + } + + // Check if next handler was called when expected + if tc.wantStatusCode == http.StatusOK && !nextCalled { + t.Error("next handler was not called") + } + if tc.wantStatusCode != http.StatusOK && nextCalled { + t.Error("next handler was called when it shouldn't have been") + } + }) + } +} + +func TestGetUserFromContext(t *testing.T) { + testCases := []struct { + name string + setupCtx func() context.Context + wantUserID int + wantRole string + wantErr bool + errContains string + }{ + { + name: "valid user context", + setupCtx: func() context.Context { + return context.WithValue(context.Background(), auth.UserContextKey, auth.UserClaims{ + UserID: 1, + Role: "admin", + }) + }, + wantUserID: 1, + wantRole: "admin", + wantErr: false, + }, + { + name: "missing user context", + setupCtx: func() context.Context { + return context.Background() + }, + wantErr: true, + errContains: "no user found in context", + }, + { + name: "invalid context value type", + setupCtx: func() context.Context { + return context.WithValue(context.Background(), auth.UserContextKey, "invalid") + }, + wantErr: true, + errContains: "no user found in context", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := tc.setupCtx() + claims, err := auth.GetUserFromContext(ctx) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if claims.UserID != tc.wantUserID { + t.Errorf("UserID = %v, want %v", claims.UserID, tc.wantUserID) + } + + if claims.Role != tc.wantRole { + t.Errorf("Role = %v, want %v", claims.Role, tc.wantRole) + } + }) + } +} diff --git a/server/internal/auth/session_test.go b/server/internal/auth/session_test.go new file mode 100644 index 0000000..00c7494 --- /dev/null +++ b/server/internal/auth/session_test.go @@ -0,0 +1,298 @@ +package auth_test + +import ( + "errors" + "strings" + "testing" + "time" + + "novamd/internal/auth" + "novamd/internal/models" +) + +// Mock SessionStore +type mockSessionStore struct { + sessions map[string]*models.Session + sessionsByToken map[string]*models.Session // Added index by refresh token +} + +func newMockSessionStore() *mockSessionStore { + return &mockSessionStore{ + sessions: make(map[string]*models.Session), + sessionsByToken: make(map[string]*models.Session), + } +} + +func (m *mockSessionStore) CreateSession(session *models.Session) error { + m.sessions[session.ID] = session + m.sessionsByToken[session.RefreshToken] = session + return nil +} + +func (m *mockSessionStore) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { + session, exists := m.sessionsByToken[refreshToken] + if !exists { + return nil, errors.New("session not found") + } + if session.ExpiresAt.Before(time.Now()) { + return nil, errors.New("session expired") + } + return session, nil +} + +func (m *mockSessionStore) DeleteSession(sessionID string) error { + session, exists := m.sessions[sessionID] + if !exists { + return errors.New("session not found") + } + delete(m.sessionsByToken, session.RefreshToken) + delete(m.sessions, sessionID) + return nil +} + +func (m *mockSessionStore) CleanExpiredSessions() error { + for id, session := range m.sessions { + if session.ExpiresAt.Before(time.Now()) { + delete(m.sessionsByToken, session.RefreshToken) + delete(m.sessions, id) + } + } + return nil +} + +func TestCreateSession(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + jwtService, _ := auth.NewJWTService(config) + mockDB := newMockSessionStore() + sessionService := auth.NewSessionService(mockDB, jwtService) + + testCases := []struct { + name string + userID int + role string + wantErr bool + }{ + { + name: "successful session creation", + userID: 1, + role: "admin", + wantErr: false, + }, + { + name: "another successful session", + userID: 2, + role: "editor", + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + session, accessToken, err := sessionService.CreateSession(tc.userID, tc.role) + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Verify session + if session.UserID != tc.userID { + t.Errorf("userID = %v, want %v", session.UserID, tc.userID) + } + + // Verify the session was stored + storedSession, exists := mockDB.sessions[session.ID] + if !exists { + t.Error("session was not stored in database") + } + if storedSession.RefreshToken != session.RefreshToken { + t.Error("stored refresh token doesn't match") + } + + // Verify access token + claims, err := jwtService.ValidateToken(accessToken) + if err != nil { + t.Errorf("failed to validate access token: %v", err) + return + } + if claims.UserID != tc.userID { + t.Errorf("access token userID = %v, want %v", claims.UserID, tc.userID) + } + if claims.Role != tc.role { + t.Errorf("access token role = %v, want %v", claims.Role, tc.role) + } + if claims.Type != auth.AccessToken { + t.Errorf("token type = %v, want access token", claims.Type) + } + }) + } +} + +func TestRefreshSession(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + jwtService, _ := auth.NewJWTService(config) + mockDB := newMockSessionStore() + sessionService := auth.NewSessionService(mockDB, jwtService) + + testCases := []struct { + name string + setupSession func() string + wantErr bool + errorContains string + }{ + { + name: "valid refresh token", + setupSession: func() string { + token, _ := jwtService.GenerateRefreshToken(1, "admin") + session := &models.Session{ + ID: "test-session-1", + UserID: 1, + RefreshToken: token, + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + } + mockDB.CreateSession(session) + return token + }, + wantErr: false, + }, + { + name: "expired refresh token", + setupSession: func() string { + token, _ := jwtService.GenerateRefreshToken(1, "admin") + session := &models.Session{ + ID: "test-session-2", + UserID: 1, + RefreshToken: token, + ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired + CreatedAt: time.Now().Add(-2 * time.Hour), + } + mockDB.CreateSession(session) + return token + }, + wantErr: true, + errorContains: "session expired", + }, + { + name: "non-existent refresh token", + setupSession: func() string { + return "non-existent-token" + }, + wantErr: true, + errorContains: "session not found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + refreshToken := tc.setupSession() + newAccessToken, err := sessionService.RefreshSession(refreshToken) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tc.errorContains != "" && !strings.Contains(err.Error(), tc.errorContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errorContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Verify new access token + claims, err := jwtService.ValidateToken(newAccessToken) + if err != nil { + t.Errorf("failed to validate new access token: %v", err) + return + } + if claims.Type != auth.AccessToken { + t.Errorf("token type = %v, want access token", claims.Type) + } + }) + } +} + +func TestInvalidateSession(t *testing.T) { + config := auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + } + jwtService, _ := auth.NewJWTService(config) + mockDB := newMockSessionStore() + sessionService := auth.NewSessionService(mockDB, jwtService) + + testCases := []struct { + name string + setupSession func() string + wantErr bool + errorContains string + }{ + { + name: "valid session invalidation", + setupSession: func() string { + session := &models.Session{ + ID: "test-session-1", + UserID: 1, + RefreshToken: "valid-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + } + mockDB.CreateSession(session) + return session.ID + }, + wantErr: false, + }, + { + name: "non-existent session", + setupSession: func() string { + return "non-existent-session-id" + }, + wantErr: true, + errorContains: "session not found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sessionID := tc.setupSession() + err := sessionService.InvalidateSession(sessionID) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), tc.errorContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errorContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Verify session was removed + if _, exists := mockDB.sessions[sessionID]; exists { + t.Error("session still exists after invalidation") + } + }) + } +} From ebdd7bd74140d1f90fbbbf788cbe569b00049049 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 23 Nov 2024 00:29:26 +0100 Subject: [PATCH 15/38] Implement auth package tests --- server/internal/api/routes.go | 6 +- server/internal/auth/middleware.go | 51 ++------ server/internal/auth/middleware_test.go | 112 ++++-------------- server/internal/context/context.go | 62 ++++++++++ .../context.go => context/middleware.go} | 23 ++-- server/internal/handlers/admin_handlers.go | 4 +- server/internal/handlers/auth_handlers.go | 4 +- server/internal/handlers/file_handlers.go | 16 +-- server/internal/handlers/git_handlers.go | 6 +- server/internal/handlers/user_handlers.go | 8 +- .../internal/handlers/workspace_handlers.go | 16 +-- server/internal/httpcontext/context.go | 31 ----- 12 files changed, 136 insertions(+), 203 deletions(-) create mode 100644 server/internal/context/context.go rename server/internal/{middleware/context.go => context/middleware.go} (59%) delete mode 100644 server/internal/httpcontext/context.go diff --git a/server/internal/api/routes.go b/server/internal/api/routes.go index 35751aa..1409ddc 100644 --- a/server/internal/api/routes.go +++ b/server/internal/api/routes.go @@ -3,9 +3,9 @@ package api import ( "novamd/internal/auth" + "novamd/internal/context" "novamd/internal/db" "novamd/internal/handlers" - "novamd/internal/middleware" "novamd/internal/storage" "github.com/go-chi/chi/v5" @@ -29,7 +29,7 @@ func SetupRoutes(r chi.Router, db db.Database, s storage.Manager, authMiddleware r.Group(func(r chi.Router) { // Apply authentication middleware to all routes in this group r.Use(authMiddleware.Authenticate) - r.Use(middleware.WithUserContext) + r.Use(context.WithUserContextMiddleware) // Auth routes r.Post("/auth/logout", handler.Logout(sessionService)) @@ -67,7 +67,7 @@ func SetupRoutes(r chi.Router, db db.Database, s storage.Manager, authMiddleware // Single workspace routes r.Route("/{workspaceName}", func(r chi.Router) { - r.Use(middleware.WithWorkspaceContext(db)) + r.Use(context.WithWorkspaceContextMiddleware(db)) r.Use(authMiddleware.RequireWorkspaceAccess) r.Get("/", handler.GetWorkspace()) diff --git a/server/internal/auth/middleware.go b/server/internal/auth/middleware.go index 864ebd3..e669460 100644 --- a/server/internal/auth/middleware.go +++ b/server/internal/auth/middleware.go @@ -1,25 +1,12 @@ package auth import ( - "context" - "fmt" "net/http" "strings" - "novamd/internal/httpcontext" + "novamd/internal/context" ) -type contextKey string - -// UserContextKey is the key used to store user claims in the request context -const UserContextKey contextKey = "user" - -// UserClaims represents the user information stored in the request context -type UserClaims struct { - UserID int - Role string -} - // Middleware handles JWT authentication for protected routes type Middleware struct { jwtManager JWTManager @@ -70,14 +57,14 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler { return } - // Add user claims to request context - ctx := context.WithValue(r.Context(), UserContextKey, UserClaims{ - UserID: claims.UserID, - Role: claims.Role, - }) + // Create handler context with user information + hctx := &context.HandlerContext{ + UserID: claims.UserID, + UserRole: claims.Role, + } - // Call the next handler with the updated context - next.ServeHTTP(w, r.WithContext(ctx)) + // Add context to request and continue + next.ServeHTTP(w, context.WithHandlerContext(r, hctx)) }) } @@ -89,13 +76,12 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler { func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - claims, ok := r.Context().Value(UserContextKey).(UserClaims) + ctx, ok := context.GetRequestContext(w, r) if !ok { - http.Error(w, "Unauthorized", http.StatusUnauthorized) return } - if claims.Role != role && claims.Role != "admin" { + if ctx.UserRole != role && ctx.UserRole != "admin" { http.Error(w, "Insufficient permissions", http.StatusForbidden) return } @@ -112,8 +98,7 @@ func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler { // - http.Handler: the handler function func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Get our handler context - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -133,17 +118,3 @@ func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } - -// GetUserFromContext retrieves user claims from the request context -// Parameters: -// - ctx: the request context -// Returns: -// - *UserClaims: the user claims -// - error: any error that occurred -func GetUserFromContext(ctx context.Context) (*UserClaims, error) { - claims, ok := ctx.Value(UserContextKey).(UserClaims) - if !ok { - return nil, fmt.Errorf("no user found in context") - } - return &claims, nil -} diff --git a/server/internal/auth/middleware_test.go b/server/internal/auth/middleware_test.go index 153bf33..71bdc9d 100644 --- a/server/internal/auth/middleware_test.go +++ b/server/internal/auth/middleware_test.go @@ -1,15 +1,13 @@ package auth_test import ( - "context" "net/http" "net/http/httptest" - "strings" "testing" "time" "novamd/internal/auth" - "novamd/internal/httpcontext" + "novamd/internal/context" "novamd/internal/models" ) @@ -97,7 +95,7 @@ func TestAuthenticateMiddleware(t *testing.T) { // Create test handler nextCalled := false - next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { nextCalled = true w.WriteHeader(http.StatusOK) }) @@ -158,12 +156,15 @@ func TestRequireRole(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - // Create context with user claims - ctx := context.WithValue(context.Background(), auth.UserContextKey, auth.UserClaims{ - UserID: 1, - Role: tc.userRole, - }) - req := httptest.NewRequest("GET", "/test", nil).WithContext(ctx) + // Create handler context with user info + hctx := &context.HandlerContext{ + UserID: 1, + UserRole: tc.userRole, + } + + // Create request with handler context + req := httptest.NewRequest("GET", "/test", nil) + req = context.WithHandlerContext(req, hctx) w := newMockResponseWriter() // Create test handler @@ -201,13 +202,13 @@ func TestRequireWorkspaceAccess(t *testing.T) { testCases := []struct { name string - setupContext func() *httpcontext.HandlerContext + setupContext func() *context.HandlerContext wantStatusCode int }{ { name: "workspace owner access", - setupContext: func() *httpcontext.HandlerContext { - return &httpcontext.HandlerContext{ + setupContext: func() *context.HandlerContext { + return &context.HandlerContext{ UserID: 1, UserRole: "editor", Workspace: &models.Workspace{ @@ -220,8 +221,8 @@ func TestRequireWorkspaceAccess(t *testing.T) { }, { name: "admin access to other's workspace", - setupContext: func() *httpcontext.HandlerContext { - return &httpcontext.HandlerContext{ + setupContext: func() *context.HandlerContext { + return &context.HandlerContext{ UserID: 2, UserRole: "admin", Workspace: &models.Workspace{ @@ -234,8 +235,8 @@ func TestRequireWorkspaceAccess(t *testing.T) { }, { name: "unauthorized access attempt", - setupContext: func() *httpcontext.HandlerContext { - return &httpcontext.HandlerContext{ + setupContext: func() *context.HandlerContext { + return &context.HandlerContext{ UserID: 2, UserRole: "editor", Workspace: &models.Workspace{ @@ -248,8 +249,8 @@ func TestRequireWorkspaceAccess(t *testing.T) { }, { name: "no workspace in context", - setupContext: func() *httpcontext.HandlerContext { - return &httpcontext.HandlerContext{ + setupContext: func() *context.HandlerContext { + return &context.HandlerContext{ UserID: 1, UserRole: "editor", Workspace: nil, @@ -262,8 +263,8 @@ func TestRequireWorkspaceAccess(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Create request with context - ctx := context.WithValue(context.Background(), httpcontext.HandlerContextKey, tc.setupContext()) - req := httptest.NewRequest("GET", "/test", nil).WithContext(ctx) + req := httptest.NewRequest("GET", "/test", nil) + req = context.WithHandlerContext(req, tc.setupContext()) w := newMockResponseWriter() // Create test handler @@ -291,72 +292,3 @@ func TestRequireWorkspaceAccess(t *testing.T) { }) } } - -func TestGetUserFromContext(t *testing.T) { - testCases := []struct { - name string - setupCtx func() context.Context - wantUserID int - wantRole string - wantErr bool - errContains string - }{ - { - name: "valid user context", - setupCtx: func() context.Context { - return context.WithValue(context.Background(), auth.UserContextKey, auth.UserClaims{ - UserID: 1, - Role: "admin", - }) - }, - wantUserID: 1, - wantRole: "admin", - wantErr: false, - }, - { - name: "missing user context", - setupCtx: func() context.Context { - return context.Background() - }, - wantErr: true, - errContains: "no user found in context", - }, - { - name: "invalid context value type", - setupCtx: func() context.Context { - return context.WithValue(context.Background(), auth.UserContextKey, "invalid") - }, - wantErr: true, - errContains: "no user found in context", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ctx := tc.setupCtx() - claims, err := auth.GetUserFromContext(ctx) - - if tc.wantErr { - if err == nil { - t.Error("expected error, got nil") - } else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) { - t.Errorf("error = %v, want error containing %v", err, tc.errContains) - } - return - } - - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - - if claims.UserID != tc.wantUserID { - t.Errorf("UserID = %v, want %v", claims.UserID, tc.wantUserID) - } - - if claims.Role != tc.wantRole { - t.Errorf("Role = %v, want %v", claims.Role, tc.wantRole) - } - }) - } -} diff --git a/server/internal/context/context.go b/server/internal/context/context.go new file mode 100644 index 0000000..dda7eeb --- /dev/null +++ b/server/internal/context/context.go @@ -0,0 +1,62 @@ +// Package context provides functions for managing request context +package context + +import ( + "context" + "fmt" + "net/http" + "novamd/internal/models" +) + +type contextKey string + +const ( + // HandlerContextKey is the key used to store handler context in the request context + HandlerContextKey contextKey = "handlerContext" +) + +// UserClaims represents user information from authentication +type UserClaims struct { + UserID int + Role string +} + +// HandlerContext holds the request-specific data available to all handlers +type HandlerContext struct { + UserID int + UserRole string + Workspace *models.Workspace // Optional, only set for workspace routes +} + +// GetRequestContext retrieves the handler context from the request +func GetRequestContext(w http.ResponseWriter, r *http.Request) (*HandlerContext, bool) { + ctx := r.Context().Value(HandlerContextKey) + if ctx == nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return nil, false + } + return ctx.(*HandlerContext), true +} + +// WithHandlerContext adds handler context to the request +func WithHandlerContext(r *http.Request, hctx *HandlerContext) *http.Request { + return r.WithContext(context.WithValue(r.Context(), HandlerContextKey, hctx)) +} + +// GetUserFromContext retrieves user claims from the context +func GetUserFromContext(ctx context.Context) (*UserClaims, error) { + val := ctx.Value(HandlerContextKey) + if val == nil { + return nil, fmt.Errorf("no user found in context") + } + + hctx, ok := val.(*HandlerContext) + if !ok { + return nil, fmt.Errorf("invalid context type") + } + + return &UserClaims{ + UserID: hctx.UserID, + Role: hctx.UserRole, + }, nil +} diff --git a/server/internal/middleware/context.go b/server/internal/context/middleware.go similarity index 59% rename from server/internal/middleware/context.go rename to server/internal/context/middleware.go index 288ed24..9c1d9b3 100644 --- a/server/internal/middleware/context.go +++ b/server/internal/context/middleware.go @@ -1,38 +1,37 @@ -package middleware +package context import ( "net/http" - "novamd/internal/auth" "novamd/internal/db" - "novamd/internal/httpcontext" "github.com/go-chi/chi/v5" ) -// User ID and User Role context -func WithUserContext(next http.Handler) http.Handler { +// WithUserContextMiddleware extracts user information from JWT claims +// and adds it to the request context +func WithUserContextMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - claims, err := auth.GetUserFromContext(r.Context()) + claims, err := GetUserFromContext(r.Context()) if err != nil { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } - hctx := &httpcontext.HandlerContext{ + hctx := &HandlerContext{ UserID: claims.UserID, UserRole: claims.Role, } - r = httpcontext.WithHandlerContext(r, hctx) + r = WithHandlerContext(r, hctx) next.ServeHTTP(w, r) }) } -// Workspace context -func WithWorkspaceContext(db db.Database) func(http.Handler) http.Handler { +// WithWorkspaceContextMiddleware adds workspace information to the request context +func WithWorkspaceContextMiddleware(db db.Database) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := GetRequestContext(w, r) if !ok { return } @@ -46,7 +45,7 @@ func WithWorkspaceContext(db db.Database) func(http.Handler) http.Handler { // Update existing context with workspace ctx.Workspace = workspace - r = httpcontext.WithHandlerContext(r, ctx) + r = WithHandlerContext(r, ctx) next.ServeHTTP(w, r) }) } diff --git a/server/internal/handlers/admin_handlers.go b/server/internal/handlers/admin_handlers.go index 059c8cb..5ed8d42 100644 --- a/server/internal/handlers/admin_handlers.go +++ b/server/internal/handlers/admin_handlers.go @@ -3,8 +3,8 @@ package handlers import ( "encoding/json" "net/http" + "novamd/internal/context" "novamd/internal/db" - "novamd/internal/httpcontext" "novamd/internal/models" "novamd/internal/storage" "strconv" @@ -172,7 +172,7 @@ func (h *Handler) AdminUpdateUser() http.HandlerFunc { // AdminDeleteUser deletes a specific user func (h *Handler) AdminDeleteUser() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } diff --git a/server/internal/handlers/auth_handlers.go b/server/internal/handlers/auth_handlers.go index cce4c30..d6aec83 100644 --- a/server/internal/handlers/auth_handlers.go +++ b/server/internal/handlers/auth_handlers.go @@ -4,7 +4,7 @@ import ( "encoding/json" "net/http" "novamd/internal/auth" - "novamd/internal/httpcontext" + "novamd/internal/context" "novamd/internal/models" "golang.org/x/crypto/bcrypt" @@ -129,7 +129,7 @@ func (h *Handler) RefreshToken(authService *auth.SessionService) http.HandlerFun // GetCurrentUser returns the currently authenticated user func (h *Handler) GetCurrentUser() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } diff --git a/server/internal/handlers/file_handlers.go b/server/internal/handlers/file_handlers.go index d970fa1..9f75eb2 100644 --- a/server/internal/handlers/file_handlers.go +++ b/server/internal/handlers/file_handlers.go @@ -5,14 +5,14 @@ import ( "io" "net/http" - "novamd/internal/httpcontext" + "novamd/internal/context" "github.com/go-chi/chi/v5" ) func (h *Handler) ListFiles() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -29,7 +29,7 @@ func (h *Handler) ListFiles() http.HandlerFunc { func (h *Handler) LookupFileByName() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -52,7 +52,7 @@ func (h *Handler) LookupFileByName() http.HandlerFunc { func (h *Handler) GetFileContent() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -71,7 +71,7 @@ func (h *Handler) GetFileContent() http.HandlerFunc { func (h *Handler) SaveFile() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -95,7 +95,7 @@ func (h *Handler) SaveFile() http.HandlerFunc { func (h *Handler) DeleteFile() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -114,7 +114,7 @@ func (h *Handler) DeleteFile() http.HandlerFunc { func (h *Handler) GetLastOpenedFile() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -136,7 +136,7 @@ func (h *Handler) GetLastOpenedFile() http.HandlerFunc { func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } diff --git a/server/internal/handlers/git_handlers.go b/server/internal/handlers/git_handlers.go index f8ee589..ea6eeff 100644 --- a/server/internal/handlers/git_handlers.go +++ b/server/internal/handlers/git_handlers.go @@ -4,12 +4,12 @@ import ( "encoding/json" "net/http" - "novamd/internal/httpcontext" + "novamd/internal/context" ) func (h *Handler) StageCommitAndPush() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -40,7 +40,7 @@ func (h *Handler) StageCommitAndPush() http.HandlerFunc { func (h *Handler) PullChanges() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } diff --git a/server/internal/handlers/user_handlers.go b/server/internal/handlers/user_handlers.go index 013baf0..678c8e5 100644 --- a/server/internal/handlers/user_handlers.go +++ b/server/internal/handlers/user_handlers.go @@ -4,7 +4,7 @@ import ( "encoding/json" "net/http" - "novamd/internal/httpcontext" + "novamd/internal/context" "golang.org/x/crypto/bcrypt" ) @@ -22,7 +22,7 @@ type DeleteAccountRequest struct { func (h *Handler) GetUser() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -40,7 +40,7 @@ func (h *Handler) GetUser() http.HandlerFunc { // UpdateProfile updates the current user's profile func (h *Handler) UpdateProfile() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -144,7 +144,7 @@ func (h *Handler) UpdateProfile() http.HandlerFunc { // DeleteAccount handles user account deletion func (h *Handler) DeleteAccount() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } diff --git a/server/internal/handlers/workspace_handlers.go b/server/internal/handlers/workspace_handlers.go index 0f2a012..8dee442 100644 --- a/server/internal/handlers/workspace_handlers.go +++ b/server/internal/handlers/workspace_handlers.go @@ -5,13 +5,13 @@ import ( "fmt" "net/http" - "novamd/internal/httpcontext" + "novamd/internal/context" "novamd/internal/models" ) func (h *Handler) ListWorkspaces() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -28,7 +28,7 @@ func (h *Handler) ListWorkspaces() http.HandlerFunc { func (h *Handler) CreateWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -56,7 +56,7 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc { func (h *Handler) GetWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -83,7 +83,7 @@ func gitSettingsChanged(new, old *models.Workspace) bool { func (h *Handler) UpdateWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -134,7 +134,7 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc { func (h *Handler) DeleteWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -197,7 +197,7 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc { func (h *Handler) GetLastWorkspaceName() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } @@ -214,7 +214,7 @@ func (h *Handler) GetLastWorkspaceName() http.HandlerFunc { func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := httpcontext.GetRequestContext(w, r) + ctx, ok := context.GetRequestContext(w, r) if !ok { return } diff --git a/server/internal/httpcontext/context.go b/server/internal/httpcontext/context.go deleted file mode 100644 index 1e9b278..0000000 --- a/server/internal/httpcontext/context.go +++ /dev/null @@ -1,31 +0,0 @@ -package httpcontext - -import ( - "context" - "net/http" - "novamd/internal/models" -) - -// HandlerContext holds the request-specific data available to all handlers -type HandlerContext struct { - UserID int - UserRole string - Workspace *models.Workspace -} - -type contextKey string - -const HandlerContextKey contextKey = "handlerContext" - -func GetRequestContext(w http.ResponseWriter, r *http.Request) (*HandlerContext, bool) { - ctx := r.Context().Value(HandlerContextKey) - if ctx == nil { - http.Error(w, "Internal server error", http.StatusInternalServerError) - return nil, false - } - return ctx.(*HandlerContext), true -} - -func WithHandlerContext(r *http.Request, hctx *HandlerContext) *http.Request { - return r.WithContext(context.WithValue(r.Context(), HandlerContextKey, hctx)) -} From 1150c4ba39eeec705215498ec76bc30ccdec2eec Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 23 Nov 2024 16:36:29 +0100 Subject: [PATCH 16/38] Test config package --- server/internal/config/config.go | 18 +-- server/internal/config/config_test.go | 215 ++++++++++++++++++++++++++ 2 files changed, 219 insertions(+), 14 deletions(-) create mode 100644 server/internal/config/config_test.go diff --git a/server/internal/config/config.go b/server/internal/config/config.go index 90049d3..0cfa5f7 100644 --- a/server/internal/config/config.go +++ b/server/internal/config/config.go @@ -1,9 +1,9 @@ +// Package config provides the configuration for the application package config import ( "fmt" "os" - "path/filepath" "strconv" "strings" "time" @@ -11,6 +11,7 @@ import ( "novamd/internal/crypto" ) +// Config holds the configuration for the application type Config struct { DBPath string WorkDir string @@ -27,6 +28,7 @@ type Config struct { IsDevelopment bool } +// DefaultConfig returns a new Config instance with default values func DefaultConfig() *Config { return &Config{ DBPath: "./novamd.db", @@ -39,6 +41,7 @@ func DefaultConfig() *Config { } } +// Validate checks if the configuration is valid func (c *Config) Validate() error { if c.AdminEmail == "" || c.AdminPassword == "" { return fmt.Errorf("NOVAMD_ADMIN_EMAIL and NOVAMD_ADMIN_PASSWORD must be set") @@ -63,16 +66,10 @@ func Load() (*Config, error) { if dbPath := os.Getenv("NOVAMD_DB_PATH"); dbPath != "" { config.DBPath = dbPath } - if err := ensureDir(filepath.Dir(config.DBPath)); err != nil { - return nil, fmt.Errorf("failed to create database directory: %w", err) - } if workDir := os.Getenv("NOVAMD_WORKDIR"); workDir != "" { config.WorkDir = workDir } - if err := ensureDir(config.WorkDir); err != nil { - return nil, fmt.Errorf("failed to create work directory: %w", err) - } if staticPath := os.Getenv("NOVAMD_STATIC_PATH"); staticPath != "" { config.StaticPath = staticPath @@ -115,10 +112,3 @@ func Load() (*Config, error) { return config, nil } - -func ensureDir(dir string) error { - if dir == "" { - return nil - } - return os.MkdirAll(dir, 0755) -} diff --git a/server/internal/config/config_test.go b/server/internal/config/config_test.go new file mode 100644 index 0000000..51aef69 --- /dev/null +++ b/server/internal/config/config_test.go @@ -0,0 +1,215 @@ +package config_test + +import ( + "os" + "testing" + "time" + + "novamd/internal/config" +) + +func TestDefaultConfig(t *testing.T) { + cfg := config.DefaultConfig() + + tests := []struct { + name string + got interface{} + expected interface{} + }{ + {"DBPath", cfg.DBPath, "./novamd.db"}, + {"WorkDir", cfg.WorkDir, "./data"}, + {"StaticPath", cfg.StaticPath, "../app/dist"}, + {"Port", cfg.Port, "8080"}, + {"RateLimitRequests", cfg.RateLimitRequests, 100}, + {"RateLimitWindow", cfg.RateLimitWindow, time.Minute * 15}, + {"IsDevelopment", cfg.IsDevelopment, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.got != tt.expected { + t.Errorf("DefaultConfig().%s = %v, want %v", tt.name, tt.got, tt.expected) + } + }) + } +} + +// setEnv is a helper function to set environment variables and check for errors +func setEnv(t *testing.T, key, value string) { + if err := os.Setenv(key, value); err != nil { + t.Fatalf("Failed to set environment variable %s: %v", key, err) + } +} + +func TestLoad(t *testing.T) { + // Helper function to reset environment variables + cleanup := func() { + envVars := []string{ + "NOVAMD_ENV", + "NOVAMD_DB_PATH", + "NOVAMD_WORKDIR", + "NOVAMD_STATIC_PATH", + "NOVAMD_PORT", + "NOVAMD_APP_URL", + "NOVAMD_CORS_ORIGINS", + "NOVAMD_ADMIN_EMAIL", + "NOVAMD_ADMIN_PASSWORD", + "NOVAMD_ENCRYPTION_KEY", + "NOVAMD_JWT_SIGNING_KEY", + "NOVAMD_RATE_LIMIT_REQUESTS", + "NOVAMD_RATE_LIMIT_WINDOW", + } + for _, env := range envVars { + if err := os.Unsetenv(env); err != nil { + t.Fatalf("Failed to unset environment variable %s: %v", env, err) + } + } + } + + t.Run("load with defaults", func(t *testing.T) { + cleanup() + defer cleanup() + + // Set required env vars + setEnv(t, "NOVAMD_ADMIN_EMAIL", "admin@example.com") + setEnv(t, "NOVAMD_ADMIN_PASSWORD", "password123") + setEnv(t, "NOVAMD_ENCRYPTION_KEY", "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=") // 32 bytes base64 encoded + + cfg, err := config.Load() + if err != nil { + t.Fatalf("Load() error = %v", err) + } + + if cfg.DBPath != "./novamd.db" { + t.Errorf("default DBPath = %v, want %v", cfg.DBPath, "./novamd.db") + } + }) + + t.Run("load with custom values", func(t *testing.T) { + cleanup() + defer cleanup() + + // Set all environment variables + envs := map[string]string{ + "NOVAMD_ENV": "development", + "NOVAMD_DB_PATH": "/custom/db/path.db", + "NOVAMD_WORKDIR": "/custom/work/dir", + "NOVAMD_STATIC_PATH": "/custom/static/path", + "NOVAMD_PORT": "3000", + "NOVAMD_APP_URL": "http://localhost:3000", + "NOVAMD_CORS_ORIGINS": "http://localhost:3000,http://localhost:3001", + "NOVAMD_ADMIN_EMAIL": "admin@example.com", + "NOVAMD_ADMIN_PASSWORD": "password123", + "NOVAMD_ENCRYPTION_KEY": "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=", + "NOVAMD_JWT_SIGNING_KEY": "secret-key", + "NOVAMD_RATE_LIMIT_REQUESTS": "200", + "NOVAMD_RATE_LIMIT_WINDOW": "30m", + } + + for k, v := range envs { + setEnv(t, k, v) + } + + cfg, err := config.Load() + if err != nil { + t.Fatalf("Load() error = %v", err) + } + + tests := []struct { + name string + got interface{} + expected interface{} + }{ + {"IsDevelopment", cfg.IsDevelopment, true}, + {"DBPath", cfg.DBPath, "/custom/db/path.db"}, + {"WorkDir", cfg.WorkDir, "/custom/work/dir"}, + {"StaticPath", cfg.StaticPath, "/custom/static/path"}, + {"Port", cfg.Port, "3000"}, + {"AppURL", cfg.AppURL, "http://localhost:3000"}, + {"AdminEmail", cfg.AdminEmail, "admin@example.com"}, + {"AdminPassword", cfg.AdminPassword, "password123"}, + {"JWTSigningKey", cfg.JWTSigningKey, "secret-key"}, + {"RateLimitRequests", cfg.RateLimitRequests, 200}, + {"RateLimitWindow", cfg.RateLimitWindow, 30 * time.Minute}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.got != tt.expected { + t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.expected) + } + }) + } + + // Test CORS origins separately as it's a slice + expectedOrigins := []string{"http://localhost:3000", "http://localhost:3001"} + if len(cfg.CORSOrigins) != len(expectedOrigins) { + t.Errorf("CORSOrigins length = %v, want %v", len(cfg.CORSOrigins), len(expectedOrigins)) + } + for i, origin := range cfg.CORSOrigins { + if origin != expectedOrigins[i] { + t.Errorf("CORSOrigins[%d] = %v, want %v", i, origin, expectedOrigins[i]) + } + } + }) + + t.Run("validation failures", func(t *testing.T) { + testCases := []struct { + name string + setupEnv func(*testing.T) + expectedError string + }{ + { + name: "missing admin email", + setupEnv: func(t *testing.T) { + cleanup() + setEnv(t, "NOVAMD_ADMIN_PASSWORD", "password123") + setEnv(t, "NOVAMD_ENCRYPTION_KEY", "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=") + }, + expectedError: "NOVAMD_ADMIN_EMAIL and NOVAMD_ADMIN_PASSWORD must be set", + }, + { + name: "missing admin password", + setupEnv: func(t *testing.T) { + cleanup() + setEnv(t, "NOVAMD_ADMIN_EMAIL", "admin@example.com") + setEnv(t, "NOVAMD_ENCRYPTION_KEY", "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=") + }, + expectedError: "NOVAMD_ADMIN_EMAIL and NOVAMD_ADMIN_PASSWORD must be set", + }, + { + name: "missing encryption key", + setupEnv: func(t *testing.T) { + cleanup() + setEnv(t, "NOVAMD_ADMIN_EMAIL", "admin@example.com") + setEnv(t, "NOVAMD_ADMIN_PASSWORD", "password123") + }, + expectedError: "invalid NOVAMD_ENCRYPTION_KEY: encryption key is required", + }, + { + name: "invalid encryption key", + setupEnv: func(t *testing.T) { + cleanup() + setEnv(t, "NOVAMD_ADMIN_EMAIL", "admin@example.com") + setEnv(t, "NOVAMD_ADMIN_PASSWORD", "password123") + setEnv(t, "NOVAMD_ENCRYPTION_KEY", "invalid-key") + }, + expectedError: "invalid NOVAMD_ENCRYPTION_KEY: invalid base64 encoding: illegal base64 data at input byte 7", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.setupEnv(t) + _, err := config.Load() + if err == nil { + t.Error("expected error, got nil") + return + } + if err.Error() != tc.expectedError { + t.Errorf("error = %v, want error containing %v", err, tc.expectedError) + } + }) + } + }) +} From 8f2f8b30ddeaa3e58f205fae9c466c99f74c77d8 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 23 Nov 2024 21:28:15 +0100 Subject: [PATCH 17/38] Test secrets package --- server/internal/config/config.go | 5 +- server/internal/crypto/crypto.go | 114 ----------- server/internal/db/db.go | 14 +- server/internal/secrets/secrets.go | 112 +++++++++++ server/internal/secrets/secrets_test.go | 257 ++++++++++++++++++++++++ 5 files changed, 378 insertions(+), 124 deletions(-) delete mode 100644 server/internal/crypto/crypto.go create mode 100644 server/internal/secrets/secrets.go create mode 100644 server/internal/secrets/secrets_test.go diff --git a/server/internal/config/config.go b/server/internal/config/config.go index 0cfa5f7..41c21a6 100644 --- a/server/internal/config/config.go +++ b/server/internal/config/config.go @@ -3,12 +3,11 @@ package config import ( "fmt" + "novamd/internal/secrets" "os" "strconv" "strings" "time" - - "novamd/internal/crypto" ) // Config holds the configuration for the application @@ -48,7 +47,7 @@ func (c *Config) Validate() error { } // Validate encryption key - if err := crypto.ValidateKey(c.EncryptionKey); err != nil { + if err := secrets.ValidateKey(c.EncryptionKey); err != nil { return fmt.Errorf("invalid NOVAMD_ENCRYPTION_KEY: %w", err) } diff --git a/server/internal/crypto/crypto.go b/server/internal/crypto/crypto.go deleted file mode 100644 index 76cf338..0000000 --- a/server/internal/crypto/crypto.go +++ /dev/null @@ -1,114 +0,0 @@ -package crypto - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "encoding/base64" - "fmt" - "io" -) - -var ( - ErrKeyRequired = fmt.Errorf("encryption key is required") - ErrInvalidKeySize = fmt.Errorf("encryption key must be 32 bytes (256 bits) when decoded") -) - -type Crypto struct { - key []byte -} - -// ValidateKey checks if the provided key is suitable for AES-256 -func ValidateKey(key string) error { - if key == "" { - return ErrKeyRequired - } - - // Attempt to decode base64 - keyBytes, err := base64.StdEncoding.DecodeString(key) - if err != nil { - return fmt.Errorf("invalid base64 encoding: %w", err) - } - - if len(keyBytes) != 32 { - return fmt.Errorf("%w: got %d bytes", ErrInvalidKeySize, len(keyBytes)) - } - - // Verify the key can be used for AES - _, err = aes.NewCipher(keyBytes) - if err != nil { - return fmt.Errorf("invalid encryption key: %w", err) - } - - return nil -} - -// New creates a new Crypto instance with the provided base64-encoded key -func New(key string) (*Crypto, error) { - if err := ValidateKey(key); err != nil { - return nil, err - } - - keyBytes, _ := base64.StdEncoding.DecodeString(key) - return &Crypto{key: keyBytes}, nil -} - -// Encrypt encrypts the plaintext using AES-256-GCM -func (c *Crypto) Encrypt(plaintext string) (string, error) { - if plaintext == "" { - return "", nil - } - - block, err := aes.NewCipher(c.key) - if err != nil { - return "", err - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - - nonce := make([]byte, gcm.NonceSize()) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return "", err - } - - ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil) - return base64.StdEncoding.EncodeToString(ciphertext), nil -} - -// Decrypt decrypts the ciphertext using AES-256-GCM -func (c *Crypto) Decrypt(ciphertext string) (string, error) { - if ciphertext == "" { - return "", nil - } - - data, err := base64.StdEncoding.DecodeString(ciphertext) - if err != nil { - return "", err - } - - block, err := aes.NewCipher(c.key) - if err != nil { - return "", err - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - - nonceSize := gcm.NonceSize() - if len(data) < nonceSize { - return "", fmt.Errorf("ciphertext too short") - } - - nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:] - plaintext, err := gcm.Open(nil, nonce, ciphertextBytes, nil) - if err != nil { - return "", err - } - - return string(plaintext), nil -} diff --git a/server/internal/db/db.go b/server/internal/db/db.go index 8cd04e9..a2624ee 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -5,8 +5,8 @@ import ( "database/sql" "fmt" - "novamd/internal/crypto" "novamd/internal/models" + "novamd/internal/secrets" _ "github.com/mattn/go-sqlite3" // SQLite driver ) @@ -70,7 +70,7 @@ type Database interface { // database represents the database connection type database struct { *sql.DB - crypto *crypto.Crypto + secretsService secrets.Encryptor } // Init initializes the database connection @@ -85,14 +85,14 @@ func Init(dbPath string, encryptionKey string) (Database, error) { } // Initialize crypto service - cryptoService, err := crypto.New(encryptionKey) + secretsService, err := secrets.New(encryptionKey) if err != nil { return nil, fmt.Errorf("failed to initialize encryption: %w", err) } database := &database{ - DB: db, - crypto: cryptoService, + DB: db, + secretsService: secretsService, } if err := database.Migrate(); err != nil { @@ -112,12 +112,12 @@ func (db *database) encryptToken(token string) (string, error) { if token == "" { return "", nil } - return db.crypto.Encrypt(token) + return db.secretsService.Encrypt(token) } func (db *database) decryptToken(token string) (string, error) { if token == "" { return "", nil } - return db.crypto.Decrypt(token) + return db.secretsService.Decrypt(token) } diff --git a/server/internal/secrets/secrets.go b/server/internal/secrets/secrets.go new file mode 100644 index 0000000..abb81d4 --- /dev/null +++ b/server/internal/secrets/secrets.go @@ -0,0 +1,112 @@ +// Package secrets provides an Encryptor interface for encrypting and decrypting strings using AES-256-GCM. +package secrets + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "fmt" + "io" +) + +// Encryptor is an interface for encrypting and decrypting strings +type Encryptor interface { + Encrypt(plaintext string) (string, error) + Decrypt(ciphertext string) (string, error) +} + +type encryptor struct { + gcm cipher.AEAD +} + +// ValidateKey checks if the provided base64-encoded key is suitable for AES-256 +func ValidateKey(key string) error { + _, err := decodeAndValidateKey(key) + return err +} + +// decodeAndValidateKey validates and decodes the base64-encoded key +// Returns the decoded key bytes if valid +func decodeAndValidateKey(key string) ([]byte, error) { + if key == "" { + return nil, fmt.Errorf("encryption key is required") + } + + keyBytes, err := base64.StdEncoding.DecodeString(key) + if err != nil { + return nil, fmt.Errorf("invalid base64 encoding: %w", err) + } + + if len(keyBytes) != 32 { + return nil, fmt.Errorf("encryption key must be 32 bytes (256 bits): got %d bytes", len(keyBytes)) + } + + // Verify the key can be used for AES + _, err = aes.NewCipher(keyBytes) + if err != nil { + return nil, fmt.Errorf("invalid encryption key: %w", err) + } + + return keyBytes, nil +} + +// New creates a new Crypto instance with the provided base64-encoded key +func New(key string) (Encryptor, error) { + keyBytes, err := decodeAndValidateKey(key) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(keyBytes) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + return &encryptor{gcm: gcm}, nil +} + +// Encrypt encrypts the plaintext using AES-256-GCM +func (e *encryptor) Encrypt(plaintext string) (string, error) { + if plaintext == "" { + return "", nil + } + + nonce := make([]byte, e.gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + + ciphertext := e.gcm.Seal(nonce, nonce, []byte(plaintext), nil) + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +// Decrypt decrypts the ciphertext using AES-256-GCM +func (e *encryptor) Decrypt(ciphertext string) (string, error) { + if ciphertext == "" { + return "", nil + } + + data, err := base64.StdEncoding.DecodeString(ciphertext) + if err != nil { + return "", fmt.Errorf("invalid base64 encoding: %w", err) + } + + nonceSize := e.gcm.NonceSize() + if len(data) < nonceSize { + return "", fmt.Errorf("invalid ciphertext: too short") + } + + nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:] + plaintext, err := e.gcm.Open(nil, nonce, ciphertextBytes, nil) + if err != nil { + return "", err + } + + return string(plaintext), nil +} diff --git a/server/internal/secrets/secrets_test.go b/server/internal/secrets/secrets_test.go new file mode 100644 index 0000000..ed09d98 --- /dev/null +++ b/server/internal/secrets/secrets_test.go @@ -0,0 +1,257 @@ +package secrets_test + +import ( + "encoding/base64" + "strings" + "testing" + + "novamd/internal/secrets" +) + +func TestValidateKey(t *testing.T) { + testCases := []struct { + name string + key string + wantErr bool + errContains string + }{ + { + name: "valid 32-byte base64 key", + key: base64.StdEncoding.EncodeToString(make([]byte, 32)), + wantErr: false, + }, + { + name: "empty key", + key: "", + wantErr: true, + errContains: "encryption key is required", + }, + { + name: "invalid base64", + key: "not-base64!@#$", + wantErr: true, + errContains: "invalid base64 encoding", + }, + { + name: "wrong key size (16 bytes)", + key: base64.StdEncoding.EncodeToString(make([]byte, 16)), + wantErr: true, + errContains: "encryption key must be 32 bytes", + }, + { + name: "wrong key size (64 bytes)", + key: base64.StdEncoding.EncodeToString(make([]byte, 64)), + wantErr: true, + errContains: "encryption key must be 32 bytes", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := secrets.ValidateKey(tc.key) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + return + } + if !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %q", err, tc.errContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +func TestNew(t *testing.T) { + testCases := []struct { + name string + key string + wantErr bool + errContains string + }{ + { + name: "valid key", + key: base64.StdEncoding.EncodeToString(make([]byte, 32)), + wantErr: false, + }, + { + name: "empty key", + key: "", + wantErr: true, + errContains: "encryption key is required", + }, + { + name: "invalid key", + key: "invalid", + wantErr: true, + errContains: "invalid base64 encoding", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e, err := secrets.New(tc.key) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + return + } + if !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %q", err, tc.errContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if e == nil { + t.Error("expected Encryptor instance, got nil") + } + }) + } +} + +func TestEncryptDecrypt(t *testing.T) { + // Generate a valid key for testing + key := base64.StdEncoding.EncodeToString(make([]byte, 32)) + e, err := secrets.New(key) + if err != nil { + t.Fatalf("failed to create Encryptor instance: %v", err) + } + + testCases := []struct { + name string + plaintext string + wantErr bool + }{ + { + name: "normal text", + plaintext: "Hello, World!", + wantErr: false, + }, + { + name: "empty string", + plaintext: "", + wantErr: false, + }, + { + name: "long text", + plaintext: strings.Repeat("Long text with lots of content. ", 100), + wantErr: false, + }, + { + name: "special characters", + plaintext: "!@#$%^&*()_+-=[]{}|;:,.<>?", + wantErr: false, + }, + { + name: "unicode characters", + plaintext: "Hello, 世界! नमस्ते", + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Test encryption + ciphertext, err := e.Encrypt(tc.plaintext) + if tc.wantErr { + if err == nil { + t.Error("expected encryption error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected encryption error: %v", err) + } + + // Verify ciphertext is different from plaintext + if tc.plaintext != "" && ciphertext == tc.plaintext { + t.Error("ciphertext matches plaintext") + } + + // Test decryption + decrypted, err := e.Decrypt(ciphertext) + if err != nil { + t.Fatalf("unexpected decryption error: %v", err) + } + + // Verify decrypted text matches original + if decrypted != tc.plaintext { + t.Errorf("decrypted text = %q, want %q", decrypted, tc.plaintext) + } + }) + } +} + +func TestDecryptInvalidCiphertext(t *testing.T) { + key := base64.StdEncoding.EncodeToString(make([]byte, 32)) + e, err := secrets.New(key) + if err != nil { + t.Fatalf("failed to create Encryptor instance: %v", err) + } + + testCases := []struct { + name string + ciphertext string + wantErr bool + errContains string + }{ + { + name: "empty ciphertext", + ciphertext: "", + wantErr: false, + }, + { + name: "invalid base64", + ciphertext: "not-base64!@#$", + wantErr: true, + errContains: "invalid base64 encoding", + }, + { + name: "invalid ciphertext (too short)", + ciphertext: base64.StdEncoding.EncodeToString(make([]byte, 10)), + wantErr: true, + errContains: "invalid ciphertext: too short", + }, + { + name: "tampered ciphertext", + ciphertext: base64.StdEncoding.EncodeToString(make([]byte, 50)), + wantErr: true, + errContains: "message authentication failed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decrypted, err := e.Decrypt(tc.ciphertext) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + return + } + if !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %q", err, tc.errContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if decrypted != "" { + t.Errorf("expected empty string, got %q", decrypted) + } + }) + } +} From 9f241271a7caca56ce8f80e65a8de174d6b0205e Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 23 Nov 2024 22:15:25 +0100 Subject: [PATCH 18/38] Test context package --- server/internal/context/context_test.go | 139 +++++++++++++++ server/internal/context/middleware.go | 2 +- server/internal/context/middleware_test.go | 197 +++++++++++++++++++++ server/internal/db/db.go | 33 +++- 4 files changed, 366 insertions(+), 5 deletions(-) create mode 100644 server/internal/context/context_test.go create mode 100644 server/internal/context/middleware_test.go diff --git a/server/internal/context/context_test.go b/server/internal/context/context_test.go new file mode 100644 index 0000000..8a0d947 --- /dev/null +++ b/server/internal/context/context_test.go @@ -0,0 +1,139 @@ +package context_test + +import ( + stdctx "context" + "net/http" + "net/http/httptest" + "testing" + + "novamd/internal/context" +) + +func TestGetRequestContext(t *testing.T) { + tests := []struct { + name string + setupCtx func() *context.HandlerContext + wantStatus int + wantOK bool + }{ + { + name: "valid context", + setupCtx: func() *context.HandlerContext { + return &context.HandlerContext{ + UserID: 1, + UserRole: "admin", + } + }, + wantStatus: http.StatusOK, + wantOK: true, + }, + { + name: "missing context", + setupCtx: func() *context.HandlerContext { + return nil + }, + wantStatus: http.StatusInternalServerError, + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create test request + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + if ctx := tt.setupCtx(); ctx != nil { + req = context.WithHandlerContext(req, ctx) + } + + gotCtx, ok := context.GetRequestContext(w, req) + + if ok != tt.wantOK { + t.Errorf("GetRequestContext() ok = %v, want %v", ok, tt.wantOK) + } + + if !tt.wantOK { + if w.Code != tt.wantStatus { + t.Errorf("GetRequestContext() status = %v, want %v", w.Code, tt.wantStatus) + } + return + } + + if gotCtx.UserID != tt.setupCtx().UserID { + t.Errorf("GetRequestContext() UserID = %v, want %v", gotCtx.UserID, tt.setupCtx().UserID) + } + + if gotCtx.UserRole != tt.setupCtx().UserRole { + t.Errorf("GetRequestContext() UserRole = %v, want %v", gotCtx.UserRole, tt.setupCtx().UserRole) + } + }) + } +} + +func TestGetUserFromContext(t *testing.T) { + tests := []struct { + name string + setupCtx func() stdctx.Context + wantUser *context.UserClaims + wantError bool + }{ + { + name: "valid user context", + setupCtx: func() stdctx.Context { + return stdctx.WithValue(stdctx.Background(), context.HandlerContextKey, &context.HandlerContext{ + UserID: 1, + UserRole: "admin", + }) + }, + wantUser: &context.UserClaims{ + UserID: 1, + Role: "admin", + }, + wantError: false, + }, + { + name: "missing context", + setupCtx: func() stdctx.Context { + return stdctx.Background() + }, + wantUser: nil, + wantError: true, + }, + { + name: "invalid context type", + setupCtx: func() stdctx.Context { + return stdctx.WithValue(stdctx.Background(), context.HandlerContextKey, "invalid") + }, + wantUser: nil, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.setupCtx() + gotUser, err := context.GetUserFromContext(ctx) + + if tt.wantError { + if err == nil { + t.Error("GetUserFromContext() error = nil, want error") + } + return + } + + if err != nil { + t.Errorf("GetUserFromContext() unexpected error = %v", err) + return + } + + if gotUser.UserID != tt.wantUser.UserID { + t.Errorf("GetUserFromContext() UserID = %v, want %v", gotUser.UserID, tt.wantUser.UserID) + } + + if gotUser.Role != tt.wantUser.Role { + t.Errorf("GetUserFromContext() Role = %v, want %v", gotUser.Role, tt.wantUser.Role) + } + }) + } +} diff --git a/server/internal/context/middleware.go b/server/internal/context/middleware.go index 9c1d9b3..c916e96 100644 --- a/server/internal/context/middleware.go +++ b/server/internal/context/middleware.go @@ -28,7 +28,7 @@ func WithUserContextMiddleware(next http.Handler) http.Handler { } // WithWorkspaceContextMiddleware adds workspace information to the request context -func WithWorkspaceContextMiddleware(db db.Database) func(http.Handler) http.Handler { +func WithWorkspaceContextMiddleware(db db.WorkspaceReader) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, ok := GetRequestContext(w, r) diff --git a/server/internal/context/middleware_test.go b/server/internal/context/middleware_test.go new file mode 100644 index 0000000..eae1c97 --- /dev/null +++ b/server/internal/context/middleware_test.go @@ -0,0 +1,197 @@ +package context_test + +import ( + stdctx "context" + "database/sql" + "net/http" + "net/http/httptest" + "testing" + + "novamd/internal/context" + "novamd/internal/models" +) + +// MockDB implements the minimal database interface needed for testing +type MockDB struct { + GetWorkspaceByNameFunc func(userID int, workspaceName string) (*models.Workspace, error) +} + +func (m *MockDB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) { + return m.GetWorkspaceByNameFunc(userID, workspaceName) +} + +func (m *MockDB) GetWorkspaceByID(_ int) (*models.Workspace, error) { + return nil, nil +} + +func (m *MockDB) GetWorkspacesByUserID(_ int) ([]*models.Workspace, error) { + return nil, nil +} + +func (m *MockDB) GetAllWorkspaces() ([]*models.Workspace, error) { + return nil, nil +} + +func TestWithUserContextMiddleware(t *testing.T) { + tests := []struct { + name string + setupCtx func() *context.HandlerContext + wantStatus int + wantNext bool + }{ + { + name: "valid user context", + setupCtx: func() *context.HandlerContext { + return &context.HandlerContext{ + UserID: 1, + UserRole: "admin", + } + }, + wantStatus: http.StatusOK, + wantNext: true, + }, + { + name: "missing user context", + setupCtx: func() *context.HandlerContext { + return nil + }, + wantStatus: http.StatusUnauthorized, + wantNext: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + if ctx := tt.setupCtx(); ctx != nil { + req = context.WithHandlerContext(req, ctx) + } + + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + middleware := context.WithUserContextMiddleware(next) + middleware.ServeHTTP(w, req) + + if nextCalled != tt.wantNext { + t.Errorf("WithUserContextMiddleware() next called = %v, want %v", nextCalled, tt.wantNext) + } + + if w.Code != tt.wantStatus { + t.Errorf("WithUserContextMiddleware() status = %v, want %v", w.Code, tt.wantStatus) + } + }) + } +} + +func TestWithWorkspaceContextMiddleware(t *testing.T) { + tests := []struct { + name string + setupCtx func() *context.HandlerContext + workspaceName string + mockWorkspace *models.Workspace + mockError error + wantStatus int + wantNext bool + }{ + { + name: "valid workspace context", + setupCtx: func() *context.HandlerContext { + return &context.HandlerContext{ + UserID: 1, + UserRole: "admin", + } + }, + workspaceName: "test-workspace", + mockWorkspace: &models.Workspace{ + ID: 1, + UserID: 1, + Name: "test-workspace", + }, + mockError: nil, + wantStatus: http.StatusOK, + wantNext: true, + }, + { + name: "workspace not found", + setupCtx: func() *context.HandlerContext { + return &context.HandlerContext{ + UserID: 1, + UserRole: "admin", + } + }, + workspaceName: "nonexistent", + mockWorkspace: nil, + mockError: sql.ErrNoRows, + wantStatus: http.StatusNotFound, + wantNext: false, + }, + { + name: "missing user context", + setupCtx: func() *context.HandlerContext { return nil }, + workspaceName: "test-workspace", + mockWorkspace: nil, + mockError: nil, + wantStatus: http.StatusInternalServerError, + wantNext: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockDB := &MockDB{ + GetWorkspaceByNameFunc: func(_ int, _ string) (*models.Workspace, error) { + return tt.mockWorkspace, tt.mockError + }, + } + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + if ctx := tt.setupCtx(); ctx != nil { + req = context.WithHandlerContext(req, ctx) + } + + // Add workspace name to request context via chi URL params + req = req.WithContext(stdctx.WithValue(req.Context(), "workspaceName", tt.workspaceName)) + + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + + // Verify workspace was added to context + if tt.mockWorkspace != nil { + ctx, ok := context.GetRequestContext(w, r) + if !ok { + t.Error("Failed to get request context in next handler") + return + } + if ctx.Workspace == nil { + t.Error("Workspace not set in context") + return + } + if ctx.Workspace.ID != tt.mockWorkspace.ID { + t.Errorf("Workspace ID = %v, want %v", ctx.Workspace.ID, tt.mockWorkspace.ID) + } + } + }) + + middleware := context.WithWorkspaceContextMiddleware(mockDB)(next) + middleware.ServeHTTP(w, req) + + if nextCalled != tt.wantNext { + t.Errorf("WithWorkspaceContextMiddleware() next called = %v, want %v", nextCalled, tt.wantNext) + } + + if w.Code != tt.wantStatus { + t.Errorf("WithWorkspaceContextMiddleware() status = %v, want %v", w.Code, tt.wantStatus) + } + }) + } +} diff --git a/server/internal/db/db.go b/server/internal/db/db.go index a2624ee..4c45ddc 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -24,12 +24,17 @@ type UserStore interface { CountAdminUsers() (int, error) } -// WorkspaceStore defines the methods for interacting with workspace data in the database -type WorkspaceStore interface { - CreateWorkspace(workspace *models.Workspace) error +// WorkspaceReader defines the methods for reading workspace data from the database +type WorkspaceReader interface { GetWorkspaceByID(workspaceID int) (*models.Workspace, error) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) + GetAllWorkspaces() ([]*models.Workspace, error) +} + +// WorkspaceWriter defines the methods for writing workspace data to the database +type WorkspaceWriter interface { + CreateWorkspace(workspace *models.Workspace) error UpdateWorkspace(workspace *models.Workspace) error DeleteWorkspace(workspaceID int) error UpdateWorkspaceSettings(workspace *models.Workspace) error @@ -37,7 +42,12 @@ type WorkspaceStore interface { UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error UpdateLastOpenedFile(workspaceID int, filePath string) error GetLastOpenedFile(workspaceID int) (string, error) - GetAllWorkspaces() ([]*models.Workspace, error) +} + +// WorkspaceStore defines the methods for interacting with workspace data in the database +type WorkspaceStore interface { + WorkspaceReader + WorkspaceWriter } // SessionStore defines the methods for interacting with jwt sessions in the database @@ -67,6 +77,21 @@ type Database interface { Migrate() error } +var ( + // Main Database interface + _ Database = (*database)(nil) + + // Component interfaces + _ UserStore = (*database)(nil) + _ WorkspaceStore = (*database)(nil) + _ SessionStore = (*database)(nil) + _ SystemStore = (*database)(nil) + + // Sub-interfaces + _ WorkspaceReader = (*database)(nil) + _ WorkspaceWriter = (*database)(nil) +) + // database represents the database connection type database struct { *sql.DB From 9d81b1036d678dc738481bb9f4cd5d6406f67732 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 23 Nov 2024 22:33:55 +0100 Subject: [PATCH 19/38] Refactor db init --- server/cmd/server/main.go | 19 +++++++++++++++++-- server/internal/db/db.go | 15 ++------------- server/internal/secrets/secrets.go | 8 ++++---- server/internal/secrets/secrets_test.go | 6 +++--- 4 files changed, 26 insertions(+), 22 deletions(-) diff --git a/server/cmd/server/main.go b/server/cmd/server/main.go index 7fbf006..39758cf 100644 --- a/server/cmd/server/main.go +++ b/server/cmd/server/main.go @@ -19,6 +19,7 @@ import ( "novamd/internal/config" "novamd/internal/db" "novamd/internal/handlers" + "novamd/internal/secrets" "novamd/internal/storage" ) @@ -29,12 +30,26 @@ func main() { log.Fatal("Failed to load configuration:", err) } + // Initialize secrets service + secretsService, err := secrets.NewService(cfg.EncryptionKey) + if err != nil { + log.Fatal("Failed to initialize secrets service:", err) + } + // Initialize database - database, err := db.Init(cfg.DBPath, cfg.EncryptionKey) + database, err := db.Init(cfg.DBPath, secretsService) if err != nil { log.Fatal(err) } - defer database.Close() + err = database.Migrate() + if err != nil { + log.Fatal("Failed to apply database migrations:", err) + } + defer func() { + if err := database.Close(); err != nil { + log.Printf("Error closing database: %v", err) + } + }() // Get or generate JWT signing key signingKey := cfg.JWTSigningKey diff --git a/server/internal/db/db.go b/server/internal/db/db.go index 4c45ddc..04df282 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -3,7 +3,6 @@ package db import ( "database/sql" - "fmt" "novamd/internal/models" "novamd/internal/secrets" @@ -95,11 +94,11 @@ var ( // database represents the database connection type database struct { *sql.DB - secretsService secrets.Encryptor + secretsService secrets.Service } // Init initializes the database connection -func Init(dbPath string, encryptionKey string) (Database, error) { +func Init(dbPath string, secretsService secrets.Service) (Database, error) { db, err := sql.Open("sqlite3", dbPath) if err != nil { return nil, err @@ -109,21 +108,11 @@ func Init(dbPath string, encryptionKey string) (Database, error) { return nil, err } - // Initialize crypto service - secretsService, err := secrets.New(encryptionKey) - if err != nil { - return nil, fmt.Errorf("failed to initialize encryption: %w", err) - } - database := &database{ DB: db, secretsService: secretsService, } - if err := database.Migrate(); err != nil { - return nil, err - } - return database, nil } diff --git a/server/internal/secrets/secrets.go b/server/internal/secrets/secrets.go index abb81d4..2eab8d1 100644 --- a/server/internal/secrets/secrets.go +++ b/server/internal/secrets/secrets.go @@ -10,8 +10,8 @@ import ( "io" ) -// Encryptor is an interface for encrypting and decrypting strings -type Encryptor interface { +// Service is an interface for encrypting and decrypting strings +type Service interface { Encrypt(plaintext string) (string, error) Decrypt(ciphertext string) (string, error) } @@ -51,8 +51,8 @@ func decodeAndValidateKey(key string) ([]byte, error) { return keyBytes, nil } -// New creates a new Crypto instance with the provided base64-encoded key -func New(key string) (Encryptor, error) { +// NewService creates a new Encryptor instance with the provided base64-encoded key +func NewService(key string) (Service, error) { keyBytes, err := decodeAndValidateKey(key) if err != nil { return nil, err diff --git a/server/internal/secrets/secrets_test.go b/server/internal/secrets/secrets_test.go index ed09d98..f1db818 100644 --- a/server/internal/secrets/secrets_test.go +++ b/server/internal/secrets/secrets_test.go @@ -96,7 +96,7 @@ func TestNew(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - e, err := secrets.New(tc.key) + e, err := secrets.NewService(tc.key) if tc.wantErr { if err == nil { @@ -122,7 +122,7 @@ func TestNew(t *testing.T) { func TestEncryptDecrypt(t *testing.T) { // Generate a valid key for testing key := base64.StdEncoding.EncodeToString(make([]byte, 32)) - e, err := secrets.New(key) + e, err := secrets.NewService(key) if err != nil { t.Fatalf("failed to create Encryptor instance: %v", err) } @@ -194,7 +194,7 @@ func TestEncryptDecrypt(t *testing.T) { func TestDecryptInvalidCiphertext(t *testing.T) { key := base64.StdEncoding.EncodeToString(make([]byte, 32)) - e, err := secrets.New(key) + e, err := secrets.NewService(key) if err != nil { t.Fatalf("failed to create Encryptor instance: %v", err) } From 1e7cd0934e51e0634db213c3e5db57ab491039ec Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 24 Nov 2024 00:17:08 +0100 Subject: [PATCH 20/38] Add migrations tests --- .vscode/settings.json | 4 +- server/internal/db/migrations.go | 3 + server/internal/db/migrations_test.go | 156 ++++++++++++++++++++++++++ server/internal/db/testing.go | 30 +++++ 4 files changed, 192 insertions(+), 1 deletion(-) create mode 100644 server/internal/db/migrations_test.go create mode 100644 server/internal/db/testing.go diff --git a/.vscode/settings.json b/.vscode/settings.json index 4d7334d..34f5162 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,6 +14,7 @@ "go.lintTool": "golangci-lint", "go.lintOnSave": "package", "go.formatTool": "goimports", + "go.testFlags": ["-tags=test"], "[go]": { "editor.formatOnSave": true, "editor.codeActionsOnSave": { @@ -23,6 +24,7 @@ }, "gopls": { "usePlaceholders": true, - "staticcheck": true + "staticcheck": true, + "buildFlags": ["-tags", "test"] } } diff --git a/server/internal/db/migrations.go b/server/internal/db/migrations.go index f59b20f..6ecb3fb 100644 --- a/server/internal/db/migrations.go +++ b/server/internal/db/migrations.go @@ -49,6 +49,9 @@ var migrations = []Migration{ { Version: 2, SQL: ` + -- Enable foreign key constraints + PRAGMA foreign_keys = ON; + -- Create sessions table for authentication CREATE TABLE IF NOT EXISTS sessions ( id TEXT PRIMARY KEY, diff --git a/server/internal/db/migrations_test.go b/server/internal/db/migrations_test.go new file mode 100644 index 0000000..6d1a5fb --- /dev/null +++ b/server/internal/db/migrations_test.go @@ -0,0 +1,156 @@ +package db_test + +import ( + "testing" + + "novamd/internal/db" + + _ "github.com/mattn/go-sqlite3" +) + +type mockSecrets struct{} + +func (m *mockSecrets) Encrypt(s string) (string, error) { return s, nil } +func (m *mockSecrets) Decrypt(s string) (string, error) { return s, nil } + +func TestMigrate(t *testing.T) { + database, err := db.NewTestDB(":memory:", &mockSecrets{}) + if err != nil { + t.Fatalf("failed to initialize database: %v", err) + } + defer database.Close() + + t.Run("migrations are applied in order", func(t *testing.T) { + if err := database.Migrate(); err != nil { + t.Fatalf("failed to run initial migrations: %v", err) + } + + // Check migration version + var version int + err := database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version) + if err != nil { + t.Fatalf("failed to get migration version: %v", err) + } + + if version != 2 { // Current number of migrations in production code + t.Errorf("expected migration version 2, got %d", version) + } + + // Verify number of migration entries matches versions applied + var count int + err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count) + if err != nil { + t.Fatalf("failed to count migrations: %v", err) + } + + if count != 2 { + t.Errorf("expected 2 migration entries, got %d", count) + } + }) + + t.Run("migrations create expected schema", func(t *testing.T) { + // Verify tables exist + tables := []string{"users", "workspaces", "sessions", "system_settings", "migrations"} + for _, table := range tables { + if !tableExists(t, database, table) { + t.Errorf("table %q does not exist", table) + } + } + + // Verify indexes + indexes := []struct { + table string + name string + }{ + {"sessions", "idx_sessions_user_id"}, + {"sessions", "idx_sessions_expires_at"}, + {"sessions", "idx_sessions_refresh_token"}, + } + + for _, idx := range indexes { + if !indexExists(t, database, idx.table, idx.name) { + t.Errorf("index %q on table %q does not exist", idx.name, idx.table) + } + } + }) + + t.Run("migrations are idempotent", func(t *testing.T) { + // Run migrations again + if err := database.Migrate(); err != nil { + t.Fatalf("failed to re-run migrations: %v", err) + } + + // Verify migration count hasn't changed + var count int + err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count) + if err != nil { + t.Fatalf("failed to count migrations: %v", err) + } + + if count != 2 { + t.Errorf("expected 2 migration entries, got %d", count) + } + }) + + t.Run("rollback on migration failure", func(t *testing.T) { + // Create a test table that would conflict with a failing migration + _, err := database.TestDB().Exec("CREATE TABLE test_rollback (id INTEGER PRIMARY KEY)") + if err != nil { + t.Fatalf("failed to create test table: %v", err) + } + + // Start transaction + tx, err := database.Begin() + if err != nil { + t.Fatalf("failed to start transaction: %v", err) + } + + // Try operations that should fail and rollback + _, err = tx.Exec(` + CREATE TABLE test_rollback (id INTEGER PRIMARY KEY); + INSERT INTO nonexistent_table VALUES (1); + `) + if err == nil { + tx.Rollback() + t.Fatal("expected migration to fail") + } + tx.Rollback() + + // Verify the migration version hasn't changed + var version int + err = database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version) + if err != nil { + t.Fatalf("failed to get migration version: %v", err) + } + + if version != 2 { + t.Errorf("expected migration version to remain at 2, got %d", version) + } + }) +} + +func tableExists(t *testing.T, database db.TestDatabase, tableName string) bool { + t.Helper() + + var name string + err := database.TestDB().QueryRow(` + SELECT name FROM sqlite_master + WHERE type='table' AND name=?`, + tableName, + ).Scan(&name) + + return err == nil +} + +func indexExists(t *testing.T, database db.TestDatabase, tableName, indexName string) bool { + t.Helper() + + var name string + err := database.TestDB().QueryRow(` + SELECT name FROM sqlite_master + WHERE type='index' AND tbl_name=? AND name=?`, + tableName, indexName, + ).Scan(&name) + + return err == nil +} diff --git a/server/internal/db/testing.go b/server/internal/db/testing.go new file mode 100644 index 0000000..41a63d0 --- /dev/null +++ b/server/internal/db/testing.go @@ -0,0 +1,30 @@ +//go:build test + +package db + +import ( + "database/sql" + "novamd/internal/secrets" +) + +type TestDatabase interface { + Database + TestDB() *sql.DB +} + +func NewTestDB(dbPath string, secretsService secrets.Service) (TestDatabase, error) { + db, err := Init(dbPath, secretsService) + if err != nil { + return nil, err + } + + return &testDatabase{db.(*database)}, nil +} + +type testDatabase struct { + *database +} + +func (td *testDatabase) TestDB() *sql.DB { + return td.DB +} From 9ac047d440a5230bee4f7592bbff64c21c01a20d Mon Sep 17 00:00:00 2001 From: LordMathis Date: Mon, 25 Nov 2024 20:54:49 +0100 Subject: [PATCH 21/38] Delete unused test case fixture --- server/internal/testutils/assertions.go | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 server/internal/testutils/assertions.go diff --git a/server/internal/testutils/assertions.go b/server/internal/testutils/assertions.go deleted file mode 100644 index 5bbfa23..0000000 --- a/server/internal/testutils/assertions.go +++ /dev/null @@ -1,13 +0,0 @@ -package testutils - -import ( - "testing" -) - -// TestCase defines a generic test case structure that can be used across packages -type TestCase struct { - Name string - Setup func(t *testing.T, fixtures any) - Fixtures any - Validate func(t *testing.T, result any, err error) -} From 32bd202d6f9472b2bb629449436b50ee8df58ef6 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Mon, 25 Nov 2024 21:44:43 +0100 Subject: [PATCH 22/38] Implement session and system tests --- server/internal/db/db.go | 5 + server/internal/db/migrations.go | 3 - server/internal/db/migrations_test.go | 5 - server/internal/db/sessions_test.go | 294 +++++++++++++++++++ server/internal/db/system_test.go | 213 ++++++++++++++ server/internal/db/{testing.go => testdb.go} | 0 server/internal/db/testutil_test.go | 6 + server/internal/db/users.go | 2 +- server/internal/db/workspaces.go | 2 +- 9 files changed, 520 insertions(+), 10 deletions(-) create mode 100644 server/internal/db/sessions_test.go create mode 100644 server/internal/db/system_test.go rename server/internal/db/{testing.go => testdb.go} (100%) create mode 100644 server/internal/db/testutil_test.go diff --git a/server/internal/db/db.go b/server/internal/db/db.go index 04df282..a54ba4a 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -108,6 +108,11 @@ func Init(dbPath string, secretsService secrets.Service) (Database, error) { return nil, err } + // Enable foreign keys for this connection + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + return nil, err + } + database := &database{ DB: db, secretsService: secretsService, diff --git a/server/internal/db/migrations.go b/server/internal/db/migrations.go index 6ecb3fb..f59b20f 100644 --- a/server/internal/db/migrations.go +++ b/server/internal/db/migrations.go @@ -49,9 +49,6 @@ var migrations = []Migration{ { Version: 2, SQL: ` - -- Enable foreign key constraints - PRAGMA foreign_keys = ON; - -- Create sessions table for authentication CREATE TABLE IF NOT EXISTS sessions ( id TEXT PRIMARY KEY, diff --git a/server/internal/db/migrations_test.go b/server/internal/db/migrations_test.go index 6d1a5fb..0b966c0 100644 --- a/server/internal/db/migrations_test.go +++ b/server/internal/db/migrations_test.go @@ -8,11 +8,6 @@ import ( _ "github.com/mattn/go-sqlite3" ) -type mockSecrets struct{} - -func (m *mockSecrets) Encrypt(s string) (string, error) { return s, nil } -func (m *mockSecrets) Decrypt(s string) (string, error) { return s, nil } - func TestMigrate(t *testing.T) { database, err := db.NewTestDB(":memory:", &mockSecrets{}) if err != nil { diff --git a/server/internal/db/sessions_test.go b/server/internal/db/sessions_test.go new file mode 100644 index 0000000..f5f87eb --- /dev/null +++ b/server/internal/db/sessions_test.go @@ -0,0 +1,294 @@ +package db_test + +import ( + "strings" + "testing" + "time" + + "novamd/internal/db" + "novamd/internal/models" + + "github.com/google/uuid" +) + +func TestSessionOperations(t *testing.T) { + database, err := db.NewTestDB(":memory:", &mockSecrets{}) + if err != nil { + t.Fatalf("failed to create test database: %v", err) + } + defer database.Close() + + if err := database.Migrate(); err != nil { + t.Fatalf("failed to run migrations: %v", err) + } + + // Create a test user first since sessions need a valid user ID + user, err := database.CreateUser(&models.User{ + Email: "test@example.com", + DisplayName: "Test User", + PasswordHash: "hash", + Role: "editor", + }) + if err != nil { + t.Fatalf("failed to create test user: %v", err) + } + + t.Run("CreateSession", func(t *testing.T) { + testCases := []struct { + name string + session *models.Session + wantErr bool + errContains string + }{ + { + name: "valid session", + session: &models.Session{ + ID: uuid.New().String(), + UserID: user.ID, + RefreshToken: "valid-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + }, + wantErr: false, + }, + { + name: "invalid user ID", + session: &models.Session{ + ID: uuid.New().String(), + UserID: 99999, // Non-existent user ID + RefreshToken: "invalid-user-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + }, + wantErr: true, + errContains: "FOREIGN KEY constraint failed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := database.CreateSession(tc.session) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify session was stored + stored, err := database.GetSessionByRefreshToken(tc.session.RefreshToken) + if err != nil { + t.Fatalf("failed to retrieve stored session: %v", err) + } + + // Compare fields + if stored.ID != tc.session.ID { + t.Errorf("ID = %v, want %v", stored.ID, tc.session.ID) + } + if stored.UserID != tc.session.UserID { + t.Errorf("UserID = %v, want %v", stored.UserID, tc.session.UserID) + } + if stored.RefreshToken != tc.session.RefreshToken { + t.Errorf("RefreshToken = %v, want %v", stored.RefreshToken, tc.session.RefreshToken) + } + // Compare times within a reasonable threshold + if diff := stored.ExpiresAt.Sub(tc.session.ExpiresAt); diff > time.Second || diff < -time.Second { + t.Errorf("ExpiresAt differs by %v, want difference less than 1s", diff) + } + }) + } + }) + + t.Run("GetSessionByRefreshToken", func(t *testing.T) { + // Create test sessions + validSession := &models.Session{ + ID: uuid.New().String(), + UserID: user.ID, + RefreshToken: "valid-get-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + } + expiredSession := &models.Session{ + ID: uuid.New().String(), + UserID: user.ID, + RefreshToken: "expired-token", + ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired + CreatedAt: time.Now().Add(-2 * time.Hour), + } + + if err := database.CreateSession(validSession); err != nil { + t.Fatalf("failed to create valid session: %v", err) + } + if err := database.CreateSession(expiredSession); err != nil { + t.Fatalf("failed to create expired session: %v", err) + } + + testCases := []struct { + name string + refreshToken string + wantErr bool + errContains string + }{ + { + name: "valid token", + refreshToken: "valid-get-token", + wantErr: false, + }, + { + name: "expired token", + refreshToken: "expired-token", + wantErr: true, + errContains: "session not found or expired", + }, + { + name: "non-existent token", + refreshToken: "nonexistent-token", + wantErr: true, + errContains: "session not found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + session, err := database.GetSessionByRefreshToken(tc.refreshToken) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if session.RefreshToken != tc.refreshToken { + t.Errorf("RefreshToken = %v, want %v", session.RefreshToken, tc.refreshToken) + } + }) + } + }) + + t.Run("DeleteSession", func(t *testing.T) { + session := &models.Session{ + ID: uuid.New().String(), + UserID: user.ID, + RefreshToken: "delete-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + } + + if err := database.CreateSession(session); err != nil { + t.Fatalf("failed to create session: %v", err) + } + + testCases := []struct { + name string + sessionID string + wantErr bool + errContains string + }{ + { + name: "valid session ID", + sessionID: session.ID, + wantErr: false, + }, + { + name: "non-existent session ID", + sessionID: "nonexistent-id", + wantErr: true, + errContains: "session not found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := database.DeleteSession(tc.sessionID) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify session was deleted + _, err = database.GetSessionByRefreshToken(session.RefreshToken) + if err == nil { + t.Error("session still exists after deletion") + } + }) + } + }) + + t.Run("CleanExpiredSessions", func(t *testing.T) { + // Create a mix of valid and expired sessions + sessions := []*models.Session{ + { + ID: uuid.New().String(), + UserID: user.ID, + RefreshToken: "valid-clean-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + }, + { + ID: uuid.New().String(), + UserID: user.ID, + RefreshToken: "expired-clean-token-1", + ExpiresAt: time.Now().Add(-1 * time.Hour), + CreatedAt: time.Now().Add(-2 * time.Hour), + }, + { + ID: uuid.New().String(), + UserID: user.ID, + RefreshToken: "expired-clean-token-2", + ExpiresAt: time.Now().Add(-2 * time.Hour), + CreatedAt: time.Now().Add(-3 * time.Hour), + }, + } + + for _, s := range sessions { + if err := database.CreateSession(s); err != nil { + t.Fatalf("failed to create session: %v", err) + } + } + + // Clean expired sessions + if err := database.CleanExpiredSessions(); err != nil { + t.Fatalf("failed to clean expired sessions: %v", err) + } + + // Verify valid session still exists + validSession, err := database.GetSessionByRefreshToken("valid-clean-token") + if err != nil { + t.Errorf("valid session was unexpectedly deleted: %v", err) + } + if validSession == nil { + t.Error("valid session was unexpectedly deleted") + } + + // Verify expired sessions were deleted + expiredTokens := []string{"expired-clean-token-1", "expired-clean-token-2"} + for _, token := range expiredTokens { + if _, err := database.GetSessionByRefreshToken(token); err == nil { + t.Errorf("expired session with token %s still exists", token) + } + } + }) +} diff --git a/server/internal/db/system_test.go b/server/internal/db/system_test.go new file mode 100644 index 0000000..2e2ca03 --- /dev/null +++ b/server/internal/db/system_test.go @@ -0,0 +1,213 @@ +package db_test + +import ( + "encoding/base64" + "fmt" + "strings" + "testing" + "time" + + "novamd/internal/db" + "novamd/internal/models" + + "github.com/google/uuid" +) + +func TestSystemOperations(t *testing.T) { + database, err := db.NewTestDB(":memory:", &mockSecrets{}) + if err != nil { + t.Fatalf("failed to create test database: %v", err) + } + defer database.Close() + + if err := database.Migrate(); err != nil { + t.Fatalf("failed to run migrations: %v", err) + } + + t.Run("GetSystemSettings", func(t *testing.T) { + t.Run("non-existent setting", func(t *testing.T) { + _, err := database.GetSystemSetting("nonexistent-key") + if err == nil { + t.Error("expected error for non-existent key, got nil") + } + }) + + t.Run("existing setting", func(t *testing.T) { + // First set a value + err := database.SetSystemSetting("test-key", "test-value") + if err != nil { + t.Fatalf("failed to set system setting: %v", err) + } + + // Then get it back + value, err := database.GetSystemSetting("test-key") + if err != nil { + t.Fatalf("failed to get system setting: %v", err) + } + + if value != "test-value" { + t.Errorf("got value %q, want %q", value, "test-value") + } + }) + }) + + t.Run("SetSystemSettings", func(t *testing.T) { + testCases := []struct { + name string + key string + value string + wantErr bool + errContains string + }{ + { + name: "new setting", + key: "new-key", + value: "new-value", + wantErr: false, + }, + { + name: "update existing setting", + key: "update-key", + value: "original-value", + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := database.SetSystemSetting(tc.key, tc.value) + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errContains) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify the setting was stored + stored, err := database.GetSystemSetting(tc.key) + if err != nil { + t.Fatalf("failed to retrieve stored setting: %v", err) + } + if stored != tc.value { + t.Errorf("got value %q, want %q", stored, tc.value) + } + + // For the update case, test updating the value + if tc.name == "update existing setting" { + newValue := "updated-value" + err := database.SetSystemSetting(tc.key, newValue) + if err != nil { + t.Fatalf("failed to update setting: %v", err) + } + + stored, err := database.GetSystemSetting(tc.key) + if err != nil { + t.Fatalf("failed to retrieve updated setting: %v", err) + } + if stored != newValue { + t.Errorf("got updated value %q, want %q", stored, newValue) + } + } + }) + } + }) + + t.Run("EnsureJWTSecret", func(t *testing.T) { + // First call should generate a new secret + secret1, err := database.EnsureJWTSecret() + if err != nil { + t.Fatalf("failed to ensure JWT secret: %v", err) + } + + // Verify the secret is a valid base64-encoded string of sufficient length + decoded, err := base64.StdEncoding.DecodeString(secret1) + if err != nil { + t.Errorf("secret is not valid base64: %v", err) + } + if len(decoded) < 32 { + t.Errorf("secret length = %d, want >= 32", len(decoded)) + } + + // Second call should return the same secret + secret2, err := database.EnsureJWTSecret() + if err != nil { + t.Fatalf("failed to get existing JWT secret: %v", err) + } + + if secret2 != secret1 { + t.Errorf("got different secret on second call") + } + }) + + t.Run("GetSystemStats", func(t *testing.T) { + // Create some test users and sessions + users := []*models.User{ + { + Email: "user1@test.com", + DisplayName: "User 1", + PasswordHash: "hash1", + Role: "editor", + }, + { + Email: "user2@test.com", + DisplayName: "User 2", + PasswordHash: "hash2", + Role: "viewer", + }, + } + + for _, u := range users { + createdUser, err := database.CreateUser(u) + if err != nil { + t.Fatalf("failed to create test user: %v", err) + } + + // Create multiple workspaces per user + // Each user has one default workspace + for i := 0; i < 2; i++ { + workspace := &models.Workspace{ + UserID: createdUser.ID, + Name: fmt.Sprintf("Workspace %d", i), + } + if err := database.CreateWorkspace(workspace); err != nil { + t.Fatalf("failed to create test workspace: %v", err) + } + } + + // Create an active session for the first user + if createdUser.Email == "user1@test.com" { + session := &models.Session{ + ID: uuid.New().String(), + UserID: createdUser.ID, + RefreshToken: "test-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + } + if err := database.CreateSession(session); err != nil { + t.Fatalf("failed to create test session: %v", err) + } + } + } + + stats, err := database.GetSystemStats() + if err != nil { + t.Fatalf("failed to get system stats: %v", err) + } + + // Verify stats + if stats.TotalUsers != 2 { + t.Errorf("TotalUsers = %d, want 2", stats.TotalUsers) + } + if stats.TotalWorkspaces != 6 { // 2 + 1 default workspace per user + t.Errorf("TotalWorkspaces = %d, want 6", stats.TotalWorkspaces) + } + if stats.ActiveUsers != 1 { // Only user1 has an active session + t.Errorf("ActiveUsers = %d, want 1", stats.ActiveUsers) + } + }) +} diff --git a/server/internal/db/testing.go b/server/internal/db/testdb.go similarity index 100% rename from server/internal/db/testing.go rename to server/internal/db/testdb.go diff --git a/server/internal/db/testutil_test.go b/server/internal/db/testutil_test.go new file mode 100644 index 0000000..103f4b5 --- /dev/null +++ b/server/internal/db/testutil_test.go @@ -0,0 +1,6 @@ +package db_test + +type mockSecrets struct{} + +func (m *mockSecrets) Encrypt(s string) (string, error) { return s, nil } +func (m *mockSecrets) Decrypt(s string) (string, error) { return s, nil } diff --git a/server/internal/db/users.go b/server/internal/db/users.go index 17cc374..132264b 100644 --- a/server/internal/db/users.go +++ b/server/internal/db/users.go @@ -69,7 +69,7 @@ func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) e theme, auto_save, show_hidden_files, git_enabled, git_url, git_user, git_token, git_auto_commit, git_commit_msg_template - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, workspace.UserID, workspace.Name, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles, workspace.GitEnabled, workspace.GitURL, workspace.GitUser, workspace.GitToken, diff --git a/server/internal/db/workspaces.go b/server/internal/db/workspaces.go index ce39ce5..004dcf5 100644 --- a/server/internal/db/workspaces.go +++ b/server/internal/db/workspaces.go @@ -24,7 +24,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error { user_id, name, theme, auto_save, show_hidden_files, git_enabled, git_url, git_user, git_token, git_auto_commit, git_commit_msg_template - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, workspace.UserID, workspace.Name, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles, workspace.GitEnabled, workspace.GitURL, workspace.GitUser, encryptedToken, workspace.GitAutoCommit, workspace.GitCommitMsgTemplate, From e8868dde398eeab8337ace9d28e7a9cf81d7c89c Mon Sep 17 00:00:00 2001 From: LordMathis Date: Mon, 25 Nov 2024 21:58:16 +0100 Subject: [PATCH 23/38] Test users and workspaces --- server/internal/db/users_test.go | 413 +++++++++++++++++++++++++ server/internal/db/workspaces_test.go | 430 ++++++++++++++++++++++++++ 2 files changed, 843 insertions(+) create mode 100644 server/internal/db/users_test.go create mode 100644 server/internal/db/workspaces_test.go diff --git a/server/internal/db/users_test.go b/server/internal/db/users_test.go new file mode 100644 index 0000000..e0683ea --- /dev/null +++ b/server/internal/db/users_test.go @@ -0,0 +1,413 @@ +package db_test + +import ( + "strings" + "testing" + + "novamd/internal/db" + "novamd/internal/models" +) + +func TestUserOperations(t *testing.T) { + database, err := db.NewTestDB(":memory:", &mockSecrets{}) + if err != nil { + t.Fatalf("failed to create test database: %v", err) + } + defer database.Close() + + if err := database.Migrate(); err != nil { + t.Fatalf("failed to run migrations: %v", err) + } + + t.Run("CreateUser", func(t *testing.T) { + testCases := []struct { + name string + user *models.User + wantErr bool + errContains string + }{ + { + name: "valid user", + user: &models.User{ + Email: "test@example.com", + DisplayName: "Test User", + PasswordHash: "hashed_password", + Role: models.RoleEditor, + }, + wantErr: false, + }, + { + name: "duplicate email", + user: &models.User{ + Email: "test@example.com", // Same as above + DisplayName: "Another User", + PasswordHash: "different_hash", + Role: models.RoleViewer, + }, + wantErr: true, + errContains: "UNIQUE constraint failed", + }, + { + name: "invalid role", + user: &models.User{ + Email: "invalid@example.com", + DisplayName: "Invalid Role User", + PasswordHash: "hash", + Role: "invalid_role", + }, + wantErr: true, + errContains: "CHECK constraint failed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + user, err := database.CreateUser(tc.user) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify user was created properly + if user.ID == 0 { + t.Error("expected non-zero user ID") + } + if user.Email != tc.user.Email { + t.Errorf("Email = %v, want %v", user.Email, tc.user.Email) + } + if user.DisplayName != tc.user.DisplayName { + t.Errorf("DisplayName = %v, want %v", user.DisplayName, tc.user.DisplayName) + } + if user.Role != tc.user.Role { + t.Errorf("Role = %v, want %v", user.Role, tc.user.Role) + } + if user.CreatedAt.IsZero() { + t.Error("CreatedAt should not be zero") + } + if user.LastWorkspaceID == 0 { + t.Error("expected non-zero LastWorkspaceID (default workspace)") + } + }) + } + }) + + t.Run("GetUserByID", func(t *testing.T) { + // Create a test user first + createdUser, err := database.CreateUser(&models.User{ + Email: "getbyid@example.com", + DisplayName: "Get By ID User", + PasswordHash: "hash", + Role: models.RoleEditor, + }) + if err != nil { + t.Fatalf("failed to create test user: %v", err) + } + + testCases := []struct { + name string + userID int + wantErr bool + }{ + { + name: "existing user", + userID: createdUser.ID, + wantErr: false, + }, + { + name: "non-existent user", + userID: 99999, + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + user, err := database.GetUserByID(tc.userID) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if user.ID != tc.userID { + t.Errorf("ID = %v, want %v", user.ID, tc.userID) + } + }) + } + }) + + t.Run("GetUserByEmail", func(t *testing.T) { + // Create a test user first + createdUser, err := database.CreateUser(&models.User{ + Email: "getbyemail@example.com", + DisplayName: "Get By Email User", + PasswordHash: "hash", + Role: models.RoleEditor, + }) + if err != nil { + t.Fatalf("failed to create test user: %v", err) + } + + testCases := []struct { + name string + email string + wantErr bool + }{ + { + name: "existing user", + email: createdUser.Email, + wantErr: false, + }, + { + name: "non-existent user", + email: "nonexistent@example.com", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + user, err := database.GetUserByEmail(tc.email) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if user.Email != tc.email { + t.Errorf("Email = %v, want %v", user.Email, tc.email) + } + }) + } + }) + + t.Run("UpdateUser", func(t *testing.T) { + // Create a test user first + user, err := database.CreateUser(&models.User{ + Email: "update@example.com", + DisplayName: "Original Name", + PasswordHash: "original_hash", + Role: models.RoleEditor, + }) + if err != nil { + t.Fatalf("failed to create test user: %v", err) + } + + // Update user details + user.DisplayName = "Updated Name" + user.PasswordHash = "new_hash" + user.Role = models.RoleAdmin + + if err := database.UpdateUser(user); err != nil { + t.Fatalf("failed to update user: %v", err) + } + + // Verify updates + updated, err := database.GetUserByID(user.ID) + if err != nil { + t.Fatalf("failed to get updated user: %v", err) + } + + if updated.DisplayName != "Updated Name" { + t.Errorf("DisplayName = %v, want %v", updated.DisplayName, "Updated Name") + } + if updated.PasswordHash != "new_hash" { + t.Errorf("PasswordHash = %v, want %v", updated.PasswordHash, "new_hash") + } + if updated.Role != models.RoleAdmin { + t.Errorf("Role = %v, want %v", updated.Role, models.RoleAdmin) + } + }) + + t.Run("GetAllUsers", func(t *testing.T) { + // Create several test users + testUsers := []*models.User{ + { + Email: "user1@example.com", + DisplayName: "User One", + PasswordHash: "hash1", + Role: models.RoleEditor, + }, + { + Email: "user2@example.com", + DisplayName: "User Two", + PasswordHash: "hash2", + Role: models.RoleViewer, + }, + } + + for _, u := range testUsers { + _, err := database.CreateUser(u) + if err != nil { + t.Fatalf("failed to create test user: %v", err) + } + } + + // Get all users + users, err := database.GetAllUsers() + if err != nil { + t.Fatalf("failed to get all users: %v", err) + } + + // We should have at least as many users as we just created + // (there might be more from previous tests) + if len(users) < len(testUsers) { + t.Errorf("got %d users, want at least %d", len(users), len(testUsers)) + } + + // Verify each test user exists in the result + for _, expected := range testUsers { + found := false + for _, u := range users { + if u.Email == expected.Email { + found = true + if u.DisplayName != expected.DisplayName { + t.Errorf("DisplayName = %v, want %v", u.DisplayName, expected.DisplayName) + } + if u.Role != expected.Role { + t.Errorf("Role = %v, want %v", u.Role, expected.Role) + } + break + } + } + if !found { + t.Errorf("user with email %s not found in results", expected.Email) + } + } + }) + + t.Run("UpdateLastWorkspace", func(t *testing.T) { + // Create a test user with multiple workspaces + user, err := database.CreateUser(&models.User{ + Email: "workspace@example.com", + DisplayName: "Workspace User", + PasswordHash: "hash", + Role: models.RoleEditor, + }) + if err != nil { + t.Fatalf("failed to create test user: %v", err) + } + + // Create additional workspace + workspace := &models.Workspace{ + UserID: user.ID, + Name: "Second Workspace", + } + if err := database.CreateWorkspace(workspace); err != nil { + t.Fatalf("failed to create additional workspace: %v", err) + } + + // Update last workspace + err = database.UpdateLastWorkspace(user.ID, workspace.Name) + if err != nil { + t.Fatalf("failed to update last workspace: %v", err) + } + + // Verify update + lastWorkspace, err := database.GetLastWorkspaceName(user.ID) + if err != nil { + t.Fatalf("failed to get last workspace: %v", err) + } + + if lastWorkspace != workspace.Name { + t.Errorf("LastWorkspace = %v, want %v", lastWorkspace, workspace.Name) + } + }) + + t.Run("DeleteUser", func(t *testing.T) { + // Create a test user + user, err := database.CreateUser(&models.User{ + Email: "delete@example.com", + DisplayName: "Delete User", + PasswordHash: "hash", + Role: models.RoleEditor, + }) + if err != nil { + t.Fatalf("failed to create test user: %v", err) + } + + // Delete the user + if err := database.DeleteUser(user.ID); err != nil { + t.Fatalf("failed to delete user: %v", err) + } + + // Verify user is gone + _, err = database.GetUserByID(user.ID) + if err == nil { + t.Error("expected error getting deleted user, got nil") + } + + // Verify workspaces are gone + workspaces, err := database.GetWorkspacesByUserID(user.ID) + if err != nil { + t.Fatalf("unexpected error checking workspaces: %v", err) + } + if len(workspaces) > 0 { + t.Error("expected no workspaces for deleted user") + } + }) + + t.Run("CountAdminUsers", func(t *testing.T) { + // Create users with different roles + testUsers := []*models.User{ + { + Email: "admin1@example.com", + DisplayName: "Admin One", + PasswordHash: "hash1", + Role: models.RoleAdmin, + }, + { + Email: "admin2@example.com", + DisplayName: "Admin Two", + PasswordHash: "hash2", + Role: models.RoleAdmin, + }, + { + Email: "editor@example.com", + DisplayName: "Editor", + PasswordHash: "hash3", + Role: models.RoleEditor, + }, + } + + for _, u := range testUsers { + _, err := database.CreateUser(u) + if err != nil { + t.Fatalf("failed to create test user: %v", err) + } + } + + // Count admin users + count, err := database.CountAdminUsers() + if err != nil { + t.Fatalf("failed to count admin users: %v", err) + } + + // We should have at least 2 admin users (from our test cases) + // There might be more from previous tests + if count < 2 { + t.Errorf("AdminCount = %d, want at least 2", count) + } + }) +} diff --git a/server/internal/db/workspaces_test.go b/server/internal/db/workspaces_test.go new file mode 100644 index 0000000..924d4a9 --- /dev/null +++ b/server/internal/db/workspaces_test.go @@ -0,0 +1,430 @@ +package db_test + +import ( + "database/sql" + "strings" + "testing" + + "novamd/internal/db" + "novamd/internal/models" +) + +func TestWorkspaceOperations(t *testing.T) { + database, err := db.NewTestDB(":memory:", &mockSecrets{}) + if err != nil { + t.Fatalf("failed to create test database: %v", err) + } + defer database.Close() + + if err := database.Migrate(); err != nil { + t.Fatalf("failed to run migrations: %v", err) + } + + // Create a test user first + user, err := database.CreateUser(&models.User{ + Email: "test@example.com", + DisplayName: "Test User", + PasswordHash: "hash", + Role: models.RoleEditor, + }) + if err != nil { + t.Fatalf("failed to create test user: %v", err) + } + + t.Run("CreateWorkspace", func(t *testing.T) { + testCases := []struct { + name string + workspace *models.Workspace + wantErr bool + errContains string + }{ + { + name: "valid workspace", + workspace: &models.Workspace{ + UserID: user.ID, + Name: "Test Workspace", + }, + wantErr: false, + }, + { + name: "non-existent user", + workspace: &models.Workspace{ + UserID: 99999, + Name: "Invalid User", + }, + wantErr: true, + errContains: "FOREIGN KEY constraint failed", + }, + { + name: "with git settings", + workspace: &models.Workspace{ + UserID: user.ID, + Name: "Git Workspace", + Theme: "dark", + AutoSave: true, + ShowHiddenFiles: true, + GitEnabled: true, + GitURL: "https://github.com/user/repo", + GitUser: "username", + GitToken: "secret-token", + GitAutoCommit: true, + GitCommitMsgTemplate: "${action} ${filename}", + }, + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.workspace.Theme == "" { + tc.workspace.GetDefaultSettings() + } + + err := database.CreateWorkspace(tc.workspace) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify workspace was created properly + if tc.workspace.ID == 0 { + t.Error("expected non-zero workspace ID") + } + + // Retrieve and verify workspace + stored, err := database.GetWorkspaceByID(tc.workspace.ID) + if err != nil { + t.Fatalf("failed to retrieve workspace: %v", err) + } + + verifyWorkspace(t, stored, tc.workspace) + }) + } + }) + + t.Run("GetWorkspaceByID", func(t *testing.T) { + // Create a test workspace first + workspace := &models.Workspace{ + UserID: user.ID, + Name: "Get By ID Workspace", + } + workspace.GetDefaultSettings() + if err := database.CreateWorkspace(workspace); err != nil { + t.Fatalf("failed to create test workspace: %v", err) + } + + testCases := []struct { + name string + workspaceID int + wantErr bool + }{ + { + name: "existing workspace", + workspaceID: workspace.ID, + wantErr: false, + }, + { + name: "non-existent workspace", + workspaceID: 99999, + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := database.GetWorkspaceByID(tc.workspaceID) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.ID != tc.workspaceID { + t.Errorf("ID = %v, want %v", result.ID, tc.workspaceID) + } + }) + } + }) + + t.Run("GetWorkspaceByName", func(t *testing.T) { + // Create a test workspace first + workspace := &models.Workspace{ + UserID: user.ID, + Name: "Get By Name Workspace", + } + workspace.GetDefaultSettings() + if err := database.CreateWorkspace(workspace); err != nil { + t.Fatalf("failed to create test workspace: %v", err) + } + + testCases := []struct { + name string + userID int + workspaceName string + wantErr bool + }{ + { + name: "existing workspace", + userID: user.ID, + workspaceName: workspace.Name, + wantErr: false, + }, + { + name: "wrong user ID", + userID: 99999, + workspaceName: workspace.Name, + wantErr: true, + }, + { + name: "non-existent workspace", + userID: user.ID, + workspaceName: "Non-existent", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := database.GetWorkspaceByName(tc.userID, tc.workspaceName) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Name != tc.workspaceName { + t.Errorf("Name = %v, want %v", result.Name, tc.workspaceName) + } + if result.UserID != tc.userID { + t.Errorf("UserID = %v, want %v", result.UserID, tc.userID) + } + }) + } + }) + + t.Run("UpdateWorkspace", func(t *testing.T) { + // Create a test workspace first + workspace := &models.Workspace{ + UserID: user.ID, + Name: "Update Workspace", + } + workspace.GetDefaultSettings() + if err := database.CreateWorkspace(workspace); err != nil { + t.Fatalf("failed to create test workspace: %v", err) + } + + // Update workspace settings + workspace.Theme = "dark" + workspace.AutoSave = true + workspace.ShowHiddenFiles = true + workspace.GitEnabled = true + workspace.GitURL = "https://github.com/user/repo" + workspace.GitUser = "username" + workspace.GitToken = "new-token" + workspace.GitAutoCommit = true + workspace.GitCommitMsgTemplate = "custom ${filename}" + + if err := database.UpdateWorkspace(workspace); err != nil { + t.Fatalf("failed to update workspace: %v", err) + } + + // Verify updates + updated, err := database.GetWorkspaceByID(workspace.ID) + if err != nil { + t.Fatalf("failed to get updated workspace: %v", err) + } + + verifyWorkspace(t, updated, workspace) + }) + + t.Run("GetWorkspacesByUserID", func(t *testing.T) { + // Create several test workspaces + testWorkspaces := []*models.Workspace{ + { + UserID: user.ID, + Name: "User Workspace 1", + }, + { + UserID: user.ID, + Name: "User Workspace 2", + }, + } + + for _, w := range testWorkspaces { + w.GetDefaultSettings() + if err := database.CreateWorkspace(w); err != nil { + t.Fatalf("failed to create test workspace: %v", err) + } + } + + // Get all workspaces for user + workspaces, err := database.GetWorkspacesByUserID(user.ID) + if err != nil { + t.Fatalf("failed to get workspaces: %v", err) + } + + // We should have at least as many workspaces as we just created + // (there might be more from previous tests) + if len(workspaces) < len(testWorkspaces) { + t.Errorf("got %d workspaces, want at least %d", len(workspaces), len(testWorkspaces)) + } + + // Verify each test workspace exists in the result + for _, expected := range testWorkspaces { + found := false + for _, w := range workspaces { + if w.Name == expected.Name { + found = true + if w.UserID != expected.UserID { + t.Errorf("UserID = %v, want %v", w.UserID, expected.UserID) + } + break + } + } + if !found { + t.Errorf("workspace %s not found in results", expected.Name) + } + } + }) + + t.Run("UpdateLastOpenedFile", func(t *testing.T) { + // Create a test workspace + workspace := &models.Workspace{ + UserID: user.ID, + Name: "Last File Workspace", + } + workspace.GetDefaultSettings() + if err := database.CreateWorkspace(workspace); err != nil { + t.Fatalf("failed to create test workspace: %v", err) + } + + testCases := []struct { + name string + filePath string + wantErr bool + }{ + { + name: "valid file path", + filePath: "docs/test.md", + wantErr: false, + }, + { + name: "empty file path", + filePath: "", + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := database.UpdateLastOpenedFile(workspace.ID, tc.filePath) + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify update + path, err := database.GetLastOpenedFile(workspace.ID) + if err != nil { + t.Fatalf("failed to get last opened file: %v", err) + } + + if path != tc.filePath { + t.Errorf("LastOpenedFile = %v, want %v", path, tc.filePath) + } + }) + } + }) + + t.Run("DeleteWorkspace", func(t *testing.T) { + // Create a test workspace + workspace := &models.Workspace{ + UserID: user.ID, + Name: "Delete Workspace", + } + workspace.GetDefaultSettings() + if err := database.CreateWorkspace(workspace); err != nil { + t.Fatalf("failed to create test workspace: %v", err) + } + + // Delete the workspace + if err := database.DeleteWorkspace(workspace.ID); err != nil { + t.Fatalf("failed to delete workspace: %v", err) + } + + // Verify workspace is gone + _, err = database.GetWorkspaceByID(workspace.ID) + if err != sql.ErrNoRows { + t.Errorf("expected sql.ErrNoRows, got %v", err) + } + }) +} + +// Helper function to verify workspace fields +func verifyWorkspace(t *testing.T, actual, expected *models.Workspace) { + t.Helper() + + if actual.Name != expected.Name { + t.Errorf("Name = %v, want %v", actual.Name, expected.Name) + } + if actual.UserID != expected.UserID { + t.Errorf("UserID = %v, want %v", actual.UserID, expected.UserID) + } + if actual.Theme != expected.Theme { + t.Errorf("Theme = %v, want %v", actual.Theme, expected.Theme) + } + if actual.AutoSave != expected.AutoSave { + t.Errorf("AutoSave = %v, want %v", actual.AutoSave, expected.AutoSave) + } + if actual.ShowHiddenFiles != expected.ShowHiddenFiles { + t.Errorf("ShowHiddenFiles = %v, want %v", actual.ShowHiddenFiles, expected.ShowHiddenFiles) + } + if actual.GitEnabled != expected.GitEnabled { + t.Errorf("GitEnabled = %v, want %v", actual.GitEnabled, expected.GitEnabled) + } + if actual.GitURL != expected.GitURL { + t.Errorf("GitURL = %v, want %v", actual.GitURL, expected.GitURL) + } + if actual.GitUser != expected.GitUser { + t.Errorf("GitUser = %v, want %v", actual.GitUser, expected.GitUser) + } + if actual.GitToken != expected.GitToken { + t.Errorf("GitToken = %v, want %v", actual.GitToken, expected.GitToken) + } + if actual.GitAutoCommit != expected.GitAutoCommit { + t.Errorf("GitAutoCommit = %v, want %v", actual.GitAutoCommit, expected.GitAutoCommit) + } + if actual.GitCommitMsgTemplate != expected.GitCommitMsgTemplate { + t.Errorf("GitCommitMsgTemplate = %v, want %v", actual.GitCommitMsgTemplate, expected.GitCommitMsgTemplate) + } + if actual.CreatedAt.IsZero() { + t.Error("CreatedAt should not be zero") + } +} From 4ddf1f570f8bc27d5917afda374abf6bb370d019 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 26 Nov 2024 22:50:43 +0100 Subject: [PATCH 24/38] Implement auth handler integration test --- .vscode/settings.json | 4 +- server/go.mod | 4 + server/internal/auth/jwt.go | 10 + server/internal/auth/session.go | 9 +- server/internal/db/sessions.go | 6 +- server/internal/handlers/admin_handlers.go | 1 + server/internal/handlers/auth_handlers.go | 4 + .../auth_handlers_integration_test.go | 232 ++++++++++++++++++ server/internal/handlers/file_handlers.go | 19 +- server/internal/handlers/git_handlers.go | 2 + server/internal/handlers/integration_test.go | 188 ++++++++++++++ server/internal/handlers/static_handler.go | 2 + server/internal/handlers/user_handlers.go | 16 +- .../internal/handlers/workspace_handlers.go | 17 +- 14 files changed, 499 insertions(+), 15 deletions(-) create mode 100644 server/internal/handlers/auth_handlers_integration_test.go create mode 100644 server/internal/handlers/integration_test.go diff --git a/.vscode/settings.json b/.vscode/settings.json index 34f5162..52c38cb 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,7 +14,7 @@ "go.lintTool": "golangci-lint", "go.lintOnSave": "package", "go.formatTool": "goimports", - "go.testFlags": ["-tags=test"], + "go.testFlags": ["-tags=test,integration"], "[go]": { "editor.formatOnSave": true, "editor.codeActionsOnSave": { @@ -25,6 +25,6 @@ "gopls": { "usePlaceholders": true, "staticcheck": true, - "buildFlags": ["-tags", "test"] + "buildFlags": ["-tags", "test,integration"] } } diff --git a/server/go.mod b/server/go.mod index 0af2cf4..90ae840 100644 --- a/server/go.mod +++ b/server/go.mod @@ -11,6 +11,7 @@ require ( github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/mattn/go-sqlite3 v1.14.23 + github.com/stretchr/testify v1.9.0 github.com/unrolled/secure v1.17.0 golang.org/x/crypto v0.21.0 ) @@ -22,6 +23,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudflare/circl v1.3.7 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect @@ -33,6 +35,7 @@ require ( github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.2.2 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect @@ -42,4 +45,5 @@ require ( golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.13.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/server/internal/auth/jwt.go b/server/internal/auth/jwt.go index 9b65db8..f9b4b23 100644 --- a/server/internal/auth/jwt.go +++ b/server/internal/auth/jwt.go @@ -2,6 +2,8 @@ package auth import ( + "crypto/rand" + "encoding/hex" "fmt" "time" @@ -87,11 +89,19 @@ func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, erro // Returns the signed token string or an error func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) { now := time.Now() + + // Add a random nonce to ensure uniqueness + nonce := make([]byte, 8) + if _, err := rand.Read(nonce); err != nil { + return "", fmt.Errorf("failed to generate nonce: %w", err) + } + claims := Claims{ RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(now.Add(expiry)), IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), + ID: hex.EncodeToString(nonce), }, UserID: userID, Role: role, diff --git a/server/internal/auth/session.go b/server/internal/auth/session.go index b6217d6..872b2c1 100644 --- a/server/internal/auth/session.go +++ b/server/internal/auth/session.go @@ -76,8 +76,8 @@ func (s *SessionService) CreateSession(userID int, role string) (*models.Session // - string: a new access token // - error: any error that occurred func (s *SessionService) RefreshSession(refreshToken string) (string, error) { - // Get session from database - _, err := s.db.GetSessionByRefreshToken(refreshToken) + // Get session from database first + session, err := s.db.GetSessionByRefreshToken(refreshToken) if err != nil { return "", fmt.Errorf("invalid session: %w", err) } @@ -88,6 +88,11 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) { return "", fmt.Errorf("invalid refresh token: %w", err) } + // Double check that the claims match the session + if claims.UserID != session.UserID { + return "", fmt.Errorf("token does not match session") + } + // Generate a new access token return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role) } diff --git a/server/internal/db/sessions.go b/server/internal/db/sessions.go index 596dc64..ecc4a72 100644 --- a/server/internal/db/sessions.go +++ b/server/internal/db/sessions.go @@ -25,9 +25,9 @@ func (db *database) CreateSession(session *models.Session) error { func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { session := &models.Session{} err := db.QueryRow(` - SELECT id, user_id, refresh_token, expires_at, created_at - FROM sessions - WHERE refresh_token = ? AND expires_at > ?`, + SELECT id, user_id, refresh_token, expires_at, created_at + FROM sessions + WHERE refresh_token = ? AND expires_at > ?`, refreshToken, time.Now(), ).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) diff --git a/server/internal/handlers/admin_handlers.go b/server/internal/handlers/admin_handlers.go index 5ed8d42..0ae33e2 100644 --- a/server/internal/handlers/admin_handlers.go +++ b/server/internal/handlers/admin_handlers.go @@ -1,3 +1,4 @@ +// Package handlers contains the request handlers for the api routes. package handlers import ( diff --git a/server/internal/handlers/auth_handlers.go b/server/internal/handlers/auth_handlers.go index d6aec83..a1bc4b7 100644 --- a/server/internal/handlers/auth_handlers.go +++ b/server/internal/handlers/auth_handlers.go @@ -10,11 +10,13 @@ import ( "golang.org/x/crypto/bcrypt" ) +// LoginRequest represents a user login request type LoginRequest struct { Email string `json:"email"` Password string `json:"password"` } +// LoginResponse represents a user login response type LoginResponse struct { AccessToken string `json:"accessToken"` RefreshToken string `json:"refreshToken"` @@ -22,10 +24,12 @@ type LoginResponse struct { Session *models.Session `json:"session"` } +// RefreshRequest represents a refresh token request type RefreshRequest struct { RefreshToken string `json:"refreshToken"` } +// RefreshResponse represents a refresh token response type RefreshResponse struct { AccessToken string `json:"accessToken"` } diff --git a/server/internal/handlers/auth_handlers_integration_test.go b/server/internal/handlers/auth_handlers_integration_test.go new file mode 100644 index 0000000..d917e32 --- /dev/null +++ b/server/internal/handlers/auth_handlers_integration_test.go @@ -0,0 +1,232 @@ +//go:build integration + +package handlers_test + +import ( + "encoding/json" + "net/http" + "testing" + + "novamd/internal/handlers" + "novamd/internal/models" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAuthHandlers_Integration(t *testing.T) { + h := setupTestHarness(t) + defer h.teardown(t) + + t.Run("login", func(t *testing.T) { + t.Run("successful login - admin user", func(t *testing.T) { + loginReq := handlers.LoginRequest{ + Email: "admin@test.com", + Password: "admin123", + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil) + require.Equal(t, http.StatusOK, rr.Code) + + var resp handlers.LoginResponse + err := json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + + assert.NotEmpty(t, resp.AccessToken) + assert.NotEmpty(t, resp.RefreshToken) + assert.NotNil(t, resp.User) + assert.Equal(t, loginReq.Email, resp.User.Email) + assert.Equal(t, models.RoleAdmin, resp.User.Role) + }) + + t.Run("successful login - regular user", func(t *testing.T) { + loginReq := handlers.LoginRequest{ + Email: "user@test.com", + Password: "user123", + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil) + require.Equal(t, http.StatusOK, rr.Code) + + var resp handlers.LoginResponse + err := json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + + assert.NotEmpty(t, resp.AccessToken) + assert.NotEmpty(t, resp.RefreshToken) + assert.NotNil(t, resp.User) + assert.Equal(t, loginReq.Email, resp.User.Email) + assert.Equal(t, models.RoleEditor, resp.User.Role) + }) + + t.Run("login failures", func(t *testing.T) { + tests := []struct { + name string + request handlers.LoginRequest + wantCode int + }{ + { + name: "wrong password", + request: handlers.LoginRequest{ + Email: "user@test.com", + Password: "wrongpassword", + }, + wantCode: http.StatusUnauthorized, + }, + { + name: "non-existent user", + request: handlers.LoginRequest{ + Email: "nonexistent@test.com", + Password: "password123", + }, + wantCode: http.StatusUnauthorized, + }, + { + name: "empty email", + request: handlers.LoginRequest{ + Email: "", + Password: "password123", + }, + wantCode: http.StatusBadRequest, + }, + { + name: "empty password", + request: handlers.LoginRequest{ + Email: "user@test.com", + Password: "", + }, + wantCode: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", tt.request, "", nil) + assert.Equal(t, tt.wantCode, rr.Code) + }) + } + }) + }) + + t.Run("refresh token", func(t *testing.T) { + t.Run("successful token refresh", func(t *testing.T) { + // First login to get refresh token + loginReq := handlers.LoginRequest{ + Email: "user@test.com", + Password: "user123", + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil) + require.Equal(t, http.StatusOK, rr.Code) + + var loginResp handlers.LoginResponse + err := json.NewDecoder(rr.Body).Decode(&loginResp) + require.NoError(t, err) + + // Now try to refresh the token + refreshReq := handlers.RefreshRequest{ + RefreshToken: loginResp.RefreshToken, + } + + rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", refreshReq, "", nil) + require.Equal(t, http.StatusOK, rr.Code) + + var refreshResp handlers.RefreshResponse + err = json.NewDecoder(rr.Body).Decode(&refreshResp) + require.NoError(t, err) + assert.NotEmpty(t, refreshResp.AccessToken) + }) + + t.Run("refresh failures", func(t *testing.T) { + tests := []struct { + name string + request handlers.RefreshRequest + wantCode int + }{ + { + name: "invalid refresh token", + request: handlers.RefreshRequest{ + RefreshToken: "invalid-token", + }, + wantCode: http.StatusUnauthorized, + }, + { + name: "empty refresh token", + request: handlers.RefreshRequest{ + RefreshToken: "", + }, + wantCode: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", tt.request, "", nil) + assert.Equal(t, tt.wantCode, rr.Code) + }) + } + }) + }) + + t.Run("logout", func(t *testing.T) { + t.Run("successful logout", func(t *testing.T) { + // First login to get session + loginReq := handlers.LoginRequest{ + Email: "user@test.com", + Password: "user123", + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil) + require.Equal(t, http.StatusOK, rr.Code) + + var loginResp handlers.LoginResponse + err := json.NewDecoder(rr.Body).Decode(&loginResp) + require.NoError(t, err) + + // Now logout using session ID from login response + headers := map[string]string{ + "X-Session-ID": loginResp.Session.ID, + } + rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/logout", nil, loginResp.AccessToken, headers) + require.Equal(t, http.StatusOK, rr.Code) + + // Try to use the refresh token - should fail + refreshReq := handlers.RefreshRequest{ + RefreshToken: loginResp.RefreshToken, + } + + rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/refresh", refreshReq, "", nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + + t.Run("logout without session ID", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodPost, "/api/v1/auth/logout", nil, h.RegularToken, nil) + assert.Equal(t, http.StatusBadRequest, rr.Code) + }) + }) + + t.Run("get current user", func(t *testing.T) { + t.Run("successful get current user", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var user models.User + err := json.NewDecoder(rr.Body).Decode(&user) + require.NoError(t, err) + + assert.Equal(t, h.RegularUser.ID, user.ID) + assert.Equal(t, h.RegularUser.Email, user.Email) + assert.Equal(t, h.RegularUser.Role, user.Role) + }) + + t.Run("get current user without token", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "", nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + + t.Run("get current user with invalid token", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "invalid-token", nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + }) +} diff --git a/server/internal/handlers/file_handlers.go b/server/internal/handlers/file_handlers.go index 9f75eb2..3f53c17 100644 --- a/server/internal/handlers/file_handlers.go +++ b/server/internal/handlers/file_handlers.go @@ -10,6 +10,7 @@ import ( "github.com/go-chi/chi/v5" ) +// ListFiles returns a list of all files in the workspace func (h *Handler) ListFiles() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -27,6 +28,7 @@ func (h *Handler) ListFiles() http.HandlerFunc { } } +// LookupFileByName returns the paths of files with the given name func (h *Handler) LookupFileByName() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -50,6 +52,7 @@ func (h *Handler) LookupFileByName() http.HandlerFunc { } } +// GetFileContent returns the content of a file func (h *Handler) GetFileContent() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -65,10 +68,15 @@ func (h *Handler) GetFileContent() http.HandlerFunc { } w.Header().Set("Content-Type", "text/plain") - w.Write(content) + _, err = w.Write(content) + if err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + return + } } } +// SaveFile saves the content of a file func (h *Handler) SaveFile() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -93,6 +101,7 @@ func (h *Handler) SaveFile() http.HandlerFunc { } } +// DeleteFile deletes a file func (h *Handler) DeleteFile() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -108,10 +117,15 @@ func (h *Handler) DeleteFile() http.HandlerFunc { } w.WriteHeader(http.StatusOK) - w.Write([]byte("File deleted successfully")) + _, err = w.Write([]byte("File deleted successfully")) + if err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + return + } } } +// GetLastOpenedFile returns the last opened file in the workspace func (h *Handler) GetLastOpenedFile() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -134,6 +148,7 @@ func (h *Handler) GetLastOpenedFile() http.HandlerFunc { } } +// UpdateLastOpenedFile updates the last opened file in the workspace func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) diff --git a/server/internal/handlers/git_handlers.go b/server/internal/handlers/git_handlers.go index ea6eeff..f34b12a 100644 --- a/server/internal/handlers/git_handlers.go +++ b/server/internal/handlers/git_handlers.go @@ -7,6 +7,7 @@ import ( "novamd/internal/context" ) +// StageCommitAndPush stages, commits, and pushes changes to the remote repository func (h *Handler) StageCommitAndPush() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -38,6 +39,7 @@ func (h *Handler) StageCommitAndPush() http.HandlerFunc { } } +// PullChanges pulls changes from the remote repository func (h *Handler) PullChanges() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) diff --git a/server/internal/handlers/integration_test.go b/server/internal/handlers/integration_test.go new file mode 100644 index 0000000..4201fdd --- /dev/null +++ b/server/internal/handlers/integration_test.go @@ -0,0 +1,188 @@ +//go:build integration + +package handlers_test + +import ( + "bytes" + "encoding/json" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "golang.org/x/crypto/bcrypt" + + "novamd/internal/api" + "novamd/internal/auth" + "novamd/internal/db" + "novamd/internal/handlers" + "novamd/internal/models" + "novamd/internal/secrets" + "novamd/internal/storage" +) + +// testHarness encapsulates all the dependencies needed for testing +type testHarness struct { + DB db.TestDatabase + Storage storage.Manager + Router *chi.Mux + Handler *handlers.Handler + JWTManager auth.JWTManager + SessionSvc *auth.SessionService + AdminUser *models.User + AdminToken string + RegularUser *models.User + RegularToken string + TempDirectory string +} + +// setupTestHarness creates a new test environment +func setupTestHarness(t *testing.T) *testHarness { + t.Helper() + + // Create temporary directory for test files + tempDir, err := os.MkdirTemp("", "novamd-test-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + + // Initialize test database + secretsSvc, err := secrets.NewService("YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=") // test key + if err != nil { + t.Fatalf("Failed to initialize secrets service: %v", err) + } + + database, err := db.NewTestDB(":memory:", secretsSvc) + if err != nil { + t.Fatalf("Failed to initialize test database: %v", err) + } + + if err := database.Migrate(); err != nil { + t.Fatalf("Failed to run migrations: %v", err) + } + + // Initialize storage + storageSvc := storage.NewService(tempDir) + + // Initialize JWT service + jwtSvc, err := auth.NewJWTService(auth.JWTConfig{ + SigningKey: "test-key", + AccessTokenExpiry: 15 * time.Minute, + RefreshTokenExpiry: 24 * time.Hour, + }) + if err != nil { + t.Fatalf("Failed to initialize JWT service: %v", err) + } + + // Initialize session service + sessionSvc := auth.NewSessionService(database, jwtSvc) + + // Create handler + handler := &handlers.Handler{ + DB: database, + Storage: storageSvc, + } + + // Set up router with middlewares + router := chi.NewRouter() + authMiddleware := auth.NewMiddleware(jwtSvc) + router.Route("/api/v1", func(r chi.Router) { + api.SetupRoutes(r, database, storageSvc, authMiddleware, sessionSvc) + }) + + // Create test users + adminUser, adminToken := createTestUser(t, database, sessionSvc, "admin@test.com", "admin123", models.RoleAdmin) + regularUser, regularToken := createTestUser(t, database, sessionSvc, "user@test.com", "user123", models.RoleEditor) + + return &testHarness{ + DB: database, + Storage: storageSvc, + Router: router, + Handler: handler, + JWTManager: jwtSvc, + SessionSvc: sessionSvc, + AdminUser: adminUser, + AdminToken: adminToken, + RegularUser: regularUser, + RegularToken: regularToken, + TempDirectory: tempDir, + } +} + +// teardownTestHarness cleans up the test environment +func (h *testHarness) teardown(t *testing.T) { + t.Helper() + + if err := h.DB.Close(); err != nil { + t.Errorf("Failed to close database: %v", err) + } + + if err := os.RemoveAll(h.TempDirectory); err != nil { + t.Errorf("Failed to remove temp directory: %v", err) + } +} + +// createTestUser creates a test user and returns the user and access token +func createTestUser(t *testing.T, db db.Database, sessionSvc *auth.SessionService, email, password string, role models.UserRole) (*models.User, string) { + t.Helper() + + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + t.Fatalf("Failed to hash password: %v", err) + } + + user := &models.User{ + Email: email, + DisplayName: "Test User", + PasswordHash: string(hashedPassword), + Role: role, + } + + user, err = db.CreateUser(user) + if err != nil { + t.Fatalf("Failed to create user: %v", err) + } + + session, accessToken, err := sessionSvc.CreateSession(user.ID, string(user.Role)) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + if session == nil || accessToken == "" { + t.Fatal("Failed to get valid session or token") + } + + return user, accessToken +} + +// makeRequest is a helper function to make HTTP requests in tests +func (h *testHarness) makeRequest(t *testing.T, method, path string, body interface{}, token string, headers map[string]string) *httptest.ResponseRecorder { + t.Helper() + + var reqBody []byte + var err error + + if body != nil { + reqBody, err = json.Marshal(body) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + } + + req := httptest.NewRequest(method, path, bytes.NewBuffer(reqBody)) + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + req.Header.Set("Content-Type", "application/json") + + // Add any additional headers + for k, v := range headers { + req.Header.Set(k, v) + } + + rr := httptest.NewRecorder() + h.Router.ServeHTTP(rr, req) + + return rr +} diff --git a/server/internal/handlers/static_handler.go b/server/internal/handlers/static_handler.go index 8dfb710..8360b3c 100644 --- a/server/internal/handlers/static_handler.go +++ b/server/internal/handlers/static_handler.go @@ -12,12 +12,14 @@ type StaticHandler struct { staticPath string } +// NewStaticHandler creates a new StaticHandler with the given static path func NewStaticHandler(staticPath string) *StaticHandler { return &StaticHandler{ staticPath: staticPath, } } +// ServeHTTP serves the static files func (h *StaticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Get the requested path requestedPath := r.URL.Path diff --git a/server/internal/handlers/user_handlers.go b/server/internal/handlers/user_handlers.go index 678c8e5..0249327 100644 --- a/server/internal/handlers/user_handlers.go +++ b/server/internal/handlers/user_handlers.go @@ -1,6 +1,7 @@ package handlers import ( + "database/sql" "encoding/json" "net/http" @@ -9,6 +10,7 @@ import ( "golang.org/x/crypto/bcrypt" ) +// UpdateProfileRequest represents a user profile update request type UpdateProfileRequest struct { DisplayName string `json:"displayName"` Email string `json:"email"` @@ -16,10 +18,12 @@ type UpdateProfileRequest struct { NewPassword string `json:"newPassword"` } +// DeleteAccountRequest represents a user account deletion request type DeleteAccountRequest struct { Password string `json:"password"` } +// GetUser returns the current user's profile func (h *Handler) GetUser() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -64,7 +68,11 @@ func (h *Handler) UpdateProfile() http.HandlerFunc { http.Error(w, "Failed to start transaction", http.StatusInternalServerError) return } - defer tx.Rollback() + defer func() { + if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { + http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError) + } + }() // Handle password update if requested if req.NewPassword != "" { @@ -188,7 +196,11 @@ func (h *Handler) DeleteAccount() http.HandlerFunc { http.Error(w, "Failed to start transaction", http.StatusInternalServerError) return } - defer tx.Rollback() + defer func() { + if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { + http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError) + } + }() // Get user's workspaces for cleanup workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID) diff --git a/server/internal/handlers/workspace_handlers.go b/server/internal/handlers/workspace_handlers.go index 8dee442..10ae3cf 100644 --- a/server/internal/handlers/workspace_handlers.go +++ b/server/internal/handlers/workspace_handlers.go @@ -1,14 +1,15 @@ package handlers import ( + "database/sql" "encoding/json" - "fmt" "net/http" "novamd/internal/context" "novamd/internal/models" ) +// ListWorkspaces returns a list of all workspaces for the current user func (h *Handler) ListWorkspaces() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -26,6 +27,7 @@ func (h *Handler) ListWorkspaces() http.HandlerFunc { } } +// CreateWorkspace creates a new workspace func (h *Handler) CreateWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -54,6 +56,7 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc { } } +// GetWorkspace returns the current workspace func (h *Handler) GetWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -81,6 +84,7 @@ func gitSettingsChanged(new, old *models.Workspace) bool { return false } +// UpdateWorkspace updates the current workspace func (h *Handler) UpdateWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -132,6 +136,7 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc { } } +// DeleteWorkspace deletes the current workspace func (h *Handler) DeleteWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -168,7 +173,11 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc { http.Error(w, "Failed to start transaction", http.StatusInternalServerError) return } - defer tx.Rollback() + defer func() { + if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { + http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError) + } + }() // Update last workspace ID first err = h.DB.UpdateLastWorkspaceTx(tx, ctx.UserID, nextWorkspaceID) @@ -195,6 +204,7 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc { } } +// GetLastWorkspaceName returns the name of the last opened workspace func (h *Handler) GetLastWorkspaceName() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -212,6 +222,7 @@ func (h *Handler) GetLastWorkspaceName() http.HandlerFunc { } } +// UpdateLastWorkspaceName updates the name of the last opened workspace func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) @@ -224,13 +235,11 @@ func (h *Handler) UpdateLastWorkspaceName() http.HandlerFunc { } if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { - fmt.Println(err) http.Error(w, "Invalid request body", http.StatusBadRequest) return } if err := h.DB.UpdateLastWorkspace(ctx.UserID, requestBody.WorkspaceName); err != nil { - fmt.Println(err) http.Error(w, "Failed to update last workspace", http.StatusInternalServerError) return } From fbb8fa3a60f275686a59e93a669f1bc1c42f1a2f Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 27 Nov 2024 21:28:59 +0100 Subject: [PATCH 25/38] Implement admin handlers integration test --- server/internal/handlers/admin_handlers.go | 10 +- .../admin_handlers_integration_test.go | 243 ++++++++++++++++++ server/internal/handlers/integration_test.go | 29 ++- 3 files changed, 268 insertions(+), 14 deletions(-) create mode 100644 server/internal/handlers/admin_handlers_integration_test.go diff --git a/server/internal/handlers/admin_handlers.go b/server/internal/handlers/admin_handlers.go index 0ae33e2..02b436c 100644 --- a/server/internal/handlers/admin_handlers.go +++ b/server/internal/handlers/admin_handlers.go @@ -15,14 +15,16 @@ import ( "golang.org/x/crypto/bcrypt" ) -type createUserRequest struct { +// CreateUserRequest holds the request fields for creating a new user +type CreateUserRequest struct { Email string `json:"email"` DisplayName string `json:"displayName"` Password string `json:"password"` Role models.UserRole `json:"role"` } -type updateUserRequest struct { +// UpdateUserRequest holds the request fields for updating a user +type UpdateUserRequest struct { Email string `json:"email,omitempty"` DisplayName string `json:"displayName,omitempty"` Password string `json:"password,omitempty"` @@ -45,7 +47,7 @@ func (h *Handler) AdminListUsers() http.HandlerFunc { // AdminCreateUser creates a new user func (h *Handler) AdminCreateUser() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - var req createUserRequest + var req CreateUserRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return @@ -136,7 +138,7 @@ func (h *Handler) AdminUpdateUser() http.HandlerFunc { return } - var req updateUserRequest + var req UpdateUserRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return diff --git a/server/internal/handlers/admin_handlers_integration_test.go b/server/internal/handlers/admin_handlers_integration_test.go new file mode 100644 index 0000000..121ea4b --- /dev/null +++ b/server/internal/handlers/admin_handlers_integration_test.go @@ -0,0 +1,243 @@ +//go:build integration + +package handlers_test + +import ( + "encoding/json" + "fmt" + "net/http" + "testing" + + "novamd/internal/handlers" + "novamd/internal/models" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper function to check if a user exists in a slice of users +func containsUser(users []*models.User, searchUser *models.User) bool { + for _, u := range users { + if u.ID == searchUser.ID && + u.Email == searchUser.Email && + u.DisplayName == searchUser.DisplayName && + u.Role == searchUser.Role { + return true + } + } + return false +} + +func TestAdminHandlers_Integration(t *testing.T) { + h := setupTestHarness(t) + defer h.teardown(t) + + t.Run("user management", func(t *testing.T) { + t.Run("list users", func(t *testing.T) { + // Test with admin token + rr := h.makeRequest(t, http.MethodGet, "/api/v1/admin/users", nil, h.AdminToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var users []*models.User + err := json.NewDecoder(rr.Body).Decode(&users) + require.NoError(t, err) + + // Should have at least our admin and regular test users + assert.GreaterOrEqual(t, len(users), 2) + assert.True(t, containsUser(users, h.AdminUser), "Admin user not found in users list") + assert.True(t, containsUser(users, h.RegularUser), "Regular user not found in users list") + + // Test with non-admin token + rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/users", nil, h.RegularToken, nil) + assert.Equal(t, http.StatusForbidden, rr.Code) + + // Test without token + rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/users", nil, "", nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + + t.Run("create user", func(t *testing.T) { + createReq := handlers.CreateUserRequest{ + Email: "newuser@test.com", + DisplayName: "New User", + Password: "password123", + Role: models.RoleEditor, + } + + // Test with admin token + rr := h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.AdminToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var createdUser models.User + err := json.NewDecoder(rr.Body).Decode(&createdUser) + require.NoError(t, err) + assert.Equal(t, createReq.Email, createdUser.Email) + assert.Equal(t, createReq.DisplayName, createdUser.DisplayName) + assert.Equal(t, createReq.Role, createdUser.Role) + assert.NotZero(t, createdUser.LastWorkspaceID) + + // Test duplicate email + rr = h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.AdminToken, nil) + assert.Equal(t, http.StatusConflict, rr.Code) + + // Test invalid request (missing required fields) + invalidReq := handlers.CreateUserRequest{ + Email: "invalid@test.com", + // Missing password and role + } + rr = h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", invalidReq, h.AdminToken, nil) + assert.Equal(t, http.StatusBadRequest, rr.Code) + + // Test with non-admin token + rr = h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.RegularToken, nil) + assert.Equal(t, http.StatusForbidden, rr.Code) + }) + + t.Run("get user", func(t *testing.T) { + path := fmt.Sprintf("/api/v1/admin/users/%d", h.RegularUser.ID) + + // Test with admin token + rr := h.makeRequest(t, http.MethodGet, path, nil, h.AdminToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var user models.User + err := json.NewDecoder(rr.Body).Decode(&user) + require.NoError(t, err) + assert.Equal(t, h.RegularUser.ID, user.ID) + + // Test non-existent user + rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/users/999999", nil, h.AdminToken, nil) + assert.Equal(t, http.StatusNotFound, rr.Code) + + // Test with non-admin token + rr = h.makeRequest(t, http.MethodGet, path, nil, h.RegularToken, nil) + assert.Equal(t, http.StatusForbidden, rr.Code) + }) + + t.Run("update user", func(t *testing.T) { + path := fmt.Sprintf("/api/v1/admin/users/%d", h.RegularUser.ID) + updateReq := handlers.UpdateUserRequest{ + DisplayName: "Updated Name", + Role: models.RoleViewer, + } + + // Test with admin token + rr := h.makeRequest(t, http.MethodPut, path, updateReq, h.AdminToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var updatedUser models.User + err := json.NewDecoder(rr.Body).Decode(&updatedUser) + require.NoError(t, err) + assert.Equal(t, updateReq.DisplayName, updatedUser.DisplayName) + assert.Equal(t, updateReq.Role, updatedUser.Role) + + // Test with non-admin token + rr = h.makeRequest(t, http.MethodPut, path, updateReq, h.RegularToken, nil) + assert.Equal(t, http.StatusForbidden, rr.Code) + }) + + t.Run("delete user", func(t *testing.T) { + // Create a user to delete + createReq := handlers.CreateUserRequest{ + Email: "todelete@test.com", + DisplayName: "To Delete", + Password: "password123", + Role: models.RoleEditor, + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.AdminToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var createdUser models.User + err := json.NewDecoder(rr.Body).Decode(&createdUser) + require.NoError(t, err) + + path := fmt.Sprintf("/api/v1/admin/users/%d", createdUser.ID) + + // Test deleting own account (should fail) + adminPath := fmt.Sprintf("/api/v1/admin/users/%d", h.AdminUser.ID) + rr = h.makeRequest(t, http.MethodDelete, adminPath, nil, h.AdminToken, nil) + assert.Equal(t, http.StatusBadRequest, rr.Code) + + // Test with admin token + rr = h.makeRequest(t, http.MethodDelete, path, nil, h.AdminToken, nil) + assert.Equal(t, http.StatusNoContent, rr.Code) + + // Verify user is deleted + rr = h.makeRequest(t, http.MethodGet, path, nil, h.AdminToken, nil) + assert.Equal(t, http.StatusNotFound, rr.Code) + + // Test with non-admin token + rr = h.makeRequest(t, http.MethodDelete, path, nil, h.RegularToken, nil) + assert.Equal(t, http.StatusForbidden, rr.Code) + }) + }) + + t.Run("workspace management", func(t *testing.T) { + t.Run("list workspaces", func(t *testing.T) { + // Create a test workspace first + workspace := &models.Workspace{ + UserID: h.RegularUser.ID, + Name: "Test Workspace", + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + // Test with admin token + rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/workspaces", nil, h.AdminToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var workspaces []*handlers.WorkspaceStats + err := json.NewDecoder(rr.Body).Decode(&workspaces) + require.NoError(t, err) + + // Should have at least the default workspaces for admin and regular users + assert.NotEmpty(t, workspaces) + + // Verify workspace stats fields + for _, ws := range workspaces { + assert.NotZero(t, ws.UserID) + assert.NotEmpty(t, ws.UserEmail) + assert.NotZero(t, ws.WorkspaceID) + assert.NotEmpty(t, ws.WorkspaceName) + assert.NotZero(t, ws.WorkspaceCreatedAt) + assert.GreaterOrEqual(t, ws.TotalFiles, 0) + assert.GreaterOrEqual(t, ws.TotalSize, int64(0)) + } + + // Test with non-admin token + rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/workspaces", nil, h.RegularToken, nil) + assert.Equal(t, http.StatusForbidden, rr.Code) + }) + }) + + t.Run("system stats", func(t *testing.T) { + // Create some test data + workspace := &models.Workspace{ + UserID: h.RegularUser.ID, + Name: "Stats Test Workspace", + } + rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + // Test with admin token + rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/stats", nil, h.AdminToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var stats handlers.SystemStats + err := json.NewDecoder(rr.Body).Decode(&stats) + require.NoError(t, err) + + // Verify stats fields + assert.GreaterOrEqual(t, stats.TotalUsers, 2) // At least admin and regular user + assert.GreaterOrEqual(t, stats.TotalWorkspaces, 2) // At least default workspaces + assert.GreaterOrEqual(t, stats.ActiveUsers, 2) // Our test users should be active + assert.GreaterOrEqual(t, stats.TotalFiles, 0) + assert.GreaterOrEqual(t, stats.TotalSize, int64(0)) + + // Test with non-admin token + rr = h.makeRequest(t, http.MethodGet, "/api/v1/admin/stats", nil, h.RegularToken, nil) + assert.Equal(t, http.StatusForbidden, rr.Code) + }) +} diff --git a/server/internal/handlers/integration_test.go b/server/internal/handlers/integration_test.go index 4201fdd..ea26b90 100644 --- a/server/internal/handlers/integration_test.go +++ b/server/internal/handlers/integration_test.go @@ -91,23 +91,26 @@ func setupTestHarness(t *testing.T) *testHarness { api.SetupRoutes(r, database, storageSvc, authMiddleware, sessionSvc) }) - // Create test users - adminUser, adminToken := createTestUser(t, database, sessionSvc, "admin@test.com", "admin123", models.RoleAdmin) - regularUser, regularToken := createTestUser(t, database, sessionSvc, "user@test.com", "user123", models.RoleEditor) - - return &testHarness{ + h := &testHarness{ DB: database, Storage: storageSvc, Router: router, Handler: handler, JWTManager: jwtSvc, SessionSvc: sessionSvc, - AdminUser: adminUser, - AdminToken: adminToken, - RegularUser: regularUser, - RegularToken: regularToken, TempDirectory: tempDir, } + + // Create test users + adminUser, adminToken := h.createTestUser(t, database, sessionSvc, "admin@test.com", "admin123", models.RoleAdmin) + regularUser, regularToken := h.createTestUser(t, database, sessionSvc, "user@test.com", "user123", models.RoleEditor) + + h.AdminUser = adminUser + h.AdminToken = adminToken + h.RegularUser = regularUser + h.RegularToken = regularToken + + return h } // teardownTestHarness cleans up the test environment @@ -124,7 +127,7 @@ func (h *testHarness) teardown(t *testing.T) { } // createTestUser creates a test user and returns the user and access token -func createTestUser(t *testing.T, db db.Database, sessionSvc *auth.SessionService, email, password string, role models.UserRole) (*models.User, string) { +func (h *testHarness) createTestUser(t *testing.T, db db.Database, sessionSvc *auth.SessionService, email, password string, role models.UserRole) (*models.User, string) { t.Helper() hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) @@ -144,6 +147,12 @@ func createTestUser(t *testing.T, db db.Database, sessionSvc *auth.SessionServic t.Fatalf("Failed to create user: %v", err) } + // Initialize the default workspace directory in storage + err = h.Storage.InitializeUserWorkspace(user.ID, user.LastWorkspaceID) + if err != nil { + t.Fatalf("Failed to initialize user workspace: %v", err) + } + session, accessToken, err := sessionSvc.CreateSession(user.ID, string(user.Role)) if err != nil { t.Fatalf("Failed to create session: %v", err) From 91489ca633ffe925324d65d4667f62b11aada69e Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 28 Nov 2024 21:18:30 +0100 Subject: [PATCH 26/38] Update path validation error handling --- server/internal/handlers/file_handlers.go | 47 +++++++++++++++++++++-- server/internal/storage/errors.go | 24 ++++++++++++ server/internal/storage/workspace.go | 4 +- server/internal/storage/workspace_test.go | 2 +- 4 files changed, 70 insertions(+), 7 deletions(-) create mode 100644 server/internal/storage/errors.go diff --git a/server/internal/handlers/file_handlers.go b/server/internal/handlers/file_handlers.go index 3f53c17..93ddd99 100644 --- a/server/internal/handlers/file_handlers.go +++ b/server/internal/handlers/file_handlers.go @@ -4,8 +4,10 @@ import ( "encoding/json" "io" "net/http" + "os" "novamd/internal/context" + "novamd/internal/storage" "github.com/go-chi/chi/v5" ) @@ -63,7 +65,18 @@ func (h *Handler) GetFileContent() http.HandlerFunc { filePath := chi.URLParam(r, "*") content, err := h.Storage.GetFileContent(ctx.UserID, ctx.Workspace.ID, filePath) if err != nil { - http.Error(w, "Failed to read file", http.StatusNotFound) + + if storage.IsPathValidationError(err) { + http.Error(w, "Invalid file path", http.StatusBadRequest) + return + } + + if os.IsNotExist(err) { + http.Error(w, "Failed to read file", http.StatusNotFound) + return + } + + http.Error(w, "Failed to read file", http.StatusInternalServerError) return } @@ -93,6 +106,11 @@ func (h *Handler) SaveFile() http.HandlerFunc { err = h.Storage.SaveFile(ctx.UserID, ctx.Workspace.ID, filePath, content) if err != nil { + if storage.IsPathValidationError(err) { + http.Error(w, "Invalid file path", http.StatusBadRequest) + return + } + http.Error(w, "Failed to save file", http.StatusInternalServerError) return } @@ -112,6 +130,16 @@ func (h *Handler) DeleteFile() http.HandlerFunc { filePath := chi.URLParam(r, "*") err := h.Storage.DeleteFile(ctx.UserID, ctx.Workspace.ID, filePath) if err != nil { + if storage.IsPathValidationError(err) { + http.Error(w, "Invalid file path", http.StatusBadRequest) + return + } + + if os.IsNotExist(err) { + http.Error(w, "File not found", http.StatusNotFound) + return + } + http.Error(w, "Failed to delete file", http.StatusInternalServerError) return } @@ -165,10 +193,21 @@ func (h *Handler) UpdateLastOpenedFile() http.HandlerFunc { return } - // Validate the file path exists in the workspace + // Validate the file path in the workspace if requestBody.FilePath != "" { - if _, err := h.Storage.ValidatePath(ctx.UserID, ctx.Workspace.ID, requestBody.FilePath); err != nil { - http.Error(w, "Invalid file path", http.StatusBadRequest) + _, err := h.Storage.GetFileContent(ctx.UserID, ctx.Workspace.ID, requestBody.FilePath) + if err != nil { + if storage.IsPathValidationError(err) { + http.Error(w, "Invalid file path", http.StatusBadRequest) + return + } + + if os.IsNotExist(err) { + http.Error(w, "File not found", http.StatusNotFound) + return + } + + http.Error(w, "Failed to update file", http.StatusInternalServerError) return } } diff --git a/server/internal/storage/errors.go b/server/internal/storage/errors.go new file mode 100644 index 0000000..dbfdf02 --- /dev/null +++ b/server/internal/storage/errors.go @@ -0,0 +1,24 @@ +// storage/errors.go + +package storage + +import ( + "errors" + "fmt" +) + +// PathValidationError represents a path validation error (e.g., path traversal attempt) +type PathValidationError struct { + Path string + Message string +} + +func (e *PathValidationError) Error() string { + return fmt.Sprintf("%s: %s", e.Message, e.Path) +} + +// IsPathValidationError checks if the error is a PathValidationError +func IsPathValidationError(err error) bool { + var pathErr *PathValidationError + return err != nil && errors.As(err, &pathErr) +} diff --git a/server/internal/storage/workspace.go b/server/internal/storage/workspace.go index 7a264ad..2a4beca 100644 --- a/server/internal/storage/workspace.go +++ b/server/internal/storage/workspace.go @@ -27,7 +27,7 @@ func (s *Service) ValidatePath(userID, workspaceID int, path string) (string, er // First check if the path is absolute if filepath.IsAbs(path) { - return "", fmt.Errorf("invalid path: absolute paths not allowed") + return "", &PathValidationError{Path: path, Message: "absolute paths not allowed"} } // Join and clean the path @@ -36,7 +36,7 @@ func (s *Service) ValidatePath(userID, workspaceID int, path string) (string, er // Verify the path is still within the workspace if !strings.HasPrefix(cleanPath, workspacePath) { - return "", fmt.Errorf("invalid path: outside of workspace") + return "", &PathValidationError{Path: path, Message: "path traversal attempt"} } return cleanPath, nil diff --git a/server/internal/storage/workspace_test.go b/server/internal/storage/workspace_test.go index f6b1607..2fc11b2 100644 --- a/server/internal/storage/workspace_test.go +++ b/server/internal/storage/workspace_test.go @@ -48,7 +48,7 @@ func TestValidatePath(t *testing.T) { path: "../../../etc/passwd", want: "", wantErr: true, - errContains: "outside of workspace", + errContains: "path traversal attempt", }, { name: "absolute path attempt", From 3fb40a881729df7962ea0c1593a4a675c8d0c517 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 28 Nov 2024 21:18:47 +0100 Subject: [PATCH 27/38] Implement file handlers integration tests --- .../file_handlers_integration_test.go | 239 ++++++++++++++++++ server/internal/handlers/integration_test.go | 21 ++ 2 files changed, 260 insertions(+) create mode 100644 server/internal/handlers/file_handlers_integration_test.go diff --git a/server/internal/handlers/file_handlers_integration_test.go b/server/internal/handlers/file_handlers_integration_test.go new file mode 100644 index 0000000..70327d6 --- /dev/null +++ b/server/internal/handlers/file_handlers_integration_test.go @@ -0,0 +1,239 @@ +//go:build integration + +package handlers_test + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "testing" + + "novamd/internal/models" + "novamd/internal/storage" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFileHandlers_Integration(t *testing.T) { + h := setupTestHarness(t) + defer h.teardown(t) + + t.Run("file operations", func(t *testing.T) { + // Setup: Create a workspace first + workspace := &models.Workspace{ + UserID: h.RegularUser.ID, + Name: "File Test Workspace", + } + rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + err := json.NewDecoder(rr.Body).Decode(workspace) + require.NoError(t, err) + + // Construct base URL for file operations + baseURL := fmt.Sprintf("/api/v1/workspaces/%s/files", url.PathEscape(workspace.Name)) + + t.Run("list empty directory", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodGet, baseURL, nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var files []storage.FileNode + err := json.NewDecoder(rr.Body).Decode(&files) + require.NoError(t, err) + assert.Empty(t, files, "Expected empty directory") + }) + + t.Run("save and get file", func(t *testing.T) { + content := "Test content for file operations" + filePath := "test.md" + + // Save file + rr := h.makeRequestRaw(t, http.MethodPost, baseURL+"/"+filePath, strings.NewReader(content), h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + // Get file content + rr = h.makeRequest(t, http.MethodGet, baseURL+"/"+filePath, nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, content, rr.Body.String()) + + // List directory should now show the file + rr = h.makeRequest(t, http.MethodGet, baseURL, nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var files []storage.FileNode + err := json.NewDecoder(rr.Body).Decode(&files) + require.NoError(t, err) + assert.Len(t, files, 1) + assert.Equal(t, filePath, files[0].Name) + }) + + t.Run("save and list nested files", func(t *testing.T) { + files := map[string]string{ + "docs/readme.md": "README content", + "docs/api/endpoints.md": "API documentation", + "notes/meeting-notes.md": "Meeting notes content", + "notes/todo.md": "TODO list", + } + + // Create all files + for path, content := range files { + rr := h.makeRequest(t, http.MethodPost, baseURL+"/"+path, content, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + } + + // List all files + rr := h.makeRequest(t, http.MethodGet, baseURL, nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var fileNodes []storage.FileNode + err := json.NewDecoder(rr.Body).Decode(&fileNodes) + require.NoError(t, err) + + // We should have 3 root items: docs/, notes/, and test.md + assert.Len(t, fileNodes, 3) + + // Verify directory structure + var docsDir, notesDir *storage.FileNode + for i := range fileNodes { + switch fileNodes[i].Name { + case "docs": + docsDir = &fileNodes[i] + case "notes": + notesDir = &fileNodes[i] + } + } + + require.NotNil(t, docsDir) + require.NotNil(t, notesDir) + assert.Len(t, docsDir.Children, 2) // readme.md and api/ + assert.Len(t, notesDir.Children, 2) // meeting-notes.md and todo.md + }) + + t.Run("lookup file by name", func(t *testing.T) { + // Look up a file that exists in multiple locations + filename := "readme.md" + dupContent := "Another readme" + rr := h.makeRequest(t, http.MethodPost, baseURL+"/projects/"+filename, dupContent, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + // Search for the file + rr = h.makeRequest(t, http.MethodGet, baseURL+"/lookup?filename="+filename, nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var response struct { + Paths []string `json:"paths"` + } + err := json.NewDecoder(rr.Body).Decode(&response) + require.NoError(t, err) + assert.Len(t, response.Paths, 2) + + // Search for non-existent file + rr = h.makeRequest(t, http.MethodGet, baseURL+"/lookup?filename=nonexistent.md", nil, h.RegularToken, nil) + assert.Equal(t, http.StatusNotFound, rr.Code) + }) + + t.Run("delete file", func(t *testing.T) { + filePath := "to-delete.md" + content := "This file will be deleted" + + // Create file + rr := h.makeRequest(t, http.MethodPost, baseURL+"/"+filePath, content, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + // Delete file + rr = h.makeRequest(t, http.MethodDelete, baseURL+"/"+filePath, nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + // Verify file is gone + rr = h.makeRequest(t, http.MethodGet, baseURL+"/"+filePath, nil, h.RegularToken, nil) + assert.Equal(t, http.StatusNotFound, rr.Code) + }) + + t.Run("last opened file", func(t *testing.T) { + // Initially should be empty + rr := h.makeRequest(t, http.MethodGet, baseURL+"/last", nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var response struct { + LastOpenedFilePath string `json:"lastOpenedFilePath"` + } + err := json.NewDecoder(rr.Body).Decode(&response) + require.NoError(t, err) + assert.Empty(t, response.LastOpenedFilePath) + + // Update last opened file + updateReq := struct { + FilePath string `json:"filePath"` + }{ + FilePath: "docs/readme.md", + } + rr = h.makeRequest(t, http.MethodPut, baseURL+"/last", updateReq, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + // Verify update + rr = h.makeRequest(t, http.MethodGet, baseURL+"/last", nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + err = json.NewDecoder(rr.Body).Decode(&response) + require.NoError(t, err) + assert.Equal(t, updateReq.FilePath, response.LastOpenedFilePath) + + // Test invalid file path + updateReq.FilePath = "nonexistent.md" + rr = h.makeRequest(t, http.MethodPut, baseURL+"/last", updateReq, h.RegularToken, nil) + assert.Equal(t, http.StatusNotFound, rr.Code) + }) + + t.Run("unauthorized access", func(t *testing.T) { + tests := []struct { + name string + method string + path string + body interface{} + }{ + {"list files", http.MethodGet, baseURL, nil}, + {"get file", http.MethodGet, baseURL + "/test.md", nil}, + {"save file", http.MethodPost, baseURL + "/test.md", "content"}, + {"delete file", http.MethodDelete, baseURL + "/test.md", nil}, + {"get last file", http.MethodGet, baseURL + "/last", nil}, + {"update last file", http.MethodPut, baseURL + "/last", struct{ FilePath string }{"test.md"}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Test without token + rr := h.makeRequest(t, tc.method, tc.path, tc.body, "", nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + + // Test with wrong user's token + rr = h.makeRequest(t, tc.method, tc.path, tc.body, h.AdminToken, nil) + assert.Equal(t, http.StatusNotFound, rr.Code) + }) + } + }) + + t.Run("path traversal attempts", func(t *testing.T) { + maliciousPaths := []string{ + "../../../etc/passwd", + "./../../secret.txt", + "/etc/shadow", + "test/../../../etc/passwd", + } + + for _, path := range maliciousPaths { + t.Run(path, func(t *testing.T) { + // Try to read + rr := h.makeRequest(t, http.MethodGet, baseURL+"/"+path, nil, h.RegularToken, nil) + assert.Equal(t, http.StatusBadRequest, rr.Code) + + // Try to write + rr = h.makeRequest(t, http.MethodPost, baseURL+"/"+path, "malicious content", h.RegularToken, nil) + assert.Equal(t, http.StatusBadRequest, rr.Code) + }) + } + }) + }) +} diff --git a/server/internal/handlers/integration_test.go b/server/internal/handlers/integration_test.go index ea26b90..087601a 100644 --- a/server/internal/handlers/integration_test.go +++ b/server/internal/handlers/integration_test.go @@ -5,6 +5,7 @@ package handlers_test import ( "bytes" "encoding/json" + "io" "net/http/httptest" "os" "testing" @@ -195,3 +196,23 @@ func (h *testHarness) makeRequest(t *testing.T, method, path string, body interf return rr } + +// makeRequestRaw is a helper function to make HTTP requests with raw body content +func (h *testHarness) makeRequestRaw(t *testing.T, method, path string, body io.Reader, token string, headers map[string]string) *httptest.ResponseRecorder { + t.Helper() + + req := httptest.NewRequest(method, path, body) + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + // Add any additional headers + for k, v := range headers { + req.Header.Set(k, v) + } + + rr := httptest.NewRecorder() + h.Router.ServeHTTP(rr, req) + + return rr +} From 51ed9e53a4e72ac998c8c762db4a530a52d46e12 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 28 Nov 2024 21:33:28 +0100 Subject: [PATCH 28/38] Implement static handler tests --- .../static_handler_integration_test.go | 145 ++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 server/internal/handlers/static_handler_integration_test.go diff --git a/server/internal/handlers/static_handler_integration_test.go b/server/internal/handlers/static_handler_integration_test.go new file mode 100644 index 0000000..e6205ec --- /dev/null +++ b/server/internal/handlers/static_handler_integration_test.go @@ -0,0 +1,145 @@ +//go:build integration + +package handlers_test + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "novamd/internal/handlers" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStaticHandler_Integration(t *testing.T) { + // Create temporary directory for test static files + tempDir, err := os.MkdirTemp("", "novamd-static-test-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + // Create test files + files := map[string][]byte{ + "index.html": []byte("Index"), + "assets/style.css": []byte("body { color: blue; }"), + "assets/style.css.gz": []byte("gzipped css content"), + "assets/script.js": []byte("console.log('test');"), + "assets/script.js.gz": []byte("gzipped js content"), + "subdir/page.html": []byte("Page"), + "subdir/page.html.gz": []byte("gzipped html content"), + } + + for path, content := range files { + fullPath := filepath.Join(tempDir, path) + err := os.MkdirAll(filepath.Dir(fullPath), 0755) + require.NoError(t, err) + err = os.WriteFile(fullPath, content, 0644) + require.NoError(t, err) + } + + // Create static handler + handler := handlers.NewStaticHandler(tempDir) + + tests := []struct { + name string + path string + acceptEncoding string + wantStatus int + wantBody []byte + wantType string + wantEncoding string + wantCacheHeader string + }{ + { + name: "serve index.html", + path: "/", + wantStatus: http.StatusOK, + wantBody: []byte("Index"), + wantType: "text/html; charset=utf-8", + }, + { + name: "serve CSS with gzip support", + path: "/assets/style.css", + acceptEncoding: "gzip", + wantStatus: http.StatusOK, + wantBody: []byte("gzipped css content"), + wantType: "text/css", + wantEncoding: "gzip", + wantCacheHeader: "public, max-age=31536000", + }, + { + name: "serve JS with gzip support", + path: "/assets/script.js", + acceptEncoding: "gzip", + wantStatus: http.StatusOK, + wantBody: []byte("gzipped js content"), + wantType: "application/javascript", + wantEncoding: "gzip", + wantCacheHeader: "public, max-age=31536000", + }, + { + name: "serve CSS without gzip", + path: "/assets/style.css", + wantStatus: http.StatusOK, + wantBody: []byte("body { color: blue; }"), + wantType: "text/css; charset=utf-8", + wantCacheHeader: "public, max-age=31536000", + }, + { + name: "SPA routing - nonexistent path", + path: "/nonexistent", + wantStatus: http.StatusOK, + wantBody: []byte("Index"), + wantType: "text/html; charset=utf-8", + }, + { + name: "SPA routing - deep path", + path: "/some/deep/path", + wantStatus: http.StatusOK, + wantBody: []byte("Index"), + wantType: "text/html; charset=utf-8", + }, + { + name: "block directory traversal", + path: "/../../../etc/passwd", + wantStatus: http.StatusBadRequest, + }, + { + name: "nonexistent file in assets", + path: "/assets/nonexistent.js", + wantStatus: http.StatusOK, // Should serve index.html + wantBody: []byte("Index"), + wantType: "text/html; charset=utf-8", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tc.path, nil) + if tc.acceptEncoding != "" { + req.Header.Set("Accept-Encoding", tc.acceptEncoding) + } + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, tc.wantStatus, w.Code) + + if tc.wantStatus == http.StatusOK { + assert.Equal(t, tc.wantBody, w.Body.Bytes()) + assert.Equal(t, tc.wantType, w.Header().Get("Content-Type")) + + if tc.wantEncoding != "" { + assert.Equal(t, tc.wantEncoding, w.Header().Get("Content-Encoding")) + } + + if tc.wantCacheHeader != "" { + assert.Equal(t, tc.wantCacheHeader, w.Header().Get("Cache-Control")) + } + } + }) + } +} From f5d616fe006e6e6b30c7638754bb83589211adb1 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 28 Nov 2024 21:53:03 +0100 Subject: [PATCH 29/38] Update documentation --- server/internal/auth/jwt.go | 26 ++----------- server/internal/auth/middleware.go | 16 -------- server/internal/auth/session.go | 29 ++------------ server/internal/db/users.go | 2 +- server/internal/db/workspaces.go | 2 +- server/internal/db/workspaces_test.go | 14 +++---- server/internal/git/client.go | 25 ++----------- server/internal/models/user.go | 4 ++ server/internal/models/workspace.go | 5 ++- server/internal/storage/files.go | 54 +++++---------------------- server/internal/storage/git.go | 31 +++------------ server/internal/storage/service.go | 19 +++------- server/internal/storage/workspace.go | 31 +++------------ 13 files changed, 53 insertions(+), 205 deletions(-) diff --git a/server/internal/auth/jwt.go b/server/internal/auth/jwt.go index f9b4b23..59790f0 100644 --- a/server/internal/auth/jwt.go +++ b/server/internal/auth/jwt.go @@ -62,31 +62,17 @@ func NewJWTService(config JWTConfig) (JWTManager, error) { return &jwtService{config: config}, nil } -// GenerateAccessToken creates a new access token for a user -// Parameters: -// - userID: the ID of the user -// - role: the role of the user -// Returns the signed token string or an error +// GenerateAccessToken creates a new access token for a user with the given userID and role func (s *jwtService) GenerateAccessToken(userID int, role string) (string, error) { return s.generateToken(userID, role, AccessToken, s.config.AccessTokenExpiry) } -// GenerateRefreshToken creates a new refresh token for a user -// Parameters: -// - userID: the ID of the user -// - role: the role of the user -// Returns the signed token string or an error +// GenerateRefreshToken creates a new refresh token for a user with the given userID and role func (s *jwtService) GenerateRefreshToken(userID int, role string) (string, error) { return s.generateToken(userID, role, RefreshToken, s.config.RefreshTokenExpiry) } // generateToken is an internal helper function that creates a new JWT token -// Parameters: -// - userID: the ID of the user -// - role: the role of the user -// - tokenType: the type of token (access or refresh) -// - expiry: how long the token should be valid -// Returns the signed token string or an error func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, expiry time.Duration) (string, error) { now := time.Now() @@ -113,9 +99,6 @@ func (s *jwtService) generateToken(userID int, role string, tokenType TokenType, } // ValidateToken validates and parses a JWT token -// Parameters: -// - tokenString: the token to validate -// Returns the token claims if valid, or an error if invalid func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { // Validate the signing method @@ -136,10 +119,7 @@ func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) { return nil, fmt.Errorf("invalid token claims") } -// RefreshAccessToken creates a new access token using a refresh token -// Parameters: -// - refreshToken: the refresh token to use -// Returns a new access token if the refresh token is valid, or an error +// RefreshAccessToken creates a new access token using a refreshToken func (s *jwtService) RefreshAccessToken(refreshToken string) (string, error) { claims, err := s.ValidateToken(refreshToken) if err != nil { diff --git a/server/internal/auth/middleware.go b/server/internal/auth/middleware.go index e669460..8018612 100644 --- a/server/internal/auth/middleware.go +++ b/server/internal/auth/middleware.go @@ -13,10 +13,6 @@ type Middleware struct { } // NewMiddleware creates a new authentication middleware -// Parameters: -// - jwtManager: the JWT manager to use for token operations -// Returns: -// - *Middleware: the new middleware instance func NewMiddleware(jwtManager JWTManager) *Middleware { return &Middleware{ jwtManager: jwtManager, @@ -24,10 +20,6 @@ func NewMiddleware(jwtManager JWTManager) *Middleware { } // Authenticate middleware validates JWT tokens and sets user information in context -// Parameters: -// - next: the next handler to call -// Returns: -// - http.Handler: the handler function func (m *Middleware) Authenticate(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Extract token from Authorization header @@ -69,10 +61,6 @@ func (m *Middleware) Authenticate(next http.Handler) http.Handler { } // RequireRole returns a middleware that ensures the user has the required role -// Parameters: -// - role: the required role -// Returns: -// - func(http.Handler) http.Handler: the middleware function func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -92,10 +80,6 @@ func (m *Middleware) RequireRole(role string) func(http.Handler) http.Handler { } // RequireWorkspaceAccess returns a middleware that ensures the user has access to the workspace -// Parameters: -// - next: the next handler to call -// Returns: -// - http.Handler: the handler function func (m *Middleware) RequireWorkspaceAccess(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, ok := context.GetRequestContext(w, r) diff --git a/server/internal/auth/session.go b/server/internal/auth/session.go index 872b2c1..897fb3d 100644 --- a/server/internal/auth/session.go +++ b/server/internal/auth/session.go @@ -15,10 +15,7 @@ type SessionService struct { jwtManager JWTManager // JWT Manager for token operations } -// NewSessionService creates a new session service -// Parameters: -// - db: database connection -// - jwtManager: JWT service for token operations +// NewSessionService creates a new session service with the given database and JWT manager func NewSessionService(db db.SessionStore, jwtManager JWTManager) *SessionService { return &SessionService{ db: db, @@ -26,14 +23,7 @@ func NewSessionService(db db.SessionStore, jwtManager JWTManager) *SessionServic } } -// CreateSession creates a new user session -// Parameters: -// - userID: the ID of the user -// - role: the role of the user -// Returns: -// - session: the created session -// - accessToken: a new access token -// - error: any error that occurred +// CreateSession creates a new user session for a user with the given userID and role func (s *SessionService) CreateSession(userID int, role string) (*models.Session, string, error) { // Generate both access and refresh tokens accessToken, err := s.jwtManager.GenerateAccessToken(userID, role) @@ -69,12 +59,7 @@ func (s *SessionService) CreateSession(userID int, role string) (*models.Session return session, accessToken, nil } -// RefreshSession creates a new access token using a refresh token -// Parameters: -// - refreshToken: the refresh token to use -// Returns: -// - string: a new access token -// - error: any error that occurred +// RefreshSession creates a new access token using a refreshToken func (s *SessionService) RefreshSession(refreshToken string) (string, error) { // Get session from database first session, err := s.db.GetSessionByRefreshToken(refreshToken) @@ -97,18 +82,12 @@ func (s *SessionService) RefreshSession(refreshToken string) (string, error) { return s.jwtManager.GenerateAccessToken(claims.UserID, claims.Role) } -// InvalidateSession removes a session from the database -// Parameters: -// - sessionID: the ID of the session to invalidate -// Returns: -// - error: any error that occurred +// InvalidateSession removes a session with the given sessionID from the database func (s *SessionService) InvalidateSession(sessionID string) error { return s.db.DeleteSession(sessionID) } // CleanExpiredSessions removes all expired sessions from the database -// Returns: -// - error: any error that occurred func (s *SessionService) CleanExpiredSessions() error { return s.db.CleanExpiredSessions() } diff --git a/server/internal/db/users.go b/server/internal/db/users.go index 132264b..618dd3d 100644 --- a/server/internal/db/users.go +++ b/server/internal/db/users.go @@ -38,7 +38,7 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) { UserID: user.ID, Name: "Main", } - defaultWorkspace.GetDefaultSettings() // Initialize default settings + defaultWorkspace.SetDefaultSettings() // Initialize default settings // Create workspace with settings err = db.createWorkspaceTx(tx, defaultWorkspace) diff --git a/server/internal/db/workspaces.go b/server/internal/db/workspaces.go index 004dcf5..efbbab4 100644 --- a/server/internal/db/workspaces.go +++ b/server/internal/db/workspaces.go @@ -10,7 +10,7 @@ import ( func (db *database) CreateWorkspace(workspace *models.Workspace) error { // Set default settings if not provided if workspace.Theme == "" { - workspace.GetDefaultSettings() + workspace.SetDefaultSettings() } // Encrypt token if present diff --git a/server/internal/db/workspaces_test.go b/server/internal/db/workspaces_test.go index 924d4a9..aa23163 100644 --- a/server/internal/db/workspaces_test.go +++ b/server/internal/db/workspaces_test.go @@ -77,7 +77,7 @@ func TestWorkspaceOperations(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { if tc.workspace.Theme == "" { - tc.workspace.GetDefaultSettings() + tc.workspace.SetDefaultSettings() } err := database.CreateWorkspace(tc.workspace) @@ -117,7 +117,7 @@ func TestWorkspaceOperations(t *testing.T) { UserID: user.ID, Name: "Get By ID Workspace", } - workspace.GetDefaultSettings() + workspace.SetDefaultSettings() if err := database.CreateWorkspace(workspace); err != nil { t.Fatalf("failed to create test workspace: %v", err) } @@ -167,7 +167,7 @@ func TestWorkspaceOperations(t *testing.T) { UserID: user.ID, Name: "Get By Name Workspace", } - workspace.GetDefaultSettings() + workspace.SetDefaultSettings() if err := database.CreateWorkspace(workspace); err != nil { t.Fatalf("failed to create test workspace: %v", err) } @@ -229,7 +229,7 @@ func TestWorkspaceOperations(t *testing.T) { UserID: user.ID, Name: "Update Workspace", } - workspace.GetDefaultSettings() + workspace.SetDefaultSettings() if err := database.CreateWorkspace(workspace); err != nil { t.Fatalf("failed to create test workspace: %v", err) } @@ -272,7 +272,7 @@ func TestWorkspaceOperations(t *testing.T) { } for _, w := range testWorkspaces { - w.GetDefaultSettings() + w.SetDefaultSettings() if err := database.CreateWorkspace(w); err != nil { t.Fatalf("failed to create test workspace: %v", err) } @@ -314,7 +314,7 @@ func TestWorkspaceOperations(t *testing.T) { UserID: user.ID, Name: "Last File Workspace", } - workspace.GetDefaultSettings() + workspace.SetDefaultSettings() if err := database.CreateWorkspace(workspace); err != nil { t.Fatalf("failed to create test workspace: %v", err) } @@ -369,7 +369,7 @@ func TestWorkspaceOperations(t *testing.T) { UserID: user.ID, Name: "Delete Workspace", } - workspace.GetDefaultSettings() + workspace.SetDefaultSettings() if err := database.CreateWorkspace(workspace); err != nil { t.Fatalf("failed to create test workspace: %v", err) } diff --git a/server/internal/git/client.go b/server/internal/git/client.go index 82790d5..0040042 100644 --- a/server/internal/git/client.go +++ b/server/internal/git/client.go @@ -33,14 +33,7 @@ type client struct { repo *git.Repository } -// New creates a new Client instance -// Parameters: -// - url: the URL of the Git repository -// - username: the username for the Git repository -// - token: the access token for the Git repository -// - workDir: the local directory to clone the repository to -// Returns: -// - Client: the Git client +// New creates a new git Client instance func New(url, username, token, workDir string) Client { return &client{ Config: Config{ @@ -53,8 +46,6 @@ func New(url, username, token, workDir string) Client { } // Clone clones the Git repository to the local directory -// Returns: -// - error: any error that occurred during cloning func (c *client) Clone() error { auth := &http.BasicAuth{ Username: c.Username, @@ -76,8 +67,6 @@ func (c *client) Clone() error { } // Pull pulls the latest changes from the remote repository -// Returns: -// - error: any error that occurred during pulling func (c *client) Pull() error { if c.repo == nil { return fmt.Errorf("repository not initialized") @@ -105,11 +94,7 @@ func (c *client) Pull() error { return nil } -// Commit commits the changes in the repository -// Parameters: -// - message: the commit message -// Returns: -// - error: any error that occurred during committing +// Commit commits the changes in the repository with the given message func (c *client) Commit(message string) error { if c.repo == nil { return fmt.Errorf("repository not initialized") @@ -134,8 +119,6 @@ func (c *client) Commit(message string) error { } // Push pushes the changes to the remote repository -// Returns: -// - error: any error that occurred during pushing func (c *client) Push() error { if c.repo == nil { return fmt.Errorf("repository not initialized") @@ -158,9 +141,7 @@ func (c *client) Push() error { return nil } -// EnsureRepo ensures the local repository is up-to-date -// Returns: -// - error: any error that occurred during the operation +// EnsureRepo ensures the local repository is cloned and up-to-date func (c *client) EnsureRepo() error { if _, err := os.Stat(filepath.Join(c.WorkDir, ".git")); os.IsNotExist(err) { return c.Clone() diff --git a/server/internal/models/user.go b/server/internal/models/user.go index e2efcbb..3832cda 100644 --- a/server/internal/models/user.go +++ b/server/internal/models/user.go @@ -8,14 +8,17 @@ import ( var validate = validator.New() +// UserRole represents the role of a user in the system type UserRole string +// User roles const ( RoleAdmin UserRole = "admin" RoleEditor UserRole = "editor" RoleViewer UserRole = "viewer" ) +// User represents a user in the system type User struct { ID int `json:"id" validate:"required,min=1"` Email string `json:"email" validate:"required,email"` @@ -26,6 +29,7 @@ type User struct { LastWorkspaceID int `json:"lastWorkspaceId"` } +// Validate validates the user struct func (u *User) Validate() error { return validate.Struct(u) } diff --git a/server/internal/models/workspace.go b/server/internal/models/workspace.go index 9f9e814..191e6d7 100644 --- a/server/internal/models/workspace.go +++ b/server/internal/models/workspace.go @@ -4,6 +4,7 @@ import ( "time" ) +// Workspace represents a user's workspace in the system type Workspace struct { ID int `json:"id" validate:"required,min=1"` UserID int `json:"userId" validate:"required,min=1"` @@ -23,11 +24,13 @@ type Workspace struct { GitCommitMsgTemplate string `json:"gitCommitMsgTemplate"` } +// Validate validates the workspace struct func (w *Workspace) Validate() error { return validate.Struct(w) } -func (w *Workspace) GetDefaultSettings() { +// SetDefaultSettings sets the default settings for the workspace +func (w *Workspace) SetDefaultSettings() { w.Theme = "light" w.AutoSave = false w.ShowHiddenFiles = false diff --git a/server/internal/storage/files.go b/server/internal/storage/files.go index 136cc84..51fff58 100644 --- a/server/internal/storage/files.go +++ b/server/internal/storage/files.go @@ -30,12 +30,7 @@ type FileNode struct { } // ListFilesRecursively returns a list of all files in the workspace directory and its subdirectories. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to list files in -// Returns: -// - nodes: a list of files and directories in the workspace -// - error: any error that occurred during listing +// Workspace is identified by the given userID and workspaceID. func (s *Service) ListFilesRecursively(userID, workspaceID int) ([]FileNode, error) { workspacePath := s.GetWorkspacePath(userID, workspaceID) return s.walkDirectory(workspacePath, "") @@ -106,13 +101,8 @@ func (s *Service) walkDirectory(dir, prefix string) ([]FileNode, error) { } // FindFileByName returns a list of file paths that match the given filename. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to search for the file -// - filename: the name of the file to search for -// Returns: -// - foundPaths: a list of file paths that match the filename -// - error: any error that occurred during the search +// Files are searched recursively in the workspace directory and its subdirectories. +// Workspace is identified by the given userID and workspaceID. func (s *Service) FindFileByName(userID, workspaceID int, filename string) ([]string, error) { var foundPaths []string workspacePath := s.GetWorkspacePath(userID, workspaceID) @@ -144,14 +134,8 @@ func (s *Service) FindFileByName(userID, workspaceID int, filename string) ([]st return foundPaths, nil } -// GetFileContent returns the content of the file at the given path. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to get the file from -// - filePath: the path of the file to get -// Returns: -// - content: the content of the file -// - error: any error that occurred during reading +// GetFileContent returns the content of the file at the given filePath. +// Path must be a relative path within the workspace directory given by userID and workspaceID. func (s *Service) GetFileContent(userID, workspaceID int, filePath string) ([]byte, error) { fullPath, err := s.ValidatePath(userID, workspaceID, filePath) if err != nil { @@ -160,14 +144,8 @@ func (s *Service) GetFileContent(userID, workspaceID int, filePath string) ([]by return s.fs.ReadFile(fullPath) } -// SaveFile writes the content to the file at the given path. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to save the file to -// - filePath: the path of the file to save -// - content: the content to write to the file -// Returns: -// - error: any error that occurred during saving +// SaveFile writes the content to the file at the given filePath. +// Path must be a relative path within the workspace directory given by userID and workspaceID. func (s *Service) SaveFile(userID, workspaceID int, filePath string, content []byte) error { fullPath, err := s.ValidatePath(userID, workspaceID, filePath) if err != nil { @@ -182,13 +160,8 @@ func (s *Service) SaveFile(userID, workspaceID int, filePath string, content []b return s.fs.WriteFile(fullPath, content, 0644) } -// DeleteFile deletes the file at the given path. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to delete the file from -// - filePath: the path of the file to delete -// Returns: -// - error: any error that occurred during deletion +// DeleteFile deletes the file at the given filePath. +// Path must be a relative path within the workspace directory given by userID and workspaceID. func (s *Service) DeleteFile(userID, workspaceID int, filePath string) error { fullPath, err := s.ValidatePath(userID, workspaceID, filePath) if err != nil { @@ -204,12 +177,7 @@ type FileCountStats struct { } // GetFileStats returns the total number of files and related statistics in a workspace -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to count files in -// Returns: -// - result: statistics about the files in the workspace -// - error: any error that occurred during counting +// Workspace is identified by the given userID and workspaceID func (s *Service) GetFileStats(userID, workspaceID int) (*FileCountStats, error) { workspacePath := s.GetWorkspacePath(userID, workspaceID) @@ -223,8 +191,6 @@ func (s *Service) GetFileStats(userID, workspaceID int) (*FileCountStats, error) } // GetTotalFileStats returns the total file statistics for the storage. -// Returns: -// - result: statistics about the files in the storage func (s *Service) GetTotalFileStats() (*FileCountStats, error) { return s.countFilesInPath(s.RootDir) } diff --git a/server/internal/storage/git.go b/server/internal/storage/git.go index 09d3b0f..49d564b 100644 --- a/server/internal/storage/git.go +++ b/server/internal/storage/git.go @@ -13,15 +13,8 @@ type RepositoryManager interface { Pull(userID, workspaceID int) error } -// SetupGitRepo sets up a Git repository for the given user and workspace IDs. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to set up the Git repository for -// - gitURL: the URL of the Git repository -// - gitUser: the username for the Git repository -// - gitToken: the access token for the Git repository -// Returns: -// - error: any error that occurred during setup +// SetupGitRepo sets up a Git repository for the given userID and workspaceID. +// The repository is cloned from the given gitURL using the given gitUser and gitToken. func (s *Service) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToken string) error { workspacePath := s.GetWorkspacePath(userID, workspaceID) if _, ok := s.GitRepos[userID]; !ok { @@ -31,10 +24,7 @@ func (s *Service) SetupGitRepo(userID, workspaceID int, gitURL, gitUser, gitToke return s.GitRepos[userID][workspaceID].EnsureRepo() } -// DisableGitRepo disables the Git repository for the given user and workspace IDs. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to disable the Git repository for +// DisableGitRepo disables the Git repository for the given userID and workspaceID. func (s *Service) DisableGitRepo(userID, workspaceID int) { if userRepos, ok := s.GitRepos[userID]; ok { delete(userRepos, workspaceID) @@ -44,13 +34,8 @@ func (s *Service) DisableGitRepo(userID, workspaceID int) { } } -// StageCommitAndPush stages, commits, and pushes the changes to the Git repository. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to commit and push -// - message: the commit message -// Returns: -// - error: any error that occurred during the operation +// StageCommitAndPush stages, commit with the message, and pushes the changes to the Git repository. +// The git repository belongs to the given userID and is associated with the given workspaceID. func (s *Service) StageCommitAndPush(userID, workspaceID int, message string) error { repo, ok := s.getGitRepo(userID, workspaceID) if !ok { @@ -65,11 +50,7 @@ func (s *Service) StageCommitAndPush(userID, workspaceID int, message string) er } // Pull pulls the changes from the remote Git repository. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to pull changes for -// Returns: -// - error: any error that occurred during the operation +// The git repository belongs to the given userID and is associated with the given workspaceID. func (s *Service) Pull(userID, workspaceID int) error { repo, ok := s.getGitRepo(userID, workspaceID) if !ok { diff --git a/server/internal/storage/service.go b/server/internal/storage/service.go index 758e4d2..7b8b0f5 100644 --- a/server/internal/storage/service.go +++ b/server/internal/storage/service.go @@ -25,11 +25,7 @@ type Options struct { NewGitClient func(url, user, token, path string) git.Client } -// NewService creates a new Storage instance. -// Parameters: -// - rootDir: the root directory for the storage -// Returns: -// - result: the new Storage instance +// NewService creates a new Storage instance with the default options and the given rootDir root directory. func NewService(rootDir string) *Service { return NewServiceWithOptions(rootDir, Options{ Fs: &osFS{}, @@ -37,16 +33,11 @@ func NewService(rootDir string) *Service { }) } -// NewServiceWithOptions creates a new Storage instance with the given options. -// Parameters: -// - rootDir: the root directory for the storage -// - opts: the options for the storage service -// Returns: -// - result: the new Storage instance -func NewServiceWithOptions(rootDir string, opts Options) *Service { +// NewServiceWithOptions creates a new Storage instance with the given options and the given rootDir root directory. +func NewServiceWithOptions(rootDir string, options Options) *Service { return &Service{ - fs: opts.Fs, - newGitClient: opts.NewGitClient, + fs: options.Fs, + newGitClient: options.NewGitClient, RootDir: rootDir, GitRepos: make(map[int]map[int]git.Client), } diff --git a/server/internal/storage/workspace.go b/server/internal/storage/workspace.go index 2a4beca..560a1b0 100644 --- a/server/internal/storage/workspace.go +++ b/server/internal/storage/workspace.go @@ -14,14 +14,8 @@ type WorkspaceManager interface { DeleteUserWorkspace(userID, workspaceID int) error } -// ValidatePath validates the given path and returns the cleaned path if it is valid. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to validate the path for -// - path: the path to validate -// Returns: -// - result: the cleaned path if it is valid -// - error: any error that occurred during validation +// ValidatePath validates the if the given path is valid within the workspace directory. +// Workspace directory is defined as the directory for the given userID and workspaceID. func (s *Service) ValidatePath(userID, workspaceID int, path string) (string, error) { workspacePath := s.GetWorkspacePath(userID, workspaceID) @@ -42,22 +36,12 @@ func (s *Service) ValidatePath(userID, workspaceID int, path string) (string, er return cleanPath, nil } -// GetWorkspacePath returns the path to the workspace directory for the given user and workspace IDs. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace -// Returns: -// - result: the path to the workspace directory +// GetWorkspacePath returns the path to the workspace directory for the given userID and workspaceID. func (s *Service) GetWorkspacePath(userID, workspaceID int) string { return filepath.Join(s.RootDir, fmt.Sprintf("%d", userID), fmt.Sprintf("%d", workspaceID)) } -// InitializeUserWorkspace creates the workspace directory for the given user and workspace IDs. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to initialize -// Returns: -// - error: any error that occurred during the operation +// InitializeUserWorkspace creates the workspace directory for the given userID and workspaceID. func (s *Service) InitializeUserWorkspace(userID, workspaceID int) error { workspacePath := s.GetWorkspacePath(userID, workspaceID) err := s.fs.MkdirAll(workspacePath, 0755) @@ -68,12 +52,7 @@ func (s *Service) InitializeUserWorkspace(userID, workspaceID int) error { return nil } -// DeleteUserWorkspace deletes the workspace directory for the given user and workspace IDs. -// Parameters: -// - userID: the ID of the user who owns the workspace -// - workspaceID: the ID of the workspace to delete -// Returns: -// - error: any error that occurred during the operation +// DeleteUserWorkspace deletes the workspace directory for the given userID and workspaceID. func (s *Service) DeleteUserWorkspace(userID, workspaceID int) error { workspacePath := s.GetWorkspacePath(userID, workspaceID) err := s.fs.RemoveAll(workspacePath) From 9b4db528caf8357341e04c55ca910e97f2702192 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 28 Nov 2024 21:55:01 +0100 Subject: [PATCH 30/38] Fix lint issues --- server/internal/auth/middleware_test.go | 2 +- server/internal/auth/session_test.go | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/server/internal/auth/middleware_test.go b/server/internal/auth/middleware_test.go index 71bdc9d..06c44be 100644 --- a/server/internal/auth/middleware_test.go +++ b/server/internal/auth/middleware_test.go @@ -95,7 +95,7 @@ func TestAuthenticateMiddleware(t *testing.T) { // Create test handler nextCalled := false - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { nextCalled = true w.WriteHeader(http.StatusOK) }) diff --git a/server/internal/auth/session_test.go b/server/internal/auth/session_test.go index 00c7494..5457c37 100644 --- a/server/internal/auth/session_test.go +++ b/server/internal/auth/session_test.go @@ -165,7 +165,9 @@ func TestRefreshSession(t *testing.T) { ExpiresAt: time.Now().Add(24 * time.Hour), CreatedAt: time.Now(), } - mockDB.CreateSession(session) + if err := mockDB.CreateSession(session); err != nil { + t.Fatalf("failed to create session: %v", err) + } return token }, wantErr: false, @@ -181,7 +183,9 @@ func TestRefreshSession(t *testing.T) { ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired CreatedAt: time.Now().Add(-2 * time.Hour), } - mockDB.CreateSession(session) + if err := mockDB.CreateSession(session); err != nil { + t.Fatalf("failed to create session: %v", err) + } return token }, wantErr: true, @@ -255,7 +259,9 @@ func TestInvalidateSession(t *testing.T) { ExpiresAt: time.Now().Add(24 * time.Hour), CreatedAt: time.Now(), } - mockDB.CreateSession(session) + if err := mockDB.CreateSession(session); err != nil { + t.Fatalf("failed to create session: %v", err) + } return session.ID }, wantErr: false, From 6aa3fd6c655ccfb9e62a05628b413b41651ebbbf Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 28 Nov 2024 22:05:27 +0100 Subject: [PATCH 31/38] Add script for generating single file documentation --- server/gendocs.sh | 50 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100755 server/gendocs.sh diff --git a/server/gendocs.sh b/server/gendocs.sh new file mode 100755 index 0000000..01b0643 --- /dev/null +++ b/server/gendocs.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +set -euo pipefail + +# Function to generate anchor from package path +generate_anchor() { + echo "$1" | tr '/' '-' +} + +# Create documentation file +echo "# NovaMD Package Documentation + +Generated documentation for all packages in the NovaMD project. + +## Table of Contents +" > documentation.md + +# Find all directories containing .go files (excluding test files) +# Sort them for consistent output +PACKAGES=$(find . -type f -name "*.go" ! -name "*_test.go" -exec dirname {} \; | sort -u | grep -v "/\.") + +# Generate table of contents +for PKG in $PACKAGES; do + # Strip leading ./ + PKG_PATH=${PKG#./} + # Skip if empty + [ -z "$PKG_PATH" ] && continue + + ANCHOR=$(generate_anchor "$PKG_PATH") + echo "- [$PKG_PATH](#$ANCHOR)" >> documentation.md +done + +echo "" >> documentation.md + +# Generate documentation for each package +for PKG in $PACKAGES; do + # Strip leading ./ + PKG_PATH=${PKG#./} + # Skip if empty + [ -z "$PKG_PATH" ] && continue + + echo "## $PKG_PATH" >> documentation.md + echo "" >> documentation.md + echo '```go' >> documentation.md + go doc -all "./$PKG_PATH" >> documentation.md + echo '```' >> documentation.md + echo "" >> documentation.md +done + +echo "Documentation generated in documentation.md" \ No newline at end of file From 1ddf93a8bec9e95c8507d7edd27622b751e08d36 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 29 Nov 2024 23:14:36 +0100 Subject: [PATCH 32/38] Implement git handlers integration test --- .../handlers/git_handlers_integration_test.go | 181 ++++++++++++++++++ server/internal/handlers/integration_test.go | 15 +- server/internal/handlers/mock_git_test.go | 123 ++++++++++++ .../internal/handlers/workspace_handlers.go | 13 ++ server/internal/models/workspace.go | 23 ++- server/internal/storage/service.go | 8 + 6 files changed, 352 insertions(+), 11 deletions(-) create mode 100644 server/internal/handlers/git_handlers_integration_test.go create mode 100644 server/internal/handlers/mock_git_test.go diff --git a/server/internal/handlers/git_handlers_integration_test.go b/server/internal/handlers/git_handlers_integration_test.go new file mode 100644 index 0000000..6d26039 --- /dev/null +++ b/server/internal/handlers/git_handlers_integration_test.go @@ -0,0 +1,181 @@ +//go:build integration + +package handlers_test + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "testing" + + "novamd/internal/models" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGitHandlers_Integration(t *testing.T) { + h := setupTestHarness(t) + defer h.teardown(t) + + t.Run("git operations", func(t *testing.T) { + // Setup: Create a workspace with Git enabled + workspace := &models.Workspace{ + UserID: h.RegularUser.ID, + Name: "Git Test Workspace", + GitEnabled: true, + GitURL: "https://github.com/test/repo.git", + GitUser: "testuser", + GitToken: "testtoken", + GitAutoCommit: true, + GitCommitMsgTemplate: "Update: {{message}}", + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + err := json.NewDecoder(rr.Body).Decode(workspace) + require.NoError(t, err) + + // Construct base URL for Git operations + baseURL := "/api/v1/workspaces/" + url.PathEscape(workspace.Name) + "/git" + + t.Run("stage, commit and push", func(t *testing.T) { + h.MockGit.Reset() + + t.Run("successful commit", func(t *testing.T) { + commitMsg := "Test commit message" + requestBody := map[string]string{ + "message": commitMsg, + } + + rr := h.makeRequest(t, http.MethodPost, baseURL+"/commit", requestBody, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var response map[string]string + err := json.NewDecoder(rr.Body).Decode(&response) + require.NoError(t, err) + assert.Contains(t, response["message"], "successfully") + + // Verify mock was called correctly + assert.Equal(t, 1, h.MockGit.GetCommitCount(), "Commit should be called once") + assert.Equal(t, 1, h.MockGit.GetPushCount(), "Push should be called once") + assert.Equal(t, commitMsg, h.MockGit.GetLastCommitMessage(), "Commit message should match") + }) + + t.Run("empty commit message", func(t *testing.T) { + h.MockGit.Reset() + requestBody := map[string]string{ + "message": "", + } + + rr := h.makeRequest(t, http.MethodPost, baseURL+"/commit", requestBody, h.RegularToken, nil) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Equal(t, 0, h.MockGit.GetCommitCount(), "Commit should not be called") + }) + + t.Run("git error", func(t *testing.T) { + h.MockGit.Reset() + h.MockGit.SetError(fmt.Errorf("mock git error")) + + requestBody := map[string]string{ + "message": "Test message", + } + + rr := h.makeRequest(t, http.MethodPost, baseURL+"/commit", requestBody, h.RegularToken, nil) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + + h.MockGit.SetError(nil) // Reset error state + }) + }) + + t.Run("pull changes", func(t *testing.T) { + h.MockGit.Reset() + + t.Run("successful pull", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodPost, baseURL+"/pull", nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var response map[string]string + err := json.NewDecoder(rr.Body).Decode(&response) + require.NoError(t, err) + assert.Contains(t, response["message"], "Pulled changes") + + assert.Equal(t, 1, h.MockGit.GetPullCount(), "Pull should be called once") + }) + + t.Run("git error", func(t *testing.T) { + h.MockGit.Reset() + h.MockGit.SetError(fmt.Errorf("mock git error")) + + rr := h.makeRequest(t, http.MethodPost, baseURL+"/pull", nil, h.RegularToken, nil) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + + h.MockGit.SetError(nil) // Reset error state + }) + }) + + t.Run("unauthorized access", func(t *testing.T) { + h.MockGit.Reset() + + tests := []struct { + name string + method string + path string + body interface{} + }{ + { + name: "commit without token", + method: http.MethodPost, + path: baseURL + "/commit", + body: map[string]string{"message": "test"}, + }, + { + name: "pull without token", + method: http.MethodPost, + path: baseURL + "/pull", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Test without token + rr := h.makeRequest(t, tc.method, tc.path, tc.body, "", nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + + // Test with wrong user's token + rr = h.makeRequest(t, tc.method, tc.path, tc.body, h.AdminToken, nil) + assert.Equal(t, http.StatusNotFound, rr.Code) + }) + } + }) + + t.Run("workspace without git", func(t *testing.T) { + h.MockGit.Reset() + + // Create a workspace without Git enabled + nonGitWorkspace := &models.Workspace{ + UserID: h.RegularUser.ID, + Name: "Non-Git Workspace", + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", nonGitWorkspace, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + err := json.NewDecoder(rr.Body).Decode(nonGitWorkspace) + require.NoError(t, err) + + nonGitBaseURL := "/api/v1/workspaces/" + url.PathEscape(nonGitWorkspace.Name) + "/git" + + // Try to commit + commitMsg := map[string]string{"message": "test"} + rr = h.makeRequest(t, http.MethodPost, nonGitBaseURL+"/commit", commitMsg, h.RegularToken, nil) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + + // Try to pull + rr = h.makeRequest(t, http.MethodPost, nonGitBaseURL+"/pull", nil, h.RegularToken, nil) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + }) + }) +} diff --git a/server/internal/handlers/integration_test.go b/server/internal/handlers/integration_test.go index 087601a..5c65f16 100644 --- a/server/internal/handlers/integration_test.go +++ b/server/internal/handlers/integration_test.go @@ -17,6 +17,7 @@ import ( "novamd/internal/api" "novamd/internal/auth" "novamd/internal/db" + "novamd/internal/git" "novamd/internal/handlers" "novamd/internal/models" "novamd/internal/secrets" @@ -36,6 +37,7 @@ type testHarness struct { RegularUser *models.User RegularToken string TempDirectory string + MockGit *MockGitClient } // setupTestHarness creates a new test environment @@ -63,8 +65,16 @@ func setupTestHarness(t *testing.T) *testHarness { t.Fatalf("Failed to run migrations: %v", err) } - // Initialize storage - storageSvc := storage.NewService(tempDir) + // Create mock git client + mockGit := NewMockGitClient(false) + + // Create storage with mock git client + storageOpts := storage.Options{ + NewGitClient: func(url, user, token, path string) git.Client { + return mockGit + }, + } + storageSvc := storage.NewServiceWithOptions(tempDir, storageOpts) // Initialize JWT service jwtSvc, err := auth.NewJWTService(auth.JWTConfig{ @@ -100,6 +110,7 @@ func setupTestHarness(t *testing.T) *testHarness { JWTManager: jwtSvc, SessionSvc: sessionSvc, TempDirectory: tempDir, + MockGit: mockGit, } // Create test users diff --git a/server/internal/handlers/mock_git_test.go b/server/internal/handlers/mock_git_test.go new file mode 100644 index 0000000..23259f5 --- /dev/null +++ b/server/internal/handlers/mock_git_test.go @@ -0,0 +1,123 @@ +//go:build integration + +package handlers_test + +import ( + "fmt" +) + +// MockGitClient implements the git.Client interface for testing +type MockGitClient struct { + initialized bool + cloned bool + lastCommitMsg string + error error + + pullCount int + commitCount int + pushCount int + cloneCount int + ensureCount int +} + +// NewMockGitClient creates a new mock git client +func NewMockGitClient(shouldError bool) *MockGitClient { + var err error + if shouldError { + err = fmt.Errorf("mock git error") + } + return &MockGitClient{ + error: err, + } +} + +// Clone implements git.Client +func (m *MockGitClient) Clone() error { + if m.error != nil { + return m.error + } + m.cloneCount++ + m.cloned = true + return nil +} + +// Pull implements git.Client +func (m *MockGitClient) Pull() error { + if m.error != nil { + return m.error + } + m.pullCount++ + return nil +} + +// Commit implements git.Client +func (m *MockGitClient) Commit(message string) error { + if m.error != nil { + return m.error + } + m.commitCount++ + m.lastCommitMsg = message + return nil +} + +// Push implements git.Client +func (m *MockGitClient) Push() error { + if m.error != nil { + return m.error + } + m.pushCount++ + return nil +} + +// EnsureRepo implements git.Client +func (m *MockGitClient) EnsureRepo() error { + if m.error != nil { + return m.error + } + m.ensureCount++ + m.initialized = true + return nil +} + +// Helper methods for tests + +func (m *MockGitClient) GetCommitCount() int { + return m.commitCount +} + +func (m *MockGitClient) GetPushCount() int { + return m.pushCount +} + +func (m *MockGitClient) GetPullCount() int { + return m.pullCount +} + +func (m *MockGitClient) GetLastCommitMessage() string { + return m.lastCommitMsg +} + +func (m *MockGitClient) IsInitialized() bool { + return m.initialized +} + +func (m *MockGitClient) IsCloned() bool { + return m.cloned +} + +// Reset resets all counters and states +func (m *MockGitClient) Reset() { + m.initialized = false + m.cloned = false + m.lastCommitMsg = "" + m.pullCount = 0 + m.commitCount = 0 + m.pushCount = 0 + m.cloneCount = 0 + m.ensureCount = 0 +} + +// SetError sets the error state +func (m *MockGitClient) SetError(err error) { + m.error = err +} diff --git a/server/internal/handlers/workspace_handlers.go b/server/internal/handlers/workspace_handlers.go index 10ae3cf..06a2503 100644 --- a/server/internal/handlers/workspace_handlers.go +++ b/server/internal/handlers/workspace_handlers.go @@ -52,6 +52,19 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc { return } + if workspace.GitEnabled { + if err := h.Storage.SetupGitRepo( + ctx.UserID, + workspace.ID, + workspace.GitURL, + workspace.GitUser, + workspace.GitToken, + ); err != nil { + http.Error(w, "Failed to setup git repo: "+err.Error(), http.StatusInternalServerError) + return + } + } + respondJSON(w, workspace) } } diff --git a/server/internal/models/workspace.go b/server/internal/models/workspace.go index 191e6d7..5a0b7eb 100644 --- a/server/internal/models/workspace.go +++ b/server/internal/models/workspace.go @@ -31,13 +31,18 @@ func (w *Workspace) Validate() error { // SetDefaultSettings sets the default settings for the workspace func (w *Workspace) SetDefaultSettings() { - w.Theme = "light" - w.AutoSave = false - w.ShowHiddenFiles = false - w.GitEnabled = false - w.GitURL = "" - w.GitUser = "" - w.GitToken = "" - w.GitAutoCommit = false - w.GitCommitMsgTemplate = "${action} ${filename}" + + if w.Theme == "" { + w.Theme = "light" + } + + w.AutoSave = w.AutoSave || false + w.ShowHiddenFiles = w.ShowHiddenFiles || false + w.GitEnabled = w.GitEnabled || false + + w.GitAutoCommit = w.GitEnabled && (w.GitAutoCommit || false) + + if w.GitCommitMsgTemplate == "" { + w.GitCommitMsgTemplate = "${action} ${filename}" + } } diff --git a/server/internal/storage/service.go b/server/internal/storage/service.go index 7b8b0f5..0516b5f 100644 --- a/server/internal/storage/service.go +++ b/server/internal/storage/service.go @@ -35,6 +35,14 @@ func NewService(rootDir string) *Service { // NewServiceWithOptions creates a new Storage instance with the given options and the given rootDir root directory. func NewServiceWithOptions(rootDir string, options Options) *Service { + if options.Fs == nil { + options.Fs = &osFS{} + } + + if options.NewGitClient == nil { + options.NewGitClient = git.New + } + return &Service{ fs: options.Fs, newGitClient: options.NewGitClient, From d47b601447e9b1c4c0ee3301a14ffa81006090eb Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 29 Nov 2024 23:15:37 +0100 Subject: [PATCH 33/38] Rename mock secrets --- server/internal/db/{testutil_test.go => mock_secrets_test.go} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename server/internal/db/{testutil_test.go => mock_secrets_test.go} (100%) diff --git a/server/internal/db/testutil_test.go b/server/internal/db/mock_secrets_test.go similarity index 100% rename from server/internal/db/testutil_test.go rename to server/internal/db/mock_secrets_test.go From af9ab42969ec62a6653dcd0614fae7380c46166f Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 29 Nov 2024 23:57:17 +0100 Subject: [PATCH 34/38] Add integration tests for use handlers --- server/internal/handlers/user_handlers.go | 35 --- .../user_handlers_integration_test.go | 212 ++++++++++++++++++ 2 files changed, 212 insertions(+), 35 deletions(-) create mode 100644 server/internal/handlers/user_handlers_integration_test.go diff --git a/server/internal/handlers/user_handlers.go b/server/internal/handlers/user_handlers.go index 0249327..4a0a7d9 100644 --- a/server/internal/handlers/user_handlers.go +++ b/server/internal/handlers/user_handlers.go @@ -23,24 +23,6 @@ type DeleteAccountRequest struct { Password string `json:"password"` } -// GetUser returns the current user's profile -func (h *Handler) GetUser() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx, ok := context.GetRequestContext(w, r) - if !ok { - return - } - - user, err := h.DB.GetUserByID(ctx.UserID) - if err != nil { - http.Error(w, "Failed to get user", http.StatusInternalServerError) - return - } - - respondJSON(w, user) - } -} - // UpdateProfile updates the current user's profile func (h *Handler) UpdateProfile() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { @@ -62,18 +44,6 @@ func (h *Handler) UpdateProfile() http.HandlerFunc { return } - // Start transaction for atomic updates - tx, err := h.DB.Begin() - if err != nil { - http.Error(w, "Failed to start transaction", http.StatusInternalServerError) - return - } - defer func() { - if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { - http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError) - } - }() - // Handle password update if requested if req.NewPassword != "" { // Current password must be provided to change password @@ -139,11 +109,6 @@ func (h *Handler) UpdateProfile() http.HandlerFunc { return } - if err := tx.Commit(); err != nil { - http.Error(w, "Failed to commit changes", http.StatusInternalServerError) - return - } - // Return updated user data respondJSON(w, user) } diff --git a/server/internal/handlers/user_handlers_integration_test.go b/server/internal/handlers/user_handlers_integration_test.go new file mode 100644 index 0000000..bfcc38e --- /dev/null +++ b/server/internal/handlers/user_handlers_integration_test.go @@ -0,0 +1,212 @@ +//go:build integration + +package handlers_test + +import ( + "encoding/json" + "net/http" + "testing" + + "novamd/internal/handlers" + "novamd/internal/models" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUserHandlers_Integration(t *testing.T) { + h := setupTestHarness(t) + defer h.teardown(t) + + t.Run("get current user", func(t *testing.T) { + t.Run("successful get", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var user models.User + err := json.NewDecoder(rr.Body).Decode(&user) + require.NoError(t, err) + + assert.Equal(t, h.RegularUser.ID, user.ID) + assert.Equal(t, h.RegularUser.Email, user.Email) + assert.Equal(t, h.RegularUser.DisplayName, user.DisplayName) + assert.Equal(t, h.RegularUser.Role, user.Role) + assert.Empty(t, user.PasswordHash, "Password hash should not be included in response") + }) + + t.Run("unauthorized", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, "", nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + }) + + t.Run("update profile", func(t *testing.T) { + t.Run("update display name only", func(t *testing.T) { + updateReq := handlers.UpdateProfileRequest{ + DisplayName: "Updated Name", + } + + rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var user models.User + err := json.NewDecoder(rr.Body).Decode(&user) + require.NoError(t, err) + assert.Equal(t, updateReq.DisplayName, user.DisplayName) + }) + + t.Run("update email", func(t *testing.T) { + updateReq := handlers.UpdateProfileRequest{ + Email: "newemail@test.com", + CurrentPassword: "user123", // Regular user's password from test harness + } + + rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var user models.User + err := json.NewDecoder(rr.Body).Decode(&user) + require.NoError(t, err) + assert.Equal(t, updateReq.Email, user.Email) + }) + + t.Run("update email without password", func(t *testing.T) { + updateReq := handlers.UpdateProfileRequest{ + Email: "anotheremail@test.com", + } + + rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil) + assert.Equal(t, http.StatusBadRequest, rr.Code) + }) + + t.Run("update email with wrong password", func(t *testing.T) { + updateReq := handlers.UpdateProfileRequest{ + Email: "wrongpass@test.com", + CurrentPassword: "wrongpassword", + } + + rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + + t.Run("update password", func(t *testing.T) { + updateReq := handlers.UpdateProfileRequest{ + CurrentPassword: "user123", + NewPassword: "newpassword123", + } + + rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + // Verify can login with new password + loginReq := handlers.LoginRequest{ + Email: h.RegularUser.Email, + Password: "newpassword123", + } + + rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil) + assert.Equal(t, http.StatusOK, rr.Code) + }) + + t.Run("update password without current password", func(t *testing.T) { + updateReq := handlers.UpdateProfileRequest{ + NewPassword: "newpass123", + } + + rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil) + assert.Equal(t, http.StatusBadRequest, rr.Code) + }) + + t.Run("update password with wrong current password", func(t *testing.T) { + updateReq := handlers.UpdateProfileRequest{ + CurrentPassword: "wrongpassword", + NewPassword: "newpass123", + } + + rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + + t.Run("update with short password", func(t *testing.T) { + updateReq := handlers.UpdateProfileRequest{ + CurrentPassword: "user123", + NewPassword: "short", + } + + rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil) + assert.Equal(t, http.StatusBadRequest, rr.Code) + }) + + t.Run("duplicate email", func(t *testing.T) { + updateReq := handlers.UpdateProfileRequest{ + Email: h.AdminUser.Email, + CurrentPassword: "user123", + } + + rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil) + assert.Equal(t, http.StatusConflict, rr.Code) + }) + }) + + t.Run("delete account", func(t *testing.T) { + // Create a new user that we can delete + createReq := handlers.CreateUserRequest{ + Email: "todelete@test.com", + DisplayName: "To Delete", + Password: "password123", + Role: models.RoleEditor, + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/admin/users", createReq, h.AdminToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var newUser models.User + err := json.NewDecoder(rr.Body).Decode(&newUser) + require.NoError(t, err) + + // Get token for new user + loginReq := handlers.LoginRequest{ + Email: createReq.Email, + Password: createReq.Password, + } + + rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil) + require.Equal(t, http.StatusOK, rr.Code) + + var loginResp handlers.LoginResponse + err = json.NewDecoder(rr.Body).Decode(&loginResp) + require.NoError(t, err) + userToken := loginResp.AccessToken + + t.Run("successful delete", func(t *testing.T) { + deleteReq := handlers.DeleteAccountRequest{ + Password: createReq.Password, + } + + rr := h.makeRequest(t, http.MethodDelete, "/api/v1/profile", deleteReq, userToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + // Verify user is deleted + rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + + t.Run("delete with wrong password", func(t *testing.T) { + deleteReq := handlers.DeleteAccountRequest{ + Password: "wrongpassword", + } + + rr := h.makeRequest(t, http.MethodDelete, "/api/v1/profile", deleteReq, h.RegularToken, nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + + t.Run("prevent last admin deletion", func(t *testing.T) { + deleteReq := handlers.DeleteAccountRequest{ + Password: "admin123", // Admin password from test harness + } + + rr := h.makeRequest(t, http.MethodDelete, "/api/v1/profile", deleteReq, h.AdminToken, nil) + assert.Equal(t, http.StatusForbidden, rr.Code) + }) + }) +} From 2a53be5a6ea913d85af22fd8684c96c52ff7145a Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 30 Nov 2024 00:09:01 +0100 Subject: [PATCH 35/38] Fix user update tests --- .../handlers/user_handlers_integration_test.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/server/internal/handlers/user_handlers_integration_test.go b/server/internal/handlers/user_handlers_integration_test.go index bfcc38e..1722273 100644 --- a/server/internal/handlers/user_handlers_integration_test.go +++ b/server/internal/handlers/user_handlers_integration_test.go @@ -18,6 +18,9 @@ func TestUserHandlers_Integration(t *testing.T) { h := setupTestHarness(t) defer h.teardown(t) + currentEmail := h.RegularUser.Email + currentPassword := "user123" + t.Run("get current user", func(t *testing.T) { t.Run("successful get", func(t *testing.T) { rr := h.makeRequest(t, http.MethodGet, "/api/v1/auth/me", nil, h.RegularToken, nil) @@ -58,7 +61,7 @@ func TestUserHandlers_Integration(t *testing.T) { t.Run("update email", func(t *testing.T) { updateReq := handlers.UpdateProfileRequest{ Email: "newemail@test.com", - CurrentPassword: "user123", // Regular user's password from test harness + CurrentPassword: currentPassword, } rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil) @@ -68,6 +71,8 @@ func TestUserHandlers_Integration(t *testing.T) { err := json.NewDecoder(rr.Body).Decode(&user) require.NoError(t, err) assert.Equal(t, updateReq.Email, user.Email) + + currentEmail = updateReq.Email }) t.Run("update email without password", func(t *testing.T) { @@ -91,7 +96,7 @@ func TestUserHandlers_Integration(t *testing.T) { t.Run("update password", func(t *testing.T) { updateReq := handlers.UpdateProfileRequest{ - CurrentPassword: "user123", + CurrentPassword: currentPassword, NewPassword: "newpassword123", } @@ -100,12 +105,14 @@ func TestUserHandlers_Integration(t *testing.T) { // Verify can login with new password loginReq := handlers.LoginRequest{ - Email: h.RegularUser.Email, + Email: currentEmail, Password: "newpassword123", } rr = h.makeRequest(t, http.MethodPost, "/api/v1/auth/login", loginReq, "", nil) assert.Equal(t, http.StatusOK, rr.Code) + + currentPassword = updateReq.NewPassword }) t.Run("update password without current password", func(t *testing.T) { @@ -129,7 +136,7 @@ func TestUserHandlers_Integration(t *testing.T) { t.Run("update with short password", func(t *testing.T) { updateReq := handlers.UpdateProfileRequest{ - CurrentPassword: "user123", + CurrentPassword: currentPassword, NewPassword: "short", } @@ -140,7 +147,7 @@ func TestUserHandlers_Integration(t *testing.T) { t.Run("duplicate email", func(t *testing.T) { updateReq := handlers.UpdateProfileRequest{ Email: h.AdminUser.Email, - CurrentPassword: "user123", + CurrentPassword: currentPassword, } rr := h.makeRequest(t, http.MethodPut, "/api/v1/profile", updateReq, h.RegularToken, nil) From 8bed3614ee46f81dcd3bb8f312b0a944186bf770 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 30 Nov 2024 00:12:14 +0100 Subject: [PATCH 36/38] Fix user deletion handler --- server/internal/handlers/user_handlers.go | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/server/internal/handlers/user_handlers.go b/server/internal/handlers/user_handlers.go index 4a0a7d9..210dd64 100644 --- a/server/internal/handlers/user_handlers.go +++ b/server/internal/handlers/user_handlers.go @@ -1,7 +1,6 @@ package handlers import ( - "database/sql" "encoding/json" "net/http" @@ -155,18 +154,6 @@ func (h *Handler) DeleteAccount() http.HandlerFunc { } } - // Start transaction for consistent deletion - tx, err := h.DB.Begin() - if err != nil { - http.Error(w, "Failed to start transaction", http.StatusInternalServerError) - return - } - defer func() { - if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { - http.Error(w, "Failed to rollback transaction", http.StatusInternalServerError) - } - }() - // Get user's workspaces for cleanup workspaces, err := h.DB.GetWorkspacesByUserID(ctx.UserID) if err != nil { @@ -188,11 +175,6 @@ func (h *Handler) DeleteAccount() http.HandlerFunc { return } - if err := tx.Commit(); err != nil { - http.Error(w, "Failed to commit transaction", http.StatusInternalServerError) - return - } - respondJSON(w, map[string]string{"message": "Account deleted successfully"}) } } From ae48761d349e50331f51b4692a004c211cb9376e Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 30 Nov 2024 11:44:17 +0100 Subject: [PATCH 37/38] Implement workspace handlers integration tests --- .../internal/handlers/workspace_handlers.go | 5 + .../workspace_handlers_integration_test.go | 297 ++++++++++++++++++ server/internal/models/workspace.go | 5 + 3 files changed, 307 insertions(+) create mode 100644 server/internal/handlers/workspace_handlers_integration_test.go diff --git a/server/internal/handlers/workspace_handlers.go b/server/internal/handlers/workspace_handlers.go index 06a2503..dc4d94d 100644 --- a/server/internal/handlers/workspace_handlers.go +++ b/server/internal/handlers/workspace_handlers.go @@ -41,6 +41,11 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc { return } + if err := workspace.ValidateGitSettings(); err != nil { + http.Error(w, "Invalid workspace", http.StatusBadRequest) + return + } + workspace.UserID = ctx.UserID if err := h.DB.CreateWorkspace(&workspace); err != nil { http.Error(w, "Failed to create workspace", http.StatusInternalServerError) diff --git a/server/internal/handlers/workspace_handlers_integration_test.go b/server/internal/handlers/workspace_handlers_integration_test.go new file mode 100644 index 0000000..724350c --- /dev/null +++ b/server/internal/handlers/workspace_handlers_integration_test.go @@ -0,0 +1,297 @@ +//go:build integration + +package handlers_test + +import ( + "encoding/json" + "net/http" + "net/url" + "testing" + + "novamd/internal/models" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWorkspaceHandlers_Integration(t *testing.T) { + h := setupTestHarness(t) + defer h.teardown(t) + + t.Run("list workspaces", func(t *testing.T) { + t.Run("successful list", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces", nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var workspaces []*models.Workspace + err := json.NewDecoder(rr.Body).Decode(&workspaces) + require.NoError(t, err) + assert.NotEmpty(t, workspaces, "User should have at least one default workspace") + }) + + t.Run("unauthorized", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces", nil, "", nil) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + }) + + t.Run("create workspace", func(t *testing.T) { + t.Run("successful create", func(t *testing.T) { + workspace := &models.Workspace{ + Name: "Test Workspace", + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var created models.Workspace + err := json.NewDecoder(rr.Body).Decode(&created) + require.NoError(t, err) + assert.Equal(t, workspace.Name, created.Name) + assert.Equal(t, h.RegularUser.ID, created.UserID) + assert.NotZero(t, created.ID) + }) + + t.Run("create with git settings", func(t *testing.T) { + workspace := &models.Workspace{ + Name: "Git Workspace", + GitEnabled: true, + GitURL: "https://github.com/test/repo.git", + GitUser: "testuser", + GitToken: "testtoken", + GitAutoCommit: true, + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var created models.Workspace + err := json.NewDecoder(rr.Body).Decode(&created) + require.NoError(t, err) + assert.Equal(t, workspace.GitEnabled, created.GitEnabled) + assert.Equal(t, workspace.GitURL, created.GitURL) + assert.Equal(t, workspace.GitUser, created.GitUser) + assert.Equal(t, workspace.GitToken, created.GitToken) + assert.Equal(t, workspace.GitAutoCommit, created.GitAutoCommit) + }) + + t.Run("invalid workspace", func(t *testing.T) { + workspace := &models.Workspace{ + Name: "", // Empty name + GitEnabled: true, + // Missing required Git settings + } + + rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil) + assert.Equal(t, http.StatusBadRequest, rr.Code) + }) + }) + + // Create a workspace for the remaining tests + workspace := &models.Workspace{ + Name: "Test Workspace Operations", + } + rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + err := json.NewDecoder(rr.Body).Decode(workspace) + require.NoError(t, err) + + escapedName := url.PathEscape(workspace.Name) + baseURL := "/api/v1/workspaces/" + escapedName + + t.Run("get workspace", func(t *testing.T) { + t.Run("successful get", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodGet, baseURL, nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var got models.Workspace + err := json.NewDecoder(rr.Body).Decode(&got) + require.NoError(t, err) + assert.Equal(t, workspace.ID, got.ID) + assert.Equal(t, workspace.Name, got.Name) + }) + + t.Run("nonexistent workspace", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces/nonexistent", nil, h.RegularToken, nil) + assert.Equal(t, http.StatusNotFound, rr.Code) + }) + + t.Run("unauthorized access", func(t *testing.T) { + // Try accessing with another user's token + rr := h.makeRequest(t, http.MethodGet, baseURL, nil, h.AdminToken, nil) + assert.Equal(t, http.StatusNotFound, rr.Code) + }) + }) + + t.Run("update workspace", func(t *testing.T) { + t.Run("update name", func(t *testing.T) { + workspace.Name = "Updated Workspace" + + rr := h.makeRequest(t, http.MethodPut, baseURL, workspace, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var updated models.Workspace + err := json.NewDecoder(rr.Body).Decode(&updated) + require.NoError(t, err) + assert.Equal(t, workspace.Name, updated.Name) + + // Update baseURL for remaining tests + escapedName = url.PathEscape(workspace.Name) + baseURL = "/api/v1/workspaces/" + escapedName + }) + + t.Run("update settings", func(t *testing.T) { + update := &models.Workspace{ + Name: workspace.Name, + Theme: "dark", + AutoSave: true, + ShowHiddenFiles: true, + } + + rr := h.makeRequest(t, http.MethodPut, baseURL, update, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var updated models.Workspace + err := json.NewDecoder(rr.Body).Decode(&updated) + require.NoError(t, err) + assert.Equal(t, update.Theme, updated.Theme) + assert.Equal(t, update.AutoSave, updated.AutoSave) + assert.Equal(t, update.ShowHiddenFiles, updated.ShowHiddenFiles) + }) + + t.Run("enable git", func(t *testing.T) { + update := &models.Workspace{ + Name: workspace.Name, + Theme: "dark", + GitEnabled: true, + GitURL: "https://github.com/test/repo.git", + GitUser: "testuser", + GitToken: "testtoken", + GitAutoCommit: true, + } + + rr := h.makeRequest(t, http.MethodPut, baseURL, update, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var updated models.Workspace + err := json.NewDecoder(rr.Body).Decode(&updated) + require.NoError(t, err) + assert.Equal(t, update.GitEnabled, updated.GitEnabled) + assert.Equal(t, update.GitURL, updated.GitURL) + assert.Equal(t, update.GitUser, updated.GitUser) + assert.Equal(t, update.GitToken, updated.GitToken) + + // Mock should have been called to setup git + assert.True(t, h.MockGit.IsInitialized()) + }) + + t.Run("invalid git settings", func(t *testing.T) { + update := &models.Workspace{ + Name: workspace.Name, + GitEnabled: true, + // Missing required Git settings + } + + rr := h.makeRequest(t, http.MethodPut, baseURL, update, h.RegularToken, nil) + assert.Equal(t, http.StatusBadRequest, rr.Code) + }) + }) + + t.Run("last workspace", func(t *testing.T) { + t.Run("get last workspace", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces/last", nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var response struct { + LastWorkspaceName string `json:"lastWorkspaceName"` + } + err := json.NewDecoder(rr.Body).Decode(&response) + require.NoError(t, err) + assert.NotEmpty(t, response.LastWorkspaceName) + }) + + t.Run("update last workspace", func(t *testing.T) { + req := struct { + WorkspaceName string `json:"workspaceName"` + }{ + WorkspaceName: workspace.Name, + } + + rr := h.makeRequest(t, http.MethodPut, "/api/v1/workspaces/last", req, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + // Verify the update + rr = h.makeRequest(t, http.MethodGet, "/api/v1/workspaces/last", nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var response struct { + LastWorkspaceName string `json:"lastWorkspaceName"` + } + err := json.NewDecoder(rr.Body).Decode(&response) + require.NoError(t, err) + assert.Equal(t, workspace.Name, response.LastWorkspaceName) + }) + }) + + t.Run("delete workspace", func(t *testing.T) { + // Get current workspaces to know how many we have + rr := h.makeRequest(t, http.MethodGet, "/api/v1/workspaces", nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var existingWorkspaces []*models.Workspace + err := json.NewDecoder(rr.Body).Decode(&existingWorkspaces) + require.NoError(t, err) + + // Create a new workspace we can safely delete + newWorkspace := &models.Workspace{ + Name: "Workspace To Delete", + } + rr = h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", newWorkspace, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + err = json.NewDecoder(rr.Body).Decode(newWorkspace) + require.NoError(t, err) + + t.Run("successful delete", func(t *testing.T) { + rr := h.makeRequest(t, http.MethodDelete, "/api/v1/workspaces/"+url.PathEscape(newWorkspace.Name), nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + var response struct { + NextWorkspaceName string `json:"nextWorkspaceName"` + } + err := json.NewDecoder(rr.Body).Decode(&response) + require.NoError(t, err) + assert.NotEmpty(t, response.NextWorkspaceName) + + // Verify workspace is deleted + rr = h.makeRequest(t, http.MethodGet, "/api/v1/workspaces/"+url.PathEscape(newWorkspace.Name), nil, h.RegularToken, nil) + assert.Equal(t, http.StatusNotFound, rr.Code) + }) + + t.Run("prevent deleting last workspace", func(t *testing.T) { + // Delete all but one workspace + for i := 0; i < len(existingWorkspaces)-1; i++ { + ws := existingWorkspaces[i] + rr := h.makeRequest(t, http.MethodDelete, "/api/v1/workspaces/"+url.PathEscape(ws.Name), nil, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + } + + // Try to delete the last remaining workspace + lastWs := existingWorkspaces[len(existingWorkspaces)-1] + rr := h.makeRequest(t, http.MethodDelete, "/api/v1/workspaces/"+url.PathEscape(lastWs.Name), nil, h.RegularToken, nil) + assert.Equal(t, http.StatusBadRequest, rr.Code) + }) + + t.Run("unauthorized deletion", func(t *testing.T) { + // Create a workspace to attempt unauthorized deletion + workspace := &models.Workspace{ + Name: "Unauthorized Delete Test", + } + rr := h.makeRequest(t, http.MethodPost, "/api/v1/workspaces", workspace, h.RegularToken, nil) + require.Equal(t, http.StatusOK, rr.Code) + + // Try to delete with wrong user's token + rr = h.makeRequest(t, http.MethodDelete, "/api/v1/workspaces/"+url.PathEscape(workspace.Name), nil, h.AdminToken, nil) + assert.Equal(t, http.StatusNotFound, rr.Code) + }) + }) +} diff --git a/server/internal/models/workspace.go b/server/internal/models/workspace.go index 5a0b7eb..584155a 100644 --- a/server/internal/models/workspace.go +++ b/server/internal/models/workspace.go @@ -29,6 +29,11 @@ func (w *Workspace) Validate() error { return validate.Struct(w) } +// ValidateGitSettings validates the git settings if git is enabled +func (w *Workspace) ValidateGitSettings() error { + return validate.StructExcept(w, "ID", "UserID", "Theme") +} + // SetDefaultSettings sets the default settings for the workspace func (w *Workspace) SetDefaultSettings() { From 842513f8a5a884bdfe924620da22f8e3c4b8fd1d Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 30 Nov 2024 11:48:57 +0100 Subject: [PATCH 38/38] Add test tags to github workflow --- .github/workflows/go-test.yml | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/.github/workflows/go-test.yml b/.github/workflows/go-test.yml index 108f629..33bbbb6 100644 --- a/.github/workflows/go-test.yml +++ b/.github/workflows/go-test.yml @@ -27,13 +27,8 @@ jobs: go-version: "1.23" cache: true - # - name: Install dependencies - # run: | - # sudo apt-get update - # sudo apt-get install -y gcc - - name: Run Tests - run: go test ./... -v + run: go test -tags=test,integration ./... -v - name: Run Tests with Race Detector - run: go test -race ./... -v + run: go test -tags=test,integration -race ./... -v