-
Notifications
You must be signed in to change notification settings - Fork 129
Add object mode fallback for Numba RandomVariables #1249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -386,50 +386,45 @@ def numba_funcify_RandomVariable_core(op: RandomVariable, **kwargs): | |
) | ||
|
||
|
||
@numba_funcify.register | ||
def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs): | ||
core_shape = node.inputs[0] | ||
def rv_fallback_impl(op: RandomVariableWithCoreShape, node): | ||
"""Create a fallback implementation for random variables using object mode.""" | ||
import warnings | ||
|
||
[rv_node] = op.fgraph.apply_nodes | ||
rv_op: RandomVariable = rv_node.op | ||
|
||
warnings.warn( | ||
f"Numba will use object mode to execute the random variable {rv_op.name}", | ||
UserWarning, | ||
) | ||
|
||
size = rv_op.size_param(rv_node) | ||
dist_params = rv_op.dist_params(rv_node) | ||
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size) | ||
core_shape_len = get_vector_length(core_shape) | ||
inplace = rv_op.inplace | ||
|
||
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) | ||
nin = 1 + len(dist_params) # rng + params | ||
core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1) | ||
|
||
batch_ndim = rv_op.batch_ndim(rv_node) | ||
|
||
# numba doesn't support nested literals right now... | ||
input_bc_patterns = encode_literals( | ||
tuple(input_var.type.broadcastable[:batch_ndim] for input_var in dist_params) | ||
) | ||
output_bc_patterns = encode_literals( | ||
(rv_node.outputs[1].type.broadcastable[:batch_ndim],) | ||
) | ||
output_dtypes = encode_literals((rv_node.default_output().type.dtype,)) | ||
inplace_pattern = encode_literals(()) | ||
|
||
def random_wrapper(core_shape, rng, size, *dist_params): | ||
if not inplace: | ||
rng = copy(rng) | ||
|
||
draws = _vectorized( | ||
core_op_fn, | ||
input_bc_patterns, | ||
output_bc_patterns, | ||
output_dtypes, | ||
inplace_pattern, | ||
(rng,), | ||
dist_params, | ||
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),), | ||
None if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len), | ||
fixed_size = ( | ||
None if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len) | ||
) | ||
return rng, draws | ||
|
||
with numba.objmode(res="UniTuple(types.npy_rng, types.pyobject)"): | ||
# Convert tuple params back to arrays for perform method | ||
np_dist_params = [np.asarray(p) for p in dist_params] | ||
|
||
# Prepare output storage for perform method | ||
outputs = [[None], [None]] | ||
|
||
# Call the perform method directly | ||
rv_op.perform(rv_node, [rng, fixed_size, *np_dist_params], outputs) | ||
|
||
next_rng = outputs[0][0] | ||
result = outputs[1][0] | ||
res = (next_rng, result) | ||
|
||
return res | ||
|
||
def random(core_shape, rng, size, *dist_params): | ||
raise NotImplementedError("Non-jitted random variable not implemented") | ||
|
@@ -439,3 +434,67 @@ def ov_random(core_shape, rng, size, *dist_params): | |
return random_wrapper | ||
|
||
return random | ||
|
||
|
||
@numba_funcify.register | ||
def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs): | ||
core_shape = node.inputs[0] | ||
|
||
[rv_node] = op.fgraph.apply_nodes | ||
rv_op: RandomVariable = rv_node.op | ||
size = rv_op.size_param(rv_node) | ||
dist_params = rv_op.dist_params(rv_node) | ||
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size) | ||
core_shape_len = get_vector_length(core_shape) | ||
inplace = rv_op.inplace | ||
|
||
try: | ||
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only this line should be in the try except |
||
nin = 1 + len(dist_params) # rng + params | ||
core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1) | ||
|
||
batch_ndim = rv_op.batch_ndim(rv_node) | ||
|
||
# numba doesn't support nested literals right now... | ||
input_bc_patterns = encode_literals( | ||
tuple( | ||
input_var.type.broadcastable[:batch_ndim] for input_var in dist_params | ||
) | ||
) | ||
output_bc_patterns = encode_literals( | ||
(rv_node.outputs[1].type.broadcastable[:batch_ndim],) | ||
) | ||
output_dtypes = encode_literals((rv_node.default_output().type.dtype,)) | ||
inplace_pattern = encode_literals(()) | ||
|
||
def random_wrapper(core_shape, rng, size, *dist_params): | ||
if not inplace: | ||
rng = copy(rng) | ||
|
||
draws = _vectorized( | ||
core_op_fn, | ||
input_bc_patterns, | ||
output_bc_patterns, | ||
output_dtypes, | ||
inplace_pattern, | ||
(rng,), | ||
dist_params, | ||
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),), | ||
None | ||
if size_len is None | ||
else numba_ndarray.to_fixed_tuple(size, size_len), | ||
) | ||
return rng, draws | ||
|
||
def random(core_shape, rng, size, *dist_params): | ||
raise NotImplementedError("Non-jitted random variable not implemented") | ||
|
||
@overload(random, jit_options=_jit_options) | ||
def ov_random(core_shape, rng, size, *dist_params): | ||
return random_wrapper | ||
|
||
return random | ||
|
||
except NotImplementedError: | ||
# Fall back to object mode for random variables that don't have core implementation | ||
return rv_fallback_impl(op, node) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -705,3 +705,52 @@ def test_repeated_args(): | |
final_node = fn.maker.fgraph.outputs[0].owner | ||
assert isinstance(final_node.op, RandomVariableWithCoreShape) | ||
assert final_node.inputs[-2] is final_node.inputs[-1] | ||
|
||
|
||
def test_unsupported_rv_fallback(): | ||
"""Test that unsupported random variables fallback to object mode.""" | ||
import warnings | ||
|
||
# Create a mock random variable that doesn't have a numba implementation | ||
class CustomRV(ptr.RandomVariable): | ||
name = "custom" | ||
signature = "(d)->(d)" # We need a parameter for test to pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. create a univariate rv which will be a simpler test |
||
dtype = "float64" | ||
|
||
def _supp_shape_from_params(self, dist_params, param_shapes=None): | ||
# Return the shape of the support | ||
return [1] | ||
|
||
def rng_fn(self, rng, value, size=None): | ||
# Just return the value plus a random number | ||
return value + rng.standard_normal() | ||
|
||
custom_rv = CustomRV() | ||
|
||
# Create a graph with the unsupported RV | ||
rng = shared(np.random.default_rng(123)) | ||
value = np.array(1.0) | ||
x = custom_rv(value, rng=rng) | ||
|
||
# Capture warnings to check for the fallback warning | ||
with warnings.catch_warnings(record=True) as w: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use pytest.warns |
||
warnings.simplefilter("always") | ||
|
||
# Compile with numba mode | ||
fn = function([], x, mode=numba_mode) | ||
|
||
# Execute to trigger the fallback | ||
result = fn() | ||
|
||
# Check that a warning was raised about object mode | ||
assert any("object mode" in str(warning.message) for warning in w) | ||
|
||
# Verify the result is as expected | ||
assert isinstance(result, np.ndarray) | ||
|
||
# Run again to make sure the compiled function works properly | ||
result2 = fn() | ||
assert isinstance(result2, np.ndarray) | ||
assert not np.array_equal( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will fail because updates were not set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually test with and without updates, in which case it should change or stay the same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also set seed twice and compare to make sure it's following it |
||
result, result2 | ||
) # Results should differ with different RNG states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have a generic fallback implementation function can't we just use it like we do for other Ops?
May just need to do the unboxing of the RV that the other function is doing