Skip to content

Commit

Permalink
feat: add warm_start matching scikit-learn
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Nov 12, 2021
1 parent a0fd306 commit d725101
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 24 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,7 @@ loaded_clf.load_model(saved_filepath)
/!\ TabNetPretrainer Only : Percentage of input features to mask during pretraining.

Should be between 0 and 1. The bigger the harder the reconstruction task is.

- `warm_start` : bool (default=False)
In order to match scikit-learn API, this is set to False.
It allows to fit twice the same model and start from a warm start.
51 changes: 30 additions & 21 deletions census_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,18 @@
"metadata": {},
"outputs": [],
"source": [
"clf = TabNetClassifier(cat_idxs=cat_idxs,\n",
" cat_dims=cat_dims,\n",
" cat_emb_dim=1,\n",
" optimizer_fn=torch.optim.Adam,\n",
" optimizer_params=dict(lr=2e-2),\n",
" scheduler_params={\"step_size\":50, # how to use learning rate scheduler\n",
" \"gamma\":0.9},\n",
" scheduler_fn=torch.optim.lr_scheduler.StepLR,\n",
" mask_type='entmax' # \"sparsemax\"\n",
"tabnet_params = {\"cat_idxs\":cat_idxs,\n",
" \"cat_dims\":cat_dims,\n",
" \"cat_emb_dim\":1,\n",
" \"optimizer_fn\":torch.optim.Adam,\n",
" \"optimizer_params\":dict(lr=2e-2),\n",
" \"scheduler_params\":{\"step_size\":50, # how to use learning rate scheduler\n",
" \"gamma\":0.9},\n",
" \"scheduler_fn\":torch.optim.lr_scheduler.StepLR,\n",
" \"mask_type\":'entmax' # \"sparsemax\"\n",
" }\n",
"\n",
"clf = TabNetClassifier(**tabnet_params\n",
" )"
]
},
Expand Down Expand Up @@ -199,7 +202,7 @@
"metadata": {},
"outputs": [],
"source": [
"max_epochs = 1000 if not os.getenv(\"CI\", False) else 2"
"max_epochs = 100 if not os.getenv(\"CI\", False) else 2"
]
},
{
Expand All @@ -210,17 +213,23 @@
},
"outputs": [],
"source": [
"clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" eval_metric=['auc'],\n",
" max_epochs=max_epochs , patience=20,\n",
" batch_size=1024, virtual_batch_size=128,\n",
" num_workers=0,\n",
" weights=1,\n",
" drop_last=False\n",
") "
"# This illustrates the warm_start=False behaviour\n",
"save_history = []\n",
"for _ in range(2):\n",
" clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" eval_metric=['auc'],\n",
" max_epochs=max_epochs , patience=20,\n",
" batch_size=1024, virtual_batch_size=128,\n",
" num_workers=0,\n",
" weights=1,\n",
" drop_last=False\n",
" )\n",
" save_history.append(clf.history[\"valid_auc\"])\n",
" \n",
"assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))"
]
},
{
Expand Down
15 changes: 13 additions & 2 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
create_dataloaders,
define_device,
ComplexEncoder,
check_input
check_input,
check_warm_start
)
from pytorch_tabnet.callbacks import (
CallbackContainer,
Expand Down Expand Up @@ -73,6 +74,10 @@ def __post_init__(self):
if self.verbose != 0:
warnings.warn(f"Device used : {self.device}")

# create deep copies of mutable parameters
self.optimizer_fn = copy.deepcopy(self.optimizer_fn)
self.scheduler_fn = copy.deepcopy(self.scheduler_fn)

def __update__(self, **kwargs):
"""
Updates parameters.
Expand Down Expand Up @@ -120,6 +125,7 @@ def fit(
callbacks=None,
pin_memory=True,
from_unsupervised=None,
warm_start=False
):
"""Train a neural network stored in self.network
Using train_dataloader for training data and
Expand Down Expand Up @@ -163,6 +169,8 @@ def fit(
Whether to set pin_memory to True or False during training
from_unsupervised: unsupervised trained model
Use a previously self supervised model as starting weights
warm_start: bool
If True, current model parameters are used to start training
"""
# update model name

Expand All @@ -184,6 +192,7 @@ def fit(
self.loss_fn = loss_fn

check_input(X_train)
check_warm_start(warm_start, from_unsupervised)

self.update_fit_params(
X_train,
Expand All @@ -203,7 +212,8 @@ def fit(
# Update parameters to match self pretraining
self.__update__(**from_unsupervised.get_params())

if not hasattr(self, "network"):
if not hasattr(self, "network") or not warm_start:
# model has never been fitted before of warm_start is False
self._set_network()
self._update_network_params()
self._set_metrics(eval_metric, eval_names)
Expand Down Expand Up @@ -542,6 +552,7 @@ def _predict_batch(self, X):

def _set_network(self):
"""Setup the network and explain matrix."""
torch.manual_seed(self.seed)
self.network = tab_network.TabNet(
self.input_dim,
self.output_dim,
Expand Down
6 changes: 5 additions & 1 deletion pytorch_tabnet/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def fit(
drop_last=True,
callbacks=None,
pin_memory=True,
warm_start=False
):
"""Train a neural network stored in self.network
Using train_dataloader for training data and
Expand Down Expand Up @@ -130,8 +131,10 @@ def fit(
X_train, eval_set
)

if not hasattr(self, 'network'):
if not hasattr(self, "network") or not warm_start:
# model has never been fitted before of warm_start is False
self._set_network()

self._update_network_params()
self._set_metrics(eval_names)
self._set_optimizer()
Expand Down Expand Up @@ -168,6 +171,7 @@ def _set_network(self):
"""Setup the network and explain matrix."""
if not hasattr(self, 'pretraining_ratio'):
self.pretraining_ratio = 0.5
torch.manual_seed(self.seed)
self.network = tab_network.TabNetPretraining(
self.input_dim,
pretraining_ratio=self.pretraining_ratio,
Expand Down
11 changes: 11 additions & 0 deletions pytorch_tabnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
from sklearn.utils import check_array
import pandas as pd
import warnings


class TorchDataset(Dataset):
Expand Down Expand Up @@ -349,4 +350,14 @@ def check_input(X):
err_message = "Pandas DataFrame are not supported: apply X.values when calling fit"
raise(ValueError, err_message)
check_array(X)


def check_warm_start(warm_start, from_unsupervised):
"""
Gives a warning about ambiguous usage of the two parameters.
"""
if warm_start and from_unsupervised is not None:
warn_msg = "warm_start=True and from_unsupervised != None: "
warn_msg = "warm_start will be ignore, training will start from unsupervised weights"
warnings.warn(warn_msg)
return

0 comments on commit d725101

Please sign in to comment.