We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 7ae70cd commit 8c76fddCopy full SHA for 8c76fdd
torchrec/pt2/utils.py
@@ -75,12 +75,16 @@ def kjt_for_pt2_tracing(
75
76
values = kjt.values().long()
77
torch._dynamo.decorators.mark_unbacked(values, 0)
78
+ weights = kjt.weights_or_none()
79
+ if weights is not None:
80
+ weights = weights.float()
81
+ torch._dynamo.decorators.mark_unbacked(weights, 0)
82
83
return KeyedJaggedTensor(
84
keys=kjt.keys(),
85
values=values,
86
lengths=lengths,
- weights=kjt.weights_or_none(),
87
+ weights=weights,
88
stride=stride if not is_vb else None,
89
stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
90
inverse_indices=inverse_indices,
0 commit comments