Test context package

This commit is contained in:
2024-11-23 22:15:25 +01:00
parent 8f2f8b30dd
commit 9f241271a7
4 changed files with 366 additions and 5 deletions

View File

@@ -0,0 +1,139 @@
package context_test
import (
stdctx "context"
"net/http"
"net/http/httptest"
"testing"
"novamd/internal/context"
)
func TestGetRequestContext(t *testing.T) {
tests := []struct {
name string
setupCtx func() *context.HandlerContext
wantStatus int
wantOK bool
}{
{
name: "valid context",
setupCtx: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 1,
UserRole: "admin",
}
},
wantStatus: http.StatusOK,
wantOK: true,
},
{
name: "missing context",
setupCtx: func() *context.HandlerContext {
return nil
},
wantStatus: http.StatusInternalServerError,
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test request
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
if ctx := tt.setupCtx(); ctx != nil {
req = context.WithHandlerContext(req, ctx)
}
gotCtx, ok := context.GetRequestContext(w, req)
if ok != tt.wantOK {
t.Errorf("GetRequestContext() ok = %v, want %v", ok, tt.wantOK)
}
if !tt.wantOK {
if w.Code != tt.wantStatus {
t.Errorf("GetRequestContext() status = %v, want %v", w.Code, tt.wantStatus)
}
return
}
if gotCtx.UserID != tt.setupCtx().UserID {
t.Errorf("GetRequestContext() UserID = %v, want %v", gotCtx.UserID, tt.setupCtx().UserID)
}
if gotCtx.UserRole != tt.setupCtx().UserRole {
t.Errorf("GetRequestContext() UserRole = %v, want %v", gotCtx.UserRole, tt.setupCtx().UserRole)
}
})
}
}
func TestGetUserFromContext(t *testing.T) {
tests := []struct {
name string
setupCtx func() stdctx.Context
wantUser *context.UserClaims
wantError bool
}{
{
name: "valid user context",
setupCtx: func() stdctx.Context {
return stdctx.WithValue(stdctx.Background(), context.HandlerContextKey, &context.HandlerContext{
UserID: 1,
UserRole: "admin",
})
},
wantUser: &context.UserClaims{
UserID: 1,
Role: "admin",
},
wantError: false,
},
{
name: "missing context",
setupCtx: func() stdctx.Context {
return stdctx.Background()
},
wantUser: nil,
wantError: true,
},
{
name: "invalid context type",
setupCtx: func() stdctx.Context {
return stdctx.WithValue(stdctx.Background(), context.HandlerContextKey, "invalid")
},
wantUser: nil,
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := tt.setupCtx()
gotUser, err := context.GetUserFromContext(ctx)
if tt.wantError {
if err == nil {
t.Error("GetUserFromContext() error = nil, want error")
}
return
}
if err != nil {
t.Errorf("GetUserFromContext() unexpected error = %v", err)
return
}
if gotUser.UserID != tt.wantUser.UserID {
t.Errorf("GetUserFromContext() UserID = %v, want %v", gotUser.UserID, tt.wantUser.UserID)
}
if gotUser.Role != tt.wantUser.Role {
t.Errorf("GetUserFromContext() Role = %v, want %v", gotUser.Role, tt.wantUser.Role)
}
})
}
}

View File

