Skip to content

Commit fa71ece

Browse files
committed
Add more tests, including asyncio_ensure_future wrapper
1 parent 1c76b88 commit fa71ece

File tree

4 files changed

+41
-13
lines changed

4 files changed

+41
-13
lines changed

sdk/core/azure-core/azure/core/polling/_async_poller.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def __init__(
128128

129129
self._polling_method.initialize(client, initial_response, deserialization_callback)
130130

131+
def polling_method(self) -> AsyncPollingMethod[PollingReturnType]:
132+
"""Return the polling method associated to this poller.
133+
"""
134+
return self._polling_method
135+
131136
def continuation_token(self) -> str:
132137
"""Return a continuation token that allows to restart the poller later.
133138

sdk/core/azure-core/azure/core/polling/_poller.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ def _start(self):
196196
call(self._polling_method)
197197
callbacks, self._callbacks = self._callbacks, []
198198

199+
def polling_method(self):
200+
# type: () -> PollingMethod[PollingReturnType]
201+
"""Return the polling method associated to this poller.
202+
"""
203+
return self._polling_method
204+
199205
def continuation_token(self):
200206
# type: () -> str
201207
"""Return a continuation token that allows to restart the poller later.

sdk/core/azure-core/tests/azure_core_asynctests/test_polling.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# THE SOFTWARE.
2424
#
2525
#--------------------------------------------------------------------------
26+
import asyncio
2627
import time
2728
try:
2829
from unittest import mock
@@ -89,6 +90,7 @@ async def run(self):
8990
"""Empty run, no polling.
9091
"""
9192
self._finished = True
93+
await asyncio.sleep(self._sleep) # Give me time to add callbacks!
9294

9395
def status(self):
9496
"""Return the current status as a string.
@@ -129,39 +131,53 @@ def deserialization_callback(response):
129131

130132
method = AsyncNoPolling()
131133

132-
poller = AsyncLROPoller(client, initial_response, deserialization_callback, method)
134+
raw_poller = AsyncLROPoller(client, initial_response, deserialization_callback, method)
135+
poller = asyncio.ensure_future(raw_poller.result())
136+
137+
done_cb = mock.MagicMock()
138+
poller.add_done_callback(done_cb)
133139

134-
result = await poller.result()
140+
result = await poller
135141
assert poller.done()
136142
assert result == "Treated: "+initial_response
137-
assert poller.status() == "succeeded"
143+
assert raw_poller.status() == "succeeded"
144+
assert raw_poller.polling_method() is method
145+
done_cb.assert_called_once_with(poller)
138146

139147
# Test with a basic Model
140148
poller = AsyncLROPoller(client, initial_response, Model, method)
141149
assert poller._polling_method._deserialization_callback == Model.deserialize
142150

143151
# Test poller that method do a run
144152
method = PollingTwoSteps(sleep=1)
145-
poller = AsyncLROPoller(client, initial_response, deserialization_callback, method)
153+
raw_poller = AsyncLROPoller(client, initial_response, deserialization_callback, method)
154+
poller = asyncio.ensure_future(raw_poller.result())
155+
156+
done_cb = mock.MagicMock()
157+
done_cb2 = mock.MagicMock()
158+
poller.add_done_callback(done_cb)
159+
poller.remove_done_callback(done_cb2)
146160

147-
result = await poller.result()
161+
result = await poller
148162
assert result == "Treated: "+initial_response
149-
assert poller.status() == "succeeded"
163+
assert raw_poller.status() == "succeeded"
164+
done_cb.assert_called_once_with(poller)
165+
done_cb2.assert_not_called()
150166

151167
# Test continuation token
152-
cont_token = poller.continuation_token()
168+
cont_token = raw_poller.continuation_token()
153169

154170
method = PollingTwoSteps(sleep=1)
155171
new_poller = AsyncLROPoller.from_continuation_token(
156172
continuation_token=cont_token,
157173
client=client,
158174
initial_response=initial_response,
159-
deserialization_callback=Model,
175+
deserialization_callback=deserialization_callback,
160176
polling_method=method
161177
)
162-
result = await poller.result()
178+
result = await new_poller.result()
163179
assert result == "Treated: "+initial_response
164-
assert poller.status() == "succeeded"
180+
assert new_poller.status() == "succeeded"
165181

166182

167183
@pytest.mark.asyncio

sdk/core/azure-core/tests/test_polling.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def deserialization_callback(response):
163163
assert poller.done()
164164
assert result == "Treated: "+initial_response
165165
assert poller.status() == "succeeded"
166+
assert poller.polling_method() is method
166167
done_cb.assert_called_once_with(method)
167168

168169
# Test with a basic Model
@@ -196,12 +197,12 @@ def deserialization_callback(response):
196197
continuation_token=cont_token,
197198
client=client,
198199
initial_response=initial_response,
199-
deserialization_callback=Model,
200+
deserialization_callback=deserialization_callback,
200201
polling_method=method
201202
)
202-
result = poller.result()
203+
result = new_poller.result()
203204
assert result == "Treated: "+initial_response
204-
assert poller.status() == "succeeded"
205+
assert new_poller.status() == "succeeded"
205206

206207

207208
def test_broken_poller(client):

0 commit comments

Comments
 (0)