Skip to content

Commit c97f143

Browse files
fix merge
2 parents 2267dc7 + 5b2eb7f commit c97f143

File tree

3 files changed

+49
-13
lines changed

3 files changed

+49
-13
lines changed

Module.lua

+12-2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
5959
moduleTree = obj
6060
obj = nil
6161
isTable = false
62+
elseif obj.dpnn_sharedClone then
63+
-- allow to use a custom sharedClone method on one module
64+
moduleTree = obj
65+
obj = nil
66+
isTable = false
6267
elseif scdone[torch.pointer(obj)] then
6368
moduleTree = scdone[torch.pointer(obj)]
6469
else
@@ -142,8 +147,13 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
142147
if scdone[torch.pointer(original)] then
143148
for k,param in pairs(moduleTree) do
144149
if torch.isTypeOf(param,'nn.Module') then
145-
-- AbstractRecurrent instances branch here with stepClone = true
146-
clone[k] = param
150+
if param.dpnn_sharedClone then
151+
-- Call the custom sharedClone
152+
clone[k] = param:dpnn_sharedClone()
153+
else
154+
-- AbstractRecurrent instances branch here with stepClone = true
155+
clone[k] = param
156+
end
147157
original[k] = param
148158
elseif torch.isTensor(param) then
149159
if param.storage then

VRClassReward.lua

+16-5
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,28 @@ function VRClassReward:updateOutput(input, target)
2222
assert(torch.type(input) == 'table')
2323
local input = self:toBatch(input[1], 1)
2424
self._maxVal = self._maxVal or input.new()
25-
self._maxIdx = self._maxIdx or torch.type(input) == 'torch.CudaTensor' and input.new() or torch.LongTensor()
25+
self._maxIdx = self._maxIdx or torch.type(input) == 'torch.CudaTensor' and torch.CudaLongTensor() or torch.LongTensor()
2626

2727
-- max class value is class prediction
2828
self._maxVal:max(self._maxIdx, input, 2)
29-
if torch.type(self._maxIdx) ~= torch.type(target) then
30-
self._target = self._target or self._maxIdx.new()
29+
30+
-- reward = scale when correctly classified
31+
local maxIdx = self._maxIdx
32+
if torch.type(self._maxIdx) == 'torch.CudaLongTensor' then
33+
self.__maxIdx = self.__maxIdx or torch.CudaTensor()
34+
self.__maxIdx:resize(maxIdx:size()):copy(maxIdx)
35+
maxIdx = self.__maxIdx
36+
end
37+
38+
if torch.type(maxIdx) ~= torch.type(target) then
39+
self._target = self._target or maxIdx.new()
3140
self._target:resize(target:size()):copy(target)
3241
target = self._target
3342
end
3443

3544
-- reward = scale when correctly classified
36-
self._reward = self._reward or self._maxIdx.new()
37-
self._reward:eq(self._maxIdx, target)
45+
self._reward = self._reward or maxIdx.new()
46+
self._reward:eq(maxIdx, target)
3847
self.reward = self.reward or input.new()
3948
self.reward:resize(self._reward:size(1)):copy(self._reward)
4049
self.reward:mul(self.scale)
@@ -66,6 +75,7 @@ function VRClassReward:updateGradInput(inputTable, target)
6675
self.gradInput[1] = self:fromBatch(self.gradInput[1], 1)
6776

6877
-- learn the baseline reward
78+
self.criterion:forward(baseline, self.reward)
6979
self.gradInput[2] = self.criterion:backward(baseline, self.reward)
7080
self.gradInput[2] = self:fromBatch(self.gradInput[2], 1)
7181
return self.gradInput
@@ -74,6 +84,7 @@ end
7484
function VRClassReward:type(type)
7585
self._maxVal = nil
7686
self._maxIdx = nil
87+
self.__maxIdx = nil
7788
self._target = nil
7889
local module = self.module
7990
self.module = nil

test/test.lua

+21-6
Original file line numberDiff line numberDiff line change
@@ -806,19 +806,34 @@ function dpnntest.ReinforceCategorical()
806806
end
807807

808808
function dpnntest.VRClassReward()
809-
local input = {torch.randn(13,10), torch.randn(13,1)}
809+
local input = {torch.randn(13,10):float(), torch.randn(13,1):float()}
810810
local target = torch.IntTensor(13):random(1,10)
811-
local rf = nn.Reinforce()
812-
local vrc = nn.VRClassReward(rf)
811+
local rf = nn.Reinforce():float()
812+
local vrc = nn.VRClassReward(rf):float()
813813
local err = vrc:forward(input, target)
814814
local gradInput = vrc:backward(input, target)
815815
local val, idx = input[1]:max(2)
816-
local reward = torch.eq(idx:select(2,1):int(), target):double()
816+
local reward = torch.eq(idx:select(2,1):int(), target):float()
817817
local err2 = -reward:mean()
818818
mytester:assert(err == err2, "VRClassReward forward err")
819-
local gradInput2 = nn.MSECriterion():backward(input[2], reward)
819+
local gradInput2 = nn.MSECriterion():float():backward(input[2], reward)
820820
mytester:assertTensorEq(gradInput[2], gradInput2, 0.000001, "VRClassReward backward baseline err")
821-
mytester:assertTensorEq(gradInput[1], input[1]:zero(), 0.000001, "VRClassReward backward class err")
821+
mytester:assert(math.abs(gradInput[1]:sum()) < 0.000001, "VRClassReward backward class err")
822+
823+
if pcall(function() require 'cunn' end) then
824+
local gradInput = {gradInput[1], gradInput[2]}
825+
input[1], input[2] = input[1]:cuda(), input[2]:cuda()
826+
target = target:cuda()
827+
rf:cuda()
828+
vrc:cuda()
829+
830+
local err2 = vrc:forward(input, target)
831+
local gradInput2 = vrc:backward(input, target)
832+
833+
mytester:assert(math.abs(err - err2) < 0.000001, "VRClassReward forward cuda err")
834+
mytester:assertTensorEq(gradInput[2], gradInput2[2]:float(), 0.000001, "VRClassReward backward baseline cuda err")
835+
mytester:assertTensorEq(gradInput[1], gradInput2[1]:float(), 0.000001, "VRClassReward backward class cuda err")
836+
end
822837
end
823838

824839
function dpnntest.BinaryClassReward()

0 commit comments

Comments
 (0)