Skip to content

MLX backend POC #1365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 56 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
d25f214
mlx poc
williambdean Apr 11, 2025
edacc0e
add test for dot
williambdean Apr 11, 2025
052fdc2
restore pytorch
williambdean Apr 11, 2025
a9ecad0
wrap in mx.array
williambdean Apr 11, 2025
e690bff
modify the pytorch jit
williambdean Apr 11, 2025
ad29c17
move file
williambdean Apr 11, 2025
ba29b37
dont wrap
williambdean Apr 11, 2025
8716870
attempt to fix github action
williambdean Apr 11, 2025
9bf7edf
change the rtol
williambdean Apr 11, 2025
96ba116
add init file
williambdean Apr 11, 2025
e116fa1
skip if not installed
williambdean Apr 11, 2025
5d5f754
remove torch related code / comments
williambdean Apr 11, 2025
b8cee3f
simplify the fgraph_convert
williambdean Apr 12, 2025
d057453
assert type
williambdean Apr 12, 2025
ae202e6
simplify the internal
williambdean Apr 18, 2025
f1941fe
remove the language
williambdean Apr 18, 2025
7c8eae7
Adding operations in pytensor
cetagostini Apr 18, 2025
67a74fb
add extension
williambdean Apr 18, 2025
fb5eb52
make compare function
williambdean Apr 18, 2025
516b595
rename function
williambdean Apr 18, 2025
67bb8da
correct the function name
williambdean Apr 18, 2025
82bb964
tests for elemwise
williambdean Apr 18, 2025
877d79f
Changes
cetagostini Apr 18, 2025
fafedd6
Toma tu tomate William
cetagostini Apr 18, 2025
60acb8d
Pushing changes with the core shit.
cetagostini Apr 18, 2025
242aba7
add more tests
williambdean Apr 18, 2025
6cb47fc
additional tests
williambdean Apr 18, 2025
bc98e09
test for switch with mlx
williambdean Apr 18, 2025
4d5b34b
Pushing code
cetagostini Apr 18, 2025
5abd32d
Changes
cetagostini Apr 18, 2025
12daeac
A lot of new code
cetagostini Apr 18, 2025
ac93949
almost there baby william
cetagostini Apr 18, 2025
a19cbc8
Another push small
cetagostini Apr 18, 2025
5c97bc8
fix for all
williambdean Apr 18, 2025
2fc81bc
fix for carlos
williambdean Apr 18, 2025
e6437cc
just return the compiled func
williambdean Apr 19, 2025
c3a3e1a
A change for willy may!
cetagostini Apr 19, 2025
e7cf10e
FINALLY BABY LETS PARTY! (IF YOU ARE READING THIS MAKE MORE PRs)
cetagostini Apr 19, 2025
880dd5c
refactor to use getattr
williambdean Apr 19, 2025
1e6addd
bring argmax test
williambdean Apr 19, 2025
aabbb78
use deepcopy
williambdean Apr 19, 2025
0812c55
move some tests
williambdean Apr 19, 2025
294c271
THE SUPER BLOCKWISEE YA YA YA YA JUUUUU
cetagostini Apr 19, 2025
9d3eca8
Merge branch 'mlx-poc' of https://github.com/williambdean/pytensor in…
cetagostini Apr 19, 2025
9f31ab1
Guys, I'm getting sad. We need help yisus!!!!!
cetagostini Apr 19, 2025
37440ff
WILLIAM YOU NEED TO GO ANOTHER MILE! GO ON MY MATEEEEEEE, GO PHILLIES!
cetagostini Apr 19, 2025
4e4923f
RETURN, WHAT A SHAME! Sad times are coming.
cetagostini Apr 19, 2025
6b27dc4
AI COULD BE COOL? OR WE ARE JUST FUCKING AROUND?
cetagostini Apr 19, 2025
e308f83
AI RULES BABY MY MATE
cetagostini Apr 19, 2025
3744a18
test conv1d case
williambdean Apr 19, 2025
b41cab0
I'm going for pizzas, it was an incredible day!
cetagostini Apr 19, 2025
323fa9d
Merge branch 'mlx-poc' of https://github.com/williambdean/pytensor in…
cetagostini Apr 19, 2025
9766975
SUUUUUUUUU!!!!!! LIFE IS GOING WELL. MLX FOR MEDIA MIX MODELS BAY
cetagostini Apr 19, 2025
5ffc5ef
pre-commit
cetagostini Apr 19, 2025
597f84e
Almost working
cetagostini Apr 19, 2025
fb8fd2f
Last PR sampling working
cetagostini Apr 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytensor.link.basic import Linker, PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.link.jax.linker import JAXLinker
from pytensor.link.mlx.linker import MLXLinker
from pytensor.link.numba.linker import NumbaLinker
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.link.vm import VMLinker
Expand All @@ -50,6 +51,7 @@
"jax": JAXLinker(),
"pytorch": PytorchLinker(),
"numba": NumbaLinker(),
"mlx": MLXLinker(),
}


Expand Down Expand Up @@ -494,13 +496,28 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
),
)

MLX = Mode(
MLXLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
],
),
)


predefined_modes = {
"FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN,
"JAX": JAX,
"NUMBA": NUMBA,
"PYTORCH": PYTORCH,
"MLX": MLX,
}

_CACHED_RUNTIME_MODES: dict[str, Mode] = {}
Expand Down Expand Up @@ -585,6 +602,8 @@ def get_target_language(mode=None) -> tuple[Literal["py", "c", "numba", "jax"],
return ("py",)
if isinstance(linker, CLinker):
return ("c",)
if isinstance(linker, MLXLinker):
return ("py",)

if isinstance(linker, VMLinker | OpWiseCLinker):
return ("c", "py") if config.cxx else ("py",)
Expand Down
5 changes: 5 additions & 0 deletions pytensor/link/mlx/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# isort: off
from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify

import pytensor.link.mlx.dispatch.math
# isort: on
61 changes: 61 additions & 0 deletions pytensor/link/mlx/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from functools import singledispatch
from types import NoneType

import mlx.core as mx
import numpy as np

from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python


@singledispatch
def mlx_typify(data, **kwargs):
raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}")


