Skip to content

Commit

Permalink
feat: mask-dependent loss
Browse files Browse the repository at this point in the history
chore: fix lint

chore: update README

feat: add explain to unsupervised training

feat: update network parameters

When the network is already defined we still need to update
some parameters fed through the fit function such as
the virtual batch size and, in the case of unsupervised
pretraining, the pretraining_ratio.
  • Loading branch information
eduardocarvp authored and Optimox committed Dec 7, 2020
1 parent d4af838 commit 64052b0
Show file tree
Hide file tree
Showing 11 changed files with 143 additions and 54 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8]
max-line-length = 100
ignore = E203
ignore = E203, W503
count = True
exclude =
.git,
Expand Down
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,44 @@ clf.fit(

A specific customization example notebook is available here : https://github.com/dreamquark-ai/tabnet/blob/develop/customizing_example.ipynb

# Semi-supervised pre-training

Added later to TabNet's original paper, semi-supervised pre-training is now available via the class `TabNetPretrainer`:

```python
# TabNetPretrainer
unsupervised_model = TabNetPretrainer(
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
mask_type='entmax' # "sparsemax"
)

unsupervised_model.fit(
X_train=X_train,
eval_set=[X_valid],
pretraining_ratio=0.8,
)

clf = TabNetClassifier(
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
scheduler_params={"step_size":10, # how to use learning rate scheduler
"gamma":0.9},
scheduler_fn=torch.optim.lr_scheduler.StepLR,
mask_type='sparsemax' # This will be overwritten if using pretrain model
)

clf.fit(
X_train=X_train, y_train=y_train,
eval_set=[(X_train, y_train), (X_valid, y_valid)],
eval_name=['train', 'valid'],
eval_metric=['auc'],
from_unsupervised=unsupervised_model
)
```

A complete exemple can be found within the notebook `pretraining_example.ipynb`.

# Useful links

- [explanatory video](https://youtu.be/ysBaZO8YmX8)
Expand Down
32 changes: 28 additions & 4 deletions pretraining_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,15 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
"scrolled": false
},
"outputs": [],
"source": [
"unsupervised_model.fit(\n",
" X_train=X_train,\n",
" eval_set=[X_valid],\n",
" max_epochs=max_epochs , patience=5,\n",
" batch_size=1024, virtual_batch_size=128,\n",
" batch_size=2048, virtual_batch_size=128,\n",
" num_workers=0,\n",
" drop_last=False,\n",
" pretraining_ratio=0.8,\n",
Expand All @@ -228,6 +228,30 @@
"assert(reconstructed_X.shape==embedded_X.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"unsupervised_explain_matrix, unsupervised_masks = unsupervised_model.explain(X_valid)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axs = plt.subplots(1, 3, figsize=(20,20))\n",
"\n",
"for i in range(3):\n",
" axs[i].imshow(unsupervised_masks[i][:50])\n",
" axs[i].set_title(f\"mask {i}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -272,7 +296,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
"scrolled": false
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -468,7 +492,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.7.5"
},
"toc": {
"base_numbering": 1,
Expand Down
37 changes: 22 additions & 15 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import warnings
import copy


@dataclass
class TabModel(BaseEstimator):
""" Class for TabNet model."""
Expand Down Expand Up @@ -74,22 +75,24 @@ def __update__(self, **kwargs):
If does not already exists, creates it.
Otherwise overwrite with warnings.
"""
update_list = ["cat_dims",
"cat_emb_dim",
"cat_idxs",
"input_dim",
"mask_type",
"n_a",
"n_d",
"n_independent",
"n_shared",
"n_steps"]
update_list = [
"cat_dims",
"cat_emb_dim",
"cat_idxs",
"input_dim",
"mask_type",
"n_a",
"n_d",
"n_independent",
"n_shared",
"n_steps",
]
for var_name, value in kwargs.items():
if var_name in update_list:
try:
exec(f"global previous_val; previous_val = self.{var_name}")
if previous_val != value: # noqa
wrn_msg = f"Pretraining: {var_name} changed from {previous_val} to {value}" # noqa
if previous_val != value: # noqa
wrn_msg = f"Pretraining: {var_name} changed from {previous_val} to {value}" # noqa
warnings.warn(wrn_msg)
exec(f"self.{var_name} = value")
except AttributeError:
Expand All @@ -112,7 +115,7 @@ def fit(
drop_last=False,
callbacks=None,
pin_memory=True,
from_unsupervised=None
from_unsupervised=None,
):
"""Train a neural network stored in self.network
Using train_dataloader for training data and
Expand Down Expand Up @@ -196,8 +199,9 @@ def fit(
# Update parameters to match self pretraining
self.__update__(**from_unsupervised.get_params())

if not hasattr(self, 'network'):
if not hasattr(self, "network"):
self._set_network()
self._update_network_params()
self._set_metrics(eval_metric, eval_names)
self._set_optimizer()
self._set_callbacks(callbacks)
Expand Down Expand Up @@ -318,7 +322,7 @@ def explain(self, X):
def load_weights_from_unsupervised(self, unsupervised_model):
update_state_dict = copy.deepcopy(self.network.state_dict())
for param, weights in unsupervised_model.network.state_dict().items():
if param.startswith('encoder'):
if param.startswith("encoder"):
# Convert encoder's layers name to match
new_param = "tabnet." + param
else:
Expand Down Expand Up @@ -686,6 +690,9 @@ def _compute_feature_importances(self, loader):
)
self.feature_importances_ = feature_importances_ / np.sum(feature_importances_)

def _update_network_params(self):
self.network.virtual_batch_size = self.virtual_batch_size

@abstractmethod
def update_fit_params(self, X_train, y_train, eval_set, weights):
"""
Expand Down
25 changes: 14 additions & 11 deletions pytorch_tabnet/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,11 @@ def on_train_end(self, logs=None):
)
print(msg)
else:
msg = (f"Stop training because you reached max_epochs = {self.trainer.max_epochs}"
+ f" with best_epoch = {self.best_epoch} and "
+ f"best_{self.early_stopping_metric} = {round(self.best_loss, 5)}")
msg = (
f"Stop training because you reached max_epochs = {self.trainer.max_epochs}"
+ f" with best_epoch = {self.best_epoch} and "
+ f"best_{self.early_stopping_metric} = {round(self.best_loss, 5)}"
)
print(msg)
print("Best weights from best epoch are automatically used!")

Expand Down Expand Up @@ -196,7 +198,7 @@ def on_train_begin(self, logs=None):
self.history.update({"lr": []})
self.history.update({name: [] for name in self.trainer._metrics_names})
self.start_time = logs["start_time"]
self.epoch_loss = 0.
self.epoch_loss = 0.0

def on_epoch_begin(self, epoch, logs=None):
self.epoch_metrics = {"loss": 0.0}
Expand All @@ -220,8 +222,9 @@ def on_epoch_end(self, epoch, logs=None):

def on_batch_end(self, batch, logs=None):
batch_size = logs["batch_size"]
self.epoch_loss = (self.samples_seen * self.epoch_loss + batch_size * logs["loss"]
) / (self.samples_seen + batch_size)
self.epoch_loss = (
self.samples_seen * self.epoch_loss + batch_size * logs["loss"]
) / (self.samples_seen + batch_size)
self.samples_seen += batch_size

def __getitem__(self, name):
Expand Down Expand Up @@ -256,11 +259,11 @@ class LRSchedulerCallback(Callback):
early_stopping_metric: str
is_batch_level: bool = False

def __post_init__(self, ):
self.is_metric_related = hasattr(self.scheduler_fn,
"is_better")
self.scheduler = self.scheduler_fn(self.optimizer,
**self.scheduler_params)
def __post_init__(
self,
):
self.is_metric_related = hasattr(self.scheduler_fn, "is_better")
self.scheduler = self.scheduler_fn(self.optimizer, **self.scheduler_params)
super().__init__()

def on_batch_end(self, batch, logs=None):
Expand Down
18 changes: 13 additions & 5 deletions pytorch_tabnet/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,26 @@ def UnsupervisedLoss(y_pred, embedded_x, obf_vars, eps=1e-9):
Orginal input embedded by network
obf_vars : torch.Tensor
Binary mask for obfuscated variables.
1 means the variables was obfuscated so reconstruction is based on this.
1 means the variable was obfuscated so reconstruction is based on this.
eps : float
A small floating point to avoid ZeroDivisionError
This can happen in degenerated case when a feature has only one value
Returns
-------
loss : torch float
Unsupervised loss, average value over batch samples.
Unsupervised loss, average value over batch samples.
"""
errors = y_pred - embedded_x
reconstruction_errors = torch.mul(errors, obf_vars)**2
batch_stds = torch.std(embedded_x, dim=0)**2 + eps
reconstruction_errors = torch.mul(errors, obf_vars) ** 2
batch_stds = torch.std(embedded_x, dim=0) ** 2 + eps
features_loss = torch.matmul(reconstruction_errors, 1 / batch_stds)
# compute the number of non-obfuscated variables
nb_used_variables = torch.sum(obf_vars, dim=1)
# print(nb_used_variables)
# take the mean of the used variable errors
# print(features_loss)
features_loss = features_loss / (nb_used_variables + eps)
# here we take the mean per batch, contrary to the paper
loss = torch.mean(features_loss)
return loss
Expand Down Expand Up @@ -162,7 +168,9 @@ def get_metrics_by_names(cls, names):
available_names = [metric()._name for metric in available_metrics]
metrics = []
for name in names:
assert name in available_names, f"{name} is not available, choose in {available_names}"
assert (
name in available_names
), f"{name} is not available, choose in {available_names}"
idx = available_names.index(name)
metric = available_metrics[idx]()
metrics.append(metric)
Expand Down
10 changes: 4 additions & 6 deletions pytorch_tabnet/multiclass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _is_integral_float(y):


def is_multilabel(y):
""" Check if ``y`` is in a multilabel format.
"""Check if ``y`` is in a multilabel format.
Parameters
----------
Expand Down Expand Up @@ -394,17 +394,15 @@ def infer_multitask_output(y_train):

if len(y_train.shape) < 2:
raise ValueError(
f"""y_train shoud be of shape (n_examples, n_tasks) """
+ f"""but got {y_train.shape}"""
"y_train shoud be of shape (n_examples, n_tasks)"
+ f"but got {y_train.shape}"
)
nb_tasks = y_train.shape[1]
tasks_dims = []
tasks_labels = []
for task_idx in range(nb_tasks):
try:
output_dim, train_labels = infer_output_dim(
y_train[:, task_idx]
)
output_dim, train_labels = infer_output_dim(y_train[:, task_idx])
tasks_dims.append(output_dim)
tasks_labels.append(train_labels)
except ValueError as err:
Expand Down
5 changes: 5 additions & 0 deletions pytorch_tabnet/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def fit(

if not hasattr(self, 'network'):
self._set_network()
self._update_network_params()
self._set_metrics(eval_names)
self._set_optimizer()
self._set_callbacks(callbacks)
Expand Down Expand Up @@ -192,6 +193,10 @@ def _set_network(self):
self.network.post_embed_dim,
)

def _update_network_params(self):
self.network.virtual_batch_size = self.virtual_batch_size
self.network.pretraining_ratio = self.pretraining_ratio

def _set_metrics(self, eval_names):
"""Set attributes relative to the metrics.
Expand Down
13 changes: 7 additions & 6 deletions pytorch_tabnet/pretraining_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from torch.utils.data import DataLoader
from pytorch_tabnet.utils import (create_sampler,
PredictDataset,
)
from pytorch_tabnet.utils import (
create_sampler,
PredictDataset,
)
from sklearn.utils import check_array


Expand Down Expand Up @@ -49,7 +50,7 @@ def create_dataloaders(
shuffle=need_shuffle,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=pin_memory
pin_memory=pin_memory,
)

valid_dataloaders = []
Expand All @@ -62,7 +63,7 @@ def create_dataloaders(
shuffle=need_shuffle,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=pin_memory
pin_memory=pin_memory,
)
)

Expand Down Expand Up @@ -94,7 +95,7 @@ def validate_eval_set(eval_set, eval_name, X_train):
for set_nb, X in enumerate(eval_set):
check_array(X)
msg = (
f"Number of columns is different between eval set {set_nb} "
f"Number of columns is different between eval set {set_nb}"
+ f"({X.shape[1]}) and X_train ({X_train.shape[1]})"
)
assert X.shape[1] == X_train.shape[1], msg
Expand Down
Loading

0 comments on commit 64052b0

Please sign in to comment.