Skip to content

Commit

Permalink
feat: speed boost and code simplification for GBN
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp authored and Optimox committed May 9, 2020
1 parent 313d074 commit 1642909
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions pytorch_tabnet/tab_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,10 @@ def __init__(self, input_dim, virtual_batch_size=128, momentum=0.01):
self.bn = BatchNorm1d(self.input_dim, momentum=momentum)

def forward(self, x):
chunks = x.chunk(x.shape[0] // self.virtual_batch_size +
((x.shape[0] % self.virtual_batch_size) > 0))
res = torch.Tensor([]).to(x.device)
for x_ in chunks:
y = self.bn(x_)
res = torch.cat([res, y], dim=0)

return res
chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0)
res = [self.bn(x_) for x_ in chunks]

return torch.cat(res, dim=0)


class TabNetNoEmbeddings(torch.nn.Module):
Expand Down

0 comments on commit 1642909

Please sign in to comment.