Skip to content

Commit d091e90

Browse files
r4victorpranitnaik43
authored andcommitted
Print message on dstack attach exit (dstackai#2358)
1 parent 7db10a7 commit d091e90

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

src/dstack/_internal/cli/commands/attach.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from dstack._internal.cli.commands import APIBaseCommand
88
from dstack._internal.cli.services.args import port_mapping
99
from dstack._internal.cli.services.completion import RunNameCompleter
10+
from dstack._internal.cli.services.configurators.run import (
11+
get_run_exit_code,
12+
print_finished_message,
13+
)
1014
from dstack._internal.cli.utils.common import console
1115
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT
1216
from dstack._internal.core.errors import CLIError
@@ -100,6 +104,21 @@ def _command(self, args: argparse.Namespace):
100104
pass
101105
finally:
102106
run.detach()
107+
# TODO: Handle run resubmissions similar to dstack apply
108+
109+
# After reading the logs, the run may not be marked as finished immediately.
110+
# Give the run some time to transition to a finished state before exiting.
111+
for _ in range(30):
112+
run.refresh()
113+
if run.status.is_finished():
114+
print_finished_message(run)
115+
exit(get_run_exit_code(run))
116+
time.sleep(1)
117+
console.print(
118+
"[error]Lost run connection. Timed out waiting for run final status."
119+
" Check `dstack ps` to see if it's done or failed."
120+
)
121+
exit(1)
103122

104123

105124
_IGNORED_PORTS = [DSTACK_RUNNER_HTTP_PORT]

src/dstack/_internal/cli/services/configurators/run.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ def apply_configuration(
238238
reattach = True
239239
break
240240
if run.status.is_finished():
241-
_print_finished_message(run)
242-
exit(_get_run_exit_code(run))
241+
print_finished_message(run)
242+
exit(get_run_exit_code(run))
243243
time.sleep(1)
244244
if not reattach:
245245
console.print(
@@ -439,7 +439,7 @@ def apply_args(
439439
):
440440
super().apply_args(conf, args, unknown)
441441
if args.ports:
442-
conf.ports = list(merge_ports(conf.ports, args.ports).values())
442+
conf.ports = list(_merge_ports(conf.ports, args.ports).values())
443443

444444

445445
class TaskConfigurator(RunWithPortsConfigurator):
@@ -475,17 +475,17 @@ def apply_args(self, conf: ServiceConfiguration, args: argparse.Namespace, unkno
475475
self.interpolate_run_args(conf.commands, unknown)
476476

477477

478-
def merge_ports(conf: List[PortMapping], args: List[PortMapping]) -> Dict[int, PortMapping]:
479-
unique_ports_constraint([pm.container_port for pm in conf])
480-
unique_ports_constraint([pm.container_port for pm in args])
478+
def _merge_ports(conf: List[PortMapping], args: List[PortMapping]) -> Dict[int, PortMapping]:
479+
_unique_ports_constraint([pm.container_port for pm in conf])
480+
_unique_ports_constraint([pm.container_port for pm in args])
481481
ports = {pm.container_port: pm for pm in conf}
482482
for pm in args: # override conf
483483
ports[pm.container_port] = pm
484-
unique_ports_constraint([pm.local_port for pm in ports.values() if pm.local_port is not None])
484+
_unique_ports_constraint([pm.local_port for pm in ports.values() if pm.local_port is not None])
485485
return ports
486486

487487

488-
def unique_ports_constraint(ports: List[int]):
488+
def _unique_ports_constraint(ports: List[int]):
489489
used_ports = set()
490490
for i in ports:
491491
if i in used_ports:
@@ -514,7 +514,7 @@ def _print_service_urls(run: Run) -> None:
514514
console.print()
515515

516516

517-
def _print_finished_message(run: Run):
517+
def print_finished_message(run: Run):
518518
if run.status == RunStatus.DONE:
519519
console.print("[code]Done[/]")
520520
return
@@ -542,7 +542,7 @@ def _print_finished_message(run: Run):
542542
console.print(f"[error]{message}[/]")
543543

544544

545-
def _get_run_exit_code(run: Run) -> int:
545+
def get_run_exit_code(run: Run) -> int:
546546
if run.status == RunStatus.DONE:
547547
return 0
548548
return 1

0 commit comments

Comments
 (0)