From 451bd8669038ddf7869843f45ca872ae92e2260d Mon Sep 17 00:00:00 2001 From: Eduardo Carvalho Date: Mon, 26 Oct 2020 11:06:11 +0100 Subject: [PATCH] fix: load from cpu when saved on gpu --- .gitignore | 1 + pytorch_tabnet/abstract_model.py | 8 +++++--- pytorch_tabnet/pretraining.py | 1 - pytorch_tabnet/tab_network.py | 12 ------------ pytorch_tabnet/utils.py | 2 ++ 5 files changed, 8 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index d4d43ba1..4644f73d 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ data/ .vscode/ *.pt *~ +.vscode/ # Notebook to python forest_example.py diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index d5a18887..714f9ceb 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -64,10 +64,10 @@ def __post_init__(self): self.batch_size = 1024 self.virtual_batch_size = 128 torch.manual_seed(self.seed) - torch.manual_seed(self.seed) # Defining device self.device = torch.device(define_device(self.device_name)) - print(f"Device used : {self.device}") + if self.verbose != 0: + print(f"Device used : {self.device}") def __update__(self, **kwargs): """ @@ -330,6 +330,7 @@ def load_weights_from_unsupervised(self, unsupervised_model): if self.network.state_dict().get(new_param) is not None: # update only common layers update_state_dict[new_param] = weights + self.network.load_state_dict(update_state_dict) def save_model(self, path): @@ -380,6 +381,7 @@ def load_model(self, filepath): with zipfile.ZipFile(filepath) as z: with z.open("model_params.json") as f: loaded_params = json.load(f) + loaded_params["device_name"] = self.device_name with z.open("network.pt") as f: try: saved_state_dict = torch.load(f, map_location=self.device) @@ -399,6 +401,7 @@ def load_model(self, filepath): self._set_network() self.network.load_state_dict(saved_state_dict) self.network.eval() + return def _train_epoch(self, train_loader): @@ -539,7 +542,6 @@ def _set_network(self): epsilon=self.epsilon, virtual_batch_size=self.virtual_batch_size, momentum=self.momentum, - device_name=self.device_name, mask_type=self.mask_type, ).to(self.device) diff --git a/pytorch_tabnet/pretraining.py b/pytorch_tabnet/pretraining.py index 38a27320..3600b61d 100644 --- a/pytorch_tabnet/pretraining.py +++ b/pytorch_tabnet/pretraining.py @@ -182,7 +182,6 @@ def _set_network(self): epsilon=self.epsilon, virtual_batch_size=self.virtual_batch_size, momentum=self.momentum, - device_name=self.device_name, mask_type=self.mask_type, ).to(self.device) diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index 345a7a12..8c259eab 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -298,7 +298,6 @@ def __init__( epsilon=1e-15, virtual_batch_size=128, momentum=0.02, - device_name="auto", mask_type="sparsemax", ): super(TabNetPretraining, self).__init__() @@ -499,7 +498,6 @@ def __init__( epsilon=1e-15, virtual_batch_size=128, momentum=0.02, - device_name="auto", mask_type="sparsemax", ): """ @@ -538,7 +536,6 @@ def __init__( Batch size for Ghost Batch Normalization momentum : float Float value between 0 and 1 which will be used for momentum in all batch norm - device_name : {'auto', 'cuda', 'cpu'} mask_type : str Either "sparsemax" or "entmax" : this is the masking function to use """ @@ -581,15 +578,6 @@ def __init__( mask_type, ) - # Defining device - if device_name == "auto": - if torch.cuda.is_available(): - device_name = "cuda" - else: - device_name = "cpu" - self.device = torch.device(device_name) - self.to(self.device) - def forward(self, x): x = self.embedder(x) return self.tabnet(x) diff --git a/pytorch_tabnet/utils.py b/pytorch_tabnet/utils.py index 7a66bbc9..594777a5 100644 --- a/pytorch_tabnet/utils.py +++ b/pytorch_tabnet/utils.py @@ -311,5 +311,7 @@ def define_device(device_name): return "cuda" else: return "cpu" + elif device_name == "cuda" and not torch.cuda.is_available(): + return "cpu" else: return device_name