Skip to content

Commit eaa2e90

Browse files
clee2000malfet
andauthored
Set random seed (#2438)
To make tutorial builds predictable, but still keep randomness when one rans it on Collab. Also, reset default_device after every tutorial runCo-authored-by: Nikita Shulga <[email protected]> Co-authored-by: Nikita Shulga <[email protected]>
1 parent 730029b commit eaa2e90

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

Diff for: conf.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import pytorch_sphinx_theme
3535
import torch
3636
import glob
37+
import random
3738
import shutil
3839
from custom_directives import IncludeDirective, GalleryItemDirective, CustomGalleryItemDirective, CustomCalloutItemDirective, CustomCardItemDirective
3940
import distutils.file_util
@@ -85,6 +86,11 @@
8586

8687
# -- Sphinx-gallery configuration --------------------------------------------
8788

89+
def reset_seeds(gallery_conf, fname):
90+
torch.manual_seed(42)
91+
torch.set_default_device(None)
92+
random.seed(10)
93+
8894
sphinx_gallery_conf = {
8995
'examples_dirs': ['beginner_source', 'intermediate_source',
9096
'advanced_source', 'recipes_source', 'prototype_source'],
@@ -94,7 +100,8 @@
94100
'backreferences_dir': None,
95101
'first_notebook_cell': ("# For tips on running notebooks in Google Colab, see\n"
96102
"# https://pytorch.org/tutorials/beginner/colab\n"
97-
"%matplotlib inline")
103+
"%matplotlib inline"),
104+
'reset_modules': (reset_seeds)
98105
}
99106

100107
if os.getenv('GALLERY_PATTERN'):

Diff for: recipes_source/recipes/changing_default_device.py

-3
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@
4343
print(mod.weight.device)
4444
print(mod(torch.randn(128, 20)).device)
4545

46-
# And then globally return it back to CPU
47-
torch.set_default_device('cpu')
48-
4946
################################################################
5047
# This function imposes a slight performance cost on every Python
5148
# call to the torch API (not just factory functions). If this

0 commit comments

Comments
 (0)