From cc57d62698ef629d63dcc8878d4d48f231f3cd77 Mon Sep 17 00:00:00 2001 From: Quentin Raquet Date: Thu, 8 Oct 2020 14:41:01 +0200 Subject: [PATCH] feat: refacto models with metrics and callbacks --- pytorch_tabnet/abstract_model.py | 828 +++++++++++++++++------------ pytorch_tabnet/multiclass_utils.py | 179 +++++-- pytorch_tabnet/multitask.py | 335 +++--------- pytorch_tabnet/tab_model.py | 537 +++---------------- pytorch_tabnet/utils.py | 111 +++- 5 files changed, 864 insertions(+), 1126 deletions(-) diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index 2944a68c..d3c8ec31 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -1,14 +1,25 @@ +from dataclasses import dataclass, field +from typing import List, Any, Dict import torch +from torch.nn.utils import clip_grad_norm_ import numpy as np from scipy.sparse import csc_matrix -import time from abc import abstractmethod from pytorch_tabnet import tab_network -from pytorch_tabnet.utils import (PredictDataset, - create_explain_matrix) +from pytorch_tabnet.utils import ( + PredictDataset, + create_explain_matrix, + validate_eval_set, + create_dataloaders, +) +from pytorch_tabnet.callbacks import ( + CallbackContainer, + History, + EarlyStopping, +) +from pytorch_tabnet.metrics import MetricContainer, check_metrics from sklearn.base import BaseEstimator from torch.utils.data import DataLoader -from copy import deepcopy import io import json from pathlib import Path @@ -16,119 +27,64 @@ import zipfile +@dataclass class TabModel(BaseEstimator): - def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[], cat_emb_dim=1, - n_independent=2, n_shared=2, epsilon=1e-15, momentum=0.02, - lambda_sparse=1e-3, seed=0, - clip_value=1, verbose=1, - optimizer_fn=torch.optim.Adam, - optimizer_params=dict(lr=2e-2), - scheduler_params=None, scheduler_fn=None, - mask_type="sparsemax", - input_dim=None, output_dim=None, - device_name='auto'): - """ Class for TabNet model - - Parameters - ---------- - device_name: str - 'cuda' if running on GPU, 'cpu' if not, 'auto' to autodetect - """ - - self.n_d = n_d - self.n_a = n_a - self.n_steps = n_steps - self.gamma = gamma - self.cat_idxs = cat_idxs - self.cat_dims = cat_dims - self.cat_emb_dim = cat_emb_dim - self.n_independent = n_independent - self.n_shared = n_shared - self.epsilon = epsilon - self.momentum = momentum - self.lambda_sparse = lambda_sparse - self.clip_value = clip_value - self.verbose = verbose - self.optimizer_fn = optimizer_fn - self.optimizer_params = optimizer_params - self.device_name = device_name - self.scheduler_params = scheduler_params - self.scheduler_fn = scheduler_fn - self.mask_type = mask_type - self.input_dim = input_dim - self.output_dim = output_dim - + """ Class for TabNet model.""" + + n_d: int = 8 + n_a: int = 8 + n_steps: int = 3 + gamma: float = 1.3 + cat_idxs: List[int] = field(default_factory=list) + cat_dims: List[int] = field(default_factory=list) + cat_emb_dim: int = 1 + n_independent: int = 2 + n_shared: int = 2 + epsilon: float = 1e-15 + momentum: float = 0.02 + lambda_sparse: float = 1e-3 + seed: int = 0 + clip_value: int = 1 + verbose: int = 1 + optimizer_fn: Any = torch.optim.Adam + optimizer_params: Dict = field(default_factory=lambda: dict(lr=2e-2)) + scheduler_fn: Any = None + scheduler_params: Dict = field(default_factory=dict) + mask_type: str = "sparsemax" + input_dim: int = None + output_dim: int = None + device_name: str = "auto" + + def __post_init__(self): self.batch_size = 1024 - - self.seed = seed + self.virtual_batch_size = 1024 torch.manual_seed(self.seed) # Defining device - if device_name == 'auto': + if self.device_name == "auto": if torch.cuda.is_available(): - device_name = 'cuda' + device_name = "cuda" else: - device_name = 'cpu' + device_name = "cpu" self.device = torch.device(device_name) print(f"Device used : {self.device}") - @abstractmethod - def construct_loaders(self, X_train, y_train, X_valid, y_valid, - weights, batch_size, num_workers, drop_last): - """ - Returns - ------- - train_dataloader, valid_dataloader : torch.DataLoader, torch.DataLoader - Training and validation dataloaders - ------- - """ - raise NotImplementedError('users must define construct_loaders to use this base class') - - def init_network( - self, - input_dim, - output_dim, - n_d, - n_a, - n_steps, - gamma, - cat_idxs, - cat_dims, - cat_emb_dim, - n_independent, - n_shared, - epsilon, - virtual_batch_size, - momentum, - device_name, - mask_type, - ): - self.network = tab_network.TabNet( - input_dim, - output_dim, - n_d=n_d, - n_a=n_a, - n_steps=n_steps, - gamma=gamma, - cat_idxs=cat_idxs, - cat_dims=cat_dims, - cat_emb_dim=cat_emb_dim, - n_independent=n_independent, - n_shared=n_shared, - epsilon=epsilon, - virtual_batch_size=virtual_batch_size, - momentum=momentum, - device_name=device_name, - mask_type=mask_type).to(self.device) - - self.reducing_matrix = create_explain_matrix( - self.network.input_dim, - self.network.cat_emb_dim, - self.network.cat_idxs, - self.network.post_embed_dim) - - def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None, - weights=0, max_epochs=100, patience=10, batch_size=1024, - virtual_batch_size=128, num_workers=0, drop_last=False): + def fit( + self, + X_train, + y_train, + eval_set=None, + eval_name=None, + eval_metric=None, + loss_fn=None, + weights=0, + max_epochs=100, + patience=10, + batch_size=1024, + virtual_batch_size=128, + num_workers=0, + drop_last=False, + callbacks=None, + ): """Train a neural network stored in self.network Using train_dataloader for training data and valid_dataloader for validation. @@ -139,10 +95,14 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None, Train set y_train : np.array Train targets - X_train: np.ndarray - Train set - y_train : np.array - Train targets + eval_set: list of tuple + List of eval tuple set (X, y). + The last one is used for early stopping + eval_name: list of str + List of eval set names. + eval_metric : list of str + List of evaluation metrics. + The last metric is used for early stopping. weights : bool or dictionnary 0 for no balancing 1 for automated balancing @@ -159,125 +119,155 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None, Number of workers used in torch.utils.data.DataLoader drop_last : bool Whether to drop last batch during training + callbacks : list of callback function + List of custom callbacks """ # update model name - self.update_fit_params(X_train, y_train, X_valid, y_valid, loss_fn, - weights, max_epochs, patience, batch_size, - virtual_batch_size, num_workers, drop_last) - - train_dataloader, valid_dataloader = self.construct_loaders(X_train, - y_train, - X_valid, - y_valid, - self.updated_weights, - self.batch_size, - self.num_workers, - self.drop_last) - - self.init_network( - input_dim=self.input_dim, - output_dim=self.output_dim, - n_d=self.n_d, - n_a=self.n_a, - n_steps=self.n_steps, - gamma=self.gamma, - cat_idxs=self.cat_idxs, - cat_dims=self.cat_dims, - cat_emb_dim=self.cat_emb_dim, - n_independent=self.n_independent, - n_shared=self.n_shared, - epsilon=self.epsilon, - virtual_batch_size=self.virtual_batch_size, - momentum=self.momentum, - device_name=self.device_name, - mask_type=self.mask_type - ) + self.max_epochs = max_epochs + self.patience = patience + self.batch_size = batch_size + self.virtual_batch_size = virtual_batch_size + self.num_workers = num_workers + self.drop_last = drop_last + self.input_dim = X_train.shape[1] + self._stop_training = False - self.optimizer = self.optimizer_fn(self.network.parameters(), - **self.optimizer_params) + if eval_set is None: + eval_set = [] - if self.scheduler_fn: - self.scheduler = self.scheduler_fn(self.optimizer, **self.scheduler_params) + if loss_fn is None: + self.loss_fn = self._default_loss else: - self.scheduler = None - - self.losses_train = [] - self.losses_valid = [] - self.learning_rates = [] - self.metrics_train = [] - self.metrics_valid = [] - - if self.verbose > 0: - print("Will train until validation stopping metric", - f"hasn't improved in {self.patience} rounds.") - msg_epoch = f'| EPOCH | train | valid | total time (s)' - print('---------------------------------------') - print(msg_epoch) - - total_time = 0 - while (self.epoch < self.max_epochs and - self.patience_counter < self.patience): - starting_time = time.time() - # updates learning rate history - self.learning_rates.append(self.optimizer.param_groups[-1]["lr"]) - - fit_metrics = self.fit_epoch(train_dataloader, valid_dataloader) - - # leaving it here, may be used for callbacks later - self.losses_train.append(fit_metrics['train']['loss_avg']) - self.losses_valid.append(fit_metrics['valid']['total_loss']) - self.metrics_train.append(fit_metrics['train']['stopping_loss']) - self.metrics_valid.append(fit_metrics['valid']['stopping_loss']) - - stopping_loss = fit_metrics['valid']['stopping_loss'] - if stopping_loss < self.best_cost: - self.best_cost = stopping_loss - self.patience_counter = 0 - # Saving model - self.best_network = deepcopy(self.network) - else: - self.patience_counter += 1 - - if self.scheduler is not None: - if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau): - self.scheduler.step(stopping_loss) - else: - self.scheduler.step() - - self.epoch += 1 - total_time += time.time() - starting_time - if self.verbose > 0: - if self.epoch % self.verbose == 0: - separator = "|" - msg_epoch = f"| {self.epoch:<5} | " - msg_epoch += f"{-fit_metrics['train']['stopping_loss']:.5f}" - msg_epoch += f' {separator:<2} ' - msg_epoch += f"{-fit_metrics['valid']['stopping_loss']:.5f}" - msg_epoch += f' {separator:<2} ' - msg_epoch += f" {np.round(total_time, 1):<10}" - print(msg_epoch) - - if self.verbose > 0: - if self.patience_counter == self.patience: - print(f"Early stopping occured at epoch {self.epoch}") - print(f"Training done in {total_time:.3f} seconds.") - print('---------------------------------------') - - self.history = {"train": {"loss": self.losses_train, - "metric": self.metrics_train, - "lr": self.learning_rates}, - "valid": {"loss": self.losses_valid, - "metric": self.metrics_valid}} - # load best models post training - self.load_best_model() + self.loss_fn = loss_fn + + self.update_fit_params( + X_train, y_train, eval_set, weights, + ) + + # Validate and reformat eval set depending on training data + eval_names, eval_set = validate_eval_set(eval_set, eval_name, X_train, y_train) + + train_dataloader, valid_dataloaders = self._construct_loaders( + X_train, y_train, eval_set + ) + + self._set_network() + self._set_metrics(eval_metric, eval_names) + self._set_callbacks(callbacks) + self._set_optimizer() + self._set_scheduler() + + # Call method on_train_begin for all callbacks + self._callback_container.on_train_begin() + + # Training loop over epochs + for epoch_idx in range(self.max_epochs): + + # Call method on_epoch_begin for all callbacks + self._callback_container.on_epoch_begin(epoch_idx) + + self._train_epoch(train_dataloader) + + # Apply predict epoch to all eval sets + for eval_name, valid_dataloader in zip(eval_names, valid_dataloaders): + self._predict_epoch(eval_name, valid_dataloader) + + # Call method on_epoch_end for all callbacks + self._callback_container.on_epoch_end(epoch_idx, self.history.batch_metrics) + + if self._stop_training: + break + + # Call method on_train_end for all callbacks + self._callback_container.on_train_end() + self.network.eval() # compute feature importance once the best model is defined self._compute_feature_importances(train_dataloader) - def save_model(self, path): + def predict(self, X): + """ + Make predictions on a batch (valid) + + Parameters + ---------- + X: a :tensor: `torch.Tensor` + Input data + + Returns + ------- + predictions: np.array + Predictions of the regression problem + """ + self.network.eval() + dataloader = DataLoader( + PredictDataset(X), batch_size=self.batch_size, shuffle=False + ) + + results = [] + for batch_nb, data in enumerate(dataloader): + data = data.to(self.device).float() + output, M_loss = self.network(data) + predictions = output.cpu().detach().numpy() + results.append(predictions) + res = np.vstack(results) + return self.predict_func(res) + + def explain(self, X): + """ + Return local explanation + + Parameters + ---------- + X: tensor: `torch.Tensor` + Input data + + Returns + ------- + M_explain: matrix + Importance per sample, per columns. + masks: matrix + Sparse matrix showing attention masks used by network. """ - Saving model with two distinct files. + self.network.eval() + + dataloader = DataLoader( + PredictDataset(X), batch_size=self.batch_size, shuffle=False + ) + + res_explain = [] + + for batch_nb, data in enumerate(dataloader): + data = data.to(self.device).float() + + M_explain, masks = self.network.forward_masks(data) + for key, value in masks.items(): + masks[key] = csc_matrix.dot( + value.cpu().detach().numpy(), self.reducing_matrix + ) + + res_explain.append( + csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix) + ) + + if batch_nb == 0: + res_masks = masks + else: + for key, value in masks.items(): + res_masks[key] = np.vstack([res_masks[key], value]) + + res_explain = np.vstack(res_explain) + + return res_explain, res_masks + + def save_model(self, path): + """Saving TabNet model in two distinct files. + + Parameters + ---------- + filepath : str + Path of the model. """ saved_params = {} for key, val in self.get_params().items(): @@ -296,13 +286,19 @@ def save_model(self, path): # Save state_dict torch.save(self.network.state_dict(), Path(path).joinpath("network.pt")) - shutil.make_archive(path, 'zip', path) + shutil.make_archive(path, "zip", path) shutil.rmtree(path) print(f"Successfully saved model at {path}.zip") return f"{path}.zip" def load_model(self, filepath): + """Load TabNet model. + Parameters + ---------- + filepath : str + Path of the model. + """ try: with zipfile.ZipFile(filepath) as z: with z.open("model_params.json") as f: @@ -320,170 +316,278 @@ def load_model(self, filepath): self.__init__(**loaded_params) - self.init_network( - input_dim=self.input_dim, - output_dim=self.output_dim, - n_d=self.n_d, - n_a=self.n_a, - n_steps=self.n_steps, - gamma=self.gamma, - cat_idxs=self.cat_idxs, - cat_dims=self.cat_dims, - cat_emb_dim=self.cat_emb_dim, - n_independent=self.n_independent, - n_shared=self.n_shared, - epsilon=self.epsilon, - virtual_batch_size=1024, - momentum=self.momentum, - device_name=self.device_name, - mask_type=self.mask_type - ) + self._set_network() self.network.load_state_dict(saved_state_dict) self.network.eval() return - def fit_epoch(self, train_dataloader, valid_dataloader): + def _train_epoch(self, train_loader): """ - Evaluates and updates network for one epoch. + Trains one epoch of the network in self.network Parameters ---------- - train_dataloader: a :class: `torch.utils.data.Dataloader` + train_loader: a :class: `torch.utils.data.Dataloader` DataLoader with train set - valid_dataloader: a :class: `torch.utils.data.Dataloader` - DataLoader with valid set """ - train_metrics = self.train_epoch(train_dataloader) - valid_metrics = self.predict_epoch(valid_dataloader) + self.network.train() - fit_metrics = {'train': train_metrics, - 'valid': valid_metrics} + for batch_idx, (X, y) in enumerate(train_loader): + self._callback_container.on_batch_begin(batch_idx) - return fit_metrics + batch_logs = self._train_batch(X, y) - @abstractmethod - def train_epoch(self, train_loader): - """ - Trains one epoch of the network in self.network + self._callback_container.on_batch_end(batch_idx, batch_logs) - Parameters - ---------- - train_loader: a :class: `torch.utils.data.Dataloader` - DataLoader with train set - """ - raise NotImplementedError('users must define train_epoch to use this base class') + epoch_logs = {"lr": self._optimizer.param_groups[-1]["lr"]} + self.history.batch_metrics.update(epoch_logs) - @abstractmethod - def train_batch(self, data, targets): + return + + def _train_batch(self, X, y): """ Trains one batch of data Parameters ---------- - data: a :tensor: `torch.tensor` - Input data - target: a :tensor: `torch.tensor` - Target data + X: torch.tensor + Train matrix + y: torch.tensor + Target matrix + + Returns + ------- + batch_outs : dict + Dictionnary with "y": target and "score": prediction scores. + batch_logs : dict + Dictionnary with "batch_size" and "loss". """ - raise NotImplementedError('users must define train_batch to use this base class') + batch_logs = {"batch_size": X.shape[0]} - @abstractmethod - def predict_epoch(self, loader): + X = X.to(self.device).float() + y = y.to(self.device).float() + + self._optimizer.zero_grad() + + output, M_loss = self.network(X) + + loss = self.compute_loss(output, y) + # Add the overall sparsity loss + loss -= self.lambda_sparse * M_loss + + # Perform backward pass and optimization + loss.backward() + if self.clip_value: + clip_grad_norm_(self.network.parameters(), self.clip_value) + self._optimizer.step() + + batch_logs["loss"] = loss.cpu().detach().numpy().item() + + if self._scheduler is not None: + self._scheduler.step() + + return batch_logs + + def _predict_epoch(self, name, loader): """ - Validates one epoch of the network in self.network + Predict an epoch and update metrics. Parameters ---------- - loader: a :class: `torch.utils.data.Dataloader` + name: str + Name of the validation set + loader: torch.utils.data.Dataloader DataLoader with validation set """ - raise NotImplementedError('users must define predict_epoch to use this base class') + # Setting network on evaluation mode (no dropout etc...) + self.network.eval() - @abstractmethod - def predict_batch(self, data, targets): + list_y_true = [] + list_y_score = [] + + # Main loop + for batch_idx, (X, y) in enumerate(loader): + scores = self._predict_batch(X) + list_y_true.append(y) + list_y_score.append(scores) + + y_true, scores = self.stack_batches(list_y_true, list_y_score) + + metrics_logs = self._metric_container_dict[name](y_true, scores) + self.network.train() + self.history.batch_metrics.update(metrics_logs) + return + + def _predict_batch(self, X): """ - Make predictions on a batch (valid) + Predict one batch of data. Parameters ---------- - data: a :tensor: `torch.Tensor` - Input data - target: a :tensor: `torch.Tensor` - Target data + x: torch.tensor + Owned products Returns ------- - batch_outs: dict + np.array + model scores """ - raise NotImplementedError('users must define predict_batch to use this base class') + X = X.to(self.device).float() - def load_best_model(self): - if self.best_network is not None: - self.network = self.best_network + # compute model output + scores, _ = self.network(X) - @abstractmethod - def predict(self, X): - """ - Make predictions on a batch (valid) + if isinstance(scores, list): + scores = [x.cpu().detach().numpy() for x in scores] + else: + scores = scores.cpu().detach().numpy() + + return scores + + def _set_network(self): + """Setup the network and explain matrix.""" + self.network = tab_network.TabNet( + self.input_dim, + self.output_dim, + n_d=self.n_d, + n_a=self.n_a, + n_steps=self.n_steps, + gamma=self.gamma, + cat_idxs=self.cat_idxs, + cat_dims=self.cat_dims, + cat_emb_dim=self.cat_emb_dim, + n_independent=self.n_independent, + n_shared=self.n_shared, + epsilon=self.epsilon, + virtual_batch_size=self.virtual_batch_size, + momentum=self.momentum, + device_name=self.device_name, + mask_type=self.mask_type, + ).to(self.device) + + self.reducing_matrix = create_explain_matrix( + self.network.input_dim, + self.network.cat_emb_dim, + self.network.cat_idxs, + self.network.post_embed_dim, + ) + + def _set_metrics(self, metrics, eval_names): + """Set attributes relative to the metrics. Parameters ---------- - data: a :tensor: `torch.Tensor` - Input data - target: a :tensor: `torch.Tensor` - Target data + metrics : list of str + List of eval metric names. + eval_names : list of str + List of eval set names. - Returns - ------- - predictions: np.array - Predictions of the regression problem or the last class """ - raise NotImplementedError('users must define predict to use this base class') + metrics = metrics or [self._default_metric] + + metrics = check_metrics(metrics) + # Set metric container for each sets + self._metric_container_dict = {} + for name in eval_names: + self._metric_container_dict.update( + {name: MetricContainer(metrics, prefix=f"{name}_")} + ) + + self._metrics = [] + self._metrics_names = [] + for _, metric_container in self._metric_container_dict.items(): + self._metrics.extend(metric_container.metrics) + self._metrics_names.extend(metric_container.names) + + # Early stopping metric is the last eval metric + self.early_stopping_metric = ( + self._metrics_names[-1] if len(self._metrics_names) > 0 else None + ) - def explain(self, X): - """ - Return local explanation + def _set_callbacks(self, custom_callbacks): + """Setup the callbacks functions. Parameters ---------- - data: a :tensor: `torch.Tensor` - Input data - target: a :tensor: `torch.Tensor` - Target data + callbacks : list of func + List of callback functions. - Returns - ------- - M_explain: matrix - Importance per sample, per columns. - masks: matrix - Sparse matrix showing attention masks used by network. """ - self.network.eval() + # Setup default callbacks history and early stopping + self.history = History(self, verbose=self.verbose) + early_stopping = EarlyStopping( + early_stopping_metric=self.early_stopping_metric, + is_maximize=( + self._metrics[-1]._maximize if len(self._metrics) > 0 else None + ), + patience=self.patience, + ) + callbacks = [self.history, early_stopping] + if custom_callbacks: + callbacks.extend(custom_callbacks) + self._callback_container = CallbackContainer(callbacks) + self._callback_container.set_trainer(self) + + def _set_optimizer(self): + """Setup optimizer.""" + self._optimizer = self.optimizer_fn( + self.network.parameters(), **self.optimizer_params + ) - dataloader = DataLoader(PredictDataset(X), - batch_size=self.batch_size, shuffle=False) + def _set_scheduler(self): + """Setup scheduler.""" + self._scheduler = None + if self.scheduler_fn: + self._scheduler = self.scheduler_fn( + self._optimizer, **self.scheduler_params + ) - for batch_nb, data in enumerate(dataloader): - data = data.to(self.device).float() + def _construct_loaders(self, X_train, y_train, eval_set): + """Generate dataloaders for train and eval set. - M_explain, masks = self.network.forward_masks(data) - for key, value in masks.items(): - masks[key] = csc_matrix.dot(value.cpu().detach().numpy(), - self.reducing_matrix) + Parameters + ---------- + X_train : np.array + Train set. + y_train : np.array + Train targets. + eval_set: list of tuple + List of eval tuple set (X, y). - if batch_nb == 0: - res_explain = csc_matrix.dot(M_explain.cpu().detach().numpy(), - self.reducing_matrix) - res_masks = masks - else: - res_explain = np.vstack([res_explain, - csc_matrix.dot(M_explain.cpu().detach().numpy(), - self.reducing_matrix)]) - for key, value in masks.items(): - res_masks[key] = np.vstack([res_masks[key], value]) - return res_explain, res_masks + Returns + ------- + train_dataloader : `torch.utils.data.Dataloader` + Training dataloader. + valid_dataloaders : list of `torch.utils.data.Dataloader` + List of validation dataloaders. + + """ + # all weights are not allowed for this type of model + y_train_mapped = self.prepare_target(y_train) + for i, (X, y) in enumerate(eval_set): + y_mapped = self.prepare_target(y) + eval_set[i] = (X, y_mapped) + + train_dataloader, valid_dataloaders = create_dataloaders( + X_train, + y_train_mapped, + eval_set, + self.updated_weights, + self.batch_size, + self.num_workers, + self.drop_last, + ) + return train_dataloader, valid_dataloaders def _compute_feature_importances(self, loader): + """Compute global feature importance. + + Parameters + ---------- + loader : `torch.utils.data.Dataloader` + Pytorch dataloader. + + """ self.network.eval() feature_importances_ = np.zeros((self.network.post_embed_dim)) for data, targets in loader: @@ -491,6 +595,68 @@ def _compute_feature_importances(self, loader): M_explain, masks = self.network.forward_masks(data) feature_importances_ += M_explain.sum(dim=0).cpu().detach().numpy() - feature_importances_ = csc_matrix.dot(feature_importances_, - self.reducing_matrix) + feature_importances_ = csc_matrix.dot( + feature_importances_, self.reducing_matrix + ) self.feature_importances_ = feature_importances_ / np.sum(feature_importances_) + + @abstractmethod + def update_fit_params(self, X_train, y_train, eval_set, weights): + """ + Set attributes relative to fit function. + + Parameters + ---------- + X_train: np.ndarray + Train set + y_train : np.array + Train targets + eval_set: list of tuple + List of eval tuple set (X, y). + weights : bool or dictionnary + 0 for no balancing + 1 for automated balancing + """ + raise NotImplementedError( + "users must define update_fit_params to use this base class" + ) + + @abstractmethod + def compute_loss(self, y_score, y_true): + """ + Compute the loss. + + Parameters + ---------- + y_score: a :tensor: `torch.Tensor` + Score matrix + y_true: a :tensor: `torch.Tensor` + Target matrix + + Returns + ------- + float + Loss value + """ + raise NotImplementedError( + "users must define compute_loss to use this base class" + ) + + @abstractmethod + def prepare_target(self, y): + """ + Prepare target before training. + + Parameters + ---------- + y: a :tensor: `torch.Tensor` + Target matrix. + + Returns + ------- + `torch.Tensor` + Converted target matrix. + """ + raise NotImplementedError( + "users must define prepare_target to use this base class" + ) diff --git a/pytorch_tabnet/multiclass_utils.py b/pytorch_tabnet/multiclass_utils.py index 5598eb53..aaa85c90 100644 --- a/pytorch_tabnet/multiclass_utils.py +++ b/pytorch_tabnet/multiclass_utils.py @@ -26,17 +26,21 @@ def _assert_all_finite(X, allow_nan=False): # everything is finite; fall back to O(n) space np.isfinite to prevent # false positives from overflow in sum method. The sum is also calculated # safely to reduce dtype induced overflows. - is_float = X.dtype.kind in 'fc' + is_float = X.dtype.kind in "fc" if is_float and (np.isfinite(np.sum(X))): pass elif is_float: msg_err = "Input contains {} or a value too large for {!r}." - if (allow_nan and np.isinf(X).any() or - not allow_nan and not np.isfinite(X).all()): - type_err = 'infinity' if allow_nan else 'NaN, infinity' + if ( + allow_nan + and np.isinf(X).any() + or not allow_nan + and not np.isfinite(X).all() + ): + type_err = "infinity" if allow_nan else "NaN, infinity" raise ValueError(msg_err.format(type_err, X.dtype)) # for object dtype data, we only check for NaNs (GH-13254) - elif X.dtype == np.dtype('object') and not allow_nan: + elif X.dtype == np.dtype("object") and not allow_nan: if np.isnan(X).any(): raise ValueError("Input contains NaN") @@ -54,7 +58,7 @@ def assert_all_finite(X, allow_nan=False): def _unique_multiclass(y): - if hasattr(y, '__array__'): + if hasattr(y, "__array__"): return np.unique(np.asarray(y)) else: return set(y) @@ -68,9 +72,9 @@ def _unique_indicator(y): _FN_UNIQUE_LABELS = { - 'binary': _unique_multiclass, - 'multiclass': _unique_multiclass, - 'multilabel-indicator': _unique_indicator, + "binary": _unique_multiclass, + "multiclass": _unique_multiclass, + "multilabel-indicator": _unique_indicator, } @@ -106,7 +110,7 @@ def unique_labels(*ys): array([ 1, 2, 5, 10, 11]) """ if not ys: - raise ValueError('No argument has been passed.') + raise ValueError("No argument has been passed.") # Check that we don't mix label format ys_types = set(type_of_target(x) for x in ys) @@ -126,14 +130,14 @@ def unique_labels(*ys): ys_labels = set(chain.from_iterable(_unique_labels(y) for y in ys)) # Check that we don't mix string type with number type - if (len(set(isinstance(label, str) for label in ys_labels)) > 1): + if len(set(isinstance(label, str) for label in ys_labels)) > 1: raise ValueError("Mix of label input types (string and number)") return np.array(sorted(ys_labels)) def _is_integral_float(y): - return y.dtype.kind == 'f' and np.all(y.astype(int) == y) + return y.dtype.kind == "f" and np.all(y.astype(int) == y) def is_multilabel(y): @@ -164,7 +168,7 @@ def is_multilabel(y): >>> is_multilabel(np.array([[1, 0, 0]])) True """ - if hasattr(y, '__array__'): + if hasattr(y, "__array__"): y = np.asarray(y) if not (hasattr(y, "shape") and y.ndim == 2 and y.shape[1] > 1): return False @@ -172,14 +176,20 @@ def is_multilabel(y): if issparse(y): if isinstance(y, (dok_matrix, lil_matrix)): y = y.tocsr() - return (len(y.data) == 0 or np.unique(y.data).size == 1 and - (y.dtype.kind in 'biu' or # bool, int, uint - _is_integral_float(np.unique(y.data)))) + return ( + len(y.data) == 0 + or np.unique(y.data).size == 1 + and ( + y.dtype.kind in "biu" + or _is_integral_float(np.unique(y.data)) # bool, int, uint + ) + ) else: labels = np.unique(y) - return len(labels) < 3 and (y.dtype.kind in 'biu' or # bool, int, uint - _is_integral_float(labels)) + return len(labels) < 3 and ( + y.dtype.kind in "biu" or _is_integral_float(labels) # bool, int, uint + ) def check_classification_targets(y): @@ -194,8 +204,13 @@ def check_classification_targets(y): y : array-like """ y_type = type_of_target(y) - if y_type not in ['binary', 'multiclass', 'multiclass-multioutput', - 'multilabel-indicator', 'multilabel-sequences']: + if y_type not in [ + "binary", + "multiclass", + "multiclass-multioutput", + "multilabel-indicator", + "multilabel-sequences", + ]: raise ValueError("Unknown label type: %r" % y_type) @@ -263,45 +278,51 @@ def type_of_target(y): >>> type_of_target(np.array([[0, 1], [1, 1]])) 'multilabel-indicator' """ - valid = ((isinstance(y, (Sequence, spmatrix)) or hasattr(y, '__array__')) - and not isinstance(y, str)) + valid = ( + isinstance(y, (Sequence, spmatrix)) or hasattr(y, "__array__") + ) and not isinstance(y, str) if not valid: - raise ValueError('Expected array-like (array or non-string sequence), ' - 'got %r' % y) + raise ValueError( + "Expected array-like (array or non-string sequence), " "got %r" % y + ) - sparseseries = (y.__class__.__name__ == 'SparseSeries') + sparseseries = y.__class__.__name__ == "SparseSeries" if sparseseries: raise ValueError("y cannot be class 'SparseSeries'.") if is_multilabel(y): - return 'multilabel-indicator' + return "multilabel-indicator" try: y = np.asarray(y) except ValueError: # Known to fail in numpy 1.3 for array of arrays - return 'unknown' + return "unknown" # The old sequence of sequences format try: - if (not hasattr(y[0], '__array__') and isinstance(y[0], Sequence) - and not isinstance(y[0], str)): - raise ValueError('You appear to be using a legacy multi-label data' - ' representation. Sequence of sequences are no' - ' longer supported; use a binary array or sparse' - ' matrix instead - the MultiLabelBinarizer' - ' transformer can convert to this format.') + if ( + not hasattr(y[0], "__array__") + and isinstance(y[0], Sequence) + and not isinstance(y[0], str) + ): + raise ValueError( + "You appear to be using a legacy multi-label data" + " representation. Sequence of sequences are no" + " longer supported; use a binary array or sparse" + " matrix instead - the MultiLabelBinarizer" + " transformer can convert to this format." + ) except IndexError: pass # Invalid inputs - if y.ndim > 2 or (y.dtype == object and len(y) and - not isinstance(y.flat[0], str)): - return 'unknown' # [[[1, 2]]] or [obj_1] and not ["label_1"] + if y.ndim > 2 or (y.dtype == object and len(y) and not isinstance(y.flat[0], str)): + return "unknown" # [[[1, 2]]] or [obj_1] and not ["label_1"] if y.ndim == 2 and y.shape[1] == 0: - return 'unknown' # [[]] + return "unknown" # [[]] if y.ndim == 2 and y.shape[1] > 1: suffix = "-multioutput" # [[1, 2], [1, 2]] @@ -309,12 +330,86 @@ def type_of_target(y): suffix = "" # [1, 2, 3] or [[1], [2], [3]] # check float and contains non-integer float values - if y.dtype.kind == 'f' and np.any(y != y.astype(int)): + if y.dtype.kind == "f" and np.any(y != y.astype(int)): # [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.] _assert_all_finite(y) - return 'continuous' + suffix + return "continuous" + suffix if (len(np.unique(y)) > 2) or (y.ndim >= 2 and len(y[0]) > 1): - return 'multiclass' + suffix # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]] + return "multiclass" + suffix # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]] else: - return 'binary' # [1, 2] or [["a"], ["b"]] + return "binary" # [1, 2] or [["a"], ["b"]] + + +def infer_output_dim(y_train): + """ + Infer output_dim from targets + + Parameters + ---------- + y_train : np.array + Training targets + + Returns + ------- + output_dim : int + Number of classes for output + train_labels : list + Sorted list of initial classes + """ + train_labels = unique_labels(y_train) + output_dim = len(train_labels) + + return output_dim, train_labels + + +def check_output_dim(labels, y): + if y is not None: + valid_labels = unique_labels(y) + if not set(valid_labels).issubset(set(labels)): + raise ValueError( + f"""Valid set -- {set(valid_labels)} -- + contains unkown targets from training -- + {set(labels)}""" + ) + return + + +def infer_multitask_output(y_train): + """ + Infer output_dim from targets + This is for multiple tasks. + + Parameters + ---------- + y_train : np.ndarray + Training targets + y_valid : np.ndarray + Validation targets + + Returns + ------- + tasks_dims : list + Number of classes for output + tasks_labels : list + List of sorted list of initial classes + """ + + if len(y_train.shape) < 2: + raise ValueError( + f"""y_train shoud be of shape (n_examples, n_tasks) """ + + f"""but got {y_train.shape}""" + ) + nb_tasks = y_train.shape[1] + tasks_dims = [] + tasks_labels = [] + for task_idx in range(nb_tasks): + try: + output_dim, train_labels = infer_output_dim( + y_train[:, task_idx] + ) + tasks_dims.append(output_dim) + tasks_labels.append(train_labels) + except ValueError as err: + raise ValueError(f"""Error for task {task_idx} : {err}""") + return tasks_dims, tasks_labels diff --git a/pytorch_tabnet/multitask.py b/pytorch_tabnet/multitask.py index db52ab15..0a3bb42c 100644 --- a/pytorch_tabnet/multitask.py +++ b/pytorch_tabnet/multitask.py @@ -1,282 +1,77 @@ import torch import numpy as np -from pytorch_tabnet.multiclass_utils import unique_labels -from torch.nn.utils import clip_grad_norm_ -from pytorch_tabnet.utils import (PredictDataset, - create_dataloaders, - filter_weights) +from scipy.special import softmax +from pytorch_tabnet.utils import PredictDataset, filter_weights from pytorch_tabnet.abstract_model import TabModel +from pytorch_tabnet.multiclass_utils import infer_multitask_output from torch.utils.data import DataLoader class TabNetMultiTaskClassifier(TabModel): - - def infer_output_dim(self, y_train, y_valid): - """ - Infer output_dim from targets - This is for simple 1D np arrays - Parameters - ---------- - y_train : np.array - Training targets - y_valid : np.array - Validation targets - - Returns - ------- - output_dim : int - Number of classes for output - train_labels : list - Sorted list of initial classes - """ - train_labels = unique_labels(y_train) - output_dim = len(train_labels) - - if y_valid is not None: - valid_labels = unique_labels(y_valid) - if not set(valid_labels).issubset(set(train_labels)): - raise ValueError(f"""Valid set -- {set(valid_labels)} -- - contains unkown targets from training -- - {set(train_labels)}""") - return output_dim, train_labels - - def infer_multitask_output(self, y_train, y_valid): - """ - Infer output_dim from targets - This is for multiple tasks. - - Parameters - ---------- - y_train : np.ndarray - Training targets - y_valid : np.ndarray - Validation targets - - Returns - ------- - tasks_dims : list - Number of classes for output - tasks_labels : list - List of sorted list of initial classes - """ - - if len(y_train.shape) < 2: - raise ValueError(f"""y_train shoud be of shape (n_examples, n_tasks) """ + - f"""but got {y_train.shape}""") - nb_tasks = y_train.shape[1] - tasks_dims = [] - tasks_labels = [] - for task_idx in range(nb_tasks): - try: - output_dim, train_labels = self.infer_output_dim(y_train[:, task_idx], - y_valid[:, task_idx]) - tasks_dims.append(output_dim) - tasks_labels.append(train_labels) - except ValueError as err: - raise ValueError(f"""Error for task {task_idx} : {err}""") - return tasks_dims, tasks_labels - - def construct_loaders(self, X_train, y_train, X_valid, y_valid, weights, - batch_size, num_workers, drop_last): - """ - Returns - ------- - train_dataloader, valid_dataloader : torch.DataLoader, torch.DataLoader - Training and validation dataloaders - ------- - """ - # all weights are not allowed for this type of model - filter_weights(weights) - y_train_mapped = y_train.copy() - y_valid_mapped = y_valid.copy() - for task_idx in range(y_train.shape[1]): + def __post_init__(self): + super(TabNetMultiTaskClassifier, self).__post_init__() + self._task = 'classification' + self._default_loss = torch.nn.functional.cross_entropy + self._default_metric = 'logloss' + + def prepare_target(self, y): + y_mapped = y.copy() + for task_idx in range(y.shape[1]): task_mapper = self.target_mapper[task_idx] - y_train_mapped[:, task_idx] = np.vectorize(task_mapper.get)(y_train[:, task_idx]) - y_valid_mapped[:, task_idx] = np.vectorize(task_mapper.get)(y_valid[:, task_idx]) - train_dataloader, valid_dataloader = create_dataloaders(X_train, - y_train_mapped, - X_valid, - y_valid_mapped, - weights, - batch_size, - num_workers, - drop_last) - return train_dataloader, valid_dataloader - - def update_fit_params(self, X_train, y_train, X_valid, y_valid, loss_fn, - weights, max_epochs, patience, - batch_size, virtual_batch_size, num_workers, drop_last): - - if loss_fn is None: - self.loss_fn = torch.nn.functional.cross_entropy - else: - self.loss_fn = loss_fn - assert X_train.shape[1] == X_valid.shape[1], "Dimension mismatch X_train X_valid" - self.input_dim = X_train.shape[1] - - output_dim, train_labels = self.infer_multitask_output(y_train, y_valid) - self.output_dim = output_dim - self.classes_ = train_labels - self.target_mapper = [{class_label: index - for index, class_label in enumerate(classes)} - for classes in self.classes_] - self.preds_mapper = [{index: class_label - for index, class_label in enumerate(classes)} - for classes in self.classes_] - self.weights = weights - self.updated_weights = weights - - self.max_epochs = max_epochs - self.patience = patience - self.batch_size = batch_size - self.virtual_batch_size = virtual_batch_size - # Initialize counters and histories. - self.patience_counter = 0 - self.epoch = 0 - self.best_cost = np.inf - self.num_workers = num_workers - self.drop_last = drop_last + y_mapped[:, task_idx] = np.vectorize(task_mapper.get)(y[:, task_idx]) + return y_mapped - def train_epoch(self, train_loader): - """ - Trains one epoch of the network in self.network - - Parameters - ---------- - train_loader: a :class: `torch.utils.data.Dataloader` - DataLoader with train set - """ - - self.network.train() - total_loss = 0 - - for data, targets in train_loader: - batch_outs = self.train_batch(data, targets) - total_loss += batch_outs["loss"] - # TODO : add stopping loss - total_loss = total_loss / len(train_loader) - epoch_metrics = {'loss_avg': total_loss, - 'stopping_loss': total_loss, - } - - return epoch_metrics - - def compute_multi_loss(self, output, targets): + def compute_loss(self, y_pred, y_true): """ Computes the loss according to network output and targets Parameters ---------- - output: list of tensors + y_pred: list of tensors Output of network - targets: LongTensor + y_true: LongTensor Targets label encoded """ loss = 0 + y_true = y_true.long() if isinstance(self.loss_fn, list): # if you specify a different loss for each task - for task_loss, task_output, task_id in zip(self.loss_fn, - output, - range(len(self.loss_fn))): - loss += task_loss(task_output, targets[:, task_id]) + for task_loss, task_output, task_id in zip( + self.loss_fn, y_pred, range(len(self.loss_fn)) + ): + loss += task_loss(task_output, y_true[:, task_id]) else: # same loss function is applied to all tasks - for task_id, task_output in enumerate(output): - loss += self.loss_fn(task_output, targets[:, task_id]) + for task_id, task_output in enumerate(y_pred): + loss += self.loss_fn(task_output, y_true[:, task_id]) - loss /= len(output) + loss /= len(y_pred) return loss - def train_batch(self, data, targets): - """ - Trains one batch of data - - Parameters - ---------- - data: a :tensor: `torch.tensor` - Input data - target: a :tensor: `torch.tensor` - Target data - - Returns - ------- - batch_outs = {'loss': loss_value, - 'y_preds': output, - 'y': targets} - """ - self.network.train() - data = data.to(self.device).float() - - targets = targets.to(self.device).long() - self.optimizer.zero_grad() - - output, M_loss = self.network(data) - - loss = self.compute_multi_loss(output, targets) - # Add the overall sparsity loss - loss -= self.lambda_sparse*M_loss - - loss.backward() - if self.clip_value: - clip_grad_norm_(self.network.parameters(), self.clip_value) - self.optimizer.step() - - loss_value = loss.item() - batch_outs = {'loss': loss_value, - 'y_preds': output, - 'y': targets} - return batch_outs - - def predict_epoch(self, loader): - """ - Validates one epoch of the network in self.network - - Parameters - ---------- - loader: a :class: `torch.utils.data.Dataloader` - DataLoader with validation set - """ - self.network.eval() - total_loss = 0 - for data, targets in loader: - batch_outs = self.predict_batch(data, targets) - total_loss += batch_outs["loss"] - # TODO : add stopping loss - total_loss = total_loss / len(loader) - epoch_metrics = {'total_loss': total_loss, - 'stopping_loss': total_loss, - } - return epoch_metrics - - def predict_batch(self, data, targets): - """ - Make predictions on a batch (valid) - - Parameters - ---------- - data: a :tensor: `torch.Tensor` - Input data - target: a :tensor: `torch.Tensor` - Target data - - Returns - ------- - batch_outs: dict - """ - self.network.eval() - data = data.to(self.device).float() - targets = targets.to(self.device).long() - output, _ = self.network(data) - - loss = self.compute_multi_loss(output, targets) - # Here we do not compute the sparsity loss - - loss_value = loss.item() - batch_outs = {'loss': loss_value, - 'y_preds': output, - 'y': targets} - return batch_outs + def stack_batches(self, list_y_true, list_y_score): + y_true = np.vstack(list_y_true) + y_score = [] + for i in range(len(self.output_dim)): + score = np.vstack([x[i] for x in list_y_score]) + score = softmax(score, axis=1) + y_score.append(score) + return y_true, y_score + + def update_fit_params(self, X_train, y_train, eval_set, weights): + output_dim, train_labels = infer_multitask_output(y_train) + self.output_dim = output_dim + self.classes_ = train_labels + self.target_mapper = [ + {class_label: index for index, class_label in enumerate(classes)} + for classes in self.classes_ + ] + self.preds_mapper = [ + {index: class_label for index, class_label in enumerate(classes)} + for classes in self.classes_ + ] + self.updated_weights = weights + filter_weights(self.updated_weights) def predict(self, X): """ @@ -295,24 +90,32 @@ def predict(self, X): Predictions of the most probable class """ self.network.eval() - dataloader = DataLoader(PredictDataset(X), - batch_size=self.batch_size, shuffle=False) + dataloader = DataLoader( + PredictDataset(X), batch_size=self.batch_size, shuffle=False + ) results = {} for data in dataloader: data = data.to(self.device).float() output, _ = self.network(data) - predictions = [torch.argmax(torch.nn.Softmax(dim=1)(task_output), - dim=1).cpu().detach().numpy().reshape(-1) - for task_output in output] + predictions = [ + torch.argmax(torch.nn.Softmax(dim=1)(task_output), dim=1) + .cpu() + .detach() + .numpy() + .reshape(-1) + for task_output in output + ] for task_idx in range(len(self.output_dim)): results[task_idx] = results.get(task_idx, []) + [predictions[task_idx]] # stack all task individually results = [np.hstack(task_res) for task_res in results.values()] # map all task individually - results = [np.vectorize(self.preds_mapper[task_idx].get)(task_res) - for task_idx, task_res in enumerate(results)] + results = [ + np.vectorize(self.preds_mapper[task_idx].get)(task_res) + for task_idx, task_res in enumerate(results) + ] return results def predict_proba(self, X): @@ -332,16 +135,18 @@ def predict_proba(self, X): """ self.network.eval() - dataloader = DataLoader(PredictDataset(X), - batch_size=self.batch_size, - shuffle=False) + dataloader = DataLoader( + PredictDataset(X), batch_size=self.batch_size, shuffle=False + ) results = {} for data in dataloader: data = data.to(self.device).float() output, _ = self.network(data) - predictions = [torch.nn.Softmax(dim=1)(task_output).cpu().detach().numpy() - for task_output in output] + predictions = [ + torch.nn.Softmax(dim=1)(task_output).cpu().detach().numpy() + for task_output in output + ] for task_idx in range(len(self.output_dim)): results[task_idx] = results.get(task_idx, []) + [predictions[task_idx]] res = [np.vstack(task_res) for task_res in results.values()] diff --git a/pytorch_tabnet/tab_model.py b/pytorch_tabnet/tab_model.py index 5f4fdee2..3011c191 100755 --- a/pytorch_tabnet/tab_model.py +++ b/pytorch_tabnet/tab_model.py @@ -1,45 +1,18 @@ import torch import numpy as np -from pytorch_tabnet.multiclass_utils import unique_labels -from sklearn.metrics import roc_auc_score, mean_squared_error, accuracy_score -from torch.nn.utils import clip_grad_norm_ -from pytorch_tabnet.utils import (PredictDataset, - create_dataloaders, - filter_weights) +from scipy.special import softmax +from pytorch_tabnet.utils import PredictDataset, filter_weights from pytorch_tabnet.abstract_model import TabModel +from pytorch_tabnet.multiclass_utils import infer_output_dim, check_output_dim from torch.utils.data import DataLoader class TabNetClassifier(TabModel): - - def infer_output_dim(self, y_train, y_valid): - """ - Infer output_dim from targets - - Parameters - ---------- - y_train : np.array - Training targets - y_valid : np.array - Validation targets - - Returns - ------- - output_dim : int - Number of classes for output - train_labels : list - Sorted list of initial classes - """ - train_labels = unique_labels(y_train) - output_dim = len(train_labels) - - if y_valid is not None: - valid_labels = unique_labels(y_valid) - if not set(valid_labels).issubset(set(train_labels)): - raise ValueError(f"""Valid set -- {set(valid_labels)} -- - contains unkown targets from training -- - {set(train_labels)}""") - return output_dim, train_labels + def __post_init__(self): + super(TabNetClassifier, self).__post_init__() + self._task = 'classification' + self._default_loss = torch.nn.functional.cross_entropy + self._default_metric = 'accuracy' def weight_updater(self, weights): """ @@ -58,236 +31,46 @@ def weight_updater(self, weights): if isinstance(weights, int): return weights elif isinstance(weights, dict): - return {self.target_mapper[key]: value - for key, value in weights.items()} + return {self.target_mapper[key]: value for key, value in weights.items()} else: return weights - def construct_loaders(self, X_train, y_train, X_valid, y_valid, - weights, batch_size, num_workers, drop_last): - """ - Returns - ------- - train_dataloader, valid_dataloader : torch.DataLoader, torch.DataLoader - Training and validation dataloaders - ------- - """ - y_train_mapped = np.vectorize(self.target_mapper.get)(y_train) - y_valid_mapped = np.vectorize(self.target_mapper.get)(y_valid) - train_dataloader, valid_dataloader = create_dataloaders(X_train, - y_train_mapped, - X_valid, - y_valid_mapped, - weights, - batch_size, - num_workers, - drop_last) - return train_dataloader, valid_dataloader - - def update_fit_params(self, X_train, y_train, X_valid, y_valid, loss_fn, - weights, max_epochs, patience, - batch_size, virtual_batch_size, num_workers, drop_last): - if loss_fn is None: - self.loss_fn = torch.nn.functional.cross_entropy - else: - self.loss_fn = loss_fn - assert X_train.shape[1] == X_valid.shape[1], "Dimension mismatch X_train X_valid" - self.input_dim = X_train.shape[1] - - output_dim, train_labels = self.infer_output_dim(y_train, y_valid) + def prepare_target(self, y): + return np.vectorize(self.target_mapper.get)(y) + + def compute_loss(self, y_pred, y_true): + return self.loss_fn(y_pred, y_true.long()) + + def update_fit_params( + self, + X_train, + y_train, + eval_set, + weights, + ): + output_dim, train_labels = infer_output_dim(y_train) + for X, y in eval_set: + check_output_dim(train_labels, y) self.output_dim = output_dim + self._default_metric = ('auc' if self.output_dim == 2 else 'accuracy') self.classes_ = train_labels - self.target_mapper = {class_label: index - for index, class_label in enumerate(self.classes_)} - self.preds_mapper = {index: class_label - for index, class_label in enumerate(self.classes_)} - self.weights = weights - self.updated_weights = self.weight_updater(self.weights) - - self.max_epochs = max_epochs - self.patience = patience - self.batch_size = batch_size - self.virtual_batch_size = virtual_batch_size - # Initialize counters and histories. - self.patience_counter = 0 - self.epoch = 0 - self.best_cost = np.inf - self.num_workers = num_workers - self.drop_last = drop_last - - def train_epoch(self, train_loader): - """ - Trains one epoch of the network in self.network - - Parameters - ---------- - train_loader: a :class: `torch.utils.data.Dataloader` - DataLoader with train set - """ - - self.network.train() - y_preds = [] - ys = [] - total_loss = 0 - - for data, targets in train_loader: - batch_outs = self.train_batch(data, targets) - if self.output_dim == 2: - y_preds.append(torch.nn.Softmax(dim=1)(batch_outs["y_preds"])[:, 1] - .cpu().detach().numpy()) - else: - values, indices = torch.max(batch_outs["y_preds"], dim=1) - y_preds.append(indices.cpu().detach().numpy()) - ys.append(batch_outs["y"].cpu().detach().numpy()) - total_loss += batch_outs["loss"] - - y_preds = np.hstack(y_preds) - ys = np.hstack(ys) - - if self.output_dim == 2: - stopping_loss = -roc_auc_score(y_true=ys, y_score=y_preds) - else: - stopping_loss = -accuracy_score(y_true=ys, y_pred=y_preds) - total_loss = total_loss / len(train_loader) - epoch_metrics = {'loss_avg': total_loss, - 'stopping_loss': stopping_loss, - } - - return epoch_metrics - - def train_batch(self, data, targets): - """ - Trains one batch of data - - Parameters - ---------- - data: a :tensor: `torch.tensor` - Input data - target: a :tensor: `torch.tensor` - Target data - """ - self.network.train() - data = data.to(self.device).float() - - targets = targets.to(self.device).long() - self.optimizer.zero_grad() - - output, M_loss = self.network(data) - - loss = self.loss_fn(output, targets) - loss -= self.lambda_sparse*M_loss - - loss.backward() - if self.clip_value: - clip_grad_norm_(self.network.parameters(), self.clip_value) - self.optimizer.step() - - loss_value = loss.item() - batch_outs = {'loss': loss_value, - 'y_preds': output, - 'y': targets} - return batch_outs - - def predict_epoch(self, loader): - """ - Validates one epoch of the network in self.network - - Parameters - ---------- - loader: a :class: `torch.utils.data.Dataloader` - DataLoader with validation set - """ - y_preds = [] - ys = [] - self.network.eval() - total_loss = 0 - - for data, targets in loader: - batch_outs = self.predict_batch(data, targets) - total_loss += batch_outs["loss"] - if self.output_dim == 2: - y_preds.append(torch.nn.Softmax(dim=1)(batch_outs["y_preds"])[:, 1] - .cpu().detach().numpy()) - else: - values, indices = torch.max(batch_outs["y_preds"], dim=1) - y_preds.append(indices.cpu().detach().numpy()) - ys.append(batch_outs["y"].cpu().detach().numpy()) - - y_preds = np.hstack(y_preds) - ys = np.hstack(ys) - - if self.output_dim == 2: - stopping_loss = -roc_auc_score(y_true=ys, y_score=y_preds) - else: - stopping_loss = -accuracy_score(y_true=ys, y_pred=y_preds) - - total_loss = total_loss / len(loader) - epoch_metrics = {'total_loss': total_loss, - 'stopping_loss': stopping_loss} - - return epoch_metrics - - def predict_batch(self, data, targets): - """ - Make predictions on a batch (valid) - - Parameters - ---------- - data: a :tensor: `torch.Tensor` - Input data - target: a :tensor: `torch.Tensor` - Target data - - Returns - ------- - batch_outs: dict - """ - self.network.eval() - data = data.to(self.device).float() - targets = targets.to(self.device).long() - output, M_loss = self.network(data) - - loss = self.loss_fn(output, targets) - loss -= self.lambda_sparse*M_loss - - loss_value = loss.item() - batch_outs = {'loss': loss_value, - 'y_preds': output, - 'y': targets} - return batch_outs - - def predict(self, X): - """ - Make predictions on a batch (valid) - - Parameters - ---------- - data: a :tensor: `torch.Tensor` - Input data - target: a :tensor: `torch.Tensor` - Target data - - Returns - ------- - predictions: np.array - Predictions of the most probable class - """ - self.network.eval() - dataloader = DataLoader(PredictDataset(X), - batch_size=self.batch_size, shuffle=False) - - for batch_nb, data in enumerate(dataloader): - data = data.to(self.device).float() - output, M_loss = self.network(data) - predictions = torch.argmax(torch.nn.Softmax(dim=1)(output), - dim=1) - predictions = predictions.cpu().detach().numpy().reshape(-1) - if batch_nb == 0: - res = predictions - else: - res = np.hstack([res, predictions]) - - return np.vectorize(self.preds_mapper.get)(res) + self.target_mapper = { + class_label: index for index, class_label in enumerate(self.classes_) + } + self.preds_mapper = { + index: class_label for index, class_label in enumerate(self.classes_) + } + self.updated_weights = self.weight_updater(weights) + + def stack_batches(self, list_y_true, list_y_score): + y_true = np.hstack(list_y_true) + y_score = np.vstack(list_y_score) + y_score = softmax(y_score, axis=1) + return y_true, y_score + + def predict_func(self, outputs): + outputs = np.argmax(outputs, axis=1) + return np.vectorize(self.preds_mapper.get)(outputs) def predict_proba(self, X): """ @@ -306,8 +89,9 @@ def predict_proba(self, X): """ self.network.eval() - dataloader = DataLoader(PredictDataset(X), - batch_size=self.batch_size, shuffle=False) + dataloader = DataLoader( + PredictDataset(X), batch_size=self.batch_size, shuffle=False + ) results = [] for batch_nb, data in enumerate(dataloader): @@ -321,211 +105,34 @@ def predict_proba(self, X): class TabNetRegressor(TabModel): - - def construct_loaders(self, X_train, y_train, X_valid, y_valid, weights, - batch_size, num_workers, drop_last): - """ - Returns - ------- - train_dataloader, valid_dataloader : torch.DataLoader, torch.DataLoader - Training and validation dataloaders - ------- - """ - # all weights are not allowed for this type of model - filter_weights(weights) - train_dataloader, valid_dataloader = create_dataloaders(X_train, - y_train, - X_valid, - y_valid, - weights, - batch_size, - num_workers, - drop_last) - return train_dataloader, valid_dataloader - - def update_fit_params(self, X_train, y_train, X_valid, y_valid, loss_fn, - weights, max_epochs, patience, - batch_size, virtual_batch_size, num_workers, drop_last): - - if loss_fn is None: - self.loss_fn = torch.nn.functional.mse_loss - else: - self.loss_fn = loss_fn - - assert X_train.shape[1] == X_valid.shape[1], "Dimension mismatch X_train X_valid" - self.input_dim = X_train.shape[1] - - if len(y_train.shape) == 1: - raise ValueError("""Please apply reshape(-1, 1) to your targets - if doing single regression.""") - assert y_train.shape[1] == y_valid.shape[1], "Dimension mismatch y_train y_valid" + def __post_init__(self): + super(TabNetRegressor, self).__post_init__() + self._task = 'regression' + self._default_loss = torch.nn.functional.mse_loss + self._default_metric = 'mse' + + def prepare_target(self, y): + return y + + def compute_loss(self, y_pred, y_true): + return self.loss_fn(y_pred, y_true) + + def update_fit_params( + self, + X_train, + y_train, + eval_set, + weights + ): self.output_dim = y_train.shape[1] self.updated_weights = weights + filter_weights(self.updated_weights) - self.max_epochs = max_epochs - self.patience = patience - self.batch_size = batch_size - self.virtual_batch_size = virtual_batch_size - # Initialize counters and histories. - self.patience_counter = 0 - self.epoch = 0 - self.best_cost = np.inf - self.num_workers = num_workers - self.drop_last = drop_last - - def train_epoch(self, train_loader): - """ - Trains one epoch of the network in self.network - - Parameters - ---------- - train_loader: a :class: `torch.utils.data.Dataloader` - DataLoader with train set - """ - - self.network.train() - y_preds = [] - ys = [] - total_loss = 0 - - for data, targets in train_loader: - batch_outs = self.train_batch(data, targets) - y_preds.append(batch_outs["y_preds"].cpu().detach().numpy()) - ys.append(batch_outs["y"].cpu().detach().numpy()) - total_loss += batch_outs["loss"] - - y_preds = np.vstack(y_preds) - ys = np.vstack(ys) - - stopping_loss = mean_squared_error(y_true=ys, y_pred=y_preds) - total_loss = total_loss / len(train_loader) - epoch_metrics = {'loss_avg': total_loss, - 'stopping_loss': stopping_loss, - } - - return epoch_metrics - - def train_batch(self, data, targets): - """ - Trains one batch of data - - Parameters - ---------- - data: a :tensor: `torch.tensor` - Input data - target: a :tensor: `torch.tensor` - Target data - """ - self.network.train() - data = data.to(self.device).float() - - targets = targets.to(self.device).float() - self.optimizer.zero_grad() - - output, M_loss = self.network(data) - - loss = self.loss_fn(output, targets) - loss -= self.lambda_sparse*M_loss - - loss.backward() - if self.clip_value: - clip_grad_norm_(self.network.parameters(), self.clip_value) - self.optimizer.step() - - loss_value = loss.item() - batch_outs = {'loss': loss_value, - 'y_preds': output, - 'y': targets} - return batch_outs - - def predict_epoch(self, loader): - """ - Validates one epoch of the network in self.network - - Parameters - ---------- - loader: a :class: `torch.utils.data.Dataloader` - DataLoader with validation set - """ - y_preds = [] - ys = [] - self.network.eval() - total_loss = 0 - - for data, targets in loader: - batch_outs = self.predict_batch(data, targets) - total_loss += batch_outs["loss"] - y_preds.append(batch_outs["y_preds"].cpu().detach().numpy()) - ys.append(batch_outs["y"].cpu().detach().numpy()) - - y_preds = np.vstack(y_preds) - ys = np.vstack(ys) - - stopping_loss = mean_squared_error(y_true=ys, y_pred=y_preds) - - total_loss = total_loss / len(loader) - epoch_metrics = {'total_loss': total_loss, - 'stopping_loss': stopping_loss} - - return epoch_metrics - - def predict_batch(self, data, targets): - """ - Make predictions on a batch (valid) - - Parameters - ---------- - data: a :tensor: `torch.Tensor` - Input data - target: a :tensor: `torch.Tensor` - Target data - - Returns - ------- - batch_outs: dict - """ - self.network.eval() - data = data.to(self.device).float() - targets = targets.to(self.device).float() - - output, M_loss = self.network(data) - - loss = self.loss_fn(output, targets) - loss -= self.lambda_sparse*M_loss - - loss_value = loss.item() - batch_outs = {'loss': loss_value, - 'y_preds': output, - 'y': targets} - return batch_outs + def predict_func(self, outputs): + return outputs - def predict(self, X): - """ - Make predictions on a batch (valid) - - Parameters - ---------- - data: a :tensor: `torch.Tensor` - Input data - target: a :tensor: `torch.Tensor` - Target data - - Returns - ------- - predictions: np.array - Predictions of the regression problem - """ - self.network.eval() - dataloader = DataLoader(PredictDataset(X), - batch_size=self.batch_size, shuffle=False) - - results = [] - for batch_nb, data in enumerate(dataloader): - data = data.to(self.device).float() - - output, M_loss = self.network(data) - predictions = output.cpu().detach().numpy() - results.append(predictions) - res = np.vstack(results) - return res + def stack_batches(self, list_y_true, list_y_score): + y_true = np.vstack(list_y_true) + y_score = np.vstack(list_y_score) + return y_true, y_score diff --git a/pytorch_tabnet/utils.py b/pytorch_tabnet/utils.py index a20662a3..efe740d7 100644 --- a/pytorch_tabnet/utils.py +++ b/pytorch_tabnet/utils.py @@ -50,8 +50,9 @@ def __getitem__(self, index): return x -def create_dataloaders(X_train, y_train, X_valid, y_valid, weights, - batch_size, num_workers, drop_last): +def create_dataloaders( + X_train, y_train, eval_set, weights, batch_size, num_workers, drop_last +): """ Create dataloaders with or wihtout subsampling depending on weights and balanced. @@ -84,9 +85,10 @@ def create_dataloaders(X_train, y_train, X_valid, y_valid, weights, elif weights == 1: need_shuffle = False class_sample_count = np.array( - [len(np.where(y_train == t)[0]) for t in np.unique(y_train)]) + [len(np.where(y_train == t)[0]) for t in np.unique(y_train)] + ) - weights = 1. / class_sample_count + weights = 1.0 / class_sample_count samples_weight = np.array([weights[t] for t in y_train]) @@ -94,7 +96,7 @@ def create_dataloaders(X_train, y_train, X_valid, y_valid, weights, samples_weight = samples_weight.double() sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) else: - raise ValueError('Weights should be either 0, 1, dictionnary or list.') + raise ValueError("Weights should be either 0, 1, dictionnary or list.") elif isinstance(weights, dict): # custom weights per class need_shuffle = False @@ -103,24 +105,32 @@ def create_dataloaders(X_train, y_train, X_valid, y_valid, weights, else: # custom weights if len(weights) != len(y_train): - raise ValueError('Custom weights should match number of train samples.') + raise ValueError("Custom weights should match number of train samples.") need_shuffle = False samples_weight = np.array(weights) sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) - train_dataloader = DataLoader(TorchDataset(X_train, y_train), - batch_size=batch_size, - sampler=sampler, - shuffle=need_shuffle, - num_workers=num_workers, - drop_last=drop_last) - - valid_dataloader = DataLoader(TorchDataset(X_valid, y_valid), - batch_size=batch_size, - shuffle=False, - num_workers=num_workers) - - return train_dataloader, valid_dataloader + train_dataloader = DataLoader( + TorchDataset(X_train, y_train), + batch_size=batch_size, + sampler=sampler, + shuffle=need_shuffle, + num_workers=num_workers, + drop_last=drop_last, + ) + + valid_dataloaders = [] + for X, y in eval_set: + valid_dataloaders.append( + DataLoader( + TorchDataset(X, y), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + ) + ) + + return train_dataloader, valid_dataloaders def create_explain_matrix(input_dim, cat_emb_dim, cat_idxs, post_embed_dim): @@ -148,18 +158,20 @@ def create_explain_matrix(input_dim, cat_emb_dim, cat_idxs, post_embed_dim): """ if isinstance(cat_emb_dim, int): - all_emb_impact = [cat_emb_dim-1]*len(cat_idxs) + all_emb_impact = [cat_emb_dim - 1] * len(cat_idxs) else: - all_emb_impact = [emb_dim-1 for emb_dim in cat_emb_dim] + all_emb_impact = [emb_dim - 1 for emb_dim in cat_emb_dim] acc_emb = 0 nb_emb = 0 indices_trick = [] for i in range(input_dim): if i not in cat_idxs: - indices_trick.append([i+acc_emb]) + indices_trick.append([i + acc_emb]) else: - indices_trick.append(range(i+acc_emb, i+acc_emb+all_emb_impact[nb_emb]+1)) + indices_trick.append( + range(i + acc_emb, i + acc_emb + all_emb_impact[nb_emb] + 1) + ) acc_emb += all_emb_impact[nb_emb] nb_emb += 1 @@ -190,3 +202,56 @@ def filter_weights(weights): if isinstance(weights, dict): raise ValueError(err_msg + "Dict given.") return + + +def validate_eval_set(eval_set, eval_name, X_train, y_train): + """Check if the shapes of eval_set are compatible with (X_train, y_train). + + Parameters + ---------- + eval_set: list of tuple + List of eval tuple set (X, y). + The last one is used for early stopping + eval_names: list of str + List of eval set names. + X_train: np.ndarray + Train owned products + y_train : np.array + Train targeted products + + Returns + ------- + eval_names : list of str + Validated list of eval_names. + eval_set : list of tuple + Validated list of eval_set. + + """ + eval_name = eval_name or [f"val_{i}" for i in range(len(eval_set))] + + assert len(eval_set) == len( + eval_name + ), "eval_set and eval_name have not the same length" + if len(eval_set) > 0: + assert all( + len(elem) == 2 for elem in eval_set + ), "Each tuple of eval_set need to have two elements" + for name, (X, y) in zip(eval_name, eval_set): + msg = ( + f"Number of columns is different between X_{name} " + + f"({X.shape[1]}) and X_train ({X_train.shape[1]})" + ) + assert X.shape[1] == X_train.shape[1], msg + if len(y_train.shape) == 2: + msg = ( + f"Number of columns is different between y_{name} " + + f"({y.shape[1]}) and y_train ({y_train.shape[1]})" + ) + assert y.shape[1] == y_train.shape[1], msg + msg = ( + f"You need the same number of rows between X_{name} " + + f"({X.shape[0]}) and y_{name} ({y.shape[0]})" + ) + assert X.shape[0] == y.shape[0], msg + + return eval_name, eval_set