From ddf02dab9bdc41c6d7736f0be509950e907909a4 Mon Sep 17 00:00:00 2001 From: Eduardo Carvalho Date: Mon, 20 Dec 2021 23:58:22 +0100 Subject: [PATCH] fix: replace std 0 by the mean or 1 if mean is 0 --- pytorch_tabnet/metrics.py | 12 ++++++++++-- tests/unsupervised_loss.py | 22 +++++++++++++++++----- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/pytorch_tabnet/metrics.py b/pytorch_tabnet/metrics.py index 32847c16..a9fd88fd 100644 --- a/pytorch_tabnet/metrics.py +++ b/pytorch_tabnet/metrics.py @@ -39,7 +39,11 @@ def UnsupervisedLoss(y_pred, embedded_x, obf_vars, eps=1e-9): """ errors = y_pred - embedded_x reconstruction_errors = torch.mul(errors, obf_vars) ** 2 - batch_stds = torch.std(embedded_x, dim=0) ** 2 + eps + batch_means = torch.mean(embedded_x, dim=0) + batch_means[batch_means == 0] = 1 + + batch_stds = torch.std(embedded_x, dim=0) ** 2 + batch_stds[batch_stds == 0] = batch_means[batch_stds == 0] features_loss = torch.matmul(reconstruction_errors, 1 / batch_stds) # compute the number of obfuscated variables to reconstruct nb_reconstructed_variables = torch.sum(obf_vars, dim=1) @@ -53,7 +57,11 @@ def UnsupervisedLoss(y_pred, embedded_x, obf_vars, eps=1e-9): def UnsupervisedLossNumpy(y_pred, embedded_x, obf_vars, eps=1e-9): errors = y_pred - embedded_x reconstruction_errors = np.multiply(errors, obf_vars) ** 2 - batch_stds = np.std(embedded_x, axis=0, ddof=1) ** 2 + eps + batch_means = np.mean(embedded_x, axis=0) + batch_means = np.where(batch_means == 0, 1, batch_means) + + batch_stds = np.std(embedded_x, axis=0, ddof=1) ** 2 + batch_stds = np.where(batch_stds == 0, batch_means, batch_stds) features_loss = np.matmul(reconstruction_errors, 1 / batch_stds) # compute the number of obfuscated variables to reconstruct nb_reconstructed_variables = np.sum(obf_vars, axis=1) diff --git a/tests/unsupervised_loss.py b/tests/unsupervised_loss.py index cdb046fa..4627d3e7 100644 --- a/tests/unsupervised_loss.py +++ b/tests/unsupervised_loss.py @@ -1,15 +1,27 @@ import numpy as np import torch +import pytest from pytorch_tabnet.metrics import UnsupervisedLoss, UnsupervisedLossNumpy torch.set_printoptions(precision=10) -def test_equal_losses(): - y_pred = np.random.uniform(low=-2, high=2, size=(20, 100)) - embedded_x = np.random.uniform(low=-2, high=2, size=(20, 100)) - obf_vars = np.random.choice([0, 1], size=(20, 100), replace=True) - +@pytest.mark.parametrize( + "y_pred,embedded_x,obf_vars", + [ + ( + np.random.uniform(low=-2, high=2, size=(20, 100)), + np.random.uniform(low=-2, high=2, size=(20, 100)), + np.random.choice([0, 1], size=(20, 100), replace=True) + ), + ( + np.random.uniform(low=-2, high=2, size=(30, 50)), + np.ones((30, 50)), + np.random.choice([0, 1], size=(30, 50), replace=True) + ) + ] +) +def test_equal_losses(y_pred, embedded_x, obf_vars): numpy_loss = UnsupervisedLossNumpy( y_pred=y_pred, embedded_x=embedded_x,