Skip to content

Commit 117f39f

Browse files
Microvefacebook-github-bot
authored andcommitted
Mark weights unbacked (#2583)
Summary: This is to avoid recompilations caused by the shape changes of `_weights` in KJT. Differential Revision: D66342695
1 parent 2962be0 commit 117f39f

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

Diff for: torchrec/pt2/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,15 @@ def kjt_for_pt2_tracing(
7575

7676
values = kjt.values().long()
7777
torch._dynamo.decorators.mark_unbacked(values, 0)
78+
weights = kjt.weights_or_none()
79+
if weights is not None:
80+
torch._dynamo.decorators.mark_unbacked(weights, 0)
7881

7982
return KeyedJaggedTensor(
8083
keys=kjt.keys(),
8184
values=values,
8285
lengths=lengths,
83-
weights=kjt.weights_or_none(),
86+
weights=weights,
8487
stride=stride if not is_vb else None,
8588
stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
8689
inverse_indices=inverse_indices,

0 commit comments

Comments
 (0)