diff --git a/pretraining_example.ipynb b/pretraining_example.ipynb index 2594403d..f443bf62 100644 --- a/pretraining_example.ipynb +++ b/pretraining_example.ipynb @@ -2,7 +2,17 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -33,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -44,9 +54,17 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading file...\n" + ] + } + ], "source": [ "out.parent.mkdir(parents=True, exist_ok=True)\n", "if out.exists():\n", @@ -65,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -90,9 +108,31 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "39 73\n", + " State-gov 9\n", + " Bachelors 16\n", + " 13 16\n", + " Never-married 7\n", + " Adm-clerical 15\n", + " Not-in-family 6\n", + " White 5\n", + " Male 2\n", + " 2174 119\n", + " 0 92\n", + " 40 94\n", + " United-States 42\n", + " <=50K 2\n", + "Set 3\n" + ] + } + ], "source": [ "nunique = train.nunique()\n", "types = train.dtypes\n", @@ -120,7 +160,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -135,7 +175,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -158,7 +198,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -167,9 +207,18 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/work/pytorch_tabnet/abstract_model.py:75: UserWarning: Device used : cpu\n", + " warnings.warn(f\"Device used : {self.device}\")\n" + ] + } + ], "source": [ "# TabNetPretrainer\n", "unsupervised_model = TabNetPretrainer(\n", @@ -193,20 +242,38 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ - "max_epochs = 1000 if not os.getenv(\"CI\", False) else 2" + "max_epochs = 2 if not os.getenv(\"CI\", False) else 2" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 27, "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 0 | loss: 6.48655 | val_0_unsup_loss_numpy: 2.0507700443267822| 0:00:07s\n", + "epoch 1 | loss: 1.61586 | val_0_unsup_loss_numpy: 1.2413300275802612| 0:00:15s\n", + "Stop training because you reached max_epochs = 2 with best_epoch = 1 and best_val_0_unsup_loss_numpy = 1.2413300275802612\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/work/pytorch_tabnet/callbacks.py:172: UserWarning: Best weights from best epoch are automatically used!\n", + " warnings.warn(wrn_msg)\n" + ] + } + ], "source": [ "unsupervised_model.fit(\n", " X_train=X_train,\n", @@ -216,7 +283,6 @@ " num_workers=0,\n", " drop_last=False,\n", " pretraining_ratio=0.8,\n", - "\n", ") " ] }, diff --git a/pytorch_tabnet/metrics.py b/pytorch_tabnet/metrics.py index 230f9d80..d0f22bd9 100644 --- a/pytorch_tabnet/metrics.py +++ b/pytorch_tabnet/metrics.py @@ -50,6 +50,21 @@ def UnsupervisedLoss(y_pred, embedded_x, obf_vars, eps=1e-9): return loss +def UnsupervisedLossNumpy(y_pred, embedded_x, obf_vars, eps=1e-9): + errors = y_pred - embedded_x + reconstruction_errors = np.multiply(errors, obf_vars) ** 2 + batch_stds = np.std(embedded_x, axis=0) ** 2 + eps + features_loss = np.matmul(reconstruction_errors, 1 / batch_stds) + # compute the number of obfuscated variables to reconstruct + nb_reconstructed_variables = np.sum(obf_vars, axis=1) + # take the mean of the reconstructed variable errors + features_loss = features_loss / (nb_reconstructed_variables + eps) + # here we take the mean per batch, contrary to the paper + loss = np.mean(features_loss) + return loss + + + @dataclass class UnsupMetricContainer: """Container holding a list of metrics. @@ -413,6 +428,41 @@ def __call__(self, y_pred, embedded_x, obf_vars): return loss.item() +class UnsupervisedNumpyMetric(Metric): + """ + Unsupervised metric + """ + + def __init__(self): + self._name = "unsup_loss_numpy" + self._maximize = False + + def __call__(self, y_pred, embedded_x, obf_vars): + """ + Compute MSE (Mean Squared Error) of predictions. + + Parameters + ---------- + y_pred : torch.Tensor or np.array + Reconstructed prediction (with embeddings) + embedded_x : torch.Tensor + Original input embedded by network + obf_vars : torch.Tensor + Binary mask for obfuscated variables. + 1 means the variables was obfuscated so reconstruction is based on this. + + Returns + ------- + float + MSE of predictions vs targets. + """ + return UnsupervisedLossNumpy( + y_pred, + embedded_x, + obf_vars + ) + + class RMSE(Metric): """ Root Mean Squared Error. diff --git a/pytorch_tabnet/pretraining.py b/pytorch_tabnet/pretraining.py index cceff32e..fb34ef3b 100644 --- a/pytorch_tabnet/pretraining.py +++ b/pytorch_tabnet/pretraining.py @@ -26,7 +26,7 @@ def __post_init__(self): super(TabNetPretrainer, self).__post_init__() self._task = 'unsupervised' self._default_loss = UnsupervisedLoss - self._default_metric = 'unsup_loss' + self._default_metric = 'unsup_loss_numpy' def prepare_target(self, y): return y @@ -341,9 +341,9 @@ def _predict_epoch(self, name, loader): # Main loop for batch_idx, X in enumerate(loader): output, embedded_x, obf_vars = self._predict_batch(X) - list_output.append(output) - list_embedded_x.append(embedded_x) - list_obfuscation.append(obf_vars) + list_output.append(output.cpu().detach().numpy()) + list_embedded_x.append(embedded_x.cpu().detach().numpy()) + list_obfuscation.append(obf_vars.cpu().detach().numpy()) output, embedded_x, obf_vars = self.stack_batches(list_output, list_embedded_x, @@ -372,9 +372,9 @@ def _predict_batch(self, X): return self.network(X) def stack_batches(self, list_output, list_embedded_x, list_obfuscation): - output = torch.cat(list_output, axis=0) - embedded_x = torch.cat(list_embedded_x, axis=0) - obf_vars = torch.cat(list_obfuscation, axis=0) + output = np.vstack(list_output) + embedded_x = np.vstack(list_embedded_x) + obf_vars = np.vstack(list_obfuscation) return output, embedded_x, obf_vars def predict(self, X):