Skip to content

Use RandomVariables for Minibatch #6277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ricardoV94 opened this issue Nov 8, 2022 · 2 comments · Fixed by #6304
Closed

Use RandomVariables for Minibatch #6277

ricardoV94 opened this issue Nov 8, 2022 · 2 comments · Fixed by #6304

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 8, 2022

Minibatch creates a stochastic view of the underlying dataset using some pretty funky Aesara magic.

In a straightforward implementation it would do something like:

import pymc as pm
import numpy as np

data = pm.Normal.dist(size=(100, 2)).eval()

with pm.Model() as m:
  data = pm.Data("data", data)

  mb_slice_start, mb_slice_end = pm.Uniform.dist(0, data.shape[0], size=(2)).sort()
  minibatch_data = data[mb_slice_start: mb_slice_end]
  
  x = pm.Normal("x", 0, 1, observed=minibatch_data)

And then PyMC also figures out that x lopg has to be scaled by the size of the unobserved values in each iteration.

One pressing issues is that it relies on the deprecated MRG sampler #4523, which also has no support for JAX or Numba. It isn't as simple a matter as switching to the default RandomStream which returns RandomVariables, because we treat the shared RNG variables used by RandomVariables in a special manner inside aesaraf.compile_pymc (the function through which all PyMC functions go through before calling aesara.function). Specifically we always set the values of the distinct shared RNG variables to something new, based on what random_seed is passed to compile_pymc.

This is a problem when you need to synchronize multiple mini-batches of data, say when you have a linear regression model with a minibatch of x and y. The way it currently works is that you set the same random_seed to each (default is hard-coded 42), and then rely on this initial value being the same so that the endpoints of the slice stay synchronized. But if we were to use the standard RandomStream, and pass the graph to compile_pymc, the two RNGs associated with the two slices would be overwritten to different values, regardless of the fact that they started with the same value.

Anyway, this trick might also not have been enough, since there was a need to introduce align_minibatches in #2760:

pymc/pymc/data.py

Lines 463 to 474 in faebc60

def align_minibatches(batches=None):
if batches is None:
for rngs in Minibatch.RNG.values():
for rng in rngs:
rng.seed()
else:
for b in batches:
if not isinstance(b, Minibatch):
raise TypeError(f"{b} is not a Minibatch")
for rng in Minibatch.RNG[id(b)]:
rng.seed()

https://github.com/pymc-devs/pymc/blob/2296350959e4035b4f1ee13fab88014d4f0fa545/pymc/tests/test_data.py#L754-L773

The big issue here is that we are not representing that the slice endpoints are the same symbolically. The current graph looks something like:

with pm.Model() as m:
  x = pm.Data("x", x)
  y = pm.Data("y", y)

  rng = aesara.shared(np.random.default_rng(42))
  mb_slice_start, mb_slice_end = pm.Uniform.dist(0, data.shape[0], size=(2), rng=rng).sort()
  minibatch_x = x[mb_slice_start: mb_slice_end]

  rng = aesara.shared(np.random.default_rng(42))
  mb_slice_start, mb_slice_end = pm.Uniform.dist(0, data.shape[0], size=(2), rng=rng).sort()
  minibatch_y = y[mb_slice_start: mb_slice_end]
  
  obs = pm.Normal("obs", x, 1, observed=minibatch_y)

The two rngs were set to the same initial value, but other than that they are completely different as far as Aesara cares. The correct graph would be:

with pm.Model() as m:
  x = pm.Data("x", x)
  y = pm.Data("y", y)
  
  rng = aesara.shared(np.random.default_rng(42))
  mb_slice_start, mb_slice_end = pm.Uniform.dist(0, data.shape[0], shape=(2,), rng=rng).sort()
  minibatch_x = x[mb_slice_start: mb_slice_end]
  minibatch_y = y[mb_slice_start: mb_slice_end]
  
  obs = pm.Normal("obs", x, 1, observed=minibatch_y)

That is, using the same slicing RV for both mini-batches. Regardless of the seed value that is set by compile_pymc, the two minibatches will always be compatible.

The incremental proposal

Refactor Minibatch so that it can accept multiple variables at once, which will share the same random slices. Then start using RandomVariables instead of MRG stuff, and get rid of seed at the minibatch level

with pm.Model() as m:
  x, y = pm.Minibatch({x: (10, ...), y: (10)})

Maybe even sprinkle some dims if you want

The radical proposal

Offer the same API, but implement the minibatch view with the more straightforward code from the pseudo-examples in this issue.

No need to keep a global list of all the RNGs:

RNG = collections.defaultdict(list) # type: Dict[str, List[Any]]

The whole Op view magic shouldn't be needed AFAICT. The only thing is that right now PyMC complains if you pass something to observed other than a shared variable or constant data. The concern here is that the observations should not depend on other model RVs, but that's easy to check with aesara.graph.ancestors()

Apply(aesara.compile.view_op, inputs=[self.minibatch], outputs=[self])

The other issue is that we don't allow orphan RVs in the logp graph, but we can easily create a subtype like MinibatchUniformRV, that is allowed, just like SimulatorRVs are allowed.

This should make the Minibatch code much more readable and maintainable, as well as compatible with the non-C backends (I have no idea if either supports View)

@ferrine
Copy link
Member

ferrine commented Nov 15, 2022

The existing API is very very overly complicated, I'll use a radical approach of removing some functionality for the sake of simplicity and maintainance

def Minibatch(*variables: TensorVariable, batch_size) -> Tuple[TensorVariable]:
    """
    Get random slices from variables from the leading dimension.
    """

@ferrine
Copy link
Member

ferrine commented Nov 15, 2022

In case we need more functionality, we'll properly motivate it and test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants