Skip to content

Commit

Permalink
finish test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
ourownstory committed Nov 10, 2020
1 parent 2134706 commit 6288294
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 22 deletions.
5 changes: 3 additions & 2 deletions arnet/ar_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def fit_with_defaults(self, series):
def plot_weights(self, **kwargs):
plotting.plot_weights(
ar_val=self.ar_order,
weights=self.coeff,
weights=self.coeff[0],
ar=self.ar_params,
**kwargs,
)
Expand All @@ -200,9 +200,10 @@ def plot_errors(self, **kwargs):

def save_model(self, results_path="results", model_name=None):
# self.learn.freeze()
sparsity = 1.0 if self.sparsity is None else self.sparsity
if model_name is None:
model_name = "ar{}_sparse_{:.3f}_ahead_{}_epoch_{}.pkl".format(
self.ar_order, self.sparsity, self.n_forecasts, self.n_epoch
self.ar_order, sparsity, self.n_forecasts, self.n_epoch
)
self.learn.export(fname=os.path.join(results_path, model_name))
return self
Expand Down
23 changes: 3 additions & 20 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
import warnings

warnings.filterwarnings("ignore", message=".*nonzero.*", category=UserWarning)

from fastai.learner import load_learner

## lazy imports ala fastai2 style (needed for nice print functionality)
# from fastai.basics import *
# from fastai.tabular.all import *

import arnet

log = logging.getLogger("ARNet.test")
Expand Down Expand Up @@ -69,12 +62,13 @@ def test_save_load(self):
m = m.fit_with_defaults(series=df)

# Optional:save and create inference learner
model_name = "ar{}_sparse_{:.3f}_ahead_{}_epoch_{}.pkl".format(m.ar_order, m.sparsity, m.n_forecasts, m.n_epoch)
sparsity = 1.0 if m.sparsity is None else m.sparsity
model_name = "ar{}_sparse_{:.3f}_ahead_{}_epoch_{}.pkl".format(m.ar_order, sparsity, m.n_forecasts, m.n_epoch)
m = m.save_model(results_path=results_path, model_name=model_name)
# can be loaded like this
m = m.load_model(results_path, model_name)
# can unfreeze the model and fine_tune
m.learn.fit_one_cycle(2, 0.0001)
log.info("loaded coeff: {}".format(m.coeff))

shutil.rmtree(results_path)

Expand All @@ -100,14 +94,3 @@ def test_ar_data(self):
# Look at Coeff
log.info("ar params: {}".format(arnet.nice_print_list(ar_params)))
log.info("model weights: {}".format(arnet.nice_print_list(m.coeff)))

# if save:
# # Optional:save and create inference learner
# learn.freeze()
# model_name = "ar{}_sparse_{:.3f}_ahead_{}_epoch_{}.pkl".format(ar_order, sparsity, n_forecasts, n_epoch)
# learn.export(fname=os.path.join(results_path, model_name))
# # can be loaded like this
# infer = load_learner(fname=os.path.join(results_path, model_name), cpu=True)
# # can unfreeze the model and fine_tune
# learn.unfreeze()
# learn.fit_one_cycle(1, lr_at_min / 100)

0 comments on commit 6288294

Please sign in to comment.