Skip to content

Commit ead91a2

Browse files
merged
1 parent b7f5362 commit ead91a2

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

VRClassReward.lua

+16-6
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,27 @@ function VRClassReward:updateOutput(input, target)
2222
assert(torch.type(input) == 'table')
2323
local input = self:toBatch(input[1], 1)
2424
self._maxVal = self._maxVal or input.new()
25-
self._maxIdx = self._maxIdx or torch.type(input) == 'torch.CudaTensor' and input.new() or torch.LongTensor()
25+
self._maxIdx = self._maxIdx or torch.type(input) == 'torch.CudaTensor' and torch.CudaLongTensor() or torch.LongTensor()
2626

2727
-- max class value is class prediction
2828
self._maxVal:max(self._maxIdx, input, 2)
29-
if torch.type(self._maxIdx) ~= torch.type(target) then
30-
self._target = self._target or self._maxIdx.new()
29+
30+
-- reward = scale when correctly classified
31+
local maxIdx = self._maxIdx
32+
if torch.type(self._maxIdx) == 'torch.CudaLongTensor' then
33+
self.__maxIdx = self.__maxIdx or torch.CudaTensor()
34+
self.__maxIdx:resize(maxIdx:size()):copy(maxIdx)
35+
maxIdx = self.__maxIdx
36+
end
37+
38+
if torch.type(maxIdx) ~= torch.type(target) then
39+
self._target = self._target or maxIdx.new()
3140
self._target:resize(target:size()):copy(target)
3241
target = self._target
3342
end
3443

35-
-- reward = scale when correctly classified
36-
self._reward = self._maxIdx.new()
37-
self._reward:eq(self._maxIdx, target)
44+
self._reward = self._reward or maxIdx.new()
45+
self._reward:eq(maxIdx, target)
3846
self.reward = self.reward or input.new()
3947
self.reward:resize(self._reward:size(1)):copy(self._reward)
4048
self.reward:mul(self.scale)
@@ -66,6 +74,7 @@ function VRClassReward:updateGradInput(inputTable, target)
6674
self.gradInput[1] = self:fromBatch(self.gradInput[1], 1)
6775

6876
-- learn the baseline reward
77+
self.criterion:forward(baseline, self.reward)
6978
self.gradInput[2] = self.criterion:backward(baseline, self.reward)
7079
self.gradInput[2] = self:fromBatch(self.gradInput[2], 1)
7180
return self.gradInput
@@ -74,6 +83,7 @@ end
7483
function VRClassReward:type(type)
7584
self._maxVal = nil
7685
self._maxIdx = nil
86+
self.__maxIdx = nil
7787
self._target = nil
7888
local module = self.module
7989
self.module = nil

test/test.lua

+21-6
Original file line numberDiff line numberDiff line change
@@ -806,19 +806,34 @@ function dpnntest.ReinforceCategorical()
806806
end
807807

808808
function dpnntest.VRClassReward()
809-
local input = {torch.randn(13,10), torch.randn(13,1)}
809+
local input = {torch.randn(13,10):float(), torch.randn(13,1):float()}
810810
local target = torch.IntTensor(13):random(1,10)
811-
local rf = nn.Reinforce()
812-
local vrc = nn.VRClassReward(rf)
811+
local rf = nn.Reinforce():float()
812+
local vrc = nn.VRClassReward(rf):float()
813813
local err = vrc:forward(input, target)
814814
local gradInput = vrc:backward(input, target)
815815
local val, idx = input[1]:max(2)
816-
local reward = torch.eq(idx:select(2,1):int(), target):double()
816+
local reward = torch.eq(idx:select(2,1):int(), target):float()
817817
local err2 = -reward:mean()
818818
mytester:assert(err == err2, "VRClassReward forward err")
819-
local gradInput2 = nn.MSECriterion():backward(input[2], reward)
819+
local gradInput2 = nn.MSECriterion():float():backward(input[2], reward)
820820
mytester:assertTensorEq(gradInput[2], gradInput2, 0.000001, "VRClassReward backward baseline err")
821-
mytester:assertTensorEq(gradInput[1], input[1]:zero(), 0.000001, "VRClassReward backward class err")
821+
mytester:assert(math.abs(gradInput[1]:sum()) < 0.000001, "VRClassReward backward class err")
822+
823+
if pcall(function() require 'cunn' end) then
824+
local gradInput = {gradInput[1], gradInput[2]}
825+
input[1], input[2] = input[1]:cuda(), input[2]:cuda()
826+
target = target:cuda()
827+
rf:cuda()
828+
vrc:cuda()
829+
830+
local err2 = vrc:forward(input, target)
831+
local gradInput2 = vrc:backward(input, target)
832+
833+
mytester:assert(math.abs(err - err2) < 0.000001, "VRClassReward forward cuda err")
834+
mytester:assertTensorEq(gradInput[2], gradInput2[2]:float(), 0.000001, "VRClassReward backward baseline cuda err")
835+
mytester:assertTensorEq(gradInput[1], gradInput2[1]:float(), 0.000001, "VRClassReward backward class cuda err")
836+
end
822837
end
823838

824839
function dpnntest.Clip()

0 commit comments

Comments
 (0)