Skip to content

Commit

Permalink
fix: torch.load map_location in Py36 fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
athewsey authored and Optimox committed Oct 15, 2020
1 parent ba09980 commit 63cb8c4
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,10 @@ def load_model(self, filepath):
# In Python <3.7, the returned file object is not seekable (which at least
# some versions of PyTorch require) - so we'll try buffering it in to a
# BytesIO instead:
saved_state_dict = torch.load(io.BytesIO(f.read()))
saved_state_dict = torch.load(
io.BytesIO(f.read()),
map_location=self.device,
)
except KeyError:
raise KeyError("Your zip file is missing at least one component")

Expand Down

0 comments on commit 63cb8c4

Please sign in to comment.