Skip to content

Commit

Permalink
fix: specify device
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Oct 15, 2020
1 parent 63cb8c4 commit 46a301f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
8 changes: 2 additions & 6 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
validate_eval_set,
create_dataloaders,
check_nans,
define_device,
)
from pytorch_tabnet.callbacks import (
CallbackContainer,
Expand Down Expand Up @@ -62,12 +63,7 @@ def __post_init__(self):
self.virtual_batch_size = 1024
torch.manual_seed(self.seed)
# Defining device
if self.device_name == "auto":
if torch.cuda.is_available():
device_name = "cuda"
else:
device_name = "cpu"
self.device = torch.device(device_name)
self.device = torch.device(define_device(self.device_name))
print(f"Device used : {self.device}")

def fit(
Expand Down
22 changes: 22 additions & 0 deletions pytorch_tabnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,25 @@ def check_nans(array):
raise ValueError("NaN were found, TabNet does not allow nans.")
if np.isinf(array).any():
raise ValueError("Infinite values were found, TabNet does not allow inf.")


def define_device(device_name):
"""
Define the device to use during training and inference.
If auto it will detect automatically whether to use cuda or cpu
Parameters
----------
- device_name : str
Either "auto", "cpu" or "cuda"
Returns
-------
- str
Either "cpu" or "cuda"
"""
if device_name == "auto":
if torch.cuda.is_available():
return "cuda"
else:
return "cpu"
else:
return device_name

0 comments on commit 46a301f

Please sign in to comment.