Skip to content

Commit 8a5a29f

Browse files
committed
add 5d output
1 parent 1745778 commit 8a5a29f

File tree

1 file changed

+62
-10
lines changed

1 file changed

+62
-10
lines changed

torch_xla/experimental/custom_kernel.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ def _shard_map(
2525
Note:
2626
``shard_map`` is an experimental API, and still subject to change. For an
2727
introduction to sharded data, refer to :ref:`sharded-computation`. For a more
28-
in-depth look at using ``shard_map``, refer to `SPMD multi-device parallelism with shard_map`_.
28+
in-depth look at using ``shard_map``, refer to
29+
[SPMD multi-device parallelism with shard_map](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
2930
3031
Args:
3132
func: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
@@ -315,7 +316,20 @@ def _fa_custom_forward_single_device(
315316
mesh = xs.get_global_mesh() or Mesh.from_str(mesh)
316317
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl
317318

318-
q_full_shape = None
319+
num_batches = None
320+
batch_size = None
321+
if len(q.shape) == 5:
322+
num_batches, batch_size, *rest = q.shape
323+
q = q.reshape(-1, *rest)
324+
k = k.reshape(-1, *rest)
325+
v = v.reshape(-1, *rest)
326+
if q_segment_ids is not None:
327+
q_segment_ids = q_segment_ids.reshape(-1, *rest)
328+
if kv_segment_ids is not None:
329+
kv_segment_ids = kv_segment_ids.reshape(-1, *rest)
330+
if ab is not none:
331+
ab = ab.reshape(-1, *rest)
332+
319333

320334
# Suprisingly, any tensor that is input to the custom_op decorated function will show
321335
# requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
@@ -389,6 +403,11 @@ def _fa_custom_forward_single_device(
389403
o, *aux = o
390404
l, m = (v[..., 0] for v in aux[-2:])
391405

406+
if num_batches is not None:
407+
o = o.reshape(num_batches, batch_size, *o.shape[1:])
408+
l = l.reshape(num_batches, batch_size, *l.shape[1:])
409+
m = m.reshape(num_batches, batch_size, *m.shape[1:])
410+
392411
return o, l, m
393412

394413

@@ -403,8 +422,6 @@ def fa_custom_forward(
403422
partition_spec = eval(partition_spec)
404423
mesh = xs.get_global_mesh() or Mesh.from_str(mesh)
405424

406-
q_full_shape = None
407-
408425
# Suprisingly, any tensor that is input to the custom_op decorated function will show
409426
# requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
410427
# requires_grad for inputs.
@@ -434,7 +451,10 @@ def fa_custom_forward(
434451
ab, max(block_k_major, block_k), 3, padding_minus_inf=True)
435452

436453
if partition_spec is not None:
437-
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
454+
if len(partition_spec) == 5:
455+
segment_id_partition_spec = (partition_spec[0], partition_spec[1], partition_spec[3])
456+
else:
457+
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
438458

439459
input_specs = [
440460
partition_spec, # q
@@ -455,13 +475,13 @@ def fa_custom_forward(
455475
]
456476

457477
fa_forward_callable = _shard_map(
458-
_fa_custom_forward_one_device,
478+
_fa_custom_forward_single_device,
459479
mesh,
460480
input_specs,
461481
output_specs,
462482
)
463483
else:
464-
fa_forward_callable = _fa_custom_forward_one_device
484+
fa_forward_callable = _fa_custom_forward_single_device
465485

466486
o, l, m = fa_forward_callable(
467487
q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, ctx_grad
@@ -505,6 +525,25 @@ def _fa_custom_backward_single_device(
505525
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv
506526
grad_q = grad_k = grad_v = grad_ab = segment_ids = None
507527

528+
529+
num_batches = None
530+
batch_size = None
531+
if len(q.shape) == 5:
532+
num_batches, batch_size, *rest = q.shape
533+
grad_output = grad_output.reshape(-1, *rest)
534+
q = q.reshape(-1, *rest)
535+
k = k.reshape(-1, *rest)
536+
v = v.reshape(-1, *rest)
537+
o = o.reshape(-1, *rest)
538+
l = l.reshape(-1, *rest)
539+
m = m.reshape(-1, *rest)
540+
if q_segment_ids is not None:
541+
q_segment_ids = q_segment_ids.reshape(-1, *rest)
542+
if kv_segment_ids is not None:
543+
kv_segment_ids = kv_segment_ids.reshape(-1, *rest)
544+
if ab is not none:
545+
ab = ab.reshape(-1, *rest)
546+
508547
require_grad_q, require_grad_k, require_grad_v, *rest = ctx_grad
509548
require_grad_ab = ctx_grad[-3]
510549

@@ -617,6 +656,16 @@ def _fa_custom_backward_single_device(
617656
if require_grad_v:
618657
grad_v = grads[1]
619658

659+
if num_batches is not None:
660+
def _reshape(x):
661+
if x is not None:
662+
return x.reshape(num_batches, batch_size, *x.shape[1:])
663+
return None
664+
grad_q = _reshape(grad_q)
665+
grad_k = _reshape(grad_k)
666+
grad_v = _reshape(grad_v)
667+
grad_ab = _reshape(grad_ab)
668+
620669
return grad_q, grad_k, grad_v, grad_ab
621670

622671
@custom_op("xla::fa_custom_backward", mutates_args=())
@@ -643,7 +692,10 @@ def fa_custom_backward(
643692

644693

645694
if partition_spec:
646-
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
695+
if len(partition_spec) == 5:
696+
segment_id_partition_spec = (partition_spec[0], partition_spec[1], partition_spec[3])
697+
else:
698+
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
647699
input_specs = [
648700
partition_spec, # grad_output
649701
partition_spec, # q
@@ -669,7 +721,7 @@ def fa_custom_backward(
669721
partition_spec,
670722
]
671723
fa_backward_callable = _shard_map(
672-
_fa_custom_backward_single_device
724+
_fa_custom_backward_single_device,
673725
mesh,
674726
input_specs,
675727
output_specs
@@ -678,7 +730,7 @@ def fa_custom_backward(
678730
fa_backward_callable = _fa_custom_backward_single_device
679731

680732
res = fa_backward_callable(
681-
grad_output, q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab, causal, sm_scale
733+
grad_output, q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab, causal, sm_scale,
682734
q_full_shape, kv_full_shape, ab_full_shape, ctx_grad
683735
)
684736

0 commit comments

Comments
 (0)