Skip to content

Commit 7578110

Browse files
authored
Use spawn in _compat_test.py to avoid fork problems (#6374)
Review: @dstrain115
1 parent 30b6c39 commit 7578110

File tree

3 files changed

+50
-43
lines changed

3 files changed

+50
-43
lines changed

cirq-core/cirq/_compat_test.py

+45-38
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def test_wrap_module():
346346

347347

348348
def test_deprecate_attributes_assert_attributes_in_sys_modules():
349-
subprocess_context(_test_deprecate_attributes_assert_attributes_in_sys_modules)()
349+
run_in_subprocess(_test_deprecate_attributes_assert_attributes_in_sys_modules)
350350

351351

352352
def _test_deprecate_attributes_assert_attributes_in_sys_modules():
@@ -635,42 +635,49 @@ def _type_repr_in_deprecated_module():
635635
] + _deprecation_origin
636636

637637

638-
def _trace_unhandled_exceptions(*args, queue: 'multiprocessing.Queue', func: Callable, **kwargs):
638+
def _trace_unhandled_exceptions(*args, queue: 'multiprocessing.Queue', func: Callable):
639639
try:
640-
func(*args, **kwargs)
640+
func(*args)
641641
queue.put(None)
642642
except BaseException as ex:
643643
msg = str(ex)
644644
queue.put((type(ex).__name__, msg, traceback.format_exc()))
645645

646646

647-
def subprocess_context(test_func):
648-
"""Ensures that sys.modules changes in subprocesses won't impact the parent process."""
647+
def run_in_subprocess(test_func, *args):
648+
"""Run a function in a subprocess.
649+
650+
This ensures that sys.modules changes in subprocesses won't impact the parent process.
651+
652+
Args:
653+
test_func: The function to be run in a subprocess.
654+
*args: Positional args to pass to the function.
655+
"""
656+
649657
assert callable(test_func), (
650-
"subprocess_context expects a function. Did you call the function instead of passing "
658+
"run_in_subprocess expects a function. Did you call the function instead of passing "
651659
"it to this method?"
652660
)
653661

654-
ctx = multiprocessing.get_context('spawn' if os.name == 'nt' else 'fork')
655-
656-
exception = ctx.Queue()
662+
# Use spawn to ensure subprocesses are isolated.
663+
# See https://github.com/quantumlib/Cirq/issues/6373
664+
ctx = multiprocessing.get_context('spawn')
657665

658-
def isolated_func(*args, **kwargs):
659-
kwargs['queue'] = exception
660-
kwargs['func'] = test_func
661-
p = ctx.Process(target=_trace_unhandled_exceptions, args=args, kwargs=kwargs)
662-
p.start()
663-
p.join()
664-
result = exception.get()
665-
if result: # pragma: no cover
666-
ex_type, msg, ex_trace = result
667-
if ex_type == "Skipped":
668-
warnings.warn(f"Skipping: {ex_type}: {msg}\n{ex_trace}")
669-
pytest.skip(f'{ex_type}: {msg}\n{ex_trace}')
670-
else:
671-
pytest.fail(f'{ex_type}: {msg}\n{ex_trace}')
666+
queue = ctx.Queue()
672667

673-
return isolated_func
668+
p = ctx.Process(
669+
target=_trace_unhandled_exceptions, args=args, kwargs={'queue': queue, 'func': test_func}
670+
)
671+
p.start()
672+
p.join()
673+
result = queue.get()
674+
if result: # pragma: no cover
675+
ex_type, msg, ex_trace = result
676+
if ex_type == "Skipped":
677+
warnings.warn(f"Skipping: {ex_type}: {msg}\n{ex_trace}")
678+
pytest.skip(f'{ex_type}: {msg}\n{ex_trace}')
679+
else:
680+
pytest.fail(f'{ex_type}: {msg}\n{ex_trace}')
674681

675682

676683
@mock.patch.dict(os.environ, {"CIRQ_FORCE_DEDUPE_MODULE_DEPRECATION": "1"})
@@ -698,7 +705,7 @@ def isolated_func(*args, **kwargs):
698705
],
699706
)
700707
def test_deprecated_module(outdated_method, deprecation_messages):
701-
subprocess_context(_test_deprecated_module_inner)(outdated_method, deprecation_messages)
708+
run_in_subprocess(_test_deprecated_module_inner, outdated_method, deprecation_messages)
702709

703710

