From d725101a559c6be49a6f8e20c3e68b18b8eb7b01 Mon Sep 17 00:00:00 2001 From: Optimox Date: Mon, 8 Nov 2021 15:41:52 +0100 Subject: [PATCH] feat: add warm_start matching scikit-learn --- README.md | 4 +++ census_example.ipynb | 51 +++++++++++++++++++------------- pytorch_tabnet/abstract_model.py | 15 ++++++++-- pytorch_tabnet/pretraining.py | 6 +++- pytorch_tabnet/utils.py | 11 +++++++ 5 files changed, 63 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index e71512d6..1e9e11c8 100644 --- a/README.md +++ b/README.md @@ -360,3 +360,7 @@ loaded_clf.load_model(saved_filepath) /!\ TabNetPretrainer Only : Percentage of input features to mask during pretraining. Should be between 0 and 1. The bigger the harder the reconstruction task is. + +- `warm_start` : bool (default=False) + In order to match scikit-learn API, this is set to False. + It allows to fit twice the same model and start from a warm start. diff --git a/census_example.ipynb b/census_example.ipynb index 3ceb0ad6..7ef1caf8 100755 --- a/census_example.ipynb +++ b/census_example.ipynb @@ -158,15 +158,18 @@ "metadata": {}, "outputs": [], "source": [ - "clf = TabNetClassifier(cat_idxs=cat_idxs,\n", - " cat_dims=cat_dims,\n", - " cat_emb_dim=1,\n", - " optimizer_fn=torch.optim.Adam,\n", - " optimizer_params=dict(lr=2e-2),\n", - " scheduler_params={\"step_size\":50, # how to use learning rate scheduler\n", - " \"gamma\":0.9},\n", - " scheduler_fn=torch.optim.lr_scheduler.StepLR,\n", - " mask_type='entmax' # \"sparsemax\"\n", + "tabnet_params = {\"cat_idxs\":cat_idxs,\n", + " \"cat_dims\":cat_dims,\n", + " \"cat_emb_dim\":1,\n", + " \"optimizer_fn\":torch.optim.Adam,\n", + " \"optimizer_params\":dict(lr=2e-2),\n", + " \"scheduler_params\":{\"step_size\":50, # how to use learning rate scheduler\n", + " \"gamma\":0.9},\n", + " \"scheduler_fn\":torch.optim.lr_scheduler.StepLR,\n", + " \"mask_type\":'entmax' # \"sparsemax\"\n", + " }\n", + "\n", + "clf = TabNetClassifier(**tabnet_params\n", " )" ] }, @@ -199,7 +202,7 @@ "metadata": {}, "outputs": [], "source": [ - "max_epochs = 1000 if not os.getenv(\"CI\", False) else 2" + "max_epochs = 100 if not os.getenv(\"CI\", False) else 2" ] }, { @@ -210,17 +213,23 @@ }, "outputs": [], "source": [ - "clf.fit(\n", - " X_train=X_train, y_train=y_train,\n", - " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", - " eval_name=['train', 'valid'],\n", - " eval_metric=['auc'],\n", - " max_epochs=max_epochs , patience=20,\n", - " batch_size=1024, virtual_batch_size=128,\n", - " num_workers=0,\n", - " weights=1,\n", - " drop_last=False\n", - ") " + "# This illustrates the warm_start=False behaviour\n", + "save_history = []\n", + "for _ in range(2):\n", + " clf.fit(\n", + " X_train=X_train, y_train=y_train,\n", + " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", + " eval_name=['train', 'valid'],\n", + " eval_metric=['auc'],\n", + " max_epochs=max_epochs , patience=20,\n", + " batch_size=1024, virtual_batch_size=128,\n", + " num_workers=0,\n", + " weights=1,\n", + " drop_last=False\n", + " )\n", + " save_history.append(clf.history[\"valid_auc\"])\n", + " \n", + "assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))" ] }, { diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index db5e9369..9f5cc023 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -13,7 +13,8 @@ create_dataloaders, define_device, ComplexEncoder, - check_input + check_input, + check_warm_start ) from pytorch_tabnet.callbacks import ( CallbackContainer, @@ -73,6 +74,10 @@ def __post_init__(self): if self.verbose != 0: warnings.warn(f"Device used : {self.device}") + # create deep copies of mutable parameters + self.optimizer_fn = copy.deepcopy(self.optimizer_fn) + self.scheduler_fn = copy.deepcopy(self.scheduler_fn) + def __update__(self, **kwargs): """ Updates parameters. @@ -120,6 +125,7 @@ def fit( callbacks=None, pin_memory=True, from_unsupervised=None, + warm_start=False ): """Train a neural network stored in self.network Using train_dataloader for training data and @@ -163,6 +169,8 @@ def fit( Whether to set pin_memory to True or False during training from_unsupervised: unsupervised trained model Use a previously self supervised model as starting weights + warm_start: bool + If True, current model parameters are used to start training """ # update model name @@ -184,6 +192,7 @@ def fit( self.loss_fn = loss_fn check_input(X_train) + check_warm_start(warm_start, from_unsupervised) self.update_fit_params( X_train, @@ -203,7 +212,8 @@ def fit( # Update parameters to match self pretraining self.__update__(**from_unsupervised.get_params()) - if not hasattr(self, "network"): + if not hasattr(self, "network") or not warm_start: + # model has never been fitted before of warm_start is False self._set_network() self._update_network_params() self._set_metrics(eval_metric, eval_names) @@ -542,6 +552,7 @@ def _predict_batch(self, X): def _set_network(self): """Setup the network and explain matrix.""" + torch.manual_seed(self.seed) self.network = tab_network.TabNet( self.input_dim, self.output_dim, diff --git a/pytorch_tabnet/pretraining.py b/pytorch_tabnet/pretraining.py index fd6fdc8c..cceff32e 100644 --- a/pytorch_tabnet/pretraining.py +++ b/pytorch_tabnet/pretraining.py @@ -58,6 +58,7 @@ def fit( drop_last=True, callbacks=None, pin_memory=True, + warm_start=False ): """Train a neural network stored in self.network Using train_dataloader for training data and @@ -130,8 +131,10 @@ def fit( X_train, eval_set ) - if not hasattr(self, 'network'): + if not hasattr(self, "network") or not warm_start: + # model has never been fitted before of warm_start is False self._set_network() + self._update_network_params() self._set_metrics(eval_names) self._set_optimizer() @@ -168,6 +171,7 @@ def _set_network(self): """Setup the network and explain matrix.""" if not hasattr(self, 'pretraining_ratio'): self.pretraining_ratio = 0.5 + torch.manual_seed(self.seed) self.network = tab_network.TabNetPretraining( self.input_dim, pretraining_ratio=self.pretraining_ratio, diff --git a/pytorch_tabnet/utils.py b/pytorch_tabnet/utils.py index 5eedae7e..aaf908cf 100644 --- a/pytorch_tabnet/utils.py +++ b/pytorch_tabnet/utils.py @@ -6,6 +6,7 @@ import json from sklearn.utils import check_array import pandas as pd +import warnings class TorchDataset(Dataset): @@ -349,4 +350,14 @@ def check_input(X): err_message = "Pandas DataFrame are not supported: apply X.values when calling fit" raise(ValueError, err_message) check_array(X) + + +def check_warm_start(warm_start, from_unsupervised): + """ + Gives a warning about ambiguous usage of the two parameters. + """ + if warm_start and from_unsupervised is not None: + warn_msg = "warm_start=True and from_unsupervised != None: " + warn_msg = "warm_start will be ignore, training will start from unsupervised weights" + warnings.warn(warn_msg) return