Skip to content

Commit b470a75

Browse files
authored
Use functools.wraps on when wrapping callbacks (#137)
This preserves the wrapped callback name when stacking multiple middleware layers Add subscribe_broadcast and subscribe_temporary to TimerMiddleware
1 parent b91d8de commit b470a75

File tree

3 files changed

+67
-5
lines changed

3 files changed

+67
-5
lines changed

src/workflows/transport/middleware/__init__.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,45 @@ def __init__(self, logger: logging.Logger = None, level=logging.INFO):
178178
self.level = level
179179

180180
def subscribe(self, call_next: Callable, channel, callback, **kwargs) -> int:
181+
@functools.wraps(callback)
182+
def wrapped_callback(header, message):
183+
start_time = time.perf_counter()
184+
result = callback(header, message)
185+
end_time = time.perf_counter()
186+
source = get_callback_source(callback)
187+
self.logger.log(
188+
self.level,
189+
f"Callback for {source} took {end_time - start_time:.4f} seconds",
190+
)
191+
return result
192+
193+
return call_next(channel, wrapped_callback, **kwargs)
194+
195+
def subscribe_temporary(
196+
self,
197+
call_next: Callable,
198+
channel_hint: Optional[str],
199+
callback: MessageCallback,
200+
**kwargs,
201+
) -> TemporarySubscription:
202+
@functools.wraps(callback)
203+
def wrapped_callback(header, message):
204+
start_time = time.perf_counter()
205+
result = callback(header, message)
206+
end_time = time.perf_counter()
207+
source = get_callback_source(callback)
208+
self.logger.log(
209+
self.level,
210+
f"Callback for {source} took {end_time - start_time:.4f} seconds",
211+
)
212+
return result
213+
214+
return call_next(channel_hint, wrapped_callback, **kwargs)
215+
216+
def subscribe_broadcast(
217+
self, call_next: Callable, channel, callback, **kwargs
218+
) -> int:
219+
@functools.wraps(callback)
181220
def wrapped_callback(header, message):
182221
start_time = time.perf_counter()
183222
result = callback(header, message)

src/workflows/transport/middleware/prometheus.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import functools
34
import time
45
from typing import Callable, Optional
56

@@ -82,6 +83,7 @@ def __init__(self, source: str):
8283
self.source = source
8384

8485
def subscribe(self, call_next: Callable, channel, callback, **kwargs) -> int:
86+
@functools.wraps(callback)
8587
def wrapped_callback(header, message):
8688
start_time = time.perf_counter()
8789
result = callback(header, message)
@@ -102,6 +104,7 @@ def subscribe_temporary(
102104
callback: MessageCallback,
103105
**kwargs,
104106
) -> TemporarySubscription:
107+
@functools.wraps(callback)
105108
def wrapped_callback(header, message):
106109
start_time = time.perf_counter()
107110
result = callback(header, message)
@@ -118,6 +121,7 @@ def wrapped_callback(header, message):
118121
def subscribe_broadcast(
119122
self, call_next: Callable, channel, callback, **kwargs
120123
) -> int:
124+
@functools.wraps(callback)
121125
def wrapped_callback(header, message):
122126
start_time = time.perf_counter()
123127
result = callback(header, message)

tests/transport/test_middleware.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,34 @@ def test_timer_middleware(caplog):
5454
def callback(header, message):
5555
time.sleep(1)
5656

57-
subscription_id = offline.subscribe(str(mock.sentinel.channel), callback)
57+
sid_1 = offline.subscribe(str(mock.sentinel.channel), callback)
58+
sid_2 = offline.subscribe_broadcast(str(mock.sentinel.channel), callback)
59+
ts = offline.subscribe_temporary(str(mock.sentinel.channel), callback)
60+
61+
expected_text = (
62+
"Callback for test_middleware:test_timer_middleware.<locals>.callback took"
63+
)
64+
5865
with caplog.at_level(logging.DEBUG):
59-
offline.subscription_callback(subscription_id)(
66+
offline.subscription_callback(sid_1)(
6067
{"destination": "foo"}, str(mock.sentinel.message)
6168
)
62-
assert (
63-
"Callback for test_middleware:test_timer_middleware.<locals>.callback took"
64-
in caplog.text
69+
assert expected_text in caplog.text
70+
caplog.clear()
71+
72+
with caplog.at_level(logging.DEBUG):
73+
offline.subscription_callback(sid_2)(
74+
{"destination": "bar"}, str(mock.sentinel.message)
75+
)
76+
assert expected_text in caplog.text
77+
caplog.clear()
78+
79+
with caplog.at_level(logging.DEBUG):
80+
offline.subscription_callback(ts.subscription_id)(
81+
{"destination": "foobar"}, str(mock.sentinel.message)
6582
)
83+
assert expected_text in caplog.text
84+
caplog.clear()
6685

6786

6887
def test_prometheus_middleware():

0 commit comments

Comments
 (0)