Skip to content

Commit

Permalink
fix: sort by cat_idx into embedding generator
Browse files Browse the repository at this point in the history
  • Loading branch information
queraq authored and Optimox committed Jun 2, 2020
1 parent fd91028 commit 9ab3ad5
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions pytorch_tabnet/tab_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,12 @@ def __init__(self, input_dim, cat_dims, cat_idxs, cat_emb_dim):
self.post_embed_dim = int(input_dim + np.sum(self.cat_emb_dims) - len(self.cat_emb_dims))

self.embeddings = torch.nn.ModuleList()

# Sort dims by cat_idx
sorted_idxs = np.argsort(cat_idxs)
cat_dims = [cat_dims[i] for i in sorted_idxs]
self.cat_emb_dims = [self.cat_emb_dims[i] for i in sorted_idxs]

for cat_dim, emb_dim in zip(cat_dims, self.cat_emb_dims):
self.embeddings.append(torch.nn.Embedding(cat_dim, emb_dim))

Expand Down

0 comments on commit 9ab3ad5

Please sign in to comment.