Skip to content

Commit 5998c5e

Browse files
Jakub MichalczykVincent Moens
Jakub Michalczyk
authored and
Vincent Moens
committed
correct dim for resolving dtype in _split_and_pad_sequence
1 parent fb641de commit 5998c5e

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchrl/objectives/value/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,9 @@ def _split_and_pad_sequence(
286286

287287
# int16 supports length up to 32767
288288
dtype = (
289-
torch.int16 if tensor.shape[-2] < torch.iinfo(torch.int16).max else torch.int32
289+
torch.int16
290+
if tensor.size(time_dim) < torch.iinfo(torch.int16).max
291+
else torch.int32
290292
)
291293
arange = torch.arange(max_seq_len, device=tensor.device, dtype=dtype).unsqueeze(0)
292294
mask = arange < splits.unsqueeze(1)

0 commit comments

Comments
 (0)