Skip to content

Commit 34656ae

Browse files
committed
.10, #4
1 parent 0746ba4 commit 34656ae

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

benchmarks/test_rnn_parity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_pytorch_parity(tmpdir):
6666
for pl_out, pt_out in zip(lightning_outs, manual_outs):
6767
np.testing.assert_almost_equal(pl_out, pt_out, 8)
6868

69-
tutils.assert_speed_parity(pl_times, pt_times, num_epochs)
69+
tutils.assert_speed_parity(pl_times, pt_times)
7070

7171

7272
def vanilla_loop(MODEL, num_runs=10, num_epochs=10):

benchmarks/test_trainer_parity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_pytorch_parity(tmpdir):
5252
:param tmpdir:
5353
:return:
5454
"""
55-
num_epochs = 5
55+
num_epochs = 4
5656
num_rums = 3
5757
lightning_outs, pl_times = lightning_loop(ParityModuleMNIST, num_rums, num_epochs)
5858
manual_outs, pt_times = vanilla_loop(ParityModuleMNIST, num_rums, num_epochs)
@@ -62,7 +62,7 @@ def test_pytorch_parity(tmpdir):
6262
np.testing.assert_almost_equal(pl_out, pt_out, 5)
6363

6464
# the fist run initialize dataset (download & filter)
65-
tutils.assert_speed_parity(pl_times[1:], pt_times[1:], num_epochs)
65+
tutils.assert_speed_parity(pl_times[1:], pt_times[1:])
6666

6767

6868
def vanilla_loop(cls_model, num_runs=10, num_epochs=10):

tests/base/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
from tests.base.model_template import EvalModelTemplate
1313

1414

15-
def assert_speed_parity(pl_times, pt_times, num_epochs, max_diff_per_epoch=0.05):
15+
def assert_speed_parity(pl_times, pt_times, max_diff_per_epoch=0.1):
1616
# assert speeds
1717
diffs = np.asarray(pl_times) - np.asarray(pt_times)
18-
# norm by nb epochs and tha vanila time
19-
diffs = diffs / num_epochs / np.asarray(pt_times)
18+
# norm by vanila time
19+
diffs = diffs / np.asarray(pt_times)
2020
assert np.alltrue(diffs < max_diff_per_epoch), \
2121
f"lightning {diffs} was slower than PT (threshold {max_diff_per_epoch})"
2222

0 commit comments

Comments
 (0)