diff --git a/.flake8 b/.flake8 index 60df12f5..de7f8abf 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,6 @@ [flake8] max-line-length = 100 -ignore = E203 +ignore = E203, W503 count = True exclude = .git, diff --git a/README.md b/README.md index 5ca16ef6..3635b263 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,44 @@ clf.fit( A specific customization example notebook is available here : https://github.com/dreamquark-ai/tabnet/blob/develop/customizing_example.ipynb +# Semi-supervised pre-training + +Added later to TabNet's original paper, semi-supervised pre-training is now available via the class `TabNetPretrainer`: + +```python +# TabNetPretrainer +unsupervised_model = TabNetPretrainer( + optimizer_fn=torch.optim.Adam, + optimizer_params=dict(lr=2e-2), + mask_type='entmax' # "sparsemax" +) + +unsupervised_model.fit( + X_train=X_train, + eval_set=[X_valid], + pretraining_ratio=0.8, +) + +clf = TabNetClassifier( + optimizer_fn=torch.optim.Adam, + optimizer_params=dict(lr=2e-2), + scheduler_params={"step_size":10, # how to use learning rate scheduler + "gamma":0.9}, + scheduler_fn=torch.optim.lr_scheduler.StepLR, + mask_type='sparsemax' # This will be overwritten if using pretrain model +) + +clf.fit( + X_train=X_train, y_train=y_train, + eval_set=[(X_train, y_train), (X_valid, y_valid)], + eval_name=['train', 'valid'], + eval_metric=['auc'], + from_unsupervised=unsupervised_model +) +``` + +A complete exemple can be found within the notebook `pretraining_example.ipynb`. + # Useful links - [explanatory video](https://youtu.be/ysBaZO8YmX8) diff --git a/pretraining_example.ipynb b/pretraining_example.ipynb index af91f4dc..6dab9c3b 100644 --- a/pretraining_example.ipynb +++ b/pretraining_example.ipynb @@ -202,7 +202,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": true + "scrolled": false }, "outputs": [], "source": [ @@ -210,7 +210,7 @@ " X_train=X_train,\n", " eval_set=[X_valid],\n", " max_epochs=max_epochs , patience=5,\n", - " batch_size=1024, virtual_batch_size=128,\n", + " batch_size=2048, virtual_batch_size=128,\n", " num_workers=0,\n", " drop_last=False,\n", " pretraining_ratio=0.8,\n", @@ -228,6 +228,30 @@ "assert(reconstructed_X.shape==embedded_X.shape)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "unsupervised_explain_matrix, unsupervised_masks = unsupervised_model.explain(X_valid)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(20,20))\n", + "\n", + "for i in range(3):\n", + " axs[i].imshow(unsupervised_masks[i][:50])\n", + " axs[i].set_title(f\"mask {i}\")\n" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -272,7 +296,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": true + "scrolled": false }, "outputs": [], "source": [ @@ -468,7 +492,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.7.5" }, "toc": { "base_numbering": 1, diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index 689697aa..d5a18887 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -31,6 +31,7 @@ import warnings import copy + @dataclass class TabModel(BaseEstimator): """ Class for TabNet model.""" @@ -74,22 +75,24 @@ def __update__(self, **kwargs): If does not already exists, creates it. Otherwise overwrite with warnings. """ - update_list = ["cat_dims", - "cat_emb_dim", - "cat_idxs", - "input_dim", - "mask_type", - "n_a", - "n_d", - "n_independent", - "n_shared", - "n_steps"] + update_list = [ + "cat_dims", + "cat_emb_dim", + "cat_idxs", + "input_dim", + "mask_type", + "n_a", + "n_d", + "n_independent", + "n_shared", + "n_steps", + ] for var_name, value in kwargs.items(): if var_name in update_list: try: exec(f"global previous_val; previous_val = self.{var_name}") - if previous_val != value: # noqa - wrn_msg = f"Pretraining: {var_name} changed from {previous_val} to {value}" # noqa + if previous_val != value: # noqa + wrn_msg = f"Pretraining: {var_name} changed from {previous_val} to {value}" # noqa warnings.warn(wrn_msg) exec(f"self.{var_name} = value") except AttributeError: @@ -112,7 +115,7 @@ def fit( drop_last=False, callbacks=None, pin_memory=True, - from_unsupervised=None + from_unsupervised=None, ): """Train a neural network stored in self.network Using train_dataloader for training data and @@ -196,8 +199,9 @@ def fit( # Update parameters to match self pretraining self.__update__(**from_unsupervised.get_params()) - if not hasattr(self, 'network'): + if not hasattr(self, "network"): self._set_network() + self._update_network_params() self._set_metrics(eval_metric, eval_names) self._set_optimizer() self._set_callbacks(callbacks) @@ -318,7 +322,7 @@ def explain(self, X): def load_weights_from_unsupervised(self, unsupervised_model): update_state_dict = copy.deepcopy(self.network.state_dict()) for param, weights in unsupervised_model.network.state_dict().items(): - if param.startswith('encoder'): + if param.startswith("encoder"): # Convert encoder's layers name to match new_param = "tabnet." + param else: @@ -686,6 +690,9 @@ def _compute_feature_importances(self, loader): ) self.feature_importances_ = feature_importances_ / np.sum(feature_importances_) + def _update_network_params(self): + self.network.virtual_batch_size = self.virtual_batch_size + @abstractmethod def update_fit_params(self, X_train, y_train, eval_set, weights): """ diff --git a/pytorch_tabnet/callbacks.py b/pytorch_tabnet/callbacks.py index 7b6701c1..94892a8b 100644 --- a/pytorch_tabnet/callbacks.py +++ b/pytorch_tabnet/callbacks.py @@ -161,9 +161,11 @@ def on_train_end(self, logs=None): ) print(msg) else: - msg = (f"Stop training because you reached max_epochs = {self.trainer.max_epochs}" - + f" with best_epoch = {self.best_epoch} and " - + f"best_{self.early_stopping_metric} = {round(self.best_loss, 5)}") + msg = ( + f"Stop training because you reached max_epochs = {self.trainer.max_epochs}" + + f" with best_epoch = {self.best_epoch} and " + + f"best_{self.early_stopping_metric} = {round(self.best_loss, 5)}" + ) print(msg) print("Best weights from best epoch are automatically used!") @@ -196,7 +198,7 @@ def on_train_begin(self, logs=None): self.history.update({"lr": []}) self.history.update({name: [] for name in self.trainer._metrics_names}) self.start_time = logs["start_time"] - self.epoch_loss = 0. + self.epoch_loss = 0.0 def on_epoch_begin(self, epoch, logs=None): self.epoch_metrics = {"loss": 0.0} @@ -220,8 +222,9 @@ def on_epoch_end(self, epoch, logs=None): def on_batch_end(self, batch, logs=None): batch_size = logs["batch_size"] - self.epoch_loss = (self.samples_seen * self.epoch_loss + batch_size * logs["loss"] - ) / (self.samples_seen + batch_size) + self.epoch_loss = ( + self.samples_seen * self.epoch_loss + batch_size * logs["loss"] + ) / (self.samples_seen + batch_size) self.samples_seen += batch_size def __getitem__(self, name): @@ -256,11 +259,11 @@ class LRSchedulerCallback(Callback): early_stopping_metric: str is_batch_level: bool = False - def __post_init__(self, ): - self.is_metric_related = hasattr(self.scheduler_fn, - "is_better") - self.scheduler = self.scheduler_fn(self.optimizer, - **self.scheduler_params) + def __post_init__( + self, + ): + self.is_metric_related = hasattr(self.scheduler_fn, "is_better") + self.scheduler = self.scheduler_fn(self.optimizer, **self.scheduler_params) super().__init__() def on_batch_end(self, batch, logs=None): diff --git a/pytorch_tabnet/metrics.py b/pytorch_tabnet/metrics.py index fd8e1155..473f31b0 100644 --- a/pytorch_tabnet/metrics.py +++ b/pytorch_tabnet/metrics.py @@ -24,7 +24,7 @@ def UnsupervisedLoss(y_pred, embedded_x, obf_vars, eps=1e-9): Orginal input embedded by network obf_vars : torch.Tensor Binary mask for obfuscated variables. - 1 means the variables was obfuscated so reconstruction is based on this. + 1 means the variable was obfuscated so reconstruction is based on this. eps : float A small floating point to avoid ZeroDivisionError This can happen in degenerated case when a feature has only one value @@ -32,12 +32,18 @@ def UnsupervisedLoss(y_pred, embedded_x, obf_vars, eps=1e-9): Returns ------- loss : torch float - Unsupervised loss, average value over batch samples. + Unsupervised loss, average value over batch samples. """ errors = y_pred - embedded_x - reconstruction_errors = torch.mul(errors, obf_vars)**2 - batch_stds = torch.std(embedded_x, dim=0)**2 + eps + reconstruction_errors = torch.mul(errors, obf_vars) ** 2 + batch_stds = torch.std(embedded_x, dim=0) ** 2 + eps features_loss = torch.matmul(reconstruction_errors, 1 / batch_stds) + # compute the number of non-obfuscated variables + nb_used_variables = torch.sum(obf_vars, dim=1) + # print(nb_used_variables) + # take the mean of the used variable errors + # print(features_loss) + features_loss = features_loss / (nb_used_variables + eps) # here we take the mean per batch, contrary to the paper loss = torch.mean(features_loss) return loss @@ -162,7 +168,9 @@ def get_metrics_by_names(cls, names): available_names = [metric()._name for metric in available_metrics] metrics = [] for name in names: - assert name in available_names, f"{name} is not available, choose in {available_names}" + assert ( + name in available_names + ), f"{name} is not available, choose in {available_names}" idx = available_names.index(name) metric = available_metrics[idx]() metrics.append(metric) diff --git a/pytorch_tabnet/multiclass_utils.py b/pytorch_tabnet/multiclass_utils.py index 2076c62e..2f4b44aa 100644 --- a/pytorch_tabnet/multiclass_utils.py +++ b/pytorch_tabnet/multiclass_utils.py @@ -140,7 +140,7 @@ def _is_integral_float(y): def is_multilabel(y): - """ Check if ``y`` is in a multilabel format. + """Check if ``y`` is in a multilabel format. Parameters ---------- @@ -394,17 +394,15 @@ def infer_multitask_output(y_train): if len(y_train.shape) < 2: raise ValueError( - f"""y_train shoud be of shape (n_examples, n_tasks) """ - + f"""but got {y_train.shape}""" + "y_train shoud be of shape (n_examples, n_tasks)" + + f"but got {y_train.shape}" ) nb_tasks = y_train.shape[1] tasks_dims = [] tasks_labels = [] for task_idx in range(nb_tasks): try: - output_dim, train_labels = infer_output_dim( - y_train[:, task_idx] - ) + output_dim, train_labels = infer_output_dim(y_train[:, task_idx]) tasks_dims.append(output_dim) tasks_labels.append(train_labels) except ValueError as err: diff --git a/pytorch_tabnet/pretraining.py b/pytorch_tabnet/pretraining.py index 4e658cde..38a27320 100644 --- a/pytorch_tabnet/pretraining.py +++ b/pytorch_tabnet/pretraining.py @@ -131,6 +131,7 @@ def fit( if not hasattr(self, 'network'): self._set_network() + self._update_network_params() self._set_metrics(eval_names) self._set_optimizer() self._set_callbacks(callbacks) @@ -192,6 +193,10 @@ def _set_network(self): self.network.post_embed_dim, ) + def _update_network_params(self): + self.network.virtual_batch_size = self.virtual_batch_size + self.network.pretraining_ratio = self.pretraining_ratio + def _set_metrics(self, eval_names): """Set attributes relative to the metrics. diff --git a/pytorch_tabnet/pretraining_utils.py b/pytorch_tabnet/pretraining_utils.py index 1c7ab610..3a05f0cc 100644 --- a/pytorch_tabnet/pretraining_utils.py +++ b/pytorch_tabnet/pretraining_utils.py @@ -1,7 +1,8 @@ from torch.utils.data import DataLoader -from pytorch_tabnet.utils import (create_sampler, - PredictDataset, - ) +from pytorch_tabnet.utils import ( + create_sampler, + PredictDataset, +) from sklearn.utils import check_array @@ -49,7 +50,7 @@ def create_dataloaders( shuffle=need_shuffle, num_workers=num_workers, drop_last=drop_last, - pin_memory=pin_memory + pin_memory=pin_memory, ) valid_dataloaders = [] @@ -62,7 +63,7 @@ def create_dataloaders( shuffle=need_shuffle, num_workers=num_workers, drop_last=drop_last, - pin_memory=pin_memory + pin_memory=pin_memory, ) ) @@ -94,7 +95,7 @@ def validate_eval_set(eval_set, eval_name, X_train): for set_nb, X in enumerate(eval_set): check_array(X) msg = ( - f"Number of columns is different between eval set {set_nb} " + f"Number of columns is different between eval set {set_nb}" + f"({X.shape[1]}) and X_train ({X_train.shape[1]})" ) assert X.shape[1] == X_train.shape[1], msg diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index 3750a530..345a7a12 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -364,8 +364,7 @@ def forward(self, x): masked_x, obf_vars = self.masker(embedded_x) # set prior of encoder with obf_mask prior = 1 - obf_vars - steps_out, _ = self.encoder(masked_x, - prior=prior) + steps_out, _ = self.encoder(masked_x, prior=prior) res = self.decoder(steps_out) return res, embedded_x, obf_vars else: @@ -373,6 +372,10 @@ def forward(self, x): res = self.decoder(steps_out) return res, embedded_x, torch.ones(embedded_x.shape).to(x.device) + def forward_masks(self, x): + embedded_x = self.embedder(x) + return self.encoder.forward_masks(embedded_x) + class TabNetNoEmbeddings(torch.nn.Module): def __init__( @@ -887,6 +890,8 @@ def forward(self, x): ------- masked input and obfuscated variables. """ - obfuscated_vars = torch.bernoulli(self.pretraining_ratio * torch.ones(x.shape)).to(x.device) + obfuscated_vars = torch.bernoulli( + self.pretraining_ratio * torch.ones(x.shape) + ).to(x.device) masked_input = torch.mul(1 - obfuscated_vars, x) return masked_input, obfuscated_vars diff --git a/pytorch_tabnet/utils.py b/pytorch_tabnet/utils.py index 53f5b2f1..7a66bbc9 100644 --- a/pytorch_tabnet/utils.py +++ b/pytorch_tabnet/utils.py @@ -54,7 +54,7 @@ def __getitem__(self, index): def create_sampler(weights, y_train): """ This creates a sampler from the given weights - + Parameters ---------- weights : either 0, 1, dict or iterable @@ -146,7 +146,7 @@ def create_dataloaders( shuffle=need_shuffle, num_workers=num_workers, drop_last=drop_last, - pin_memory=pin_memory + pin_memory=pin_memory, ) valid_dataloaders = [] @@ -157,7 +157,7 @@ def create_dataloaders( batch_size=batch_size, shuffle=False, num_workers=num_workers, - pin_memory=pin_memory + pin_memory=pin_memory, ) )