Skip to content

Commit 83d952d

Browse files
authored
Merge pull request #122 from Cadene/master
Add LearningRateDecay to Adam
2 parents e24fd85 + 7b32fd2 commit 83d952d

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

adam.lua

+7-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ ARGS:
77
- 'x' : the initial point
88
- 'config` : a table with configuration parameters for the optimizer
99
- 'config.learningRate' : learning rate
10+
- `config.learningRateDecay` : learning rate decay
1011
- 'config.beta1' : first moment coefficient
1112
- 'config.beta2' : second moment coefficient
1213
- 'config.epsilon' : for numerical stability
@@ -25,6 +26,7 @@ function optim.adam(opfunc, x, config, state)
2526
local config = config or {}
2627
local state = state or config
2728
local lr = config.learningRate or 0.001
29+
local lrd = config.learningRateDecay or 0
2830

2931
local beta1 = config.beta1 or 0.9
3032
local beta2 = config.beta2 or 0.999
@@ -48,6 +50,9 @@ function optim.adam(opfunc, x, config, state)
4850
-- A tmp tensor to hold the sqrt(v) + epsilon
4951
state.denom = state.denom or x.new(dfdx:size()):zero()
5052

53+
-- (3) learning rate decay (annealing)
54+
local clr = lr / (1 + state.t*lrd)
55+
5156
state.t = state.t + 1
5257

5358
-- Decay the first and second moment running average coefficient
@@ -58,8 +63,8 @@ function optim.adam(opfunc, x, config, state)
5863

5964
local biasCorrection1 = 1 - beta1^state.t
6065
local biasCorrection2 = 1 - beta2^state.t
61-
local stepSize = lr * math.sqrt(biasCorrection2)/biasCorrection1
62-
-- (3) update x
66+
local stepSize = clr * math.sqrt(biasCorrection2)/biasCorrection1
67+
-- (4) update x
6368
x:addcdiv(-stepSize, state.m, state.denom)
6469

6570
-- return x*, f(x) before optimization

doc/algos.md

+1
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ Arguments:
200200
* `x`: the initial point
201201
* `config`: a table with configuration parameters for the optimizer
202202
* `config.learningRate`: learning rate
203+
* `config.learningRateDecay`: learning rate decay
203204
* `config.beta1`: first moment coefficient
204205
* `config.beta2`: second moment coefficient
205206
* `config.epsilon`: for numerical stability

0 commit comments

Comments
 (0)