Skip to content

Commit c2ce47f

Browse files
ricardoV94twiecki
authored andcommitted
Make all but draws keyword-only arguments in sample.
Reorder arguments more logically [citation needed]
1 parent 0660efa commit c2ce47f

File tree

2 files changed

+61
-60
lines changed

2 files changed

+61
-60
lines changed

pymc/sampling/mcmc.py

Lines changed: 59 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -311,27 +311,27 @@ def _sample_external_nuts(
311311

312312
def sample(
313313
draws: int = 1000,
314+
*,
315+
tune: int = 1000,
316+
chains: Optional[int] = None,
317+
cores: Optional[int] = None,
318+
random_seed: RandomState = None,
319+
progressbar: bool = True,
314320
step=None,
321+
nuts_sampler: str = "pymc",
322+
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
315323
init: str = "auto",
324+
jitter_max_retries: int = 10,
316325
n_init: int = 200_000,
317-
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
318326
trace: Optional[BaseTrace] = None,
319-
chains: Optional[int] = None,
320-
cores: Optional[int] = None,
321-
tune: int = 1000,
322-
progressbar: bool = True,
323-
model: Optional[Model] = None,
324-
random_seed: RandomState = None,
325327
discard_tuned_samples: bool = True,
326328
compute_convergence_checks: bool = True,
327-
callback=None,
328-
jitter_max_retries: int = 10,
329-
*,
330-
nuts_sampler: str = "pymc",
331-
return_inferencedata: bool = True,
332329
keep_warning_stat: bool = False,
330+
return_inferencedata: bool = True,
333331
idata_kwargs: dict = None,
332+
callback=None,
334333
mp_ctx=None,
334+
model: Optional[Model] = None,
335335
**kwargs,
336336
) -> Union[InferenceData, MultiTrace]:
337337
r"""Draw samples from the posterior using the given step methods.
@@ -343,78 +343,79 @@ def sample(
343343
draws : int
344344
The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded
345345
by default. See ``discard_tuned_samples``.
346-
init : str
347-
Initialization method to use for auto-assigned NUTS samplers. See `pm.init_nuts` for a list
348-
of all options. This argument is ignored when manually passing the NUTS step method.
346+
tune : int
347+
Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or
348+
similar during tuning. Tuning samples will be drawn in addition to the number specified in
349+
the ``draws`` argument, and will be discarded unless ``discard_tuned_samples`` is set to
350+
False.
351+
chains : int
352+
The number of chains to sample. Running independent chains is important for some
353+
convergence statistics and can also reveal multiple modes in the posterior. If ``None``,
354+
then set to either ``cores`` or 2, whichever is larger.
355+
cores : int
356+
The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
357+
system, but at most 4.
358+
random_seed : int, array-like of int, RandomState or Generator, optional
359+
Random seed(s) used by the sampling steps. If a list, tuple or array of ints
360+
is passed, each entry will be used to seed each chain. A ValueError will be
361+
raised if the length does not match the number of chains.
362+
progressbar : bool, optional default=True
363+
Whether or not to display a progress bar in the command line. The bar shows the percentage
364+
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
365+
time until completion ("expected time of arrival"; ETA).
349366
Only applicable to the pymc nuts sampler.
350367
step : function or iterable of functions
351368
A step function or collection of functions. If there are variables without step methods,
352369
step methods for those variables will be assigned automatically. By default the NUTS step
353370
method will be used, if appropriate to the model.
354-
n_init : int
355-
Number of iterations of initializer. Only works for 'ADVI' init methods.
371+
nuts_sampler : str
372+
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
373+
This requires the chosen sampler to be installed.
374+
All samplers, except "pymc", require the full model to be continuous.
356375
initvals : optional, dict, array of dict
357376
Dict or list of dicts with initial value strategies to use instead of the defaults from
358377
`Model.initial_values`. The keys should be names of transformed random variables.
359378
Initialization methods for NUTS (see ``init`` keyword) can overwrite the default.
379+
init : str
380+
Initialization method to use for auto-assigned NUTS samplers. See `pm.init_nuts` for a list
381+
of all options. This argument is ignored when manually passing the NUTS step method.
382+
Only applicable to the pymc nuts sampler.
383+
jitter_max_retries : int
384+
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform
385+
jitter that yields a finite probability. This applies to ``jitter+adapt_diag`` and
386+
``jitter+adapt_full`` init methods.
387+
n_init : int
388+
Number of iterations of initializer. Only works for 'ADVI' init methods.
360389
trace : backend, optional
361390
A backend instance or None.
362391
If None, the NDArray backend is used.
363-
chains : int
364-
The number of chains to sample. Running independent chains is important for some
365-
convergence statistics and can also reveal multiple modes in the posterior. If ``None``,
366-
then set to either ``cores`` or 2, whichever is larger.
367-
cores : int
368-
The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
369-
system, but at most 4.
370-
tune : int
371-
Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or
372-
similar during tuning. Tuning samples will be drawn in addition to the number specified in
373-
the ``draws`` argument, and will be discarded unless ``discard_tuned_samples`` is set to
374-
False.
375-
progressbar : bool, optional default=True
376-
Whether or not to display a progress bar in the command line. The bar shows the percentage
377-
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
378-
time until completion ("expected time of arrival"; ETA).
379-
model : Model (optional if in ``with`` context)
380-
Model to sample from. The model needs to have free random variables.
381-
random_seed : int, array-like of int, RandomState or Generator, optional
382-
Random seed(s) used by the sampling steps. If a list, tuple or array of ints
383-
is passed, each entry will be used to seed each chain. A ValueError will be
384-
raised if the length does not match the number of chains.
385392
discard_tuned_samples : bool
386393
Whether to discard posterior samples of the tune interval.
387394
compute_convergence_checks : bool, default=True
388395
Whether to compute sampler statistics like Gelman-Rubin and ``effective_n``.
389-
callback : function, default=None
390-
A function which gets called for every sample from the trace of a chain. The function is
391-
called with the trace and the current draw and will contain all samples for a single trace.
392-
the ``draw.chain`` argument can be used to determine which of the active chains the sample
393-
is drawn from.
394-
Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback.
395-
jitter_max_retries : int
396-
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform
397-
jitter that yields a finite probability. This applies to ``jitter+adapt_diag`` and
398-
``jitter+adapt_full`` init methods.
399-
nuts_sampler : str
400-
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
401-
This requires the chosen sampler to be installed.
402-
All samplers, except "pymc", require the full model to be continuous.
403-
return_inferencedata : bool
404-
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a
405-
`MultiTrace` (False). Defaults to `True`.
406-
idata_kwargs : dict, optional
407-
Keyword arguments for :func:`pymc.to_inference_data`
408396
keep_warning_stat : bool
409397
If ``True`` the "warning" stat emitted by, for example, HMC samplers will be kept
410398
in the returned ``idata.sample_stat`` group.
411399
This leads to the ``idata`` not supporting ``.to_netcdf()`` or ``.to_zarr()`` and
412400
should only be set to ``True`` if you intend to use the "warning" objects right away.
413401
Defaults to ``False`` such that ``pm.drop_warning_stat`` is applied automatically,
414402
making the ``InferenceData`` compatible with saving.
403+
return_inferencedata : bool
404+
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a
405+
`MultiTrace` (False). Defaults to `True`.
406+
idata_kwargs : dict, optional
407+
Keyword arguments for :func:`pymc.to_inference_data`
408+
callback : function, default=None
409+
A function which gets called for every sample from the trace of a chain. The function is
410+
called with the trace and the current draw and will contain all samples for a single trace.
411+
the ``draw.chain`` argument can be used to determine which of the active chains the sample
412+
is drawn from.
413+
Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback.
415414
mp_ctx : multiprocessing.context.BaseContent
416415
A multiprocessing context for parallel sampling.
417416
See multiprocessing documentation for details.
417+
model : Model (optional if in ``with`` context)
418+
Model to sample from. The model needs to have free random variables.
418419
419420
Returns
420421
-------

pymc/tests/distributions/test_mixture.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def test_single_poisson_sampling(self):
456456
warnings.filterwarnings("ignore", "overflow encountered in exp", RuntimeWarning)
457457
trace = sample(
458458
5000,
459-
step,
459+
step=step,
460460
random_seed=self.random_seed,
461461
progressbar=False,
462462
chains=1,
@@ -783,7 +783,7 @@ def test_normal_mixture_sampling(self):
783783
warnings.filterwarnings("ignore", "overflow encountered in exp", RuntimeWarning)
784784
trace = sample(
785785
5000,
786-
step,
786+
step=step,
787787
random_seed=self.random_seed,
788788
progressbar=False,
789789
chains=1,

0 commit comments

Comments
 (0)