diff --git a/src/mridle/experiment/tuner.py b/src/mridle/experiment/tuner.py index 13f737a2..fc2cc2fb 100644 --- a/src/mridle/experiment/tuner.py +++ b/src/mridle/experiment/tuner.py @@ -104,7 +104,7 @@ def hyperopt_objective(cls, params, model, x_train, y_train, scoring_fn: str, id model_copy = model_copy.fit(x_train_cv, y_train_cv) - if scoring_fn not in ['mse', 'mae', 'rmse', 'mape']: + if scoring_fn not in ['mse', 'mae', 'rmse', 'mape', 'redflag']: y_proba_preds = model_copy.predict_proba(x_test_cv) y_proba_preds = np.clip(y_proba_preds, 1e-5, 1 - 1e-5) if y_proba_preds.shape[1] == 2: