Skip to content

Commit 7d52af6

Browse files
phiplegfchollet
authored andcommitted
Added logsumexp to backend. (#6346)
1 parent 70ffba0 commit 7d52af6

File tree

3 files changed

+80
-0
lines changed

3 files changed

+80
-0
lines changed

keras/backend/tensorflow_backend.py

+22
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,28 @@ def log(x):
13041304
return tf.log(x)
13051305

13061306

1307+
def logsumexp(x, axis=None, keepdims=False):
1308+
"""Computes log(sum(exp(elements across dimensions of a tensor))).
1309+
1310+
This function is more numerically stable than log(sum(exp(x))).
1311+
It avoids overflows caused by taking the exp of large inputs and
1312+
underflows caused by taking the log of small inputs.
1313+
1314+
# Arguments
1315+
x: A tensor or variable.
1316+
axis: An integer, the axis to reduce over.
1317+
keepdims: A boolean, whether to keep the dimensions or not.
1318+
If `keepdims` is `False`, the rank of the tensor is reduced
1319+
by 1. If `keepdims` is `True`, the reduced dimension is
1320+
retained with length 1.
1321+
1322+
# Returns
1323+
The reduced tensor.
1324+
"""
1325+
axis = _normalize_axis(axis, ndim(x))
1326+
return tf.reduce_logsumexp(x, reduction_indices=axis, keep_dims=keepdims)
1327+
1328+
13071329
def round(x):
13081330
"""Element-wise rounding to the closest integer.
13091331

keras/backend/theano_backend.py

+23
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,29 @@ def log(x):
528528
return T.log(x)
529529

530530

531+
def logsumexp(x, axis=None, keepdims=False):
532+
"""Computes log(sum(exp(elements across dimensions of a tensor))).
533+
534+
This function is more numerically stable than log(sum(exp(x))).
535+
It avoids overflows caused by taking the exp of large inputs and
536+
underflows caused by taking the log of small inputs.
537+
538+
# Arguments
539+
x: A tensor or variable.
540+
axis: An integer, the axis to reduce over.
541+
keepdims: A boolean, whether to keep the dimensions or not.
542+
If `keepdims` is `False`, the rank of the tensor is reduced
543+
by 1. If `keepdims` is `True`, the reduced dimension is
544+
retained with length 1.
545+
546+
# Returns
547+
The reduced tensor.
548+
"""
549+
# Theano has a built-in optimization for logsumexp (see https://github.com/Theano/Theano/pull/4736)
550+
# so we can just write the expression directly:
551+
return T.log(T.sum(T.exp(x), axis=axis, keepdims=keepdims))
552+
553+
531554
def round(x):
532555
return T.round(x, mode='half_to_even')
533556

tests/keras/backend/backend_test.py

+35
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,41 @@ def step_function(x, states):
580580
assert_allclose(tf_last_output, th_last_output, atol=1e-04)
581581
assert_allclose(tf_outputs, th_outputs, atol=1e-04)
582582

583+
@pytest.mark.parametrize('x_np,axis,keepdims', [
584+
(np.array([1.1, 0.8, 0.9]), 0, False),
585+
(np.array([[1.1, 0.8, 0.9]]), 0, False),
586+
(np.array([[1.1, 0.8, 0.9]]), 1, False),
587+
(np.array([[1.1, 0.8, 0.9]]), -1, False),
588+
(np.array([[1.1, 0.8, 0.9]]), 1, True),
589+
(np.array([[1.1], [1.2]]), 0, False),
590+
(np.array([[1.1], [1.2]]), 1, False),
591+
(np.array([[1.1], [1.2]]), -1, False),
592+
(np.array([[1.1], [1.2]]), -1, True),
593+
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), None, False),
594+
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), 0, False),
595+
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), 1, False),
596+
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), -1, False),
597+
])
598+
@pytest.mark.parametrize('K', [KTH, KTF], ids=["KTH", "KTF"])
599+
def test_logsumexp(self, x_np, axis, keepdims, K):
600+
'''
601+
Check if K.logsumexp works properly for values close to one.
602+
'''
603+
x = K.variable(x_np)
604+
assert_allclose(K.eval(K.logsumexp(x, axis=axis, keepdims=keepdims)),
605+
np.log(np.sum(np.exp(x_np), axis=axis, keepdims=keepdims)),
606+
rtol=1e-5)
607+
608+
@pytest.mark.parametrize('K', [KTH, KTF], ids=["KTH", "KTF"])
609+
def test_logsumexp_optim(self, K):
610+
'''
611+
Check if optimization works.
612+
'''
613+
x_np = np.array([1e+4, 1e-4])
614+
assert_allclose(K.eval(K.logsumexp(K.variable(x_np), axis=0)),
615+
1e4,
616+
rtol=1e-5)
617+
583618
def test_switch(self):
584619
val = np.random.random()
585620
xth = KTH.variable(val)

0 commit comments

Comments
 (0)