From 05c3a641da730d0d7c89fc5d8d7f0ad49215bef4 Mon Sep 17 00:00:00 2001 From: "Gabriele N. Tornetta" Date: Thu, 20 Mar 2025 10:58:18 +0000 Subject: [PATCH 1/5] Implement PEP 302 optional get_code loader method We implement the optional get_code loader method. This increases compatibility with other tools/libraries that need to manipulate the code object of a module before it is executed. --- src/_pytest/assertion/rewrite.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 2e606d1903a..9d621977e21 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -80,6 +80,7 @@ def __init__(self, config: Config) -> None: self._basenames_to_check_rewrite = {"conftest"} self._marked_for_rewrite_cache: dict[str, bool] = {} self._session_paths_checked = False + self.fn: str | None = None def set_session(self, session: Session | None) -> None: self.session = session @@ -126,7 +127,7 @@ def find_spec( ): return None else: - fn = spec.origin + self.fn = fn = spec.origin if not self._should_rewrite(name, fn, state): return None @@ -143,14 +144,11 @@ def create_module( ) -> types.ModuleType | None: return None # default behaviour is fine - def exec_module(self, module: types.ModuleType) -> None: - assert module.__spec__ is not None - assert module.__spec__.origin is not None - fn = Path(module.__spec__.origin) + def get_code(self, fullname: str) -> types.CodeType + assert self.fn is not None + fn = Path(self.fn) state = self.config.stash[assertstate_key] - self._rewritten_names[module.__name__] = fn - # The requested module looks like a test file, so rewrite it. This is # the most magical part of the process: load the source, rewrite the # asserts, and load the rewritten source. We also cache the rewritten @@ -183,7 +181,15 @@ def exec_module(self, module: types.ModuleType) -> None: self._writing_pyc = False else: state.trace(f"found cached rewritten pyc for {fn}") - exec(co, module.__dict__) + + return co + + def exec_module(self, module: types.ModuleType) -> None: + module_name = module.__name__ + + self._rewritten_names[module_name] = fn + + exec(self.get_code(module_name), module.__dict__) def _early_rewrite_bailout(self, name: str, state: AssertionState) -> bool: """A fast way to get out of rewriting modules. From 6ecc59da371947e65f0cc8f074e3c7345281a85e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Mar 2025 11:00:33 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/_pytest/assertion/rewrite.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 9d621977e21..50ade97eae6 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -183,12 +183,12 @@ def get_code(self, fullname: str) -> types.CodeType state.trace(f"found cached rewritten pyc for {fn}") return co - + def exec_module(self, module: types.ModuleType) -> None: module_name = module.__name__ - + self._rewritten_names[module_name] = fn - + exec(self.get_code(module_name), module.__dict__) def _early_rewrite_bailout(self, name: str, state: AssertionState) -> bool: From 316b8aa4bf9ff914aaebe6fb1c54d9391aa73057 Mon Sep 17 00:00:00 2001 From: "Gabriele N. Tornetta" Date: Thu, 20 Mar 2025 11:46:12 +0000 Subject: [PATCH 3/5] Update src/_pytest/assertion/rewrite.py --- src/_pytest/assertion/rewrite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 50ade97eae6..a465fdba98c 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -144,7 +144,7 @@ def create_module( ) -> types.ModuleType | None: return None # default behaviour is fine - def get_code(self, fullname: str) -> types.CodeType + def get_code(self, fullname: str) -> types.CodeType: assert self.fn is not None fn = Path(self.fn) state = self.config.stash[assertstate_key] From 8cc76128842d52d1819cd7805eb1c73bbd03b1e6 Mon Sep 17 00:00:00 2001 From: "Gabriele N. Tornetta" Date: Thu, 20 Mar 2025 12:53:26 +0000 Subject: [PATCH 4/5] Update src/_pytest/assertion/rewrite.py --- src/_pytest/assertion/rewrite.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index a465fdba98c..55f0522b14e 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -185,9 +185,11 @@ def get_code(self, fullname: str) -> types.CodeType: return co def exec_module(self, module: types.ModuleType) -> None: + assert self.fn is not None + module_name = module.__name__ - self._rewritten_names[module_name] = fn + self._rewritten_names[module_name] = Path(self.fn) exec(self.get_code(module_name), module.__dict__) From 4f156127a57795d317ddaec99036b712b8cb680d Mon Sep 17 00:00:00 2001 From: "Gabriele N. Tornetta" Date: Fri, 21 Mar 2025 11:00:56 +0000 Subject: [PATCH 5/5] map module fullname to origin --- src/_pytest/assertion/rewrite.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 55f0522b14e..3b7b6431ffd 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -80,7 +80,7 @@ def __init__(self, config: Config) -> None: self._basenames_to_check_rewrite = {"conftest"} self._marked_for_rewrite_cache: dict[str, bool] = {} self._session_paths_checked = False - self.fn: str | None = None + self._fns: dict[str, str] = {} def set_session(self, session: Session | None) -> None: self.session = session @@ -127,7 +127,7 @@ def find_spec( ): return None else: - self.fn = fn = spec.origin + self._fns[name] = fn = spec.origin if not self._should_rewrite(name, fn, state): return None @@ -145,8 +145,8 @@ def create_module( return None # default behaviour is fine def get_code(self, fullname: str) -> types.CodeType: - assert self.fn is not None - fn = Path(self.fn) + assert fullname in self._fns + fn = Path(self._fns[fullname]) state = self.config.stash[assertstate_key] # The requested module looks like a test file, so rewrite it. This is @@ -185,11 +185,15 @@ def get_code(self, fullname: str) -> types.CodeType: return co def exec_module(self, module: types.ModuleType) -> None: - assert self.fn is not None - module_name = module.__name__ - self._rewritten_names[module_name] = Path(self.fn) + assert ( + module_name in self._fns + and module.__spec__ is not None + and module.__spec__.origin == self._fns[module_name] + ) + + self._rewritten_names[module_name] = Path(self._fns[module_name]) exec(self.get_code(module_name), module.__dict__)