diff --git a/pytorch_tabnet/tab_model.py b/pytorch_tabnet/tab_model.py index 0fb3c5ab..0b40364e 100755 --- a/pytorch_tabnet/tab_model.py +++ b/pytorch_tabnet/tab_model.py @@ -13,6 +13,7 @@ from sklearn.base import BaseEstimator from torch.utils.data import DataLoader from copy import deepcopy +import io import json from pathlib import Path import shutil @@ -305,7 +306,13 @@ def load_model(self, filepath): with z.open("model_params.json") as f: loaded_params = json.load(f) with z.open("network.pt") as f: - saved_state_dict = torch.load(f) + try: + saved_state_dict = torch.load(f) + except io.UnsupportedOperation: + # In Python <3.7, the returned file object is not seekable (which at least + # some versions of PyTorch require) - so we'll try buffering it in to a + # BytesIO instead: + saved_state_dict = torch.load(io.BytesIO(f.read())) except KeyError: raise KeyError("Your zip file is missing at least one component")