mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-05 23:44:22 +00:00
Test context package
This commit is contained in:
139
server/internal/context/context_test.go
Normal file
139
server/internal/context/context_test.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package context_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
stdctx "context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"novamd/internal/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetRequestContext(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupCtx func() *context.HandlerContext
|
||||||
|
wantStatus int
|
||||||
|
wantOK bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid context",
|
||||||
|
setupCtx: func() *context.HandlerContext {
|
||||||
|
return &context.HandlerContext{
|
||||||
|
UserID: 1,
|
||||||
|
UserRole: "admin",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing context",
|
||||||
|
setupCtx: func() *context.HandlerContext {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantStatus: http.StatusInternalServerError,
|
||||||
|
wantOK: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Create test request
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if ctx := tt.setupCtx(); ctx != nil {
|
||||||
|
req = context.WithHandlerContext(req, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotCtx, ok := context.GetRequestContext(w, req)
|
||||||
|
|
||||||
|
if ok != tt.wantOK {
|
||||||
|
t.Errorf("GetRequestContext() ok = %v, want %v", ok, tt.wantOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.wantOK {
|
||||||
|
if w.Code != tt.wantStatus {
|
||||||
|
t.Errorf("GetRequestContext() status = %v, want %v", w.Code, tt.wantStatus)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotCtx.UserID != tt.setupCtx().UserID {
|
||||||
|
t.Errorf("GetRequestContext() UserID = %v, want %v", gotCtx.UserID, tt.setupCtx().UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotCtx.UserRole != tt.setupCtx().UserRole {
|
||||||
|
t.Errorf("GetRequestContext() UserRole = %v, want %v", gotCtx.UserRole, tt.setupCtx().UserRole)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserFromContext(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupCtx func() stdctx.Context
|
||||||
|
wantUser *context.UserClaims
|
||||||
|
wantError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid user context",
|
||||||
|
setupCtx: func() stdctx.Context {
|
||||||
|
return stdctx.WithValue(stdctx.Background(), context.HandlerContextKey, &context.HandlerContext{
|
||||||
|
UserID: 1,
|
||||||
|
UserRole: "admin",
|
||||||
|
})
|
||||||
|
},
|
||||||
|
wantUser: &context.UserClaims{
|
||||||
|
UserID: 1,
|
||||||
|
Role: "admin",
|
||||||
|
},
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing context",
|
||||||
|
setupCtx: func() stdctx.Context {
|
||||||
|
return stdctx.Background()
|
||||||
|
},
|
||||||
|
wantUser: nil,
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid context type",
|
||||||
|
setupCtx: func() stdctx.Context {
|
||||||
|
return stdctx.WithValue(stdctx.Background(), context.HandlerContextKey, "invalid")
|
||||||
|
},
|
||||||
|
wantUser: nil,
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := tt.setupCtx()
|
||||||
|
gotUser, err := context.GetUserFromContext(ctx)
|
||||||
|
|
||||||
|
if tt.wantError {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("GetUserFromContext() error = nil, want error")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("GetUserFromContext() unexpected error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotUser.UserID != tt.wantUser.UserID {
|
||||||
|
t.Errorf("GetUserFromContext() UserID = %v, want %v", gotUser.UserID, tt.wantUser.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotUser.Role != tt.wantUser.Role {
|
||||||
|
t.Errorf("GetUserFromContext() Role = %v, want %v", gotUser.Role, tt.wantUser.Role)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -28,7 +28,7 @@ func WithUserContextMiddleware(next http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WithWorkspaceContextMiddleware adds workspace information to the request context
|
// WithWorkspaceContextMiddleware adds workspace information to the request context
|
||||||
func WithWorkspaceContextMiddleware(db db.Database) func(http.Handler) http.Handler {
|
func WithWorkspaceContextMiddleware(db db.WorkspaceReader) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx, ok := GetRequestContext(w, r)
|
ctx, ok := GetRequestContext(w, r)
|
||||||
|
|||||||
197
server/internal/context/middleware_test.go
Normal file
197
server/internal/context/middleware_test.go
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
package context_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
stdctx "context"
|
||||||
|
"database/sql"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"novamd/internal/context"
|
||||||
|
"novamd/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockDB implements the minimal database interface needed for testing
|
||||||
|
type MockDB struct {
|
||||||
|
GetWorkspaceByNameFunc func(userID int, workspaceName string) (*models.Workspace, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
|
||||||
|
return m.GetWorkspaceByNameFunc(userID, workspaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDB) GetWorkspaceByID(_ int) (*models.Workspace, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDB) GetWorkspacesByUserID(_ int) ([]*models.Workspace, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDB) GetAllWorkspaces() ([]*models.Workspace, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithUserContextMiddleware(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupCtx func() *context.HandlerContext
|
||||||
|
wantStatus int
|
||||||
|
wantNext bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid user context",
|
||||||
|
setupCtx: func() *context.HandlerContext {
|
||||||
|
return &context.HandlerContext{
|
||||||
|
UserID: 1,
|
||||||
|
UserRole: "admin",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
wantNext: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing user context",
|
||||||
|
setupCtx: func() *context.HandlerContext {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
wantNext: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if ctx := tt.setupCtx(); ctx != nil {
|
||||||
|
req = context.WithHandlerContext(req, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextCalled := false
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
nextCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
middleware := context.WithUserContextMiddleware(next)
|
||||||
|
middleware.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if nextCalled != tt.wantNext {
|
||||||
|
t.Errorf("WithUserContextMiddleware() next called = %v, want %v", nextCalled, tt.wantNext)
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Code != tt.wantStatus {
|
||||||
|
t.Errorf("WithUserContextMiddleware() status = %v, want %v", w.Code, tt.wantStatus)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithWorkspaceContextMiddleware(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupCtx func() *context.HandlerContext
|
||||||
|
workspaceName string
|
||||||
|
mockWorkspace *models.Workspace
|
||||||
|
mockError error
|
||||||
|
wantStatus int
|
||||||
|
wantNext bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid workspace context",
|
||||||
|
setupCtx: func() *context.HandlerContext {
|
||||||
|
return &context.HandlerContext{
|
||||||
|
UserID: 1,
|
||||||
|
UserRole: "admin",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
workspaceName: "test-workspace",
|
||||||
|
mockWorkspace: &models.Workspace{
|
||||||
|
ID: 1,
|
||||||
|
UserID: 1,
|
||||||
|
Name: "test-workspace",
|
||||||
|
},
|
||||||
|
mockError: nil,
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
wantNext: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "workspace not found",
|
||||||
|
setupCtx: func() *context.HandlerContext {
|
||||||
|
return &context.HandlerContext{
|
||||||
|
UserID: 1,
|
||||||
|
UserRole: "admin",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
workspaceName: "nonexistent",
|
||||||
|
mockWorkspace: nil,
|
||||||
|
mockError: sql.ErrNoRows,
|
||||||
|
wantStatus: http.StatusNotFound,
|
||||||
|
wantNext: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing user context",
|
||||||
|
setupCtx: func() *context.HandlerContext { return nil },
|
||||||
|
workspaceName: "test-workspace",
|
||||||
|
mockWorkspace: nil,
|
||||||
|
mockError: nil,
|
||||||
|
wantStatus: http.StatusInternalServerError,
|
||||||
|
wantNext: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockDB := &MockDB{
|
||||||
|
GetWorkspaceByNameFunc: func(_ int, _ string) (*models.Workspace, error) {
|
||||||
|
return tt.mockWorkspace, tt.mockError
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if ctx := tt.setupCtx(); ctx != nil {
|
||||||
|
req = context.WithHandlerContext(req, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add workspace name to request context via chi URL params
|
||||||
|
req = req.WithContext(stdctx.WithValue(req.Context(), "workspaceName", tt.workspaceName))
|
||||||
|
|
||||||
|
nextCalled := false
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
nextCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
// Verify workspace was added to context
|
||||||
|
if tt.mockWorkspace != nil {
|
||||||
|
ctx, ok := context.GetRequestContext(w, r)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Failed to get request context in next handler")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ctx.Workspace == nil {
|
||||||
|
t.Error("Workspace not set in context")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ctx.Workspace.ID != tt.mockWorkspace.ID {
|
||||||
|
t.Errorf("Workspace ID = %v, want %v", ctx.Workspace.ID, tt.mockWorkspace.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
middleware := context.WithWorkspaceContextMiddleware(mockDB)(next)
|
||||||
|
middleware.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if nextCalled != tt.wantNext {
|
||||||
|
t.Errorf("WithWorkspaceContextMiddleware() next called = %v, want %v", nextCalled, tt.wantNext)
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Code != tt.wantStatus {
|
||||||
|
t.Errorf("WithWorkspaceContextMiddleware() status = %v, want %v", w.Code, tt.wantStatus)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -24,12 +24,17 @@ type UserStore interface {
|
|||||||
CountAdminUsers() (int, error)
|
CountAdminUsers() (int, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WorkspaceStore defines the methods for interacting with workspace data in the database
|
// WorkspaceReader defines the methods for reading workspace data from the database
|
||||||
type WorkspaceStore interface {
|
type WorkspaceReader interface {
|
||||||
CreateWorkspace(workspace *models.Workspace) error
|
|
||||||
GetWorkspaceByID(workspaceID int) (*models.Workspace, error)
|
GetWorkspaceByID(workspaceID int) (*models.Workspace, error)
|
||||||
GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error)
|
GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error)
|
||||||
GetWorkspacesByUserID(userID int) ([]*models.Workspace, error)
|
GetWorkspacesByUserID(userID int) ([]*models.Workspace, error)
|
||||||
|
GetAllWorkspaces() ([]*models.Workspace, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WorkspaceWriter defines the methods for writing workspace data to the database
|
||||||
|
type WorkspaceWriter interface {
|
||||||
|
CreateWorkspace(workspace *models.Workspace) error
|
||||||
UpdateWorkspace(workspace *models.Workspace) error
|
UpdateWorkspace(workspace *models.Workspace) error
|
||||||
DeleteWorkspace(workspaceID int) error
|
DeleteWorkspace(workspaceID int) error
|
||||||
UpdateWorkspaceSettings(workspace *models.Workspace) error
|
UpdateWorkspaceSettings(workspace *models.Workspace) error
|
||||||
@@ -37,7 +42,12 @@ type WorkspaceStore interface {
|
|||||||
UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error
|
UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error
|
||||||
UpdateLastOpenedFile(workspaceID int, filePath string) error
|
UpdateLastOpenedFile(workspaceID int, filePath string) error
|
||||||
GetLastOpenedFile(workspaceID int) (string, error)
|
GetLastOpenedFile(workspaceID int) (string, error)
|
||||||
GetAllWorkspaces() ([]*models.Workspace, error)
|
}
|
||||||
|
|
||||||
|
// WorkspaceStore defines the methods for interacting with workspace data in the database
|
||||||
|
type WorkspaceStore interface {
|
||||||
|
WorkspaceReader
|
||||||
|
WorkspaceWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionStore defines the methods for interacting with jwt sessions in the database
|
// SessionStore defines the methods for interacting with jwt sessions in the database
|
||||||
@@ -67,6 +77,21 @@ type Database interface {
|
|||||||
Migrate() error
|
Migrate() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Main Database interface
|
||||||
|
_ Database = (*database)(nil)
|
||||||
|
|
||||||
|
// Component interfaces
|
||||||
|
_ UserStore = (*database)(nil)
|
||||||
|
_ WorkspaceStore = (*database)(nil)
|
||||||
|
_ SessionStore = (*database)(nil)
|
||||||
|
_ SystemStore = (*database)(nil)
|
||||||
|
|
||||||
|
// Sub-interfaces
|
||||||
|
_ WorkspaceReader = (*database)(nil)
|
||||||
|
_ WorkspaceWriter = (*database)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
// database represents the database connection
|
// database represents the database connection
|
||||||
type database struct {
|
type database struct {
|
||||||
*sql.DB
|
*sql.DB
|
||||||
|
|||||||
Reference in New Issue
Block a user