Skip to content

Add object mode fallback for Numba RandomVariables #1249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 91 additions & 32 deletions pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,50 +386,45 @@ def numba_funcify_RandomVariable_core(op: RandomVariable, **kwargs):
)


@numba_funcify.register
def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs):
core_shape = node.inputs[0]
def rv_fallback_impl(op: RandomVariableWithCoreShape, node):
"""Create a fallback implementation for random variables using object mode."""
import warnings

[rv_node] = op.fgraph.apply_nodes
rv_op: RandomVariable = rv_node.op

warnings.warn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have a generic fallback implementation function can't we just use it like we do for other Ops?

May just need to do the unboxing of the RV that the other function is doing

f"Numba will use object mode to execute the random variable {rv_op.name}",
UserWarning,
)

size = rv_op.size_param(rv_node)
dist_params = rv_op.dist_params(rv_node)
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
core_shape_len = get_vector_length(core_shape)
inplace = rv_op.inplace

core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
nin = 1 + len(dist_params) # rng + params
core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)

batch_ndim = rv_op.batch_ndim(rv_node)

# numba doesn't support nested literals right now...
input_bc_patterns = encode_literals(
tuple(input_var.type.broadcastable[:batch_ndim] for input_var in dist_params)
)
output_bc_patterns = encode_literals(
(rv_node.outputs[1].type.broadcastable[:batch_ndim],)
)
output_dtypes = encode_literals((rv_node.default_output().type.dtype,))
inplace_pattern = encode_literals(())

def random_wrapper(core_shape, rng, size, *dist_params):
if not inplace:
rng = copy(rng)

draws = _vectorized(
core_op_fn,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
(rng,),
dist_params,
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
None if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len),
fixed_size = (
None if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len)
)
return rng, draws

with numba.objmode(res="UniTuple(types.npy_rng, types.pyobject)"):
# Convert tuple params back to arrays for perform method
np_dist_params = [np.asarray(p) for p in dist_params]

# Prepare output storage for perform method
outputs = [[None], [None]]

# Call the perform method directly
rv_op.perform(rv_node, [rng, fixed_size, *np_dist_params], outputs)

next_rng = outputs[0][0]
result = outputs[1][0]
res = (next_rng, result)

return res

def random(core_shape, rng, size, *dist_params):
raise NotImplementedError("Non-jitted random variable not implemented")
Expand All @@ -439,3 +434,67 @@ def ov_random(core_shape, rng, size, *dist_params):
return random_wrapper

return random


@numba_funcify.register
def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs):
core_shape = node.inputs[0]

[rv_node] = op.fgraph.apply_nodes
rv_op: RandomVariable = rv_node.op
size = rv_op.size_param(rv_node)
dist_params = rv_op.dist_params(rv_node)
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
core_shape_len = get_vector_length(core_shape)
inplace = rv_op.inplace

try:
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only this line should be in the try except

nin = 1 + len(dist_params) # rng + params
core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)

batch_ndim = rv_op.batch_ndim(rv_node)

# numba doesn't support nested literals right now...
input_bc_patterns = encode_literals(
tuple(
input_var.type.broadcastable[:batch_ndim] for input_var in dist_params
)
)
output_bc_patterns = encode_literals(
(rv_node.outputs[1].type.broadcastable[:batch_ndim],)
)
output_dtypes = encode_literals((rv_node.default_output().type.dtype,))
inplace_pattern = encode_literals(())

def random_wrapper(core_shape, rng, size, *dist_params):
if not inplace:
rng = copy(rng)

draws = _vectorized(
core_op_fn,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
(rng,),
dist_params,
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
None
if size_len is None
else numba_ndarray.to_fixed_tuple(size, size_len),
)
return rng, draws

def random(core_shape, rng, size, *dist_params):
raise NotImplementedError("Non-jitted random variable not implemented")

@overload(random, jit_options=_jit_options)
def ov_random(core_shape, rng, size, *dist_params):
return random_wrapper

return random

except NotImplementedError:
# Fall back to object mode for random variables that don't have core implementation
return rv_fallback_impl(op, node)
49 changes: 49 additions & 0 deletions tests/link/numba/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,3 +705,52 @@ def test_repeated_args():
final_node = fn.maker.fgraph.outputs[0].owner
assert isinstance(final_node.op, RandomVariableWithCoreShape)
assert final_node.inputs[-2] is final_node.inputs[-1]


def test_unsupported_rv_fallback():
"""Test that unsupported random variables fallback to object mode."""
import warnings

# Create a mock random variable that doesn't have a numba implementation
class CustomRV(ptr.RandomVariable):
name = "custom"
signature = "(d)->(d)" # We need a parameter for test to pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create a univariate rv which will be a simpler test

dtype = "float64"

def _supp_shape_from_params(self, dist_params, param_shapes=None):
# Return the shape of the support
return [1]

def rng_fn(self, rng, value, size=None):
# Just return the value plus a random number
return value + rng.standard_normal()

custom_rv = CustomRV()

# Create a graph with the unsupported RV
rng = shared(np.random.default_rng(123))
value = np.array(1.0)
x = custom_rv(value, rng=rng)

# Capture warnings to check for the fallback warning
with warnings.catch_warnings(record=True) as w:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use pytest.warns

warnings.simplefilter("always")

# Compile with numba mode
fn = function([], x, mode=numba_mode)

# Execute to trigger the fallback
result = fn()

# Check that a warning was raised about object mode
assert any("object mode" in str(warning.message) for warning in w)

# Verify the result is as expected
assert isinstance(result, np.ndarray)

# Run again to make sure the compiled function works properly
result2 = fn()
assert isinstance(result2, np.ndarray)
assert not np.array_equal(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will fail because updates were not set

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually test with and without updates, in which case it should change or stay the same

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also set seed twice and compare to make sure it's following it

result, result2
) # Results should differ with different RNG states
Loading