1
1
import functools
2
2
import os
3
+ import math
3
4
import warnings
4
5
5
6
import torch
@@ -20,7 +21,7 @@ def _shard_map(func, mesh, input_specs, output_specs):
20
21
21
22
Note:
22
23
``shard_map`` is an experimental API, and still subject to change. For an
23
- introduction to sharded data, refer to :ref:`sharded-computation` . For a more
24
+ introduction to sharded data. For a more
24
25
in-depth look at using ``shard_map``, refer to
25
26
[SPMD multi-device parallelism with shard_map](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
26
27
@@ -43,7 +44,7 @@ def _shard_map(func, mesh, input_specs, output_specs):
43
44
the ``mesh`` and ``out_specs``.
44
45
45
46
Reference:
46
- This function is identical Jax's shard_map:
47
+ This function behaves identically Jax's shard_map:
47
48
https://docs.jax.dev/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html
48
49
"""
49
50
@@ -56,18 +57,14 @@ def _full_shape(a, spec):
56
57
result_shape = []
57
58
for axis_size , axis_sharding in zip (a .shape , spec ):
58
59
if axis_sharding is None :
59
- new_size = axis_size
60
- else :
61
- if isinstance (axis_sharding , (str , int )):
62
- mesh_mult = mesh_name_to_size [axis_sharding ]
63
- else :
64
- # tuple or list
65
- mesh_mult = math .prod (mesh_name_to_size [a ]
66
- for a in axis_sharding
67
- if mesh_name_to_size [a ] is not None )
68
-
69
- if mesh_mult is not None :
70
- new_size = axis_size * mesh_mult
60
+ axis_sharding = ()
61
+ mesh_mult = []
62
+ if isinstance (axis_sharding , (str , int )):
63
+ axis_sharding = [axis_sharding ]
64
+ for axis in axis_sharding :
65
+ size = mesh_name_to_size [axis ] or 1
66
+ mesh_mult .append (size )
67
+ new_size = axis_size * math .prod (mesh_mult )
71
68
result_shape .append (new_size )
72
69
return tuple (result_shape )
73
70
@@ -76,7 +73,7 @@ def wrapped(*args):
76
73
input_specs ), f'args={ len (args )} ; input_specs={ len (input_specs )} '
77
74
new_args = []
78
75
for i , (a , spec ) in enumerate (zip (args , input_specs )):
79
- if isinstance (a , torch .Tensor ) and spec is not None :
76
+ if isinstance (a , torch .Tensor ):
80
77
assert (len (a .shape ) == len (spec )
81
78
), f'{ i } th input has wrong shape: { a .shape } for { spec } '
82
79
new_a = xs .enable_manual_sharding (a , spec , mesh = mesh ).global_tensor
@@ -87,22 +84,21 @@ def wrapped(*args):
87
84
res = func (* new_args )
88
85
if isinstance (res , tuple ):
89
86
res_updated = []
90
- for i , r in enumerate (res ):
91
- if isinstance (r , torch .Tensor ):
87
+ for i , ( r , spec ) in enumerate (zip ( res , output_specs ) ):
88
+ if isinstance (r , torch .Tensor ) and spec is not None :
92
89
assert str (r .device ).startswith ('xla' ), f'{ i } th device is { r .device } '
93
90
assert len (r .shape ) == len (
94
- output_specs [i ]
95
- ), f' { i } th shape is { r . shape } , sharding is { output_specs [ i ] } '
96
- return tuple (
97
- xs . disable_manual_sharding ( a , spec , _full_shape ( a , spec ), mesh = mesh ).
98
- global_tensor
99
- if isinstance ( a , torch . Tensor ) and spec is not None else a
100
- for a , spec in zip ( res , output_specs ))
91
+ spec ), f' { i } th shape is { r . shape } , sharding is { output_specs [i ]} '
92
+ new_r = xs . disable_manual_sharding (
93
+ r , spec , _full_shape ( r , spec ), mesh = mesh ). global_tensor
94
+ else :
95
+ new_r = r
96
+ res_updated . append ( new_r )
97
+ return res_updated
101
98
else :
102
99
return xs .disable_manual_sharding (
103
100
res , output_specs [0 ], _full_shape (res , output_specs [0 ]),
104
101
mesh = mesh ).global_tensor
105
- return res
106
102
107
103
return wrapped
108
104
@@ -309,6 +305,24 @@ def wrapped_kernel(kernel: Callable,
309
305
return functools .partial (wrapped_kernel , kernel , output_shape_dtype_fn )
310
306
311
307
308
+ def _maybe_reshape_input_output_funcs (current_shape , non_batch_dims = 3 ):
309
+ batch_dims = len (current_shape ) - non_batch_dims
310
+ orig_batch_dims = current_shape [:batch_dims ]
311
+ other_dims = current_shape [batch_dims :]
312
+
313
+ def reshape_input (tensor ):
314
+ if tensor is None :
315
+ return None
316
+ return tensor .reshape (- 1 , * tensor .shape [batch_dims :])
317
+
318
+ def reshape_output (tensor ):
319
+ if tensor is None :
320
+ return None
321
+ return tensor .reshape (* orig_batch_dims , * tensor .shape [1 :])
322
+
323
+ return reshape_input , reshape_output
324
+
325
+
312
326
def _fa_custom_forward_single_device (
313
327
q : torch .Tensor , k : torch .Tensor , v : torch .Tensor , causal : bool ,
314
328
q_segment_ids : torch .Tensor , kv_segment_ids : torch .Tensor , sm_scale : float ,
@@ -318,20 +332,16 @@ def _fa_custom_forward_single_device(
318
332
319
333
num_batches = None
320
334
batch_size = None
321
- if len (q .shape ) == 5 :
322
- num_batches , batch_size , * rest = q .shape
323
- q = q .reshape (- 1 , * rest )
324
- k = k .reshape (- 1 , * rest )
325
- v = v .reshape (- 1 , * rest )
326
- if q_segment_ids is not None :
327
- q_segment_ids = q_segment_ids .reshape (- 1 , * rest )
328
- if kv_segment_ids is not None :
329
- kv_segment_ids = kv_segment_ids .reshape (- 1 , * rest )
330
- if ab is not None :
331
- ab = ab .reshape (- 1 , * rest )
332
-
333
- # Suprisingly, any tensor that is input to the custom_op decorated function will show
334
- # requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
335
+ reshape_to_4d , undo_reshape = _maybe_reshape_input_output_funcs (q .shape , 3 )
336
+ q = reshape_to_4d (q )
337
+ v = reshape_to_4d (v )
338
+ k = reshape_to_4d (k )
339
+ q_segment_ids = reshape_to_4d (q_segment_ids )
340
+ kv_segment_ids = reshape_to_4d (kv_segment_ids )
341
+ ab = reshape_to_4d (ab )
342
+
343
+ # Surprisingly, any tensor that is input to the custom_op decorated function will show
344
+ # requires_grad=False by design. We have to pass ctx_grad to record the
335
345
# requires_grad for inputs.
336
346
# Original we use save_residuals = q.requires_grad or k.requires_grad or v.requires_grad
337
347
save_residuals = any (ctx_grad [:3 ])
@@ -401,12 +411,9 @@ def _fa_custom_forward_single_device(
401
411
o , * aux = custom_call_output
402
412
l , m = (v [..., 0 ] for v in aux [- 2 :])
403
413
404
- if num_batches is not None :
405
- o = o .reshape (num_batches , batch_size , * o .shape [1 :])
406
- if l is not None :
407
- l = l .reshape (num_batches , batch_size , * l .shape [1 :])
408
- if m is not None :
409
- m = m .reshape (num_batches , batch_size , * m .shape [1 :])
414
+ o = undo_reshape (o )
415
+ l = undo_reshape (l )
416
+ m = undo_reshape (m )
410
417
411
418
return o , l , m
412
419
@@ -518,21 +525,18 @@ def _fa_custom_backward_single_device(
518
525
519
526
num_batches = None
520
527
batch_size = None
521
- if len (q .shape ) == 5 :
522
- num_batches , batch_size , * rest = q .shape
523
- grad_output = grad_output .reshape (- 1 , * rest )
524
- q = q .reshape (- 1 , * rest )
525
- k = k .reshape (- 1 , * rest )
526
- v = v .reshape (- 1 , * rest )
527
- o = o .reshape (- 1 , * rest )
528
- l = l .reshape (- 1 , * rest )
529
- m = m .reshape (- 1 , * rest )
530
- if q_segment_ids is not None :
531
- q_segment_ids = q_segment_ids .reshape (- 1 , * rest )
532
- if kv_segment_ids is not None :
533
- kv_segment_ids = kv_segment_ids .reshape (- 1 , * rest )
534
- if ab is not none :
535
- ab = ab .reshape (- 1 , * rest )
528
+ reshape_to_4d , undo_reshape = _maybe_reshape_input_output_funcs (q .shape , 3 )
529
+
530
+ grad_output = reshape_to_4d (grad_output )
531
+ q = reshape_to_4d (q )
532
+ k = reshape_to_4d (k )
533
+ v = reshape_to_4d (v )
534
+ o = reshape_to_4d (o )
535
+ l = reshape_to_4d (l )
536
+ m = reshape_to_4d (m )
537
+ q_segment_ids = reshape_to_4d (q_segment_ids )
538
+ kv_segment_ids = reshape_to_4d (kv_segment_ids )
539
+ ab = reshape_to_4d (ab )
536
540
537
541
require_grad_q , require_grad_k , require_grad_v , * rest = ctx_grad
538
542
require_grad_ab = ctx_grad [- 3 ]
@@ -646,17 +650,10 @@ def _fa_custom_backward_single_device(
646
650
if require_grad_v :
647
651
grad_v = grads [1 ]
648
652
649
- if num_batches is not None :
650
-
651
- def _reshape (x ):
652
- if x is not None :
653
- return x .reshape (num_batches , batch_size , * x .shape [1 :])
654
- return None
655
-
656
- grad_q = _reshape (grad_q )
657
- grad_k = _reshape (grad_k )
658
- grad_v = _reshape (grad_v )
659
- grad_ab = _reshape (grad_ab )
653
+ grad_q = undo_reshape (grad_q )
654
+ grad_k = undo_reshape (grad_k )
655
+ grad_v = undo_reshape (grad_v )
656
+ grad_ab = undo_reshape (grad_ab )
660
657
661
658
return grad_q , grad_k , grad_v , grad_ab
662
659
0 commit comments