Skip to content

Commit 62696d9

Browse files
NCE:multicuda()
1 parent f46acf0 commit 62696d9

File tree

3 files changed

+164
-13
lines changed

3 files changed

+164
-13
lines changed

Module.lua

+2-1
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,8 @@ function Module:updateGradParameters(momFactor, momDamp, momNesterov)
484484
end
485485
local momGradParams = self:momentumGradParameters()
486486
for i,gradParam in pairs(gradParams) do
487-
momGradParams[i]:mul(momFactor):add(1-momDamp, gradParam)
487+
momGradParams[i]:mul(momFactor)
488+
momGradParams[i]:add(1-momDamp, gradParam)
488489
end
489490

490491
if momNesterov then

NCEModule.lua

+44-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
-- Ref.: A. https://www.cs.toronto.edu/~amnih/papers/ncelm.pdf
44
------------------------------------------------------------------------
55
local NCEModule, parent = torch.class("nn.NCEModule", "nn.Linear")
6-
NCEModule.version = 3
6+
NCEModule.version = 4 -- added multicuda()
77

88
-- for efficient serialization
99
local empty = _.clone(parent.dpnn_mediumEmpty)
@@ -62,7 +62,7 @@ function NCEModule:updateOutput(inputTable)
6262
if self.addBuffer:nElement() ~= batchsize then
6363
self.addBuffer:resize(batchsize):fill(1)
6464
end
65-
self.linout:addmm(0, self.linout, 1, input, self.weight:t())
65+
self.weight.addmm(self.linout, 0, self.linout, 1, input, self.weight:t())
6666
if self.bias then self.linout:addr(1, self.addBuffer, self.bias) end
6767
self.output = torch.type(self.output) == 'table' and input.new() or self.output
6868
if self.logsoftmax then
@@ -102,8 +102,8 @@ function NCEModule:updateOutput(inputTable)
102102
end
103103

104104
-- build (batchsize x k+1 x inputsize) weight tensor
105-
self._weight = self._weight or self.weight.new()
106-
self._weight:index(self.weight, 1, self.sampleidx:view(-1))
105+
self._weight = self._weight or self.bias.new()
106+
self.weight.index(self._weight, self.weight, 1, self.sampleidx:view(-1))
107107
assert(self._weight:nElement() == batchsize*(self.k+1)*inputsize)
108108
self._weight:resize(batchsize, self.k+1, inputsize)
109109

@@ -190,7 +190,7 @@ function NCEModule:accGradParameters(inputTable, gradOutput, scale)
190190
local batchsize = input:size(1)
191191
local inputsize = self.weight:size(2)
192192

193-
self._gradWeight = self._gradWeight or self.gradWeight.new()
193+
self._gradWeight = self._gradWeight or self.bias.new()
194194
self._gradWeight:resizeAs(self._weight):zero() -- batchsize x k+1 x inputsize
195195
self._gradOutput:resize(batchsize, self.k+1, 1)
196196
self._gradOutput:mul(scale)
@@ -216,7 +216,23 @@ function NCEModule:type(type, cache)
216216
local unigrams = self.unigrams
217217
self.unigrams = nil
218218
local am = self.aliasmultinomial
219-
local rtn = parent.type(self, type, cache)
219+
220+
local rtn
221+
if type and torch.type(self.weight) == 'torch.MultiCudaTensor' then
222+
assert(type == 'torch.CudaTensor', "Cannot convert a multicuda NCEModule to anything other than cuda")
223+
local weight = self.weight
224+
local gradWeight = self.gradWeight
225+
self.weight = nil
226+
self.gradWeight = nil
227+
228+
rtn = parent.type(self, type, cache)
229+
230+
self.weight = weight
231+
self.gradWeight = gradWeight
232+
else
233+
rtn = parent.type(self, type, cache)
234+
end
235+
220236
self.unigrams = unigrams
221237
self.aliasmultinomial = am
222238
return rtn
@@ -264,3 +280,25 @@ function NCEModule:clearState()
264280
gradInput:set()
265281
end
266282
end
283+
284+
function NCEModule:multicuda(device1, device2)
285+
assert(device1 and device2, "specify two devices as arguments")
286+
require 'torchx'
287+
assert(torchx.version and torchx.version >= 1, "update torchx: luarocks install torchx")
288+
289+
self:float()
290+
291+
local isize = self.weight:size(2)
292+
local weights = {
293+
cutorch.withDevice(device1, function() return self.weight[{{}, {1, torch.round(isize/2)}}]:cuda() end),
294+
cutorch.withDevice(device2, function() return self.weight[{{}, {torch.round(isize/2)+1, isize}}]:cuda() end)
295+
}
296+
self.weight = torch.MultiCudaTensor(2, weights)
297+
local gradWeights = {
298+
cutorch.withDevice(device1, function() return self.gradWeight[{{}, {1, torch.round(isize/2)}}]:cuda() end),
299+
cutorch.withDevice(device2, function() return self.gradWeight[{{}, {torch.round(isize/2)+1, isize}}]:cuda() end)
300+
}
301+
self.gradWeight = torch.MultiCudaTensor(2, gradWeights)
302+
303+
self:cuda()
304+
end