704711
def _test_deprecated_module_inner(outdated_method, deprecation_messages):
@@ -736,7 +743,7 @@ def test_same_name_submodule_earlier_in_subtree():
736743
cirq.ops.engine.calibration packages. The wrong resolution resulted in false circular
737744
imports!
738745
"""
739-
subprocess_context(_test_same_name_submodule_earlier_in_subtree_inner)()
746+
run_in_subprocess(_test_same_name_submodule_earlier_in_subtree_inner)
740747

741748

742749
def _test_same_name_submodule_earlier_in_subtree_inner():
@@ -748,7 +755,7 @@ def _test_same_name_submodule_earlier_in_subtree_inner():
748755
def test_metadata_search_path():
749756
# to cater for metadata path finders
750757
# https://docs.python.org/3/library/importlib.metadata.html#extending-the-search-algorithm
751-
subprocess_context(_test_metadata_search_path_inner)()
758+
run_in_subprocess(_test_metadata_search_path_inner)
752759

753760

754761
def _test_metadata_search_path_inner(): # pragma: no cover
@@ -760,7 +767,7 @@ def _test_metadata_search_path_inner(): # pragma: no cover
760767

761768

762769
def test_metadata_distributions_after_deprecated_submodule():
763-
subprocess_context(_test_metadata_distributions_after_deprecated_submodule)()
770+
run_in_subprocess(_test_metadata_distributions_after_deprecated_submodule)
764771

765772

766773
def _test_metadata_distributions_after_deprecated_submodule():
@@ -779,7 +786,7 @@ def _test_metadata_distributions_after_deprecated_submodule():
779786

780787

781788
def test_parent_spec_after_deprecated_submodule():
782-
subprocess_context(_test_parent_spec_after_deprecated_submodule)()
789+
run_in_subprocess(_test_parent_spec_after_deprecated_submodule)
783790

784791

785792
def _test_parent_spec_after_deprecated_submodule():
@@ -791,7 +798,7 @@ def _test_parent_spec_after_deprecated_submodule():
791798
def test_type_repr_in_new_module():
792799
# to cater for metadata path finders
793800
# https://docs.python.org/3/library/importlib.metadata.html#extending-the-search-algorithm
794-
subprocess_context(_test_type_repr_in_new_module_inner)()
801+
run_in_subprocess(_test_type_repr_in_new_module_inner)
795802

796803

797804
def _test_type_repr_in_new_module_inner():
@@ -849,19 +856,19 @@ def _test_broken_module_3_inner():
849856

850857

851858
def test_deprecated_module_error_handling_1():
852-
subprocess_context(_test_broken_module_1_inner)()
859+
run_in_subprocess(_test_broken_module_1_inner)
853860

854861

855862
def test_deprecated_module_error_handling_2():
856-
subprocess_context(_test_broken_module_2_inner)()
863+
run_in_subprocess(_test_broken_module_2_inner)
857864

858865

859866
def test_deprecated_module_error_handling_3():
860-
subprocess_context(_test_broken_module_3_inner)()
867+
run_in_subprocess(_test_broken_module_3_inner)
861868

862869

863870
def test_new_module_is_top_level():
864-
subprocess_context(_test_new_module_is_top_level_inner)()
871+
run_in_subprocess(_test_new_module_is_top_level_inner)
865872

866873

867874
def _test_new_module_is_top_level_inner():
@@ -877,7 +884,7 @@ def _test_new_module_is_top_level_inner():
877884

878885

879886
def test_import_deprecated_with_no_attribute():
880-
subprocess_context(_test_import_deprecated_with_no_attribute_inner)()
887+
run_in_subprocess(_test_import_deprecated_with_no_attribute_inner)
881888

882889

883890
def _test_import_deprecated_with_no_attribute_inner():
@@ -970,23 +977,23 @@ def module_repr(self, module: ModuleType) -> str:
970977

971978
def test_subprocess_test_failure():
972979
with pytest.raises(Failed, match='ValueError.*this fails'):
973-
subprocess_context(_test_subprocess_test_failure_inner)()
980+
run_in_subprocess(_test_subprocess_test_failure_inner)
974981

975982

976983
def _test_subprocess_test_failure_inner():
977984
raise ValueError('this fails')
978985

979986

980987
def test_dir_is_still_valid():
981-
subprocess_context(_dir_is_still_valid_inner)()
988+
run_in_subprocess(_dir_is_still_valid_inner)
982989

983990

984991
def _dir_is_still_valid_inner():
985992
"""to ensure that create_attribute=True keeps the dir(module) intact"""
986993

987994
import cirq.testing._compat_test_data as mod
988995

989-
for m in ['fake_a', 'info', 'module_a', 'sys']:
996+
for m in ['fake_a', 'logging', 'module_a']:
990997
assert m in dir(mod)
991998

992999

cirq-core/cirq/testing/_compat_test_data/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
See cirq/_compat_test.py for the tests.
44
This module contains example deprecations for modules.
55
"""
6-
import sys
7-
from logging import info
6+
import logging
7+
88
from cirq import _compat
99

10-
info("init:compat_test_data")
10+
logging.info("init:compat_test_data")
1111

1212
# simulates a rename of a child module
1313
# fake_a -> module_a
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# pylint: disable=wrong-or-nonexistent-copyright-notice
22
"""module_a for module deprecation tests"""
33

4-
from logging import info
4+
import logging
55

66
from cirq.testing._compat_test_data.module_a import module_b
77

@@ -11,4 +11,4 @@
1111

1212
MODULE_A_ATTRIBUTE = "module_a"
1313

14-
info("init:module_a")
14+
logging.info("init:module_a")

0 commit comments

Comments
 (0)