Skip to content

Commit 5b2eb7f

Browse files
committed
allow user to specify a clone function when using sharedClone
1 parent ead91a2 commit 5b2eb7f

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

Module.lua

+12-2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
5959
moduleTree = obj
6060
obj = nil
6161
isTable = false
62+
elseif obj.dpnn_sharedClone then
63+
-- allow to use a custom sharedClone method on one module
64+
moduleTree = obj
65+
obj = nil
66+
isTable = false
6267
elseif scdone[torch.pointer(obj)] then
6368
moduleTree = scdone[torch.pointer(obj)]
6469
else
@@ -142,8 +147,13 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
142147
if scdone[torch.pointer(original)] then
143148
for k,param in pairs(moduleTree) do
144149
if torch.isTypeOf(param,'nn.Module') then
145-
-- AbstractRecurrent instances branch here with stepClone = true
146-
clone[k] = param
150+
if param.dpnn_sharedClone then
151+
-- Call the custom sharedClone
152+
clone[k] = param:dpnn_sharedClone()
153+
else
154+
-- AbstractRecurrent instances branch here with stepClone = true
155+
clone[k] = param
156+
end
147157
original[k] = param
148158
elseif torch.isTensor(param) then
149159
if param.storage then

0 commit comments

Comments
 (0)