@@ -28,7 +28,7 @@ func WithUserContextMiddleware(next http.Handler) http.Handler {
} }
// WithWorkspaceContextMiddleware adds workspace information to the request context // WithWorkspaceContextMiddleware adds workspace information to the request context
func WithWorkspaceContextMiddleware(db db.Database) func(http.Handler) http.Handler { func WithWorkspaceContextMiddleware(db db.WorkspaceReader) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, ok := GetRequestContext(w, r) ctx, ok := GetRequestContext(w, r)

View File

@@ -0,0 +1,197 @@
package context_test
import (
stdctx "context"
"database/sql"
"net/http"
"net/http/httptest"
"testing"
"novamd/internal/context"
"novamd/internal/models"
)
// MockDB implements the minimal database interface needed for testing
type MockDB struct {
GetWorkspaceByNameFunc func(userID int, workspaceName string) (*models.Workspace, error)
}
func (m *MockDB) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
return m.GetWorkspaceByNameFunc(userID, workspaceName)
}
func (m *MockDB) GetWorkspaceByID(_ int) (*models.Workspace, error) {
return nil, nil
}
func (m *MockDB) GetWorkspacesByUserID(_ int) ([]*models.Workspace, error) {
return nil, nil
}
func (m *MockDB) GetAllWorkspaces() ([]*models.Workspace, error) {
return nil, nil
}
func TestWithUserContextMiddleware(t *testing.T) {
tests := []struct {
name string
setupCtx func() *context.HandlerContext
wantStatus int
wantNext bool
}{
{
name: "valid user context",
setupCtx: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 1,
UserRole: "admin",
}
},
wantStatus: http.StatusOK,
wantNext: true,
},
{
name: "missing user context",
setupCtx: func() *context.HandlerContext {
return nil
},
wantStatus: http.StatusUnauthorized,
wantNext: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
if ctx := tt.setupCtx(); ctx != nil {
req = context.WithHandlerContext(req, ctx)
}
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})
middleware := context.WithUserContextMiddleware(next)
middleware.ServeHTTP(w, req)
if nextCalled != tt.wantNext {
t.Errorf("WithUserContextMiddleware() next called = %v, want %v", nextCalled, tt.wantNext)
}
if w.Code != tt.wantStatus {
t.Errorf("WithUserContextMiddleware() status = %v, want %v", w.Code, tt.wantStatus)
}
})
}
}
func TestWithWorkspaceContextMiddleware(t *testing.T) {
tests := []struct {
name string
setupCtx func() *context.HandlerContext
workspaceName string
mockWorkspace *models.Workspace
mockError error
wantStatus int
wantNext bool
}{
{
name: "valid workspace context",
setupCtx: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 1,
UserRole: "admin",
}
},
workspaceName: "test-workspace",
mockWorkspace: &models.Workspace{
ID: 1,
UserID: 1,
Name: "test-workspace",
},
mockError: nil,
wantStatus: http.StatusOK,
wantNext: true,
},
{
name: "workspace not found",
setupCtx: func() *context.HandlerContext {
return &context.HandlerContext{
UserID: 1,
UserRole: "admin",
}
},
workspaceName: "nonexistent",
mockWorkspace: nil,
mockError: sql.ErrNoRows,
wantStatus: http.StatusNotFound,
wantNext: false,
},
{
name: "missing user context",
setupCtx: func() *context.HandlerContext { return nil },
workspaceName: "test-workspace",
mockWorkspace: nil,
mockError: nil,
wantStatus: http.StatusInternalServerError,
wantNext: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockDB := &MockDB{
GetWorkspaceByNameFunc: func(_ int, _ string) (*models.Workspace, error) {
return tt.mockWorkspace, tt.mockError
},
}
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
if ctx := tt.setupCtx(); ctx != nil {
req = context.WithHandlerContext(req, ctx)
}
// Add workspace name to request context via chi URL params
req = req.WithContext(stdctx.WithValue(req.Context(), "workspaceName", tt.workspaceName))
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
// Verify workspace was added to context
if tt.mockWorkspace != nil {
ctx, ok := context.GetRequestContext(w, r)
if !ok {
t.Error("Failed to get request context in next handler")
return
}
if ctx.Workspace == nil {
t.Error("Workspace not set in context")
return
}
if ctx.Workspace.ID != tt.mockWorkspace.ID {
t.Errorf("Workspace ID = %v, want %v", ctx.Workspace.ID, tt.mockWorkspace.ID)
}
}
})
middleware := context.WithWorkspaceContextMiddleware(mockDB)(next)
middleware.ServeHTTP(w, req)
if nextCalled != tt.wantNext {
t.Errorf("WithWorkspaceContextMiddleware() next called = %v, want %v", nextCalled, tt.wantNext)
}
if w.Code != tt.wantStatus {
t.Errorf("WithWorkspaceContextMiddleware() status = %v, want %v", w.Code, tt.wantStatus)
}
})
}
}

View File

@@ -24,12 +24,17 @@ type UserStore interface {
CountAdminUsers() (int, error) CountAdminUsers() (int, error)
} }
// WorkspaceStore defines the methods for interacting with workspace data in the database // WorkspaceReader defines the methods for reading workspace data from the database
type WorkspaceStore interface { type WorkspaceReader interface {
CreateWorkspace(workspace *models.Workspace) error
GetWorkspaceByID(workspaceID int) (*models.Workspace, error) GetWorkspaceByID(workspaceID int) (*models.Workspace, error)
GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error)
GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error)
GetAllWorkspaces() ([]*models.Workspace, error)
}
// WorkspaceWriter defines the methods for writing workspace data to the database
type WorkspaceWriter interface {
CreateWorkspace(workspace *models.Workspace) error
UpdateWorkspace(workspace *models.Workspace) error UpdateWorkspace(workspace *models.Workspace) error
DeleteWorkspace(workspaceID int) error DeleteWorkspace(workspaceID int) error
UpdateWorkspaceSettings(workspace *models.Workspace) error UpdateWorkspaceSettings(workspace *models.Workspace) error
@@ -37,7 +42,12 @@ type WorkspaceStore interface {
UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error
UpdateLastOpenedFile(workspaceID int, filePath string) error UpdateLastOpenedFile(workspaceID int, filePath string) error
GetLastOpenedFile(workspaceID int) (string, error) GetLastOpenedFile(workspaceID int) (string, error)
GetAllWorkspaces() ([]*models.Workspace, error) }
// WorkspaceStore defines the methods for interacting with workspace data in the database
type WorkspaceStore interface {
WorkspaceReader
WorkspaceWriter
} }
// SessionStore defines the methods for interacting with jwt sessions in the database // SessionStore defines the methods for interacting with jwt sessions in the database
@@ -67,6 +77,21 @@ type Database interface {
Migrate() error Migrate() error
} }
var (
// Main Database interface
_ Database = (*database)(nil)
// Component interfaces
_ UserStore = (*database)(nil)
_ WorkspaceStore = (*database)(nil)
_ SessionStore = (*database)(nil)
_ SystemStore = (*database)(nil)
// Sub-interfaces
_ WorkspaceReader = (*database)(nil)
_ WorkspaceWriter = (*database)(nil)
)
// database represents the database connection // database represents the database connection
type database struct { type database struct {
*sql.DB *sql.DB