Skip to content

Commit 21c4d87

Browse files
KubaMichalczykJakub Michalczyk
and
Jakub Michalczyk
authored
[BugFix] correct dim for resolving dtype in _split_and_pad_sequence (#2801)
Co-authored-by: Jakub Michalczyk <[email protected]>
1 parent fb641de commit 21c4d87

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchrl/objectives/value/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,9 @@ def _split_and_pad_sequence(
286286

287287
# int16 supports length up to 32767
288288
dtype = (
289-
torch.int16 if tensor.shape[-2] < torch.iinfo(torch.int16).max else torch.int32
289+
torch.int16
290+
if tensor.size(time_dim) < torch.iinfo(torch.int16).max
291+
else torch.int32
290292
)
291293
arange = torch.arange(max_seq_len, device=tensor.device, dtype=dtype).unsqueeze(0)
292294
mask = arange < splits.unsqueeze(1)

0 commit comments

Comments
 (0)