Skip to content

Commit

Permalink
feat/336 : check if pandas df and drop_last default to True
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Nov 12, 2021
1 parent 233f74e commit a0fd306
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 10 deletions.
7 changes: 4 additions & 3 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
create_dataloaders,
define_device,
ComplexEncoder,
check_input
)
from pytorch_tabnet.callbacks import (
CallbackContainer,
Expand All @@ -22,7 +23,7 @@
)
from pytorch_tabnet.metrics import MetricContainer, check_metrics
from sklearn.base import BaseEstimator
from sklearn.utils import check_array

from torch.utils.data import DataLoader
import io
import json
Expand Down Expand Up @@ -115,7 +116,7 @@ def fit(
batch_size=1024,
virtual_batch_size=128,
num_workers=0,
drop_last=False,
drop_last=True,
callbacks=None,
pin_memory=True,
from_unsupervised=None,
Expand Down Expand Up @@ -182,7 +183,7 @@ def fit(
else:
self.loss_fn = loss_fn

check_array(X_train)
check_input(X_train)

self.update_fit_params(
X_train,
Expand Down
8 changes: 4 additions & 4 deletions pytorch_tabnet/pretraining.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import torch
import numpy as np
from sklearn.utils import check_array
from torch.utils.data import DataLoader
from pytorch_tabnet import tab_network
from pytorch_tabnet.utils import (
create_explain_matrix,
filter_weights,
PredictDataset
PredictDataset,
check_input
)
from torch.nn.utils import clip_grad_norm_
from pytorch_tabnet.pretraining_utils import (
Expand Down Expand Up @@ -55,7 +55,7 @@ def fit(
batch_size=1024,
virtual_batch_size=128,
num_workers=0,
drop_last=False,
drop_last=True,
callbacks=None,
pin_memory=True,
):
Expand Down Expand Up @@ -118,7 +118,7 @@ def fit(
else:
self.loss_fn = loss_fn

check_array(X_train)
check_input(X_train)

self.update_fit_params(
weights,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_tabnet/pretraining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from pytorch_tabnet.utils import (
create_sampler,
PredictDataset,
check_input
)
from sklearn.utils import check_array


def create_dataloaders(
Expand Down Expand Up @@ -93,7 +93,7 @@ def validate_eval_set(eval_set, eval_name, X_train):
), "eval_set and eval_name have not the same length"

for set_nb, X in enumerate(eval_set):
check_array(X)
check_input(X)
msg = (
f"Number of columns is different between eval set {set_nb}"
+ f"({X.shape[1]}) and X_train ({X_train.shape[1]})"
Expand Down
15 changes: 14 additions & 1 deletion pytorch_tabnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import scipy
import json
from sklearn.utils import check_array
import pandas as pd


class TorchDataset(Dataset):
Expand Down Expand Up @@ -271,7 +272,7 @@ def validate_eval_set(eval_set, eval_name, X_train, y_train):
len(elem) == 2 for elem in eval_set
), "Each tuple of eval_set need to have two elements"
for name, (X, y) in zip(eval_name, eval_set):
check_array(X)
check_input(X)
msg = (
f"Dimension mismatch between X_{name} "
+ f"{X.shape} and X_train {X_train.shape}"
Expand Down Expand Up @@ -337,3 +338,15 @@ def default(self, obj):
return int(obj)
# Let the base class default method raise the TypeError
return json.JSONEncoder.default(self, obj)


def check_input(X):
"""
Raise a clear error if X is a pandas dataframe
and check array according to scikit rules
"""
if isinstance(X, (pd.DataFrame, pd.Series)):
err_message = "Pandas DataFrame are not supported: apply X.values when calling fit"
raise(ValueError, err_message)
check_array(X)
return

0 comments on commit a0fd306

Please sign in to comment.