Skip to content

Commit

Permalink
Merge pull request #129 from uzh-dqbm-cmi/set_up_intervention
Browse files Browse the repository at this point in the history
Set up intervention
  • Loading branch information
mcmahom5 committed Aug 8, 2023
2 parents 18b2220 + 77288f1 commit 244659e
Show file tree
Hide file tree
Showing 10 changed files with 381 additions and 54 deletions.
30 changes: 15 additions & 15 deletions conf/base/parameters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -382,27 +382,27 @@ models:
flavor: DataSet
config:
features:
- 'no_show_before'
# - 'no_show_before'
- 'appts_before'
- 'show_before'
# - 'show_before'
- 'no_show_rate'
- 'sched_days_advanced'
- 'month'
# - 'month'
- 'age'
- 'modality'
- 'occupation'
- 'reason'
- 'sex'
# - 'sex'
- 'hour_sched'
- 'distance_to_usz'
- 'day_of_week_str'
- 'marital'
- 'times_rescheduled'
target: NoShow
Stratifier:
flavor: PartitionedLabelStratifier
flavor: PartitionedFeatureStratifier
config:
n_partitions: 5
split_feature: 'year'
Architecture:
flavor: Pipeline
config:
Expand All @@ -417,14 +417,14 @@ models:
with_mean: True
args:
columns:
- 'no_show_before'
# - 'no_show_before'
- 'sched_days_advanced'
- 'age'
- 'hour_sched'
- 'distance_to_usz'
- 'times_rescheduled'
- 'appts_before'
- 'show_before'
# - 'show_before'
- 'no_show_rate'
- name: 'onehot'
flavor: sklearn.preprocessing.OneHotEncoder
Expand All @@ -437,13 +437,13 @@ models:
- 'reason'
- 'modality'
- 'day_of_week_str'
- name: 'cyc'
flavor: mridle.utilities.modeling.CyclicalTransformer
config:
period: 12
args:
columns:
- 'month'
#- name: 'cyc'
# flavor: mridle.utilities.modeling.CyclicalTransformer
# config:
# period: 12
# args:
# columns:
# - 'month'
- flavor: XGBClassifier
name: 'classifier'
config:
Expand Down
3 changes: 2 additions & 1 deletion src/mridle/experiment/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import skorch
from sklearn.base import BaseEstimator
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
Expand Down Expand Up @@ -34,6 +34,7 @@ class ArchitectureInterface(ComponentInterface):

registered_flavors = {
'RandomForestClassifier': RandomForestClassifier, # TODO enable auto-loading from sklearn
'RandomForestRegressor': RandomForestRegressor, # TODO enable auto-loading from sklearn
'LogisticRegression': LogisticRegression,
'XGBClassifier': xgb.XGBClassifier,
'Pipeline': Pipeline,
Expand Down
9 changes: 7 additions & 2 deletions src/mridle/experiment/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,13 @@ def validate_config(config, data):
if col not in data.columns:
raise ValueError(f'Feature column {col} not found in dataset.')

if config['target'] not in data.columns:
raise ValueError(f"Target column {config['target']} not found in dataset.")
if isinstance(config['target'], str):
if config['target'] not in data.columns:
raise ValueError(f"Target column {config['target']} not found in dataset.")
elif isinstance(config['target'], list):
if not set(config['target']).issubset(data.columns):
not_in_list = list(set(config['target']).difference(data.columns))
raise ValueError(f"Target columns {not_in_list} not found in dataset.")

return True

Expand Down
4 changes: 4 additions & 0 deletions src/mridle/experiment/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ class MetricInterface(ComponentInterface):
'AUPRC': AUPRC,
'AUROC': AUROC,
'LogLoss': LogLoss,
'MAE': MAE,
'MSE': MSE,
'RMSE': RMSE,
'MedianAbsoluteError': MedianAbsoluteError
}

