@@ -25,7 +25,8 @@ def _shard_map(
25
25
Note:
26
26
``shard_map`` is an experimental API, and still subject to change. For an
27
27
introduction to sharded data, refer to :ref:`sharded-computation`. For a more
28
- in-depth look at using ``shard_map``, refer to `SPMD multi-device parallelism with shard_map`_.
28
+ in-depth look at using ``shard_map``, refer to
29
+ [SPMD multi-device parallelism with shard_map](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
29
30
30
31
Args:
31
32
func: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
@@ -315,7 +316,18 @@ def _fa_custom_forward_single_device(
315
316
mesh = xs .get_global_mesh () or Mesh .from_str (mesh )
316
317
from jax .experimental .pallas .ops .tpu .flash_attention import _flash_attention_impl
317
318
318
- q_full_shape = None
319
+ num_batches = None
320
+ batch_size = None
321
+ if len (q .shape ) == 5 :
322
+ num_batches , batch_size , * rest = q .shape
323
+ q = q .reshape (- 1 , * rest )
324
+ k = k .reshape (- 1 , * rest )
325
+ v = v .reshape (- 1 , * rest )
326
+ q_segment_ids = q_segment_ids .reshape (- 1 , * rest )
327
+ kv_segment_ids = kv_segment_ids .reshape (- 1 , * rest )
328
+ if ab is not none :
329
+ ab = ab .reshape (- 1 , * rest )
330
+
319
331
320
332
# Suprisingly, any tensor that is input to the custom_op decorated function will show
321
333
# requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
@@ -389,6 +401,11 @@ def _fa_custom_forward_single_device(
389
401
o , * aux = o
390
402
l , m = (v [..., 0 ] for v in aux [- 2 :])
391
403
404
+ if num_batches is not None :
405
+ o = o .reshape (num_batches , batch_size , * o .shape [1 :])
406
+ l = l .reshape (num_batches , batch_size , * l .shape [1 :])
407
+ m = m .reshape (num_batches , batch_size , * m .shape [1 :])
408
+
392
409
return o , l , m
393
410
394
411
@@ -403,8 +420,6 @@ def fa_custom_forward(
403
420
partition_spec = eval (partition_spec )
404
421
mesh = xs .get_global_mesh () or Mesh .from_str (mesh )
405
422
406
- q_full_shape = None
407
-
408
423
# Suprisingly, any tensor that is input to the custom_op decorated function will show
409
424
# requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
410
425
# requires_grad for inputs.
@@ -434,7 +449,10 @@ def fa_custom_forward(
434
449
ab , max (block_k_major , block_k ), 3 , padding_minus_inf = True )
435
450
436
451
if partition_spec is not None :
437
- segment_id_partition_spec = (partition_spec [0 ], partition_spec [2 ])
452
+ if len (partition_spec ) == 5 :
453
+ segment_id_partition_spec = (partition_spec [0 ], partition_spec [1 ], partition_spec [3 ])
454
+ else :
455
+ segment_id_partition_spec = (partition_spec [0 ], partition_spec [2 ])
438
456
439
457
input_specs = [
440
458
partition_spec , # q
@@ -455,13 +473,13 @@ def fa_custom_forward(
455
473
]
456
474
457
475
fa_forward_callable = _shard_map (
458
- _fa_custom_forward_one_device ,
476
+ _fa_custom_forward_single_device ,
459
477
mesh ,
460
478
input_specs ,
461
479
output_specs ,
462
480
)
463
481
else :
464
- fa_forward_callable = _fa_custom_forward_one_device
482
+ fa_forward_callable = _fa_custom_forward_single_device
465
483
466
484
o , l , m = fa_forward_callable (
467
485
q , k , v , causal , q_segment_ids , kv_segment_ids , sm_scale , ab , ctx_grad
@@ -505,6 +523,25 @@ def _fa_custom_backward_single_device(
505
523
from jax .experimental .pallas .ops .tpu .flash_attention import _flash_attention_bwd_dq , _flash_attention_bwd_dkv
506
524
grad_q = grad_k = grad_v = grad_ab = segment_ids = None
507
525
526
+
527
+ num_batches = None
528
+ batch_size = None
529
+ if len (q .shape ) == 5 :
530
+ num_batches , batch_size , * rest = q .shape
531
+ grad_output = grad_output .reshape (- 1 , * rest )
532
+ q = q .reshape (- 1 , * rest )
533
+ k = k .reshape (- 1 , * rest )
534
+ v = v .reshape (- 1 , * rest )
535
+ o = o .reshape (- 1 , * rest )
536
+ l = l .reshape (- 1 , * rest )
537
+ m = m .reshape (- 1 , * rest )
538
+ if q_segment_ids is not None :
539
+ q_segment_ids = q_segment_ids .reshape (- 1 , * rest )
540
+ if kv_segment_ids is not None :
541
+ kv_segment_ids = kv_segment_ids .reshape (- 1 , * rest )
542
+ if ab is not none :
543
+ ab = ab .reshape (- 1 , * rest )
544
+
508
545
require_grad_q , require_grad_k , require_grad_v , * rest = ctx_grad
509
546
require_grad_ab = ctx_grad [- 3 ]
510
547
@@ -643,7 +680,10 @@ def fa_custom_backward(
643
680
644
681
645
682
if partition_spec :
646
- segment_id_partition_spec = (partition_spec [0 ], partition_spec [2 ])
683
+ if len (partition_spec ) == 5 :
684
+ segment_id_partition_spec = (partition_spec [0 ], partition_spec [1 ], partition_spec [3 ])
685
+ else :
686
+ segment_id_partition_spec = (partition_spec [0 ], partition_spec [2 ])
647
687
input_specs = [
648
688
partition_spec , # grad_output
649
689
partition_spec , # q
@@ -669,7 +709,7 @@ def fa_custom_backward(
669
709
partition_spec ,
670
710
]
671
711
fa_backward_callable = _shard_map (
672
- _fa_custom_backward_single_device
712
+ _fa_custom_backward_single_device ,
673
713
mesh ,
674
714
input_specs ,
675
715
output_specs
@@ -678,7 +718,7 @@ def fa_custom_backward(
678
718
fa_backward_callable = _fa_custom_backward_single_device
679
719
680
720
res = fa_backward_callable (
681
- grad_output , q , k , v , o , l , m , q_segment_ids , kv_segment_ids , ab , causal , sm_scale
721
+ grad_output , q , k , v , o , l , m , q_segment_ids , kv_segment_ids , ab , causal , sm_scale ,
682
722
q_full_shape , kv_full_shape , ab_full_shape , ctx_grad
683
723
)
684
724
0 commit comments