Skip to content

Commit fc5c20c

Browse files
multigpu runs
1 parent 62696d9 commit fc5c20c

File tree

4 files changed

+82
-10
lines changed

4 files changed

+82
-10
lines changed

Container.lua

+22
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,25 @@ function Container:sparseParameters()
2727
end
2828
return params, gradParams, scales, size
2929
end
30+
31+
function Container:parameters()
32+
local function tinsert(to, from)
33+
if torch.type(from) == 'table' then -- we change this line so that it works with torch.MultiCudaTensor
34+
for i=1,#from do
35+
tinsert(to,from[i])
36+
end
37+
else
38+
table.insert(to,from)
39+
end
40+
end
41+
local w = {}
42+
local gw = {}
43+
for i=1,#self.modules do
44+
local mw,mgw = self.modules[i]:parameters()
45+
if mw then
46+
tinsert(w,mw)
47+
tinsert(gw,mgw)
48+
end
49+
end
50+
return w,gw
51+
end

GPU.lua

+26
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,29 @@ function GPU:fromBatch(...)
8080
return parent.fromBatch(self, unpack(args))
8181
end
8282
end
83+
84+
-- set the device of the decorated module
85+
function GPU:setDevice(device)
86+
self.device = device or self.device
87+
88+
local function recursiveModuleDevice(obj)
89+
if type(obj) == 'table' and not (torch.isTypeOf(obj, 'nn.GPU') or torch.type(obj) == 'torch.MultiCudaTensor') then
90+
for k,v in pairs(obj) do
91+
obj[k] = recursiveModuleDevice(v)
92+
end
93+
elseif torch.type(obj):match('torch.Cuda.*Tensor') then
94+
if obj:getDevice() ~= self.device then
95+
obj = obj:clone() -- this will reallocate it to self.device
96+
local newdevice = obj:getDevice()
97+
-- when nElement() == 0 newdevice is 0
98+
assert(newdevice == self.device or newdevice == 0)
99+
end
100+
end
101+
assert(obj ~= nil)
102+
return obj
103+
end
104+
105+
assert(self.modules[1])
106+
self.modules[1] = cutorch.withDevice(self.device, function() return recursiveModuleDevice(self.modules[1]) end)
107+
return self
108+
end

Module.lua

+31-8
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
6969
if param then
7070
params[paramName] = param
7171
obj[paramName] = nil
72-
if param:storage() then
72+
if torch.isTensor(param) and param.storage and param:storage() then
7373
pointers[torch.pointer(param:storage():data())] = true
7474
end
7575
end
@@ -82,7 +82,7 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
8282
if gradParam then
8383
params[paramName] = gradParam
8484
obj[paramName] = nil
85-
if gradParam:storage() then
85+
if torch.isTensor(gradParam) and gradParam.storage and gradParam:storage() then
8686
pointers[torch.pointer(gradParam:storage():data())] = true
8787
end
8888
end
@@ -144,8 +144,13 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
144144
clone[k] = param
145145
original[k] = param
146146
elseif torch.isTensor(param) then
147-
clone[k] = param.new():set(param)
148-
original[k] = param
147+
if param.storage then
148+
clone[k] = param.new():set(param)
149+
original[k] = param
150+
else -- for torch.MultiCudaTensor
151+
clone[k] = param
152+
original[k] = param
153+
end
149154
elseif type(param) == 'table' then
150155
recursiveSet(clone[k], original[k], param)
151156
end
@@ -397,7 +402,7 @@ function Module:gradParamClip(cutoffNorm, moduleLocal)
397402
local norm = 0
398403
if moduleLocal and self.modules then
399404
for i,module in ipairs(self.modules) do
400-
norm = norm + math.pow(module:gradParamClip(maxOutNorm, maxInNorm), 2)
405+
norm = norm + math.pow(module:gradParamClip(cutoffNorm, moduleLocal), 2)
401406
end
402407
norm = math.sqrt(norm)
403408
else
@@ -406,13 +411,25 @@ function Module:gradParamClip(cutoffNorm, moduleLocal)
406411
return norm
407412
end
408413
for k,gradParam in pairs(gradParams) do -- pairs for sparse params
409-
norm = norm + math.pow(gradParam:norm(),2)
414+
if torch.type(gradParam) == 'torch.CudaTensor' then
415+
cutorch.withDevice(gradParam:getDevice(), function() -- support multi-device models
416+
norm = norm + math.pow(gradParam:norm(),2)
417+
end)
418+
else
419+
norm = norm + math.pow(gradParam:norm(),2)
420+
end
410421
end
411422
norm = math.sqrt(norm)
412423
if norm > cutoffNorm then
413424
-- rescale gradParams to obtain desired cutoffNorm
414425
for k,gradParam in pairs(gradParams) do
415-
gradParam:mul(cutoffNorm/norm)
426+
if torch.type(gradParam) == 'torch.CudaTensor' then
427+
cutorch.withDevice(gradParam:getDevice(), function() -- support multi-device models
428+
gradParam:mul(cutoffNorm/norm)
429+
end)
430+
else
431+
gradParam:mul(cutoffNorm/norm)
432+
end
416433
end
417434
end
418435
end
@@ -455,7 +472,13 @@ function Module:momentumGradParameters()
455472
end
456473
self.momGradParams = {}
457474
for i,gradParam in pairs(gradParams) do
458-
self.momGradParams[i] = gradParam.new():resizeAs(gradParam):copy(gradParam)
475+
if torch.type(gradParam) == 'torch.CudaTensor' then
476+
cutorch.withDevice(gradParam:getDevice(), function() -- support multi-device models
477+
self.momGradParams[i] = gradParam.new():resizeAs(gradParam):copy(gradParam)
478+
end)
479+
else
480+
self.momGradParams[i] = gradParam.new():resizeAs(gradParam):copy(gradParam)
481+
end
459482
end
460483
end
461484
return self.momGradParams

test/test.lua

+3-2
Original file line numberDiff line numberDiff line change
@@ -773,8 +773,9 @@ function dpnntest.ReinforceNormal()
773773
end
774774

775775
function dpnntest.ReinforceGamma()
776-
require 'randomkit'
777-
require 'cephes'
776+
if not pcall(function() require 'randomkit'; require 'cephes' end) then
777+
return
778+
end
778779
local input = torch.rand(500,1000):fill(250) -- shapes
779780
local gradOutput = torch.Tensor() -- will be ignored
780781
local reward = torch.randn(500)

0 commit comments

Comments
 (0)