diff --git a/ngboost/api.py b/ngboost/api.py index a80c6a0..4433f0e 100644 --- a/ngboost/api.py +++ b/ngboost/api.py @@ -97,6 +97,8 @@ def __init__( early_stopping_rounds, ) + self._estimator_type = "regressor" + def __getstate__(self): state = super().__getstate__() # Remove the unpicklable entries. @@ -172,6 +174,7 @@ def __init__( tol, random_state, ) + self._estimator_type = "classifier" def predict_proba(self, X, max_iter=None): """