test/test.lua

+118-6
Original file line numberDiff line numberDiff line change
@@ -2199,7 +2199,7 @@ function dpnntest.OneHot()
21992199
end
22002200
end
22012201

2202-
function dpnntest.NCE()
2202+
function dpnntest.NCE_main()
22032203
local batchsize = 4
22042204
local k = 10
22052205
local inputsize = 3
@@ -2353,7 +2353,7 @@ function dpnntest.NCE()
23532353
local linear = nn.Linear(inputsize, outputsize)
23542354
linear.weight:copy(ncem.weight)
23552355
linear.bias:copy(ncem.bias)
2356-
local mlp = nn.Sequential():add(linear):add(nn.Exp())
2356+
local mlp = nn.Sequential():add(linear):add(nn.Exp()):add(nn.MulConstant(1/ncem.Z[1]))
23572357
mlp:cuda()
23582358

23592359
local output2_ = mlp:forward(input)
@@ -2455,6 +2455,8 @@ function dpnntest.NCE_multinomial()
24552455
end
24562456

24572457
function dpnnbigtest.NCE_benchmark()
2458+
pcall(function() require 'cunn' end) -- make sure to import cunn before initializing large tensors, else weird segfault...
2459+
24582460
local nclass = 1000000
24592461
local hiddensize = 200
24602462
local batchsize = 50
@@ -2483,8 +2485,6 @@ function dpnnbigtest.NCE_benchmark()
24832485
sync = function() cutorch.synchronize() end
24842486
end
24852487

2486-
print(torch.type(nce.unigrams))
2487-
24882488
local output = nce:forward{input, target}
24892489
local loss = crit:forward(output, target)
24902490
local gradOutput = crit:backward(output, target)
@@ -2494,8 +2494,8 @@ function dpnnbigtest.NCE_benchmark()
24942494
local loss = nll:forward(output, target)
24952495
local gradOutput = nll:backward(output, target)
24962496
local gradInput = mlp:backward(input, gradOutput)
2497-
sync()
24982497

2498+
sync()
24992499
local a = torch.Timer()
25002500
for i=1,nloop do
25012501
output = nce:forward{input, target}
@@ -2525,7 +2525,6 @@ function dpnnbigtest.NCE_benchmark()
25252525
local ncebwd = a:time().real
25262526

25272527
-- mlp nll
2528-
25292528
local a = torch.Timer()
25302529
for i=1,nloop do
25312530
output = mlp:forward(input)
@@ -2561,6 +2560,38 @@ function dpnnbigtest.NCE_benchmark()
25612560
print("criterion:backward (nce vs nll)", critbwd, nllbwd)
25622561
print("module:backward (nce vs linear)", ncebwd, mlpbwd)
25632562
print("total (nce vs linear)", ncetotal, lintotal, lintotal/ncetotal)
2563+
2564+
if not (cunn and cutorch.getDeviceCount() > 1) then
2565+
return
2566+
end
2567+
2568+
nce:multicuda(1,2)
2569+
2570+
local output = nce:forward{input, target}
2571+
local loss = crit:forward(output, target)
2572+
local gradOutput = crit:backward(output, target)
2573+
local gradInput = nce:backward({input, target}, gradOutput)
2574+
sync()
2575+
2576+
local a = torch.Timer()
2577+
for i=1,nloop do
2578+
output = nce:forward{input, target}
2579+
end
2580+
sync()
2581+
local ncefwd2 = a:time().real
2582+
2583+
a:reset()
2584+
for i=1,nloop do
2585+
gradInput = nce:backward({input, target}, gradOutput)
2586+
end
2587+
sync()
2588+
local ncebwd2 = a:time().real
2589+
2590+
local total1 = ncefwd+ncebwd
2591+
local total2 = ncefwd2+ncebwd2
2592+
print("module:forward (1 vs 2 gpu)", ncefwd, ncefwd2)
2593+
print("module:backward (1 vs 2 gpu)", ncebwd, ncebwd2)
2594+
print("total (1 vs 2 gpu)", total1, total2, total2/total1)
25642595
end
25652596

25662597
function dpnntest.NaN()
@@ -2599,6 +2630,87 @@ function dpnntest.NaN()
25992630
mytester:assert(not pcall(function() nan:backward(input, gradOutput) end))
26002631
end
26012632