@classmethod
Expand Down
26 changes: 25 additions & 1 deletion src/mridle/experiment/stratifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def materialize_partition(self, partition_id: int, data_set: DataSet) -> Tuple[p
class PartitionedLabelStratifier(Stratifier):

def partition_data(self, data_set: DataSet) -> List[Tuple[List[int], List[int]]]:
"""Randomly shuffle and split the doc_list into n_partitions roughly equal lists, stratified by label."""
"""Randomly shuffle and split the data_set into n_partitions roughly equal lists, stratified by label."""
label_list = data_set.y
skf = StratifiedKFold(n_splits=self.config['n_partitions'], random_state=42, shuffle=True)
x = np.zeros(len(label_list)) # split takes a X argument for backwards compatibility and is not used
Expand Down Expand Up @@ -100,9 +100,33 @@ def validate_config(cls, config):
return True


class PartitionedFeatureStratifier(Stratifier):

def partition_data(self, data_set: DataSet) -> List[Tuple[List[int], List[int]]]:
"""Split dataset by feature values of provided column."""
data_set_copy = data_set.data.copy()
data_set_copy = data_set_copy.reset_index()
label_list = data_set_copy[self.config['split_feature']].unique()
partitions = []
for l_id, f_label in enumerate(label_list):
print(f_label)
train_ids = np.array(data_set_copy[data_set_copy[self.config['split_feature']] != f_label].index)
test_ids = np.array(data_set_copy[data_set_copy[self.config['split_feature']] == f_label].index)
partitions.append([train_ids, test_ids])
return partitions

@classmethod
def validate_config(cls, config):
for key in ['split_feature', ]:
if key not in config:
raise ValueError(f"{cls.__name__} config must contain entry '{key}'.")
return True


class StratifierInterface(ComponentInterface):

registered_flavors = {
'PartitionedFeatureStratifier': PartitionedFeatureStratifier,
'PartitionedLabelStratifier': PartitionedLabelStratifier,
'TrainTestStratifier': TrainTestStratifier,
}
Expand Down
49 changes: 42 additions & 7 deletions src/mridle/pipelines/data_science/feature_engineering/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ def daterange(date1, date2):
yield date1 + timedelta(n)


def get_last_non_na(x):
if x.last_valid_index() is None:
return '0'
else:
return x[x.last_valid_index()]


def generate_training_data(status_df, valid_date_range, append_outcome=True, add_no_show_before=True):
"""
Build data for use in models by trying to replicate the conditions under which the model would be used in reality
Expand Down Expand Up @@ -201,13 +208,15 @@ def build_feature_set(status_df: pd.DataFrame, valid_date_range: List[str], mast
'distance_to_usz_sq': 'last',
'close_to_usz': 'last',
'times_rescheduled': 'last',
'start_time': 'last'
'start_time': 'last',
'Telefon': lambda x: get_last_non_na(x)
}

slot_df = build_slot_df(status_df, valid_date_range, agg_dict, build_future_slots=build_future_slots,
include_id_cols=True)

slot_df = feature_days_scheduled_in_advance(status_df, slot_df)
slot_df = feature_year(slot_df)
slot_df = feature_month(slot_df)
slot_df = feature_hour_sched(slot_df)
slot_df = feature_day_of_week(slot_df)
Expand All @@ -217,7 +226,7 @@ def build_feature_set(status_df: pd.DataFrame, valid_date_range: List[str], mast
slot_df = feature_cyclical_month(slot_df)
slot_df = slot_df[slot_df['day_of_week_str'].isin(['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday'])]
slot_df = slot_df[slot_df['sched_days_advanced'] > 2]

slot_df = limit_to_day_hours(slot_df)
return slot_df


Expand Down Expand Up @@ -274,6 +283,20 @@ def feature_month(slot_df: pd.DataFrame) -> pd.DataFrame:
return slot_df


def feature_year(slot_df: pd.DataFrame) -> pd.DataFrame:
"""
Append the year feature to the dataframe.
Args:
slot_df: A dataframe containing appointment slots.
Returns: A row-per-status-change dataframe with additional column 'year'.
"""
slot_df['year'] = slot_df['start_time'].dt.year
return slot_df


def feature_hour_sched(slot_df: pd.DataFrame) -> pd.DataFrame:
"""
Append the hour_sched feature to the dataframe using was_sched_for_date.
Expand Down Expand Up @@ -628,10 +651,6 @@ def feature_occupation(df):

df_remap.loc[df_remap['Beruf'] == 'nan', 'occupation'] = 'none_given'
df_remap.loc[df_remap['Beruf'] == '-', 'occupation'] = 'none_given'
df_remap.loc[df_remap['Beruf'].apply(regex_search, search_str='rentner|Renter|pensioniert|pens.|rente'),
'occupation'] = 'retired'
df_remap.loc[df_remap['Beruf'].apply(regex_search, search_str='keine Angaben|keine Ang'),
'occupation'] = 'none_given'
df_remap.loc[df_remap['Beruf'].apply(regex_search,
search_str='Angestellte|ang.|baue|angest.|Hauswart|dozent|designer|^KV$|'
'masseu|Raumpflegerin|Apothekerin|Ing.|fotog|Psycholog|'
Expand All @@ -649,14 +668,24 @@ def feature_occupation(df):
'ingenieur|Kauf|mitarbeiter|Verkäufer|Informatiker|koch|'
'lehrer|arbeiter|architekt'),
'occupation'] = 'employed'
df_remap.loc[df_remap['Beruf'].apply(regex_search, search_str='rentner|Renter|pensioniert|pens.|rente'),
'occupation'] = 'retired'
df_remap.loc[df_remap['Beruf'].apply(regex_search, search_str='IV-Rentner'),
'occupation'] = 'iv_retired'

df_remap.loc[df_remap['Beruf'].apply(regex_search, search_str='keine Angaben|keine Ang'),
'occupation'] = 'none_given'

df_remap.loc[df_remap['Beruf'].apply(regex_search, search_str='student|Schüler|Doktorand|'
'Kind|Stud.|Ausbildung|^MA$'),
'occupation'] = 'student'
df_remap.loc[df_remap['Beruf'].apply(regex_search, search_str='^IV$|^IV-Bezüger|^$|arbeitslos|ohne Arbeit|'
'ohne|o.A.|nicht Arbeitstätig|'
'Sozialhilfeempfänger|o. Arbeit|keine Arbeit|'
'Asyl|RAV|Hausfrau|Hausmann'),
'Asyl|RAV'),
'occupation'] = 'unemployed'
df_remap.loc[
df_remap['Beruf'].apply(regex_search, search_str='Hausfrau|Hausmann'), 'occupation'] = 'stay_at_home_parent'
df_remap.loc[df_remap['Beruf'].apply(regex_search, search_str='selbst'), 'occupation'] = 'self_employed'
df_remap.loc[df_remap['Beruf'].apply(regex_search, search_str='arzt|aerzt|ärzt|pflegefachfrau|Pflegehelfer|'
'MTRA|Erzieherin|Fachfrau Betreuung|'
Expand All @@ -667,7 +696,9 @@ def feature_occupation(df):

df_remap.loc[df_remap['occupation'] == '', 'occupation'] = 'other'
df_remap.loc[df_remap['occupation'].isna(), 'occupation'] = 'other'

df_remap = df_remap.drop('Beruf', axis=1)

return df_remap


Expand Down Expand Up @@ -715,5 +746,9 @@ def feature_duration(dicom_df: pd.DataFrame) -> pd.DataFrame:
return dicom_df


def limit_to_day_hours(df):
return df[(df['hour_sched'] < 18) & (df['hour_sched'] > 6)]


def regex_search(x, search_str):
return bool(re.search(search_str, x, re.IGNORECASE))
2 changes: 1 addition & 1 deletion src/mridle/pipelines/data_science/live_data/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def get_slt_with_outcome():
'/data/mridle/data/silent_live_test/live_files/all/out_features_data/features_master_slt_features.csv',
parse_dates=['start_time', 'end_time'])
preds.drop(columns=['NoShow'], inplace=True)
actuals = pd.read_csv('/data/mridle/data/silent_live_test/live_files/all/actuals/master_actuals_with_filename.csv',
actuals = pd.read_csv('/data/mridle/data/silent_live_test/live_files/all/actuals/master_actuals.csv',
parse_dates=['start_time', 'end_time'])

preds['MRNCmpdId'] = preds['MRNCmpdId'].astype(str)
Expand Down
Loading

0 comments on commit 244659e

Please sign in to comment.