Skip to content

Commit cdc6a39

Browse files
MarcoGorellitwiecki
authored andcommitted
no print statements
1 parent 04cdd96 commit cdc6a39

File tree

4 files changed

+13
-6
lines changed

4 files changed

+13
-6
lines changed

.pre-commit-config.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ repos:
3333
- id: pylint
3434
args: [--rcfile=.pylintrc]
3535
files: ^pymc3/
36+
- repo: https://github.com/MarcoGorelli/madforhooks
37+
rev: 0.2.1
38+
hooks:
39+
- id: no-print-statements
40+
files: ^pymc3/
3641
- repo: local
3742
hooks:
3843
- id: check-no-tests-are-ignored

pymc3/math.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def perform(self, node, inputs, outputs, params=None):
294294
log_det = np.sum(np.log(np.abs(s)))
295295
z[0] = np.asarray(log_det, dtype=x.dtype)
296296
except Exception:
297-
print(f"Failed to compute logdet of {x}.")
297+
print(f"Failed to compute logdet of {x}.", file=sys.stdout)
298298
raise
299299

300300
def grad(self, inputs, g_outputs):

pymc3/sampling_jax.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# pylint: skip-file
22
import os
33
import re
4+
import sys
45
import warnings
56

67
xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
@@ -206,7 +207,7 @@ def sample_numpyro_nuts(
206207
rv_samples.name = rv.name
207208
sample_outputs.append(rv_samples)
208209

209-
print("Compiling...")
210+
print("Compiling...", file=sys.stdout)
210211

211212
tic1 = pd.Timestamp.now()
212213
_sample = compile_rv_inplace(
@@ -219,14 +220,14 @@ def sample_numpyro_nuts(
219220
)
220221
tic2 = pd.Timestamp.now()
221222

222-
print("Compilation time = ", tic2 - tic1)
223+
print("Compilation time = ", tic2 - tic1, file=sys.stdout)
223224

224-
print("Sampling...")
225+
print("Sampling...", file=sys.stdout)
225226

226227
*mcmc_samples, leapfrogs_taken = _sample()
227228
tic3 = pd.Timestamp.now()
228229

229-
print("Sampling time = ", tic3 - tic2)
230+
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
230231

231232
posterior = {k.name: v for k, v in zip(sample_outputs, mcmc_samples)}
232233

pymc3/tuning/starting.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
@author: johnsalvatier
1919
"""
2020
import copy
21+
import sys
2122

2223
import aesara.gradient as tg
2324
import numpy as np
@@ -153,7 +154,7 @@ def dlogp_func(x):
153154
assert isinstance(cost_func.progress, ProgressBar)
154155
cost_func.progress.total = last_v
155156
cost_func.progress.update(last_v)
156-
print()
157+
print(file=sys.stdout)
157158

158159
mx0 = RaveledVars(mx0, x0.point_map_info)
159160

0 commit comments

Comments
 (0)