diff --git a/tests/test_legacy.py b/tests/test_legacy.py index a13b7eb..f0b3225 100644 --- a/tests/test_legacy.py +++ b/tests/test_legacy.py @@ -84,7 +84,7 @@ def test_legacy_ar(self): arnet.plot_weights( ar_val=len(ar_params[0]), weights=coeff[0], ar=ar_params[0], save=not self.plot, savedir=results_path ) - arnet.plot_prediction_sample(preds, y, num_obs=100, save=not self.plot, savedir=results_path) + arnet.plot_prediction_sample(preds[:100], y[:100], save=not self.plot, savedir=results_path) arnet.plot_error_scatter(preds, y, save=not self.plot, savedir=results_path) if self.save: