From e7dc059d8d45ce207b3c24e975dda68fec2155ba Mon Sep 17 00:00:00 2001 From: Optimox Date: Thu, 17 Oct 2019 13:17:33 +0200 Subject: [PATCH] feat: start PyTorch TabNet Paper Implementation --- .dockerignore | 1 + .gitatttributes | 29 ++ .gitignore | 131 ++++++++ Dockerfile | 14 + LICENSE | 21 ++ Makefile | 47 +++ README.md | 116 +++++++ census_example.ipynb | 362 ++++++++++++++++++++ forest_example.ipynb | 397 ++++++++++++++++++++++ poetry.lock | 764 ++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 26 ++ requirements.txt | 8 + tabnet/sparsemax.py | 264 +++++++++++++++ tabnet/tab_model.py | 449 +++++++++++++++++++++++++ tabnet/tab_network.py | 335 ++++++++++++++++++ 15 files changed, 2964 insertions(+) create mode 100644 .dockerignore create mode 100644 .gitatttributes create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 README.md create mode 100644 census_example.ipynb create mode 100644 forest_example.ipynb create mode 100644 poetry.lock create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 tabnet/sparsemax.py create mode 100644 tabnet/tab_model.py create mode 100644 tabnet/tab_network.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..72e8ffc0 --- /dev/null +++ b/.dockerignore @@ -0,0 +1 @@ +* diff --git a/.gitatttributes b/.gitatttributes new file mode 100644 index 00000000..b887e7d6 --- /dev/null +++ b/.gitatttributes @@ -0,0 +1,29 @@ +* text=auto +# Basic .gitattributes for a python repo. + +# Source files +# ============ +*.pxd text diff=python +*.py text diff=python +*.py3 text diff=python +*.pyw text diff=python +*.pyx text diff=python +*.pyz text diff=python + +# Binary files +# ============ +*.db binary +*.p binary +*.pkl binary +*.pickle binary +*.pyc binary +*.pyd binary +*.pyo binary + +# Jupyter notebook +*.ipynb text + +# Note: .db, .p, and .pkl files are associated +# with the python modules ``pickle``, ``dbm.*``, +# ``shelve``, ``marshal``, ``anydbm``, & ``bsddb`` +# (among others). diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..aac5a890 --- /dev/null +++ b/.gitignore @@ -0,0 +1,131 @@ +.cache/ +../.history/ +data/ +.ipynb_checkpoints/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..be0228c9 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,14 @@ +FROM python:3.7-slim-buster +RUN apt update && apt install curl make git -y +RUN curl -sSL https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py | python +ENV SHELL /bin/bash -l + +ENV POETRY_CACHE /work/.cache/poetry +ENV PIP_CACHE_DIR /work/.cache/pip +ENV JUPYTER_RUNTIME_DIR /work/.cache/jupyter/runtime +ENV JUPYTER_CONFIG_DIR /work/.cache/jupyter/config + +RUN $HOME/.poetry/bin/poetry config settings.virtualenvs.path $POETRY_CACHE + +# ENTRYPOINT ["poetry", "run"] +CMD ["bash", "-l"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..dbd33c7e --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 DreamQuark + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..5df9e523 --- /dev/null +++ b/Makefile @@ -0,0 +1,47 @@ +# set default shell +SHELL := $(shell which bash) +FOLDER=$$(pwd) +# default shell options +.SHELLFLAGS = -c +NO_COLOR=\\e[39m +OK_COLOR=\\e[32m +ERROR_COLOR=\\e[31m +WARN_COLOR=\\e[33m +PORT=8889 +.SILENT: ; +default: help; # default target + +IMAGE_NAME=python-poetry:latest + +build: + echo "Building Dockerfile" + docker build -t ${IMAGE_NAME} . +.PHONY: build + +start: build + echo "Starting container ${IMAGE_NAME}" + docker run --rm -it -v ${FOLDER}:/work -w /work -p ${PORT}:${PORT} -e "JUPYTER_PORT=${PORT}" ${IMAGE_NAME} +.PHONY: start + +notebook: + poetry run jupyter notebook --allow-root --ip 0.0.0.0 --port ${PORT} --no-browser --notebook-dir . +.PHONY: notebook + +root_bash: + docker exec -it --user root $$(docker ps --filter ancestor=${IMAGE_NAME} --filter expose=${PORT} -q) bash +.PHONY: root_bash + +help: + echo -e "make [ACTION] " + echo + echo -e "This image uses Poetry for dependency management (https://poetry.eustace.io/)" + echo + echo -e "Default port for Jupyter notebook is 8888" + echo + echo -e "$(UDLINE_TEXT)ACTIONS$(NORMAL_TEXT):" + echo -e "- $(BOLD_TEXT)init$(NORMAL_TEXT): create pyproject.toml interactive and install virtual env" + echo -e "- $(BOLD_TEXT)run$(NORMAL_TEXT) port=: run the Jupyter notebook on the given port" + echo -e "- $(BOLD_TEXT)stop$(NORMAL_TEXT) port=: stop the running notebook on this port" + echo -e "- $(BOLD_TEXT)logs$(NORMAL_TEXT) port=: show and tail the logs of the notebooks" + echo -e "- $(BOLD_TEXT)shell$(NORMAL_TEXT) port=: open a poetry shell" +.PHONY: help diff --git a/README.md b/README.md new file mode 100644 index 00000000..7c7afea8 --- /dev/null +++ b/README.md @@ -0,0 +1,116 @@ +# README + +# TabNet : Attentive Interpretable Tabular Learning + +This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attentive Interpretable Tabular Learning. arXiv preprint arXiv:1908.07442.) https://arxiv.org/pdf/1908.07442.pdf. + +# Installation + +You can install using pip by running: +`pip install tabnet` + +If you wan to use it locally within a docker container: + +`git clone git@github.com:dreamquark-ai/tabnet.git` + +`cd tabnet` to get inside the repository + +`make start` to build and get inside the container + +`poetry install` to install all the dependencies, including jupyter + +`make notebook` inside the same terminal + +You can then follow the link to a jupyter notebook with tabnet installed. + + + +GPU version is available and should be working but is not supported yet. + +# How to use it? + +The implementation makes it easy to try different architectures of TabNet. +All you need is to change the network parameters and training parameters. All parameters are quickly describe bellow, to get a better understanding of what each parameters do please refer to the orginal paper. + +You can also get comfortable with the code works by playing with the **notebooks tutorials** for adult census income dataset and forest cover type dataset. + +## Network parameters + +- input_dim : int + + Number of initial features of the dataset + +- output_dim : int + + Size of the desired output. Ex : + - 1 for regression task + - 2 for binary classification + - N > 2 for multiclass classifcation + +- nd : int + + Width of the decision prediction layer. Bigger values gives more capacity to the model with the risk of overfitting. + Values typically range from 8 to 64. + +- na : int + + Width of the attention embedding for each mask. + According to the paper nd=na is usually a good choice. + +- n_steps : int + Number of steps in the architecture (usually between 3 and 10) + +- gamma : float + This is the coefficient for feature reusage in the masks. + A value close to 1 will make mask selection least correlated between layers. + Values range from 1.0 to 2.0 +- cat_idxs : list of int + + List of categorical features indices. +- cat_emb_dim : list of int + + List of embeddings size for each categorical features. +- n_independent : int + + Number of independent Gated Linear Units layers at each step. + Usual values range from 1 to 5 (default=2) +- n_shared : int + + Number of shared Gated Linear Units at each step + Usual values range from 1 to 5 (default=2) +- virtual_batch_size : int + + Size of the mini batches used for Ghost Batch Normalization + +## Training parameters + +- max_epochs : int (default = 200) + + Maximum number of epochs for trainng. +- patience : int (default = 15) + + Number of consecutive epochs without improvement before performing early stopping. +- lr : float (default = 0.02) + + Initial learning rate used for training. As mentionned in the original paper, a large initial learning of ```0.02 ``` with decay is a good option. +- clip_value : float (default None) + + If a float is given this will clip the gradient at clip_value. +- lambda_sparse : float (default = 1e-3) + + This is the extra sparsity loss coefficient as proposed in the original paper. The bigger this coefficient is, the sparser your model will be in terms of feature selection. Depending on the difficulty of your problem, reducing this value could help. +- model_name : str (default = 'DQTabNet') + + Name of the model used for saving in disk, you can customize this to easily retrieve and reuse your trained models. +- saving_path : str (default = './') + + Path defining where to save models. +- scheduler_fn : torch.optim.lr_scheduler (default = None) + + Pytorch Scheduler to change learning rates during training. +- scheduler_params: dict + + Parameters dictionnary for the scheduler_fn. Ex : {"gamma": 0.95, "step_size": 10} +- verbose : int (default=-1) + + Verbosity for notebooks plots, set to 1 to see every epoch. diff --git a/census_example.ipynb b/census_example.ipynb new file mode 100644 index 00000000..2bb8c8a8 --- /dev/null +++ b/census_example.ipynb @@ -0,0 +1,362 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tabnet import tab_network\n", + "from tabnet.tab_model import Model\n", + "\n", + "import torch\n", + "from sklearn.preprocessing import LabelEncoder\n", + "from sklearn.metrics import roc_auc_score\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "np.random.seed(0)\n", + "\n", + "\n", + "import os\n", + "import wget\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Download census-income dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n", + "dataset_name = 'census-income'\n", + "out = Path(os.getcwd().rsplit(\"/\", 1)[0]+'/data/'+dataset_name+'.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "out.parent.mkdir(parents=True, exist_ok=True)\n", + "if out.exists():\n", + " print(\"File already exists.\")\n", + "else:\n", + " print(\"Downloading file...\")\n", + " wget.download(url, out.as_posix())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load data and split" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train = pd.read_csv(out)\n", + "target = ' <=50K'\n", + "if \"Set\" not in train.columns:\n", + " train[\"Set\"] = np.random.choice([\"train\", \"valid\", \"test\"], p =[.8, .1, .1], size=(train.shape[0],))\n", + "\n", + "train_indices = train[train.Set==\"train\"].index\n", + "valid_indices = train[train.Set==\"valid\"].index\n", + "test_indices = train[train.Set==\"test\"].index" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Simple preprocessing\n", + "\n", + "Label encode categorical features and fill empty cells." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "categorical_columns = []\n", + "categorical_dims = {}\n", + "for col in train.columns[train.dtypes == object]:\n", + " print(col, train[col].nunique())\n", + " l_enc = LabelEncoder()\n", + " train[col] = train[col].fillna(\"VV_likely\")\n", + " train[col] = l_enc.fit_transform(train[col].values)\n", + " categorical_columns.append(col)\n", + " categorical_dims[col] = len(l_enc.classes_)\n", + "\n", + "for col in train.columns[train.dtypes == 'float64']:\n", + " train.fillna(train.loc[train_indices, col].mean(), inplace=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Define categorical features for categorical embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "unused_feat = [\"Set\", \"ID\"]\n", + "\n", + "features = [ col for col in train.columns if col not in unused_feat+[target]] \n", + "\n", + "cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n", + "\n", + "cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n", + "\n", + "train[target] = train[target].astype(int)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Network parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_workers= 5\n", + "LR = 2e-2\n", + "batch_size = 1024 #64\n", + "mini_batch_size = 128\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "network_params = {\"input_dim\" : len(features),\n", + " \"n_d\" : 8,\n", + " \"n_a\" : 8,\n", + " \"n_independent\": 2,\n", + " \"n_shared\": 2,\n", + " \"n_steps\": 3,\n", + " \"gamma\": 1.3,\n", + " \"output_dim\" : 2,\n", + " \"momentum\": 0.1,\n", + " \"cat_idxs\":cat_idxs,\n", + " \"cat_dims\": cat_dims,\n", + " \"cat_emb_dim\": 1,\n", + " \"virtual_batch_size\": mini_batch_size,\n", + "}\n", + "\n", + "description = f\"test_TabNet_LR_{LR}_BS_{batch_size}_DS_{dataset_name}\"\n", + "description += f\"_miniBS_{mini_batch_size}\"\n", + "description += f\"_nd_{network_params['n_d']}\"\n", + "description += f\"_na_{network_params['n_a']}\"\n", + "description += f\"_nsteps_{network_params['n_steps']}\"\n", + "description += f\"_gamma_{network_params['gamma']}\"\n", + "description += f\"_momentum_{network_params['momentum']}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_scheduler = torch.optim.lr_scheduler.StepLR\n", + "scheduler_params = {\"gamma\": 0.9,\n", + " \"step_size\": 20}\n", + "\n", + "training_params = {\"model_name\": description,\n", + " \"lambda_sparse\": 1e-3,\n", + " \"lr\":LR,\n", + " \"patience\": 200,\n", + " \"optimizer_fn\":torch.optim.Adam,\n", + " \"scheduler_fn\": my_scheduler,\n", + " \"scheduler_params\":scheduler_params,\n", + " \"max_epochs\": 1000,\n", + " \"batch_size\": batch_size,\n", + " \"clip_value\": 0.5,\n", + " \"device\":device\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X_train = train.iloc[train_indices][features].values\n", + "y_train = train.iloc[train_indices][target].values\n", + "\n", + "X_valid = train.iloc[valid_indices][features].values\n", + "y_valid = train.iloc[valid_indices][target].values\n", + "\n", + "X_test = train.iloc[test_indices][features].values\n", + "y_test = train.iloc[test_indices][target].values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "network = tab_network.TabNet\n", + "model = Model()\n", + "\n", + "\n", + "model.def_network(network, **network_params)\n", + "model.set_params(**training_params)\n", + "\n", + "model.fit(\n", + " X_train=X_train, y_train=y_train,\n", + " X_valid=X_valid, y_valid=y_valid,\n", + " balanced=False, #True,\n", + " weights=None, #{0: 1, 1:10}\n", + ") " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.load_best_model()\n", + "\n", + "preds, M_explain, masks = model.predict_proba(X_test)\n", + "\n", + "y_true = y_test\n", + "\n", + "test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_true)\n", + "\n", + "print(f\"BEST VALID SCORE FOR {dataset_name} : {model.best_cost}\")\n", + "print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_auc}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# XGB" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from xgboost import XGBClassifier\n", + "\n", + "clf = XGBClassifier(max_depth=8,\n", + " learning_rate=0.1,\n", + " n_estimators=1000,\n", + " verbosity=0,\n", + " silent=None,\n", + " objective='binary:logistic',\n", + " booster='gbtree',\n", + " n_jobs=-1,\n", + " nthread=None,\n", + " gamma=0,\n", + " min_child_weight=1,\n", + " max_delta_step=0,\n", + " subsample=0.7,\n", + " colsample_bytree=1,\n", + " colsample_bylevel=1,\n", + " colsample_bynode=1,\n", + " reg_alpha=0,\n", + " reg_lambda=1,\n", + " scale_pos_weight=1,\n", + " base_score=0.5,\n", + " random_state=0,\n", + " seed=None,)\n", + "\n", + "clf.fit(X_train, y_train,\n", + " eval_set=[(X_valid, y_valid)],\n", + " early_stopping_rounds=40,\n", + " verbose=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preds = np.array(clf.predict_proba(X_valid))\n", + "valid_auc = roc_auc_score(y_score=preds[:,1], y_true=y_valid)\n", + "print(valid_auc)\n", + "\n", + "preds = np.array(clf.predict_proba(X_test))\n", + "test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)\n", + "print(test_auc)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/forest_example.ipynb b/forest_example.ipynb new file mode 100644 index 00000000..bda3a5f0 --- /dev/null +++ b/forest_example.ipynb @@ -0,0 +1,397 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tabnet.tab_model import Model\n", + "from tabnet import tab_network\n", + "import torch\n", + "from sklearn.preprocessing import LabelEncoder\n", + "from sklearn.metrics import accuracy_score\n", + "from sklearn.model_selection import train_test_split\n", + "import pandas as pd\n", + "import numpy as np\n", + "np.random.seed(0)\n", + "\n", + "\n", + "import os\n", + "import wget\n", + "from pathlib import Path\n", + "import shutil\n", + "import gzip" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Download ForestCoverType dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz\"\n", + "dataset_name = 'forest-cover-type'\n", + "tmp_out = Path(os.getcwd().rsplit(\"/\", 1)[0]+'/data/'+dataset_name+'.gz')\n", + "out = Path(os.getcwd().rsplit(\"/\", 1)[0]+'/data/'+dataset_name+'.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "out.parent.mkdir(parents=True, exist_ok=True)\n", + "if out.exists():\n", + " print(\"File already exists.\")\n", + "else:\n", + " print(\"Downloading file...\")\n", + " wget.download(url, tmp_out.as_posix())\n", + " with gzip.open(tmp_out, 'rb') as f_in:\n", + " with open(out, 'wb') as f_out:\n", + " shutil.copyfileobj(f_in, f_out)\n", + " \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load data and split\n", + "Same split as in original paper" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target = \"Covertype\"\n", + "\n", + "bool_columns = [\n", + " \"Wilderness_Area1\", \"Wilderness_Area2\", \"Wilderness_Area3\",\n", + " \"Wilderness_Area4\", \"Soil_Type1\", \"Soil_Type2\", \"Soil_Type3\", \"Soil_Type4\",\n", + " \"Soil_Type5\", \"Soil_Type6\", \"Soil_Type7\", \"Soil_Type8\", \"Soil_Type9\",\n", + " \"Soil_Type10\", \"Soil_Type11\", \"Soil_Type12\", \"Soil_Type13\", \"Soil_Type14\",\n", + " \"Soil_Type15\", \"Soil_Type16\", \"Soil_Type17\", \"Soil_Type18\", \"Soil_Type19\",\n", + " \"Soil_Type20\", \"Soil_Type21\", \"Soil_Type22\", \"Soil_Type23\", \"Soil_Type24\",\n", + " \"Soil_Type25\", \"Soil_Type26\", \"Soil_Type27\", \"Soil_Type28\", \"Soil_Type29\",\n", + " \"Soil_Type30\", \"Soil_Type31\", \"Soil_Type32\", \"Soil_Type33\", \"Soil_Type34\",\n", + " \"Soil_Type35\", \"Soil_Type36\", \"Soil_Type37\", \"Soil_Type38\", \"Soil_Type39\",\n", + " \"Soil_Type40\"\n", + "]\n", + "\n", + "int_columns = [\n", + " \"Elevation\", \"Aspect\", \"Slope\", \"Horizontal_Distance_To_Hydrology\",\n", + " \"Vertical_Distance_To_Hydrology\", \"Horizontal_Distance_To_Roadways\",\n", + " \"Hillshade_9am\", \"Hillshade_Noon\", \"Hillshade_3pm\",\n", + " \"Horizontal_Distance_To_Fire_Points\"\n", + "]\n", + "\n", + "feature_columns = (\n", + " int_columns + bool_columns + [target])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train = pd.read_csv(out, header=None, names=feature_columns)\n", + "\n", + "n_total = len(train)\n", + "\n", + "# Train, val and test split follows\n", + "# Rory Mitchell, Andrey Adinets, Thejaswi Rao, and Eibe Frank.\n", + "# Xgboost: Scalable GPU accelerated learning. arXiv:1806.11248, 2018.\n", + "\n", + "train_val_indices, test_indices = train_test_split(\n", + " range(n_total), test_size=0.2, random_state=0)\n", + "train_indices, valid_indices = train_test_split(\n", + " train_val_indices, test_size=0.2 / 0.6, random_state=0)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Simple preprocessing\n", + "\n", + "Label encode categorical features and fill empty cells." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "categorical_columns = []\n", + "categorical_dims = {}\n", + "for col in train.columns[train.dtypes == object]:\n", + " print(col, train[col].nunique())\n", + " l_enc = LabelEncoder()\n", + " train[col] = train[col].fillna(\"VV_likely\")\n", + " train[col] = l_enc.fit_transform(train[col].values)\n", + " categorical_columns.append(col)\n", + " categorical_dims[col] = len(l_enc.classes_)\n", + "\n", + "for col in train.columns[train.dtypes == 'float64']:\n", + " train.fillna(train.loc[train_indices, col].mean(), inplace=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Define categorical features for categorical embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "unused_feat = [\"Set\", \"ID\"]\n", + "\n", + "features = [ col for col in train.columns if col not in unused_feat+[target]] \n", + "\n", + "cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n", + "\n", + "cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n", + "\n", + "\n", + "train[target] = train[target].astype(int)-1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Network parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_workers= 5\n", + "LR = 2e-2\n", + "batch_size = 16384 #64\n", + "mini_batch_size = 256\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "network_params = {\"input_dim\" : len(features),\n", + " \"n_d\" : 32,\n", + " \"n_a\" : 32,\n", + " \"n_independent\": 2,\n", + " \"n_shared\": 2,\n", + " \"n_steps\": 6,\n", + " \"gamma\": 1.5,\n", + " \"output_dim\" :7,\n", + " \"momentum\": 0.3,\n", + " \"cat_idxs\":cat_idxs,\n", + " \"cat_dims\": cat_dims,\n", + " \"cat_emb_dim\": 1,\n", + " \"virtual_batch_size\": mini_batch_size,\n", + "}\n", + "\n", + "description = f\"test_TabNet_LR_{LR}_BS_{batch_size}_DS_{dataset_name}\"\n", + "description += f\"_miniBS_{mini_batch_size}\"\n", + "description += f\"_nd_{network_params['n_d']}\"\n", + "description += f\"_na_{network_params['n_a']}\"\n", + "description += f\"_nsteps_{network_params['n_steps']}\"\n", + "description += f\"_gamma_{network_params['gamma']}\"\n", + "description += f\"_momentum_{network_params['momentum']}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_scheduler = torch.optim.lr_scheduler.StepLR\n", + "scheduler_params = {\"gamma\": 0.95,\n", + " \"step_size\": 20}\n", + "\n", + "training_params = {\"model_name\": description,\n", + " \"lambda_sparse\": 1e-4,\n", + " \"lr\":LR,\n", + " \"patience\": 60,\n", + " \"optimizer_fn\":torch.optim.Adam,\n", + " \"scheduler_fn\": my_scheduler,\n", + " \"scheduler_params\":scheduler_params,\n", + " \"max_epochs\": 1000,\n", + " \"batch_size\": batch_size,\n", + " \"clip_value\": 2.0,\n", + " \"device\":device\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X_train = train.iloc[train_indices][features].values\n", + "y_train = train.iloc[train_indices][target].values\n", + "\n", + "X_valid = train.iloc[valid_indices][features].values\n", + "y_valid = train.iloc[valid_indices][target].values\n", + "\n", + "X_test = train.iloc[test_indices][features].values\n", + "y_test = train.iloc[test_indices][target].values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "network = tab_network.TabNet\n", + "model = Model()\n", + "\n", + "\n", + "model.def_network(network, **network_params)\n", + "model.set_params(**training_params)\n", + "\n", + "model.fit(\n", + " X_train=X_train, y_train=y_train,\n", + " X_valid=X_valid, y_valid=y_valid\n", + ") " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.load_best_model()\n", + "\n", + "preds, M_explain, masks = model.predict_proba(X_test)\n", + "\n", + "y_true = y_test\n", + "\n", + "test_acc = accuracy_score(y_pred=np.argmax(preds, axis=1), y_true=y_true)\n", + "\n", + "print(f\"BEST VALID SCORE FOR {dataset_name} : {model.best_cost}\")\n", + "print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_acc}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# XGB" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from xgboost import XGBClassifier\n", + "\n", + "clf = XGBClassifier(max_depth=8,\n", + " learning_rate=0.1,\n", + " n_estimators=1000,\n", + " verbosity=0,\n", + " silent=None,\n", + " objective=\"multi:softmax\",\n", + " booster='gbtree',\n", + " n_jobs=-1,\n", + " nthread=None,\n", + " gamma=0,\n", + " min_child_weight=1,\n", + " max_delta_step=0,\n", + " subsample=0.7,\n", + " colsample_bytree=1,\n", + " colsample_bylevel=1,\n", + " colsample_bynode=1,\n", + " reg_alpha=0,\n", + " reg_lambda=1,\n", + " scale_pos_weight=1,\n", + " base_score=0.5,\n", + " random_state=0,\n", + " seed=None,)\n", + "\n", + "clf.fit(X_train, y_train,\n", + " eval_set=[(X_valid, y_valid)],\n", + " early_stopping_rounds=40,\n", + " verbose=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preds_valid = np.array(clf.predict_proba(X_valid, ))\n", + "valid_acc = accuracy_score(y_pred=np.argmax(preds_valid, axis=1), y_true=y_valid)\n", + "print(valid_acc)\n", + "\n", + "preds_test = np.array(clf.predict_proba(X_test))\n", + "test_acc = accuracy_score(y_pred=np.argmax(preds_test, axis=1), y_true=y_test)\n", + "print(test_acc)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 00000000..449f1c76 --- /dev/null +++ b/poetry.lock @@ -0,0 +1,764 @@ +[[package]] +category = "dev" +description = "Disable App Nap on OS X 10.9" +marker = "python_version >= \"3.3\" and sys_platform == \"darwin\" or sys_platform == \"darwin\"" +name = "appnope" +optional = false +python-versions = "*" +version = "0.1.0" + +[[package]] +category = "dev" +description = "Classes Without Boilerplate" +name = "attrs" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +version = "19.2.0" + +[[package]] +category = "dev" +description = "Specifications for callback functions passed in to an API" +name = "backcall" +optional = false +python-versions = "*" +version = "0.1.0" + +[[package]] +category = "dev" +description = "An easy safelist-based HTML-sanitizing tool." +name = "bleach" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +version = "3.1.0" + +[package.dependencies] +six = ">=1.9.0" +webencodings = "*" + +[[package]] +category = "dev" +description = "Cross-platform colored terminal text." +marker = "python_version >= \"3.3\" and sys_platform == \"win32\" or sys_platform == \"win32\"" +name = "colorama" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +version = "0.4.1" + +[[package]] +category = "dev" +description = "Composable style cycles" +name = "cycler" +optional = false +python-versions = "*" +version = "0.10.0" + +[package.dependencies] +six = "*" + +[[package]] +category = "dev" +description = "Better living through Python with decorators" +name = "decorator" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*" +version = "4.4.0" + +[[package]] +category = "dev" +description = "XML bomb protection for Python stdlib modules" +name = "defusedxml" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +version = "0.6.0" + +[[package]] +category = "dev" +description = "Discover and load entry points from installed packages." +name = "entrypoints" +optional = false +python-versions = ">=2.7" +version = "0.3" + +[[package]] +category = "dev" +description = "Read metadata from Python packages" +name = "importlib-metadata" +optional = false +python-versions = ">=2.7,!=3.0,!=3.1,!=3.2,!=3.3" +version = "0.23" + +[package.dependencies] +zipp = ">=0.5" + +[[package]] +category = "dev" +description = "IPython Kernel for Jupyter" +name = "ipykernel" +optional = false +python-versions = ">=3.4" +version = "5.1.2" + +[package.dependencies] +ipython = ">=5.0.0" +jupyter-client = "*" +tornado = ">=4.2" +traitlets = ">=4.1.0" + +[[package]] +category = "dev" +description = "IPython: Productive Interactive Computing" +name = "ipython" +optional = false +python-versions = ">=3.5" +version = "7.8.0" + +[package.dependencies] +appnope = "*" +backcall = "*" +colorama = "*" +decorator = "*" +jedi = ">=0.10" +pexpect = "*" +pickleshare = "*" +prompt-toolkit = ">=2.0.0,<2.1.0" +pygments = "*" +setuptools = ">=18.5" +traitlets = ">=4.2" + +[[package]] +category = "dev" +description = "Vestigial utilities from IPython" +name = "ipython-genutils" +optional = false +python-versions = "*" +version = "0.2.0" + +[[package]] +category = "dev" +description = "IPython HTML widgets for Jupyter" +name = "ipywidgets" +optional = false +python-versions = "*" +version = "7.5.1" + +[package.dependencies] +ipykernel = ">=4.5.1" +nbformat = ">=4.2.0" +traitlets = ">=4.3.1" +widgetsnbextension = ">=3.5.0,<3.6.0" + +[package.dependencies.ipython] +python = ">=3.3" +version = ">=4.0.0" + +[[package]] +category = "dev" +description = "An autocompletion tool for Python that can be used for text editors." +name = "jedi" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +version = "0.15.1" + +[package.dependencies] +parso = ">=0.5.0" + +[[package]] +category = "dev" +description = "A very fast and expressive template engine." +name = "jinja2" +optional = false +python-versions = "*" +version = "2.10.3" + +[package.dependencies] +MarkupSafe = ">=0.23" + +[[package]] +category = "main" +description = "Lightweight pipelining: using Python functions as pipeline jobs." +name = "joblib" +optional = false +python-versions = "*" +version = "0.14.0" + +[[package]] +category = "dev" +description = "An implementation of JSON Schema validation for Python" +name = "jsonschema" +optional = false +python-versions = "*" +version = "3.1.1" + +[package.dependencies] +attrs = ">=17.4.0" +importlib-metadata = "*" +pyrsistent = ">=0.14.0" +setuptools = "*" +six = ">=1.11.0" + +[[package]] +category = "dev" +description = "Jupyter metapackage. Install all the Jupyter components in one go." +name = "jupyter" +optional = false +python-versions = "*" +version = "1.0.0" + +[package.dependencies] +ipykernel = "*" +ipywidgets = "*" +jupyter-console = "*" +nbconvert = "*" +notebook = "*" +qtconsole = "*" + +[[package]] +category = "dev" +description = "Jupyter protocol implementation and client libraries" +name = "jupyter-client" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +version = "5.3.4" + +[package.dependencies] +jupyter-core = ">=4.6.0" +python-dateutil = ">=2.1" +pywin32 = ">=1.0" +pyzmq = ">=13" +tornado = ">=4.1" +traitlets = "*" + +[[package]] +category = "dev" +description = "Jupyter terminal console" +name = "jupyter-console" +optional = false +python-versions = ">=3.5" +version = "6.0.0" + +[package.dependencies] +ipykernel = "*" +ipython = "*" +jupyter-client = "*" +prompt-toolkit = ">=2.0.0,<2.1.0" +pygments = "*" + +[[package]] +category = "dev" +description = "Jupyter core package. A base package on which Jupyter projects rely." +name = "jupyter-core" +optional = false +python-versions = ">=2.7, !=3.0, !=3.1, !=3.2" +version = "4.6.0" + +[package.dependencies] +pywin32 = ">=1.0" +traitlets = "*" + +[[package]] +category = "dev" +description = "A fast implementation of the Cassowary constraint solver" +name = "kiwisolver" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +version = "1.1.0" + +[package.dependencies] +setuptools = "*" + +[[package]] +category = "dev" +description = "Safely add untrusted strings to HTML/XML markup." +name = "markupsafe" +optional = false +python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*" +version = "1.1.1" + +[[package]] +category = "dev" +description = "Python plotting package" +name = "matplotlib" +optional = false +python-versions = ">=3.6" +version = "3.1.1" + +[package.dependencies] +cycler = ">=0.10" +kiwisolver = ">=1.0.1" +numpy = ">=1.11" +pyparsing = ">=2.0.1,<2.0.4 || >2.0.4,<2.1.2 || >2.1.2,<2.1.6 || >2.1.6" +python-dateutil = ">=2.1" + +[[package]] +category = "dev" +description = "The fastest markdown parser in pure Python" +name = "mistune" +optional = false +python-versions = "*" +version = "0.8.4" + +[[package]] +category = "dev" +description = "More routines for operating on iterables, beyond itertools" +name = "more-itertools" +optional = false +python-versions = ">=3.4" +version = "7.2.0" + +[[package]] +category = "dev" +description = "Converting Jupyter Notebooks" +name = "nbconvert" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +version = "5.6.0" + +[package.dependencies] +bleach = "*" +defusedxml = "*" +entrypoints = ">=0.2.2" +jinja2 = ">=2.4" +jupyter-core = "*" +mistune = ">=0.8.1,<2" +nbformat = ">=4.4" +pandocfilters = ">=1.4.1" +pygments = "*" +testpath = "*" +traitlets = ">=4.2" + +[[package]] +category = "dev" +description = "The Jupyter Notebook format" +name = "nbformat" +optional = false +python-versions = "*" +version = "4.4.0" + +[package.dependencies] +ipython-genutils = "*" +jsonschema = ">=2.4,<2.5.0 || >2.5.0" +jupyter-core = "*" +traitlets = ">=4.1" + +[[package]] +category = "dev" +description = "A web-based notebook environment for interactive computing" +name = "notebook" +optional = false +python-versions = ">=3.5" +version = "6.0.1" + +[package.dependencies] +Send2Trash = "*" +ipykernel = "*" +ipython-genutils = "*" +jinja2 = "*" +jupyter-client = ">=5.3.1" +jupyter-core = ">=4.4.0" +nbconvert = "*" +nbformat = "*" +prometheus-client = "*" +pyzmq = ">=17" +terminado = ">=0.8.1" +tornado = ">=5.0" +traitlets = ">=4.2.1" + +[[package]] +category = "main" +description = "NumPy is the fundamental package for array computing with Python." +name = "numpy" +optional = false +python-versions = ">=3.5" +version = "1.17.2" + +[[package]] +category = "main" +description = "Powerful data structures for data analysis, time series, and statistics" +name = "pandas" +optional = false +python-versions = ">=3.5.3" +version = "0.25.1" + +[package.dependencies] +numpy = ">=1.13.3" +python-dateutil = ">=2.6.1" +pytz = ">=2017.2" + +[[package]] +category = "dev" +description = "Utilities for writing pandoc filters in python" +name = "pandocfilters" +optional = false +python-versions = "*" +version = "1.4.2" + +[[package]] +category = "dev" +description = "A Python Parser" +name = "parso" +optional = false +python-versions = "*" +version = "0.5.1" + +[[package]] +category = "dev" +description = "Pexpect allows easy control of interactive console applications." +marker = "python_version >= \"3.3\" and sys_platform != \"win32\" or sys_platform != \"win32\"" +name = "pexpect" +optional = false +python-versions = "*" +version = "4.7.0" + +[package.dependencies] +ptyprocess = ">=0.5" + +[[package]] +category = "dev" +description = "Tiny 'shelve'-like database with concurrency support" +name = "pickleshare" +optional = false +python-versions = "*" +version = "0.7.5" + +[[package]] +category = "dev" +description = "Python client for the Prometheus monitoring system." +name = "prometheus-client" +optional = false +python-versions = "*" +version = "0.7.1" + +[[package]] +category = "dev" +description = "Library for building powerful interactive command lines in Python" +name = "prompt-toolkit" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +version = "2.0.10" + +[package.dependencies] +six = ">=1.9.0" +wcwidth = "*" + +[[package]] +category = "dev" +description = "Run a subprocess in a pseudo terminal" +marker = "sys_platform != \"win32\" or os_name != \"nt\" or python_version >= \"3.3\" and sys_platform != \"win32\"" +name = "ptyprocess" +optional = false +python-versions = "*" +version = "0.6.0" + +[[package]] +category = "dev" +description = "Pygments is a syntax highlighting package written in Python." +name = "pygments" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +version = "2.4.2" + +[[package]] +category = "dev" +description = "Python parsing module" +name = "pyparsing" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +version = "2.4.2" + +[[package]] +category = "dev" +description = "Persistent/Functional/Immutable data structures" +name = "pyrsistent" +optional = false +python-versions = "*" +version = "0.15.4" + +[package.dependencies] +six = "*" + +[[package]] +category = "main" +description = "Extensions to the standard Python datetime module" +name = "python-dateutil" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +version = "2.8.0" + +[package.dependencies] +six = ">=1.5" + +[[package]] +category = "main" +description = "World timezone definitions, modern and historical" +name = "pytz" +optional = false +python-versions = "*" +version = "2019.3" + +[[package]] +category = "dev" +description = "Python for Window Extensions" +marker = "sys_platform == \"win32\"" +name = "pywin32" +optional = false +python-versions = "*" +version = "225" + +[[package]] +category = "dev" +description = "Python bindings for the winpty library" +marker = "os_name == \"nt\"" +name = "pywinpty" +optional = false +python-versions = "*" +version = "0.5.5" + +[[package]] +category = "dev" +description = "Python bindings for 0MQ" +name = "pyzmq" +optional = false +python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*" +version = "18.1.0" + +[[package]] +category = "dev" +description = "Jupyter Qt console" +name = "qtconsole" +optional = false +python-versions = "*" +version = "4.5.5" + +[package.dependencies] +ipykernel = ">=4.1" +ipython-genutils = "*" +jupyter-client = ">=4.1" +jupyter-core = "*" +pygments = "*" +traitlets = "*" + +[[package]] +category = "main" +description = "A set of python modules for machine learning and data mining" +name = "scikit-learn" +optional = false +python-versions = ">=3.5" +version = "0.21.3" + +[package.dependencies] +joblib = ">=0.11" +numpy = ">=1.11.0" +scipy = ">=0.17.0" + +[[package]] +category = "main" +description = "SciPy: Scientific Library for Python" +name = "scipy" +optional = false +python-versions = ">=3.5" +version = "1.3.1" + +[package.dependencies] +numpy = ">=1.13.3" + +[[package]] +category = "dev" +description = "Send file to trash natively under Mac OS X, Windows and Linux." +name = "send2trash" +optional = false +python-versions = "*" +version = "1.5.0" + +[[package]] +category = "main" +description = "Python 2 and 3 compatibility utilities" +name = "six" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*" +version = "1.12.0" + +[[package]] +category = "dev" +description = "Terminals served to xterm.js using Tornado websockets" +name = "terminado" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +version = "0.8.2" + +[package.dependencies] +ptyprocess = "*" +pywinpty = ">=0.5" +tornado = ">=4" + +[[package]] +category = "dev" +description = "Test utilities for code working with files and commands" +name = "testpath" +optional = false +python-versions = "*" +version = "0.4.2" + +[[package]] +category = "main" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +name = "torch" +optional = false +python-versions = "*" +version = "1.0.1" + +[[package]] +category = "dev" +description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +name = "tornado" +optional = false +python-versions = ">= 3.5" +version = "6.0.3" + +[[package]] +category = "main" +description = "Fast, Extensible Progress Meter" +name = "tqdm" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*" +version = "4.30.0" + +[[package]] +category = "dev" +description = "Traitlets Python config system" +name = "traitlets" +optional = false +python-versions = "*" +version = "4.3.3" + +[package.dependencies] +decorator = "*" +ipython-genutils = "*" +six = "*" + +[[package]] +category = "dev" +description = "Measures number of Terminal column cells of wide-character codes" +name = "wcwidth" +optional = false +python-versions = "*" +version = "0.1.7" + +[[package]] +category = "dev" +description = "Character encoding aliases for legacy web content" +name = "webencodings" +optional = false +python-versions = "*" +version = "0.5.1" + +[[package]] +category = "dev" +description = "pure python download utility" +name = "wget" +optional = false +python-versions = "*" +version = "3.2" + +[[package]] +category = "dev" +description = "IPython HTML widgets for Jupyter" +name = "widgetsnbextension" +optional = false +python-versions = "*" +version = "3.5.1" + +[package.dependencies] +notebook = ">=4.4.1" + +[[package]] +category = "dev" +description = "XGBoost Python Package" +name = "xgboost" +optional = false +python-versions = ">=3.4" +version = "0.90" + +[package.dependencies] +numpy = "*" +scipy = "*" + +[[package]] +category = "dev" +description = "Backport of pathlib-compatible object wrapper for zip files" +name = "zipp" +optional = false +python-versions = ">=2.7" +version = "0.6.0" + +[package.dependencies] +more-itertools = "*" + +[metadata] +content-hash = "eee4cbcd17241646b0beb644f7c6ea5451319a72763de4151e73ed2a70a6053b" +python-versions = "^3.6.8" + +[metadata.hashes] +appnope = ["5b26757dc6f79a3b7dc9fab95359328d5747fcb2409d331ea66d0272b90ab2a0", "8b995ffe925347a2138d7ac0fe77155e4311a0ea6d6da4f5128fe4b3cbe5ed71"] +attrs = ["ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2", "f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396"] +backcall = ["38ecd85be2c1e78f77fd91700c76e14667dc21e2713b63876c0eb901196e01e4", "bbbf4b1e5cd2bdb08f915895b51081c041bac22394fdfcfdfbe9f14b77c08bf2"] +bleach = ["213336e49e102af26d9cde77dd2d0397afabc5a6bf2fed985dc35b5d1e285a16", "3fdf7f77adcf649c9911387df51254b813185e32b2c6619f690b593a617e19fa"] +colorama = ["05eed71e2e327246ad6b38c540c4a3117230b19679b875190486ddd2d721422d", "f8ac84de7840f5b9c4e3347b3c1eaa50f7e49c2b07596221daec5edaabbd7c48"] +cycler = ["1d8a5ae1ff6c5cf9b93e8811e581232ad8920aeec647c37316ceac982b08cb2d", "cd7b2d1018258d7247a71425e9f26463dfb444d411c39569972f4ce586b0c9d8"] +decorator = ["86156361c50488b84a3f148056ea716ca587df2f0de1d34750d35c21312725de", "f069f3a01830ca754ba5258fde2278454a0b5b79e0d7f5c13b3b97e57d4acff6"] +defusedxml = ["6687150770438374ab581bb7a1b327a847dd9c5749e396102de3fad4e8a3ef93", "f684034d135af4c6cbb949b8a4d2ed61634515257a67299e5f940fbaa34377f5"] +entrypoints = ["589f874b313739ad35be6e0cd7efde2a4e9b6fea91edcc34e58ecbb8dbe56d19", "c70dd71abe5a8c85e55e12c19bd91ccfeec11a6e99044204511f9ed547d48451"] +importlib-metadata = ["aa18d7378b00b40847790e7c27e11673d7fed219354109d0e7b9e5b25dc3ad26", "d5f18a79777f3aa179c145737780282e27b508fc8fd688cb17c7a813e8bd39af"] +ipykernel = ["167c3ef08450f5e060b76c749905acb0e0fbef9365899377a4a1eae728864383", "b503913e0b4cce7ed2de965457dfb2edd633e8234161a60e23f2fe2161345d12"] +ipython = ["c4ab005921641e40a68e405e286e7a1fcc464497e14d81b6914b4fd95e5dee9b", "dd76831f065f17bddd7eaa5c781f5ea32de5ef217592cf019e34043b56895aa1"] +ipython-genutils = ["72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8", "eb2e116e75ecef9d4d228fdc66af54269afa26ab4463042e33785b887c628ba8"] +ipywidgets = ["13ffeca438e0c0f91ae583dc22f50379b9d6b28390ac7be8b757140e9a771516", "e945f6e02854a74994c596d9db83444a1850c01648f1574adf144fbbabe05c97"] +jedi = ["786b6c3d80e2f06fd77162a07fed81b8baa22dde5d62896a790a331d6ac21a27", "ba859c74fa3c966a22f2aeebe1b74ee27e2a462f56d3f5f7ca4a59af61bfe42e"] +jinja2 = ["74320bb91f31270f9551d46522e33af46a80c3d619f4a4bf42b3164d30b5911f", "9fe95f19286cfefaa917656583d020be14e7859c6b0252588391e47db34527de"] +joblib = ["006108c7576b3eb6c5b27761ddbf188eb6e6347696325ab2027ea1ee9a4b922d", "6fcc57aacb4e89451fd449e9412687c51817c3f48662c3d8f38ba3f8a0a193ff"] +jsonschema = ["2fa0684276b6333ff3c0b1b27081f4b2305f0a36cf702a23db50edb141893c3f", "94c0a13b4a0616458b42529091624e66700a17f847453e52279e35509a5b7631"] +jupyter = ["3e1f86076bbb7c8c207829390305a2b1fe836d471ed54be66a3b8c41e7f46cc7", "5b290f93b98ffbc21c0c7e749f054b3267782166d72fa5e3ed1ed4eaf34a2b78", "d9dc4b3318f310e34c82951ea5d6683f67bed7def4b259fafbfe4f1beb1d8e5f"] +jupyter-client = ["60e6faec1031d63df57f1cc671ed673dced0ed420f4377ea33db37b1c188b910", "d0c077c9aaa4432ad485e7733e4d91e48f87b4f4bab7d283d42bb24cbbba0a0f"] +jupyter-console = ["308ce876354924fb6c540b41d5d6d08acfc946984bf0c97777c1ddcb42e0b2f5", "cc80a97a5c389cbd30252ffb5ce7cefd4b66bde98219edd16bf5cb6f84bb3568"] +jupyter-core = ["1368a838bba378c3c99f54c2961489831ea929ec7689a1d59d9844e584bc27dc", "85103cee6548992780912c1a0a9ec2583a4a18f1ef79a248ec0db4446500bce3"] +kiwisolver = ["05b5b061e09f60f56244adc885c4a7867da25ca387376b02c1efc29cc16bcd0f", "26f4fbd6f5e1dabff70a9ba0d2c4bd30761086454aa30dddc5b52764ee4852b7", "3b2378ad387f49cbb328205bda569b9f87288d6bc1bf4cd683c34523a2341efe", "400599c0fe58d21522cae0e8b22318e09d9729451b17ee61ba8e1e7c0346565c", "47b8cb81a7d18dbaf4fed6a61c3cecdb5adec7b4ac292bddb0d016d57e8507d5", "53eaed412477c836e1b9522c19858a8557d6e595077830146182225613b11a75", "58e626e1f7dfbb620d08d457325a4cdac65d1809680009f46bf41eaf74ad0187", "5a52e1b006bfa5be04fe4debbcdd2688432a9af4b207a3f429c74ad625022641", "5c7ca4e449ac9f99b3b9d4693debb1d6d237d1542dd6a56b3305fe8a9620f883", "682e54f0ce8f45981878756d7203fd01e188cc6c8b2c5e2cf03675390b4534d5", "79bfb2f0bd7cbf9ea256612c9523367e5ec51d7cd616ae20ca2c90f575d839a2", "7f4dd50874177d2bb060d74769210f3bce1af87a8c7cf5b37d032ebf94f0aca3", "8944a16020c07b682df861207b7e0efcd2f46c7488619cb55f65882279119389", "8aa7009437640beb2768bfd06da049bad0df85f47ff18426261acecd1cf00897", "939f36f21a8c571686eb491acfffa9c7f1ac345087281b412d63ea39ca14ec4a", "9733b7f64bd9f807832d673355f79703f81f0b3e52bfce420fc00d8cb28c6a6c", "a02f6c3e229d0b7220bd74600e9351e18bc0c361b05f29adae0d10599ae0e326", "a0c0a9f06872330d0dd31b45607197caab3c22777600e88031bfe66799e70bb0", "acc4df99308111585121db217681f1ce0eecb48d3a828a2f9bbf9773f4937e9e", "b64916959e4ae0ac78af7c3e8cef4becee0c0e9694ad477b4c6b3a536de6a544", "d3fcf0819dc3fea58be1fd1ca390851bdb719a549850e708ed858503ff25d995", "d52e3b1868a4e8fd18b5cb15055c76820df514e26aa84cc02f593d99fef6707f", "db1a5d3cc4ae943d674718d6c47d2d82488ddd94b93b9e12d24aabdbfe48caee", "e3a21a720791712ed721c7b95d433e036134de6f18c77dbe96119eaf7aa08004", "e8bf074363ce2babeb4764d94f8e65efd22e6a7c74860a4f05a6947afc020ff2", "f16814a4a96dc04bf1da7d53ee8d5b1d6decfc1a92a63349bb15d37b6a263dd9", "f2b22153870ca5cf2ab9c940d7bc38e8e9089fa0f7e5856ea195e1cf4ff43d5a", "f790f8b3dff3d53453de6a7b7ddd173d2e020fb160baff578d578065b108a05f"] +markupsafe = ["00bc623926325b26bb9605ae9eae8a215691f33cae5df11ca5424f06f2d1f473", "09027a7803a62ca78792ad89403b1b7a73a01c8cb65909cd876f7fcebd79b161", "09c4b7f37d6c648cb13f9230d847adf22f8171b1ccc4d5682398e77f40309235", "1027c282dad077d0bae18be6794e6b6b8c91d58ed8a8d89a89d59693b9131db5", "24982cc2533820871eba85ba648cd53d8623687ff11cbb805be4ff7b4c971aff", "29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b", "43a55c2930bbc139570ac2452adf3d70cdbb3cfe5912c71cdce1c2c6bbd9c5d1", "46c99d2de99945ec5cb54f23c8cd5689f6d7177305ebff350a58ce5f8de1669e", "500d4957e52ddc3351cabf489e79c91c17f6e0899158447047588650b5e69183", "535f6fc4d397c1563d08b88e485c3496cf5784e927af890fb3c3aac7f933ec66", "62fe6c95e3ec8a7fad637b7f3d372c15ec1caa01ab47926cfdf7a75b40e0eac1", "6dd73240d2af64df90aa7c4e7481e23825ea70af4b4922f8ede5b9e35f78a3b1", "717ba8fe3ae9cc0006d7c451f0bb265ee07739daf76355d06366154ee68d221e", "79855e1c5b8da654cf486b830bd42c06e8780cea587384cf6545b7d9ac013a0b", "7c1699dfe0cf8ff607dbdcc1e9b9af1755371f92a68f706051cc8c37d447c905", "88e5fcfb52ee7b911e8bb6d6aa2fd21fbecc674eadd44118a9cc3863f938e735", "8defac2f2ccd6805ebf65f5eeb132adcf2ab57aa11fdf4c0dd5169a004710e7d", "98c7086708b163d425c67c7a91bad6e466bb99d797aa64f965e9d25c12111a5e", "9add70b36c5666a2ed02b43b335fe19002ee5235efd4b8a89bfcf9005bebac0d", "9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c", "ade5e387d2ad0d7ebf59146cc00c8044acbd863725f887353a10df825fc8ae21", "b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2", "b1282f8c00509d99fef04d8ba936b156d419be841854fe901d8ae224c59f0be5", "b2051432115498d3562c084a49bba65d97cf251f5a331c64a12ee7e04dacc51b", "ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6", "c8716a48d94b06bb3b2524c2b77e055fb313aeb4ea620c8dd03a105574ba704f", "cd5df75523866410809ca100dc9681e301e3c27567cf498077e8551b6d20e42f", "e249096428b3ae81b08327a63a485ad0878de3fb939049038579ac0ef61e17e7"] +matplotlib = ["1febd22afe1489b13c6749ea059d392c03261b2950d1d45c17e3aed812080c93", "31a30d03f39528c79f3a592857be62a08595dec4ac034978ecd0f814fa0eec2d", "4442ce720907f67a79d45de9ada47be81ce17e6c2f448b3c64765af93f6829c9", "796edbd1182cbffa7e1e7a97f1e141f875a8501ba8dd834269ae3cd45a8c976f", "934e6243df7165aad097572abf5b6003c77c9b6c480c3c4de6f2ef1b5fdd4ec0", "bab9d848dbf1517bc58d1f486772e99919b19efef5dd8596d4b26f9f5ee08b6b", "c1fe1e6cdaa53f11f088b7470c2056c0df7d80ee4858dadf6cbe433fcba4323b", "e5b8aeca9276a3a988caebe9f08366ed519fff98f77c6df5b64d7603d0e42e36", "ec6bd0a6a58df3628ff269978f4a4b924a0d371ad8ce1f8e2b635b99e482877a"] +mistune = ["59a3429db53c50b5c6bcc8a07f8848cb00d7dc8bdb431a4ab41920d201d4756e", "88a1051873018da288eee8538d476dffe1262495144b33ecb586c4ab266bb8d4"] +more-itertools = ["409cd48d4db7052af495b09dec721011634af3753ae1ef92d2b32f73a745f832", "92b8c4b06dac4f0611c0729b2f2ede52b2e1bac1ab48f089c7ddc12e26bb60c4"] +nbconvert = ["427a468ec26e7d68a529b95f578d5cbf018cb4c1f889e897681c2b6d11897695", "48d3c342057a2cf21e8df820d49ff27ab9f25fc72b8f15606bd47967333b2709"] +nbformat = ["b9a0dbdbd45bb034f4f8893cafd6f652ea08c8c1674ba83f2dc55d3955743b0b", "f7494ef0df60766b7cabe0a3651556345a963b74dbc16bc7c18479041170d402"] +notebook = ["660976fe4fe45c7aa55e04bf4bccb9f9566749ff637e9020af3422f9921f9a5d", "b0a290f5cc7792d50a21bec62b3c221dd820bf00efa916ce9aeec4b5354bde20"] +numpy = ["05dbfe72684cc14b92568de1bc1f41e5f62b00f714afc9adee42f6311738091f", "0d82cb7271a577529d07bbb05cb58675f2deb09772175fab96dc8de025d8ac05", "10132aa1fef99adc85a905d82e8497a580f83739837d7cbd234649f2e9b9dc58", "12322df2e21f033a60c80319c25011194cd2a21294cc66fee0908aeae2c27832", "16f19b3aa775dddc9814e02a46b8e6ae6a54ed8cf143962b4e53f0471dbd7b16", "3d0b0989dd2d066db006158de7220802899a1e5c8cf622abe2d0bd158fd01c2c", "438a3f0e7b681642898fd7993d38e2bf140a2d1eafaf3e89bb626db7f50db355", "5fd214f482ab53f2cea57414c5fb3e58895b17df6e6f5bca5be6a0bb6aea23bb", "73615d3edc84dd7c4aeb212fa3748fb83217e00d201875a47327f55363cef2df", "7bd355ad7496f4ce1d235e9814ec81ee3d28308d591c067ce92e49f745ba2c2f", "7d077f2976b8f3de08a0dcf5d72083f4af5411e8fddacd662aae27baa2601196", "a4092682778dc48093e8bda8d26ee8360153e2047826f95a3f5eae09f0ae3abf", "b458de8624c9f6034af492372eb2fee41a8e605f03f4732f43fc099e227858b2", "e70fc8ff03a961f13363c2c95ef8285e0cf6a720f8271836f852cc0fa64e97c8", "ee8e9d7cad5fe6dde50ede0d2e978d81eafeaa6233fb0b8719f60214cf226578", "f4a4f6aba148858a5a5d546a99280f71f5ee6ec8182a7d195af1a914195b21a2"] +pandas = ["18d91a9199d1dfaa01ad645f7540370ba630bdcef09daaf9edf45b4b1bca0232", "3f26e5da310a0c0b83ea50da1fd397de2640b02b424aa69be7e0784228f656c9", "4182e32f4456d2c64619e97c58571fa5ca0993d1e8c2d9ca44916185e1726e15", "426e590e2eb0e60f765271d668a30cf38b582eaae5ec9b31229c8c3c10c5bc21", "5eb934a8f0dc358f0e0cdf314072286bbac74e4c124b64371395e94644d5d919", "717928808043d3ea55b9bcde636d4a52d2236c246f6df464163a66ff59980ad8", "8145f97c5ed71827a6ec98ceaef35afed1377e2d19c4078f324d209ff253ecb5", "8744c84c914dcc59cbbb2943b32b7664df1039d99e834e1034a3372acb89ea4d", "c1ac1d9590d0c9314ebf01591bd40d4c03d710bfc84a3889e5263c97d7891dee", "cb2e197b7b0687becb026b84d3c242482f20cbb29a9981e43604eb67576da9f6", "d4001b71ad2c9b84ff18b182cea22b7b6cbf624216da3ea06fb7af28d1f93165", "d8930772adccb2882989ab1493fa74bd87d47c8ac7417f5dd3dd834ba8c24dc9", "dfbb0173ee2399bc4ed3caf2d236e5c0092f948aafd0a15fbe4a0e77ee61a958", "eebfbba048f4fa8ac711b22c78516e16ff8117d05a580e7eeef6b0c2be554c18", "f1b21bc5cf3dbea53d33615d1ead892dfdae9d7052fa8898083bec88be20dcd2"] +pandocfilters = ["b3dd70e169bb5449e6bc6ff96aea89c5eea8c5f6ab5e207fc2f521a2cf4a0da9"] +parso = ["63854233e1fadb5da97f2744b6b24346d2750b85965e7e399bec1620232797dc", "666b0ee4a7a1220f65d367617f2cd3ffddff3e205f3f16a0284df30e774c2a9c"] +pexpect = ["2094eefdfcf37a1fdbfb9aa090862c1a4878e5c7e0e7e7088bdb511c558e5cd1", "9e2c1fd0e6ee3a49b28f95d4b33bc389c89b20af6a1255906e90ff1262ce62eb"] +pickleshare = ["87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca", "9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56"] +prometheus-client = ["71cd24a2b3eb335cb800c7159f423df1bd4dcd5171b234be15e3f31ec9f622da"] +prompt-toolkit = ["46642344ce457641f28fc9d1c9ca939b63dadf8df128b86f1b9860e59c73a5e4", "e7f8af9e3d70f514373bf41aa51bc33af12a6db3f71461ea47fea985defb2c31", "f15af68f66e664eaa559d4ac8a928111eebd5feda0c11738b5998045224829db"] +ptyprocess = ["923f299cc5ad920c68f2bc0bc98b75b9f838b93b599941a6b63ddbc2476394c0", "d7cc528d76e76342423ca640335bd3633420dc1366f258cb31d05e865ef5ca1f"] +pygments = ["71e430bc85c88a430f000ac1d9b331d2407f681d6f6aec95e8bcfbc3df5b0127", "881c4c157e45f30af185c1ffe8d549d48ac9127433f2c380c24b84572ad66297"] +pyparsing = ["6f98a7b9397e206d78cc01df10131398f1c8b8510a2f4d97d9abd82e1aacdd80", "d9338df12903bbf5d65a0e4e87c2161968b10d2e489652bb47001d82a9b028b4"] +pyrsistent = ["34b47fa169d6006b32e99d4b3c4031f155e6e68ebcc107d6454852e8e0ee6533"] +python-dateutil = ["7e6584c74aeed623791615e26efd690f29817a27c73085b78e4bad02493df2fb", "c89805f6f4d64db21ed966fda138f8a5ed7a4fdbc1a8ee329ce1b74e3c74da9e"] +pytz = ["1c557d7d0e871de1f5ccd5833f60fb2550652da6be2693c1e02300743d21500d", "b02c06db6cf09c12dd25137e563b31700d3b80fcc4ad23abb7a315f2789819be"] +pywin32 = ["0443e9bb196e72480f50cbddc2cf98fbb858a77d02e281ba79489ea3287b36e9", "09bbe7cdb29eb40ab2e83f7a232eeeedde864be7a0622b70a90f456aad07a234", "0d8e0f47808798d320c983574c36c49db642678902933a210edd40157d206fd0", "0db7c9f4b93528afd080d35912a60be2f86a1d6c49c0a9cf9cedd106eed81ea3", "749e590875051661ecefbd9dfa957a485016de0f25e43f5e70f888ef1e29587b", "779d3e9d4b934f2445d2920c3941416d99af72eb7f7fd57a63576cc8aa540ad6", "7c89d2c11a31c7aaa16dc4d25054d7e0e99d6f6b24193cf62c83850484658c87", "81f7732b662c46274d7d8c411c905d53e71999cba95457a0686467c3ebc745ca", "9db1fb8830bfa99c5bfd335d4482c14db5c6f5028db3b006787ef4200206242b", "bd8d04835db28646d9e07fd0ab7c7b18bd90e89dfdc559e60389179495ef30da", "fc6822a68afd79e97b015985dd455767c72009b81bcd18957068626c43f11e75", "fe6cfc2045931866417740b575231c7e12d69d481643be1493487ad53b089959"] +pywinpty = ["0e01321e53a230233358a6d608a1a8bc86c3882cf82769ba3c62ca387dc9cc51", "333e0bc5fca8ad9e9a1516ebedb2a65da38dc1f399f8b2ea57d6cccec1ff2cc8", "3ca3123aa6340ab31bbf9bd012b92e72f9ec905e4c9ee152cc997403e1778cd3", "44a6dddcf2abf402e22f87e2c9a341f7d0b296afbec3d28184c8de4d7f514ee4", "53d94d574c3d4da2df5b1c3ae728b8d90e4d33502b0388576bbd4ddeb4de0f77", "c3955f162c53dde968f3fc11361658f1d83b683bfe601d4b6f94bb01ea4300bc", "cec9894ecb34de3d7b1ca121dd98433035b9f8949b5095e84b103b349231509c", "dcd45912e2fe2e6f72cee997a4da6ed1ad2056165a277ce5ec7f7ac98dcdf667", "f2bcdd9a2ffd8b223752a971b3d377fb7bfed85f140ec9710f1218d760f2ccb7"] +pyzmq = ["01636e95a88d60118479041c6aaaaf5419c6485b7b1d37c9c4dd424b7b9f1121", "021dba0d1436516092c624359e5da51472b11ba8edffa334218912f7e8b65467", "0463bd941b6aead494d4035f7eebd70035293dd6caf8425993e85ad41de13fa3", "05fd51edd81eed798fccafdd49c936b6c166ffae7b32482e4d6d6a2e196af4e6", "1fadc8fbdf3d22753c36d4172169d184ee6654f8d6539e7af25029643363c490", "22efa0596cf245a78a99060fe5682c4cd00c58bb7614271129215c889062db80", "260c70b7c018905ec3659d0f04db735ac830fe27236e43b9dc0532cf7c9873ef", "2762c45e289732d4450406cedca35a9d4d71e449131ba2f491e0bf473e3d2ff2", "2fc6cada8dc53521c1189596f1898d45c5f68603194d3a6453d6db4b27f4e12e", "343b9710a61f2b167673bea1974e70b5dccfe64b5ed10626798f08c1f7227e72", "41bf96d5f554598a0632c3ec28e3026f1d6591a50f580df38eff0b8067efb9e7", "856b2cdf7a1e2cbb84928e1e8db0ea4018709b39804103d3a409e5584f553f57", "85b869abc894672de9aecdf032158ea8ad01e2f0c3b09ef60e3687fb79418096", "93f44739db69234c013a16990e43db1aa0af3cf5a4b8b377d028ff24515fbeb3", "98fa3e75ccb22c0dc99654e3dd9ff693b956861459e8c8e8734dd6247b89eb29", "9a22c94d2e93af8bebd4fcf5fa38830f5e3b1ff0d4424e2912b07651eb1bafb4", "a7d3f4b4bbb5d7866ae727763268b5c15797cbd7b63ea17f3b0ec1067da8994b", "b645a49376547b3816433a7e2d2a99135c8e651e50497e7ecac3bd126e4bea16", "cf0765822e78cf9e45451647a346d443f66792aba906bc340f4e0ac7870c169c", "dc398e1e047efb18bfab7a8989346c6921a847feae2cad69fedf6ca12fb99e2c", "dd5995ae2e80044e33b5077fb4bc2b0c1788ac6feaf15a6b87a00c14b4bdd682", "e03fe5e07e70f245dc9013a9d48ae8cc4b10c33a1968039c5a3b64b5d01d083d", "ea09a306144dff2795e48439883349819bef2c53c0ee62a3c2fae429451843bb", "f4e37f33da282c3c319849877e34f97f0a3acec09622ec61b7333205bdd13b52", "fa4bad0d1d173dee3e8ef3c3eb6b2bb6c723fc7a661eeecc1ecb2fa99860dd45"] +qtconsole = ["40d5d8e00d070ea266dbf6f0da74c4b9597b8b8d67cd8233c3ffd8debf923703", "b91e7412587e6cfe1644696538f73baf5611e837be5406633218443b2827c6d9"] +scikit-learn = ["1ac81293d261747c25ea5a0ee8cd2bb1f3b5ba9ec05421a7f9f0feb4eb7c4116", "289361cf003d90b007f5066b27fcddc2d71324c82f1c88e316fedacb0dfdd516", "3a14d0abd4281fc3fd2149c486c3ec7cedad848b8d5f7b6f61522029d65a29f8", "5083a5e50d9d54548e4ada829598ae63a05651dd2bb319f821ffd9e8388384a6", "777cdd5c077b7ca9cb381396c81990cf41d2fa8350760d3cad3b4c460a7db644", "8bf2ff63da820d09b96b18e88f9625228457bff8df4618f6b087e12442ef9e15", "8d319b71c449627d178f21c57614e21747e54bb3fc9602b6f42906c3931aa320", "928050b65781fea9542dfe9bfe02d8c4f5530baa8472ec60782ea77347d2c836", "92c903613ff50e22aa95d589f9fff5deb6f34e79f7f21f609680087f137bb524", "ae322235def5ce8fae645b439e332e6f25d34bb90d6a6c8e261f17eb476457b7", "c1cd6b29eb1fd1cc672ac5e4a8be5f6ea936d094a3dc659ada0746d6fac750b1", "c41a6e2685d06bcdb0d26533af2540f54884d40db7e48baed6a5bcbf1a7cc642", "d07fcb0c0acbc043faa0e7cf4d2037f71193de3fb04fb8ed5c259b089af1cf5c", "d146d5443cda0a41f74276e42faf8c7f283fef49e8a853b832885239ef544e05", "eb2b7bed0a26ba5ce3700e15938b28a4f4513578d3e54a2156c29df19ac5fd01", "eb9b8ebf59eddd8b96366428238ab27d05a19e89c5516ce294abc35cea75d003"] +scipy = ["0baa64bf42592032f6f6445a07144e355ca876b177f47ad8d0612901c9375bef", "243b04730d7223d2b844bda9500310eecc9eda0cba9ceaf0cde1839f8287dfa8", "2643cfb46d97b7797d1dbdb6f3c23fe3402904e3c90e6facfe6a9b98d808c1b5", "396eb4cdad421f846a1498299474f0a3752921229388f91f60dc3eda55a00488", "3ae3692616975d3c10aca6d574d6b4ff95568768d4525f76222fb60f142075b9", "435d19f80b4dcf67dc090cc04fde2c5c8a70b3372e64f6a9c58c5b806abfa5a8", "46a5e55850cfe02332998b3aef481d33f1efee1960fe6cfee0202c7dd6fc21ab", "75b513c462e58eeca82b22fc00f0d1875a37b12913eee9d979233349fce5c8b2", "7ccfa44a08226825126c4ef0027aa46a38c928a10f0a8a8483c80dd9f9a0ad44", "89dd6a6d329e3f693d1204d5562dd63af0fd7a17854ced17f9cbc37d5b853c8d", "a81da2fe32f4eab8b60d56ad43e44d93d392da228a77e229e59b51508a00299c", "a9d606d11eb2eec7ef893eb825017fbb6eef1e1d0b98a5b7fc11446ebeb2b9b1", "ac37eb652248e2d7cbbfd89619dce5ecfd27d657e714ed049d82f19b162e8d45", "cbc0611699e420774e945f6a4e2830f7ca2b3ee3483fca1aa659100049487dd5", "d02d813ec9958ed63b390ded463163685af6025cb2e9a226ec2c477df90c6957", "dd3b52e00f93fd1c86f2d78243dfb0d02743c94dd1d34ffea10055438e63b99d"] +send2trash = ["60001cc07d707fe247c94f74ca6ac0d3255aabcb930529690897ca2a39db28b2", "f1691922577b6fa12821234aeb57599d887c4900b9ca537948d2dac34aea888b"] +six = ["3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", "d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73"] +terminado = ["d9d012de63acb8223ac969c17c3043337c2fcfd28f3aea1ee429b345d01ef460", "de08e141f83c3a0798b050ecb097ab6259c3f0331b2f7b7750c9075ced2c20c2"] +testpath = ["46c89ebb683f473ffe2aab0ed9f12581d4d078308a3cb3765d79c6b2317b0109", "b694b3d9288dbd81685c5d2e7140b81365d46c29f5db4bc659de5aa6b98780f8"] +torch = ["0932756a2de0ea9a47a4aee34e6cd475734a355477e5149a006fc8faf57a3229", "1330e1c47302113f05e65b17e518e9ebcf41b53982e38ee4e662fbc5390bb46c", "13d09d5022e0dd251a88b6be0415eecddafb093b067a253ff0c5c0f5acd12077", "6154a8b92d869982d586d6a31955071d4bceb89e170153efd2861555bccd84c1", "6618b915124d22309d6ba7d80cf7539084bc7146f21837a9329a1d9e3a4e647d", "743ab46bf82eef8b71042f9423eeef6e2bae0974694f3b3e918a287f69dd693a", "b71e072dc68ef49afc3d9aaf0af8bcb20ce03bfce5cb43ee45be2d9cae5edf40", "d7a88d0e8c58effe46a4b31531e26340657375c61da4f5a2002b0e4b07f85437", "ebf899165c96cba8468237c8bb0a0cc9e1a838ecd05fb0272934a83f33594a77"] +tornado = ["349884248c36801afa19e342a77cc4458caca694b0eda633f5878e458a44cb2c", "398e0d35e086ba38a0427c3b37f4337327231942e731edaa6e9fd1865bbd6f60", "4e73ef678b1a859f0cb29e1d895526a20ea64b5ffd510a2307b5998c7df24281", "559bce3d31484b665259f50cd94c5c28b961b09315ccd838f284687245f416e5", "abbe53a39734ef4aba061fca54e30c6b4639d3e1f59653f0da37a0003de148c7", "c845db36ba616912074c5b1ee897f8e0124df269468f25e4fe21fe72f6edd7a9", "c9399267c926a4e7c418baa5cbe91c7d1cf362d505a1ef898fde44a07c9dd8a5"] +tqdm = ["13f018038711256ed27aae118a80d63929588e90f00d072a0f4eb7aa3333b4dc", "dd60ea2567baa013c625153ce41fd274209c69a5814513e1d635f20e5cd61b97"] +traitlets = ["70b4c6a1d9019d7b4f6846832288f86998aa3b9207c6821f3578a6a6a467fe44", "d023ee369ddd2763310e4c3eae1ff649689440d4ae59d7485eb4cfbbe3e359f7"] +wcwidth = ["3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e", "f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c"] +webencodings = ["a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", "b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"] +wget = ["35e630eca2aa50ce998b9b1a127bb26b30dfee573702782aa982f875e3f16061"] +widgetsnbextension = ["079f87d87270bce047512400efd70238820751a11d2d8cb137a5a5bdbaf255c7", "bd314f8ceb488571a5ffea6cc5b9fc6cba0adaf88a9d2386b93a489751938bcd"] +xgboost = ["5ec073f6d68348784e9afdb831371fefb89de896d8eb58e79244ad05177c5753", "898f26bb66589c644d17deff1b03961504f7ad79296ed434d0d7a5e9cb4deae6", "d69f90d61a63e8889fd39a31ad00c629bac1ca627f8406b9b6d4594c9e29ab84"] +zipp = ["3718b1cbcd963c7d4c5511a8240812904164b7f381b647143a89d3b98f9bcd8e", "f06903e9f1f43b12d371004b4ac7b06ab39a44adc747266928ae6debfa7b3335"] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..dec39339 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,26 @@ +[tool.poetry] +name = "tabnet" +version = "0.1.0" +description = "" +readme = "README.md" +authors = [] +exclude = ["tabnet/*.ipynb"] + +[tool.poetry.dependencies] +python = "^3.6.8" + +numpy="1.17.2" +torch="1.0.1" +tqdm="4.30.0" +scikit_learn="0.21.3" +pandas="0.25.1" + +[tool.poetry.dev-dependencies] +jupyter="1.0.0" +xgboost="0.90" +matplotlib="3.1.1" +wget="3.2" + +[build-system] +requires = ["poetry>=0.12"] +build-backend = "poetry.masonry.api" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..616a882b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +numpy>=1.16.4 +matplotlib>=3.1.1 +torch>=1.0.1 +tqdm>=4.30.0 +ipython>=7.8.0 +scikit_learn>=0.21.3 +xgboost>=0.90 +tqdm>=4.30.0 diff --git a/tabnet/sparsemax.py b/tabnet/sparsemax.py new file mode 100644 index 00000000..99482a93 --- /dev/null +++ b/tabnet/sparsemax.py @@ -0,0 +1,264 @@ +from torch import nn +from torch.autograd import Function +import torch + + +# Other possible implementations: +# https://github.com/KrisKorrel/sparsemax-pytorch/blob/master/sparsemax.py +# https://github.com/msobroza/SparsemaxPytorch/blob/master/mnist/sparsemax.py +# https://github.com/vene/sparse-structured-attention/blob/master/pytorch/torchsparseattn/sparsemax.py + + + +# credits to Yandex https://github.com/Qwicen/node/blob/master/lib/nn_utils.py +def _make_ix_like(input, dim=0): + d = input.size(dim) + rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) + view = [1] * input.dim() + view[0] = -1 + return rho.view(view).transpose(0, dim) + + +class SparsemaxFunction(Function): + """ + An implementation of sparsemax (Martins & Astudillo, 2016). See + :cite:`DBLP:journals/corr/MartinsA16` for detailed description. + By Ben Peters and Vlad Niculae + """ + + @staticmethod + def forward(ctx, input, dim=-1): + """sparsemax: normalizing sparse transform (a la softmax) + Parameters: + input (Tensor): any shape + dim: dimension along which to apply sparsemax + Returns: + output (Tensor): same shape as input + """ + ctx.dim = dim + max_val, _ = input.max(dim=dim, keepdim=True) + input -= max_val # same numerical stability trick as for softmax + tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim) + output = torch.clamp(input - tau, min=0) + ctx.save_for_backward(supp_size, output) + return output + + @staticmethod + def backward(ctx, grad_output): + supp_size, output = ctx.saved_tensors + dim = ctx.dim + grad_input = grad_output.clone() + grad_input[output == 0] = 0 + + v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() + v_hat = v_hat.unsqueeze(dim) + grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) + return grad_input, None + + + @staticmethod + def _threshold_and_support(input, dim=-1): + """Sparsemax building block: compute the threshold + Args: + input: any dimension + dim: dimension along which to apply the sparsemax + Returns: + the threshold value + """ + + input_srt, _ = torch.sort(input, descending=True, dim=dim) + input_cumsum = input_srt.cumsum(dim) - 1 + rhos = _make_ix_like(input, dim) + support = rhos * input_srt > input_cumsum + + support_size = support.sum(dim=dim).unsqueeze(dim) + tau = input_cumsum.gather(dim, support_size - 1) + tau /= support_size.to(input.dtype) + return tau, support_size + + +#sparsemax = lambda input, dim=-1: SparsemaxFunction.apply(input, dim) + + +sparsemax = SparsemaxFunction.apply + + +class Sparsemax(nn.Module): + + def __init__(self, dim=-1): + self.dim = dim + super(Sparsemax, self).__init__() + + def forward(self, input): + return sparsemax(input, self.dim) + + +class Entmax15Function(Function): + """ + An implementation of exact Entmax with alpha=1.5 (B. Peters, V. Niculae, A. Martins). See + :cite:`https://arxiv.org/abs/1905.05702 for detailed description. + Source: https://github.com/deep-spin/entmax + """ + + @staticmethod + def forward(ctx, input, dim=-1): + ctx.dim = dim + + max_val, _ = input.max(dim=dim, keepdim=True) + input = input - max_val # same numerical stability trick as for softmax + input = input / 2 # divide by 2 to solve actual Entmax + + tau_star, _ = Entmax15Function._threshold_and_support(input, dim) + output = torch.clamp(input - tau_star, min=0) ** 2 + ctx.save_for_backward(output) + return output + + @staticmethod + def backward(ctx, grad_output): + Y, = ctx.saved_tensors + gppr = Y.sqrt() # = 1 / g'' (Y) + dX = grad_output * gppr + q = dX.sum(ctx.dim) / gppr.sum(ctx.dim) + q = q.unsqueeze(ctx.dim) + dX -= q * gppr + return dX, None + + @staticmethod + def _threshold_and_support(input, dim=-1): + Xsrt, _ = torch.sort(input, descending=True, dim=dim) + + rho = _make_ix_like(input, dim) + mean = Xsrt.cumsum(dim) / rho + mean_sq = (Xsrt ** 2).cumsum(dim) / rho + ss = rho * (mean_sq - mean ** 2) + delta = (1 - ss) / rho + + # NOTE this is not exactly the same as in reference algo + # Fortunately it seems the clamped values never wrongly + # get selected by tau <= sorted_z. Prove this! + delta_nz = torch.clamp(delta, 0) + tau = mean - torch.sqrt(delta_nz) + + support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim) + tau_star = tau.gather(dim, support_size - 1) + return tau_star, support_size + + +class Entmoid15(Function): + """ A highly optimized equivalent of labda x: Entmax15([x, 0]) """ + + @staticmethod + def forward(ctx, input): + output = Entmoid15._forward(input) + ctx.save_for_backward(output) + return output + + @staticmethod + def _forward(input): + input, is_pos = abs(input), input >= 0 + tau = (input + torch.sqrt(F.relu(8 - input ** 2))) / 2 + tau.masked_fill_(tau <= input, 2.0) + y_neg = 0.25 * F.relu(tau - input, inplace=True) ** 2 + return torch.where(is_pos, 1 - y_neg, y_neg) + + @staticmethod + def backward(ctx, grad_output): + return Entmoid15._backward(ctx.saved_tensors[0], grad_output) + + @staticmethod + def _backward(output, grad_output): + gppr0, gppr1 = output.sqrt(), (1 - output).sqrt() + grad_input = grad_output * gppr0 + q = grad_input / (gppr0 + gppr1) + grad_input -= q * gppr0 + return grad_input + + + +entmax15 = Entmax15Function.apply +entmoid15 = Entmoid15.apply + +class Entmax15(nn.Module): + + def __init__(self, dim=-1): + self.dim = dim + super(Entmax15, self).__init__() + + def forward(self, input): + return entmax15(input, self.dim) + + + + +# Credits were lost... +# def _make_ix_like(input, dim=0): +# d = input.size(dim) +# rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) +# view = [1] * input.dim() +# view[0] = -1 +# return rho.view(view).transpose(0, dim) +# +# +# def _threshold_and_support(input, dim=0): +# """Sparsemax building block: compute the threshold +# Args: +# input: any dimension +# dim: dimension along which to apply the sparsemax +# Returns: +# the threshold value +# """ +# +# input_srt, _ = torch.sort(input, descending=True, dim=dim) +# input_cumsum = input_srt.cumsum(dim) - 1 +# rhos = _make_ix_like(input, dim) +# support = rhos * input_srt > input_cumsum +# +# support_size = support.sum(dim=dim).unsqueeze(dim) +# tau = input_cumsum.gather(dim, support_size - 1) +# tau /= support_size.to(input.dtype) +# return tau, support_size +# +# +# class SparsemaxFunction(Function): +# +# @staticmethod +# def forward(ctx, input, dim=0): +# """sparsemax: normalizing sparse transform (a la softmax) +# Parameters: +# input (Tensor): any shape +# dim: dimension along which to apply sparsemax +# Returns: +# output (Tensor): same shape as input +# """ +# ctx.dim = dim +# max_val, _ = input.max(dim=dim, keepdim=True) +# input -= max_val # same numerical stability trick as for softmax +# tau, supp_size = _threshold_and_support(input, dim=dim) +# output = torch.clamp(input - tau, min=0) +# ctx.save_for_backward(supp_size, output) +# return output +# +# @staticmethod +# def backward(ctx, grad_output): +# supp_size, output = ctx.saved_tensors +# dim = ctx.dim +# grad_input = grad_output.clone() +# grad_input[output == 0] = 0 +# +# v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() +# v_hat = v_hat.unsqueeze(dim) +# grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) +# return grad_input, None +# +# +# sparsemax = SparsemaxFunction.apply +# +# +# class Sparsemax(nn.Module): +# +# def __init__(self, dim=0): +# self.dim = dim +# super(Sparsemax, self).__init__() +# +# def forward(self, input): +# return sparsemax(input, self.dim) diff --git a/tabnet/tab_model.py b/tabnet/tab_model.py new file mode 100644 index 00000000..94f45b0a --- /dev/null +++ b/tabnet/tab_model.py @@ -0,0 +1,449 @@ +import torch +import numpy as np +from tqdm import tqdm +import time +from sklearn.metrics import roc_auc_score, mean_squared_error, accuracy_score +from torch.autograd import Variable +from IPython.display import clear_output +from torch.nn.utils import clip_grad_norm_ +import matplotlib.pyplot as plt +from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler + + +class TorchDataset(Dataset): + """ + Format for numpy array + + Parameters + ---------- + X: 2D array + The input matrix + y: 2D array + The one-hot encoded target + """ + + def __init__(self, x, y): + self.x = x + self.y = y + self.timer = [] + + def __len__(self): + return len(self.x) + + def __getitem__(self, index): + x, y = self.x[index], self.y[index] + return x, y + + +class Model(object): + def __init__(self, + device_name='auto',): + """ Class for TabNet model + + Parameters + ---------- + type: str + Model type ('classification' or 'regression'). + device_name: str + 'cuda' if running on GPU, 'cpu' if not, 'auto' to autodetect + save: bool + If True, save model in path. Model name will be the time stamp + of its execution. + load: str + Name of model that should be loaded. + """ + + # Defining device + if device_name == 'auto': + if torch.cuda.is_available(): + device_name = 'cuda' + else: + device_name = 'cpu' + self.device = torch.device(device_name) + print(f"Device used : {self.device}") + + def def_network(self, network, **kwargs): + """Defines network architecture and attributes **kwargs to network + parameters, e.g. `input_dim`, `output_dim` and `layers`. + If load is passed to model init, it ignores all parameters and load + the file located at path/load.pt, + + Parameters + ---------- + network: a :class: `nn.Module` + The network whose weights will be updated in training. + See `network.py` for possible networks. + """ + self.network = network(**kwargs).to(self.device) + + def set_params(self, **kwargs): + """Sets default hyperparameters and overrides default with + values set in **kwargs. + + Parameters + ---------- + loss_fn: :class: `torch.nn.functional` + The loss function. A few options: + - torch.nn.functional.mse_loss + - torch.nn.functional.binary_cross_entropy + - torch.nn.functional.cross_entropy + - torch.nn.functional.l1_loss + max_epochs: int + The maximum number of epochs for training + patience: int + learning_rate: float + The initial learning rate + schedule: str + The learning rate schedule('lambda', 'cos', + 'exp' or 'step' or None). + lr_params: dict + Additional infos on the learning rate scheduler + See https://pytorch.org/docs/master/optim.html for params + of each scheduler type. + optimizer_fn: :class: torch.optim object + The optimizer function to be used. A few options: + - torch.optim.SGD + - torch.optim.Adam + - torch.optim.Adadelta + opt_params: dict + Further parameters to be used by optimizer_fn. + or None) + + """ + # default params + self.max_epochs = 100 + self.patience = 15 + self.lr = 2e-2 + self.scheduler = None + self.lr_params = {} + self.opt_params = {} + self.optimizer_fn = None + self.clip_value = None + self.model_name = "DQTabNet" + self.lambda_sparse = 1e-3 + self.scheduler_fn = None + self.patience_counter = 0 + self.batch_size = 1024 + self.saving_path = "./" + self.verbose = 1 + + # Overrides parametersk + self.__dict__.update(kwargs) + + self.output_dim = self.network.output_dim + + if self.output_dim == 1: + self.loss_fn = torch.nn.functional.mse_loss + else: + self.loss_fn = torch.nn.functional.cross_entropy + + self.opt_params['lr'] = self.lr + + if self.optimizer_fn is None: + self.optimizer = torch.optim.Adam(self.network.parameters(), + **self.opt_params) + else: + self.optimizer = self.optimizer_fn(self.network.parameters(), + **self.opt_params) + + if self.scheduler_fn: + self.scheduler = self.scheduler_fn(self.optimizer, **self.scheduler_params) + + def fit(self, X_train, y_train, X_valid=None, y_valid=None, + balanced=False, weights=None): + """Train a neural network stored in self.network + Using train_dataloader for training data and + valid_dataloader for validation. + + Parameters + ---------- + X_train: np.ndarray + Train set + y_train : np.array + Train targets + X_train: np.ndarray + Train set + y_train : np.array + Train targets + balanced : bool + If set to True, training will oversample less frequent classes + weights : dictionnary + For classification problems only, a dictionnary with keys ranging from + 0 to output_dim - 1, with corresponding weights for each class + + """ + # Initialize counters and histories. + self.patience_counter = 0 + self.epoch = 0 + self.best_cost = np.inf + + if balanced: + if weights: + samples_weight = np.array([weights[t] for t in y_train]) + else: + class_sample_count = np.array( + [len(np.where(y_train == t)[0]) for t in np.unique(y_train)]) + + weights = 1. / class_sample_count + + samples_weight = np.array([weights[t] for t in y_train]) + + samples_weight = torch.from_numpy(samples_weight) + samples_weigth = samples_weight.double() + sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) + train_dataloader = DataLoader(TorchDataset(X_train, y_train), batch_size=self.batch_size, sampler=sampler) + valid_dataloader = DataLoader(TorchDataset(X_valid, y_valid), batch_size=self.batch_size, shuffle=False) + + train_dataloader = DataLoader(TorchDataset(X_train, y_train), batch_size=self.batch_size, shuffle=True) + valid_dataloader = DataLoader(TorchDataset(X_valid, y_valid), batch_size=self.batch_size, shuffle=False) + + losses_train = [] + losses_valid = [] + + metrics_train = [] + metrics_valid = [] + + while (self.epoch < self.max_epochs and + self.patience_counter < self.patience): + print(f"EPOCH : {self.epoch}") + fit_metrics = self.fit_epoch(train_dataloader, valid_dataloader) + losses_train.append(fit_metrics['train']['loss_avg']) + losses_valid.append(fit_metrics['valid']['total_loss']) + metrics_train.append(fit_metrics['train']['stopping_loss']) + metrics_valid.append(fit_metrics['valid']['stopping_loss']) + + + stopping_loss = fit_metrics['valid']['stopping_loss'] + if stopping_loss < self.best_cost: + self.best_cost = stopping_loss + self.patience_counter = 0 + + print("saving model") + torch.save(self.network, self.saving_path+f"{self.model_name}.pt") + else: + self.patience_counter += 1 + + print("Best metric valid: ", self.best_cost) + self.epoch += 1 + + if self.epoch % self.verbose == 0: + clear_output() + fig = plt.figure(figsize=(15, 5)) + plt.subplot(1, 2, 1) + plt.plot(range(len(losses_train)), losses_train, label='Train') + plt.plot(range(len(losses_valid)), losses_valid, label='Valid') + plt.grid() + plt.title('Losses') + plt.legend() + #plt.show() + + plt.subplot(1, 2, 2) + plt.plot(range(len(metrics_train)), metrics_train, label='Train') + plt.plot(range(len(metrics_valid)), metrics_valid, label='Valid') + plt.grid() + plt.title('Training Metrics') + plt.legend() + plt.show() + + def fit_epoch(self, train_dataloader, valid_dataloader): + """ + Evaluates and updates network for one epoch. + + Parameters + ---------- + train_dataloader: a :class: `torch.utils.data.Dataloader` + DataLoader with train set + valid_dataloader: a :class: `torch.utils.data.Dataloader` + DataLoader with valid set + """ + train_metrics = self.train_epoch(train_dataloader) + valid_metrics = self.predict_epoch(valid_dataloader) + + fit_metrics = {'train': train_metrics, + 'valid': valid_metrics} + + return fit_metrics + + def train_epoch(self, train_loader): + """ + Trains one epoch of the network in self.network + + Parameters + ---------- + train_loader: a :class: `torch.utils.data.Dataloader` + DataLoader with train set + """ + + self.network.train() + y_preds = [] + ys = [] + total_loss = 0 + + with tqdm() as pbar: + for data, targets in train_loader: + batch_outs = self.train_batch(data, targets) + if self.output_dim == 1: + y_preds.append(batch_outs["y_preds"].cpu().detach().numpy()) + elif self.output_dim == 2: + y_preds.append(batch_outs["y_preds"][:, 1].cpu().detach().numpy()) + else: + values, indices = torch.max(batch_outs["y_preds"], dim=1) + y_preds.append(indices.cpu().detach().numpy()) + ys.append(batch_outs["y"].cpu().detach().numpy()) + total_loss+=batch_outs["loss"] + pbar.update(1) + + y_preds = np.hstack(y_preds) + ys = np.hstack(ys) + + if self.output_dim == 2: + stopping_loss = -roc_auc_score(y_true=ys, y_score=y_preds) + # print("AUC train: ", -stopping_loss) + elif self.output_dim == 1: + stopping_loss = mean_squared_error(y_true=ys, y_pred=y_preds) + # print("MSE train: ", stopping_loss) + else: + stopping_loss = -accuracy_score(y_true=ys, y_pred=y_preds) + # print("ACCURACY Train ", -stopping_loss) + total_loss = total_loss / len(train_loader) + epoch_metrics = {'loss_avg': total_loss, + 'stopping_loss': stopping_loss + } + + if self.scheduler is not None: + self.scheduler.step() + print("Current learning rate: ", self.optimizer.param_groups[-1]["lr"]) + return epoch_metrics + + def train_batch(self, data, targets): + """ + Trains one batch of data + + Parameters + ---------- + data: a :tensor: `torch.tensor` + Input data + target: a :tensor: `torch.tensor` + Target data + """ + self.network.train() + data = data.to(self.device).float() + targets = targets.to(self.device).long() + self.optimizer.zero_grad() + + output, M_loss, M_explain, _ = self.network(data) + + loss = self.loss_fn(output, targets) + loss -= self.lambda_sparse*M_loss + + loss.backward() + if self.clip_value: + clip_grad_norm_(self.network.parameters(), self.clip_value) + self.optimizer.step() + + loss_value = loss.item() + batch_outs = {'loss': loss_value, + 'y_preds': output, + 'y': targets} + return batch_outs + + def predict_epoch(self, loader): + """ + Validates one epoch of the network in self.network + + Parameters + ---------- + loader: a :class: `torch.utils.data.Dataloader` + DataLoader with validation set + """ + y_preds = [] + ys = [] + self.network.eval() + total_loss = 0 + + for data, targets in loader: + batch_outs = self.predict_batch(data, targets) + total_loss += batch_outs["loss"] + if self.output_dim == 1: + y_preds.append(batch_outs["y_preds"].cpu().detach().numpy()) + elif self.output_dim == 2: + y_preds.append(batch_outs["y_preds"][:, 1].cpu().detach().numpy()) + else: + values, indices = torch.max(batch_outs["y_preds"], dim=1) + y_preds.append(indices.cpu().detach().numpy()) + ys.append(batch_outs["y"].cpu().detach().numpy()) + + y_preds = np.hstack(y_preds) + ys = np.hstack(ys) + + if self.output_dim == 2: + stopping_loss = -roc_auc_score(y_true=ys, y_score=y_preds) + # print("AUC Valid: ", -stopping_loss) + elif self.output_dim == 1: + stopping_loss = mean_squared_error(y_true=ys, y_pred=y_preds) + # print("MSE Valid: ", stopping_loss) + else: + stopping_loss = -accuracy_score(y_true=ys, y_pred=y_preds) + # print("ACCURACY Valid ", -stopping_loss) + + total_loss = total_loss / len(loader) + epoch_metrics = {'total_loss': total_loss, + 'stopping_loss': stopping_loss} + + return epoch_metrics + + def predict_batch(self, data, targets): + """ + Make predictions on a batch (valid) + + Parameters + ---------- + data: a :tensor: `torch.Tensor` + Input data + target: a :tensor: `torch.Tensor` + Target data + + Returns + ------- + batch_outs: dict + """ + self.network.eval() + data = data.to(self.device).float() + targets = targets.to(self.device).long() + + output, M_loss, M_explain, _ = self.network(data) + + loss = self.loss_fn(output, targets) + loss -= self.lambda_sparse*M_loss + + loss_value = loss.item() + batch_outs = {'loss': loss_value, + 'y_preds': output, + 'y': targets} + return batch_outs + + def load_best_model(self): + self.network = torch.load(self.saving_path+f"{self.model_name}.pt") + + def predict_proba(self, X): + """ + Make predictions on a batch (valid) + + Parameters + ---------- + data: a :tensor: `torch.Tensor` + Input data + target: a :tensor: `torch.Tensor` + Target data + + Returns + ------- + batch_outs: dict + """ + self.network.eval() + data = torch.Tensor(X).to(self.device).float() + + output, M_loss, M_explain, masks = self.network(data) + predictions = output.cpu().detach().numpy() + + return predictions, M_explain, masks diff --git a/tabnet/tab_network.py b/tabnet/tab_network.py new file mode 100644 index 00000000..189fb851 --- /dev/null +++ b/tabnet/tab_network.py @@ -0,0 +1,335 @@ +import torch +from torch.nn import Linear, BatchNorm1d, ReLU +import numpy as np +from tabnet import sparsemax +from copy import deepcopy + + +def initialize_non_glu(module, input_dim, output_dim): + gain_value = np.sqrt((input_dim+output_dim)/np.sqrt(4*input_dim)) + torch.nn.init.xavier_normal_(module.weight, gain=gain_value) + #torch.nn.init.zeros_(module.bias) + return + + +def initialize_glu(module, input_dim, output_dim): + gain_value = np.sqrt((input_dim+output_dim)/np.sqrt(input_dim)) + torch.nn.init.xavier_normal_(module.weight, gain=gain_value) + #torch.nn.init.zeros_(module.bias) + return + + +class GBN(torch.nn.Module): + """ + Ghost Batch Normalization + https://arxiv.org/abs/1705.08741 + """ + def __init__(self, input_dim, virtual_batch_size=128, momentum=0.01, device='cpu'): + super(GBN, self).__init__() + + self.input_dim = input_dim + self.virtual_batch_size = virtual_batch_size + self.bn = BatchNorm1d(self.input_dim, momentum=momentum) + self.device = device + + def forward(self, x): + chunks = x.chunk(x.shape[0] // self.virtual_batch_size + ((x.shape[0] % self.virtual_batch_size) > 0)) + res = torch.Tensor([]).to(self.device) + for x_ in chunks: + y = self.bn(x_) + res = torch.cat([res, y], dim=0) + + return res + + +class TabNet(torch.nn.Module): + def __init__(self, input_dim, output_dim, n_d, n_a, + n_steps, gamma, cat_idxs, cat_dims, cat_emb_dim=1, + n_independent=2, n_shared=2, epsilon=1e-15, + virtual_batch_size=128, momentum=0.02, device_name='auto'): + """ + Defines TabNet network + + Parameters + ---------- + - input_dim : int + Initial number of features + - output_dim : int + Dimension of network output + examples : one for regression, 2 for binary classification etc... + - n_d : int + Dimension of the prediction layer (usually between 4 and 64) + - n_a : int + Dimension of the attention layer (usually between 4 and 64) + - n_steps: int + Number of sucessive steps in the newtork (usually betwenn 3 and 10) + - gamma : float + Float above 1, scaling factor for attention updates (usually betwenn 1.0 to 2.0) + - cat_idxs : list of int + Index of each categorical column in the dataset + - cat_dims : list of int + Number of categories in each categorical column + - cat_emb_dim : int or list of int + Size of the embedding of categorical features + if int, all categorical features will have same embedding size + if list of int, every corresponding feature will have specific size + - momentum : float + Float value between 0 and 1 which will be used for momentum in all batch norm + - n_independent : int + Number of independent GLU layer in each GLU block (default 2) + - n_shared : int + Number of independent GLU layer in each GLU block (default 2) + - epsilon: float + Avoid log(0), this should be kept very low + """ + super(TabNet, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.n_d = n_d + self.n_a = n_a + self.n_steps = n_steps + self.gamma = gamma + self.epsilon = epsilon + self.n_independent = n_independent + self.n_shared = n_shared + self.cat_idxs = cat_idxs or [] + self.cat_dims = cat_dims or [] + self.cat_emb_dim = cat_emb_dim + self.virtual_batch_size = virtual_batch_size + + # Defining device + if device_name == 'auto': + if torch.cuda.is_available(): + device_name = 'cuda' + else: + device_name = 'cpu' + self.device = torch.device(device_name) + + if type(cat_emb_dim) == int: + self.cat_emb_dims = [cat_emb_dim]*len(self.cat_idxs) + else: + # check that all embeddings are provided + assert(len(cat_emb_dim) == len(cat_dims)) + self.cat_emb_dims = cat_emb_dim + self.embeddings = torch.nn.ModuleList() + for cat_dim, emb_dim in zip(self.cat_dims, self.cat_emb_dims): + self.embeddings.append(torch.nn.Embedding(cat_dim, emb_dim)) + + # record continuous indices + self.continuous_idx = torch.ones(self.input_dim, dtype=torch.uint8) + self.continuous_idx[self.cat_idxs] = 0 + self.post_embed_dim = self.input_dim + (cat_emb_dim - 1)*len(self.cat_idxs) + self.initial_bn = BatchNorm1d(self.post_embed_dim, momentum=0.01) + + if self.n_shared > 0: + shared_feat_transform = GLU_Block(self.post_embed_dim, + n_d+n_a, + n_glu=self.n_shared, + virtual_batch_size=self.virtual_batch_size, + first=True, + momentum=momentum, + device=self.device) + else: + shared_feat_transform = None + + self.initial_splitter = FeatTransformer(self.post_embed_dim, n_d+n_a, shared_feat_transform, + n_glu=self.n_independent, + virtual_batch_size=self.virtual_batch_size, + momentum=momentum, + device=self.device) + + # self.shared_feat_transformers = torch.nn.ModuleList() + self.feat_transformers = torch.nn.ModuleList() + self.att_transformers = torch.nn.ModuleList() + + for step in range(n_steps): + transformer = FeatTransformer(self.post_embed_dim, n_d+n_a, shared_feat_transform, + n_glu=self.n_independent, + virtual_batch_size=self.virtual_batch_size, + momentum=momentum, + device=self.device) + attention = AttentiveTransformer(n_a, self.post_embed_dim, + virtual_batch_size=self.virtual_batch_size, + momentum=momentum, + device=self.device) + self.feat_transformers.append(transformer) + self.att_transformers.append(attention) + + self.soft_max = torch.nn.Softmax(dim=1) + self.final_mapping = Linear(n_d, output_dim, bias=False) + initialize_non_glu(self.final_mapping, n_d, output_dim) + + def apply_embeddings(self, x): + """Apply embdeddings to raw inputs""" + # Getting categorical data + cat_cols = [] + for icat, cat_idx in enumerate(self.cat_idxs): + cat_col = x[:, cat_idx].long() + cat_col = self.embeddings[icat](cat_col) + cat_cols.append(cat_col) + post_embeddings = torch.cat([x[:, self.continuous_idx].float()] + cat_cols, dim=1) + post_embeddings = post_embeddings.float() + return post_embeddings + + def forward(self, x): + res = 0 + x = self.apply_embeddings(x) + x = self.initial_bn(x) + + prior = torch.ones(x.shape).to(self.device) + M_explain = torch.zeros(x.shape).to(self.device) + M_loss = 0 + att = self.initial_splitter(x)[:, self.n_d:] + masks = {} + + for step in range(self.n_steps): + M = self.att_transformers[step](prior, att) + masks[step] = M + M_loss += torch.mean(torch.sum(torch.mul(M, torch.log(M+self.epsilon)), dim=1)) / (self.n_steps) + # update prior + prior = torch.mul(self.gamma - M, prior) + # output + masked_x = torch.mul(M, x) + out = self.feat_transformers[step](masked_x) + d = ReLU()(out[:, :self.n_d]) + res = torch.add(res, d) + # explain + step_importance = torch.sum(d, dim=1) + M_explain += torch.mul(M, step_importance.unsqueeze(dim=1)) + # update attention + att = out[:, self.n_d:] + + res = self.final_mapping(res) + return res, M_loss, M_explain, masks + + +class AttentiveTransformer(torch.nn.Module): + def __init__(self, input_dim, output_dim, virtual_batch_size=128, momentum=0.02, device='cpu'): + """ + Initialize an attention transformer. + + Parameters + ---------- + - input_dim : int + Input size + - output_dim : int + Outpu_size + - momentum : float + Float value between 0 and 1 which will be used for momentum in batch norm + """ + super(AttentiveTransformer, self).__init__() + self.fc = Linear(input_dim, output_dim, bias=False) + initialize_non_glu(self.fc, input_dim, output_dim) + self.bn = GBN(output_dim, virtual_batch_size=virtual_batch_size, momentum=momentum, device=device) + + # Sparsemax + self.sp_max = sparsemax.Sparsemax(dim=-1) + #Entmax + # self.sp_max = sparsemax.Entmax15(dim=-1) + + def forward(self, priors, processed_feat): + x = self.fc(processed_feat) + x = self.bn(x) + x = torch.mul(x, priors) + x = self.sp_max(x) + return x + + +class FeatTransformer(torch.nn.Module): + def __init__(self, input_dim, output_dim, shared_blocks, n_glu, + virtual_batch_size=128, momentum=0.02, device='cpu'): + super(FeatTransformer, self).__init__() + """ + Initialize a feature transformer. + + Parameters + ---------- + - input_dim : int + Input size + - output_dim : int + Outpu_size + - shared_blocks : torch.nn.Module + The shared block that should be common to every step + - momentum : float + Float value between 0 and 1 which will be used for momentum in batch norm + """ + + self.shared = deepcopy(shared_blocks) + if self.shared is not None: + for l in self.shared.glu_layers: + l.bn = GBN(2*output_dim, virtual_batch_size=virtual_batch_size, momentum=momentum, device=device) + + if self.shared is None: + self.specifics = GLU_Block(input_dim, output_dim, + n_glu=n_glu, + first=True, + virtual_batch_size=virtual_batch_size, + momentum=momentum, + device=device) + else: + self.specifics = GLU_Block(output_dim, output_dim, + n_glu=n_glu, + virtual_batch_size=virtual_batch_size, + momentum=momentum, + device=device) + + def forward(self, x): + if self.shared is not None: + x = self.shared(x) + x = self.specifics(x) + return x + + +class GLU_Block(torch.nn.Module): + """ + Independant GLU block, specific to each step + """ + def __init__(self, input_dim, output_dim, n_glu=2, first=False, + virtual_batch_size=128, momentum=0.02, device='cpu'): + super(GLU_Block, self).__init__() + self.first = first + self.n_glu = n_glu + self.glu_layers = torch.nn.ModuleList() + self.scale = torch.sqrt(torch.FloatTensor([0.5]).to(device)) + for glu_id in range(self.n_glu): + if glu_id == 0: + self.glu_layers.append(GLU_Layer(input_dim, output_dim, + virtual_batch_size=virtual_batch_size, + momentum=momentum, + device=device)) + else: + self.glu_layers.append(GLU_Layer(output_dim, output_dim, + virtual_batch_size=virtual_batch_size, + momentum=momentum, + device=device)) + + + def forward(self, x): + if self.first: # the first layer of the block has no scale multiplication + x = self.glu_layers[0](x) + layers_left = range(1, self.n_glu) + else: + layers_left = range(self.n_glu) + + for glu_id in layers_left: + x = torch.add(x, self.glu_layers[glu_id](x)) + x = x*self.scale + return x + + +class GLU_Layer(torch.nn.Module): + def __init__(self, input_dim, output_dim, + virtual_batch_size=128, momentum=0.02, device='cpu'): + super(GLU_Layer, self).__init__() + + self.output_dim = output_dim + self.fc = Linear(input_dim, 2*output_dim, bias=False) + initialize_glu(self.fc, input_dim, 2*output_dim) + + self.bn = GBN(2*output_dim, virtual_batch_size=virtual_batch_size, momentum=momentum, device=device) + + def forward(self, x): + x = self.fc(x) + x = self.bn(x) + out = torch.mul(x[:, :self.output_dim], torch.sigmoid(x[:, self.output_dim:])) + return out