@@ -83,13 +83,13 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive(
83
83
submitted_at = datetime (2023 , 1 , 2 , 5 , 12 , 30 , 5 , tzinfo = timezone .utc ),
84
84
job_provisioning_data = job_provisioning_data ,
85
85
)
86
- with patch (
87
- "dstack._internal.server.services.runner.ssh.SSHTunnel"
88
- ) as SSHTunnelMock , patch (
89
- "dstack._internal.server.services.runner.client.RunnerClient"
90
- ) as RunnerClientMock , patch (
91
- "dstack._internal.utils.common.get_current_datetime"
92
- ) as datetime_mock :
86
+ with (
87
+ patch ( "dstack._internal.server.services.runner.ssh.SSHTunnel" ) as SSHTunnelMock ,
88
+ patch (
89
+ "dstack._internal.server.services.runner.client.RunnerClient"
90
+ ) as RunnerClientMock ,
91
+ patch ( "dstack._internal.utils.common.get_current_datetime" ) as datetime_mock ,
92
+ ):
93
93
datetime_mock .return_value = datetime (2023 , 1 , 2 , 5 , 12 , 30 , 10 , tzinfo = timezone .utc )
94
94
runner_client_mock = RunnerClientMock .return_value
95
95
runner_client_mock .healthcheck = Mock ()
@@ -123,11 +123,12 @@ async def test_runs_provisioning_job(self, test_db, session: AsyncSession):
123
123
status = JobStatus .PROVISIONING ,
124
124
job_provisioning_data = job_provisioning_data ,
125
125
)
126
- with patch (
127
- "dstack._internal.server.services.runner.ssh.SSHTunnel"
128
- ) as SSHTunnelMock , patch (
129
- "dstack._internal.server.services.runner.client.RunnerClient"
130
- ) as RunnerClientMock :
126
+ with (
127
+ patch ("dstack._internal.server.services.runner.ssh.SSHTunnel" ) as SSHTunnelMock ,
128
+ patch (
129
+ "dstack._internal.server.services.runner.client.RunnerClient"
130
+ ) as RunnerClientMock ,
131
+ ):
131
132
runner_client_mock = RunnerClientMock .return_value
132
133
runner_client_mock .healthcheck .return_value = HealthcheckResponse (
133
134
service = "dstack-runner" , version = "0.0.1.dev2"
@@ -164,11 +165,13 @@ async def test_updates_running_job(self, test_db, session: AsyncSession, tmp_pat
164
165
status = JobStatus .RUNNING ,
165
166
job_provisioning_data = job_provisioning_data ,
166
167
)
167
- with patch (
168
- "dstack._internal.server.services.runner.ssh.SSHTunnel"
169
- ) as SSHTunnelMock , patch (
170
- "dstack._internal.server.services.runner.client.RunnerClient"
171
- ) as RunnerClientMock , patch .object (settings , "SERVER_DIR_PATH" , tmp_path ):
168
+ with (
169
+ patch ("dstack._internal.server.services.runner.ssh.SSHTunnel" ) as SSHTunnelMock ,
170
+ patch (
171
+ "dstack._internal.server.services.runner.client.RunnerClient"
172
+ ) as RunnerClientMock ,
173
+ patch .object (settings , "SERVER_DIR_PATH" , tmp_path ),
174
+ ):
172
175
runner_client_mock = RunnerClientMock .return_value
173
176
runner_client_mock .pull .return_value = PullResponse (
174
177
job_states = [JobStateEvent (timestamp = 1 , state = JobStatus .RUNNING )],
@@ -182,11 +185,12 @@ async def test_updates_running_job(self, test_db, session: AsyncSession, tmp_pat
182
185
assert job is not None
183
186
assert job .status == JobStatus .RUNNING
184
187
assert job .runner_timestamp == 1
185
- with patch (
186
- "dstack._internal.server.services.runner.ssh.SSHTunnel"
187
- ) as SSHTunnelMock , patch (
188
- "dstack._internal.server.services.runner.client.RunnerClient"
189
- ) as RunnerClientMock :
188
+ with (
189
+ patch ("dstack._internal.server.services.runner.ssh.SSHTunnel" ) as SSHTunnelMock ,
190
+ patch (
191
+ "dstack._internal.server.services.runner.client.RunnerClient"
192
+ ) as RunnerClientMock ,
193
+ ):
190
194
runner_client_mock = RunnerClientMock .return_value
191
195
runner_client_mock .pull .return_value = PullResponse (
192
196
job_states = [JobStateEvent (timestamp = 1 , state = JobStatus .DONE )],
@@ -251,11 +255,10 @@ async def test_provisioning_shim_with_volumes(
251
255
status = JobStatus .PROVISIONING ,
252
256
job_provisioning_data = job_provisioning_data ,
253
257
)
254
- with patch (
255
- "dstack._internal.server.services.runner.ssh.SSHTunnel"
256
- ) as SSHTunnelMock , patch (
257
- "dstack._internal.server.services.runner.client.ShimClient"
258
- ) as ShimClientMock :
258
+ with (
259
+ patch ("dstack._internal.server.services.runner.ssh.SSHTunnel" ) as SSHTunnelMock ,
260
+ patch ("dstack._internal.server.services.runner.client.ShimClient" ) as ShimClientMock ,
261
+ ):
259
262
ShimClientMock .return_value .healthcheck .return_value = HealthcheckResponse (
260
263
service = "dstack-shim" , version = "0.0.1.dev2"
261
264
)
@@ -303,13 +306,13 @@ async def test_pulling_shim(self, test_db, session: AsyncSession):
303
306
status = JobStatus .PULLING ,
304
307
job_provisioning_data = job_provisioning_data ,
305
308
)
306
- with patch (
307
- "dstack._internal.server.services.runner.ssh.SSHTunnel"
308
- ) as SSHTunnelMock , patch (
309
- "dstack._internal.server.services.runner.client.RunnerClient"
310
- ) as RunnerClientMock , patch (
311
- "dstack._internal.server.services.runner.client.ShimClient"
312
- ) as ShimClientMock :
309
+ with (
310
+ patch ( "dstack._internal.server.services.runner.ssh.SSHTunnel" ) as SSHTunnelMock ,
311
+ patch (
312
+ "dstack._internal.server.services.runner.client.RunnerClient"
313
+ ) as RunnerClientMock ,
314
+ patch ( "dstack._internal.server.services.runner.client.ShimClient" ) as ShimClientMock ,
315
+ ):
313
316
RunnerClientMock .return_value .healthcheck .return_value = HealthcheckResponse (
314
317
service = "dstack-runner" , version = "0.0.1.dev2"
315
318
)
@@ -355,9 +358,10 @@ async def test_pulling_shim_failed(self, test_db, session: AsyncSession):
355
358
job_provisioning_data = job_provisioning_data ,
356
359
instance = instance ,
357
360
)
358
- with patch (
359
- "dstack._internal.server.services.runner.ssh.SSHTunnel"
360
- ) as SSHTunnelMock , patch ("dstack._internal.server.services.runner.ssh.time.sleep" ):
361
+ with (
362
+ patch ("dstack._internal.server.services.runner.ssh.SSHTunnel" ) as SSHTunnelMock ,
363
+ patch ("dstack._internal.server.services.runner.ssh.time.sleep" ),
364
+ ):
361
365
SSHTunnelMock .side_effect = SSHError
362
366
await process_running_jobs ()
363
367
assert SSHTunnelMock .call_count == 3
0 commit comments