Skip to content

Commit

Permalink
feat: add check nan and inf
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Oct 12, 2020
1 parent b01339a commit d871406
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
create_explain_matrix,
validate_eval_set,
create_dataloaders,
check_nans,
)
from pytorch_tabnet.callbacks import (
CallbackContainer,
Expand Down Expand Up @@ -140,6 +141,8 @@ def fit(
else:
self.loss_fn = loss_fn

check_nans(X_train)
check_nans(y_train)
self.update_fit_params(
X_train, y_train, eval_set, weights,
)
Expand Down
9 changes: 9 additions & 0 deletions pytorch_tabnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def validate_eval_set(eval_set, eval_name, X_train, y_train):
len(elem) == 2 for elem in eval_set
), "Each tuple of eval_set need to have two elements"
for name, (X, y) in zip(eval_name, eval_set):
check_nans(X)
check_nans(y)
msg = (
f"Number of columns is different between X_{name} "
+ f"({X.shape[1]}) and X_train ({X_train.shape[1]})"
Expand All @@ -255,3 +257,10 @@ def validate_eval_set(eval_set, eval_name, X_train, y_train):
assert X.shape[0] == y.shape[0], msg

return eval_name, eval_set


def check_nans(array):
if np.isnan(array).any():
raise ValueError("NaN were found, TabNet does not allow nans.")
if np.isinf(array).any():
raise ValueError("Infinite values were found, TabNet does not allow inf.")

0 comments on commit d871406

Please sign in to comment.