Skip to content

Commit

Permalink
Remove neural
Browse files Browse the repository at this point in the history
  • Loading branch information
mcmahom5 committed Jun 6, 2024
1 parent 76b34ef commit eaf452c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 18 deletions.
4 changes: 0 additions & 4 deletions conf/base/catalog/model_comparison.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,3 @@ xgboost_permutation_imp_validation:
filepath: /data/mridle/data/kedro_data_catalog/08_reporting/model_comparison/permutation_importances/xgboost_permutation_imp_validation.html
layer: reporting

neural_net_permutation_imp_validation:
type: mridle.extras.datasets.altair_dataset.AltairDataSet
filepath: /data/mridle/data/kedro_data_catalog/08_reporting/model_comparison/permutation_importances/neural_net_permutation_imp_validation.html
layer: reporting
16 changes: 7 additions & 9 deletions src/mridle/pipelines/data_science/model_comparison/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def create_evaluation_table(harvey_model_log_reg, harvey_random_forest, logistic_regression_model, random_forest_model,
xgboost_model, neural_net_model, validation_data):
xgboost_model, validation_data):
"""
Function to create a table of metrics for the models.
Expand All @@ -26,15 +26,14 @@ def create_evaluation_table(harvey_model_log_reg, harvey_random_forest, logistic
logistic_regression_model: serialised logistic regression model
random_forest_model: serialised random forest model
xgboost_model: serialised xgboost model
neural_net_model: serialised neural net model
validation_data: validation data, split out from master_feature_set before experiments were ran
Returns:
"""
serialised_models = [('Harvey LogReg', harvey_model_log_reg), ('Harvey RandomForest', harvey_random_forest),
('Logistic Regression', logistic_regression_model), ('RandomForest', random_forest_model),
('XGBoost', xgboost_model), ('Neural Net', neural_net_model)]
('XGBoost', xgboost_model)]

evaluation_table = []
avg_appts_per_week = 105 # taken from aggregation of df_features_original data for the year 2017 (in notebook 52)
Expand All @@ -47,8 +46,7 @@ def create_evaluation_table(harvey_model_log_reg, harvey_random_forest, logistic
experiment = Experiment.deserialize(serialised_m)

preds_prob = experiment.final_predictor.predict_proba(val_dataset.x)
if model_name == 'Neural Net':
preds_prob = [prob for [prob] in preds_prob]

preds_prob_sorted = np.sort(preds_prob)[::-1]

calc_list = []
Expand Down Expand Up @@ -155,11 +153,11 @@ def create_model_precision_comparison_plot(evaluation_table_df: pd.DataFrame) ->


def plot_pr_roc_curve_comparison(harvey_model_log_reg, harvey_random_forest, logistic_regression_model,
random_forest_model, xgboost_model, neural_net_model, validation_data):
random_forest_model, xgboost_model, validation_data):

serialised_models = [('Harvey LogReg', harvey_model_log_reg), ('Harvey RandomForest', harvey_random_forest),
('Logistic Regression', logistic_regression_model), ('RandomForest', random_forest_model),
('XGBoost', xgboost_model), ('Neural Net', neural_net_model)]
('XGBoost', xgboost_model)]

alt.data_transformers.disable_max_rows()
all_pr_df = pd.DataFrame()
Expand Down Expand Up @@ -224,14 +222,14 @@ def plot_pr_roc_curve_comparison(harvey_model_log_reg, harvey_random_forest, log


def plot_permutation_importance_charts(harvey_model_log_reg, harvey_random_forest, logistic_regression_model,
random_forest_model, xgboost_model, neural_net_model, train_data,
random_forest_model, xgboost_model, train_data,
validation_data):
"""
Create permutation importance charts for each combination of the supplied models and the two supplied datasets
"""
serialised_models = [('Harvey LogReg', harvey_model_log_reg), ('Harvey RandomForest', harvey_random_forest),
('Logistic Regression', logistic_regression_model), ('RandomForest', random_forest_model),
('XGBoost', xgboost_model), ('Neural Net', neural_net_model)]
('XGBoost', xgboost_model)]
dfs = [('Training', train_data), ('Validation', validation_data)]

p_imp_plot_list = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def create_pipeline(**kwargs):
node(
func=create_evaluation_table,
inputs=["harvey_model_logistic_reg", "harvey_model_random_forest", "logistic_regression_model",
"random_forest_model", "xgboost_model", "neural_net_model", "validation_data"],
"random_forest_model", "xgboost_model", "validation_data"],
outputs="evaluation_table",
name="create_evaluation_table"
),
Expand All @@ -22,21 +22,20 @@ def create_pipeline(**kwargs):
node(
func=plot_pr_roc_curve_comparison,
inputs=["harvey_model_logistic_reg", "harvey_model_random_forest", "logistic_regression_model",
"random_forest_model", "xgboost_model", "neural_net_model", "validation_data"],
"random_forest_model", "xgboost_model", "validation_data"],
outputs=["pr_curve_comparison", "roc_curve_comparison"],
name="plot_pr_roc_curve_comparison"
),
node(
func=plot_permutation_importance_charts,
inputs=["harvey_model_logistic_reg", "harvey_model_random_forest", "logistic_regression_model",
"random_forest_model", "xgboost_model", "neural_net_model", "train_data", "validation_data"],
"random_forest_model", "xgboost_model", "train_data", "validation_data"],
outputs=["harvey_logistic_reg_permutation_imp_train", "harvey_logistic_reg_permutation_imp_validation",
"harvey_random_forest_permutation_imp_train",
"harvey_random_forest_permutation_imp_validation",
"logistic_regression_permutation_imp_train", "logistic_regression_permutation_imp_validation",
"random_forest_permutation_imp_train", "random_forest_permutation_imp_validation",
"xgboost_permutation_imp_train", "xgboost_permutation_imp_validation",
"neural_net_permutation_imp_train", "neural_net_permutation_imp_validation"
"xgboost_permutation_imp_train", "xgboost_permutation_imp_validation"
],
name="plot_permutation_imp_xgboost"
)
Expand Down

0 comments on commit eaf452c

Please sign in to comment.