diff --git a/census_example.ipynb b/census_example.ipynb index 78b49dd5..2041a352 100644 --- a/census_example.ipynb +++ b/census_example.ipynb @@ -19,7 +19,12 @@ "\n", "import os\n", "import wget\n", - "from pathlib import Path" + "from pathlib import Path\n", + "\n", + "\n", + "# This is due to torch1.3 bug : https://github.com/pytorch/pytorch/issues/27972\n", + "import warnings\n", + "warnings.simplefilter(\"ignore\", UserWarning)" ] }, { @@ -172,7 +177,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": false + "scrolled": true }, "outputs": [], "source": [ @@ -336,7 +341,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.5" + "version": "3.7.6" } }, "nbformat": 4, diff --git a/forest_example.ipynb b/forest_example.ipynb index 39921dc4..c223ffdf 100644 --- a/forest_example.ipynb +++ b/forest_example.ipynb @@ -21,7 +21,11 @@ "import wget\n", "from pathlib import Path\n", "import shutil\n", - "import gzip" + "import gzip\n", + "\n", + "# This is due to torch1.3 bug : https://github.com/pytorch/pytorch/issues/27972\n", + "import warnings\n", + "warnings.simplefilter(\"ignore\", UserWarning)" ] }, { @@ -381,7 +385,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.5" + "version": "3.7.6" } }, "nbformat": 4, diff --git a/pytorch_tabnet/tab_model.py b/pytorch_tabnet/tab_model.py index c1b11cb1..073e5d90 100644 --- a/pytorch_tabnet/tab_model.py +++ b/pytorch_tabnet/tab_model.py @@ -150,14 +150,14 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, if self.verbose > 0: print("Will train until validation stopping metric", f"hasn't improved in {self.patience} rounds.") - msg_epoch = f'| EPOCH | train | valid | total time (s)' + msg_epoch = f'| EPOCH | train | valid | total time (s)' print('---------------------------------------') print(msg_epoch) - starting_time = time.time() total_time = 0 while (self.epoch < self.max_epochs and self.patience_counter < self.patience): + starting_time = time.time() fit_metrics = self.fit_epoch(train_dataloader, valid_dataloader) # leaving it here, may be used for callbacks later @@ -181,9 +181,12 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, total_time += time.time() - starting_time if self.verbose > 0: if self.epoch % self.verbose == 0: + separator = "|" msg_epoch = f"| {self.epoch:<5} | " - msg_epoch += f"{-np.round(fit_metrics['train']['stopping_loss'], 5):<5} | " - msg_epoch += f"{-np.round(fit_metrics['valid']['stopping_loss'], 5):<5} | " + msg_epoch += f"{-fit_metrics['train']['stopping_loss']:.5f}" + msg_epoch += f' {separator:<2} ' + msg_epoch += f"{-fit_metrics['valid']['stopping_loss']:.5f}" + msg_epoch += f' {separator:<2} ' msg_epoch += f" {np.round(total_time, 1):<10}" print(msg_epoch) diff --git a/regression_example.ipynb b/regression_example.ipynb index c4ccb9f0..ccfd15ab 100644 --- a/regression_example.ipynb +++ b/regression_example.ipynb @@ -19,7 +19,11 @@ "\n", "import os\n", "import wget\n", - "from pathlib import Path" + "from pathlib import Path\n", + "\n", + "# This is due to torch1.3 bug : https://github.com/pytorch/pytorch/issues/27972\n", + "import warnings\n", + "warnings.simplefilter(\"ignore\", UserWarning)" ] }, { @@ -335,7 +339,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.5" + "version": "3.7.6" } }, "nbformat": 4,