diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index 516211a8..6df98528 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -32,14 +32,10 @@ def __init__(self, input_dim, virtual_batch_size=128, momentum=0.01): self.bn = BatchNorm1d(self.input_dim, momentum=momentum) def forward(self, x): - chunks = x.chunk(x.shape[0] // self.virtual_batch_size + - ((x.shape[0] % self.virtual_batch_size) > 0)) - res = torch.Tensor([]).to(x.device) - for x_ in chunks: - y = self.bn(x_) - res = torch.cat([res, y], dim=0) - - return res + chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0) + res = [self.bn(x_) for x_ in chunks] + + return torch.cat(res, dim=0) class TabNetNoEmbeddings(torch.nn.Module):