Skip to content

Commit

Permalink
feat: remove mask computations from forward
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp authored and Optimox committed Jun 2, 2020
1 parent 9ab3ad5 commit 44d1a47
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 37 deletions.
57 changes: 25 additions & 32 deletions pytorch_tabnet/tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,6 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None,
self.patience_counter = 0
# Saving model
self.best_network = copy.deepcopy(self.network)
# Updating feature_importances_
self.feature_importances_ = fit_metrics['train']['feature_importances_']
else:
self.patience_counter += 1

Expand Down Expand Up @@ -209,6 +207,9 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None,
# load best models post training
self.load_best_model()

# compute feature importance once the best model is defined
self._compute_feature_importances(train_dataloader)

def fit_epoch(self, train_dataloader, valid_dataloader):
"""
Evaluates and updates network for one epoch.
Expand Down Expand Up @@ -333,7 +334,7 @@ def explain(self, X):
for batch_nb, data in enumerate(dataloader):
data = data.to(self.device).float()

output, M_loss, M_explain, masks = self.network(data)
M_explain, masks = self.network.forward_masks(data)
for key, value in masks.items():
masks[key] = csc_matrix.dot(value.cpu().detach().numpy(),
self.reducing_matrix)
Expand All @@ -350,6 +351,18 @@ def explain(self, X):
res_masks[key] = np.vstack([res_masks[key], value])
return res_explain, res_masks

def _compute_feature_importances(self, loader):
self.network.eval()
feature_importances_ = np.zeros((self.network.post_embed_dim))
for data, targets in loader:
data = data.to(self.device).float()
M_explain, masks = self.network.forward_masks(data)
feature_importances_ += M_explain.sum(dim=0).cpu().detach().numpy()

feature_importances_ = csc_matrix.dot(feature_importances_,
self.reducing_matrix)
self.feature_importances_ = feature_importances_ / np.sum(feature_importances_)


class TabNetClassifier(TabModel):

Expand Down Expand Up @@ -471,7 +484,6 @@ def train_epoch(self, train_loader):
y_preds = []
ys = []
total_loss = 0
feature_importances_ = np.zeros((self.network.post_embed_dim))

for data, targets in train_loader:
batch_outs = self.train_batch(data, targets)
Expand All @@ -483,13 +495,6 @@ def train_epoch(self, train_loader):
y_preds.append(indices.cpu().detach().numpy())
ys.append(batch_outs["y"].cpu().detach().numpy())
total_loss += batch_outs["loss"]
feature_importances_ += batch_outs['batch_importance']

# Reduce to initial input_dim
feature_importances_ = csc_matrix.dot(feature_importances_,
self.reducing_matrix)
# Normalize feature_importances_
feature_importances_ = feature_importances_ / np.sum(feature_importances_)

y_preds = np.hstack(y_preds)
ys = np.hstack(ys)
Expand All @@ -501,7 +506,6 @@ def train_epoch(self, train_loader):
total_loss = total_loss / len(train_loader)
epoch_metrics = {'loss_avg': total_loss,
'stopping_loss': stopping_loss,
'feature_importances_': feature_importances_
}

if self.scheduler is not None:
Expand All @@ -525,7 +529,7 @@ def train_batch(self, data, targets):
targets = targets.to(self.device).long()
self.optimizer.zero_grad()

output, M_loss, M_explain, _ = self.network(data)
output, M_loss = self.network(data)

loss = self.loss_fn(output, targets)
loss -= self.lambda_sparse*M_loss
Expand All @@ -538,8 +542,7 @@ def train_batch(self, data, targets):
loss_value = loss.item()
batch_outs = {'loss': loss_value,
'y_preds': output,
'y': targets,
'batch_importance': M_explain.sum(dim=0).cpu().detach().numpy()}
'y': targets}
return batch_outs

def predict_epoch(self, loader):
Expand Down Expand Up @@ -599,7 +602,7 @@ def predict_batch(self, data, targets):
self.network.eval()
data = data.to(self.device).float()
targets = targets.to(self.device).long()
output, M_loss, M_explain, _ = self.network(data)
output, M_loss = self.network(data)

loss = self.loss_fn(output, targets)
loss -= self.lambda_sparse*M_loss
Expand Down Expand Up @@ -632,7 +635,7 @@ def predict(self, X):

for batch_nb, data in enumerate(dataloader):
data = data.to(self.device).float()
output, M_loss, M_explain, masks = self.network(data)
output, M_loss = self.network(data)
predictions = torch.argmax(torch.nn.Softmax(dim=1)(output),
dim=1)
predictions = predictions.cpu().detach().numpy().reshape(-1)
Expand Down Expand Up @@ -667,7 +670,7 @@ def predict_proba(self, X):
for batch_nb, data in enumerate(dataloader):
data = data.to(self.device).float()

output, M_loss, M_explain, masks = self.network(data)
output, M_loss = self.network(data)
predictions = torch.nn.Softmax(dim=1)(output).cpu().detach().numpy()
results.append(predictions)
res = np.vstack(results)
Expand Down Expand Up @@ -741,20 +744,12 @@ def train_epoch(self, train_loader):
y_preds = []
ys = []
total_loss = 0
feature_importances_ = np.zeros((self.network.post_embed_dim))

for data, targets in train_loader:
batch_outs = self.train_batch(data, targets)
y_preds.append(batch_outs["y_preds"].cpu().detach().numpy())
ys.append(batch_outs["y"].cpu().detach().numpy())
total_loss += batch_outs["loss"]
feature_importances_ += batch_outs['batch_importance']

# Reduce to initial input_dim
feature_importances_ = csc_matrix.dot(feature_importances_,
self.reducing_matrix)
# Normalize feature_importances_
feature_importances_ = feature_importances_ / np.sum(feature_importances_)

y_preds = np.vstack(y_preds)
ys = np.vstack(ys)
Expand All @@ -763,7 +758,6 @@ def train_epoch(self, train_loader):
total_loss = total_loss / len(train_loader)
epoch_metrics = {'loss_avg': total_loss,
'stopping_loss': stopping_loss,
'feature_importances_': feature_importances_
}

if self.scheduler is not None:
Expand All @@ -788,7 +782,7 @@ def train_batch(self, data, targets):
targets = targets.to(self.device).float()
self.optimizer.zero_grad()

output, M_loss, M_explain, _ = self.network(data)
output, M_loss = self.network(data)

loss = self.loss_fn(output, targets)
loss -= self.lambda_sparse*M_loss
Expand All @@ -801,8 +795,7 @@ def train_batch(self, data, targets):
loss_value = loss.item()
batch_outs = {'loss': loss_value,
'y_preds': output,
'y': targets,
'batch_importance': M_explain.sum(dim=0).cpu().detach().numpy()}
'y': targets}
return batch_outs

def predict_epoch(self, loader):
Expand Down Expand Up @@ -855,7 +848,7 @@ def predict_batch(self, data, targets):
data = data.to(self.device).float()
targets = targets.to(self.device).float()

output, M_loss, M_explain, _ = self.network(data)
output, M_loss = self.network(data)

loss = self.loss_fn(output, targets)
loss -= self.lambda_sparse*M_loss
Expand Down Expand Up @@ -890,7 +883,7 @@ def predict(self, X):
for batch_nb, data in enumerate(dataloader):
data = data.to(self.device).float()

output, M_loss, M_explain, masks = self.network(data)
output, M_loss = self.network(data)
predictions = output.cpu().detach().numpy()
results.append(predictions)
res = np.vstack(results)
Expand Down
38 changes: 33 additions & 5 deletions pytorch_tabnet/tab_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __init__(self, input_dim, output_dim,
self.n_shared = n_shared
self.virtual_batch_size = virtual_batch_size

self.initial_bn = BatchNorm1d(self.input_dim, momentum=0.01)

if self.n_shared > 0:
shared_feat_transform = torch.nn.ModuleList()
for i in range(self.n_shared):
Expand Down Expand Up @@ -120,6 +122,32 @@ def __init__(self, input_dim, output_dim,

def forward(self, x):
res = 0
x = self.initial_bn(x)

prior = torch.ones(x.shape).to(x.device)
M_loss = 0
att = self.initial_splitter(x)[:, self.n_d:]

for step in range(self.n_steps):
M = self.att_transformers[step](prior, att)
M_loss += torch.mean(torch.sum(torch.mul(M, torch.log(M+self.epsilon)),
dim=1))
# update prior
prior = torch.mul(self.gamma - M, prior)
# output
masked_x = torch.mul(M, x)
out = self.feat_transformers[step](masked_x)
d = ReLU()(out[:, :self.n_d])
res = torch.add(res, d)
# update attention
att = out[:, self.n_d:]

M_loss /= self.n_steps
res = self.final_mapping(res)
return res, M_loss

def forward_masks(self, x):
x = self.initial_bn(x)

prior = torch.ones(x.shape).to(x.device)
M_explain = torch.zeros(x.shape).to(x.device)
Expand All @@ -138,15 +166,13 @@ def forward(self, x):
masked_x = torch.mul(M, x)
out = self.feat_transformers[step](masked_x)
d = ReLU()(out[:, :self.n_d])
res = torch.add(res, d)
# explain
step_importance = torch.sum(d, dim=1)
M_explain += torch.mul(M, step_importance.unsqueeze(dim=1))
# update attention
att = out[:, self.n_d:]

res = self.final_mapping(res)
return res, M_loss, M_explain, masks
return M_explain, masks


class TabNet(torch.nn.Module):
Expand Down Expand Up @@ -215,7 +241,6 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8,
self.tabnet = TabNetNoEmbeddings(self.post_embed_dim, output_dim, n_d, n_a, n_steps,
gamma, n_independent, n_shared, epsilon,
virtual_batch_size, momentum)
self.initial_bn = BatchNorm1d(self.post_embed_dim, momentum=0.01)

# Defining device
if device_name == 'auto':
Expand All @@ -228,9 +253,12 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8,

def forward(self, x):
x = self.embedder(x)
x = self.initial_bn(x)
return self.tabnet(x)

def forward_masks(self, x):
x = self.embedder(x)
return self.tabnet.forward_masks(x)


class AttentiveTransformer(torch.nn.Module):
def __init__(self, input_dim, output_dim, virtual_batch_size=128, momentum=0.02):
Expand Down

0 comments on commit 44d1a47

Please sign in to comment.