Skip to content

Commit cd27c81

Browse files
faizan-mtensorflower-gardener
authored andcommitted
Migrate experimental_relax_shapes to reduce_retracing
PiperOrigin-RevId: 437742142
1 parent 953f063 commit cd27c81

File tree

5 files changed

+12
-12
lines changed

5 files changed

+12
-12
lines changed

discussion/examples/windowed_sampling.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -1837,9 +1837,9 @@
18371837
"WARNING:tensorflow:Note that RandomStandardNormal inside pfor op may not give same output as inside a sequential loop.\n",
18381838
"Fast window 75\n",
18391839
"Slow window 25\n",
1840-
"WARNING:tensorflow:5 out of the last 5 calls to \u003cfunction slow_window at 0x7f9456031ea0\u003e triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
1840+
"WARNING:tensorflow:5 out of the last 5 calls to \u003cfunction slow_window at 0x7f9456031ea0\u003e triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
18411841
"Slow window 50\n",
1842-
"WARNING:tensorflow:6 out of the last 6 calls to \u003cfunction slow_window at 0x7f9456031ea0\u003e triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
1842+
"WARNING:tensorflow:6 out of the last 6 calls to \u003cfunction slow_window at 0x7f9456031ea0\u003e triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
18431843
"Slow window 100\n",
18441844
"Slow window 200\n",
18451845
"Fast window 75\n",

discussion/turnkey_inference_candidate/window_tune_nuts_sampling.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _sample_posterior(target_log_prob_unconstrained,
4646
parallel_iterations=10,
4747
jit_compile=True,
4848
use_input_signature=False,
49-
experimental_relax_shapes=False):
49+
reduce_retracing=False):
5050
"""MCMC sampling with HMC/NUTS using an expanding epoch tuning scheme."""
5151

5252
seed_stream = tfp.util.SeedStream(seed, 'window_tune_nuts_sampling')
@@ -117,7 +117,7 @@ def _sample_posterior(target_log_prob_unconstrained,
117117
input_signature=input_signature,
118118
autograph=False,
119119
jit_compile=jit_compile,
120-
experimental_relax_shapes=experimental_relax_shapes)
120+
reduce_retracing=reduce_retracing)
121121
def fast_adaptation_interval(num_steps, previous_state):
122122
"""Step size only adaptation interval.
123123
@@ -179,7 +179,7 @@ def body_fn_window2(
179179
input_signature=input_signature,
180180
autograph=False,
181181
jit_compile=jit_compile,
182-
experimental_relax_shapes=experimental_relax_shapes)
182+
reduce_retracing=reduce_retracing)
183183
def slow_adaptation_interval(num_steps, previous_n, previous_state,
184184
previous_mean, previous_cov):
185185
"""Interval that tunes the mass matrix and step size simultaneously.
@@ -328,7 +328,7 @@ def window_tune_nuts_sampling(target_log_prob,
328328
parallel_iterations=10,
329329
jit_compile=True,
330330
use_input_signature=True,
331-
experimental_relax_shapes=False):
331+
reduce_retracing=False):
332332
"""Sample from a density with NUTS and an expanding window tuning scheme.
333333
334334
This function implements a turnkey MCMC sampling routine using NUTS and an
@@ -347,7 +347,7 @@ def window_tune_nuts_sampling(target_log_prob,
347347
of the tuning epoch (window 1, 2, and 3 in Stan [1]) run with two @tf.function
348348
compiled functions. The user can control the compilation options using the
349349
kwargs `jit_compile`, `use_input_signature`, and
350-
`experimental_relax_shapes`. Setting all to True would compile to XLA and
350+
`reduce_retracing`. Setting all to True would compile to XLA and
351351
potentially avoid the small overhead of function recompilation (note that it
352352
is not yet the case in XLA right now). It is not yet clear whether doing it
353353
this way is better than just wrapping the full inference routine in
@@ -403,7 +403,7 @@ def window_tune_nuts_sampling(target_log_prob,
403403
function is always compiled by XLA.
404404
use_input_signature: If True, generate an input_signature kwarg to pass to
405405
tf.function decorator.
406-
experimental_relax_shapes: kwarg pass to tf.function decorator. When True,
406+
reduce_retracing: kwarg pass to tf.function decorator. When True,
407407
tf.function may generate fewer, graphs that are less specialized on input
408408
shapes.
409409
@@ -564,6 +564,6 @@ def target_log_prob_unconstrained_concated(x):
564564
parallel_iterations=parallel_iterations,
565565
jit_compile=jit_compile,
566566
use_input_signature=use_input_signature,
567-
experimental_relax_shapes=experimental_relax_shapes)
567+
reduce_retracing=reduce_retracing)
568568
return forward_transform(
569569
split_and_reshape(nuts_samples)), diagnostic, conditioning_bijector

tensorflow_probability/examples/jupyter_notebooks/TFP_Release_Notebook_0_13_0.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@
954954
" c1=<tf.Tensor: shape=(), dtype=float32, numpy=1.7696666>,\n",
955955
" counts=<tf.Tensor: shape=(10,), dtype=float32, numpy=array([ 6., 10., 23., 7., 2., 20., 14., 16., 22., 17.], dtype=float32)>\n",
956956
")\n",
957-
"WARNING:tensorflow:6 out of the last 6 calls to <function windowed_adaptive_hmc at 0x7fda42bed8c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
957+
"WARNING:tensorflow:6 out of the last 6 calls to <function windowed_adaptive_hmc at 0x7fda42bed8c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
958958
"StructTuple(\n",
959959
" c0=<tf.Tensor: shape=(), dtype=float32, numpy=0.7161876>,\n",
960960
" c1=<tf.Tensor: shape=(), dtype=float32, numpy=1.7696666>,\n",

tensorflow_probability/python/distributions/gamma.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def _random_gamma_noncpu(
510510
# tf.function required to access Grappler's implementation_selector.
511511
@implementation_selection.never_runs_functions_eagerly
512512
# TODO(b/163029794): Shape relaxation breaks XLA.
513-
@tf.function(autograph=False, experimental_relax_shapes=False)
513+
@tf.function(autograph=False, reduce_retracing=False)
514514
def _random_gamma_no_gradient(
515515
shape, concentration, rate, log_rate, seed, log_space):
516516
"""Sample a gamma, CPU specialized to stateless_gamma.

tensorflow_probability/python/internal/backend/numpy/v2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161
def _function(func=None, input_signature=None, autograph=True, # pylint: disable=unused-argument
6262
experimental_autograph_options=None, # pylint: disable=unused-argument
63-
experimental_relax_shapes=False, jit_compile=None): # pylint: disable=unused-argument
63+
reduce_retracing=False, jit_compile=None): # pylint: disable=unused-argument
6464
"""Like `tf.function`, for JAX."""
6565
transform = lambda fn: fn
6666
if jit_compile:

0 commit comments

Comments
 (0)