Skip to content

Commit ec46270

Browse files
authored
fix: progress bar display (#420)
* fix: progress bar display * edit PathStatus.ELBO_ARGMAX_AT_ZERO message for clearer explanation
1 parent 00a4ca3 commit ec46270

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

pymc_extras/inference/pathfinder/pathfinder.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from pytensor.tensor import TensorConstant, TensorVariable
6161
from rich.console import Console, Group
6262
from rich.padding import Padding
63+
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
6364
from rich.table import Table
6465
from rich.text import Text
6566

@@ -1395,7 +1396,7 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
13951396

13961397
path_status_message = {
13971398
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter is may be too close to the mean posterior and a poor exploration of the parameter space. Consider increasing jitter if this occurence is high relative to the number of paths.",
1398-
PathStatus.INVALID_LOGP: "Invalid logP values occur when a path's logP values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1399+
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
13991400
PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
14001401
}
14011402

@@ -1521,12 +1522,20 @@ def multipath_pathfinder(
15211522
results = []
15221523
compute_start = time.time()
15231524
try:
1524-
with CustomProgress(
1525+
desc = f"Paths Complete: {{path_idx}}/{num_paths}"
1526+
progress = CustomProgress(
1527+
"[progress.description]{task.description}",
1528+
BarColumn(),
1529+
"[progress.percentage]{task.percentage:>3.0f}%",
1530+
TimeRemainingColumn(),
1531+
TextColumn("/"),
1532+
TimeElapsedColumn(),
15251533
console=Console(theme=default_progress_theme),
15261534
disable=not progressbar,
1527-
) as progress:
1528-
task = progress.add_task("Fitting", total=num_paths)
1529-
for result in generator:
1535+
)
1536+
with progress:
1537+
task = progress.add_task(desc.format(path_idx=0), completed=0, total=num_paths)
1538+
for path_idx, result in enumerate(generator, start=1):
15301539
try:
15311540
if isinstance(result, Exception):
15321541
raise result
@@ -1552,7 +1561,14 @@ def multipath_pathfinder(
15521561
lbfgs_status=LBFGSStatus.LBFGS_FAILED,
15531562
)
15541563
)
1555-
progress.update(task, advance=1)
1564+
finally:
1565+
# TODO: display LBFGS and Path Status in real time
1566+
progress.update(
1567+
task,
1568+
description=desc.format(path_idx=path_idx),
1569+
completed=path_idx,
1570+
refresh=True,
1571+
)
15561572
except (KeyboardInterrupt, StopIteration) as e:
15571573
# if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData.
15581574
if isinstance(e, StopIteration):

0 commit comments

Comments
 (0)