Skip to content

Commit

Permalink
[bug] Fix saved model size (#1425)
Browse files Browse the repository at this point in the history
* initial try

* tidy up

* fixed

* reversed test changes

* reversed test changes

* removed minimal version
  • Loading branch information
leoniewgnr committed Sep 21, 2023
1 parent b04e141 commit 50d5e41
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions neuralprophet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,50 @@


def save(forecaster, path: str):
"""save a fitted np model to a disk file.
"""Save a fitted Neural Prophet model to disk.
Parameters
----------
Parameters:
forecaster : np.forecaster.NeuralProphet
input forecaster that is fitted
path : str
path and filename to be saved. filename could be any but suggested to have extension .np.
Examples
--------
After you fitted a model, you may save the model to save_test_model.np
>>> from neuralprophet import save
>>> save(forecaster, "test_save_model.np")
"""
# Remove the Lightning trainer since it does not serialise correcly with torch.save
attrs_to_remove = ["trainer"]
# List of attributes to remove
attrs_to_remove_forecaster = ["trainer"]
attrs_to_remove_model = ["_trainer"]

# Store removed attributes temporarily
removed_attrs = {}
for attr in attrs_to_remove:
removed_attrs[attr] = getattr(forecaster, attr)
setattr(forecaster, attr, None)
torch.save(forecaster, path)

# Restore the Lightning trainer
for attr in attrs_to_remove:
setattr(forecaster, attr, removed_attrs[attr])

# Remove specified attributes from forecaster
for attr in attrs_to_remove_forecaster:
if hasattr(forecaster, attr):
removed_attrs[attr] = getattr(forecaster, attr)
setattr(forecaster, attr, None)

# Remove specified attributes from forecaster.model
for attr in attrs_to_remove_model:
if hasattr(forecaster.model, attr):
removed_attrs[attr] = getattr(forecaster.model, attr)
setattr(forecaster.model, attr, None)

# Perform the save operation
try:
torch.save(forecaster, path)
except Exception as e:
print(f"An error occurred while saving the model: {e}")
raise
finally:
# Restore the removed attributes
for attr, value in removed_attrs.items():
if hasattr(forecaster, attr):
setattr(forecaster, attr, value)
elif hasattr(forecaster.model, attr):
setattr(forecaster.model, attr, value)


def load(path: str):
Expand Down

0 comments on commit 50d5e41

Please sign in to comment.