Skip to content

Commit 0a27ca4

Browse files
committed
add 5d output
1 parent 1745778 commit 0a27ca4

File tree

1 file changed

+50
-10
lines changed

1 file changed

+50
-10
lines changed

torch_xla/experimental/custom_kernel.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ def _shard_map(
2525
Note:
2626
``shard_map`` is an experimental API, and still subject to change. For an
2727
introduction to sharded data, refer to :ref:`sharded-computation`. For a more
28-
in-depth look at using ``shard_map``, refer to `SPMD multi-device parallelism with shard_map`_.
28+
in-depth look at using ``shard_map``, refer to
29+
[SPMD multi-device parallelism with shard_map](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
2930
3031
Args:
3132
func: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
@@ -315,7 +316,18 @@ def _fa_custom_forward_single_device(
315316
mesh = xs.get_global_mesh() or Mesh.from_str(mesh)
316317
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl
317318

318-
q_full_shape = None
319+
num_batches = None
320+
batch_size = None
321+
if len(q.shape) == 5:
322+
num_batches, batch_size, *rest = q.shape
323+
q = q.reshape(-1, *rest)
324+
k = k.reshape(-1, *rest)
325+
v = v.reshape(-1, *rest)
326+
q_segment_ids = q_segment_ids.reshape(-1, *rest)
327+
kv_segment_ids = kv_segment_ids.reshape(-1, *rest)
328+
if ab is not none:
329+
ab = ab.reshape(-1, *rest)
330+
319331

320332
# Suprisingly, any tensor that is input to the custom_op decorated function will show
321333
# requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
@@ -389,6 +401,11 @@ def _fa_custom_forward_single_device(
389401
o, *aux = o
390402
l, m = (v[..., 0] for v in aux[-2:])
391403

404+
if num_batches is not None:
405+
o = o.reshape(num_batches, batch_size, *o.shape[1:])
406+
l = l.reshape(num_batches, batch_size, *l.shape[1:])
407+
m = m.reshape(num_batches, batch_size, *m.shape[1:])
408+
392409
return o, l, m
393410

394411

@@ -403,8 +420,6 @@ def fa_custom_forward(
403420
partition_spec = eval(partition_spec)
404421
mesh = xs.get_global_mesh() or Mesh.from_str(mesh)
405422

406-
q_full_shape = None
407-
408423
# Suprisingly, any tensor that is input to the custom_op decorated function will show
409424
# requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
410425
# requires_grad for inputs.
@@ -434,7 +449,10 @@ def fa_custom_forward(
434449
ab, max(block_k_major, block_k), 3, padding_minus_inf=True)
435450

436451
if partition_spec is not None:
437-
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
452+
if len(partition_spec) == 5:
453+
segment_id_partition_spec = (partition_spec[0], partition_spec[1], partition_spec[3])
454+
else:
455+
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
438456

439457
input_specs = [
440458
partition_spec, # q
@@ -455,13 +473,13 @@ def fa_custom_forward(
455473
]
456474

457475
fa_forward_callable = _shard_map(
458-
_fa_custom_forward_one_device,
476+
_fa_custom_forward_single_device,
459477
mesh,
460478
input_specs,
461479
output_specs,
462480
)
463481
else:
464-
fa_forward_callable = _fa_custom_forward_one_device
482+
fa_forward_callable = _fa_custom_forward_single_device
465483

466484
o, l, m = fa_forward_callable(
467485
q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, ctx_grad
@@ -505,6 +523,25 @@ def _fa_custom_backward_single_device(
505523
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv
506524
grad_q = grad_k = grad_v = grad_ab = segment_ids = None
507525

526+
527+
num_batches = None
528+
batch_size = None
529+
if len(q.shape) == 5:
530+
num_batches, batch_size, *rest = q.shape
531+
grad_output = grad_output.reshape(-1, *rest)
532+
q = q.reshape(-1, *rest)
533+
k = k.reshape(-1, *rest)
534+
v = v.reshape(-1, *rest)
535+
o = o.reshape(-1, *rest)
536+
l = l.reshape(-1, *rest)
537+
m = m.reshape(-1, *rest)
538+
if q_segment_ids is not None:
539+
q_segment_ids = q_segment_ids.reshape(-1, *rest)
540+
if kv_segment_ids is not None:
541+
kv_segment_ids = kv_segment_ids.reshape(-1, *rest)
542+
if ab is not none:
543+
ab = ab.reshape(-1, *rest)
544+
508545
require_grad_q, require_grad_k, require_grad_v, *rest = ctx_grad
509546
require_grad_ab = ctx_grad[-3]
510547

@@ -643,7 +680,10 @@ def fa_custom_backward(
643680

644681

645682
if partition_spec:
646-
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
683+
if len(partition_spec) == 5:
684+
segment_id_partition_spec = (partition_spec[0], partition_spec[1], partition_spec[3])
685+
else:
686+
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
647687
input_specs = [
648688
partition_spec, # grad_output
649689
partition_spec, # q
@@ -669,7 +709,7 @@ def fa_custom_backward(
669709
partition_spec,
670710
]
671711
fa_backward_callable = _shard_map(
672-
_fa_custom_backward_single_device
712+
_fa_custom_backward_single_device,
673713
mesh,
674714
input_specs,
675715
output_specs
@@ -678,7 +718,7 @@ def fa_custom_backward(
678718
fa_backward_callable = _fa_custom_backward_single_device
679719

680720
res = fa_backward_callable(
681-
grad_output, q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab, causal, sm_scale
721+
grad_output, q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab, causal, sm_scale,
682722
q_full_shape, kv_full_shape, ab_full_shape, ctx_grad
683723
)
684724

0 commit comments

Comments
 (0)