diff --git a/src/mridle/experiment/architecture.py b/src/mridle/experiment/architecture.py index 1f4f290a..b22cc690 100644 --- a/src/mridle/experiment/architecture.py +++ b/src/mridle/experiment/architecture.py @@ -3,7 +3,7 @@ import skorch from sklearn.base import BaseEstimator from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor -from sklearn.linear_model import LogisticRegression +from sklearn.linear_model import LogisticRegression, Lasso from sklearn.pipeline import Pipeline from sklearn.compose import ColumnTransformer from sklearn.preprocessing import FunctionTransformer @@ -35,6 +35,7 @@ class ArchitectureInterface(ComponentInterface): registered_flavors = { 'RandomForestClassifier': RandomForestClassifier, # TODO enable auto-loading from sklearn 'RandomForestRegressor': RandomForestRegressor, # TODO enable auto-loading from sklearn + 'LassoRegressor': Lasso, # TODO enable auto-loading from sklearn 'LogisticRegression': LogisticRegression, 'XGBClassifier': xgb.XGBClassifier, 'Pipeline': Pipeline,