Skip to content

Commit 8983df7

Browse files
committed
lint
1 parent 1f24d81 commit 8983df7

File tree

1 file changed

+111
-128
lines changed

1 file changed

+111
-128
lines changed

torch_xla/experimental/custom_kernel.py

Lines changed: 111 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,9 @@
1414

1515
_XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1"
1616

17-
def _shard_map(
18-
func,
19-
mesh,
20-
input_specs,
21-
output_specs
22-
):
23-
"""Map a function over shards of data.
17+
18+
def _shard_map(func, mesh, input_specs, output_specs):
19+
"""Map a function over shards of data.
2420
2521
Note:
2622
``shard_map`` is an experimental API, and still subject to change. For an
@@ -51,57 +47,58 @@ def _shard_map(
5147
https://docs.jax.dev/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html
5248
"""
5349

54-
def _full_shape(a, spec):
55-
# a is local tensor
56-
# spec is the sharding spec
57-
# return logical shape of global tensor
58-
mesh_name_to_size = dict(
59-
zip(mesh.axis_names, mesh.mesh_shape)
60-
)
61-
62-
result_shape = []
63-
for axis_size, axis_sharding in zip(a.shape, spec):
64-
if axis_sharding is None:
65-
new_size = axis_size
66-
else:
67-
if isinstance(axis_sharding, str):
68-
mesh_mult = mesh_name_to_size[axis_sharding]
69-
else:
70-
# tuple or list
71-
mesh_mult = math.prod(
72-
mesh_name_to_size[a] for a in axis_sharding
73-
if mesh_name_to_size[a] is not None)
74-
75-
if mesh_mult is not None:
76-
new_size = axis_size * mesh_mult
77-
result_shape.append(new_size)
78-
return tuple(result_shape)
79-
80-
def wrapped(*args):
81-
assert len(args) == len(input_specs), f'args={len(args)}; input_specs={len(input_specs)}'
82-
new_args = []
83-
for i, (a, spec) in enumerate(zip(args, input_specs)):
84-
if isinstance(a, torch.Tensor) and spec is not None:
85-
assert(len(a.shape) == len(spec)), f'{i}th input has wrong shape: {a.shape} for {spec}'
86-
new_a = xs.enable_manual_sharding(a, spec, mesh=mesh).global_tensor
87-
new_args.append(new_a)
88-
else:
89-
new_args.append(a)
90-
91-
res = func(*new_args)
92-
if isinstance(res, tuple):
93-
return tuple(
94-
xs.disable_manual_sharding(
95-
a, spec, _full_shape(a, spec), mesh=mesh).global_tensor
96-
if isinstance(a, torch.Tensor) and spec is not None else a
97-
for a, spec in zip(res, output_specs)
98-
)
50+
def _full_shape(a, spec):
51+
# a is local tensor
52+
# spec is the sharding spec
53+
# return logical shape of global tensor
54+
mesh_name_to_size = dict(zip(mesh.axis_names, mesh.mesh_shape))
55+
56+
result_shape = []
57+
for axis_size, axis_sharding in zip(a.shape, spec):
58+
if axis_sharding is None:
59+
new_size = axis_size
60+
else:
61+
if isinstance(axis_sharding, str):
62+
mesh_mult = mesh_name_to_size[axis_sharding]
9963
else:
100-
return xs.disable_manual_sharding(
101-
res, output_specs[0],
102-
_full_shape(res, output_specs[0]), mesh=mesh).global_tensor
103-
return res
104-
return wrapped
64+
# tuple or list
65+
mesh_mult = math.prod(mesh_name_to_size[a]
66+
for a in axis_sharding
67+
if mesh_name_to_size[a] is not None)
68+
69+
if mesh_mult is not None:
70+
new_size = axis_size * mesh_mult
71+
result_shape.append(new_size)
72+
return tuple(result_shape)
73+
74+
def wrapped(*args):
75+
assert len(args) == len(
76+
input_specs), f'args={len(args)}; input_specs={len(input_specs)}'
77+
new_args = []
78+
for i, (a, spec) in enumerate(zip(args, input_specs)):
79+
if isinstance(a, torch.Tensor) and spec is not None:
80+
assert (len(a.shape) == len(spec)
81+
), f'{i}th input has wrong shape: {a.shape} for {spec}'
82+
new_a = xs.enable_manual_sharding(a, spec, mesh=mesh).global_tensor
83+
new_args.append(new_a)
84+
else:
85+
new_args.append(a)
86+
87+
res = func(*new_args)
88+
if isinstance(res, tuple):
89+
return tuple(
90+
xs.disable_manual_sharding(a, spec, _full_shape(a, spec), mesh=mesh).
91+
global_tensor
92+
if isinstance(a, torch.Tensor) and spec is not None else a
93+
for a, spec in zip(res, output_specs))
94+
else:
95+
return xs.disable_manual_sharding(
96+
res, output_specs[0], _full_shape(res, output_specs[0]),
97+
mesh=mesh).global_tensor
98+
return res
99+
100+
return wrapped
101+
105102

106103
def safe_empty_like(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
107104
"""Returns empty tensor like input, or None if input is None."""
@@ -306,16 +303,10 @@ def wrapped_kernel(kernel: Callable,
306303

307304

308305
def _fa_custom_forward_single_device(
309-
q: torch.Tensor,
310-
k: torch.Tensor,
311-
v: torch.Tensor,
312-
causal: bool,
313-
q_segment_ids: torch.Tensor,
314-
kv_segment_ids: torch.Tensor,
315-
sm_scale: float,
316-
ab: Optional[torch.Tensor],
317-
ctx_grad: List[bool]
318-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
306+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool,
307+
q_segment_ids: torch.Tensor, kv_segment_ids: torch.Tensor, sm_scale: float,
308+
ab: Optional[torch.Tensor],
309+
ctx_grad: List[bool]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
319310
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl
320311

321312
num_batches = None
@@ -331,15 +322,13 @@ def _fa_custom_forward_single_device(
331322
kv_segment_ids = kv_segment_ids.reshape(-1, *rest)
332323
if ab is not none:
333324
ab = ab.reshape(-1, *rest)
334-
335325

336326
# Suprisingly, any tensor that is input to the custom_op decorated function will show
337327
# requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
338328
# requires_grad for inputs.
339329
# Original we use save_residuals = q.requires_grad or k.requires_grad or v.requires_grad
340330
save_residuals = any(ctx_grad[:3])
341331

342-
343332
block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"],
344333
k.shape[2])
345334
block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2])
@@ -456,42 +445,42 @@ def fa_custom_forward(
456445

457446
if partition_spec is not None:
458447
if len(partition_spec) == 5:
459-
segment_id_partition_spec = (partition_spec[0], partition_spec[1], partition_spec[3])
448+
segment_id_partition_spec = (partition_spec[0], partition_spec[1],
449+
partition_spec[3])
460450
lm_partition_spec = partition_spec[:4]
461451
else:
462452
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
463453
lm_partition_spec = partition_spec[:3]
464454

465455
input_specs = [
466-
partition_spec, # q
467-
partition_spec, # k
468-
partition_spec, # v
469-
None,
470-
segment_id_partition_spec,
471-
segment_id_partition_spec,
472-
None,
473-
partition_spec,
474-
None,
456+
partition_spec, # q
457+
partition_spec, # k
458+
partition_spec, # v
459+
None,
460+
segment_id_partition_spec,
461+
segment_id_partition_spec,
462+
None,
463+
partition_spec,
464+
None,
475465
]
476466

477467
output_specs = [
478-
partition_spec, # o
479-
lm_partition_spec, # l
480-
lm_partition_spec, # m
468+
partition_spec, # o
469+
lm_partition_spec, # l
470+
lm_partition_spec, # m
481471
]
482472

483473
fa_forward_callable = _shard_map(
484-
_fa_custom_forward_single_device,
485-
mesh,
486-
input_specs,
487-
output_specs,
474+
_fa_custom_forward_single_device,
475+
mesh,
476+
input_specs,
477+
output_specs,
488478
)
489479
else:
490480
fa_forward_callable = _fa_custom_forward_single_device
491481

492-
o, l, m = fa_forward_callable(
493-
q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, ctx_grad
494-
)
482+
o, l, m = fa_forward_callable(q, k, v, causal, q_segment_ids, kv_segment_ids,
483+
sm_scale, ab, ctx_grad)
495484

496485
outs = [o] + [full_q, full_k, full_v, l, m, full_ab]
497486
return tuple(outs)
@@ -523,15 +512,14 @@ def _fa_custom_backward_single_device(
523512
v: torch.Tensor, o: torch.Tensor, l: torch.Tensor, m: torch.Tensor,
524513
q_segment_ids: Optional[torch.Tensor],
525514
kv_segment_ids: Optional[torch.Tensor], ab: Optional[torch.Tensor],
526-
causal: bool, sm_scale: float,
527-
q_full_shape: List[int], kv_full_shape: List[int],
528-
ab_full_shape: Optional[List[int]], ctx_grad: List[bool]
515+
causal: bool, sm_scale: float, q_full_shape: List[int],
516+
kv_full_shape: List[int], ab_full_shape: Optional[List[int]],
517+
ctx_grad: List[bool]
529518
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
530519

531520
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv
532521
grad_q = grad_k = grad_v = grad_ab = segment_ids = None
533522

534-
535523
num_batches = None
536524
batch_size = None
537525
if len(q.shape) == 5:
@@ -663,17 +651,20 @@ def _fa_custom_backward_single_device(
663651
grad_v = grads[1]
664652

665653
if num_batches is not None:
654+
666655
def _reshape(x):
667656
if x is not None:
668657
return x.reshape(num_batches, batch_size, *x.shape[1:])
669658
return None
659+
670660
grad_q = _reshape(grad_q)
671661
grad_k = _reshape(grad_k)
672662
grad_v = _reshape(grad_v)
673663
grad_ab = _reshape(grad_ab)
674664

675665
return grad_q, grad_k, grad_v, grad_ab
676666

667+
677668
@custom_op("xla::fa_custom_backward", mutates_args=())
678669
def fa_custom_backward(
679670
grad_output: torch.Tensor, q: torch.Tensor, k: torch.Tensor,
@@ -696,57 +687,49 @@ def fa_custom_backward(
696687
ab_full_shape = torch.Size(
697688
ab_full_shape) if ab_full_shape is not None else None
698689

699-
700690
if partition_spec:
701691
if len(partition_spec) == 5:
702-
segment_id_partition_spec = (partition_spec[0], partition_spec[1], partition_spec[3])
692+
segment_id_partition_spec = (partition_spec[0], partition_spec[1],
693+
partition_spec[3])
703694
lm_partition_spec = partition_spec[:4]
704695
else:
705696
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
706697
lm_partition_spec = partition_spec[:3]
707698
input_specs = [
708-
partition_spec, # grad_output
709-
partition_spec, # q
710-
partition_spec, # k
711-
partition_spec, # v
712-
partition_spec, # o
713-
lm_partition_spec, # l
714-
lm_partition_spec, # m
715-
segment_id_partition_spec, # q_segment_ids
716-
segment_id_partition_spec, # kv_segment_ids
717-
partition_spec, # ab
718-
None, # causal
719-
None, # sm_scale
720-
None, # q_full_shape
721-
None, # kv_full_shape
722-
None, # ab_full_shape
723-
None, # ctx_grad
699+
partition_spec, # grad_output
700+
partition_spec, # q
701+
partition_spec, # k
702+
partition_spec, # v
703+
partition_spec, # o
704+
lm_partition_spec, # l
705+
lm_partition_spec, # m
706+
segment_id_partition_spec, # q_segment_ids
707+
segment_id_partition_spec, # kv_segment_ids
708+
partition_spec, # ab
709+
None, # causal
710+
None, # sm_scale
711+
None, # q_full_shape
712+
None, # kv_full_shape
713+
None, # ab_full_shape
714+
None, # ctx_grad
724715
]
725716
output_specs = [
726-
partition_spec,
727-
partition_spec,
728-
partition_spec,
729-
partition_spec,
717+
partition_spec,
718+
partition_spec,
719+
partition_spec,
720+
partition_spec,
730721
]
731-
fa_backward_callable = _shard_map(
732-
_fa_custom_backward_single_device,
733-
mesh,
734-
input_specs,
735-
output_specs
736-
)
722+
fa_backward_callable = _shard_map(_fa_custom_backward_single_device, mesh,
723+
input_specs, output_specs)
737724
else:
738725
fa_backward_callable = _fa_custom_backward_single_device
739726

740-
res = fa_backward_callable(
741-
grad_output, q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab, causal, sm_scale,
742-
q_full_shape, kv_full_shape, ab_full_shape, ctx_grad
743-
)
727+
res = fa_backward_callable(grad_output, q, k, v, o, l, m, q_segment_ids,
728+
kv_segment_ids, ab, causal, sm_scale, q_full_shape,
729+
kv_full_shape, ab_full_shape, ctx_grad)
744730

745731
return res
746732

747-
748-
749-
750733

751734
@fa_custom_forward.register_fake
752735
def fa_custom_forward_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,

0 commit comments

Comments
 (0)