@@ -20,7 +20,7 @@ def _shard_map(func, mesh, input_specs, output_specs):
20
20
21
21
Note:
22
22
``shard_map`` is an experimental API, and still subject to change. For an
23
- introduction to sharded data, refer to :ref:`sharded-computation` . For a more
23
+ introduction to sharded data. For a more
24
24
in-depth look at using ``shard_map``, refer to
25
25
[SPMD multi-device parallelism with shard_map](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
26
26
@@ -43,7 +43,7 @@ def _shard_map(func, mesh, input_specs, output_specs):
43
43
the ``mesh`` and ``out_specs``.
44
44
45
45
Reference:
46
- This function is identical Jax's shard_map:
46
+ This function behaves identically Jax's shard_map:
47
47
https://docs.jax.dev/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html
48
48
"""
49
49
@@ -56,18 +56,14 @@ def _full_shape(a, spec):
56
56
result_shape = []
57
57
for axis_size , axis_sharding in zip (a .shape , spec ):
58
58
if axis_sharding is None :
59
- new_size = axis_size
60
- else :
61
- if isinstance (axis_sharding , (str , int )):
62
- mesh_mult = mesh_name_to_size [axis_sharding ]
63
- else :
64
- # tuple or list
65
- mesh_mult = math .prod (mesh_name_to_size [a ]
66
- for a in axis_sharding
67
- if mesh_name_to_size [a ] is not None )
68
-
69
- if mesh_mult is not None :
70
- new_size = axis_size * mesh_mult
59
+ axis_sharding = ()
60
+ mesh_mult = []
61
+ if isinstance (axis_sharding , (str , int )):
62
+ axis_sharding = [axis_sharding ]
63
+ for a in axis_sharding :
64
+ size = mesh_name_to_size [a ] or 1
65
+ mesh_mult .append (a )
66
+ new_size = axis_size * math .prod (mesh_mult )
71
67
result_shape .append (new_size )
72
68
return tuple (result_shape )
73
69
@@ -87,17 +83,16 @@ def wrapped(*args):
87
83
res = func (* new_args )
88
84
if isinstance (res , tuple ):
89
85
res_updated = []
90
- for i , r in enumerate (res ):
91
- if isinstance (r , torch .Tensor ):
86
+ for i , ( r , spec ) in enumerate (zip ( res , output_specs ) ):
87
+ if isinstance (r , torch .Tensor ) and spec is not None :
92
88
assert str (r .device ).startswith ('xla' ), f'{ i } th device is { r .device } '
93
89
assert len (r .shape ) == len (
94
- output_specs [i ]
95
- ), f'{ i } th shape is { r .shape } , sharding is { output_specs [i ]} '
96
- return tuple (
97
- xs .disable_manual_sharding (a , spec , _full_shape (a , spec ), mesh = mesh ).
98
- global_tensor
99
- if isinstance (a , torch .Tensor ) and spec is not None else a
100
- for a , spec in zip (res , output_specs ))
90
+ spec ), f'{ i } th shape is { r .shape } , sharding is { output_specs [i ]} '
91
+ new_r = xs .disable_manual_sharding (
92
+ r , spec , _full_shape (a , spec ), mesh = mesh ).global_tensor
93
+ else :
94
+ new_r = r
95
+ res_updated .append (new_r )
101
96
else :
102
97
return xs .disable_manual_sharding (
103
98
res , output_specs [0 ], _full_shape (res , output_specs [0 ]),
@@ -309,6 +304,24 @@ def wrapped_kernel(kernel: Callable,
309
304
return functools .partial (wrapped_kernel , kernel , output_shape_dtype_fn )
310
305
311
306
307
+ def _maybe_reshape_input_output_funcs (current_shape , non_batch_dims = 3 ):
308
+ batch_dims = len (current_shape ) - non_batch_dims
309
+ orig_batch_dims = current_shape [:batch_dims ]
310
+ other_dims = current_shape [batch_dims :]
311
+
312
+ def reshape_input (tensor ):
313
+ if tensor is None :
314
+ return None
315
+ return tensor .reshape (- 1 , * tensor .shape [batch_dims :])
316
+
317
+ def reshape_output (tensor ):
318
+ if tensor is None :
319
+ return None
320
+ return tensor .reshape (* orig_batch_dims , * tensor .shape [1 :])
321
+
322
+ return reshape_input , reshape_output
323
+
324
+
312
325
def _fa_custom_forward_single_device (
313
326
q : torch .Tensor , k : torch .Tensor , v : torch .Tensor , causal : bool ,
314
327
q_segment_ids : torch .Tensor , kv_segment_ids : torch .Tensor , sm_scale : float ,
@@ -318,20 +331,16 @@ def _fa_custom_forward_single_device(
318
331
319
332
num_batches = None
320
333
batch_size = None
321
- if len (q .shape ) == 5 :
322
- num_batches , batch_size , * rest = q .shape
323
- q = q .reshape (- 1 , * rest )
324
- k = k .reshape (- 1 , * rest )
325
- v = v .reshape (- 1 , * rest )
326
- if q_segment_ids is not None :
327
- q_segment_ids = q_segment_ids .reshape (- 1 , * rest )
328
- if kv_segment_ids is not None :
329
- kv_segment_ids = kv_segment_ids .reshape (- 1 , * rest )
330
- if ab is not None :
331
- ab = ab .reshape (- 1 , * rest )
332
-
333
- # Suprisingly, any tensor that is input to the custom_op decorated function will show
334
- # requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
334
+ reshape_to_4d , undo_reshape = _maybe_reshape_input_output_funcs (q .shape , 3 )
335
+ q = reshape_to_4d (q )
336
+ v = reshape_to_4d (v )
337
+ k = reshape_to_4d (k )
338
+ q_segment_ids = reshape_to_4d (q_segment_ids )
339
+ kv_segment_ids = reshape_to_4d (kv_segment_ids )
340
+ ab = reshape_to_4d (ab )
341
+
342
+ # Surprisingly, any tensor that is input to the custom_op decorated function will show
343
+ # requires_grad=False by design. We have to pass ctx_grad to record the
335
344
# requires_grad for inputs.
336
345
# Original we use save_residuals = q.requires_grad or k.requires_grad or v.requires_grad
337
346
save_residuals = any (ctx_grad [:3 ])
@@ -401,12 +410,9 @@ def _fa_custom_forward_single_device(
401
410
o , * aux = custom_call_output
402
411
l , m = (v [..., 0 ] for v in aux [- 2 :])
403
412
404
- if num_batches is not None :
405
- o = o .reshape (num_batches , batch_size , * o .shape [1 :])
406
- if l is not None :
407
- l = l .reshape (num_batches , batch_size , * l .shape [1 :])
408
- if m is not None :
409
- m = m .reshape (num_batches , batch_size , * m .shape [1 :])
413
+ o = undo_reshape (o )
414
+ l = undo_reshape (l )
415
+ m = undo_reshape (m )
410
416
411
417
return o , l , m
412
418
@@ -518,21 +524,18 @@ def _fa_custom_backward_single_device(
518
524
519
525
num_batches = None
520
526
batch_size = None
521
- if len (q .shape ) == 5 :
522
- num_batches , batch_size , * rest = q .shape
523
- grad_output = grad_output .reshape (- 1 , * rest )
524
- q = q .reshape (- 1 , * rest )
525
- k = k .reshape (- 1 , * rest )
526
- v = v .reshape (- 1 , * rest )
527
- o = o .reshape (- 1 , * rest )
528
- l = l .reshape (- 1 , * rest )
529
- m = m .reshape (- 1 , * rest )
530
- if q_segment_ids is not None :
531
- q_segment_ids = q_segment_ids .reshape (- 1 , * rest )
532
- if kv_segment_ids is not None :
533
- kv_segment_ids = kv_segment_ids .reshape (- 1 , * rest )
534
- if ab is not none :
535
- ab = ab .reshape (- 1 , * rest )
527
+ reshape_to_4d , undo_reshape = _maybe_reshape_input_output_funcs (q .shape , 3 )
528
+
529
+ grad_output = reshape_to_4d (grad_output )
530
+ q = reshape_to_4d (q )
531
+ k = reshape_to_4d (k )
532
+ v = reshape_to_4d (v )
533
+ o = reshape_to_4d (o )
534
+ l = reshape_to_4d (l )
535
+ m = reshape_to_4d (m )
536
+ q_segment_ids = reshape_to_4d (q_segment_ids )
537
+ kv_segment_ids = reshape_to_4d (kv_segment_ids )
538
+ ab = reshape_to_4d (ab )
536
539
537
540
require_grad_q , require_grad_k , require_grad_v , * rest = ctx_grad
538
541
require_grad_ab = ctx_grad [- 3 ]
@@ -646,17 +649,10 @@ def _fa_custom_backward_single_device(
646
649
if require_grad_v :
647
650
grad_v = grads [1 ]
648
651
649
- if num_batches is not None :
650
-
651
- def _reshape (x ):
652
- if x is not None :
653
- return x .reshape (num_batches , batch_size , * x .shape [1 :])
654
- return None
655
-
656
- grad_q = _reshape (grad_q )
657
- grad_k = _reshape (grad_k )
658
- grad_v = _reshape (grad_v )
659
- grad_ab = _reshape (grad_ab )
652
+ grad_q = undo_reshape (grad_q )
653
+ grad_k = undo_reshape (grad_k )
654
+ grad_v = undo_reshape (grad_v )
655
+ grad_ab = undo_reshape (grad_ab )
660
656
661
657
return grad_q , grad_k , grad_v , grad_ab
662
658
0 commit comments