Skip to content

Commit

Permalink
fix: add softmax to predict_proba
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox authored and Hartorn committed Nov 4, 2019
1 parent fede1ec commit bea966f
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
../.history/
data/
.ipynb_checkpoints/
*.pt

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down Expand Up @@ -128,4 +129,3 @@ dmypy.json

# Pyre type checker
.pyre/

45 changes: 42 additions & 3 deletions census_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
"metadata": {},
"outputs": [],
"source": [
"unused_feat = []\n",
"unused_feat = ['Set']\n",
"\n",
"features = [ col for col in train.columns if col not in unused_feat+[target]] \n",
"\n",
Expand Down Expand Up @@ -259,7 +259,7 @@
"source": [
"model.load_best_model()\n",
"\n",
"preds, M_explain, masks = model.predict_proba(X_test)\n",
"preds = model.predict_proba(X_test)\n",
"\n",
"y_true = y_test\n",
"\n",
Expand All @@ -269,6 +269,45 @@
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_auc}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Local explainability and masks"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"explain_matrix, masks = model.explain(X_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axs = plt.subplots(1, network_params['n_steps'])\n",
"\n",
"for i in range(network_params['n_steps']):\n",
" axs[i].imshow(masks[i][:50])\n",
" axs[i].set_title(f\"mask {i}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -359,4 +398,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
36 changes: 34 additions & 2 deletions forest_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@
"source": [
"model.load_best_model()\n",
"\n",
"preds, M_explain, masks = model.predict_proba(X_test)\n",
"preds = model.predict_proba(X_test)\n",
"\n",
"y_true = y_test\n",
"\n",
Expand All @@ -312,6 +312,38 @@
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_acc}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Local explainability and masks"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"explain_matrix, masks = model.explain(X_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"fig, axs = plt.subplots(1, network_params['n_steps'])\n",
"\n",
"for i in range(network_params['n_steps']):\n",
" axs[i].imshow(masks[i][:50])\n",
" axs[i].set_title(f\"mask {i}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -395,4 +427,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
122 changes: 92 additions & 30 deletions pytorch_tabnet/tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,7 @@
from torch.nn.utils import clip_grad_norm_
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler


class TorchDataset(Dataset):
"""
Format for numpy array
Parameters
----------
X: 2D array
The input matrix
y: 2D array
The one-hot encoded target
"""

def __init__(self, x, y):
self.x = x
self.y = y
self.timer = []

def __len__(self):
return len(self.x)

def __getitem__(self, index):
x, y = self.x[index], self.y[index]
return x, y
from pytorch_tabnet.utils import TorchDataset, PredictDataset


class Model(object):
Expand Down Expand Up @@ -434,7 +410,7 @@ def load_best_model(self):

def predict_proba(self, X):
"""
Make predictions on a batch (valid)
Make predictions for classification on a batch (valid)
Parameters
----------
Expand All @@ -448,9 +424,95 @@ def predict_proba(self, X):
batch_outs: dict
"""
self.network.eval()
data = torch.Tensor(X).to(self.device).float()

output, M_loss, M_explain, masks = self.network(data)
predictions = output.cpu().detach().numpy()
dataloader = DataLoader(PredictDataset(X),
batch_size=self.batch_size, shuffle=False)

for batch_nb, data in enumerate(dataloader):
data = data.to(self.device).float()

output, M_loss, M_explain, masks = self.network(data)
predictions = torch.nn.Softmax(dim=1)(output).cpu().detach().numpy()
if batch_nb == 0:
res = predictions
else:
res = np.vstack([res, predictions])
return res

def predict(self, X):
"""
Make predictions on a batch (valid)
Parameters
----------
data: a :tensor: `torch.Tensor`
Input data
target: a :tensor: `torch.Tensor`
Target data
Returns
-------
predictions: np.array
Predictions of the regression problem or the last class
"""
self.network.eval()
dataloader = DataLoader(PredictDataset(X),
batch_size=self.batch_size, shuffle=False)

for batch_nb, data in enumerate(dataloader):
data = data.to(self.device).float()

output, M_loss, M_explain, masks = self.network(data)
if self.output_dim == 1:
predictions = output.cpu().detach().numpy().reshape(-1)
else:
predictions = torch.argmax(torch.nn.Softmax(dim=1)(output),
dim=1)
predictions = predictions.cpu().detach().numpy().reshape(-1)

if batch_nb == 0:
res = predictions
else:
res = np.hstack([res, predictions])

return res

return predictions, M_explain, masks
def explain(self, X):
"""
Return local explanation
Parameters
----------
data: a :tensor: `torch.Tensor`
Input data
target: a :tensor: `torch.Tensor`
Target data
Returns
-------
M_explain: matrix
Importance per sample, per columns.
masks: matrix
Sparse matrix showing attention masks used by network.
"""
self.network.eval()

dataloader = DataLoader(PredictDataset(X),
batch_size=self.batch_size, shuffle=False)

for batch_nb, data in enumerate(dataloader):
data = data.to(self.device).float()

output, M_loss, M_explain, masks = self.network(data)
for key, value in masks.items():
masks[key] = value.cpu().detach().numpy()

if batch_nb == 0:
res_explain = M_explain.cpu().detach().numpy()
res_masks = masks
else:
res_explain = np.vstack([res_explain,
M_explain.cpu().detach().numpy()])
for key, value in masks.items():
res_masks[key] = np.vstack([res_masks[key], value])
return M_explain, res_masks
46 changes: 46 additions & 0 deletions pytorch_tabnet/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from torch.utils.data import Dataset


class TorchDataset(Dataset):
"""
Format for numpy array
Parameters
----------
X: 2D array
The input matrix
y: 2D array
The one-hot encoded target
"""

def __init__(self, x, y):
self.x = x
self.y = y

def __len__(self):
return len(self.x)

def __getitem__(self, index):
x, y = self.x[index], self.y[index]
return x, y


class PredictDataset(Dataset):
"""
Format for numpy array
Parameters
----------
X: 2D array
The input matrix
"""

def __init__(self, x):
self.x = x

def __len__(self):
return len(self.x)

def __getitem__(self, index):
x = self.x[index]
return x

0 comments on commit bea966f

Please sign in to comment.