Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 8bfded8

Browse files
Trace functions which return Awaitable (#15650)
1 parent 4e6390c commit 8bfded8

File tree

3 files changed

+59
-22
lines changed

3 files changed

+59
-22
lines changed

changelog.d/15650.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add support for tracing functions which return `Awaitable`s.

synapse/logging/opentracing.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def set_fates(clotho, lachesis, atropos, father="Zues", mother="Themis"):
171171
from typing import (
172172
TYPE_CHECKING,
173173
Any,
174+
Awaitable,
174175
Callable,
175176
Collection,
176177
ContextManager,
@@ -903,6 +904,7 @@ def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) ->
903904
"""
904905

905906
if inspect.iscoroutinefunction(func):
907+
# For this branch, we handle async functions like `async def func() -> RInner`.
906908
# In this branch, R = Awaitable[RInner], for some other type RInner
907909
@wraps(func)
908910
async def _wrapper(
@@ -914,36 +916,49 @@ async def _wrapper(
914916
return await func(*args, **kwargs) # type: ignore[misc]
915917

916918
else:
917-
# The other case here handles both sync functions and those
918-
# decorated with inlineDeferred.
919+
# The other case here handles sync functions including those decorated with
920+
# `@defer.inlineCallbacks` or that return a `Deferred` or other `Awaitable`.
919921
@wraps(func)
920-
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
922+
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
921923
scope = wrapping_logic(func, *args, **kwargs)
922924
scope.__enter__()
923925

924926
try:
925927
result = func(*args, **kwargs)
928+
926929
if isinstance(result, defer.Deferred):
927930

928931
def call_back(result: R) -> R:
929932
scope.__exit__(None, None, None)
930933
return result
931934

932935
def err_back(result: R) -> R:
936+
# TODO: Pass the error details into `scope.__exit__(...)` for
937+
# consistency with the other paths.
933938
scope.__exit__(None, None, None)
934939
return result
935940

936941
result.addCallbacks(call_back, err_back)
937942

943+
elif inspect.isawaitable(result):
944+
945+
async def wrap_awaitable() -> Any:
946+
try:
947+
assert isinstance(result, Awaitable)
948+
awaited_result = await result
949+
scope.__exit__(None, None, None)
950+
return awaited_result
951+
except Exception as e:
952+
scope.__exit__(type(e), None, e.__traceback__)
953+
raise
954+
955+
# The original method returned an awaitable, eg. a coroutine, so we
956+
# create another awaitable wrapping it that calls
957+
# `scope.__exit__(...)`.
958+
return wrap_awaitable()
938959
else:
939-
if inspect.isawaitable(result):
940-
logger.error(
941-
"@trace may not have wrapped %s correctly! "
942-
"The function is not async but returned a %s.",
943-
func.__qualname__,
944-
type(result).__name__,
945-
)
946-
960+
# Just a simple sync function so we can just exit the scope and
961+
# return the result without any fuss.
947962
scope.__exit__(None, None, None)
948963

949964
return result

tests/logging/test_opentracing.py

+32-11
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import cast
15+
from typing import Awaitable, cast
1616

1717
from twisted.internet import defer
1818
from twisted.test.proto_helpers import MemoryReactorClock
@@ -227,8 +227,6 @@ def test_trace_decorator_deferred(self) -> None:
227227
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
228228
with functions that return deferreds
229229
"""
230-
reactor = MemoryReactorClock()
231-
232230
with LoggingContext("root context"):
233231

234232
@trace_with_opname("fixture_deferred_func", tracer=self._tracer)
@@ -240,9 +238,6 @@ def fixture_deferred_func() -> "defer.Deferred[str]":
240238

241239
result_d1 = fixture_deferred_func()
242240

243-
# let the tasks complete
244-
reactor.pump((2,) * 8)
245-
246241
self.assertEqual(self.successResultOf(result_d1), "foo")
247242

248243
# the span should have been reported
@@ -256,8 +251,6 @@ def test_trace_decorator_async(self) -> None:
256251
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
257252
with async functions
258253
"""
259-
reactor = MemoryReactorClock()
260-
261254
with LoggingContext("root context"):
262255

263256
@trace_with_opname("fixture_async_func", tracer=self._tracer)
@@ -267,13 +260,41 @@ async def fixture_async_func() -> str:
267260

268261
d1 = defer.ensureDeferred(fixture_async_func())
269262

270-
# let the tasks complete
271-
reactor.pump((2,) * 8)
272-
273263
self.assertEqual(self.successResultOf(d1), "foo")
274264

275265
# the span should have been reported
276266
self.assertEqual(
277267
[span.operation_name for span in self._reporter.get_spans()],
278268
["fixture_async_func"],
279269
)
270+
271+
def test_trace_decorator_awaitable_return(self) -> None:
272+
"""
273+
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
274+
with functions that return an awaitable (e.g. a coroutine)
275+
"""
276+
with LoggingContext("root context"):
277+
# Something we can return without `await` to get a coroutine
278+
async def fixture_async_func() -> str:
279+
return "foo"
280+
281+
# The actual kind of function we want to test that returns an awaitable
282+
@trace_with_opname("fixture_awaitable_return_func", tracer=self._tracer)
283+
@tag_args
284+
def fixture_awaitable_return_func() -> Awaitable[str]:
285+
return fixture_async_func()
286+
287+
# Something we can run with `defer.ensureDeferred(runner())` and pump the
288+
# whole async tasks through to completion.
289+
async def runner() -> str:
290+
return await fixture_awaitable_return_func()
291+
292+
d1 = defer.ensureDeferred(runner())
293+
294+
self.assertEqual(self.successResultOf(d1), "foo")
295+
296+
# the span should have been reported
297+
self.assertEqual(
298+
[span.operation_name for span in self._reporter.get_spans()],
299+
["fixture_awaitable_return_func"],
300+
)

0 commit comments

Comments
 (0)