Skip to content

Commit

Permalink
fix: compute unsupervised loss using numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp authored and Optimox committed Dec 27, 2021
1 parent ca14b76 commit 49bd61b
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 27 deletions.
106 changes: 86 additions & 20 deletions pretraining_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -33,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -44,9 +54,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading file...\n"
]
}
],
"source": [
"out.parent.mkdir(parents=True, exist_ok=True)\n",
"if out.exists():\n",
Expand All @@ -65,7 +83,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -90,9 +108,31 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"39 73\n",
" State-gov 9\n",
" Bachelors 16\n",
" 13 16\n",
" Never-married 7\n",
" Adm-clerical 15\n",
" Not-in-family 6\n",
" White 5\n",
" Male 2\n",
" 2174 119\n",
" 0 92\n",
" 40 94\n",
" United-States 42\n",
" <=50K 2\n",
"Set 3\n"
]
}
],
"source": [
"nunique = train.nunique()\n",
"types = train.dtypes\n",
Expand Down Expand Up @@ -120,7 +160,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -135,7 +175,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -158,7 +198,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -167,9 +207,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/work/pytorch_tabnet/abstract_model.py:75: UserWarning: Device used : cpu\n",
" warnings.warn(f\"Device used : {self.device}\")\n"
]
}
],
"source": [
"# TabNetPretrainer\n",
"unsupervised_model = TabNetPretrainer(\n",
Expand All @@ -193,20 +242,38 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"max_epochs = 1000 if not os.getenv(\"CI\", False) else 2"
"max_epochs = 2 if not os.getenv(\"CI\", False) else 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 27,
"metadata": {
"scrolled": false
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0 | loss: 6.48655 | val_0_unsup_loss_numpy: 2.0507700443267822| 0:00:07s\n",
"epoch 1 | loss: 1.61586 | val_0_unsup_loss_numpy: 1.2413300275802612| 0:00:15s\n",
"Stop training because you reached max_epochs = 2 with best_epoch = 1 and best_val_0_unsup_loss_numpy = 1.2413300275802612\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/work/pytorch_tabnet/callbacks.py:172: UserWarning: Best weights from best epoch are automatically used!\n",
" warnings.warn(wrn_msg)\n"
]
}
],
"source": [
"unsupervised_model.fit(\n",
" X_train=X_train,\n",
Expand All @@ -216,7 +283,6 @@
" num_workers=0,\n",
" drop_last=False,\n",
" pretraining_ratio=0.8,\n",
"\n",
") "
]
},
Expand Down
50 changes: 50 additions & 0 deletions pytorch_tabnet/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ def UnsupervisedLoss(y_pred, embedded_x, obf_vars, eps=1e-9):
return loss


def UnsupervisedLossNumpy(y_pred, embedded_x, obf_vars, eps=1e-9):
errors = y_pred - embedded_x
reconstruction_errors = np.multiply(errors, obf_vars) ** 2
batch_stds = np.std(embedded_x, axis=0) ** 2 + eps
features_loss = np.matmul(reconstruction_errors, 1 / batch_stds)
# compute the number of obfuscated variables to reconstruct
nb_reconstructed_variables = np.sum(obf_vars, axis=1)
# take the mean of the reconstructed variable errors
features_loss = features_loss / (nb_reconstructed_variables + eps)
# here we take the mean per batch, contrary to the paper
loss = np.mean(features_loss)
return loss



@dataclass
class UnsupMetricContainer:
"""Container holding a list of metrics.
Expand Down Expand Up @@ -413,6 +428,41 @@ def __call__(self, y_pred, embedded_x, obf_vars):
return loss.item()


class UnsupervisedNumpyMetric(Metric):
"""
Unsupervised metric
"""

def __init__(self):
self._name = "unsup_loss_numpy"
self._maximize = False

def __call__(self, y_pred, embedded_x, obf_vars):
"""
Compute MSE (Mean Squared Error) of predictions.
Parameters
----------
y_pred : torch.Tensor or np.array
Reconstructed prediction (with embeddings)
embedded_x : torch.Tensor
Original input embedded by network
obf_vars : torch.Tensor
Binary mask for obfuscated variables.
1 means the variables was obfuscated so reconstruction is based on this.
Returns
-------
float
MSE of predictions vs targets.
"""
return UnsupervisedLossNumpy(
y_pred,
embedded_x,
obf_vars
)


class RMSE(Metric):
"""
Root Mean Squared Error.
Expand Down
14 changes: 7 additions & 7 deletions pytorch_tabnet/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __post_init__(self):
super(TabNetPretrainer, self).__post_init__()
self._task = 'unsupervised'
self._default_loss = UnsupervisedLoss
self._default_metric = 'unsup_loss'
self._default_metric = 'unsup_loss_numpy'

def prepare_target(self, y):
return y
Expand Down Expand Up @@ -341,9 +341,9 @@ def _predict_epoch(self, name, loader):
# Main loop
for batch_idx, X in enumerate(loader):
output, embedded_x, obf_vars = self._predict_batch(X)
list_output.append(output)
list_embedded_x.append(embedded_x)
list_obfuscation.append(obf_vars)
list_output.append(output.cpu().detach().numpy())
list_embedded_x.append(embedded_x.cpu().detach().numpy())
list_obfuscation.append(obf_vars.cpu().detach().numpy())

output, embedded_x, obf_vars = self.stack_batches(list_output,
list_embedded_x,
Expand Down Expand Up @@ -372,9 +372,9 @@ def _predict_batch(self, X):
return self.network(X)

def stack_batches(self, list_output, list_embedded_x, list_obfuscation):
output = torch.cat(list_output, axis=0)
embedded_x = torch.cat(list_embedded_x, axis=0)
obf_vars = torch.cat(list_obfuscation, axis=0)
output = np.vstack(list_output)
embedded_x = np.vstack(list_embedded_x)
obf_vars = np.vstack(list_obfuscation)
return output, embedded_x, obf_vars

def predict(self, X):
Expand Down

0 comments on commit 49bd61b

Please sign in to comment.