From 46a301fc5ae702f56f2f54ccabf61762da26588d Mon Sep 17 00:00:00 2001 From: Optimox Date: Thu, 15 Oct 2020 15:53:43 +0200 Subject: [PATCH] fix: specify device --- pytorch_tabnet/abstract_model.py | 8 ++------ pytorch_tabnet/utils.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index eac76d3f..64201b46 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -12,6 +12,7 @@ validate_eval_set, create_dataloaders, check_nans, + define_device, ) from pytorch_tabnet.callbacks import ( CallbackContainer, @@ -62,12 +63,7 @@ def __post_init__(self): self.virtual_batch_size = 1024 torch.manual_seed(self.seed) # Defining device - if self.device_name == "auto": - if torch.cuda.is_available(): - device_name = "cuda" - else: - device_name = "cpu" - self.device = torch.device(device_name) + self.device = torch.device(define_device(self.device_name)) print(f"Device used : {self.device}") def fit( diff --git a/pytorch_tabnet/utils.py b/pytorch_tabnet/utils.py index d5e2e0aa..5bac2713 100644 --- a/pytorch_tabnet/utils.py +++ b/pytorch_tabnet/utils.py @@ -266,3 +266,25 @@ def check_nans(array): raise ValueError("NaN were found, TabNet does not allow nans.") if np.isinf(array).any(): raise ValueError("Infinite values were found, TabNet does not allow inf.") + + +def define_device(device_name): + """ + Define the device to use during training and inference. + If auto it will detect automatically whether to use cuda or cpu + Parameters + ---------- + - device_name : str + Either "auto", "cpu" or "cuda" + Returns + ------- + - str + Either "cpu" or "cuda" + """ + if device_name == "auto": + if torch.cuda.is_available(): + return "cuda" + else: + return "cpu" + else: + return device_name