Skip to content

Commit

Permalink
feat: start PyTorch TabNet Paper Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Oct 28, 2019
0 parents commit e7dc059
Show file tree
Hide file tree
Showing 15 changed files with 2,964 additions and 0 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*
29 changes: 29 additions & 0 deletions .gitatttributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
* text=auto
# Basic .gitattributes for a python repo.

# Source files
# ============
*.pxd text diff=python
*.py text diff=python
*.py3 text diff=python
*.pyw text diff=python
*.pyx text diff=python
*.pyz text diff=python

# Binary files
# ============
*.db binary
*.p binary
*.pkl binary
*.pickle binary
*.pyc binary
*.pyd binary
*.pyo binary

# Jupyter notebook
*.ipynb text

# Note: .db, .p, and .pkl files are associated
# with the python modules ``pickle``, ``dbm.*``,
# ``shelve``, ``marshal``, ``anydbm``, & ``bsddb``
# (among others).
131 changes: 131 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
.cache/
../.history/
data/
.ipynb_checkpoints/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

14 changes: 14 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM python:3.7-slim-buster
RUN apt update && apt install curl make git -y
RUN curl -sSL https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py | python
ENV SHELL /bin/bash -l

ENV POETRY_CACHE /work/.cache/poetry
ENV PIP_CACHE_DIR /work/.cache/pip
ENV JUPYTER_RUNTIME_DIR /work/.cache/jupyter/runtime
ENV JUPYTER_CONFIG_DIR /work/.cache/jupyter/config

RUN $HOME/.poetry/bin/poetry config settings.virtualenvs.path $POETRY_CACHE

# ENTRYPOINT ["poetry", "run"]
CMD ["bash", "-l"]
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2019 DreamQuark

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
47 changes: 47 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# set default shell
SHELL := $(shell which bash)
FOLDER=$$(pwd)
# default shell options
.SHELLFLAGS = -c
NO_COLOR=\\e[39m
OK_COLOR=\\e[32m
ERROR_COLOR=\\e[31m
WARN_COLOR=\\e[33m
PORT=8889
.SILENT: ;
default: help; # default target

IMAGE_NAME=python-poetry:latest

build:
echo "Building Dockerfile"
docker build -t ${IMAGE_NAME} .
.PHONY: build

start: build
echo "Starting container ${IMAGE_NAME}"
docker run --rm -it -v ${FOLDER}:/work -w /work -p ${PORT}:${PORT} -e "JUPYTER_PORT=${PORT}" ${IMAGE_NAME}
.PHONY: start

notebook:
poetry run jupyter notebook --allow-root --ip 0.0.0.0 --port ${PORT} --no-browser --notebook-dir .
.PHONY: notebook

root_bash:
docker exec -it --user root $$(docker ps --filter ancestor=${IMAGE_NAME} --filter expose=${PORT} -q) bash
.PHONY: root_bash

help:
echo -e "make [ACTION] <OPTIONAL_ARGS>"
echo
echo -e "This image uses Poetry for dependency management (https://poetry.eustace.io/)"
echo
echo -e "Default port for Jupyter notebook is 8888"
echo
echo -e "$(UDLINE_TEXT)ACTIONS$(NORMAL_TEXT):"
echo -e "- $(BOLD_TEXT)init$(NORMAL_TEXT): create pyproject.toml interactive and install virtual env"
echo -e "- $(BOLD_TEXT)run$(NORMAL_TEXT) port=<port>: run the Jupyter notebook on the given port"
echo -e "- $(BOLD_TEXT)stop$(NORMAL_TEXT) port=<port>: stop the running notebook on this port"
echo -e "- $(BOLD_TEXT)logs$(NORMAL_TEXT) port=<port>: show and tail the logs of the notebooks"
echo -e "- $(BOLD_TEXT)shell$(NORMAL_TEXT) port=<port>: open a poetry shell"
.PHONY: help
116 changes: 116 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# README

# TabNet : Attentive Interpretable Tabular Learning

This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attentive Interpretable Tabular Learning. arXiv preprint arXiv:1908.07442.) https://arxiv.org/pdf/1908.07442.pdf.

# Installation

You can install using pip by running:
`pip install tabnet`

If you wan to use it locally within a docker container:

`git clone [email protected]:dreamquark-ai/tabnet.git`

`cd tabnet` to get inside the repository

`make start` to build and get inside the container

`poetry install` to install all the dependencies, including jupyter

`make notebook` inside the same terminal

You can then follow the link to a jupyter notebook with tabnet installed.



GPU version is available and should be working but is not supported yet.

# How to use it?

The implementation makes it easy to try different architectures of TabNet.
All you need is to change the network parameters and training parameters. All parameters are quickly describe bellow, to get a better understanding of what each parameters do please refer to the orginal paper.

You can also get comfortable with the code works by playing with the **notebooks tutorials** for adult census income dataset and forest cover type dataset.

## Network parameters

- input_dim : int

Number of initial features of the dataset

- output_dim : int

Size of the desired output. Ex :
- 1 for regression task
- 2 for binary classification
- N > 2 for multiclass classifcation

- nd : int

Width of the decision prediction layer. Bigger values gives more capacity to the model with the risk of overfitting.
Values typically range from 8 to 64.

- na : int

Width of the attention embedding for each mask.
According to the paper nd=na is usually a good choice.

- n_steps : int
Number of steps in the architecture (usually between 3 and 10)

- gamma : float
This is the coefficient for feature reusage in the masks.
A value close to 1 will make mask selection least correlated between layers.
Values range from 1.0 to 2.0
- cat_idxs : list of int

List of categorical features indices.
- cat_emb_dim : list of int

List of embeddings size for each categorical features.
- n_independent : int

Number of independent Gated Linear Units layers at each step.
Usual values range from 1 to 5 (default=2)
- n_shared : int

Number of shared Gated Linear Units at each step
Usual values range from 1 to 5 (default=2)
- virtual_batch_size : int

Size of the mini batches used for Ghost Batch Normalization

## Training parameters

- max_epochs : int (default = 200)

Maximum number of epochs for trainng.
- patience : int (default = 15)

Number of consecutive epochs without improvement before performing early stopping.
- lr : float (default = 0.02)

Initial learning rate used for training. As mentionned in the original paper, a large initial learning of ```0.02 ``` with decay is a good option.
- clip_value : float (default None)

If a float is given this will clip the gradient at clip_value.
- lambda_sparse : float (default = 1e-3)

This is the extra sparsity loss coefficient as proposed in the original paper. The bigger this coefficient is, the sparser your model will be in terms of feature selection. Depending on the difficulty of your problem, reducing this value could help.
- model_name : str (default = 'DQTabNet')

Name of the model used for saving in disk, you can customize this to easily retrieve and reuse your trained models.
- saving_path : str (default = './')

Path defining where to save models.
- scheduler_fn : torch.optim.lr_scheduler (default = None)

Pytorch Scheduler to change learning rates during training.
- scheduler_params: dict

Parameters dictionnary for the scheduler_fn. Ex : {"gamma": 0.95, "step_size": 10}
- verbose : int (default=-1)

Verbosity for notebooks plots, set to 1 to see every epoch.
Loading

0 comments on commit e7dc059

Please sign in to comment.