diff --git a/aws_xray_sdk/core/async_context.py b/aws_xray_sdk/core/async_context.py index b287a42f..acba00e2 100644 --- a/aws_xray_sdk/core/async_context.py +++ b/aws_xray_sdk/core/async_context.py @@ -1,5 +1,6 @@ import asyncio import sys +import copy from .context import Context as _Context @@ -108,6 +109,18 @@ def task_factory(loop, coro): else: current_task = asyncio.Task.current_task(loop=loop) if current_task is not None and hasattr(current_task, 'context'): - setattr(task, 'context', current_task.context) + if current_task.context.get('entities'): + # NOTE: (enowell) Because the `AWSXRayRecorder`'s `Context` decides + # the parent by looking at its `_local.entities`, we must copy the entities + # for concurrent subsegments. Otherwise, the subsegments would be + # modifying the same `entities` list and sugsegments would take other + # subsegments as parents instead of the original `segment`. + # + # See more: https://github.com/aws/aws-xray-sdk-python/blob/0f13101e4dba7b5c735371cb922f727b1d9f46d8/aws_xray_sdk/core/context.py#L90-L101 + new_context = copy.copy(current_task.context) + new_context['entities'] = [item for item in current_task.context['entities']] + else: + new_context = current_task.context + setattr(task, 'context', new_context) return task diff --git a/tests/test_async_recorder.py b/tests/test_async_recorder.py index eba147f7..0367fb3c 100644 --- a/tests/test_async_recorder.py +++ b/tests/test_async_recorder.py @@ -3,6 +3,7 @@ from .util import get_new_stubbed_recorder from aws_xray_sdk.version import VERSION from aws_xray_sdk.core.async_context import AsyncContext +import asyncio xray_recorder = get_new_stubbed_recorder() @@ -43,6 +44,28 @@ async def test_capture(loop): assert platform.python_implementation() == service.get('runtime') assert platform.python_version() == service.get('runtime_version') +async def test_concurrent_calls(loop): + xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop)) + async with xray_recorder.in_segment_async('segment') as segment: + global counter + counter = 0 + total_tasks = 10 + flag = asyncio.Event() + async def assert_task(): + async with xray_recorder.in_subsegment_async('segment') as subsegment: + global counter + counter += 1 + # Begin all subsegments before closing any to ensure they overlap + if counter < total_tasks: + await flag.wait() + else: + flag.set() + return subsegment.parent_id + tasks = [assert_task() for task in range(total_tasks)] + subsegs_parent_ids = await asyncio.gather(*tasks) + for subseg_parent_id in subsegs_parent_ids: + assert subseg_parent_id == segment.id + async def test_async_context_managers(loop): xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop))