Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement infer_shape automatically from gufunc_signature #1257

Open
ricardoV94 opened this issue Feb 28, 2025 · 0 comments · May be fixed by #1294
Open

Implement infer_shape automatically from gufunc_signature #1257

ricardoV94 opened this issue Feb 28, 2025 · 0 comments · May be fixed by #1294

Comments

@ricardoV94
Copy link
Member

Description

For Ops with a gufunc_signature, we can automate infer_shape implementation:

class Cholesky(Op):
# TODO: LAPACK wrapper with in-place behavior, for solve also
__props__ = ("lower", "check_finite", "on_error", "overwrite_a")
gufunc_signature = "(m,m)->(m,m)"
def __init__(
self,
*,
lower: bool = True,
check_finite: bool = True,
on_error: Literal["raise", "nan"] = "raise",
overwrite_a: bool = False,
):
self.lower = lower
self.check_finite = check_finite
if on_error not in ("raise", "nan"):
raise ValueError('on_error must be one of "raise" or ""nan"')
self.on_error = on_error
self.overwrite_a = overwrite_a
if self.overwrite_a:
self.destroy_map = {0: [0]}
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]

We actually already do it for the Blockwise Wrapper:

# The output dim is the same as another input dim
if dim_name in core_dims:
core_out_shape.append(core_dims[dim_name])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
1 participant