Skip to content

Commit

Permalink
fix: load from cpu when saved on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp authored and Optimox committed Dec 10, 2020
1 parent ebdb9ff commit 451bd86
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ data/
.vscode/
*.pt
*~
.vscode/

# Notebook to python
forest_example.py
Expand Down
8 changes: 5 additions & 3 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def __post_init__(self):
self.batch_size = 1024
self.virtual_batch_size = 128
torch.manual_seed(self.seed)
torch.manual_seed(self.seed)
# Defining device
self.device = torch.device(define_device(self.device_name))
print(f"Device used : {self.device}")
if self.verbose != 0:
print(f"Device used : {self.device}")

def __update__(self, **kwargs):
"""
Expand Down Expand Up @@ -330,6 +330,7 @@ def load_weights_from_unsupervised(self, unsupervised_model):
if self.network.state_dict().get(new_param) is not None:
# update only common layers
update_state_dict[new_param] = weights

self.network.load_state_dict(update_state_dict)

def save_model(self, path):
Expand Down Expand Up @@ -380,6 +381,7 @@ def load_model(self, filepath):
with zipfile.ZipFile(filepath) as z:
with z.open("model_params.json") as f:
loaded_params = json.load(f)
loaded_params["device_name"] = self.device_name
with z.open("network.pt") as f:
try:
saved_state_dict = torch.load(f, map_location=self.device)
Expand All @@ -399,6 +401,7 @@ def load_model(self, filepath):
self._set_network()
self.network.load_state_dict(saved_state_dict)
self.network.eval()

return

def _train_epoch(self, train_loader):
Expand Down Expand Up @@ -539,7 +542,6 @@ def _set_network(self):
epsilon=self.epsilon,
virtual_batch_size=self.virtual_batch_size,
momentum=self.momentum,
device_name=self.device_name,
mask_type=self.mask_type,
).to(self.device)

Expand Down
1 change: 0 additions & 1 deletion pytorch_tabnet/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def _set_network(self):
epsilon=self.epsilon,
virtual_batch_size=self.virtual_batch_size,
momentum=self.momentum,
device_name=self.device_name,
mask_type=self.mask_type,
).to(self.device)

Expand Down
12 changes: 0 additions & 12 deletions pytorch_tabnet/tab_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def __init__(
epsilon=1e-15,
virtual_batch_size=128,
momentum=0.02,
device_name="auto",
mask_type="sparsemax",
):
super(TabNetPretraining, self).__init__()
Expand Down Expand Up @@ -499,7 +498,6 @@ def __init__(
epsilon=1e-15,
virtual_batch_size=128,
momentum=0.02,
device_name="auto",
mask_type="sparsemax",
):
"""
Expand Down Expand Up @@ -538,7 +536,6 @@ def __init__(
Batch size for Ghost Batch Normalization
momentum : float
Float value between 0 and 1 which will be used for momentum in all batch norm
device_name : {'auto', 'cuda', 'cpu'}
mask_type : str
Either "sparsemax" or "entmax" : this is the masking function to use
"""
Expand Down Expand Up @@ -581,15 +578,6 @@ def __init__(
mask_type,
)

# Defining device
if device_name == "auto":
if torch.cuda.is_available():
device_name = "cuda"
else:
device_name = "cpu"
self.device = torch.device(device_name)
self.to(self.device)

def forward(self, x):
x = self.embedder(x)
return self.tabnet(x)
Expand Down
2 changes: 2 additions & 0 deletions pytorch_tabnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,5 +311,7 @@ def define_device(device_name):
return "cuda"
else:
return "cpu"
elif device_name == "cuda" and not torch.cuda.is_available():
return "cpu"
else:
return device_name

0 comments on commit 451bd86

Please sign in to comment.