-
Notifications
You must be signed in to change notification settings - Fork 0
/
reproducibility.py
31 lines (23 loc) · 1013 Bytes
/
reproducibility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import random
import numpy as np
import torch
def make_reproducible(seed: int = 42) -> None:
"""
Make the results reproducible, possibly at a performance cost.
Note that completely reproducible results are not guaranteed across PyTorch releases, individual commits, or different platforms.
Furthermore, results may not be reproducible between CPU and GPU executions, even when using identical seeds.
More details can be found at https://pytorch.org/docs/stable/notes/randomness.html.
Parameters
----------
seed : int
random seed to use
"""
# See https://github.com/pytorch/pytorch/issues/47672 and https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility for details.
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)