From 24a930908ce210549ee5dbd4284a43c769603b0b Mon Sep 17 00:00:00 2001 From: Simon Knott Date: Thu, 30 Apr 2026 15:50:42 +0200 Subject: [PATCH] fix: propagate contextvars to event handlers Capture the caller's contextvars context when an event handler is registered, and re-establish it when the handler runs. Without this, contextvars set in user code (e.g. request IDs in logging frameworks) were not visible inside event handlers because events are dispatched from a different greenlet (sync mode) or asyncio task (async mode) than the one that registered the handler. Implementation per mode: * Async-mode coroutine handler: spawn an inner Task inside the captured context so the Task adopts it (Tasks copy the active context at construction). * Async-mode sync handler: run via Context.run. * Sync mode: temporarily set the EventGreenlet's gr_context to the captured context for the duration of the handler. Context.run is not used here because handlers like route.fulfill internally greenlet.switch, and Context.run does not compose with greenlet switches. Fixes #1816 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- playwright/_impl/_impl_to_api_mapping.py | 46 ++++++++++++++++++++++-- tests/async/test_page.py | 38 ++++++++++++++++++++ tests/sync/test_page_event_request.py | 22 ++++++++++++ 3 files changed, 103 insertions(+), 3 deletions(-) diff --git a/playwright/_impl/_impl_to_api_mapping.py b/playwright/_impl/_impl_to_api_mapping.py index e26d22025..e7b0dc3ca 100644 --- a/playwright/_impl/_impl_to_api_mapping.py +++ b/playwright/_impl/_impl_to_api_mapping.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import contextvars import inspect from typing import Any, Callable, Dict, List, Optional, Sequence, Union +import greenlet + from playwright._impl._errors import Error from playwright._impl._map import Map @@ -118,11 +122,47 @@ def to_impl( raise Error("Maximum argument depth exceeded") def wrap_handler(self, handler: Callable[..., Any]) -> Callable[..., None]: + # Capture the caller's context at registration time so contextvars + # set in user code are available when the event handler runs, even + # though events are dispatched from a different greenlet/task. + # See: https://github.com/microsoft/playwright-python/issues/1816 + context = contextvars.copy_context() + is_coroutine = inspect.iscoroutinefunction(handler) + def wrapper_func(*args: Any) -> Any: arg_count = len(inspect.signature(handler).parameters) - return handler( - *list(map(lambda a: self.from_maybe_impl(a), args))[:arg_count] - ) + mapped_args = list(map(lambda a: self.from_maybe_impl(a), args))[:arg_count] + if is_coroutine: + # Async-mode coroutine handler: propagate context to the + # handler's awaits by spawning an inner Task inside our + # captured context (Tasks copy the active context at + # construction). + async def _coro_wrapper() -> Any: + loop = asyncio.get_running_loop() + inner = context.run(lambda: loop.create_task(handler(*mapped_args))) + return await inner + + return _coro_wrapper() + # Sync handler. Two cases: + # * Async mode: no greenlet is involved in event dispatch + # (asyncio Task), so we use Context.run to run the handler + # in the captured context. + # * Sync mode: events are dispatched inside a fresh + # EventGreenlet whose default gr_context is empty. We can't + # use Context.run here because handlers like route.fulfill + # internally use greenlet.switch, and Context.run does not + # compose with greenlet switches. Instead we set the + # greenlet's gr_context to our captured context for the + # duration of the handler, then restore it. + current = greenlet.getcurrent() + if current.parent is None: + return context.run(handler, *mapped_args) + saved_context = current.gr_context + current.gr_context = context + try: + return handler(*mapped_args) + finally: + current.gr_context = saved_context if inspect.ismethod(handler): wrapper = getattr(handler.__self__, IMPL_ATTR + handler.__name__, None) diff --git a/tests/async/test_page.py b/tests/async/test_page.py index 2fcf4c92d..3589b60c0 100644 --- a/tests/async/test_page.py +++ b/tests/async/test_page.py @@ -1605,3 +1605,41 @@ async def test_page_should_ignore_deprecated_is_hidden_and_visible_timeout( await page.set_content("
foo
") assert await page.is_hidden("div", timeout=10) is False assert await page.is_visible("div", timeout=10) is True + + +async def test_should_propagate_contextvars_to_event_handlers( + page: Page, server: Server +) -> None: + import contextvars + + shared_var: "contextvars.ContextVar[str]" = contextvars.ContextVar("shared_var") + shared_var.set("expected value") + + sync_seen: List[Optional[str]] = [] + async_seen: List[Optional[str]] = [] + + def on_request_sync(request: Any) -> None: + try: + sync_seen.append(shared_var.get()) + except LookupError: + sync_seen.append(None) + + async def on_request_async(request: Any) -> None: + try: + async_seen.append(shared_var.get()) + except LookupError: + async_seen.append(None) + await asyncio.sleep(0) + try: + async_seen.append(shared_var.get()) + except LookupError: + async_seen.append(None) + + page.on("request", on_request_sync) + page.on("request", on_request_async) + await page.goto(server.EMPTY_PAGE) + await asyncio.sleep(0.1) + assert sync_seen + assert all(v == "expected value" for v in sync_seen) + assert async_seen + assert all(v == "expected value" for v in async_seen) diff --git a/tests/sync/test_page_event_request.py b/tests/sync/test_page_event_request.py index 77515091f..803ccb7bf 100644 --- a/tests/sync/test_page_event_request.py +++ b/tests/sync/test_page_event_request.py @@ -52,3 +52,25 @@ def gather_response(request: Request) -> dict: url = server.PREFIX + f"/fetch?{i}" expected.append({"url": url, "text": f"url:{url}"}) assert received == expected + + +def test_should_propagate_contextvars_to_event_handlers( + page: Page, server: Server +) -> None: + import contextvars + + shared_var: "contextvars.ContextVar[str]" = contextvars.ContextVar("shared_var") + shared_var.set("expected value") + + seen: list = [] + + def on_request(request: Request) -> None: + try: + seen.append(shared_var.get()) + except LookupError: + seen.append(None) + + page.on("request", on_request) + page.goto(server.EMPTY_PAGE) + assert seen + assert all(v == "expected value" for v in seen)