mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-05 15:44:21 +00:00
141 lines
3.1 KiB
Go
141 lines
3.1 KiB
Go
package context_test
|
|
|
|
import (
|
|
stdctx "context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"lemma/internal/context"
|
|
_ "lemma/internal/testenv"
|
|
)
|
|
|
|
func TestGetRequestContext(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupCtx func() *context.HandlerContext
|
|
wantStatus int
|
|
wantOK bool
|
|
}{
|
|
{
|
|
name: "valid context",
|
|
setupCtx: func() *context.HandlerContext {
|
|
return &context.HandlerContext{
|
|
UserID: 1,
|
|
UserRole: "admin",
|
|
}
|
|
},
|
|
wantStatus: http.StatusOK,
|
|
wantOK: true,
|
|
},
|
|
{
|
|
name: "missing context",
|
|
setupCtx: func() *context.HandlerContext {
|
|
return nil
|
|
},
|
|
wantStatus: http.StatusInternalServerError,
|
|
wantOK: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create test request
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
if ctx := tt.setupCtx(); ctx != nil {
|
|
req = context.WithHandlerContext(req, ctx)
|
|
}
|
|
|
|
gotCtx, ok := context.GetRequestContext(w, req)
|
|
|
|
if ok != tt.wantOK {
|
|
t.Errorf("GetRequestContext() ok = %v, want %v", ok, tt.wantOK)
|
|
}
|
|
|
|
if !tt.wantOK {
|
|
if w.Code != tt.wantStatus {
|
|
t.Errorf("GetRequestContext() status = %v, want %v", w.Code, tt.wantStatus)
|
|
}
|
|
return
|
|
}
|
|
|
|
if gotCtx.UserID != tt.setupCtx().UserID {
|
|
t.Errorf("GetRequestContext() UserID = %v, want %v", gotCtx.UserID, tt.setupCtx().UserID)
|
|
}
|
|
|
|
if gotCtx.UserRole != tt.setupCtx().UserRole {
|
|
t.Errorf("GetRequestContext() UserRole = %v, want %v", gotCtx.UserRole, tt.setupCtx().UserRole)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetUserFromContext(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupCtx func() stdctx.Context
|
|
wantUser *context.UserClaims
|
|
wantError bool
|
|
}{
|
|
{
|
|
name: "valid user context",
|
|
setupCtx: func() stdctx.Context {
|
|
return stdctx.WithValue(stdctx.Background(), context.HandlerContextKey, &context.HandlerContext{
|
|
UserID: 1,
|
|
UserRole: "admin",
|
|
})
|
|
},
|
|
wantUser: &context.UserClaims{
|
|
UserID: 1,
|
|
Role: "admin",
|
|
},
|
|
wantError: false,
|
|
},
|
|
{
|
|
name: "missing context",
|
|
setupCtx: func() stdctx.Context {
|
|
return stdctx.Background()
|
|
},
|
|
wantUser: nil,
|
|
wantError: true,
|
|
},
|
|
{
|
|
name: "invalid context type",
|
|
setupCtx: func() stdctx.Context {
|
|
return stdctx.WithValue(stdctx.Background(), context.HandlerContextKey, "invalid")
|
|
},
|
|
wantUser: nil,
|
|
wantError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := tt.setupCtx()
|
|
gotUser, err := context.GetUserFromContext(ctx)
|
|
|
|
if tt.wantError {
|
|
if err == nil {
|
|
t.Error("GetUserFromContext() error = nil, want error")
|
|
}
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
t.Errorf("GetUserFromContext() unexpected error = %v", err)
|
|
return
|
|
}
|
|
|
|
if gotUser.UserID != tt.wantUser.UserID {
|
|
t.Errorf("GetUserFromContext() UserID = %v, want %v", gotUser.UserID, tt.wantUser.UserID)
|
|
}
|
|
|
|
if gotUser.Role != tt.wantUser.Role {
|
|
t.Errorf("GetUserFromContext() Role = %v, want %v", gotUser.Role, tt.wantUser.Role)
|
|
}
|
|
})
|
|
}
|
|
}
|