diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index 55e84e06..597e7f8a 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -117,7 +117,7 @@ def __init__(self, input_dim, output_dim, n_d, n_a, self.embeddings.append(torch.nn.Embedding(cat_dim, emb_dim)) # record continuous indices - self.continuous_idx = torch.ones(self.input_dim, dtype=torch.uint8) + self.continuous_idx = torch.ones(self.input_dim, dtype=torch.bool) self.continuous_idx[self.cat_idxs] = 0 self.post_embed_dim = self.input_dim + (cat_emb_dim - 1)*len(self.cat_idxs) self.initial_bn = BatchNorm1d(self.post_embed_dim, momentum=0.01)