We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 2962be0 commit 117f39fCopy full SHA for 117f39f
torchrec/pt2/utils.py
@@ -75,12 +75,15 @@ def kjt_for_pt2_tracing(
75
76
values = kjt.values().long()
77
torch._dynamo.decorators.mark_unbacked(values, 0)
78
+ weights = kjt.weights_or_none()
79
+ if weights is not None:
80
+ torch._dynamo.decorators.mark_unbacked(weights, 0)
81
82
return KeyedJaggedTensor(
83
keys=kjt.keys(),
84
values=values,
85
lengths=lengths,
- weights=kjt.weights_or_none(),
86
+ weights=weights,
87
stride=stride if not is_vb else None,
88
stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
89
inverse_indices=inverse_indices,
0 commit comments