File tree 2 files changed +5
-5
lines changed
2 files changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -56,7 +56,7 @@ class DQNLoss(LossModule):
56
56
57
57
"""
58
58
59
- default_value_estimator = ValueEstimators .TDLambda
59
+ default_value_estimator = ValueEstimators .TD0
60
60
61
61
def __init__ (
62
62
self ,
Original file line number Diff line number Diff line change @@ -58,16 +58,16 @@ def transpose_tensor(tensor):
58
58
or tensor .numel () <= 1
59
59
):
60
60
return tensor , False
61
- if time_dim < 0 :
62
- timedim = tensor .ndim + time_dim
61
+ if time_dim >= 0 :
62
+ timedim = time_dim - tensor .ndim
63
63
else :
64
64
timedim = time_dim
65
- if timedim < 0 or timedim >= tensor . ndim :
65
+ if timedim < - tensor . ndim or timedim >= 0 :
66
66
raise RuntimeError (ERROR .format (tensor .shape , timedim ))
67
67
if tensor .ndim >= 2 :
68
68
single_dim = False
69
69
tensor = tensor .transpose (timedim , - 2 )
70
- elif tensor .ndim == 1 and timedim == 0 :
70
+ elif tensor .ndim == 1 and timedim == - 1 :
71
71
single_dim = True
72
72
tensor = tensor .unsqueeze (- 1 )
73
73
else :
You can’t perform that action at this time.
0 commit comments