Skip to content

Commit 494df92

Browse files
qihqipgmoka
authored andcommitted
write _shard_map; refactor flash attention to support 5d inputs. (#8730)
1 parent b922fa0 commit 494df92

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

torch_xla/experimental/custom_kernel.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,93 @@
1717
DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max)
1818

1919

20+
def _shard_map(func, mesh, input_specs, output_specs):
21+
"""Map a function over shards of data.
22+
23+
Note:
24+
``shard_map`` is an experimental API, and still subject to change. For an
25+
introduction to sharded data. For a more
26+
in-depth look at using ``shard_map``, refer to
27+
[SPMD multi-device parallelism with shard_map](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
28+
29+
Args:
30+
func: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
31+
takes as input a shard of the mapped-over arguments and produces a shard
32+
of the output.
33+
mesh: a ``Mesh`` representing the array of devices over which
34+
to shard the data and on which to execute instances of ``f``. The names of
35+
the ``Mesh`` can be used in collective communication operations in ``f``.
36+
This is typically created by a utility function like
37+
:func:`jax.experimental.mesh_utils.create_device_mesh`.
38+
in_specs: a tuple of tuples of str. Each is the partition spec of positional input
39+
of func. kwarg is not supported yet
40+
out_specs: a pytree with :class:`~tuple[tuple[str]]`, with the same length
41+
as the number of outputs
42+
43+
Returns:
44+
A callable that applies the input function ``f`` across data sharded according to
45+
the ``mesh`` and ``out_specs``.
46+
47+
Reference:
48+
This function behaves identically Jax's shard_map:
49+
https://docs.jax.dev/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html
50+
"""
51+
52+
def _full_shape(a, spec):
53+
# a is local tensor
54+
# spec is the sharding spec
55+
# return logical shape of global tensor
56+
mesh_name_to_size = mesh.shape()
57+
58+
result_shape = []
59+
for axis_size, axis_sharding in zip(a.shape, spec):
60+
if axis_sharding is None:
61+
axis_sharding = ()
62+
mesh_mult = []
63+
if isinstance(axis_sharding, (str, int)):
64+
axis_sharding = [axis_sharding]
65+
for axis in axis_sharding:
66+
size = mesh_name_to_size[axis] or 1
67+
mesh_mult.append(size)
68+
new_size = axis_size * math.prod(mesh_mult)
69+
result_shape.append(new_size)
70+
return tuple(result_shape)
71+
72+
def wrapped(*args):
73+
assert len(args) == len(
74+
input_specs), f'args={len(args)}; input_specs={len(input_specs)}'
75+
new_args = []
76+
for i, (a, spec) in enumerate(zip(args, input_specs)):
77+
if isinstance(a, torch.Tensor):
78+
assert (len(a.shape) == len(spec)
79+
), f'{i}th input has wrong shape: {a.shape} for {spec}'
80+
new_a = xs.enable_manual_sharding(a, spec, mesh=mesh).global_tensor
81+
new_args.append(new_a)
82+
else:
83+
new_args.append(a)
84+
85+
res = func(*new_args)
86+
if isinstance(res, tuple):
87+
res_updated = []
88+
for i, (r, spec) in enumerate(zip(res, output_specs)):
89+
if isinstance(r, torch.Tensor) and spec is not None:
90+
assert str(r.device).startswith('xla'), f'{i}th device is {r.device}'
91+
assert len(r.shape) == len(
92+
spec), f'{i}th shape is {r.shape}, sharding is {output_specs[i]}'
93+
new_r = xs.disable_manual_sharding(
94+
r, spec, _full_shape(r, spec), mesh=mesh).global_tensor
95+
else:
96+
new_r = r
97+
res_updated.append(new_r)
98+
return res_updated
99+
else:
100+
return xs.disable_manual_sharding(
101+
res, output_specs[0], _full_shape(res, output_specs[0]),
102+
mesh=mesh).global_tensor
103+
104+
return wrapped
105+
106+
20107
def _shard_map(func, mesh, input_specs, output_specs):
21108
"""Map a function over shards of data.
22109

0 commit comments

Comments
 (0)