2633+
function dpnntest.NCE_multicuda()
2634+
if not pcall(function() require 'torchx' end) then
2635+
return
2636+
end
2637+
if not pcall(function() require 'cunn' end) then
2638+
return
2639+
end
2640+
if cutorch.getDeviceCount() < 2 then
2641+
return
2642+
end
2643+
assert(torchx.version and torchx.version >= 1, "Update torchx")
2644+
2645+
local nclass = 1000
2646+
local hiddensize = 20
2647+
local batchsize = 5
2648+
local k = 25
2649+
local unigrams = torch.Tensor(nclass):uniform(0,1)
2650+
local noise = torch.LongTensor(batchsize, k):random(1,nclass)
2651+
2652+
local crit = nn.NCECriterion():cuda()
2653+
local crit2 = nn.NCECriterion():cuda()
2654+
2655+
local nce = nn.NCEModule(hiddensize, nclass, k, unigrams)
2656+
2657+
-- make it deterministic
2658+
nce.noiseSample = function(self, sampleidx, batchsize, k)
2659+
sampleidx:resize(batchsize, k)
2660+
sampleidx:copy(noise)
2661+
return sampleidx
2662+
end
2663+
2664+
local nce2 = nce:clone()
2665+
nce2:cuda()
2666+
2667+
local input = torch.randn(batchsize, hiddensize):cuda()
2668+
local target = torch.LongTensor(batchsize):random(1,nclass):cuda()
2669+
2670+
nce:multicuda(1, 2)
2671+
2672+
local output = nce:forward{input, target}
2673+
local loss = crit:forward(output, target)
2674+
local gradOutput = crit:backward(output, target)
2675+
nce:zeroGradParameters()
2676+
local gradInput = nce:backward({input, target}, gradOutput)
2677+
2678+
local output2 = nce2:forward{input, target}
2679+
local loss2 = crit2:forward(output2, target)
2680+
local gradOutput2 = crit2:backward(output2, target)
2681+
nce2:zeroGradParameters()
2682+
local gradInput2 = nce2:backward({input, target}, gradOutput2)
2683+
2684+
mytester:assertTensorEq(output[1], output2[1], 0.00001)
2685+
mytester:assertTensorEq(output[2], output2[2], 0.00001)
2686+
mytester:assertTensorEq(output[3], output2[3], 0.00001)
2687+
mytester:assertTensorEq(output[4], output2[4], 0.00001)
2688+
2689+
mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.00001)
2690+
mytester:assertTensorEq(gradInput[2], gradInput2[2], 0.00001)
2691+
2692+
2693+
nce2:updateParameters(0.1)
2694+
nce:updateParameters(0.1)
2695+
2696+
mytester:assertTensorEq(nce2.bias, nce.bias, 0.000001)
2697+
mytester:assertTensorEq(nce2.gradBias, nce.gradBias, 0.000001)
2698+
mytester:assertTensorEq(nce2.weight[{{},{1,hiddensize/2}}]:float(), nce.weight.tensors[1]:float(), 0.000001)
2699+
mytester:assertTensorEq(nce2.weight[{{},{1+(hiddensize/2), hiddensize}}]:float(), nce.weight.tensors[2]:float(), 0.000001)
2700+
mytester:assertTensorEq(nce2.gradWeight[{{},{1,hiddensize/2}}]:float(), nce.gradWeight.tensors[1]:float(), 0.000001)
2701+
mytester:assertTensorEq(nce2.gradWeight[{{},{1+(hiddensize/2), hiddensize}}]:float(), nce.gradWeight.tensors[2]:float(), 0.000001)
2702+
2703+
-- test momentum
2704+
nce2:updateGradParameters(0.9)
2705+
nce:updateGradParameters(0.9)
2706+
2707+
mytester:assertTensorEq(nce2.gradBias, nce.gradBias, 0.000001)
2708+
mytester:assertTensorEq(nce2.momGradParams[1][{{},{1,hiddensize/2}}]:float(), nce.momGradParams[1].tensors[1]:float(), 0.000001)
2709+
mytester:assertTensorEq(nce2.momGradParams[1][{{},{1+(hiddensize/2), hiddensize}}]:float(), nce.momGradParams[1].tensors[2]:float(), 0.000001)
2710+
mytester:assertTensorEq(nce2.gradWeight[{{},{1,hiddensize/2}}]:float(), nce.gradWeight.tensors[1]:float(), 0.000001)
2711+
mytester:assertTensorEq(nce2.gradWeight[{{},{1+(hiddensize/2), hiddensize}}]:float(), nce.gradWeight.tensors[2]:float(), 0.000001)
2712+
end
2713+
26022714
function dpnn.test(tests)
26032715
mytester = torch.Tester()
26042716
mytester:add(dpnntest)

0 commit comments

Comments
 (0)