From d25f214af7793e104f8f2da940eac3b4ad25293d Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 11:32:13 -0400 Subject: [PATCH 01/54] mlx poc --- pytensor/compile/mode.py | 19 +++++ pytensor/link/mlx/dispatch/__init__.py | 5 ++ pytensor/link/mlx/dispatch/basic.py | 61 +++++++++++++ pytensor/link/mlx/dispatch/math.py | 12 +++ pytensor/link/mlx/linker.py | 113 +++++++++++++++++++++++++ pytensor/link/pytorch/linker.py | 16 ++-- 6 files changed, 218 insertions(+), 8 deletions(-) create mode 100644 pytensor/link/mlx/dispatch/__init__.py create mode 100644 pytensor/link/mlx/dispatch/basic.py create mode 100644 pytensor/link/mlx/dispatch/math.py create mode 100644 pytensor/link/mlx/linker.py diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index f80dfaaf5c..ce58561212 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -27,6 +27,7 @@ from pytensor.link.basic import Linker, PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.jax.linker import JAXLinker +from pytensor.link.mlx.linker import MLXLinker from pytensor.link.numba.linker import NumbaLinker from pytensor.link.pytorch.linker import PytorchLinker from pytensor.link.vm import VMLinker @@ -50,6 +51,7 @@ "jax": JAXLinker(), "pytorch": PytorchLinker(), "numba": NumbaLinker(), + "mlx": MLXLinker(), } @@ -494,6 +496,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ), ) +MLX = Mode( + MLXLinker(), + RewriteDatabaseQuery( + include=["fast_run"], + exclude=[ + "cxx_only", + "BlasOpt", + "fusion", + "inplace", + "scan_save_mem_prealloc", + ], + ), +) + predefined_modes = { "FAST_COMPILE": FAST_COMPILE, @@ -501,6 +517,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "JAX": JAX, "NUMBA": NUMBA, "PYTORCH": PYTORCH, + "MLX": MLX, } _CACHED_RUNTIME_MODES: dict[str, Mode] = {} @@ -585,6 +602,8 @@ def get_target_language(mode=None) -> tuple[Literal["py", "c", "numba", "jax"], return ("py",) if isinstance(linker, CLinker): return ("c",) + if isinstance(linker, MLXLinker): + return ("py",) if isinstance(linker, VMLinker | OpWiseCLinker): return ("c", "py") if config.cxx else ("py",) diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py new file mode 100644 index 0000000000..7acb41e1b5 --- /dev/null +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -0,0 +1,5 @@ +# isort: off +from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify + +import pytensor.link.mlx.dispatch.math +# isort: on diff --git a/pytensor/link/mlx/dispatch/basic.py b/pytensor/link/mlx/dispatch/basic.py new file mode 100644 index 0000000000..9cbb92118d --- /dev/null +++ b/pytensor/link/mlx/dispatch/basic.py @@ -0,0 +1,61 @@ +from functools import singledispatch +from types import NoneType + +import mlx.core as mx +import numpy as np + +from pytensor.compile.ops import DeepCopyOp +from pytensor.graph.fg import FunctionGraph +from pytensor.link.utils import fgraph_to_python + + +@singledispatch +def mlx_typify(data, **kwargs): + raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}") + + +@mlx_typify.register(np.ndarray) +@mlx_typify.register(mx.array) +def mlx_typify_tensor(data, dtype=None, **kwargs): + return mx.array(data, dtype=dtype) + + +@mlx_typify.register(slice) +@mlx_typify.register(NoneType) +@mlx_typify.register(np.number) +def mlx_typify_no_conversion_needed(data, **kwargs): + return data + + +@singledispatch +def mlx_funcify(op, node=None, storage_map=None, **kwargs): + """Create a MLX compatible function from an PyTensor `Op`.""" + raise NotImplementedError( + f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation" + ) + + +@mlx_funcify.register(FunctionGraph) +def mlx_funcify_FunctionGraph( + fgraph, + node=None, + fgraph_name="mlx_funcified_fgraph", + conversion_func=mlx_funcify, + **kwargs, +): + built_kwargs = {"conversion_func": conversion_func, **kwargs} + return fgraph_to_python( + fgraph, + conversion_func, + type_conversion_fn=mlx_typify, + fgraph_name=fgraph_name, + **built_kwargs, + ) + + +@mlx_funcify.register(DeepCopyOp) +def mlx_funcify_DeepCopyOp(op, **kwargs): + def deepcopyop(x): + return x.copy() + + return deepcopyop diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py new file mode 100644 index 0000000000..1ef7ec4608 --- /dev/null +++ b/pytensor/link/mlx/dispatch/math.py @@ -0,0 +1,12 @@ +import mlx.core as mx + +from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.tensor.math import Dot + + +@mlx_funcify.register(Dot) +def mlx_funcify_Dot(op, **kwargs): + def dot(x, y): + return mx.matmul(x, y) + + return dot diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py new file mode 100644 index 0000000000..8cfd9a0ff5 --- /dev/null +++ b/pytensor/link/mlx/linker.py @@ -0,0 +1,113 @@ +from pytensor.link.basic import JITLinker +from pytensor.link.utils import unique_name_generator + + +class MLXLinker(JITLinker): + """A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.gen_functors = [] + + def fgraph_convert( + self, + fgraph, + order, + input_storage, + output_storage, + storage_map, + **kwargs, + ): + """Convert a PyTensor FunctionGraph to an MLX-compatible function. + + Parameters + ---------- + fgraph : FunctionGraph + The function graph to convert + order : list + The order in which to compute the nodes + input_storage : list + Storage for the input variables + output_storage : list + Storage for the output variables + storage_map : dict + Map from variables to their storage + + Returns + ------- + callable + An MLX-compatible function + """ + from pytensor.link.mlx.dispatch import mlx_funcify + + # We want to have globally unique names + # across the entire pytensor graph, not + # just the subgraph + generator = unique_name_generator(["mlx_linker"]) + + # Ensure that torch is aware of the generated + # code so we can compile without graph breaks + def conversion_func_register(*args, **kwargs): + functor = mlx_funcify(*args, **kwargs) + name = kwargs["unique_name"](functor) + self.gen_functors.append((f"_{name}", functor)) + return functor + + built_kwargs = { + "unique_name": generator, + "conversion_func": conversion_func_register, + **kwargs, + } + return mlx_funcify( + fgraph, + input_storage=input_storage, + storage_map=storage_map, + **built_kwargs, + ) + + def jit_compile(self, fn): + """JIT compile an MLX function. + + Parameters + ---------- + fn : callable + The function to compile + + Returns + ------- + callable + The compiled function + """ + import mlx.core as mx + + return mx.compile(fn) + + def create_thunk_inputs(self, storage_map): + """Create inputs for the MLX thunk. + + Parameters + ---------- + storage_map : dict + Map from variables to their storage + + Returns + ------- + list + The inputs for the thunk + """ + from numpy.random import Generator, RandomState + + from pytensor.link.mlx.dispatch import mlx_typify + + thunk_inputs = [] + for n in self.fgraph.inputs: + sinput = storage_map[n] + # Handle random number generators specially + if isinstance(sinput[0], RandomState | Generator): + new_value = mlx_typify( + sinput[0], dtype=getattr(sinput[0], "dtype", None) + ) + sinput[0] = new_value + thunk_inputs.append(sinput) + + return thunk_inputs diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index b8475e3157..0a057a9e8d 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -31,16 +31,16 @@ def conversion_func_register(*args, **kwargs): **kwargs, } return pytorch_funcify( - fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs + fgraph, + input_storage=input_storage, + storage_map=storage_map, + **built_kwargs, ) def jit_compile(self, fn): - import torch + import mlx.core as mx - # flag that tend to help our graphs - torch._dynamo.config.capture_dynamic_output_shape_ops = True - - from pytensor.link.pytorch.dispatch import pytorch_typify + from pytensor.link.mlx.dispatch import mlx_typify class wrapper: """ @@ -54,7 +54,7 @@ class wrapper: """ def __init__(self, fn, gen_functors): - self.fn = torch.compile(fn) + self.fn = mx.compile(fn) self.gen_functors = gen_functors.copy() def __call__(self, *inputs, **kwargs): @@ -65,7 +65,7 @@ def __call__(self, *inputs, **kwargs): setattr(pytensor.link.utils, n[1:], fn) # Torch does not accept numpy inputs and may return GPU objects - outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs) + outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs) # unset attrs for n, _ in self.gen_functors: From edacc0ed246c13a7bc664a5b4a5f078fccba68eb Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 11:38:56 -0400 Subject: [PATCH 02/54] add test for dot --- tests/link/mlx/dispatch/test_math.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 tests/link/mlx/dispatch/test_math.py diff --git a/tests/link/mlx/dispatch/test_math.py b/tests/link/mlx/dispatch/test_math.py new file mode 100644 index 0000000000..5608321a80 --- /dev/null +++ b/tests/link/mlx/dispatch/test_math.py @@ -0,0 +1,19 @@ +import numpy as np + +import pytensor +from pytensor.tensor.type import matrix + + +def test_mlx_dot(): + x = matrix("x") + y = matrix("y") + + out = x.dot(y) + fn = pytensor.function([x, y], out, mode="MLX") + + test_x = np.random.normal(size=(3, 2)) + test_y = np.random.normal(size=(2, 4)) + np.testing.assert_allclose( + fn(test_x, test_y), + np.dot(test_x, test_y), + ) From 052fdc23e80ad097373e5782eb06b65f75dead17 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 12:26:49 -0400 Subject: [PATCH 03/54] restore pytorch --- pytensor/link/pytorch/linker.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index 0a057a9e8d..18824a5b71 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -38,9 +38,11 @@ def conversion_func_register(*args, **kwargs): ) def jit_compile(self, fn): - import mlx.core as mx + import torch - from pytensor.link.mlx.dispatch import mlx_typify + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + from pytensor.link.pytorch.dispatch import pytorch_typify class wrapper: """ @@ -54,7 +56,7 @@ class wrapper: """ def __init__(self, fn, gen_functors): - self.fn = mx.compile(fn) + self.fn = torch.compile(fn) self.gen_functors = gen_functors.copy() def __call__(self, *inputs, **kwargs): @@ -65,7 +67,7 @@ def __call__(self, *inputs, **kwargs): setattr(pytensor.link.utils, n[1:], fn) # Torch does not accept numpy inputs and may return GPU objects - outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs) + outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs) # unset attrs for n, _ in self.gen_functors: From a9ecad0f8e41ae32e0b0148e2af15695d9c735f7 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 12:31:07 -0400 Subject: [PATCH 04/54] wrap in mx.array --- tests/link/mlx/dispatch/test_math.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/link/mlx/dispatch/test_math.py b/tests/link/mlx/dispatch/test_math.py index 5608321a80..3b2c41167f 100644 --- a/tests/link/mlx/dispatch/test_math.py +++ b/tests/link/mlx/dispatch/test_math.py @@ -1,3 +1,4 @@ +import mlx.core as mx import numpy as np import pytensor @@ -11,8 +12,8 @@ def test_mlx_dot(): out = x.dot(y) fn = pytensor.function([x, y], out, mode="MLX") - test_x = np.random.normal(size=(3, 2)) - test_y = np.random.normal(size=(2, 4)) + test_x = mx.array(np.random.normal(size=(3, 2))) + test_y = mx.array(np.random.normal(size=(2, 4))) np.testing.assert_allclose( fn(test_x, test_y), np.dot(test_x, test_y), From e690bff174f4030d713bd51c61a3af46ad6652f6 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 12:32:29 -0400 Subject: [PATCH 05/54] modify the pytorch jit --- pytensor/link/mlx/linker.py | 44 +++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index 8cfd9a0ff5..c2c970aebf 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -66,21 +66,41 @@ def conversion_func_register(*args, **kwargs): ) def jit_compile(self, fn): - """JIT compile an MLX function. + import mlx.core as mx - Parameters - ---------- - fn : callable - The function to compile + from pytensor.link.mlx.dispatch import mlx_typify - Returns - ------- - callable - The compiled function - """ - import mlx.core as mx + class wrapper: + def __init__(self, fn, gen_functors): + self.fn = mx.compile(fn) + self.gen_functors = gen_functors.copy() + + def __call__(self, *inputs, **kwargs): + import pytensor.link.utils + + # set attrs + for n, fn in self.gen_functors: + setattr(pytensor.link.utils, n[1:], fn) + + # MLX doesn't support np.ndarray as input + outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs) + + return outs + + # unset attrs + for n, _ in self.gen_functors: + if getattr(pytensor.link.utils, n[1:], False): + delattr(pytensor.link.utils, n[1:]) + + return tuple(out.cpu().numpy() for out in outs) + + def __del__(self): + del self.gen_functors + + inner_fn = wrapper(fn, self.gen_functors) + self.gen_functors = [] - return mx.compile(fn) + return inner_fn def create_thunk_inputs(self, storage_map): """Create inputs for the MLX thunk. From ad29c1780ef05a16cea7a06a08e258f97852eeb1 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 12:38:06 -0400 Subject: [PATCH 06/54] move file --- tests/link/mlx/{dispatch => }/test_math.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/link/mlx/{dispatch => }/test_math.py (100%) diff --git a/tests/link/mlx/dispatch/test_math.py b/tests/link/mlx/test_math.py similarity index 100% rename from tests/link/mlx/dispatch/test_math.py rename to tests/link/mlx/test_math.py From ba29b373df4fd283a1b542e0d0df50ab3fb35269 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 12:47:38 -0400 Subject: [PATCH 07/54] dont wrap --- tests/link/mlx/test_math.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 3b2c41167f..28397d7643 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -1,4 +1,3 @@ -import mlx.core as mx import numpy as np import pytensor @@ -12,9 +11,12 @@ def test_mlx_dot(): out = x.dot(y) fn = pytensor.function([x, y], out, mode="MLX") - test_x = mx.array(np.random.normal(size=(3, 2))) - test_y = mx.array(np.random.normal(size=(2, 4))) - np.testing.assert_allclose( - fn(test_x, test_y), - np.dot(test_x, test_y), - ) + seed = sum(map(ord, "test_mlx_dot")) + rng = np.random.default_rng(seed) + + test_x = rng.normal(size=(3, 2)) + test_y = rng.normal(size=(2, 4)) + + actual = fn(test_x, test_y) + expected = np.dot(test_x, test_y) + np.testing.assert_allclose(actual, expected) From 87168707285534ca73b449c99ca2a3b2ecd588ed Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 12:53:48 -0400 Subject: [PATCH 08/54] attempt to fix github action --- .github/workflows/test.yml | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9677615206..46800a1e13 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -82,6 +82,7 @@ jobs: install-numba: [0] install-jax: [0] install-torch: [0] + install-mlx: [0] part: - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" - "tests/scan" @@ -115,6 +116,7 @@ jobs: install-numba: 0 install-jax: 0 install-torch: 0 + install-mlx: 0 - install-numba: 1 os: "ubuntu-latest" python-version: "3.10" @@ -150,6 +152,13 @@ jobs: fast-compile: 0 float32: 0 part: "tests/link/pytorch" + - install-mlx: 1 + os: "ubuntu-latest" + python-version: "3.10" + numpy-version: ">=2.0" + fast-compile: 0 + float32: 0 + part: "tests/link/mlx" - os: macos-15 python-version: "3.13" numpy-version: ">=2.0" @@ -196,6 +205,7 @@ jobs: if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi + if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mlx; fi pip install pytest-sphinx pip install -e ./ @@ -212,6 +222,7 @@ jobs: INSTALL_NUMBA: ${{ matrix.install-numba }} INSTALL_JAX: ${{ matrix.install-jax }} INSTALL_TORCH: ${{ matrix.install-torch}} + INSTALL_MLX: ${{ matrix.install-mlx }} OS: ${{ matrix.os}} - name: Run tests From 9bf7edfb9a6e73675594f6ee1964086ff9f67d75 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 12:55:58 -0400 Subject: [PATCH 09/54] change the rtol --- tests/link/mlx/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 28397d7643..8a9c700a52 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -19,4 +19,4 @@ def test_mlx_dot(): actual = fn(test_x, test_y) expected = np.dot(test_x, test_y) - np.testing.assert_allclose(actual, expected) + np.testing.assert_allclose(actual, expected, rtol=1e-6) From 96ba1162e5e87ec635e3dfbd36f9061f6d6fbac0 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 13:17:57 -0400 Subject: [PATCH 10/54] add init file --- pytensor/link/mlx/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 pytensor/link/mlx/__init__.py diff --git a/pytensor/link/mlx/__init__.py b/pytensor/link/mlx/__init__.py new file mode 100644 index 0000000000..d5a6ab19ff --- /dev/null +++ b/pytensor/link/mlx/__init__.py @@ -0,0 +1 @@ +from pytensor.link.mlx.linker import MLXLinker From e116fa1d4c26814b663be3d88aebaeb718416b4e Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 13:21:47 -0400 Subject: [PATCH 11/54] skip if not installed --- tests/link/mlx/test_basic.py | 4 ++++ tests/link/mlx/test_math.py | 1 + 2 files changed, 5 insertions(+) create mode 100644 tests/link/mlx/test_basic.py diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py new file mode 100644 index 0000000000..f4e5149d67 --- /dev/null +++ b/tests/link/mlx/test_basic.py @@ -0,0 +1,4 @@ +import pytest + + +mx = pytest.importorskip("mlx.core") diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 8a9c700a52..f3839a1cac 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -1,6 +1,7 @@ import numpy as np import pytensor +import tests.link.mlx.test_basic # noqa: F401 from pytensor.tensor.type import matrix From 5d5f7546d53cd19c3c4a0f88a57bbfd82c853f07 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 13:28:11 -0400 Subject: [PATCH 12/54] remove torch related code / comments --- pytensor/link/mlx/linker.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index c2c970aebf..f8159a120b 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -45,8 +45,6 @@ def fgraph_convert( # just the subgraph generator = unique_name_generator(["mlx_linker"]) - # Ensure that torch is aware of the generated - # code so we can compile without graph breaks def conversion_func_register(*args, **kwargs): functor = mlx_funcify(*args, **kwargs) name = kwargs["unique_name"](functor) @@ -85,14 +83,12 @@ def __call__(self, *inputs, **kwargs): # MLX doesn't support np.ndarray as input outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs) - return outs - # unset attrs for n, _ in self.gen_functors: if getattr(pytensor.link.utils, n[1:], False): delattr(pytensor.link.utils, n[1:]) - return tuple(out.cpu().numpy() for out in outs) + return outs def __del__(self): del self.gen_functors From b8cee3f779a4472431022459cc198003be671ac9 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sat, 12 Apr 2025 16:25:22 -0400 Subject: [PATCH 13/54] simplify the fgraph_convert --- pytensor/link/mlx/linker.py | 39 ++----------------------------------- 1 file changed, 2 insertions(+), 37 deletions(-) diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index f8159a120b..f512c041d3 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -1,5 +1,4 @@ from pytensor.link.basic import JITLinker -from pytensor.link.utils import unique_name_generator class MLXLinker(JITLinker): @@ -9,29 +8,13 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.gen_functors = [] - def fgraph_convert( - self, - fgraph, - order, - input_storage, - output_storage, - storage_map, - **kwargs, - ): + def fgraph_convert(self, fgraph, **kwargs): """Convert a PyTensor FunctionGraph to an MLX-compatible function. Parameters ---------- fgraph : FunctionGraph The function graph to convert - order : list - The order in which to compute the nodes - input_storage : list - Storage for the input variables - output_storage : list - Storage for the output variables - storage_map : dict - Map from variables to their storage Returns ------- @@ -40,27 +23,9 @@ def fgraph_convert( """ from pytensor.link.mlx.dispatch import mlx_funcify - # We want to have globally unique names - # across the entire pytensor graph, not - # just the subgraph - generator = unique_name_generator(["mlx_linker"]) - - def conversion_func_register(*args, **kwargs): - functor = mlx_funcify(*args, **kwargs) - name = kwargs["unique_name"](functor) - self.gen_functors.append((f"_{name}", functor)) - return functor - - built_kwargs = { - "unique_name": generator, - "conversion_func": conversion_func_register, - **kwargs, - } return mlx_funcify( fgraph, - input_storage=input_storage, - storage_map=storage_map, - **built_kwargs, + **kwargs, ) def jit_compile(self, fn): From d057453c071d441e375b0deaa05122a774ba4e31 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sat, 12 Apr 2025 16:25:44 -0400 Subject: [PATCH 14/54] assert type --- tests/link/mlx/test_math.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index f3839a1cac..1380e01ca4 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -1,8 +1,8 @@ import numpy as np import pytensor -import tests.link.mlx.test_basic # noqa: F401 from pytensor.tensor.type import matrix +from tests.link.mlx.test_basic import mx def test_mlx_dot(): @@ -19,5 +19,6 @@ def test_mlx_dot(): test_y = rng.normal(size=(2, 4)) actual = fn(test_x, test_y) + assert isinstance(actual, mx.array) expected = np.dot(test_x, test_y) np.testing.assert_allclose(actual, expected, rtol=1e-6) From ae202e669b238f285da04ca9daee43eb74f849de Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 13:51:02 -0400 Subject: [PATCH 15/54] simplify the internal --- pytensor/link/mlx/linker.py | 31 ++++--------------------------- 1 file changed, 4 insertions(+), 27 deletions(-) diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index f512c041d3..e057bb942c 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -33,35 +33,12 @@ def jit_compile(self, fn): from pytensor.link.mlx.dispatch import mlx_typify - class wrapper: - def __init__(self, fn, gen_functors): - self.fn = mx.compile(fn) - self.gen_functors = gen_functors.copy() + inner_fn = mx.compile(fn) - def __call__(self, *inputs, **kwargs): - import pytensor.link.utils + def fn(*inputs, inner_fn=inner_fn): + return inner_fn(*(mlx_typify(inp) for inp in inputs)) - # set attrs - for n, fn in self.gen_functors: - setattr(pytensor.link.utils, n[1:], fn) - - # MLX doesn't support np.ndarray as input - outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs) - - # unset attrs - for n, _ in self.gen_functors: - if getattr(pytensor.link.utils, n[1:], False): - delattr(pytensor.link.utils, n[1:]) - - return outs - - def __del__(self): - del self.gen_functors - - inner_fn = wrapper(fn, self.gen_functors) - self.gen_functors = [] - - return inner_fn + return fn def create_thunk_inputs(self, storage_map): """Create inputs for the MLX thunk. From f1941fe1c7951d1208cbc4c82ce7d7da99e8dc75 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 14:00:29 -0400 Subject: [PATCH 16/54] remove the language --- pytensor/compile/mode.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index ce58561212..8dc7c742bc 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -602,8 +602,6 @@ def get_target_language(mode=None) -> tuple[Literal["py", "c", "numba", "jax"], return ("py",) if isinstance(linker, CLinker): return ("c",) - if isinstance(linker, MLXLinker): - return ("py",) if isinstance(linker, VMLinker | OpWiseCLinker): return ("c", "py") if config.cxx else ("py",) From 7c8eae7aa670d914bb9ad3552f58b4e4bd5279ce Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 13:08:40 -0500 Subject: [PATCH 17/54] Adding operations in pytensor --- pytensor/link/mlx/dispatch/elemwise | 77 +++++++++++++++++++++++++++++ pytensor/link/mlx/dispatch/math.py | 43 ++++++++++++++++ pytensor/link/mlx/dispatch/shape.py | 15 ++++++ 3 files changed, 135 insertions(+) create mode 100644 pytensor/link/mlx/dispatch/elemwise create mode 100644 pytensor/link/mlx/dispatch/shape.py diff --git a/pytensor/link/mlx/dispatch/elemwise b/pytensor/link/mlx/dispatch/elemwise new file mode 100644 index 0000000000..5c938cac10 --- /dev/null +++ b/pytensor/link/mlx/dispatch/elemwise @@ -0,0 +1,77 @@ +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.tensor.elemwise import CAReduce, DimShuffle +from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad +from pytensor.scalar.basic import Add, Mul, Any, AND, OR, ScalarMaximum, ScalarMinimum + +import mlx.core as mx + +@mlx_funcify.register(DimShuffle) +def mlx_funcify_DimShuffle(op, **kwargs): + def dimshuffle(x): + res = mx.transpose(x, op.transposition) + + shape = list(res.shape[: len(op.shuffle)]) + + for augm in op.augment: + shape.insert(augm, 1) + + return mx.reshape(res, shape) + + return dimshuffle + +@mlx_funcify.register(CAReduce) +def mlx_funcify_CAReduce(op, **kwargs): + if isinstance(op.scalar_op, Add): + def sum(x): + return mx.sum(x, axis=op.axis) + + return sum + elif isinstance(op.scalar_op, Mul): + def prod(x): + return mx.prod(x, axis=op.axis) + + return prod + elif isinstance(op.scalar_op, AND): + def all(x): + return mx.all(x, axis=op.axis) + + return all + elif isinstance(op.scalar_op, OR): + def any(x): + return mx.any(x, axis=op.axis) + + return any + elif isinstance(op.scalar_op, ScalarMaximum): + def max(x): + return mx.max(x, axis=op.axis) + + return max + elif isinstance(op.scalar_op, ScalarMinimum): + def min(x): + return mx.min(x, axis=op.axis) + + return min + + else: + raise NotImplementedError(f"MLX does not support {op.scalar_op}") + + +@mlx_funcify.register(Softmax) +def mlx_funcify_Softmax(op, **kwargs): + axis = op.axis + + def softmax(x): + return mx.softmax(x, axis=axis) + + return softmax + + +@mlx_funcify.register(SoftmaxGrad) +def mlx_funcify_SoftmaxGrad(op, **kwargs): + axis = op.axis + + def softmax_grad(dy, sm): + dy_times_sm = dy * sm + return dy_times_sm - mx.sum(dy_times_sm, axis=axis, keepdims=True) * sm + + return softmax_grad \ No newline at end of file diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 1ef7ec4608..842181b046 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -1,7 +1,10 @@ import mlx.core as mx from pytensor.link.mlx.dispatch import mlx_funcify + +from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot +from pytensor.scalar.basic import Add, Mul, Sub, Exp, Log, Sin, Cos @mlx_funcify.register(Dot) @@ -10,3 +13,43 @@ def dot(x, y): return mx.matmul(x, y) return dot + +@mlx_funcify.register(Elemwise) +def mlx_funcify_Elemwise(op, **kwargs): + if isinstance(op.scalar_op, Add): + def add(x, y): + return mx.add(x, y) + + return add + elif isinstance(op.scalar_op, Sub): + def sub(x, y): + return mx.sub(x, y) + + return sub + elif isinstance(op.scalar_op, Mul): + def mul(x, y): + return mx.mul(x, y) + + return mul + elif isinstance(op.scalar_op, Exp): + def exp(x): + return mx.exp(x) + + return exp + elif isinstance(op.scalar_op, Log): + def log(x): + return mx.log(x) + + return log + elif isinstance(op.scalar_op, Sin): + def sin(x): + return mx.sin(x) + + return sin + elif isinstance(op.scalar_op, Cos): + def cos(x): + return mx.cos(x) + + return cos + else: + raise NotImplementedError(f"MLX does not support {op.scalar_op}") \ No newline at end of file diff --git a/pytensor/link/mlx/dispatch/shape.py b/pytensor/link/mlx/dispatch/shape.py new file mode 100644 index 0000000000..c22ecea704 --- /dev/null +++ b/pytensor/link/mlx/dispatch/shape.py @@ -0,0 +1,15 @@ +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape +from pytensor.link.mlx.dispatch.basic import mlx_funcify + +@mlx_funcify.register(SpecifyShape) +def mlx_funcify_SpecifyShape(op, node, **kwargs): + def specifyshape(x, *shape): + assert x.ndim == len(shape) + for actual, expected in zip(x.shape, shape, strict=True): + if expected is None: + continue + if actual != expected: + raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}") + return x + + return specifyshape \ No newline at end of file From 67a74fb5e5ecb3a32c49a820e939509f0c2556d0 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 14:15:34 -0400 Subject: [PATCH 18/54] add extension --- .../mlx/dispatch/{elemwise => elemwise.py} | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) rename pytensor/link/mlx/dispatch/{elemwise => elemwise.py} (90%) diff --git a/pytensor/link/mlx/dispatch/elemwise b/pytensor/link/mlx/dispatch/elemwise.py similarity index 90% rename from pytensor/link/mlx/dispatch/elemwise rename to pytensor/link/mlx/dispatch/elemwise.py index 5c938cac10..7ec124623e 100644 --- a/pytensor/link/mlx/dispatch/elemwise +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -1,9 +1,10 @@ +import mlx.core as mx + from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum from pytensor.tensor.elemwise import CAReduce, DimShuffle -from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad -from pytensor.scalar.basic import Add, Mul, Any, AND, OR, ScalarMaximum, ScalarMinimum +from pytensor.tensor.special import Softmax, SoftmaxGrad -import mlx.core as mx @mlx_funcify.register(DimShuffle) def mlx_funcify_DimShuffle(op, **kwargs): @@ -19,42 +20,49 @@ def dimshuffle(x): return dimshuffle + @mlx_funcify.register(CAReduce) def mlx_funcify_CAReduce(op, **kwargs): if isinstance(op.scalar_op, Add): + def sum(x): return mx.sum(x, axis=op.axis) return sum elif isinstance(op.scalar_op, Mul): + def prod(x): return mx.prod(x, axis=op.axis) return prod elif isinstance(op.scalar_op, AND): + def all(x): return mx.all(x, axis=op.axis) return all elif isinstance(op.scalar_op, OR): + def any(x): return mx.any(x, axis=op.axis) return any elif isinstance(op.scalar_op, ScalarMaximum): + def max(x): return mx.max(x, axis=op.axis) return max elif isinstance(op.scalar_op, ScalarMinimum): + def min(x): return mx.min(x, axis=op.axis) return min - + else: raise NotImplementedError(f"MLX does not support {op.scalar_op}") - + @mlx_funcify.register(Softmax) def mlx_funcify_Softmax(op, **kwargs): @@ -74,4 +82,4 @@ def softmax_grad(dy, sm): dy_times_sm = dy * sm return dy_times_sm - mx.sum(dy_times_sm, axis=axis, keepdims=True) * sm - return softmax_grad \ No newline at end of file + return softmax_grad From fb5eb523dbae16ee572261c95ecb03ace969816b Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 14:38:21 -0400 Subject: [PATCH 19/54] make compare function --- tests/link/mlx/test_basic.py | 74 ++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py index f4e5149d67..746aab10bc 100644 --- a/tests/link/mlx/test_basic.py +++ b/tests/link/mlx/test_basic.py @@ -1,4 +1,78 @@ +from collections.abc import Callable, Iterable +from functools import partial + +import numpy as np import pytest +from pytensor.compile.function import function +from pytensor.compile.mode import Mode +from pytensor.graph.basic import Variable +from pytensor.link.mlx import MLXLinker + mx = pytest.importorskip("mlx.core") + +mlx_mode = Mode(linker=MLXLinker()) +py_mode = Mode(linker="py", optimizer=None) + + +def compare_mlx_and_py( + graph_inputs: Iterable[Variable], + graph_outputs: Variable | Iterable[Variable], + test_inputs: Iterable, + *, + assert_fn: Callable | None = None, + must_be_device_array: bool = True, + mlx_mode=mlx_mode, + py_mode=py_mode, +): + """Function to compare python function output and mlx compiled output for testing equality + + The inputs and outputs are then passed to this function which then compiles the given function in both + mlx and python, runs the calculation in both and checks if the results are the same + + Parameters + ---------- + graph_inputs: + Symbolic inputs to the graph + outputs: + Symbolic outputs of the graph + test_inputs: iter + Numerical inputs for testing the function. + assert_fn: func, opt + Assert function used to check for equality between python and mlx. If not + provided uses np.testing.assert_allclose + must_be_device_array: Bool + Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes + if this device array is found it indicates if the result was computed by jax + + Returns + ------- + mlx_res + + """ + if assert_fn is None: + assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) + + if any(inp.owner is not None for inp in graph_inputs): + raise ValueError("Inputs must be root variables") + + pytensor_mlx_fn = function(graph_inputs, graph_outputs, mode=mlx_mode) + mlx_res = pytensor_mlx_fn(*test_inputs) + + if must_be_device_array: + if isinstance(mlx_res, list): + assert all(isinstance(res, mx.array) for res in mlx_res) + else: + assert isinstance(mlx_res, mx.array) + + pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode) + py_res = pytensor_py_fn(*test_inputs) + + if isinstance(graph_outputs, list | tuple): + for j, p in zip(mlx_res, py_res, strict=True): + assert_fn(j, p) + else: + assert_fn(mlx_res, py_res) + + return pytensor_mlx_fn, mlx_res From 516b5958b32669f6844d301d8a8f971b7611bab5 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 14:38:52 -0400 Subject: [PATCH 20/54] rename function --- tests/link/mlx/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 1380e01ca4..0781ea4e22 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -5,7 +5,7 @@ from tests.link.mlx.test_basic import mx -def test_mlx_dot(): +def test_dot(): x = matrix("x") y = matrix("y") From 67bb8da51aadfacbd9913ee6f13308d2857aa2be Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 15:00:45 -0400 Subject: [PATCH 21/54] correct the function name --- pytensor/link/mlx/dispatch/math.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 842181b046..42f1ec7b72 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -1,10 +1,9 @@ import mlx.core as mx from pytensor.link.mlx.dispatch import mlx_funcify - +from pytensor.scalar.basic import Add, Cos, Exp, Log, Mul, Sin, Sub from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot -from pytensor.scalar.basic import Add, Mul, Sub, Exp, Log, Sin, Cos @mlx_funcify.register(Dot) @@ -14,42 +13,50 @@ def dot(x, y): return dot + @mlx_funcify.register(Elemwise) def mlx_funcify_Elemwise(op, **kwargs): if isinstance(op.scalar_op, Add): + def add(x, y): return mx.add(x, y) return add elif isinstance(op.scalar_op, Sub): + def sub(x, y): - return mx.sub(x, y) + return mx.subtract(x, y) return sub elif isinstance(op.scalar_op, Mul): + def mul(x, y): - return mx.mul(x, y) + return mx.multiply(x, y) return mul elif isinstance(op.scalar_op, Exp): + def exp(x): return mx.exp(x) return exp elif isinstance(op.scalar_op, Log): + def log(x): return mx.log(x) return log elif isinstance(op.scalar_op, Sin): + def sin(x): return mx.sin(x) - + return sin elif isinstance(op.scalar_op, Cos): + def cos(x): return mx.cos(x) return cos else: - raise NotImplementedError(f"MLX does not support {op.scalar_op}") \ No newline at end of file + raise NotImplementedError(f"MLX does not support {op.scalar_op}") From 82bb9642daa7514c07e233b3486bd75f85503d8c Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 15:01:21 -0400 Subject: [PATCH 22/54] tests for elemwise --- tests/link/mlx/test_math.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 0781ea4e22..c4f16d2f28 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -1,13 +1,14 @@ import numpy as np +import pytest import pytensor -from pytensor.tensor.type import matrix -from tests.link.mlx.test_basic import mx +import pytensor.tensor as pt +from tests.link.mlx.test_basic import compare_mlx_and_py, mx def test_dot(): - x = matrix("x") - y = matrix("y") + x = pt.matrix("x") + y = pt.matrix("y") out = x.dot(y) fn = pytensor.function([x, y], out, mode="MLX") @@ -22,3 +23,29 @@ def test_dot(): assert isinstance(actual, mx.array) expected = np.dot(test_x, test_y) np.testing.assert_allclose(actual, expected, rtol=1e-6) + + +@pytest.mark.parametrize( + "op", + [pt.exp, pt.log, pt.sin, pt.cos], + ids=["exp", "log", "sin", "cos"], +) +def test_elemwise_one_input(op) -> None: + x = pt.vector("x") + out = op(x) + x_test = mx.array([1.0, 2.0, 3.0]) + compare_mlx_and_py([x], out, [x_test]) + + +@pytest.mark.parametrize( + "op", + [pt.add, pt.sub, pt.mul], + ids=["add", "sub", "mul"], +) +def test_elemwise_two_inputs(op) -> None: + x = pt.vector("x") + y = pt.vector("y") + out = op(x, y) + x_test = mx.array([1.0, 2.0, 3.0]) + y_test = mx.array([4.0, 5.0, 6.0]) + compare_mlx_and_py([x, y], out, [x_test, y_test]) From 877d79fe2981d132194e4a528601d6c1f5b6105b Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 14:05:11 -0500 Subject: [PATCH 23/54] Changes --- pytensor/link/mlx/dispatch/__init__.py | 7 +- pytensor/link/mlx/dispatch/elemwise.py | 2 +- pytensor/link/mlx/dispatch/math.py | 53 ++++++++++-- pytensor/link/mlx/dispatch/subtensor.py | 110 ++++++++++++++++++++++++ 4 files changed, 164 insertions(+), 8 deletions(-) create mode 100644 pytensor/link/mlx/dispatch/subtensor.py diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index 7acb41e1b5..2d7dd19974 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -2,4 +2,9 @@ from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify import pytensor.link.mlx.dispatch.math -# isort: on +import pytensor.link.mlx.dispatch.basic +import pytensor.link.mlx.dispatch.elemwise +import pytensor.link.mlx.dispatch.shape +import pytensor.link.mlx.dispatch.subtensor +import pytensor.link.mlx.dispatch.core +# isort: on \ No newline at end of file diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index 7ec124623e..d4bfaeab51 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -1,7 +1,7 @@ import mlx.core as mx from pytensor.link.mlx.dispatch.basic import mlx_funcify -from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum +from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum, Switch from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.tensor.special import Softmax, SoftmaxGrad diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 42f1ec7b72..ce5064bf92 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -4,6 +4,8 @@ from pytensor.scalar.basic import Add, Cos, Exp, Log, Mul, Sin, Sub from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot +from pytensor.scalar.math import Sigmoid +from pytensor.scalar.basic import Add, Mul, Sub, Exp, Log, Sin, Cos, LE, LT, GE, GT, EQ, NEQ @mlx_funcify.register(Dot) @@ -17,9 +19,11 @@ def dot(x, y): @mlx_funcify.register(Elemwise) def mlx_funcify_Elemwise(op, **kwargs): if isinstance(op.scalar_op, Add): - - def add(x, y): - return mx.add(x, y) + def add(*args): + result = args[0] + for arg in args[1:]: + result = mx.add(result, arg) + return result return add elif isinstance(op.scalar_op, Sub): @@ -29,9 +33,11 @@ def sub(x, y): return sub elif isinstance(op.scalar_op, Mul): - - def mul(x, y): - return mx.multiply(x, y) + def mul(*args): + result = args[0] + for arg in args[1:]: + result = mx.multiply(result, arg) + return result return mul elif isinstance(op.scalar_op, Exp): @@ -58,5 +64,40 @@ def cos(x): return mx.cos(x) return cos + elif isinstance(op.scalar_op, Sigmoid): + def sigmoid(x): + return mx.sigmoid(x) + + return sigmoid + elif isinstance(op.scalar_op, LE): + def le(x, y): + return mx.less_equal(x, y) + + return le + elif isinstance(op.scalar_op, LT): + def lt(x, y): + return mx.less(x, y) + + return lt + elif isinstance(op.scalar_op, GE): + def ge(x, y): + return mx.greater_equal(x, y) + + return ge + elif isinstance(op.scalar_op, GT): + def gt(x, y): + return mx.greater(x, y) + + return gt + elif isinstance(op.scalar_op, EQ): + def eq(x, y): + return mx.equal(x, y) + + return eq + elif isinstance(op.scalar_op, NEQ): + def neq(x, y): + return mx.not_equal(x, y) + + return neq else: raise NotImplementedError(f"MLX does not support {op.scalar_op}") diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py new file mode 100644 index 0000000000..7f8b55f18e --- /dev/null +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -0,0 +1,110 @@ +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + AdvancedSubtensor, + AdvancedSubtensor1, + IncSubtensor, + Subtensor, + indices_from_subtensor, +) +from pytensor.tensor.type_other import MakeSlice + + +BOOLEAN_MASK_ERROR = """MLX does not support resizing arrays with boolean +masks. In some cases, however, it is possible to re-express your model +in a form that MLX can compile: + +>>> import pytensor.tensor as pt +>>> x_pt = pt.vector('x') +>>> y_pt = x_pt[x_pt > 0].sum() + +can be re-expressed as: + +>>> import pytensor.tensor as pt +>>> x_pt = pt.vector('x') +>>> y_pt = pt.where(x_pt > 0, x_pt, 0).sum() +""" + +DYNAMIC_SLICE_LENGTH_ERROR = """MLX does not support slicing arrays with a dynamic +slice length. +""" + + +@mlx_funcify.register(Subtensor) +@mlx_funcify.register(AdvancedSubtensor) +@mlx_funcify.register(AdvancedSubtensor1) +def mlx_funcify_Subtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) + + def subtensor(x, *ilists): + indices = indices_from_subtensor(ilists, idx_list) + if len(indices) == 1: + indices = indices[0] + + return x.__getitem__(indices) + + return subtensor + + +@mlx_funcify.register(IncSubtensor) +@mlx_funcify.register(AdvancedIncSubtensor1) +def mlx_funcify_IncSubtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) + + if getattr(op, "set_instead_of_inc", False): + + def mlx_fn(x, indices, y): + if not op.inplace: + x = x.copy() + x[indices] = y + return x + + else: + + def mlx_fn(x, indices, y): + if not op.inplace: + x = x.copy() + x[indices] += y + return x + + def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): + indices = indices_from_subtensor(ilist, idx_list) + if len(indices) == 1: + indices = indices[0] + + return mlx_fn(x, indices, y) + + return incsubtensor + + +@mlx_funcify.register(AdvancedIncSubtensor) +def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): + if getattr(op, "set_instead_of_inc", False): + + def mlx_fn(x, indices, y): + if not op.inplace: + x = x.copy() + x[indices] = y + return x + + else: + + def mlx_fn(x, indices, y): + if not op.inplace: + x = x.copy() + x[indices] += y + return x + + def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn): + return mlx_fn(x, ilist, y) + + return advancedincsubtensor + + +@mlx_funcify.register(MakeSlice) +def mlx_funcify_MakeSlice(op, **kwargs): + def makeslice(*x): + return slice(*x) + + return makeslice From fafedd66d579636eb3292245a1d5f7c0e874c2da Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 14:09:35 -0500 Subject: [PATCH 24/54] Toma tu tomate William --- pytensor/link/mlx/dispatch/__init__.py | 2 +- pytensor/link/mlx/dispatch/elemwise.py | 2 +- pytensor/link/mlx/dispatch/math.py | 49 ++++++++++++++++++++++++-- pytensor/link/mlx/dispatch/shape.py | 5 +-- 4 files changed, 51 insertions(+), 7 deletions(-) diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index 2d7dd19974..7e835d238b 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -7,4 +7,4 @@ import pytensor.link.mlx.dispatch.shape import pytensor.link.mlx.dispatch.subtensor import pytensor.link.mlx.dispatch.core -# isort: on \ No newline at end of file +# isort: on diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index d4bfaeab51..7ec124623e 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -1,7 +1,7 @@ import mlx.core as mx from pytensor.link.mlx.dispatch.basic import mlx_funcify -from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum, Switch +from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.tensor.special import Softmax, SoftmaxGrad diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index ce5064bf92..a0f68324d9 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -1,11 +1,27 @@ import mlx.core as mx from pytensor.link.mlx.dispatch import mlx_funcify -from pytensor.scalar.basic import Add, Cos, Exp, Log, Mul, Sin, Sub +from pytensor.scalar.basic import ( + EQ, + GE, + GT, + LE, + LT, + NEQ, + Add, + Cos, + Exp, + Log, + Mul, + Pow, + Sin, + Sub, + Switch, + TrueDiv, +) +from pytensor.scalar.math import Sigmoid from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot -from pytensor.scalar.math import Sigmoid -from pytensor.scalar.basic import Add, Mul, Sub, Exp, Log, Sin, Cos, LE, LT, GE, GT, EQ, NEQ @mlx_funcify.register(Dot) @@ -19,6 +35,7 @@ def dot(x, y): @mlx_funcify.register(Elemwise) def mlx_funcify_Elemwise(op, **kwargs): if isinstance(op.scalar_op, Add): + def add(*args): result = args[0] for arg in args[1:]: @@ -33,6 +50,7 @@ def sub(x, y): return sub elif isinstance(op.scalar_op, Mul): + def mul(*args): result = args[0] for arg in args[1:]: @@ -65,39 +83,64 @@ def cos(x): return cos elif isinstance(op.scalar_op, Sigmoid): + def sigmoid(x): return mx.sigmoid(x) return sigmoid elif isinstance(op.scalar_op, LE): + def le(x, y): return mx.less_equal(x, y) return le elif isinstance(op.scalar_op, LT): + def lt(x, y): return mx.less(x, y) return lt elif isinstance(op.scalar_op, GE): + def ge(x, y): return mx.greater_equal(x, y) return ge elif isinstance(op.scalar_op, GT): + def gt(x, y): return mx.greater(x, y) return gt elif isinstance(op.scalar_op, EQ): + def eq(x, y): return mx.equal(x, y) return eq elif isinstance(op.scalar_op, NEQ): + def neq(x, y): return mx.not_equal(x, y) return neq + elif isinstance(op.scalar_op, Switch): + + def switch(cond, x, y): + return mx.where(cond, x, y) + + return switch + elif isinstance(op.scalar_op, Pow): + + def pow(x, y): + return mx.power(x, y) + + return pow + elif isinstance(op.scalar_op, TrueDiv): + + def true_div(x, y): + return mx.divide(x, y) + + return true_div else: raise NotImplementedError(f"MLX does not support {op.scalar_op}") diff --git a/pytensor/link/mlx/dispatch/shape.py b/pytensor/link/mlx/dispatch/shape.py index c22ecea704..1d48eae1f5 100644 --- a/pytensor/link/mlx/dispatch/shape.py +++ b/pytensor/link/mlx/dispatch/shape.py @@ -1,5 +1,6 @@ -from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.tensor.shape import SpecifyShape + @mlx_funcify.register(SpecifyShape) def mlx_funcify_SpecifyShape(op, node, **kwargs): @@ -12,4 +13,4 @@ def specifyshape(x, *shape): raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}") return x - return specifyshape \ No newline at end of file + return specifyshape From 60acb8d7ac60d44e99a7c2d680702134cee446dd Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 14:15:32 -0500 Subject: [PATCH 25/54] Pushing changes with the core shit. --- .gitignore | 1 - pytensor/link/mlx/dispatch/core.py | 174 +++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 pytensor/link/mlx/dispatch/core.py diff --git a/.gitignore b/.gitignore index dfe862b868..ebe8e61bd0 100644 --- a/.gitignore +++ b/.gitignore @@ -27,7 +27,6 @@ __pycache__ \#*\# build compiled/*.cpp -core.* cutils_ext.cpp dist doc/.build/ diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py new file mode 100644 index 0000000000..84ae042682 --- /dev/null +++ b/pytensor/link/mlx/dispatch/core.py @@ -0,0 +1,174 @@ +""" +pytensor/link/mlx/dispatch/basic.py +----------------------------------- + +First‑cut MLX translations for the most common tensor Ops. + +The structure intentionally follows pytensor's JAX dispatcher so that +once these kernels stabilise they can be optimised further (e.g. fusing +element‑wise graphs, adding in‑place updates, RNG thinning, etc.). +""" +from __future__ import annotations + +import warnings +import numpy as np + +import mlx.core as mx # MLX +from pytensor.link.mlx.dispatch.basic import mlx_funcify # MLX + +from pytensor.tensor import get_vector_length +from pytensor.tensor.basic import ( + Join, Split, ExtractDiag, Eye, MakeVector, + ScalarFromTensor, TensorFromScalar, Tri, + get_scalar_constant_value, +) +from pytensor.tensor.exceptions import NotScalarConstantError + + +# ------------------------------------------------------------------ +# Join +# ------------------------------------------------------------------ +@mlx_funcify.register(Join) # MLX +def mlx_funcify_Join(op, **kwargs): + def join(axis, *tensors): + view = op.view + if (view != -1) and all( + tensors[i].shape[axis] == 0 # MLX + for i in list(range(view)) + list(range(view + 1, len(tensors))) + ): + return tensors[view] + + return mx.concatenate(tensors, axis=axis) # MLX + + return join + + +# ------------------------------------------------------------------ +# Split +# ------------------------------------------------------------------ +@mlx_funcify.register(Split) # MLX +def mlx_funcify_Split(op: Split, node, **kwargs): + _, axis_sym, splits_sym = node.inputs + + try: + constant_axis = get_scalar_constant_value(axis_sym) + except NotScalarConstantError: + constant_axis = None + warnings.warn( + "Split node does not have a constant axis. MLX implementation may fail." + ) + + try: + constant_splits = np.array( + [get_scalar_constant_value(splits_sym[i]) + for i in range(get_vector_length(splits_sym))] + ) + except (ValueError, NotScalarConstantError): + constant_splits = None + warnings.warn( + "Split node does not have constant split positions. MLX implementation may fail." + ) + + def split(x, axis, splits): + # Resolve constants (avoids tracing extra ops) + if constant_axis is not None: + axis = int(constant_axis) + + if constant_splits is not None: + splits = constant_splits + cumsum_splits = np.cumsum(splits[:-1]) + else: + # dynamic ‑– keep in graph + splits_arr = mx.array(splits) # MLX + cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist() # python list for mx.split + + if len(splits) != op.len_splits: + raise ValueError("Length of 'splits' is not equal to n_splits") + if np.sum(np.asarray(splits)) != x.shape[axis]: + raise ValueError("Split sizes do not sum to the input length on the chosen axis.") + if np.any(np.asarray(splits) < 0): + raise ValueError("Split sizes cannot be negative.") + + return mx.split(x, cumsum_splits, axis=axis) # MLX + + return split + + +# ------------------------------------------------------------------ +# ExtractDiag +# ------------------------------------------------------------------ +@mlx_funcify.register(ExtractDiag) # MLX +def mlx_funcify_ExtractDiag(op, **kwargs): + offset, axis1, axis2 = op.offset, op.axis1, op.axis2 + + def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): + return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) # MLX + + return extract_diag + + +# ------------------------------------------------------------------ +# Eye +# ------------------------------------------------------------------ +@mlx_funcify.register(Eye) # MLX +def mlx_funcify_Eye(op, **kwargs): + dtype = op.dtype + + def eye(N, M, k): + return mx.eye(int(N), int(M), int(k), dtype=dtype) # MLX + + return eye + + +# ------------------------------------------------------------------ +# MakeVector +# ------------------------------------------------------------------ +@mlx_funcify.register(MakeVector) # MLX +def mlx_funcify_MakeVector(op, **kwargs): + def makevector(*x): + return mx.array(x, dtype=op.dtype) # MLX + + return makevector + + +# ------------------------------------------------------------------ +# TensorFromScalar (identity for MLX) +# ------------------------------------------------------------------ +@mlx_funcify.register(TensorFromScalar) # MLX +def mlx_funcify_TensorFromScalar(op, **kwargs): + def tensor_from_scalar(x): + return x # already an MLX array / scalar + + return tensor_from_scalar + + +# ------------------------------------------------------------------ +# ScalarFromTensor +# ------------------------------------------------------------------ +@mlx_funcify.register(ScalarFromTensor) # MLX +def mlx_funcify_ScalarFromTensor(op, **kwargs): + def scalar_from_tensor(x): + return mx.array(x).reshape(-1)[0] # MLX + + return scalar_from_tensor + + +# ------------------------------------------------------------------ +# Tri +# ------------------------------------------------------------------ +@mlx_funcify.register(Tri) # MLX +def mlx_funcify_Tri(op, node, **kwargs): + # node.inputs -> N, M, k + const_args = [getattr(inp, "data", None) for inp in node.inputs] + + def tri(*args): + # Replace args with compile‑time constants when available + args = [ + arg if const_a is None else const_a + for arg, const_a in zip(args, const_args, strict=True) + ] + return mx.tri(*args, dtype=op.dtype) # MLX + + return tri + +## Change the code to use the mlx functions \ No newline at end of file From 242aba75e3f212673faaa9b21d77029d76550029 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 15:34:02 -0400 Subject: [PATCH 26/54] add more tests --- tests/link/mlx/test_math.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index c4f16d2f28..57b717360d 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -27,8 +27,13 @@ def test_dot(): @pytest.mark.parametrize( "op", - [pt.exp, pt.log, pt.sin, pt.cos], - ids=["exp", "log", "sin", "cos"], + [ + pytest.param(pt.exp, id="exp"), + pytest.param(pt.log, id="log"), + pytest.param(pt.sin, id="sin"), + pytest.param(pt.cos, id="cos"), + pytest.param(pt.sigmoid, id="sigmoid"), + ], ) def test_elemwise_one_input(op) -> None: x = pt.vector("x") @@ -39,8 +44,16 @@ def test_elemwise_one_input(op) -> None: @pytest.mark.parametrize( "op", - [pt.add, pt.sub, pt.mul], - ids=["add", "sub", "mul"], + [ + pytest.param(pt.add, id="add"), + pytest.param(pt.sub, id="sub"), + pytest.param(pt.mul, id="mul"), + pytest.param(pt.power, id="power"), + pytest.param(pt.le, id="le"), + pytest.param(pt.lt, id="lt"), + pytest.param(pt.ge, id="ge"), + pytest.param(pt.gt, id="gt"), + ], ) def test_elemwise_two_inputs(op) -> None: x = pt.vector("x") From 6cb47fc61b316fe5d22c1a56bc408d2a8e6c6366 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 15:43:39 -0400 Subject: [PATCH 27/54] additional tests --- tests/link/mlx/test_math.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 57b717360d..25a0198cc1 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -53,6 +53,9 @@ def test_elemwise_one_input(op) -> None: pytest.param(pt.lt, id="lt"), pytest.param(pt.ge, id="ge"), pytest.param(pt.gt, id="gt"), + pytest.param(pt.eq, id="eq"), + pytest.param(pt.neq, id="neq"), + pytest.param(pt.true_div, id="true_div"), ], ) def test_elemwise_two_inputs(op) -> None: From bc98e09caa1053becbeab2017d3958b6e9bd77ca Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 15:50:17 -0400 Subject: [PATCH 28/54] test for switch with mlx --- tests/link/mlx/test_basic.py | 6 ++++-- tests/link/mlx/test_math.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py index 746aab10bc..4a6e67f406 100644 --- a/tests/link/mlx/test_basic.py +++ b/tests/link/mlx/test_basic.py @@ -5,14 +5,16 @@ import pytest from pytensor.compile.function import function -from pytensor.compile.mode import Mode +from pytensor.compile.mode import MLX, Mode +from pytensor.graph import RewriteDatabaseQuery from pytensor.graph.basic import Variable from pytensor.link.mlx import MLXLinker mx = pytest.importorskip("mlx.core") -mlx_mode = Mode(linker=MLXLinker()) +optimizer = RewriteDatabaseQuery(include=["mlx"], exclude=MLX._optimizer.exclude) +mlx_mode = Mode(linker=MLXLinker(), optimizer=optimizer) py_mode = Mode(linker="py", optimizer=None) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 25a0198cc1..86fa999451 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -42,6 +42,18 @@ def test_elemwise_one_input(op) -> None: compare_mlx_and_py([x], out, [x_test]) +def test_switch() -> None: + x = pt.vector("x") + y = pt.vector("y") + + out = pt.switch(x > 0, y, x) + + x_test = mx.array([-1.0, 2.0, 3.0]) + y_test = mx.array([4.0, 5.0, 6.0]) + + compare_mlx_and_py([x, y], out, [x_test, y_test]) + + @pytest.mark.parametrize( "op", [ From 4d5b34b4b324c9f78da83e8d5a94c1f8d692738b Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 14:52:19 -0500 Subject: [PATCH 29/54] Pushing code --- pytensor/link/mlx/dispatch/__init__.py | 2 + pytensor/link/mlx/dispatch/core.py | 95 ++++++++++++------- pytensor/link/mlx/dispatch/shape.py | 10 +- pytensor/link/mlx/dispatch/signal/__init__.py | 0 pytensor/link/mlx/dispatch/signal/conv.py | 14 +++ 5 files changed, 88 insertions(+), 33 deletions(-) create mode 100644 pytensor/link/mlx/dispatch/signal/__init__.py create mode 100644 pytensor/link/mlx/dispatch/signal/conv.py diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index 7e835d238b..2dd4e8a02d 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -7,4 +7,6 @@ import pytensor.link.mlx.dispatch.shape import pytensor.link.mlx.dispatch.subtensor import pytensor.link.mlx.dispatch.core +import pytensor.link.mlx.dispatch.signal +import pytensor.link.mlx.dispatch.signal.conv # isort: on diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index 84ae042682..6985c2b656 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -2,24 +2,33 @@ pytensor/link/mlx/dispatch/basic.py ----------------------------------- -First‑cut MLX translations for the most common tensor Ops. +First-cut MLX translations for the most common tensor Ops. The structure intentionally follows pytensor's JAX dispatcher so that once these kernels stabilise they can be optimised further (e.g. fusing -element‑wise graphs, adding in‑place updates, RNG thinning, etc.). +element-wise graphs, adding in-place updates, RNG thinning, etc.). """ + from __future__ import annotations import warnings -import numpy as np -import mlx.core as mx # MLX -from pytensor.link.mlx.dispatch.basic import mlx_funcify # MLX +import mlx.core as mx # MLX +import numpy as np +from pytensor.link.mlx.dispatch.basic import mlx_funcify # MLX from pytensor.tensor import get_vector_length from pytensor.tensor.basic import ( - Join, Split, ExtractDiag, Eye, MakeVector, - ScalarFromTensor, TensorFromScalar, Tri, + Alloc, + AllocEmpty, + ExtractDiag, + Eye, + Join, + MakeVector, + ScalarFromTensor, + Split, + TensorFromScalar, + Tri, get_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError @@ -28,17 +37,17 @@ # ------------------------------------------------------------------ # Join # ------------------------------------------------------------------ -@mlx_funcify.register(Join) # MLX +@mlx_funcify.register(Join) # MLX def mlx_funcify_Join(op, **kwargs): def join(axis, *tensors): view = op.view if (view != -1) and all( - tensors[i].shape[axis] == 0 # MLX + tensors[i].shape[axis] == 0 # MLX for i in list(range(view)) + list(range(view + 1, len(tensors))) ): return tensors[view] - return mx.concatenate(tensors, axis=axis) # MLX + return mx.concatenate(tensors, axis=axis) # MLX return join @@ -46,7 +55,7 @@ def join(axis, *tensors): # ------------------------------------------------------------------ # Split # ------------------------------------------------------------------ -@mlx_funcify.register(Split) # MLX +@mlx_funcify.register(Split) # MLX def mlx_funcify_Split(op: Split, node, **kwargs): _, axis_sym, splits_sym = node.inputs @@ -60,8 +69,10 @@ def mlx_funcify_Split(op: Split, node, **kwargs): try: constant_splits = np.array( - [get_scalar_constant_value(splits_sym[i]) - for i in range(get_vector_length(splits_sym))] + [ + get_scalar_constant_value(splits_sym[i]) + for i in range(get_vector_length(splits_sym)) + ] ) except (ValueError, NotScalarConstantError): constant_splits = None @@ -78,18 +89,22 @@ def split(x, axis, splits): splits = constant_splits cumsum_splits = np.cumsum(splits[:-1]) else: - # dynamic ‑– keep in graph - splits_arr = mx.array(splits) # MLX - cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist() # python list for mx.split + # dynamic - keep in graph + splits_arr = mx.array(splits) # MLX + cumsum_splits = mx.cumsum( + splits_arr[:-1] + ).tolist() # python list for mx.split if len(splits) != op.len_splits: raise ValueError("Length of 'splits' is not equal to n_splits") if np.sum(np.asarray(splits)) != x.shape[axis]: - raise ValueError("Split sizes do not sum to the input length on the chosen axis.") + raise ValueError( + "Split sizes do not sum to the input length on the chosen axis." + ) if np.any(np.asarray(splits) < 0): raise ValueError("Split sizes cannot be negative.") - return mx.split(x, cumsum_splits, axis=axis) # MLX + return mx.split(x, cumsum_splits, axis=axis) # MLX return split @@ -97,12 +112,12 @@ def split(x, axis, splits): # ------------------------------------------------------------------ # ExtractDiag # ------------------------------------------------------------------ -@mlx_funcify.register(ExtractDiag) # MLX +@mlx_funcify.register(ExtractDiag) # MLX def mlx_funcify_ExtractDiag(op, **kwargs): offset, axis1, axis2 = op.offset, op.axis1, op.axis2 def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): - return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) # MLX + return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) # MLX return extract_diag @@ -110,12 +125,12 @@ def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): # ------------------------------------------------------------------ # Eye # ------------------------------------------------------------------ -@mlx_funcify.register(Eye) # MLX +@mlx_funcify.register(Eye) # MLX def mlx_funcify_Eye(op, **kwargs): dtype = op.dtype def eye(N, M, k): - return mx.eye(int(N), int(M), int(k), dtype=dtype) # MLX + return mx.eye(int(N), int(M), int(k), dtype=dtype) # MLX return eye @@ -123,10 +138,10 @@ def eye(N, M, k): # ------------------------------------------------------------------ # MakeVector # ------------------------------------------------------------------ -@mlx_funcify.register(MakeVector) # MLX +@mlx_funcify.register(MakeVector) # MLX def mlx_funcify_MakeVector(op, **kwargs): def makevector(*x): - return mx.array(x, dtype=op.dtype) # MLX + return mx.array(x, dtype=op.dtype) # MLX return makevector @@ -134,10 +149,10 @@ def makevector(*x): # ------------------------------------------------------------------ # TensorFromScalar (identity for MLX) # ------------------------------------------------------------------ -@mlx_funcify.register(TensorFromScalar) # MLX +@mlx_funcify.register(TensorFromScalar) # MLX def mlx_funcify_TensorFromScalar(op, **kwargs): def tensor_from_scalar(x): - return x # already an MLX array / scalar + return x # already an MLX array / scalar return tensor_from_scalar @@ -145,10 +160,10 @@ def tensor_from_scalar(x): # ------------------------------------------------------------------ # ScalarFromTensor # ------------------------------------------------------------------ -@mlx_funcify.register(ScalarFromTensor) # MLX +@mlx_funcify.register(ScalarFromTensor) # MLX def mlx_funcify_ScalarFromTensor(op, **kwargs): def scalar_from_tensor(x): - return mx.array(x).reshape(-1)[0] # MLX + return mx.array(x).reshape(-1)[0] # MLX return scalar_from_tensor @@ -156,19 +171,35 @@ def scalar_from_tensor(x): # ------------------------------------------------------------------ # Tri # ------------------------------------------------------------------ -@mlx_funcify.register(Tri) # MLX +@mlx_funcify.register(Tri) # MLX def mlx_funcify_Tri(op, node, **kwargs): # node.inputs -> N, M, k const_args = [getattr(inp, "data", None) for inp in node.inputs] def tri(*args): - # Replace args with compile‑time constants when available + # Replace args with compile-time constants when available args = [ arg if const_a is None else const_a for arg, const_a in zip(args, const_args, strict=True) ] - return mx.tri(*args, dtype=op.dtype) # MLX + return mx.tri(*args, dtype=op.dtype) # MLX return tri -## Change the code to use the mlx functions \ No newline at end of file + +@mlx_funcify.register(AllocEmpty) +def mlx_funcify_AllocEmpty(op, **kwargs): + def allocempty(*shape): + return mx.zeros(shape, dtype=op.dtype) + + return allocempty + + +@mlx_funcify.register(Alloc) +def mlx_funcify_Alloc(op, node, **kwargs): + def alloc(x, *shape): + res = mx.broadcast_to(x, shape) + Alloc._check_runtime_broadcast(node, mx.array(x), res.shape) + return res + + return alloc diff --git a/pytensor/link/mlx/dispatch/shape.py b/pytensor/link/mlx/dispatch/shape.py index 1d48eae1f5..a0b8193b42 100644 --- a/pytensor/link/mlx/dispatch/shape.py +++ b/pytensor/link/mlx/dispatch/shape.py @@ -1,5 +1,5 @@ from pytensor.link.mlx.dispatch.basic import mlx_funcify -from pytensor.tensor.shape import SpecifyShape +from pytensor.tensor.shape import Shape_i, SpecifyShape @mlx_funcify.register(SpecifyShape) @@ -14,3 +14,11 @@ def specifyshape(x, *shape): return x return specifyshape + + +@mlx_funcify.register(Shape_i) +def mlx_funcify_Shape_i(op, node, **kwargs): + def shape_i(x, i): + return x.shape[op.i] + + return shape_i diff --git a/pytensor/link/mlx/dispatch/signal/__init__.py b/pytensor/link/mlx/dispatch/signal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pytensor/link/mlx/dispatch/signal/conv.py b/pytensor/link/mlx/dispatch/signal/conv.py new file mode 100644 index 0000000000..a383695437 --- /dev/null +++ b/pytensor/link/mlx/dispatch/signal/conv.py @@ -0,0 +1,14 @@ +from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.tensor.signal.conv import Conv1d + +import mlx.core as mx + + +@mlx_funcify.register(Conv1d) +def mlx_funcify_Conv1d(op, node, **kwargs): + mode = op.mode + + def conv1d(data, kernel): + return mx.convolve(data, kernel, mode=mode) + + return conv1d From 5abd32d6b6e67a45f1a3f7f50f08131bb23c98eb Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 15:00:08 -0500 Subject: [PATCH 30/54] Changes --- pytensor/link/mlx/dispatch/blockwise.py | 16 ++++++++++++++++ pytensor/link/mlx/dispatch/signal/conv.py | 4 ++-- 2 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 pytensor/link/mlx/dispatch/blockwise.py diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py new file mode 100644 index 0000000000..240ee1ad21 --- /dev/null +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -0,0 +1,16 @@ +import mlx.core as mx + +from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.tensor.blockwise import Blockwise + +@mlx_funcify.register(Blockwise) +def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): + core_f = mlx_funcify(op.core_op) + batched_f = core_f + for _ in range(op.batch_ndim(node)): + batched_f = mx.vmap(batched_f) + + def wrapped_blockwise_f(*inputs): + return batched_f(*inputs) + + return wrapped_blockwise_f diff --git a/pytensor/link/mlx/dispatch/signal/conv.py b/pytensor/link/mlx/dispatch/signal/conv.py index a383695437..d3725b7f3e 100644 --- a/pytensor/link/mlx/dispatch/signal/conv.py +++ b/pytensor/link/mlx/dispatch/signal/conv.py @@ -1,8 +1,8 @@ +import mlx.core as mx + from pytensor.link.mlx.dispatch import mlx_funcify from pytensor.tensor.signal.conv import Conv1d -import mlx.core as mx - @mlx_funcify.register(Conv1d) def mlx_funcify_Conv1d(op, node, **kwargs): From 12daeacfd356807fef2c1a0c9e4902bf4f49fa56 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 15:23:13 -0500 Subject: [PATCH 31/54] A lot of new code --- pytensor/link/mlx/dispatch/__init__.py | 1 + pytensor/link/mlx/dispatch/blockwise.py | 68 +++++++++++++++++--- pytensor/link/mlx/dispatch/elemwise.py | 24 +++++++- pytensor/link/mlx/dispatch/math.py | 75 ++++++++++++++++++++++- pytensor/link/mlx/dispatch/signal/conv.py | 2 +- 5 files changed, 157 insertions(+), 13 deletions(-) diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index 2dd4e8a02d..f039263a37 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -9,4 +9,5 @@ import pytensor.link.mlx.dispatch.core import pytensor.link.mlx.dispatch.signal import pytensor.link.mlx.dispatch.signal.conv +import pytensor.link.mlx.dispatch.blockwise # isort: on diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 240ee1ad21..5a5ed8584a 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -1,16 +1,66 @@ import mlx.core as mx +from pytensor.graph import FunctionGraph from pytensor.link.mlx.dispatch import mlx_funcify from pytensor.tensor.blockwise import Blockwise + @mlx_funcify.register(Blockwise) def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): - core_f = mlx_funcify(op.core_op) - batched_f = core_f - for _ in range(op.batch_ndim(node)): - batched_f = mx.vmap(batched_f) - - def wrapped_blockwise_f(*inputs): - return batched_f(*inputs) - - return wrapped_blockwise_f + # Create a function graph for the core operation + core_node = op._create_dummy_core_node(node.inputs) + core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) + + # Convert the core function graph to an MLX function + tuple_core_fn = mlx_funcify(core_fgraph, **kwargs) + + # If there's only one output, unwrap it from the tuple + if len(node.outputs) == 1: + + def core_fn(*inputs): + return tuple_core_fn(*inputs)[0] + else: + core_fn = tuple_core_fn + + # Apply vmap for each batch dimension + batch_ndims = op.batch_ndim(node) + vmap_fn = core_fn + for _ in range(batch_ndims): + vmap_fn = mx.vmap(vmap_fn) + + def blockwise_fn(*inputs): + # Check for runtime broadcasting compatibility + op._check_runtime_broadcast(node, inputs) + + # Handle broadcasting for batched dimensions + if batch_ndims > 0: + # Get batch shapes for broadcasting + batch_shapes = [inp.shape[:batch_ndims] for inp in inputs] + + # Calculate the broadcasted batch shape + from functools import reduce + + def broadcast_shapes(shape1, shape2): + return tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2, strict=True)) + + if batch_shapes: + broadcasted_shape = reduce(broadcast_shapes, batch_shapes) + + # Broadcast inputs to the common batch shape + broadcasted_inputs = [] + for inp in inputs: + if inp.shape[:batch_ndims] != broadcasted_shape: + # Create the full target shape + target_shape = broadcasted_shape + inp.shape[batch_ndims:] + # Broadcast the input + broadcasted_inputs.append(mx.broadcast_to(inp, target_shape)) + else: + broadcasted_inputs.append(inp) + + # Apply the vectorized function to the broadcasted inputs + return vmap_fn(*broadcasted_inputs) + + # No broadcasting needed + return vmap_fn(*inputs) + + return blockwise_fn diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index 7ec124623e..a5374b8cb4 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -1,6 +1,7 @@ import mlx.core as mx from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.scalar import Softplus from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.tensor.special import Softmax, SoftmaxGrad @@ -59,9 +60,8 @@ def min(x): return mx.min(x, axis=op.axis) return min - else: - raise NotImplementedError(f"MLX does not support {op.scalar_op}") + raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}") @mlx_funcify.register(Softmax) @@ -83,3 +83,23 @@ def softmax_grad(dy, sm): return dy_times_sm - mx.sum(dy_times_sm, axis=axis, keepdims=True) * sm return softmax_grad + + +@mlx_funcify.register(Softplus) +def mlx_funcify_Softplus(op, **kwargs): + def softplus(x): + return mx.where( + x < -37.0, + mx.exp(x), + mx.where( + x < 18.0, + mx.log1p(mx.exp(x)), + mx.where( + x < 33.3, + x + mx.exp(-x), + x, + ), + ), + ) + + return softplus diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index a0f68324d9..1adf547e3f 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -8,6 +8,7 @@ LE, LT, NEQ, + Abs, Add, Cos, Exp, @@ -15,14 +16,21 @@ Mul, Pow, Sin, + Sqr, + Sqrt, Sub, Switch, TrueDiv, + Neg, + AND, + OR, + ScalarMaximum, + ScalarMinimum, ) from pytensor.scalar.math import Sigmoid from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot - +from pytensor.scalar import Softplus @mlx_funcify.register(Dot) def mlx_funcify_Dot(op, **kwargs): @@ -142,5 +150,70 @@ def true_div(x, y): return mx.divide(x, y) return true_div + elif isinstance(op.scalar_op, Sqr): + + def sqr(x): + return mx.square(x) + + return sqr + elif isinstance(op.scalar_op, Sqrt): + + def sqrt(x): + return mx.sqrt(x) + + return sqrt + elif isinstance(op.scalar_op, Abs): + + def abs(x): + return mx.abs(x) + + return abs + elif isinstance(op.scalar_op, Softplus): + def softplus(x): + return mx.where( + x < -37.0, + mx.exp(x), + mx.where( + x < 18.0, + mx.log1p(mx.exp(x)), + mx.where( + x < 33.3, + x + mx.exp(-x), + x, + ), + ), + ) + + return softplus + elif isinstance(op.scalar_op, Neg): + + def neg(x): + return mx.negative(x) + + return neg + elif isinstance(op.scalar_op, AND): + + def all(x): + return mx.all(x, axis=op.axis) + + return all + elif isinstance(op.scalar_op, OR): + + def any(x): + return mx.any(x, axis=op.axis) + + return any + elif isinstance(op.scalar_op, ScalarMaximum): + + def max(x): + return mx.max(x, axis=op.axis) + + return max + elif isinstance(op.scalar_op, ScalarMinimum): + + def min(x): + return mx.min(x, axis=op.axis) + + return min else: raise NotImplementedError(f"MLX does not support {op.scalar_op}") diff --git a/pytensor/link/mlx/dispatch/signal/conv.py b/pytensor/link/mlx/dispatch/signal/conv.py index d3725b7f3e..8f84ebb42f 100644 --- a/pytensor/link/mlx/dispatch/signal/conv.py +++ b/pytensor/link/mlx/dispatch/signal/conv.py @@ -5,7 +5,7 @@ @mlx_funcify.register(Conv1d) -def mlx_funcify_Conv1d(op, node, **kwargs): +def mlx_funcify_Conv1d(op, node=None, **kwargs): mode = op.mode def conv1d(data, kernel): From ac93949d4b50f6eff5c165781b8fa9a444265e3b Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 18:20:17 -0500 Subject: [PATCH 32/54] almost there baby william --- pytensor/link/mlx/dispatch/basic.py | 16 +++++++ pytensor/link/mlx/dispatch/blockwise.py | 62 +++---------------------- pytensor/link/mlx/dispatch/math.py | 37 ++++++++++++--- pytensor/link/mlx/dispatch/shape.py | 2 +- pytensor/link/mlx/dispatch/subtensor.py | 32 +++++-------- 5 files changed, 66 insertions(+), 83 deletions(-) diff --git a/pytensor/link/mlx/dispatch/basic.py b/pytensor/link/mlx/dispatch/basic.py index 9cbb92118d..a99772dba3 100644 --- a/pytensor/link/mlx/dispatch/basic.py +++ b/pytensor/link/mlx/dispatch/basic.py @@ -1,3 +1,4 @@ +import warnings from functools import singledispatch from types import NoneType @@ -7,6 +8,7 @@ from pytensor.compile.ops import DeepCopyOp from pytensor.graph.fg import FunctionGraph from pytensor.link.utils import fgraph_to_python +from pytensor.raise_op import Assert, CheckAndRaise @singledispatch @@ -59,3 +61,17 @@ def deepcopyop(x): return x.copy() return deepcopyop + + +@mlx_funcify.register(Assert) +@mlx_funcify.register(CheckAndRaise) +def mlx_funcify_CheckAndRaise(op, **kwargs): + warnings.warn( + f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as MLX tracing would remove it.""", + stacklevel=2, + ) + + def assert_fn(x, *inputs): + return x + + return assert_fn diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 5a5ed8584a..550a1c9616 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -1,66 +1,18 @@ import mlx.core as mx -from pytensor.graph import FunctionGraph from pytensor.link.mlx.dispatch import mlx_funcify from pytensor.tensor.blockwise import Blockwise @mlx_funcify.register(Blockwise) def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): - # Create a function graph for the core operation core_node = op._create_dummy_core_node(node.inputs) - core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) + core_f = mlx_funcify(op.core_op, core_node) + blockwise_f = core_f + for i in range(op.batch_ndim(node)): + blockwise_f = mx.vmap(blockwise_f) - # Convert the core function graph to an MLX function - tuple_core_fn = mlx_funcify(core_fgraph, **kwargs) + def blockwise_fun(*inputs): + return blockwise_f(*inputs) - # If there's only one output, unwrap it from the tuple - if len(node.outputs) == 1: - - def core_fn(*inputs): - return tuple_core_fn(*inputs)[0] - else: - core_fn = tuple_core_fn - - # Apply vmap for each batch dimension - batch_ndims = op.batch_ndim(node) - vmap_fn = core_fn - for _ in range(batch_ndims): - vmap_fn = mx.vmap(vmap_fn) - - def blockwise_fn(*inputs): - # Check for runtime broadcasting compatibility - op._check_runtime_broadcast(node, inputs) - - # Handle broadcasting for batched dimensions - if batch_ndims > 0: - # Get batch shapes for broadcasting - batch_shapes = [inp.shape[:batch_ndims] for inp in inputs] - - # Calculate the broadcasted batch shape - from functools import reduce - - def broadcast_shapes(shape1, shape2): - return tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2, strict=True)) - - if batch_shapes: - broadcasted_shape = reduce(broadcast_shapes, batch_shapes) - - # Broadcast inputs to the common batch shape - broadcasted_inputs = [] - for inp in inputs: - if inp.shape[:batch_ndims] != broadcasted_shape: - # Create the full target shape - target_shape = broadcasted_shape + inp.shape[batch_ndims:] - # Broadcast the input - broadcasted_inputs.append(mx.broadcast_to(inp, target_shape)) - else: - broadcasted_inputs.append(inp) - - # Apply the vectorized function to the broadcasted inputs - return vmap_fn(*broadcasted_inputs) - - # No broadcasting needed - return vmap_fn(*inputs) - - return blockwise_fn + return blockwise_fun diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 1adf547e3f..6696635ed5 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -1,36 +1,40 @@ import mlx.core as mx from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.scalar import Softplus from pytensor.scalar.basic import ( + AND, EQ, GE, GT, LE, LT, NEQ, + OR, Abs, Add, + Cast, Cos, Exp, Log, Mul, + Neg, Pow, + ScalarMaximum, + ScalarMinimum, + Sign, Sin, Sqr, Sqrt, Sub, Switch, TrueDiv, - Neg, - AND, - OR, - ScalarMaximum, - ScalarMinimum, + Log1p ) from pytensor.scalar.math import Sigmoid from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot -from pytensor.scalar import Softplus + @mlx_funcify.register(Dot) def mlx_funcify_Dot(op, **kwargs): @@ -169,6 +173,7 @@ def abs(x): return abs elif isinstance(op.scalar_op, Softplus): + def softplus(x): return mx.where( x < -37.0, @@ -194,7 +199,7 @@ def neg(x): elif isinstance(op.scalar_op, AND): def all(x): - return mx.all(x, axis=op.axis) + return mx.all(x) return all elif isinstance(op.scalar_op, OR): @@ -215,5 +220,23 @@ def min(x): return mx.min(x, axis=op.axis) return min + elif isinstance(op.scalar_op, Cast): + + def cast(x): + return mx.cast(x, op.dtype) + + return cast + elif isinstance(op.scalar_op, Sign): + + def sign(x): + return mx.sign(x) + + return sign + elif isinstance(op.scalar_op, Log1p): + + def log1p(x): + return mx.log1p(x) + + return log1p else: raise NotImplementedError(f"MLX does not support {op.scalar_op}") diff --git a/pytensor/link/mlx/dispatch/shape.py b/pytensor/link/mlx/dispatch/shape.py index a0b8193b42..bd5b5941d9 100644 --- a/pytensor/link/mlx/dispatch/shape.py +++ b/pytensor/link/mlx/dispatch/shape.py @@ -18,7 +18,7 @@ def specifyshape(x, *shape): @mlx_funcify.register(Shape_i) def mlx_funcify_Shape_i(op, node, **kwargs): - def shape_i(x, i): + def shape_i(x): return x.shape[op.i] return shape_i diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index 7f8b55f18e..b45a10519c 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -11,40 +11,32 @@ from pytensor.tensor.type_other import MakeSlice -BOOLEAN_MASK_ERROR = """MLX does not support resizing arrays with boolean -masks. In some cases, however, it is possible to re-express your model -in a form that MLX can compile: - ->>> import pytensor.tensor as pt ->>> x_pt = pt.vector('x') ->>> y_pt = x_pt[x_pt > 0].sum() - -can be re-expressed as: +@mlx_funcify.register(Subtensor) +def mlx_funcify_Subtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) ->>> import pytensor.tensor as pt ->>> x_pt = pt.vector('x') ->>> y_pt = pt.where(x_pt > 0, x_pt, 0).sum() -""" + def subtensor(x, *ilists): + indices = indices_from_subtensor([int(element) for element in ilists], idx_list) + if len(indices) == 1: + indices = indices[0] -DYNAMIC_SLICE_LENGTH_ERROR = """MLX does not support slicing arrays with a dynamic -slice length. -""" + return x.__getitem__(indices) + return subtensor -@mlx_funcify.register(Subtensor) @mlx_funcify.register(AdvancedSubtensor) @mlx_funcify.register(AdvancedSubtensor1) -def mlx_funcify_Subtensor(op, node, **kwargs): +def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): idx_list = getattr(op, "idx_list", None) - def subtensor(x, *ilists): + def advanced_subtensor(x, *ilists): indices = indices_from_subtensor(ilists, idx_list) if len(indices) == 1: indices = indices[0] return x.__getitem__(indices) - return subtensor + return advanced_subtensor @mlx_funcify.register(IncSubtensor) From a19cbc87180adf5bc96cad8c91be9ad19ab0856d Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 18:28:04 -0500 Subject: [PATCH 33/54] Another push small --- pytensor/link/mlx/dispatch/elemwise.py | 2 +- pytensor/link/mlx/dispatch/math.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index a5374b8cb4..81cdf2b2ca 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -39,7 +39,7 @@ def prod(x): elif isinstance(op.scalar_op, AND): def all(x): - return mx.all(x, axis=op.axis) + return mx.all(a=x, axis=op.axis) return all elif isinstance(op.scalar_op, OR): diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 6696635ed5..5398183f29 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -199,7 +199,7 @@ def neg(x): elif isinstance(op.scalar_op, AND): def all(x): - return mx.all(x) + return mx.all(a=x, axis=op.axis) return all elif isinstance(op.scalar_op, OR): From 5c97bc8d31185d82b7dd97c962cc60664873e301 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 19:34:11 -0400 Subject: [PATCH 34/54] fix for all --- pytensor/link/mlx/dispatch/elemwise.py | 2 +- pytensor/link/mlx/dispatch/math.py | 4 ++-- tests/link/mlx/test_elemwise.py | 12 ++++++++++++ 3 files changed, 15 insertions(+), 3 deletions(-) create mode 100644 tests/link/mlx/test_elemwise.py diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index 81cdf2b2ca..57103c12ff 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -39,7 +39,7 @@ def prod(x): elif isinstance(op.scalar_op, AND): def all(x): - return mx.all(a=x, axis=op.axis) + return x.all(axis=op.axis) return all elif isinstance(op.scalar_op, OR): diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 5398183f29..305f86c90b 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -17,6 +17,7 @@ Cos, Exp, Log, + Log1p, Mul, Neg, Pow, @@ -29,7 +30,6 @@ Sub, Switch, TrueDiv, - Log1p ) from pytensor.scalar.math import Sigmoid from pytensor.tensor.elemwise import Elemwise @@ -199,7 +199,7 @@ def neg(x): elif isinstance(op.scalar_op, AND): def all(x): - return mx.all(a=x, axis=op.axis) + return x.all(axis=op.axis) return all elif isinstance(op.scalar_op, OR): diff --git a/tests/link/mlx/test_elemwise.py b/tests/link/mlx/test_elemwise.py new file mode 100644 index 0000000000..d7e17b6654 --- /dev/null +++ b/tests/link/mlx/test_elemwise.py @@ -0,0 +1,12 @@ +import pytensor.tensor as pt +from tests.link.mlx.test_basic import compare_mlx_and_py, mx + + +def test_all() -> None: + x = pt.vector("x") + + out = pt.all(x > 0) + + x_test = mx.array([-1.0, 2.0, 3.0]) + + compare_mlx_and_py([x], out, [x_test]) From 2fc81bc8701949f555d2c85e8b3c668313d745c4 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 19:45:54 -0400 Subject: [PATCH 35/54] fix for carlos --- tests/link/mlx/test_elemwise.py | 11 ++++++----- tests/link/mlx/test_math.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/link/mlx/test_elemwise.py b/tests/link/mlx/test_elemwise.py index d7e17b6654..7819df06be 100644 --- a/tests/link/mlx/test_elemwise.py +++ b/tests/link/mlx/test_elemwise.py @@ -1,12 +1,13 @@ +import pytest + import pytensor.tensor as pt from tests.link.mlx.test_basic import compare_mlx_and_py, mx -def test_all() -> None: +@pytest.mark.parametrize("op", [pt.any, pt.all, pt.max, pt.min]) +def test_input(op) -> None: x = pt.vector("x") - - out = pt.all(x > 0) - - x_test = mx.array([-1.0, 2.0, 3.0]) + out = op(x > 0) + x_test = mx.array([1.0, 2.0, 3.0]) compare_mlx_and_py([x], out, [x_test]) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 86fa999451..850d9e754d 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -54,6 +54,16 @@ def test_switch() -> None: compare_mlx_and_py([x, y], out, [x_test, y_test]) +@pytest.mark.parametrize("op", [pt.sum, pt.prod]) +def test_input(op) -> None: + x = pt.vector("x") + y = pt.vector("y") + out = op([x, y, x + y]) + x_test = mx.array([1.0, 2.0, 3.0]) + y_test = mx.array([4.0, 5.0, 6.0]) + compare_mlx_and_py([x, y], out, [x_test, y_test]) + + @pytest.mark.parametrize( "op", [ From e6437cc33d4ea165adec28a9900086e38f6e2bd1 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 20:09:21 -0400 Subject: [PATCH 36/54] just return the compiled func --- pytensor/link/mlx/linker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index e057bb942c..1dc06aefc6 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -35,6 +35,8 @@ def jit_compile(self, fn): inner_fn = mx.compile(fn) + return inner_fn + def fn(*inputs, inner_fn=inner_fn): return inner_fn(*(mlx_typify(inp) for inp in inputs)) From c3a3e1a81d1b436282eabf390c1f3f7c66c20e47 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 19:19:39 -0500 Subject: [PATCH 37/54] A change for willy may! --- pytensor/link/mlx/dispatch/math.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 305f86c90b..293a7cfa0a 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -198,14 +198,14 @@ def neg(x): return neg elif isinstance(op.scalar_op, AND): - def all(x): - return x.all(axis=op.axis) + def all(x, y): + return mx.bitwise_and(x, y) return all elif isinstance(op.scalar_op, OR): - def any(x): - return mx.any(x, axis=op.axis) + def any(x, y): + return mx.bitwise_or(x, y) return any elif isinstance(op.scalar_op, ScalarMaximum): From e7cf10ea0ed889c1ed91e80243f825087c30d1dd Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 19:24:10 -0500 Subject: [PATCH 38/54] FINALLY BABY LETS PARTY! (IF YOU ARE READING THIS MAKE MORE PRs) --- pytensor/link/mlx/linker.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index 1dc06aefc6..e057bb942c 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -35,8 +35,6 @@ def jit_compile(self, fn): inner_fn = mx.compile(fn) - return inner_fn - def fn(*inputs, inner_fn=inner_fn): return inner_fn(*(mlx_typify(inp) for inp in inputs)) From 880dd5cf3beb875f955b3c5b7d6bbe39560322f6 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 20:51:53 -0400 Subject: [PATCH 39/54] refactor to use getattr --- pytensor/link/mlx/dispatch/elemwise.py | 86 ++++++++++++++------------ 1 file changed, 47 insertions(+), 39 deletions(-) diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index 57103c12ff..c71de48b12 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -2,7 +2,6 @@ from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.scalar import Softplus -from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.tensor.special import Softmax, SoftmaxGrad @@ -24,44 +23,53 @@ def dimshuffle(x): @mlx_funcify.register(CAReduce) def mlx_funcify_CAReduce(op, **kwargs): - if isinstance(op.scalar_op, Add): - - def sum(x): - return mx.sum(x, axis=op.axis) - - return sum - elif isinstance(op.scalar_op, Mul): - - def prod(x): - return mx.prod(x, axis=op.axis) - - return prod - elif isinstance(op.scalar_op, AND): - - def all(x): - return x.all(axis=op.axis) - - return all - elif isinstance(op.scalar_op, OR): - - def any(x): - return mx.any(x, axis=op.axis) - - return any - elif isinstance(op.scalar_op, ScalarMaximum): - - def max(x): - return mx.max(x, axis=op.axis) - - return max - elif isinstance(op.scalar_op, ScalarMinimum): - - def min(x): - return mx.min(x, axis=op.axis) - - return min - else: - raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}") + axis = op.axis + op_nfunc_spec = getattr(op, "nfunc_spec", None) + scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None) + scalar_op_name = getattr(op.scalar_op, "name", None) + scalar_op_identity = getattr(op.scalar_op, "identity", None) + acc_dtype = getattr(op, "acc_dtype", None) + + def careduce(x): + nonlocal \ + axis, \ + op_nfunc_spec, \ + scalar_nfunc_spec, \ + scalar_op_name, \ + scalar_op_identity, \ + acc_dtype + + if axis is None: + axis = list(range(x.ndim)) + + if acc_dtype is None: + acc_dtype = x.dtype.type + + if op_nfunc_spec: + mlx_op = getattr(mx, op_nfunc_spec[0]) + return mlx_op(x, axis=axis) + return mlx_op(x, axis=axis).astype(acc_dtype) + + # The PyTensor `Op` didn't tell us which NumPy equivalent to use (or + # there isn't one), so we use this fallback approach + if scalar_nfunc_spec: + scalar_fn_name = scalar_nfunc_spec[0] + elif scalar_op_name: + scalar_fn_name = scalar_op_name + + to_reduce = sorted(axis, reverse=True) + + if to_reduce: + raise NotImplementedError("Not implemented yet") + # In this case, we need to use the `jax.lax` function (if there + # is one), and not the `jnp` version. + mlx_op = getattr(mx, scalar_fn_name) + init_value = mx.array(scalar_op_identity, dtype=acc_dtype) + return mx.reduce(x, init_value, mlx_op, to_reduce).astype(acc_dtype) + else: + return x + + return careduce @mlx_funcify.register(Softmax) From 1e6addd79c2c839bb2a3f42a309e10d9a2585b2c Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 20:52:17 -0400 Subject: [PATCH 40/54] bring argmax test --- tests/link/mlx/test_math.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 850d9e754d..2c08d986c9 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -3,6 +3,7 @@ import pytensor import pytensor.tensor as pt +from pytensor.tensor.math import Argmax, Max from tests.link.mlx.test_basic import compare_mlx_and_py, mx @@ -87,3 +88,14 @@ def test_elemwise_two_inputs(op) -> None: x_test = mx.array([1.0, 2.0, 3.0]) y_test = mx.array([4.0, 5.0, 6.0]) compare_mlx_and_py([x, y], out, [x_test, y_test]) + + +@pytest.mark.xfail(reason="Argmax not implemented yet") +def test_mlx_max_and_argmax(): + # Test that a single output of a multi-output `Op` can be used as input to + # another `Op` + x = pt.dvector() + mx = Max([0])(x) + amx = Argmax([0])(x) + out = mx * amx + compare_mlx_and_py([x], [out], [np.r_[1, 2]]) From aabbb788ff70668d37ea8523e72be4032e5430e0 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 21:10:21 -0400 Subject: [PATCH 41/54] use deepcopy --- pytensor/link/mlx/dispatch/basic.py | 3 ++- pytensor/link/mlx/dispatch/subtensor.py | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pytensor/link/mlx/dispatch/basic.py b/pytensor/link/mlx/dispatch/basic.py index a99772dba3..d0b3d451f5 100644 --- a/pytensor/link/mlx/dispatch/basic.py +++ b/pytensor/link/mlx/dispatch/basic.py @@ -1,4 +1,5 @@ import warnings +from copy import deepcopy from functools import singledispatch from types import NoneType @@ -58,7 +59,7 @@ def mlx_funcify_FunctionGraph( @mlx_funcify.register(DeepCopyOp) def mlx_funcify_DeepCopyOp(op, **kwargs): def deepcopyop(x): - return x.copy() + return deepcopy(x) return deepcopyop diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index b45a10519c..ce14d08246 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -1,3 +1,5 @@ +from copy import deepcopy + from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -24,6 +26,7 @@ def subtensor(x, *ilists): return subtensor + @mlx_funcify.register(AdvancedSubtensor) @mlx_funcify.register(AdvancedSubtensor1) def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): @@ -48,7 +51,7 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs): def mlx_fn(x, indices, y): if not op.inplace: - x = x.copy() + x = deepcopy(x) x[indices] = y return x @@ -56,7 +59,7 @@ def mlx_fn(x, indices, y): def mlx_fn(x, indices, y): if not op.inplace: - x = x.copy() + x = deepcopy(x) x[indices] += y return x @@ -76,7 +79,7 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): def mlx_fn(x, indices, y): if not op.inplace: - x = x.copy() + x = deepcopy(x) x[indices] = y return x @@ -84,7 +87,7 @@ def mlx_fn(x, indices, y): def mlx_fn(x, indices, y): if not op.inplace: - x = x.copy() + x = deepcopy(x) x[indices] += y return x From 0812c55398f19c4ed95a3d7cca4891ff75f68b27 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 21:11:30 -0400 Subject: [PATCH 42/54] move some tests --- tests/link/mlx/test_shape.py | 78 ++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tests/link/mlx/test_shape.py diff --git a/tests/link/mlx/test_shape.py b/tests/link/mlx/test_shape.py new file mode 100644 index 0000000000..7a548df8f8 --- /dev/null +++ b/tests/link/mlx/test_shape.py @@ -0,0 +1,78 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.compile.ops import DeepCopyOp, ViewOp +from pytensor.configdefaults import config +from pytensor.tensor.shape import Shape, Shape_i, reshape +from pytensor.tensor.type import iscalar, vector +from tests.link.mlx.test_basic import compare_mlx_and_py + + +@pytest.mark.xfail(reason="Shape Op is not supported yet") +def test_mlx_shape_ops(): + x_np = np.zeros((20, 3)) + x = Shape()(pt.as_tensor_variable(x_np)) + + compare_mlx_and_py([], [x], [], must_be_device_array=False) + + x = Shape_i(1)(pt.as_tensor_variable(x_np)) + + compare_mlx_and_py([], [x], [], must_be_device_array=False) + + +@pytest.mark.xfail(reason="Shape Op is not supported yet") +def test_mlx_specify_shape(): + in_pt = pt.matrix("in") + x = pt.specify_shape(in_pt, (4, None)) + compare_mlx_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)]) + + # When used to assert two arrays have similar shapes + in_pt = pt.matrix("in") + shape_pt = pt.matrix("shape") + x = pt.specify_shape(in_pt, shape_pt.shape) + + compare_mlx_and_py( + [in_pt, shape_pt], + [x], + [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)], + ) + + +@pytest.mark.xfail(reason="Reshape Op is not supported yet") +def test_mlx_Reshape_constant(): + a = vector("a") + x = reshape(a, (2, 2)) + compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + +@pytest.mark.xfail(reason="Reshape Op is not supported yet") +def test_mlx_Reshape_concrete_shape(): + """MLX should compile when a concrete value is passed for the `shape` parameter.""" + a = vector("a") + x = reshape(a, a.shape) + compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2)) + compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + +@pytest.mark.xfail(reason="`shape_pt` should be specified as a static argument") +def test_mlx_Reshape_shape_graph_input(): + a = vector("a") + shape_pt = iscalar("b") + x = reshape(a, (shape_pt, shape_pt)) + compare_mlx_and_py( + [a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2] + ) + + +@pytest.mark.xfail(reason="ViewOp Op is not supported yet") +def test_mlx_compile_ops(): + x = DeepCopyOp()(pt.as_tensor_variable(1.1)) + compare_mlx_and_py([], [x], []) + + x_np = np.zeros((20, 1, 1)) + x = ViewOp()(pt.as_tensor_variable(x_np)) + + compare_mlx_and_py([], [x], []) From 294c271ca2258186fa8d3ce68c9fabc1a9a0c261 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 20:13:34 -0500 Subject: [PATCH 43/54] THE SUPER BLOCKWISEE YA YA YA YA JUUUUU --- pytensor/link/mlx/dispatch/blockwise.py | 32 ++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 550a1c9616..378fe861a8 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -2,11 +2,41 @@ from pytensor.link.mlx.dispatch import mlx_funcify from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.signal.conv import Conv1d +def blockwise_conv1d(op, node): + if op.core_op.mode != "valid": + raise NotImplementedError("Only 'valid' mode is supported for conv1d") + batches_ndim = op.batch_ndim(node) + if batches_ndim != 1: + raise NotImplementedError("Only 1D batches are supported for conv1d") + + _, kernel = node.inputs + if not all(kernel.type.broadcastable[:batches_ndim]): + raise NotImplementedError("Only 1D batches are supported for conv1d") + + def inner_f(x, kernel): + x_reshaped = x.reshape(-1, x.shape[-1]).T # shape equals to (N, B) -> N Time as batches all together + b = x_reshaped.shape[1] # + kernel_squeeze = kernel.reshape(-1) + f = kernel_squeeze.shape[0] # Number of filters + kernel_reshaped = mx.broadcast_to(a=kernel_squeeze[None, :, None], shape=(b, f, b)) + conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1) + _, conv_shape, _ = conv_result.shape + return mx.moveaxis(a=conv_result, source=-1, destination=0).reshape(x.shape[:-1] + (conv_shape,)) + return inner_f @mlx_funcify.register(Blockwise) -def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): +def funcify_Blockwise(op: Blockwise, node, **kwargs): + if isinstance(op.core_op, Conv1d): + return blockwise_conv1d(op, node, **kwargs) + + core_f = mlx_funcify(op.core_op) + + def blockwise_f(*inputs): + return blockwise_f(*inputs) core_node = op._create_dummy_core_node(node.inputs) + core_f = mlx_funcify(op.core_op, core_node) blockwise_f = core_f for i in range(op.batch_ndim(node)): From 9f31ab109c870a05a0207c9e4f9019e301129601 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 20:45:51 -0500 Subject: [PATCH 44/54] Guys, I'm getting sad. We need help yisus!!!!! --- pytensor/link/mlx/dispatch/blockwise.py | 36 ++++++++++++++++--------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 378fe861a8..9c8d67b69a 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -4,26 +4,36 @@ from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.signal.conv import Conv1d -def blockwise_conv1d(op, node): +import numpy as np + +def blockwise_conv1d(op, node, **kwargs): if op.core_op.mode != "valid": raise NotImplementedError("Only 'valid' mode is supported for conv1d") - batches_ndim = op.batch_ndim(node) - if batches_ndim != 1: - raise NotImplementedError("Only 1D batches are supported for conv1d") + # batches_ndim = op.batch_ndim(node) + # if batches_ndim != 1: + # raise NotImplementedError("Only 1D batches are supported for conv1d") - _, kernel = node.inputs - if not all(kernel.type.broadcastable[:batches_ndim]): - raise NotImplementedError("Only 1D batches are supported for conv1d") + # _, kernel = node.inputs + # if not all(kernel.type.broadcastable[:batches_ndim]): + # raise NotImplementedError("Only 1D batches are supported for conv1d") def inner_f(x, kernel): - x_reshaped = x.reshape(-1, x.shape[-1]).T # shape equals to (N, B) -> N Time as batches all together - b = x_reshaped.shape[1] # - kernel_squeeze = kernel.reshape(-1) - f = kernel_squeeze.shape[0] # Number of filters - kernel_reshaped = mx.broadcast_to(a=kernel_squeeze[None, :, None], shape=(b, f, b)) + *bx, t = x.shape + *bk, h = kernel.shape + + b = np.broadcast_shapes(bx, bk) + + x = x.reshape(b + (t,)) + kernel = kernel.reshape(b + (h,)) + + x_reshaped = x.reshape(-1, t).T # shape equals to (N, B) -> N Time as batches all together + kernel_squeeze = kernel.reshape(-1, h) + b_prod = kernel_squeeze.shape[0] + + kernel_reshaped = mx.broadcast_to(a=kernel_squeeze[None, :, None], shape=(b_prod, h, b_prod)) conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1) _, conv_shape, _ = conv_result.shape - return mx.moveaxis(a=conv_result, source=-1, destination=0).reshape(x.shape[:-1] + (conv_shape,)) + return mx.moveaxis(a=conv_result, source=-1, destination=0).reshape(b + (conv_shape,)) return inner_f @mlx_funcify.register(Blockwise) From 37440ff1a8a959ae8f1edbb573e665334239d8e8 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 21:00:15 -0500 Subject: [PATCH 45/54] WILLIAM YOU NEED TO GO ANOTHER MILE! GO ON MY MATEEEEEEE, GO PHILLIES! --- pytensor/link/mlx/dispatch/blockwise.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 9c8d67b69a..cdabaf8315 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -23,17 +23,20 @@ def inner_f(x, kernel): b = np.broadcast_shapes(bx, bk) - x = x.reshape(b + (t,)) - kernel = kernel.reshape(b + (h,)) + x = mx.broadcast_to(x, b + (t,)) + kernel = mx.broadcast_to(kernel, b + (h,)) x_reshaped = x.reshape(-1, t).T # shape equals to (N, B) -> N Time as batches all together kernel_squeeze = kernel.reshape(-1, h) b_prod = kernel_squeeze.shape[0] - kernel_reshaped = mx.broadcast_to(a=kernel_squeeze[None, :, None], shape=(b_prod, h, b_prod)) + print(kernel_squeeze.shape) + + print(b_prod, h, b_prod) + kernel_reshaped = mx.broadcast_to(kernel_squeeze[:, :, None], shape=(b_prod, h, b_prod)) conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1) _, conv_shape, _ = conv_result.shape - return mx.moveaxis(a=conv_result, source=-1, destination=0).reshape(b + (conv_shape,)) + mx.moveaxis(conv_result, source=-1, destination=0).reshape(b + (conv_shape,)) return inner_f @mlx_funcify.register(Blockwise) From 4e4923fa6d02a24b25f473690bee5aaf839abfaf Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 21:03:02 -0500 Subject: [PATCH 46/54] RETURN, WHAT A SHAME! Sad times are coming. --- pytensor/link/mlx/dispatch/blockwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index cdabaf8315..5393483f20 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -36,7 +36,7 @@ def inner_f(x, kernel): kernel_reshaped = mx.broadcast_to(kernel_squeeze[:, :, None], shape=(b_prod, h, b_prod)) conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1) _, conv_shape, _ = conv_result.shape - mx.moveaxis(conv_result, source=-1, destination=0).reshape(b + (conv_shape,)) + return mx.moveaxis(conv_result, source=-1, destination=0).reshape(b + (conv_shape,)) return inner_f @mlx_funcify.register(Blockwise) From 6b27dc4bbebd787b102b1f28b3a6af8d7013cdd9 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 22:15:05 -0500 Subject: [PATCH 47/54] AI COULD BE COOL? OR WE ARE JUST FUCKING AROUND? --- pytensor/link/mlx/dispatch/blockwise.py | 34 ++++++++++++++----------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 5393483f20..cf9b0fa830 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -18,25 +18,29 @@ def blockwise_conv1d(op, node, **kwargs): # raise NotImplementedError("Only 1D batches are supported for conv1d") def inner_f(x, kernel): - *bx, t = x.shape - *bk, h = kernel.shape + # 1) Validate shapes + B, T = x.shape + Bk, K = kernel.shape + if B != Bk: + raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") - b = np.broadcast_shapes(bx, bk) + # 2) Reshape x so that 'channels' = B, batch size = 1 + # → input shape (N=1, H=T, C_in=B) + x_in = x.T[None, :, :] # shape (1, T, B) - x = mx.broadcast_to(x, b + (t,)) - kernel = mx.broadcast_to(kernel, b + (h,)) + # 3) Build weight array of shape (C_out=B, H_f=K, C_in=1) + # groups = B will slice C_in into B single-channel groups + w = kernel[:, :, None] # shape (B, K, 1) - x_reshaped = x.reshape(-1, t).T # shape equals to (N, B) -> N Time as batches all together - kernel_squeeze = kernel.reshape(-1, h) - b_prod = kernel_squeeze.shape[0] + # 4) Convolve with one group per sequence + y = mx.conv1d(x_in, w, + stride=1, + padding=0, + dilation=1, + groups=B) - print(kernel_squeeze.shape) - - print(b_prod, h, b_prod) - kernel_reshaped = mx.broadcast_to(kernel_squeeze[:, :, None], shape=(b_prod, h, b_prod)) - conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1) - _, conv_shape, _ = conv_result.shape - return mx.moveaxis(conv_result, source=-1, destination=0).reshape(b + (conv_shape,)) + # 5) y has shape (1, T - K + 1, B); drop the batch axis and transpose + return y[0].T # final shape (B, T - K + 1) return inner_f @mlx_funcify.register(Blockwise) From e308f838086042ef8e6551617962699c9cbfcdd3 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 22:28:31 -0500 Subject: [PATCH 48/54] AI RULES BABY MY MATE --- pytensor/link/mlx/dispatch/blockwise.py | 31 +++++++++++++------------ 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index cf9b0fa830..a30203f17f 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -18,29 +18,30 @@ def blockwise_conv1d(op, node, **kwargs): # raise NotImplementedError("Only 1D batches are supported for conv1d") def inner_f(x, kernel): - # 1) Validate shapes B, T = x.shape Bk, K = kernel.shape if B != Bk: raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") - # 2) Reshape x so that 'channels' = B, batch size = 1 - # → input shape (N=1, H=T, C_in=B) - x_in = x.T[None, :, :] # shape (1, T, B) + # 1) Flip each kernel for true convolution + kernels_flipped = kernel[:, ::-1] # shape (B, K) - # 3) Build weight array of shape (C_out=B, H_f=K, C_in=1) - # groups = B will slice C_in into B single-channel groups - w = kernel[:, :, None] # shape (B, K, 1) + # 2) Reshape input into (N=1, H=T, C_in=B) + x_in = x.T[None, :, :] - # 4) Convolve with one group per sequence - y = mx.conv1d(x_in, w, - stride=1, - padding=0, - dilation=1, - groups=B) + # 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1) + w = kernels_flipped[:, :, None] - # 5) y has shape (1, T - K + 1, B); drop the batch axis and transpose - return y[0].T # final shape (B, T - K + 1) + # 4) Convolve with one group per channel → valid mode + y = mx.conv1d( + x_in, w, + stride=1, + padding=0, + dilation=1, + groups=B + ) + # y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1) + return y[0].T return inner_f @mlx_funcify.register(Blockwise) From 3744a180db6c84474168fe3819f9ddb8054d2099 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 23:31:18 -0400 Subject: [PATCH 49/54] test conv1d case --- tests/link/mlx/test_blockwise.py | 64 ++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 tests/link/mlx/test_blockwise.py diff --git a/tests/link/mlx/test_blockwise.py b/tests/link/mlx/test_blockwise.py new file mode 100644 index 0000000000..9b271186c9 --- /dev/null +++ b/tests/link/mlx/test_blockwise.py @@ -0,0 +1,64 @@ +import numpy as np + +import pytensor.tensor as pt +from pytensor.tensor import tensor +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.math import Dot +from tests.link.mlx.test_basic import compare_mlx_and_py + + +# Equivalent blockwise to matmul but with dumb signature +odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)") + + +# @pytest.mark.parametrize("matmul_op", (matmul, odd_matmul)) +# def test_matmul(matmul_op): +# rng = np.random.default_rng(14) +# a = tensor("a", shape=(2, 3, 5)) +# b = tensor("b", shape=(2, 5, 3)) +# test_values = [ +# rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (a, b) +# ] +# +# out = matmul_op(a, b) +# assert isinstance(out.owner.op, Blockwise) +# fn, _ = compare_mlx_and_py([a, b], [out], test_values) +# +## Check we are not adding any unnecessary stuff +# jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values)) +# jaxpr = jaxpr.replace("name=jax_funcified_fgraph", "name=matmul") +# expected_jaxpr = str(jax.make_jaxpr(jax.jit(jax.numpy.matmul))(*test_values)) +# assert jaxpr == expected_jaxpr + + +# conv1d +# (2, 100) +# (8, 100) +# mode = valid + + +def test_blockwise_conv1d(): + rng = np.random.default_rng(14) + a = tensor("a", shape=(2, 100)) + b = tensor("b", shape=(2, 8)) + + # a_test = np.broadcast_to(np.arange(100), (2, 100)) + a_test = rng.normal(size=(2, 100)) + b_test = rng.normal(size=(2, 8)) + # b_test = np.concatenate( + # [ + # np.ones((1, 8)), + # np.zeros((1, 8)), + # np.zeros((1, 8)), + # np.array([1, 0, 0, 0, 0, 0, 0, 0]).reshape(1, 8), + # np.array([1, 0, 0, 0, 0, 0, 0, 0]).reshape(1, 8), + # ], + # axis=0, + # ) + + test_values = [a_test, b_test] + + out = pt.signal.convolve1d(a, b, mode="valid") + + # assert isinstance(out.owner.op, Blockwise) + compare_mlx_and_py([a, b], [out], test_values, must_be_device_array=True) From b41cab00f96ce8af6de4a642daef5ce3b7590c52 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 22:40:40 -0500 Subject: [PATCH 50/54] I'm going for pizzas, it was an incredible day! --- pytensor/link/mlx/dispatch/blockwise.py | 105 ++++++++++++++++++------ 1 file changed, 82 insertions(+), 23 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index a30203f17f..95fc9d0f9a 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -7,49 +7,108 @@ import numpy as np def blockwise_conv1d(op, node, **kwargs): - if op.core_op.mode != "valid": - raise NotImplementedError("Only 'valid' mode is supported for conv1d") - # batches_ndim = op.batch_ndim(node) - # if batches_ndim != 1: - # raise NotImplementedError("Only 1D batches are supported for conv1d") + # if op.core_op.mode != "valid": + # raise NotImplementedError("Only 'valid' mode is supported for conv1d") - # _, kernel = node.inputs - # if not all(kernel.type.broadcastable[:batches_ndim]): - # raise NotImplementedError("Only 1D batches are supported for conv1d") + # def inner_f(x, kernel): + # B, T = x.shape + # Bk, K = kernel.shape + # if B != Bk: + # raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") + + # # 1) Flip each kernel for true convolution + # kernels_flipped = kernel[:, ::-1] # shape (B, K) + + # # 2) Reshape input into (N=1, H=T, C_in=B) + # x_in = x.T[None, :, :] + + # # 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1) + # w = kernels_flipped[:, :, None] + + # # 4) Convolve with one group per channel → valid mode + # y = mx.conv1d( + # x_in, w, + # stride=1, + # padding=0, + # dilation=1, + # groups=B + # ) + # # y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1) + # return y[0].T - def inner_f(x, kernel): + def batched_conv1d( + x: mx.array, + kernels: mx.array, + mode: str = op.core_op.mode, + stride: int = 1, + dilation: int = 1) -> mx.array: + """ + Apply B separate 1D convolutions (full or valid) to B sequences in parallel. + + Parameters + ---------- + x : array of shape (B, T) + B sequences of length T. + kernels : array of shape (B, K) + B kernels of length K. + mode : {"valid", "full"} + "valid" → no padding, output length = T - K + 1 + "full" → zero‑pad so output length = T + K - 1 + stride : int, convolution stride (default=1) + dilation : int, convolution dilation (default=1) + + Returns + ------- + out : array of shape (B, L) + where L = + - T - K + 1 if mode="valid" + - T + K - 1 if mode="full" + """ + # --- 1) shape checks --- B, T = x.shape - Bk, K = kernel.shape + Bk, K = kernels.shape if B != Bk: raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") - # 1) Flip each kernel for true convolution - kernels_flipped = kernel[:, ::-1] # shape (B, K) + # --- 2) flip kernels for convolution --- + kernels_flipped = kernels[:, ::-1] # shape (B, K) + + # --- 3) decide padding --- + if mode == "valid": + pad = 0 + elif mode == "full": + pad = (K - 1) * dilation + else: + raise ValueError(f"Unsupported mode {mode!r}: choose 'valid' or 'full'") - # 2) Reshape input into (N=1, H=T, C_in=B) - x_in = x.T[None, :, :] + # --- 4) reshape into MLX conv1d form --- + # input: (N=1, H=T, C_in=B) + x_in = x.T[None, :, :] - # 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1) - w = kernels_flipped[:, :, None] + # weight: (C_out=B, H_f=K, C_in=1) + w = kernels_flipped[:, :, None] - # 4) Convolve with one group per channel → valid mode + # --- 5) run grouped conv1d --- y = mx.conv1d( x_in, w, - stride=1, - padding=0, - dilation=1, + stride=stride, + padding=pad, + dilation=dilation, groups=B ) - # y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1) + # y shape: (1, H_out, B) + + # --- 6) return shape (B, H_out) --- return y[0].T - return inner_f + + return batched_conv1d @mlx_funcify.register(Blockwise) def funcify_Blockwise(op: Blockwise, node, **kwargs): if isinstance(op.core_op, Conv1d): return blockwise_conv1d(op, node, **kwargs) - core_f = mlx_funcify(op.core_op) + core_f = mlx_funcify(op.core_op, node) def blockwise_f(*inputs): return blockwise_f(*inputs) From 9766975453a5d43518b9b63567d1eac722e04d71 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 23:32:07 -0500 Subject: [PATCH 51/54] SUUUUUUUUU!!!!!! LIFE IS GOING WELL. MLX FOR MEDIA MIX MODELS BAY A shout out for the fathers of the day! Co-Authored-By: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> Co-Authored-By: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- pytensor/link/mlx/dispatch/blockwise.py | 20 ++++++++++++-------- pytensor/link/mlx/dispatch/elemwise.py | 4 ++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 95fc9d0f9a..00f774fe08 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -105,20 +105,24 @@ def batched_conv1d( @mlx_funcify.register(Blockwise) def funcify_Blockwise(op: Blockwise, node, **kwargs): + # 1) If it's a Conv1d Blockwise, use the custom implementation if isinstance(op.core_op, Conv1d): return blockwise_conv1d(op, node, **kwargs) - - core_f = mlx_funcify(op.core_op, node) - def blockwise_f(*inputs): - return blockwise_f(*inputs) + # 2) Otherwise, get the core python function for this Blockwise core_node = op._create_dummy_core_node(node.inputs) - core_f = mlx_funcify(op.core_op, core_node) - blockwise_f = core_f - for i in range(op.batch_ndim(node)): - blockwise_f = mx.vmap(blockwise_f) + # 3) Determine how many inputs correspond to batch dimensions + n_batch = op.batch_ndim(node) + + # 4) Build in_axes: map only the first n_batch args, keep the rest static + in_axes = tuple(0 if i < n_batch else None for i in range(len(node.inputs))) + + # 5) Vectorize (vmap) with in_axes + blockwise_f = mx.vmap(core_f, in_axes=in_axes) + + # 6) Return the mapped function def blockwise_fun(*inputs): return blockwise_f(*inputs) diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index c71de48b12..926da572e0 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -43,12 +43,12 @@ def careduce(x): axis = list(range(x.ndim)) if acc_dtype is None: - acc_dtype = x.dtype.type + acc_dtype = x.dtype if op_nfunc_spec: mlx_op = getattr(mx, op_nfunc_spec[0]) return mlx_op(x, axis=axis) - return mlx_op(x, axis=axis).astype(acc_dtype) + # return mlx_op(x, axis=axis).astype(acc_dtype) # The PyTensor `Op` didn't tell us which NumPy equivalent to use (or # there isn't one), so we use this fallback approach From 5ffc5ef8fc6a933ff55038ce1f84c59be7876ba0 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 23:33:44 -0500 Subject: [PATCH 52/54] pre-commit --- pytensor/link/mlx/dispatch/blockwise.py | 58 ++++++------------------- 1 file changed, 14 insertions(+), 44 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 00f774fe08..74bb018a68 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -4,44 +4,19 @@ from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.signal.conv import Conv1d -import numpy as np def blockwise_conv1d(op, node, **kwargs): - # if op.core_op.mode != "valid": - # raise NotImplementedError("Only 'valid' mode is supported for conv1d") - - # def inner_f(x, kernel): - # B, T = x.shape - # Bk, K = kernel.shape - # if B != Bk: - # raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") - - # # 1) Flip each kernel for true convolution - # kernels_flipped = kernel[:, ::-1] # shape (B, K) - - # # 2) Reshape input into (N=1, H=T, C_in=B) - # x_in = x.T[None, :, :] - - # # 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1) - # w = kernels_flipped[:, :, None] - - # # 4) Convolve with one group per channel → valid mode - # y = mx.conv1d( - # x_in, w, - # stride=1, - # padding=0, - # dilation=1, - # groups=B - # ) - # # y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1) - # return y[0].T - + """ + Custom implementation of Blockwise.conv1d for MLX. + """ + def batched_conv1d( - x: mx.array, - kernels: mx.array, - mode: str = op.core_op.mode, - stride: int = 1, - dilation: int = 1) -> mx.array: + x: mx.array, + kernels: mx.array, + mode: str = op.core_op.mode, + stride: int = 1, + dilation: int = 1, + ) -> mx.array: """ Apply B separate 1D convolutions (full or valid) to B sequences in parallel. @@ -53,14 +28,14 @@ def batched_conv1d( B kernels of length K. mode : {"valid", "full"} "valid" → no padding, output length = T - K + 1 - "full" → zero‑pad so output length = T + K - 1 + "full" → zero-pad so output length = T + K - 1 stride : int, convolution stride (default=1) dilation : int, convolution dilation (default=1) Returns ------- out : array of shape (B, L) - where L = + where L = - T - K + 1 if mode="valid" - T + K - 1 if mode="full" """ @@ -89,13 +64,7 @@ def batched_conv1d( w = kernels_flipped[:, :, None] # --- 5) run grouped conv1d --- - y = mx.conv1d( - x_in, w, - stride=stride, - padding=pad, - dilation=dilation, - groups=B - ) + y = mx.conv1d(x_in, w, stride=stride, padding=pad, dilation=dilation, groups=B) # y shape: (1, H_out, B) # --- 6) return shape (B, H_out) --- @@ -103,6 +72,7 @@ def batched_conv1d( return batched_conv1d + @mlx_funcify.register(Blockwise) def funcify_Blockwise(op: Blockwise, node, **kwargs): # 1) If it's a Conv1d Blockwise, use the custom implementation From 597f84ef2140b9e818c67aed6520796d9a25b2b7 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Sat, 19 Apr 2025 11:32:15 -0500 Subject: [PATCH 53/54] Almost working --- pytensor/link/mlx/dispatch/elemwise.py | 125 +++++++++++++++---------- pytensor/link/mlx/dispatch/math.py | 4 +- 2 files changed, 80 insertions(+), 49 deletions(-) diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index 926da572e0..7e7a27c5ab 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -5,6 +5,35 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.tensor.special import Softmax, SoftmaxGrad +from pytensor.scalar.basic import ( + AND, + EQ, + GE, + GT, + LE, + LT, + NEQ, + OR, + Abs, + Add, + Cast, + Cos, + Exp, + Log, + Log1p, + Mul, + Neg, + Pow, + ScalarMaximum, + ScalarMinimum, + Sign, + Sin, + Sqr, + Sqrt, + Sub, + Switch, + TrueDiv, +) @mlx_funcify.register(DimShuffle) def mlx_funcify_DimShuffle(op, **kwargs): @@ -21,55 +50,57 @@ def dimshuffle(x): return dimshuffle +@mlx_funcify.register(DimShuffle) +def mlx_funcify_DimShuffle(op, **kwargs): + def dimshuffle(x): + res = mx.transpose(x, op.transposition) + shape = list(res.shape[: len(op.shuffle)]) + for augm in op.augment: + shape.insert(augm, 1) + return mx.reshape(res, shape) + return dimshuffle + @mlx_funcify.register(CAReduce) def mlx_funcify_CAReduce(op, **kwargs): - axis = op.axis - op_nfunc_spec = getattr(op, "nfunc_spec", None) - scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None) - scalar_op_name = getattr(op.scalar_op, "name", None) - scalar_op_identity = getattr(op.scalar_op, "identity", None) - acc_dtype = getattr(op, "acc_dtype", None) - - def careduce(x): - nonlocal \ - axis, \ - op_nfunc_spec, \ - scalar_nfunc_spec, \ - scalar_op_name, \ - scalar_op_identity, \ - acc_dtype - - if axis is None: - axis = list(range(x.ndim)) - - if acc_dtype is None: - acc_dtype = x.dtype - - if op_nfunc_spec: - mlx_op = getattr(mx, op_nfunc_spec[0]) - return mlx_op(x, axis=axis) - # return mlx_op(x, axis=axis).astype(acc_dtype) - - # The PyTensor `Op` didn't tell us which NumPy equivalent to use (or - # there isn't one), so we use this fallback approach - if scalar_nfunc_spec: - scalar_fn_name = scalar_nfunc_spec[0] - elif scalar_op_name: - scalar_fn_name = scalar_op_name - - to_reduce = sorted(axis, reverse=True) - - if to_reduce: - raise NotImplementedError("Not implemented yet") - # In this case, we need to use the `jax.lax` function (if there - # is one), and not the `jnp` version. - mlx_op = getattr(mx, scalar_fn_name) - init_value = mx.array(scalar_op_identity, dtype=acc_dtype) - return mx.reduce(x, init_value, mlx_op, to_reduce).astype(acc_dtype) - else: - return x - - return careduce + if isinstance(op.scalar_op, Add): + + def sum(x): + return mx.sum(x, axis=op.axis) + + return sum + elif isinstance(op.scalar_op, Mul): + + def prod(x): + return mx.prod(x, axis=op.axis) + + return prod + elif isinstance(op.scalar_op, AND): + + def all(x): + return x.all(axis=op.axis) + + return all + elif isinstance(op.scalar_op, OR): + + def any(x): + return mx.any(x, axis=op.axis) + + return any + elif isinstance(op.scalar_op, ScalarMaximum): + + def max(x): + return x.max(axis=op.axis) + + return max + elif isinstance(op.scalar_op, ScalarMinimum): + + def min(x): + return x.min(axis=op.axis) + + return min + else: + raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}") + @mlx_funcify.register(Softmax) diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 293a7cfa0a..890a8db601 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -211,13 +211,13 @@ def any(x, y): elif isinstance(op.scalar_op, ScalarMaximum): def max(x): - return mx.max(x, axis=op.axis) + return x.max(axis=op.axis) return max elif isinstance(op.scalar_op, ScalarMinimum): def min(x): - return mx.min(x, axis=op.axis) + return x.min(axis=op.axis) return min elif isinstance(op.scalar_op, Cast): From fb8fd2f12ca6061d5b9002bdab155a382b5a842c Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Wed, 23 Apr 2025 11:15:11 +0300 Subject: [PATCH 54/54] Last PR sampling working Working --- pytensor/link/mlx/dispatch/core.py | 60 +++++++++++++++++++--- pytensor/link/mlx/dispatch/elemwise.py | 70 +++++++------------------- pytensor/link/mlx/dispatch/math.py | 20 +++++--- 3 files changed, 87 insertions(+), 63 deletions(-) diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index 6985c2b656..3a0b279cd3 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -127,7 +127,7 @@ def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): # ------------------------------------------------------------------ @mlx_funcify.register(Eye) # MLX def mlx_funcify_Eye(op, **kwargs): - dtype = op.dtype + dtype = convert_dtype_to_mlx(op.dtype) def eye(N, M, k): return mx.eye(int(N), int(M), int(k), dtype=dtype) # MLX @@ -135,13 +135,56 @@ def eye(N, M, k): return eye +def convert_dtype_to_mlx(dtype_str): + """Convert PyTensor dtype strings to MLX dtype objects. + + MLX expects dtype objects rather than string literals for type conversion. + This function maps common dtype strings to their MLX equivalents. + """ + if isinstance(dtype_str, str): + if dtype_str == "bool": + return mx.bool_ + elif dtype_str == "int8": + return mx.int8 + elif dtype_str == "int16": + return mx.int16 + elif dtype_str == "int32": + return mx.int32 + elif dtype_str == "int64": + return mx.int64 + elif dtype_str == "uint8": + return mx.uint8 + elif dtype_str == "uint16": + return mx.uint16 + elif dtype_str == "uint32": + return mx.uint32 + elif dtype_str == "uint64": + return mx.uint64 + elif dtype_str == "float16": + return mx.float16 + elif dtype_str == "float32": + return mx.float32 + elif dtype_str == "float64": + return mx.float64 + elif dtype_str == "bfloat16": + return mx.bfloat16 + elif dtype_str == "complex64": + return mx.complex64 + elif dtype_str == "complex128": + return mx.complex128 + # Return as is if it's already an MLX dtype or not a recognized string + return dtype_str + + # ------------------------------------------------------------------ # MakeVector # ------------------------------------------------------------------ @mlx_funcify.register(MakeVector) # MLX def mlx_funcify_MakeVector(op, **kwargs): + dtype = convert_dtype_to_mlx(op.dtype) + def makevector(*x): - return mx.array(x, dtype=op.dtype) # MLX + return mx.array(x, dtype=dtype) # MLX return makevector @@ -175,6 +218,7 @@ def scalar_from_tensor(x): def mlx_funcify_Tri(op, node, **kwargs): # node.inputs -> N, M, k const_args = [getattr(inp, "data", None) for inp in node.inputs] + dtype = convert_dtype_to_mlx(op.dtype) def tri(*args): # Replace args with compile-time constants when available @@ -182,15 +226,17 @@ def tri(*args): arg if const_a is None else const_a for arg, const_a in zip(args, const_args, strict=True) ] - return mx.tri(*args, dtype=op.dtype) # MLX + return mx.tri(*args, dtype=dtype) # MLX return tri @mlx_funcify.register(AllocEmpty) def mlx_funcify_AllocEmpty(op, **kwargs): + dtype = convert_dtype_to_mlx(op.dtype) + def allocempty(*shape): - return mx.zeros(shape, dtype=op.dtype) + return mx.zeros(shape, dtype=dtype) return allocempty @@ -198,8 +244,10 @@ def allocempty(*shape): @mlx_funcify.register(Alloc) def mlx_funcify_Alloc(op, node, **kwargs): def alloc(x, *shape): - res = mx.broadcast_to(x, shape) - Alloc._check_runtime_broadcast(node, mx.array(x), res.shape) + # Convert x to an MLX array with the correct dtype if it's a scalar + x_array = mx.array(x) + res = mx.broadcast_to(x_array, shape) + Alloc._check_runtime_broadcast(node, x_array, res.shape) return res return alloc diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index 7e7a27c5ab..aaf04968de 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -1,65 +1,37 @@ import mlx.core as mx +import numpy as np from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx from pytensor.scalar import Softplus -from pytensor.tensor.elemwise import CAReduce, DimShuffle -from pytensor.tensor.special import Softmax, SoftmaxGrad - from pytensor.scalar.basic import ( AND, - EQ, - GE, - GT, - LE, - LT, - NEQ, OR, - Abs, Add, Cast, - Cos, - Exp, - Log, - Log1p, Mul, - Neg, - Pow, - ScalarMaximum, - ScalarMinimum, - Sign, - Sin, - Sqr, - Sqrt, - Sub, - Switch, - TrueDiv, ) +from pytensor.tensor.elemwise import CAReduce, DimShuffle +from pytensor.tensor.special import Softmax, SoftmaxGrad + @mlx_funcify.register(DimShuffle) def mlx_funcify_DimShuffle(op, **kwargs): def dimshuffle(x): + # Convert scalar to array if needed + if isinstance(x, int | float) or ( + isinstance(x, np.number) and not isinstance(x, np.ndarray) + ): + x = mx.array(x) res = mx.transpose(x, op.transposition) - shape = list(res.shape[: len(op.shuffle)]) - for augm in op.augment: shape.insert(augm, 1) - return mx.reshape(res, shape) return dimshuffle -@mlx_funcify.register(DimShuffle) -def mlx_funcify_DimShuffle(op, **kwargs): - def dimshuffle(x): - res = mx.transpose(x, op.transposition) - shape = list(res.shape[: len(op.shuffle)]) - for augm in op.augment: - shape.insert(augm, 1) - return mx.reshape(res, shape) - return dimshuffle - @mlx_funcify.register(CAReduce) def mlx_funcify_CAReduce(op, **kwargs): if isinstance(op.scalar_op, Add): @@ -86,23 +58,10 @@ def any(x): return mx.any(x, axis=op.axis) return any - elif isinstance(op.scalar_op, ScalarMaximum): - - def max(x): - return x.max(axis=op.axis) - - return max - elif isinstance(op.scalar_op, ScalarMinimum): - - def min(x): - return x.min(axis=op.axis) - - return min else: raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}") - @mlx_funcify.register(Softmax) def mlx_funcify_Softmax(op, **kwargs): axis = op.axis @@ -142,3 +101,12 @@ def softplus(x): ) return softplus + + +@mlx_funcify.register(Cast) +def mlx_funcify_Cast(op, **kwargs): + def cast(x): + dtype = convert_dtype_to_mlx(op.scalar_op.o_type.dtype) + return x.astype(dtype) + + return cast diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 890a8db601..153f049b0e 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -1,6 +1,7 @@ import mlx.core as mx -from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.link.mlx.dispatch import mlx_funcify, mlx_typify +from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx from pytensor.scalar import Softplus from pytensor.scalar.basic import ( AND, @@ -36,6 +37,12 @@ from pytensor.tensor.math import Dot +@mlx_typify.register(int) +@mlx_typify.register(float) +def mlx_typify_python_scalar(data, **kwargs): + return mx.array(data) + + @mlx_funcify.register(Dot) def mlx_funcify_Dot(op, **kwargs): def dot(x, y): @@ -210,20 +217,21 @@ def any(x, y): return any elif isinstance(op.scalar_op, ScalarMaximum): - def max(x): - return x.max(axis=op.axis) + def max(x, y): + return mx.maximum(x, y) return max elif isinstance(op.scalar_op, ScalarMinimum): - def min(x): - return x.min(axis=op.axis) + def min(x, y): + return mx.minimum(x, y) return min elif isinstance(op.scalar_op, Cast): def cast(x): - return mx.cast(x, op.dtype) + dtype = convert_dtype_to_mlx(op.scalar_op.o_type.dtype) + return x.astype(dtype) return cast elif isinstance(op.scalar_op, Sign):