You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# AOT compatiable funtion only accepts argument types listed https://github.com/pytorch/pytorch/blob/82859f61857ef39898b34a5cdf0ae56ec25704d9/torch/_functorch/_aot_autograd/utils.py#L23-L34, so we serliaze partition_spec and mesh into string.
850
856
outs=fa_custom_forward(*custom_op_arg, ctx_grads)
851
857
858
+
fori, oinenumerate(outs):
859
+
ifisinstance(o, torch.Tensor):
860
+
print(f'{i}: {o.shape}')
861
+
852
862
o=outs[0]
853
863
full_q, full_k, full_v, l, m, full_ab= [xforxinouts[1:]]
0 commit comments