Skip to content

Commit

Permalink
RedFlag metric
Browse files Browse the repository at this point in the history
  • Loading branch information
mcmahom5 committed Aug 14, 2023
1 parent 35ee62e commit e7c891c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
11 changes: 11 additions & 0 deletions src/mridle/experiment/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,17 @@ def calculate(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
return metric


class RedFlag(Metric):
name = 'redflag'
metric_type = 'regression'

def calculate(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
abs_error = np.abs(y_true - y_pred)
pct_error = abs_error/y_true
metric = np.mean(pct_error > 0.2)
return metric


class RMSE(Metric):
name = 'rmse'
metric_type = 'regression'
Expand Down
4 changes: 3 additions & 1 deletion src/mridle/experiment/stratifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,14 @@ def validate_config(cls, config):
class TimeSeriesStratifier(Stratifier):

def partition_data(self, data_set: DataSet) -> List[Tuple[List[int], List[int]]]:
"""Split dataset by feature values of provided column."""
"""Split dataset by time. Take provided time variable and split the dataset based on the dates provided.
Automatically take the time range between first and second provided date as the test size."""
data_set_copy = data_set.data.copy()
data_set_copy = data_set_copy.reset_index()
time_feature = self.config['time_feature']
ordered_dates = self.config['ordered_dates']
test_size_time = (pd.to_datetime(ordered_dates[1]) - pd.to_datetime(ordered_dates[0]))
print(test_size_time)
partitions = []
for l_id, d in enumerate(ordered_dates):
print(d)
Expand Down
5 changes: 4 additions & 1 deletion src/mridle/experiment/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Dict, List, Tuple
from .architecture import Architecture
from .ConfigurableComponent import ConfigurableComponent, ComponentInterface
from .metric import AUPRC, LogLoss, F1_Macro, AUROC, BrierScore, MSE, MAE, MAPE
from .metric import AUPRC, LogLoss, F1_Macro, AUROC, BrierScore, MSE, MAE, MAPE, RedFlag


class Tuner(ConfigurableComponent):
Expand Down Expand Up @@ -129,6 +129,9 @@ def hyperopt_objective(cls, params, model, x_train, y_train, scoring_fn: str, id
elif scoring_fn == 'mape':
y_preds = model_copy.predict(x_test_cv)
loss = MAPE().calculate(y_test_cv, y_preds)
elif scoring_fn == 'redflag':
y_preds = model_copy.predict(x_test_cv)
loss = RedFlag().calculate(y_test_cv, y_preds)
else:
raise NotImplementedError(
'scoring_fn should be one of ''f1_macro'', ''log_loss'', ''auprc'', ''auroc'', ''brier_score' +
Expand Down

0 comments on commit e7c891c

Please sign in to comment.