Skip to content

Commit

Permalink
fix: functional balanced version
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Nov 20, 2019
1 parent 0bf45d2 commit fab7f16
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions pytorch_tabnet/tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,9 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None,
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
train_dataloader = DataLoader(TorchDataset(X_train, y_train),
batch_size=self.batch_size, sampler=sampler)
valid_dataloader = DataLoader(TorchDataset(X_valid, y_valid),
batch_size=self.batch_size, shuffle=False)

train_dataloader = DataLoader(TorchDataset(X_train, y_train),
batch_size=self.batch_size, shuffle=True)
else:
train_dataloader = DataLoader(TorchDataset(X_train, y_train),
batch_size=self.batch_size, shuffle=True)
valid_dataloader = DataLoader(TorchDataset(X_valid, y_valid),
batch_size=self.batch_size, shuffle=False)

Expand Down

0 comments on commit fab7f16

Please sign in to comment.