Skip to content

Commit a959ba3

Browse files
committed
fixing to be tensor type agnostic
1 parent 83d952d commit a959ba3

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

lswolfe.lua

+14-15
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options)
3434
local abs = torch.abs
3535
local min = math.min
3636
local max = math.max
37-
local Tensor = torch.Tensor
3837

3938
-- verbose function
4039
local function verbose(...)
@@ -56,25 +55,25 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options)
5655
while LSiter < maxIter do
5756
-- check conditions:
5857
if (f_new > (f + c1*t*gtd)) or (LSiter > 1 and f_new >= f_prev) then
59-
bracket = Tensor{t_prev,t}
60-
bracketFval = Tensor{f_prev,f_new}
61-
bracketGval = Tensor(2,g_new:size(1))
58+
bracket = x.new{t_prev,t}
59+
bracketFval = x.new{f_prev,f_new}
60+
bracketGval = x.new(2,g_new:size(1))
6261
bracketGval[1] = g_prev
6362
bracketGval[2] = g_new
6463
break
6564

6665
elseif abs(gtd_new) <= -c2*gtd then
67-
bracket = Tensor{t}
68-
bracketFval = Tensor{f_new}
69-
bracketGval = Tensor(1,g_new:size(1))
66+
bracket = x.new{t}
67+
bracketFval = x.new{f_new}
68+
bracketGval = x.new(1,g_new:size(1))
7069
bracketGval[1] = g_new
7170
done = true
7271
break
7372

7473
elseif gtd_new >= 0 then
75-
bracket = Tensor{t_prev,t}
76-
bracketFval = Tensor{f_prev,f_new}
77-
bracketGval = Tensor(2,g_new:size(1))
74+
bracket = x.new{t_prev,t}
75+
bracketFval = x.new{f_prev,f_new}
76+
bracketGval = x.new(2,g_new:size(1))
7877
bracketGval[1] = g_prev
7978
bracketGval[2] = g_new
8079
break
@@ -86,7 +85,7 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options)
8685
t_prev = t
8786
local minStep = t + 0.01*(t-tmp)
8887
local maxStep = t*10
89-
t = optim.polyinterp(Tensor{{tmp,f_prev,gtd_prev},
88+
t = optim.polyinterp(x.new{{tmp,f_prev,gtd_prev},
9089
{t,f_new,gtd_new}},
9190
minStep, maxStep)
9291

@@ -104,9 +103,9 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options)
104103

105104
-- reached max nb of iterations?
106105
if LSiter == maxIter then
107-
bracket = Tensor{0,t}
108-
bracketFval = Tensor{f,f_new}
109-
bracketGval = Tensor(2,g_new:size(1))
106+
bracket = x.new{0,t}
107+
bracketFval = x.new{f,f_new}
108+
bracketGval = x.new(2,g_new:size(1))
110109
bracketGval[1] = g
111110
bracketGval[2] = g_new
112111
end
@@ -123,7 +122,7 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options)
123122
local HIpos = -LOpos+3
124123

125124
-- compute new trial value
126-
t = optim.polyinterp(Tensor{{bracket[1],bracketFval[1],bracketGval[1]*d},
125+
t = optim.polyinterp(x.new{{bracket[1],bracketFval[1],bracketGval[1]*d},
127126
{bracket[2],bracketFval[2],bracketGval[2]*d}})
128127

129128
-- test what we are making sufficient progress

0 commit comments

Comments
 (0)