@@ -25,7 +25,8 @@ def _shard_map(
25
25
Note:
26
26
``shard_map`` is an experimental API, and still subject to change. For an
27
27
introduction to sharded data, refer to :ref:`sharded-computation`. For a more
28
- in-depth look at using ``shard_map``, refer to `SPMD multi-device parallelism with shard_map`_.
28
+ in-depth look at using ``shard_map``, refer to
29
+ [SPMD multi-device parallelism with shard_map](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
29
30
30
31
Args:
31
32
func: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
@@ -315,7 +316,20 @@ def _fa_custom_forward_single_device(
315
316
mesh = xs .get_global_mesh () or Mesh .from_str (mesh )
316
317
from jax .experimental .pallas .ops .tpu .flash_attention import _flash_attention_impl
317
318
318
- q_full_shape = None
319
+ num_batches = None
320
+ batch_size = None
321
+ if len (q .shape ) == 5 :
322
+ num_batches , batch_size , * rest = q .shape
323
+ q = q .reshape (- 1 , * rest )
324
+ k = k .reshape (- 1 , * rest )
325
+ v = v .reshape (- 1 , * rest )
326
+ if q_segment_ids is not None :
327
+ q_segment_ids = q_segment_ids .reshape (- 1 , * rest )
328
+ if kv_segment_ids is not None :
329
+ kv_segment_ids = kv_segment_ids .reshape (- 1 , * rest )
330
+ if ab is not none :
331
+ ab = ab .reshape (- 1 , * rest )
332
+
319
333
320
334
# Suprisingly, any tensor that is input to the custom_op decorated function will show
321
335
# requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
@@ -389,6 +403,11 @@ def _fa_custom_forward_single_device(
389
403
o , * aux = o
390
404
l , m = (v [..., 0 ] for v in aux [- 2 :])
391
405
406
+ if num_batches is not None :
407
+ o = o .reshape (num_batches , batch_size , * o .shape [1 :])
408
+ l = l .reshape (num_batches , batch_size , * l .shape [1 :])
409
+ m = m .reshape (num_batches , batch_size , * m .shape [1 :])
410
+
392
411
return o , l , m
393
412
394
413
@@ -403,8 +422,6 @@ def fa_custom_forward(
403
422
partition_spec = eval (partition_spec )
404
423
mesh = xs .get_global_mesh () or Mesh .from_str (mesh )
405
424
406
- q_full_shape = None
407
-
408
425
# Suprisingly, any tensor that is input to the custom_op decorated function will show
409
426
# requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
410
427
# requires_grad for inputs.
@@ -434,7 +451,10 @@ def fa_custom_forward(
434
451
ab , max (block_k_major , block_k ), 3 , padding_minus_inf = True )
435
452
436
453
if partition_spec is not None :
437
- segment_id_partition_spec = (partition_spec [0 ], partition_spec [2 ])
454
+ if len (partition_spec ) == 5 :
455
+ segment_id_partition_spec = (partition_spec [0 ], partition_spec [1 ], partition_spec [3 ])
456
+ else :
457
+ segment_id_partition_spec = (partition_spec [0 ], partition_spec [2 ])
438
458
439
459
input_specs = [
440
460
partition_spec , # q
@@ -455,13 +475,13 @@ def fa_custom_forward(
455
475
]
456
476
457
477
fa_forward_callable = _shard_map (
458
- _fa_custom_forward_one_device ,
478
+ _fa_custom_forward_single_device ,
459
479
mesh ,
460
480
input_specs ,
461
481
output_specs ,
462
482
)
463
483
else :
464
- fa_forward_callable = _fa_custom_forward_one_device
484
+ fa_forward_callable = _fa_custom_forward_single_device
465
485
466
486
o , l , m = fa_forward_callable (
467
487
q , k , v , causal , q_segment_ids , kv_segment_ids , sm_scale , ab , ctx_grad
@@ -505,6 +525,25 @@ def _fa_custom_backward_single_device(
505
525
from jax .experimental .pallas .ops .tpu .flash_attention import _flash_attention_bwd_dq , _flash_attention_bwd_dkv
506
526
grad_q = grad_k = grad_v = grad_ab = segment_ids = None
507
527
528
+
529
+ num_batches = None
530
+ batch_size = None
531
+ if len (q .shape ) == 5 :
532
+ num_batches , batch_size , * rest = q .shape
533
+ grad_output = grad_output .reshape (- 1 , * rest )
534
+ q = q .reshape (- 1 , * rest )
535
+ k = k .reshape (- 1 , * rest )
536
+ v = v .reshape (- 1 , * rest )
537
+ o = o .reshape (- 1 , * rest )
538
+ l = l .reshape (- 1 , * rest )
539
+ m = m .reshape (- 1 , * rest )
540
+ if q_segment_ids is not None :
541
+ q_segment_ids = q_segment_ids .reshape (- 1 , * rest )
542
+ if kv_segment_ids is not None :
543
+ kv_segment_ids = kv_segment_ids .reshape (- 1 , * rest )
544
+ if ab is not none :
545
+ ab = ab .reshape (- 1 , * rest )
546
+
508
547
require_grad_q , require_grad_k , require_grad_v , * rest = ctx_grad
509
548
require_grad_ab = ctx_grad [- 3 ]
510
549
@@ -617,6 +656,16 @@ def _fa_custom_backward_single_device(
617
656
if require_grad_v :
618
657
grad_v = grads [1 ]
619
658
659
+ if num_batches is not None :
660
+ def _reshape (x ):
661
+ if x is not None :
662
+ return x .reshape (num_batches , batch_size , * x .shape [1 :])
663
+ return None
664
+ grad_q = _reshape (grad_q )
665
+ grad_k = _reshape (grad_k )
666
+ grad_v = _reshape (grad_v )
667
+ grad_ab = _reshape (grad_ab )
668
+
620
669
return grad_q , grad_k , grad_v , grad_ab
621
670
622
671
@custom_op ("xla::fa_custom_backward" , mutates_args = ())
@@ -643,7 +692,10 @@ def fa_custom_backward(
643
692
644
693
645
694
if partition_spec :
646
- segment_id_partition_spec = (partition_spec [0 ], partition_spec [2 ])
695
+ if len (partition_spec ) == 5 :
696
+ segment_id_partition_spec = (partition_spec [0 ], partition_spec [1 ], partition_spec [3 ])
697
+ else :
698
+ segment_id_partition_spec = (partition_spec [0 ], partition_spec [2 ])
647
699
input_specs = [
648
700
partition_spec , # grad_output
649
701
partition_spec , # q
@@ -669,7 +721,7 @@ def fa_custom_backward(
669
721
partition_spec ,
670
722
]
671
723
fa_backward_callable = _shard_map (
672
- _fa_custom_backward_single_device
724
+ _fa_custom_backward_single_device ,
673
725
mesh ,
674
726
input_specs ,
675
727
output_specs
@@ -678,7 +730,7 @@ def fa_custom_backward(
678
730
fa_backward_callable = _fa_custom_backward_single_device
679
731
680
732
res = fa_backward_callable (
681
- grad_output , q , k , v , o , l , m , q_segment_ids , kv_segment_ids , ab , causal , sm_scale
733
+ grad_output , q , k , v , o , l , m , q_segment_ids , kv_segment_ids , ab , causal , sm_scale ,
682
734
q_full_shape , kv_full_shape , ab_full_shape , ctx_grad
683
735
)
684
736
0 commit comments