@mlx_typify.register(np.ndarray)
@mlx_typify.register(mx.array)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mxarray should be registered in mlx_typify_no_conversion_needed

def mlx_typify_tensor(data, dtype=None, **kwargs):
return mx.array(data, dtype=dtype)


@mlx_typify.register(slice)
@mlx_typify.register(NoneType)
@mlx_typify.register(np.number)
def mlx_typify_no_conversion_needed(data, **kwargs):
return data


@singledispatch
def mlx_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a MLX compatible function from an PyTensor `Op`."""
raise NotImplementedError(
f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation"
)


@mlx_funcify.register(FunctionGraph)
def mlx_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="mlx_funcified_fgraph",
conversion_func=mlx_funcify,
**kwargs,
):
built_kwargs = {"conversion_func": conversion_func, **kwargs}
return fgraph_to_python(
fgraph,
conversion_func,
type_conversion_fn=mlx_typify,
fgraph_name=fgraph_name,
**built_kwargs,
)


@mlx_funcify.register(DeepCopyOp)
def mlx_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return x.copy()

return deepcopyop
12 changes: 12 additions & 0 deletions pytensor/link/mlx/dispatch/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import mlx.core as mx

from pytensor.link.mlx.dispatch import mlx_funcify
from pytensor.tensor.math import Dot


@mlx_funcify.register(Dot)
def mlx_funcify_Dot(op, **kwargs):
def dot(x, y):
return mx.matmul(x, y)

return dot
113 changes: 113 additions & 0 deletions pytensor/link/mlx/linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from pytensor.link.basic import JITLinker
from pytensor.link.utils import unique_name_generator


class MLXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gen_functors = []

def fgraph_convert(
self,
fgraph,
order,
input_storage,
output_storage,
storage_map,
**kwargs,
):
"""Convert a PyTensor FunctionGraph to an MLX-compatible function.

Parameters
----------
fgraph : FunctionGraph
The function graph to convert
order : list
The order in which to compute the nodes
input_storage : list
Storage for the input variables
output_storage : list
Storage for the output variables
storage_map : dict
Map from variables to their storage

Returns
-------
callable
An MLX-compatible function
"""
from pytensor.link.mlx.dispatch import mlx_funcify

# We want to have globally unique names
# across the entire pytensor graph, not
# just the subgraph
generator = unique_name_generator(["mlx_linker"])

# Ensure that torch is aware of the generated
# code so we can compile without graph breaks
def conversion_func_register(*args, **kwargs):
functor = mlx_funcify(*args, **kwargs)
name = kwargs["unique_name"](functor)
self.gen_functors.append((f"_{name}", functor))
return functor

built_kwargs = {
"unique_name": generator,
"conversion_func": conversion_func_register,
**kwargs,
}
return mlx_funcify(
fgraph,
input_storage=input_storage,
storage_map=storage_map,
**built_kwargs,
)

def jit_compile(self, fn):
"""JIT compile an MLX function.

Parameters
----------
fn : callable
The function to compile

Returns
-------
callable
The compiled function
"""
import mlx.core as mx

return mx.compile(fn)

def create_thunk_inputs(self, storage_map):
"""Create inputs for the MLX thunk.

Parameters
----------
storage_map : dict
Map from variables to their storage

Returns
-------
list
The inputs for the thunk
"""
from numpy.random import Generator, RandomState

from pytensor.link.mlx.dispatch import mlx_typify

thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
# Handle random number generators specially
if isinstance(sinput[0], RandomState | Generator):
new_value = mlx_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
sinput[0] = new_value
thunk_inputs.append(sinput)
Comment on lines +61 to +69
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't have Random stuff yet we shouldn't include the code


return thunk_inputs
16 changes: 8 additions & 8 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ def conversion_func_register(*args, **kwargs):
**kwargs,
}
return pytorch_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs
fgraph,
input_storage=input_storage,
storage_map=storage_map,
**built_kwargs,
)

def jit_compile(self, fn):
import torch
import mlx.core as mx

# flag that tend to help our graphs
torch._dynamo.config.capture_dynamic_output_shape_ops = True

from pytensor.link.pytorch.dispatch import pytorch_typify
from pytensor.link.mlx.dispatch import mlx_typify

class wrapper:
"""
Expand All @@ -54,7 +54,7 @@ class wrapper:
"""

def __init__(self, fn, gen_functors):
self.fn = torch.compile(fn)
self.fn = mx.compile(fn)
self.gen_functors = gen_functors.copy()

def __call__(self, *inputs, **kwargs):
Expand All @@ -65,7 +65,7 @@ def __call__(self, *inputs, **kwargs):
setattr(pytensor.link.utils, n[1:], fn)

# Torch does not accept numpy inputs and may return GPU objects
outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs)
outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs)

# unset attrs
for n, _ in self.gen_functors:
Expand Down
19 changes: 19 additions & 0 deletions tests/link/mlx/dispatch/test_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np

import pytensor
from pytensor.tensor.type import matrix


def test_mlx_dot():
x = matrix("x")
y = matrix("y")

out = x.dot(y)
fn = pytensor.function([x, y], out, mode="MLX")

test_x = np.random.normal(size=(3, 2))
test_y = np.random.normal(size=(2, 4))
np.testing.assert_allclose(
fn(test_x, test_y),
np.dot(test_x, test_y),
)
Loading