@@ -69,7 +69,7 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
69
69
if param then
70
70
params [paramName ] = param
71
71
obj [paramName ] = nil
72
- if param :storage () then
72
+ if torch . isTensor ( param ) and param . storage and param :storage () then
73
73
pointers [torch .pointer (param :storage ():data ())] = true
74
74
end
75
75
end
@@ -82,7 +82,7 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
82
82
if gradParam then
83
83
params [paramName ] = gradParam
84
84
obj [paramName ] = nil
85
- if gradParam :storage () then
85
+ if torch . isTensor ( gradParam ) and gradParam . storage and gradParam :storage () then
86
86
pointers [torch .pointer (gradParam :storage ():data ())] = true
87
87
end
88
88
end
@@ -144,8 +144,13 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
144
144
clone [k ] = param
145
145
original [k ] = param
146
146
elseif torch .isTensor (param ) then
147
- clone [k ] = param .new ():set (param )
148
- original [k ] = param
147
+ if param .storage then
148
+ clone [k ] = param .new ():set (param )
149
+ original [k ] = param
150
+ else -- for torch.MultiCudaTensor
151
+ clone [k ] = param
152
+ original [k ] = param
153
+ end
149
154
elseif type (param ) == ' table' then
150
155
recursiveSet (clone [k ], original [k ], param )
151
156
end
@@ -397,7 +402,7 @@ function Module:gradParamClip(cutoffNorm, moduleLocal)
397
402
local norm = 0
398
403
if moduleLocal and self .modules then
399
404
for i ,module in ipairs (self .modules ) do
400
- norm = norm + math.pow (module :gradParamClip (maxOutNorm , maxInNorm ), 2 )
405
+ norm = norm + math.pow (module :gradParamClip (cutoffNorm , moduleLocal ), 2 )
401
406
end
402
407
norm = math.sqrt (norm )
403
408
else
@@ -406,13 +411,25 @@ function Module:gradParamClip(cutoffNorm, moduleLocal)
406
411
return norm
407
412
end
408
413
for k ,gradParam in pairs (gradParams ) do -- pairs for sparse params
409
- norm = norm + math.pow (gradParam :norm (),2 )
414
+ if torch .type (gradParam ) == ' torch.CudaTensor' then
415
+ cutorch .withDevice (gradParam :getDevice (), function () -- support multi-device models
416
+ norm = norm + math.pow (gradParam :norm (),2 )
417
+ end )
418
+ else
419
+ norm = norm + math.pow (gradParam :norm (),2 )
420
+ end
410
421
end
411
422
norm = math.sqrt (norm )
412
423
if norm > cutoffNorm then
413
424
-- rescale gradParams to obtain desired cutoffNorm
414
425
for k ,gradParam in pairs (gradParams ) do
415
- gradParam :mul (cutoffNorm / norm )
426
+ if torch .type (gradParam ) == ' torch.CudaTensor' then
427
+ cutorch .withDevice (gradParam :getDevice (), function () -- support multi-device models
428
+ gradParam :mul (cutoffNorm / norm )
429
+ end )
430
+ else
431
+ gradParam :mul (cutoffNorm / norm )
432
+ end
416
433
end
417
434
end
418
435
end
@@ -455,7 +472,13 @@ function Module:momentumGradParameters()
455
472
end
456
473
self .momGradParams = {}
457
474
for i ,gradParam in pairs (gradParams ) do
458
- self .momGradParams [i ] = gradParam .new ():resizeAs (gradParam ):copy (gradParam )
475
+ if torch .type (gradParam ) == ' torch.CudaTensor' then
476
+ cutorch .withDevice (gradParam :getDevice (), function () -- support multi-device models
477
+ self .momGradParams [i ] = gradParam .new ():resizeAs (gradParam ):copy (gradParam )
478
+ end )
479
+ else
480
+ self .momGradParams [i ] = gradParam .new ():resizeAs (gradParam ):copy (gradParam )
481
+ end
459
482
end
460
483
end
461
484
return self .momGradParams
0 commit comments