Skip to content

Commit 8c76fdd

Browse files
Microvefacebook-github-bot
authored andcommitted
Mark weights unbacked
Summary: This is to avoid recompilations caused by the shape changes of `_weights` in KJT. Differential Revision: D66342695
1 parent 7ae70cd commit 8c76fdd

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

Diff for: torchrec/pt2/utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,16 @@ def kjt_for_pt2_tracing(
7575

7676
values = kjt.values().long()
7777
torch._dynamo.decorators.mark_unbacked(values, 0)
78+
weights = kjt.weights_or_none()
79+
if weights is not None:
80+
weights = weights.float()
81+
torch._dynamo.decorators.mark_unbacked(weights, 0)
7882

7983
return KeyedJaggedTensor(
8084
keys=kjt.keys(),
8185
values=values,
8286
lengths=lengths,
83-
weights=kjt.weights_or_none(),
87+
weights=weights,
8488
stride=stride if not is_vb else None,
8589
stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
8690
inverse_indices=inverse_indices,

0 commit comments

Comments
 (0)