14
14
15
15
_XLA_USE_BF16 = os .environ .get ("XLA_USE_BF16" , "0" ) == "1"
16
16
17
- def _shard_map (
18
- func ,
19
- mesh ,
20
- input_specs ,
21
- output_specs
22
- ):
23
- """Map a function over shards of data.
17
+
18
+ def _shard_map (func , mesh , input_specs , output_specs ):
19
+ """Map a function over shards of data.
24
20
25
21
Note:
26
22
``shard_map`` is an experimental API, and still subject to change. For an
@@ -51,57 +47,58 @@ def _shard_map(
51
47
https://docs.jax.dev/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html
52
48
"""
53
49
54
- def _full_shape (a , spec ):
55
- # a is local tensor
56
- # spec is the sharding spec
57
- # return logical shape of global tensor
58
- mesh_name_to_size = dict (
59
- zip (mesh .axis_names , mesh .mesh_shape )
60
- )
61
-
62
- result_shape = []
63
- for axis_size , axis_sharding in zip (a .shape , spec ):
64
- if axis_sharding is None :
65
- new_size = axis_size
66
- else :
67
- if isinstance (axis_sharding , str ):
68
- mesh_mult = mesh_name_to_size [axis_sharding ]
69
- else :
70
- # tuple or list
71
- mesh_mult = math .prod (
72
- mesh_name_to_size [a ] for a in axis_sharding
73
- if mesh_name_to_size [a ] is not None )
74
-
75
- if mesh_mult is not None :
76
- new_size = axis_size * mesh_mult
77
- result_shape .append (new_size )
78
- return tuple (result_shape )
79
-
80
- def wrapped (* args ):
81
- assert len (args ) == len (input_specs ), f'args={ len (args )} ; input_specs={ len (input_specs )} '
82
- new_args = []
83
- for i , (a , spec ) in enumerate (zip (args , input_specs )):
84
- if isinstance (a , torch .Tensor ) and spec is not None :
85
- assert (len (a .shape ) == len (spec )), f'{ i } th input has wrong shape: { a .shape } for { spec } '
86
- new_a = xs .enable_manual_sharding (a , spec , mesh = mesh ).global_tensor
87
- new_args .append (new_a )
88
- else :
89
- new_args .append (a )
90
-
91
- res = func (* new_args )
92
- if isinstance (res , tuple ):
93
- return tuple (
94
- xs .disable_manual_sharding (
95
- a , spec , _full_shape (a , spec ), mesh = mesh ).global_tensor
96
- if isinstance (a , torch .Tensor ) and spec is not None else a
97
- for a , spec in zip (res , output_specs )
98
- )
50
+ def _full_shape (a , spec ):
51
+ # a is local tensor
52
+ # spec is the sharding spec
53
+ # return logical shape of global tensor
54
+ mesh_name_to_size = dict (zip (mesh .axis_names , mesh .mesh_shape ))
55
+
56
+ result_shape = []
57
+ for axis_size , axis_sharding in zip (a .shape , spec ):
58
+ if axis_sharding is None :
59
+ new_size = axis_size
60
+ else :
61
+ if isinstance (axis_sharding , str ):
62
+ mesh_mult = mesh_name_to_size [axis_sharding ]
99
63
else :
100
- return xs .disable_manual_sharding (
101
- res , output_specs [0 ],
102
- _full_shape (res , output_specs [0 ]), mesh = mesh ).global_tensor
103
- return res
104
- return wrapped
64
+ # tuple or list
65
+ mesh_mult = math .prod (mesh_name_to_size [a ]
66
+ for a in axis_sharding
67
+ if mesh_name_to_size [a ] is not None )
68
+
69
+ if mesh_mult is not None :
70
+ new_size = axis_size * mesh_mult
71
+ result_shape .append (new_size )
72
+ return tuple (result_shape )
73
+
74
+ def wrapped (* args ):
75
+ assert len (args ) == len (
76
+ input_specs ), f'args={ len (args )} ; input_specs={ len (input_specs )} '
77
+ new_args = []
78
+ for i , (a , spec ) in enumerate (zip (args , input_specs )):
79
+ if isinstance (a , torch .Tensor ) and spec is not None :
80
+ assert (len (a .shape ) == len (spec )
81
+ ), f'{ i } th input has wrong shape: { a .shape } for { spec } '
82
+ new_a = xs .enable_manual_sharding (a , spec , mesh = mesh ).global_tensor
83
+ new_args .append (new_a )
84
+ else :
85
+ new_args .append (a )
86
+
87
+ res = func (* new_args )
88
+ if isinstance (res , tuple ):
89
+ return tuple (
90
+ xs .disable_manual_sharding (a , spec , _full_shape (a , spec ), mesh = mesh ).
91
+ global_tensor
92
+ if isinstance (a , torch .Tensor ) and spec is not None else a
93
+ for a , spec in zip (res , output_specs ))
94
+ else :
95
+ return xs .disable_manual_sharding (
96
+ res , output_specs [0 ], _full_shape (res , output_specs [0 ]),
97
+ mesh = mesh ).global_tensor
98
+ return res
99
+
100
+ return wrapped
101
+
105
102
106
103
def safe_empty_like (tensor : Optional [torch .Tensor ]) -> Optional [torch .Tensor ]:
107
104
"""Returns empty tensor like input, or None if input is None."""
@@ -306,16 +303,10 @@ def wrapped_kernel(kernel: Callable,
306
303
307
304
308
305
def _fa_custom_forward_single_device (
309
- q : torch .Tensor ,
310
- k : torch .Tensor ,
311
- v : torch .Tensor ,
312
- causal : bool ,
313
- q_segment_ids : torch .Tensor ,
314
- kv_segment_ids : torch .Tensor ,
315
- sm_scale : float ,
316
- ab : Optional [torch .Tensor ],
317
- ctx_grad : List [bool ]
318
- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
306
+ q : torch .Tensor , k : torch .Tensor , v : torch .Tensor , causal : bool ,
307
+ q_segment_ids : torch .Tensor , kv_segment_ids : torch .Tensor , sm_scale : float ,
308
+ ab : Optional [torch .Tensor ],
309
+ ctx_grad : List [bool ]) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
319
310
from jax .experimental .pallas .ops .tpu .flash_attention import _flash_attention_impl
320
311
321
312
num_batches = None
@@ -331,15 +322,13 @@ def _fa_custom_forward_single_device(
331
322
kv_segment_ids = kv_segment_ids .reshape (- 1 , * rest )
332
323
if ab is not none :
333
324
ab = ab .reshape (- 1 , * rest )
334
-
335
325
336
326
# Suprisingly, any tensor that is input to the custom_op decorated function will show
337
327
# requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
338
328
# requires_grad for inputs.
339
329
# Original we use save_residuals = q.requires_grad or k.requires_grad or v.requires_grad
340
330
save_residuals = any (ctx_grad [:3 ])
341
331
342
-
343
332
block_k_major = min (FlashAttention .DEFAULT_BLOCK_SIZES ["block_k_major" ],
344
333
k .shape [2 ])
345
334
block_k = min (FlashAttention .DEFAULT_BLOCK_SIZES ["block_k" ], k .shape [2 ])
@@ -456,42 +445,42 @@ def fa_custom_forward(
456
445
457
446
if partition_spec is not None :
458
447
if len (partition_spec ) == 5 :
459
- segment_id_partition_spec = (partition_spec [0 ], partition_spec [1 ], partition_spec [3 ])
448
+ segment_id_partition_spec = (partition_spec [0 ], partition_spec [1 ],
449
+ partition_spec [3 ])
460
450
lm_partition_spec = partition_spec [:4 ]
461
451
else :
462
452
segment_id_partition_spec = (partition_spec [0 ], partition_spec [2 ])
463
453
lm_partition_spec = partition_spec [:3 ]
464
454
465
455
input_specs = [
466
- partition_spec , # q
467
- partition_spec , # k
468
- partition_spec , # v
469
- None ,
470
- segment_id_partition_spec ,
471
- segment_id_partition_spec ,
472
- None ,
473
- partition_spec ,
474
- None ,
456
+ partition_spec , # q
457
+ partition_spec , # k
458
+ partition_spec , # v
459
+ None ,
460
+ segment_id_partition_spec ,
461
+ segment_id_partition_spec ,
462
+ None ,
463
+ partition_spec ,
464
+ None ,
475
465
]
476
466
477
467
output_specs = [
478
- partition_spec , # o
479
- lm_partition_spec , # l
480
- lm_partition_spec , # m
468
+ partition_spec , # o
469
+ lm_partition_spec , # l
470
+ lm_partition_spec , # m
481
471
]
482
472
483
473
fa_forward_callable = _shard_map (
484
- _fa_custom_forward_single_device ,
485
- mesh ,
486
- input_specs ,
487
- output_specs ,
474
+ _fa_custom_forward_single_device ,
475
+ mesh ,
476
+ input_specs ,
477
+ output_specs ,
488
478
)
489
479
else :
490
480
fa_forward_callable = _fa_custom_forward_single_device
491
481
492
- o , l , m = fa_forward_callable (
493
- q , k , v , causal , q_segment_ids , kv_segment_ids , sm_scale , ab , ctx_grad
494
- )
482
+ o , l , m = fa_forward_callable (q , k , v , causal , q_segment_ids , kv_segment_ids ,
483
+ sm_scale , ab , ctx_grad )
495
484
496
485
outs = [o ] + [full_q , full_k , full_v , l , m , full_ab ]
497
486
return tuple (outs )
@@ -523,15 +512,14 @@ def _fa_custom_backward_single_device(
523
512
v : torch .Tensor , o : torch .Tensor , l : torch .Tensor , m : torch .Tensor ,
524
513
q_segment_ids : Optional [torch .Tensor ],
525
514
kv_segment_ids : Optional [torch .Tensor ], ab : Optional [torch .Tensor ],
526
- causal : bool , sm_scale : float ,
527
- q_full_shape : List [int ], kv_full_shape : List [int ],
528
- ab_full_shape : Optional [ List [ int ]], ctx_grad : List [bool ]
515
+ causal : bool , sm_scale : float , q_full_shape : List [ int ],
516
+ kv_full_shape : List [int ], ab_full_shape : Optional [ List [int ] ],
517
+ ctx_grad : List [bool ]
529
518
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
530
519
531
520
from jax .experimental .pallas .ops .tpu .flash_attention import _flash_attention_bwd_dq , _flash_attention_bwd_dkv
532
521
grad_q = grad_k = grad_v = grad_ab = segment_ids = None
533
522
534
-
535
523
num_batches = None
536
524
batch_size = None
537
525
if len (q .shape ) == 5 :
@@ -663,17 +651,20 @@ def _fa_custom_backward_single_device(
663
651
grad_v = grads [1 ]
664
652
665
653
if num_batches is not None :
654
+
666
655
def _reshape (x ):
667
656
if x is not None :
668
657
return x .reshape (num_batches , batch_size , * x .shape [1 :])
669
658
return None
659
+
670
660
grad_q = _reshape (grad_q )
671
661
grad_k = _reshape (grad_k )
672
662
grad_v = _reshape (grad_v )
673
663
grad_ab = _reshape (grad_ab )
674
664
675
665
return grad_q , grad_k , grad_v , grad_ab
676
666
667
+
677
668
@custom_op ("xla::fa_custom_backward" , mutates_args = ())
678
669
def fa_custom_backward (
679
670
grad_output : torch .Tensor , q : torch .Tensor , k : torch .Tensor ,
@@ -696,57 +687,49 @@ def fa_custom_backward(
696
687
ab_full_shape = torch .Size (
697
688
ab_full_shape ) if ab_full_shape is not None else None
698
689
699
-
700
690
if partition_spec :
701
691
if len (partition_spec ) == 5 :
702
- segment_id_partition_spec = (partition_spec [0 ], partition_spec [1 ], partition_spec [3 ])
692
+ segment_id_partition_spec = (partition_spec [0 ], partition_spec [1 ],
693
+ partition_spec [3 ])
703
694
lm_partition_spec = partition_spec [:4 ]
704
695
else :
705
696
segment_id_partition_spec = (partition_spec [0 ], partition_spec [2 ])
706
697
lm_partition_spec = partition_spec [:3 ]
707
698
input_specs = [
708
- partition_spec , # grad_output
709
- partition_spec , # q
710
- partition_spec , # k
711
- partition_spec , # v
712
- partition_spec , # o
713
- lm_partition_spec , # l
714
- lm_partition_spec , # m
715
- segment_id_partition_spec , # q_segment_ids
716
- segment_id_partition_spec , # kv_segment_ids
717
- partition_spec , # ab
718
- None , # causal
719
- None , # sm_scale
720
- None , # q_full_shape
721
- None , # kv_full_shape
722
- None , # ab_full_shape
723
- None , # ctx_grad
699
+ partition_spec , # grad_output
700
+ partition_spec , # q
701
+ partition_spec , # k
702
+ partition_spec , # v
703
+ partition_spec , # o
704
+ lm_partition_spec , # l
705
+ lm_partition_spec , # m
706
+ segment_id_partition_spec , # q_segment_ids
707
+ segment_id_partition_spec , # kv_segment_ids
708
+ partition_spec , # ab
709
+ None , # causal
710
+ None , # sm_scale
711
+ None , # q_full_shape
712
+ None , # kv_full_shape
713
+ None , # ab_full_shape
714
+ None , # ctx_grad
724
715
]
725
716
output_specs = [
726
- partition_spec ,
727
- partition_spec ,
728
- partition_spec ,
729
- partition_spec ,
717
+ partition_spec ,
718
+ partition_spec ,
719
+ partition_spec ,
720
+ partition_spec ,
730
721
]
731
- fa_backward_callable = _shard_map (
732
- _fa_custom_backward_single_device ,
733
- mesh ,
734
- input_specs ,
735
- output_specs
736
- )
722
+ fa_backward_callable = _shard_map (_fa_custom_backward_single_device , mesh ,
723
+ input_specs , output_specs )
737
724
else :
738
725
fa_backward_callable = _fa_custom_backward_single_device
739
726
740
- res = fa_backward_callable (
741
- grad_output , q , k , v , o , l , m , q_segment_ids , kv_segment_ids , ab , causal , sm_scale ,
742
- q_full_shape , kv_full_shape , ab_full_shape , ctx_grad
743
- )
727
+ res = fa_backward_callable (grad_output , q , k , v , o , l , m , q_segment_ids ,
728
+ kv_segment_ids , ab , causal , sm_scale , q_full_shape ,
729
+ kv_full_shape , ab_full_shape , ctx_grad )
744
730
745
731
return res
746
732
747
-
748
-
749
-
750
733
751
734
@fa_custom_forward .register_fake
752
735
def fa_custom_forward_fake (q : torch .Tensor , k : torch .Tensor , v : torch .Tensor ,
0 commit comments