From 55c09e5c47e6ec58276c301a5af7afa2dc529bc1 Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Wed, 29 Jul 2020 14:00:17 +0000 Subject: [PATCH] fix: load_model fallback to BytesIO for Py3.6 Catch io.UnsupportedOperation raised in Python <3.7 and buffer file contents into a BytesIO to work around the error. --- pytorch_tabnet/tab_model.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pytorch_tabnet/tab_model.py b/pytorch_tabnet/tab_model.py index 0fb3c5ab..0b40364e 100755 --- a/pytorch_tabnet/tab_model.py +++ b/pytorch_tabnet/tab_model.py @@ -13,6 +13,7 @@ from sklearn.base import BaseEstimator from torch.utils.data import DataLoader from copy import deepcopy +import io import json from pathlib import Path import shutil @@ -305,7 +306,13 @@ def load_model(self, filepath): with z.open("model_params.json") as f: loaded_params = json.load(f) with z.open("network.pt") as f: - saved_state_dict = torch.load(f) + try: + saved_state_dict = torch.load(f) + except io.UnsupportedOperation: + # In Python <3.7, the returned file object is not seekable (which at least + # some versions of PyTorch require) - so we'll try buffering it in to a + # BytesIO instead: + saved_state_dict = torch.load(io.BytesIO(f.read())) except KeyError: raise KeyError("Your zip file is missing at least one component")