@@ -806,19 +806,34 @@ function dpnntest.ReinforceCategorical()
806
806
end
807
807
808
808
function dpnntest .VRClassReward ()
809
- local input = {torch .randn (13 ,10 ), torch .randn (13 ,1 )}
809
+ local input = {torch .randn (13 ,10 ): float () , torch .randn (13 ,1 ): float ( )}
810
810
local target = torch .IntTensor (13 ):random (1 ,10 )
811
- local rf = nn .Reinforce ()
812
- local vrc = nn .VRClassReward (rf )
811
+ local rf = nn .Reinforce (): float ()
812
+ local vrc = nn .VRClassReward (rf ): float ()
813
813
local err = vrc :forward (input , target )
814
814
local gradInput = vrc :backward (input , target )
815
815
local val , idx = input [1 ]:max (2 )
816
- local reward = torch .eq (idx :select (2 ,1 ):int (), target ):double ()
816
+ local reward = torch .eq (idx :select (2 ,1 ):int (), target ):float ()
817
817
local err2 = - reward :mean ()
818
818
mytester :assert (err == err2 , " VRClassReward forward err" )
819
- local gradInput2 = nn .MSECriterion ():backward (input [2 ], reward )
819
+ local gradInput2 = nn .MSECriterion ():float (): backward (input [2 ], reward )
820
820
mytester :assertTensorEq (gradInput [2 ], gradInput2 , 0.000001 , " VRClassReward backward baseline err" )
821
- mytester :assertTensorEq (gradInput [1 ], input [1 ]:zero (), 0.000001 , " VRClassReward backward class err" )
821
+ mytester :assert (math.abs (gradInput [1 ]:sum ()) < 0.000001 , " VRClassReward backward class err" )
822
+
823
+ if pcall (function () require ' cunn' end ) then
824
+ local gradInput = {gradInput [1 ], gradInput [2 ]}
825
+ input [1 ], input [2 ] = input [1 ]:cuda (), input [2 ]:cuda ()
826
+ target = target :cuda ()
827
+ rf :cuda ()
828
+ vrc :cuda ()
829
+
830
+ local err2 = vrc :forward (input , target )
831
+ local gradInput2 = vrc :backward (input , target )
832
+
833
+ mytester :assert (math.abs (err - err2 ) < 0.000001 , " VRClassReward forward cuda err" )
834
+ mytester :assertTensorEq (gradInput [2 ], gradInput2 [2 ]:float (), 0.000001 , " VRClassReward backward baseline cuda err" )
835
+ mytester :assertTensorEq (gradInput [1 ], gradInput2 [1 ]:float (), 0.000001 , " VRClassReward backward class cuda err" )
836
+ end
822
837
end
823
838
824
839
function dpnntest .BinaryClassReward ()
0 commit comments