diff --git a/server/internal/context/context_test.go b/server/internal/context/context_test.go new file mode 100644 index 0000000..8a0d947 --- /dev/null +++ b/server/internal/context/context_test.go @@ -0,0 +1,139 @@ +package context_test + +import ( + stdctx "context" + "net/http" + "net/http/httptest" + "testing" + + "novamd/internal/context" +) + +func TestGetRequestContext(t *testing.T) { + tests := []struct { + name string + setupCtx func() *context.HandlerContext + wantStatus int + wantOK bool + }{ + { + name: "valid context", + setupCtx: func() *context.HandlerContext { + return &context.HandlerContext{ + UserID: 1, + UserRole: "admin", + } + }, + wantStatus: http.StatusOK, + wantOK: true, + }, + { + name: "missing context", + setupCtx: func() *context.HandlerContext { + return nil + }, + wantStatus: http.StatusInternalServerError, + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create test request + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + if ctx := tt.setupCtx(); ctx != nil { + req = context.WithHandlerContext(req, ctx) + } + + gotCtx, ok := context.GetRequestContext(w, req) + + if ok != tt.wantOK { + t.Errorf("GetRequestContext() ok = %v, want %v", ok, tt.wantOK) + } + + if !tt.wantOK { + if w.Code != tt.wantStatus { + t.Errorf("GetRequestContext() status = %v, want %v", w.Code, tt.wantStatus) + } + return + } + + if gotCtx.UserID != tt.setupCtx().UserID { + t.Errorf("GetRequestContext() UserID = %v, want %v", gotCtx.UserID, tt.setupCtx().UserID) + } + + if gotCtx.UserRole != tt.setupCtx().UserRole { + t.Errorf("GetRequestContext() UserRole = %v, want %v", gotCtx.UserRole, tt.setupCtx().UserRole) + } + }) + } +} + +func TestGetUserFromContext(t *testing.T) { + tests := []struct { + name string + setupCtx func() stdctx.Context + wantUser *context.UserClaims + wantError bool + }{ + { + name: "valid user context", + setupCtx: func() stdctx.Context { + return stdctx.WithValue(stdctx.Background(), context.HandlerContextKey, &context.HandlerContext{ + UserID: 1, + UserRole: "admin", + }) + }, + wantUser: &context.UserClaims{ + UserID: 1, + Role: "admin", + }, + wantError: false, + }, + { + name: "missing context", + setupCtx: func() stdctx.Context { + return stdctx.Background() + }, + wantUser: nil, + wantError: true, + }, + { + name: "invalid context type", + setupCtx: func() stdctx.Context { + return stdctx.WithValue(stdctx.Background(), context.HandlerContextKey, "invalid") + }, + wantUser: nil, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.setupCtx() + gotUser, err := context.GetUserFromContext(ctx) + + if tt.wantError { + if err == nil { + t.Error("GetUserFromContext() error = nil, want error") + } + return + } + + if err != nil { + t.Errorf("GetUserFromContext() unexpected error = %v", err) + return + } + + if gotUser.UserID != tt.wantUser.UserID { + t.Errorf("GetUserFromContext() UserID = %v, want %v", gotUser.UserID, tt.wantUser.UserID) + } + + if gotUser.Role != tt.wantUser.Role { + t.Errorf("GetUserFromContext() Role = %v, want %v", gotUser.Role, tt.wantUser.Role) + } + }) + } +} diff --git a/server/internal/context/middleware.go b/server/internal/context/middleware.go index 9c1d9b3..c916e96 100644 --- a/server/internal/context/middleware.go +++ b/server/internal/context/middleware.go @@ -28,7 +28,7 @@ func WithUserContextMiddleware(next http.Handler) http.Handler { } // WithWorkspaceContextMiddleware adds workspace information to the request context -func WithWorkspaceContextMiddleware(db db.Database) func(http.Handler) http.Handler { +func WithWorkspaceContextMiddleware(db db.WorkspaceReader) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, ok := GetRequestContext(w, r) diff --git a/server/internal/context/middleware_test.go b/server/internal/context/middleware_test.go new file mode 100644 index 0000000..eae1c97 --- /dev/null +++ b/server/internal/context/middleware_test.go @@ -0,0 +1,197 @@ +package context_test + +import ( + stdctx "context" + "database/sql" + "net/http" + "net/http/httptest" + "testing" + + "novamd/internal/context" + "novamd/internal/models" +) + +// MockDB implements the minimal database interface needed for testing +type MockDB struct { + GetWorkspaceByNameFunc func(userID int, workspaceName string) (*models.Workspace, error) +} + +func (m *MockDB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) { + return m.GetWorkspaceByNameFunc(userID, workspaceName) +} + +func (m *MockDB) GetWorkspaceByID(_ int) (*models.Workspace, error) { + return nil, nil +} + +func (m *MockDB) GetWorkspacesByUserID(_ int) ([]*models.Workspace, error) { + return nil, nil +} + +func (m *MockDB) GetAllWorkspaces() ([]*models.Workspace, error) { + return nil, nil +} + +func TestWithUserContextMiddleware(t *testing.T) { + tests := []struct { + name string + setupCtx func() *context.HandlerContext + wantStatus int + wantNext bool + }{ + { + name: "valid user context", + setupCtx: func() *context.HandlerContext { + return &context.HandlerContext{ + UserID: 1, + UserRole: "admin", + } + }, + wantStatus: http.StatusOK, + wantNext: true, + }, + { + name: "missing user context", + setupCtx: func() *context.HandlerContext { + return nil + }, + wantStatus: http.StatusUnauthorized, + wantNext: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + if ctx := tt.setupCtx(); ctx != nil { + req = context.WithHandlerContext(req, ctx) + } + + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + middleware := context.WithUserContextMiddleware(next) + middleware.ServeHTTP(w, req) + + if nextCalled != tt.wantNext { + t.Errorf("WithUserContextMiddleware() next called = %v, want %v", nextCalled, tt.wantNext) + } + + if w.Code != tt.wantStatus { + t.Errorf("WithUserContextMiddleware() status = %v, want %v", w.Code, tt.wantStatus) + } + }) + } +} + +func TestWithWorkspaceContextMiddleware(t *testing.T) { + tests := []struct { + name string + setupCtx func() *context.HandlerContext + workspaceName string + mockWorkspace *models.Workspace + mockError error + wantStatus int + wantNext bool + }{ + { + name: "valid workspace context", + setupCtx: func() *context.HandlerContext { + return &context.HandlerContext{ + UserID: 1, + UserRole: "admin", + } + }, + workspaceName: "test-workspace", + mockWorkspace: &models.Workspace{ + ID: 1, + UserID: 1, + Name: "test-workspace", + }, + mockError: nil, + wantStatus: http.StatusOK, + wantNext: true, + }, + { + name: "workspace not found", + setupCtx: func() *context.HandlerContext { + return &context.HandlerContext{ + UserID: 1, + UserRole: "admin", + } + }, + workspaceName: "nonexistent", + mockWorkspace: nil, + mockError: sql.ErrNoRows, + wantStatus: http.StatusNotFound, + wantNext: false, + }, + { + name: "missing user context", + setupCtx: func() *context.HandlerContext { return nil }, + workspaceName: "test-workspace", + mockWorkspace: nil, + mockError: nil, + wantStatus: http.StatusInternalServerError, + wantNext: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockDB := &MockDB{ + GetWorkspaceByNameFunc: func(_ int, _ string) (*models.Workspace, error) { + return tt.mockWorkspace, tt.mockError + }, + } + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + if ctx := tt.setupCtx(); ctx != nil { + req = context.WithHandlerContext(req, ctx) + } + + // Add workspace name to request context via chi URL params + req = req.WithContext(stdctx.WithValue(req.Context(), "workspaceName", tt.workspaceName)) + + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + + // Verify workspace was added to context + if tt.mockWorkspace != nil { + ctx, ok := context.GetRequestContext(w, r) + if !ok { + t.Error("Failed to get request context in next handler") + return + } + if ctx.Workspace == nil { + t.Error("Workspace not set in context") + return + } + if ctx.Workspace.ID != tt.mockWorkspace.ID { + t.Errorf("Workspace ID = %v, want %v", ctx.Workspace.ID, tt.mockWorkspace.ID) + } + } + }) + + middleware := context.WithWorkspaceContextMiddleware(mockDB)(next) + middleware.ServeHTTP(w, req) + + if nextCalled != tt.wantNext { + t.Errorf("WithWorkspaceContextMiddleware() next called = %v, want %v", nextCalled, tt.wantNext) + } + + if w.Code != tt.wantStatus { + t.Errorf("WithWorkspaceContextMiddleware() status = %v, want %v", w.Code, tt.wantStatus) + } + }) + } +} diff --git a/server/internal/db/db.go b/server/internal/db/db.go index a2624ee..4c45ddc 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -24,12 +24,17 @@ type UserStore interface { CountAdminUsers() (int, error) } -// WorkspaceStore defines the methods for interacting with workspace data in the database -type WorkspaceStore interface { - CreateWorkspace(workspace *models.Workspace) error +// WorkspaceReader defines the methods for reading workspace data from the database +type WorkspaceReader interface { GetWorkspaceByID(workspaceID int) (*models.Workspace, error) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) + GetAllWorkspaces() ([]*models.Workspace, error) +} + +// WorkspaceWriter defines the methods for writing workspace data to the database +type WorkspaceWriter interface { + CreateWorkspace(workspace *models.Workspace) error UpdateWorkspace(workspace *models.Workspace) error DeleteWorkspace(workspaceID int) error UpdateWorkspaceSettings(workspace *models.Workspace) error @@ -37,7 +42,12 @@ type WorkspaceStore interface { UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error UpdateLastOpenedFile(workspaceID int, filePath string) error GetLastOpenedFile(workspaceID int) (string, error) - GetAllWorkspaces() ([]*models.Workspace, error) +} + +// WorkspaceStore defines the methods for interacting with workspace data in the database +type WorkspaceStore interface { + WorkspaceReader + WorkspaceWriter } // SessionStore defines the methods for interacting with jwt sessions in the database @@ -67,6 +77,21 @@ type Database interface { Migrate() error } +var ( + // Main Database interface + _ Database = (*database)(nil) + + // Component interfaces + _ UserStore = (*database)(nil) + _ WorkspaceStore = (*database)(nil) + _ SessionStore = (*database)(nil) + _ SystemStore = (*database)(nil) + + // Sub-interfaces + _ WorkspaceReader = (*database)(nil) + _ WorkspaceWriter = (*database)(nil) +) + // database represents the database connection type database struct { *sql.DB