diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index 4ef3f4d4..eac76d3f 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -318,7 +318,10 @@ def load_model(self, filepath): # In Python <3.7, the returned file object is not seekable (which at least # some versions of PyTorch require) - so we'll try buffering it in to a # BytesIO instead: - saved_state_dict = torch.load(io.BytesIO(f.read())) + saved_state_dict = torch.load( + io.BytesIO(f.read()), + map_location=self.device, + ) except KeyError: raise KeyError("Your zip file is missing at least one component")