Skip to content

Commit 3c8197b

Browse files
authored
[BugFix] Fix examples (#1173)
1 parent 98cafa5 commit 3c8197b

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

torchrl/objectives/dqn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class DQNLoss(LossModule):
5656
5757
"""
5858

59-
default_value_estimator = ValueEstimators.TDLambda
59+
default_value_estimator = ValueEstimators.TD0
6060

6161
def __init__(
6262
self,

torchrl/objectives/value/functional.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,16 @@ def transpose_tensor(tensor):
5858
or tensor.numel() <= 1
5959
):
6060
return tensor, False
61-
if time_dim < 0:
62-
timedim = tensor.ndim + time_dim
61+
if time_dim >= 0:
62+
timedim = time_dim - tensor.ndim
6363
else:
6464
timedim = time_dim
65-
if timedim < 0 or timedim >= tensor.ndim:
65+
if timedim < -tensor.ndim or timedim >= 0:
6666
raise RuntimeError(ERROR.format(tensor.shape, timedim))
6767
if tensor.ndim >= 2:
6868
single_dim = False
6969
tensor = tensor.transpose(timedim, -2)
70-
elif tensor.ndim == 1 and timedim == 0:
70+
elif tensor.ndim == 1 and timedim == -1:
7171
single_dim = True
7272
tensor = tensor.unsqueeze(-1)
7373
else:

0 commit comments

Comments
 (0)