Skip to content

Commit

Permalink
feat: enable feature grouping for attention mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox authored and eduardocarvp committed Dec 12, 2022
1 parent 4fa545d commit bcae5f4
Show file tree
Hide file tree
Showing 13 changed files with 367 additions and 265 deletions.
15 changes: 12 additions & 3 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ commands:
name: Install prerequisites and poetry
command: |
apt update && apt install curl make git libopenblas-base build-essential -y
curl -sSL https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py | python
source $HOME/.poetry/env
curl -sSL https://install.python-poetry.org | python3 -
export PATH="/root/.local/bin:$PATH"
poetry config virtualenvs.path $POETRY_CACHE
poetry run pip install --upgrade --no-cache-dir pip==20.1;
Expand Down Expand Up @@ -70,6 +70,7 @@ jobs:
name: LintCode
shell: bash -leo pipefail
command: |
export PATH="/root/.local/bin:$PATH"
poetry run flake8
install:
executor: python-executor
Expand All @@ -87,7 +88,7 @@ jobs:
name: Install dependencies
shell: bash -leo pipefail
command: |
source $HOME/.poetry/env
export PATH="/root/.local/bin:$PATH"
poetry config virtualenvs.path $POETRY_CACHE
poetry run pip install torch==1.4.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
poetry install
Expand All @@ -108,6 +109,7 @@ jobs:
name: run unit-tests
shell: bash -leo pipefail
command: |
export PATH="/root/.local/bin:$PATH"
make unit-tests
test-nb-census:
executor: python-executor
Expand All @@ -122,6 +124,7 @@ jobs:
name: run test-nb-census
shell: bash -leo pipefail
command: |
export PATH="/root/.local/bin:$PATH"
make test-nb-census
test-nb-multi-regression:
executor: python-executor
Expand All @@ -136,6 +139,7 @@ jobs:
name: run test-nb-multi-regression
shell: bash -leo pipefail
command: |
export PATH="/root/.local/bin:$PATH"
make test-nb-multi-regression
test-nb-forest:
executor: python-executor
Expand All @@ -150,6 +154,7 @@ jobs:
name: run test-nb-forest
shell: bash -leo pipefail
command: |
export PATH="/root/.local/bin:$PATH"
make test-nb-forest
test-nb-regression:
executor: python-executor
Expand All @@ -164,6 +169,7 @@ jobs:
name: run test-nb-regression
shell: bash -leo pipefail
command: |
export PATH="/root/.local/bin:$PATH"
make test-nb-regression
test-nb-multi-task:
executor: python-executor
Expand All @@ -178,6 +184,7 @@ jobs:
name: run test-nb-multi-task
shell: bash -leo pipefail
command: |
export PATH="/root/.local/bin:$PATH"
make test-nb-multi-task
test-nb-customization:
executor: python-executor
Expand All @@ -192,6 +199,7 @@ jobs:
name: run test-nb-customization
shell: bash -leo pipefail
command: |
export PATH="/root/.local/bin:$PATH"
make test-nb-customization
test-nb-pretraining:
executor: python-executor
Expand All @@ -206,6 +214,7 @@ jobs:
name: run test-nb-pretraining
shell: bash -leo pipefail
command: |
export PATH="/root/.local/bin:$PATH"
make test-nb-pretraining
workflows:
version: 2
Expand Down
6 changes: 3 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
FROM python:3.7-slim-buster@sha256:50de4af76270c893fe36a9ae428951057d6e1a681312d11861970baa150a62e2
RUN apt update && apt install curl make git libopenblas-base -y
RUN curl -sSL https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py | python
RUN curl -sSL https://install.python-poetry.org | python3 -
ENV SHELL /bin/bash -l

ENV POETRY_CACHE /work/.cache/poetry
ENV PIP_CACHE_DIR /work/.cache/pip
ENV JUPYTER_RUNTIME_DIR /work/.cache/jupyter/runtime
ENV JUPYTER_CONFIG_DIR /work/.cache/jupyter/config

RUN $HOME/.poetry/bin/poetry config virtualenvs.path $POETRY_CACHE
RUN /root/.local/bin/poetry config virtualenvs.path $POETRY_CACHE

ENV PATH ${PATH}:/root/.poetry/bin:/bin:/usr/local/bin:/usr/bin
ENV PATH /root/.local/bin:/bin:/usr/local/bin:/usr/bin

CMD ["bash", "-l"]
8 changes: 4 additions & 4 deletions Dockerfile_gpu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# GENERATED FROM SCRIPTS
FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04@sha256:5f16bff6a7272eed75d070b13020c98f89fc5be4ebf6cdc95adffa2b5dce4a31
FROM nvidia/cuda:11.3.1-runtime-ubuntu20.04

# Avoid tzdata interactive action
ENV DEBIAN_FRONTEND noninteractive
Expand Down Expand Up @@ -199,7 +199,7 @@ RUN set -eux; \

RUN apt update && apt install curl make git libopenblas-base -y

RUN curl -sSL https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py | python
RUN curl -sSL https://install.python-poetry.org | python3 -

ENV SHELL /bin/bash -l

Expand All @@ -211,8 +211,8 @@ ENV JUPYTER_RUNTIME_DIR /work/.cache/jupyter/runtime

ENV JUPYTER_CONFIG_DIR /work/.cache/jupyter/config

RUN $HOME/.poetry/bin/poetry config virtualenvs.path $POETRY_CACHE
RUN /root/.local/bin/poetry config virtualenvs.path $POETRY_CACHE

ENV PATH /root/.poetry/bin:/bin:/usr/local/bin:/usr/bin
ENV PATH /root/.local/bin:/bin:/usr/local/bin:/usr/bin

CMD ["bash", "-l"]
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# TabNet : Attentive Interpretable Tabular Learning

This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attentive Interpretable Tabular Learning. arXiv preprint arXiv:1908.07442.) https://arxiv.org/pdf/1908.07442.pdf.
This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attentive Interpretable Tabular Learning. arXiv preprint arXiv:1908.07442.) https://arxiv.org/pdf/1908.07442.pdf. Please note that some different choices have been made overtime to improve the library which can differ from the orginal paper.

<!--- BADGES: START --->
[![CircleCI](https://circleci.com/gh/dreamquark-ai/tabnet.svg?style=svg)](https://circleci.com/gh/dreamquark-ai/tabnet)
Expand Down Expand Up @@ -68,6 +68,10 @@ If you wan to use it locally within a docker container:

- `make notebook` inside the same terminal. You can then follow the link to a jupyter notebook with tabnet installed.

# What is new ?

- from version **> 4.0** attention is now embedding aware. This aims to maintain a good attention mechanism even with large number of embedding. It is also now possible to specify attention groups (using `grouped_features`). Attention is now done at the group level and not feature level. This is especially useful if a dataset has a lot of columns coming from on single source of data (exemple: a text column transformed using TD-IDF).

# Contributing

When contributing to the TabNet repository, please make sure to first discuss the change you wish to make via a new or already existing issue.
Expand Down Expand Up @@ -316,6 +320,12 @@ loaded_clf.load_model(saved_filepath)
- `mask_type: str` (default='sparsemax')
Either "sparsemax" or "entmax" : this is the masking function to use for selecting features.

- `grouped_features: list of list of ints` (default=None)
This allows the model to share it's attention accross feature inside a same group.
This can be especially useful when your preprocessing generates correlated or dependant features: like if you use a TF-IDF or a PCA on a text column.
Note that feature importance will be exactly the same between features on a same group.
Please also note that embeddings generated for a categorical variable are always inside a same group.

- `n_shared_decoder` : int (default=1)

Number of shared GLU block in decoder, this is only useful for `TabNetPretrainer`.
Expand Down
69 changes: 64 additions & 5 deletions census_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = f\"1\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"torch.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -145,6 +165,35 @@
"cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Grouped features\n",
"\n",
"You can now specify groups of feature which will share a common attention.\n",
"\n",
"This may be very usefull for features comming from a same preprocessing technique like PCA for example."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"len(features)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"grouped_features = [[0, 1, 2], [8, 9, 10]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -160,13 +209,14 @@
"source": [
"tabnet_params = {\"cat_idxs\":cat_idxs,\n",
" \"cat_dims\":cat_dims,\n",
" \"cat_emb_dim\":1,\n",
" \"cat_emb_dim\":2,\n",
" \"optimizer_fn\":torch.optim.Adam,\n",
" \"optimizer_params\":dict(lr=2e-2),\n",
" \"scheduler_params\":{\"step_size\":50, # how to use learning rate scheduler\n",
" \"gamma\":0.9},\n",
" \"scheduler_fn\":torch.optim.lr_scheduler.StepLR,\n",
" \"mask_type\":'entmax' # \"sparsemax\"\n",
" \"mask_type\":'entmax', # \"sparsemax\"\n",
" \"grouped_features\" : grouped_features\n",
" }\n",
"\n",
"clf = TabNetClassifier(**tabnet_params\n",
Expand Down Expand Up @@ -202,7 +252,7 @@
"metadata": {},
"outputs": [],
"source": [
"max_epochs = 100 if not os.getenv(\"CI\", False) else 2"
"max_epochs = 50 if not os.getenv(\"CI\", False) else 2"
]
},
{
Expand Down Expand Up @@ -416,9 +466,11 @@
"source": [
"fig, axs = plt.subplots(1, 3, figsize=(20,20))\n",
"\n",
"\n",
"for i in range(3):\n",
" axs[i].imshow(masks[i][:50])\n",
" axs[i].set_title(f\"mask {i}\")\n"
" axs[i].set_title(f\"mask {i}\")\n",
" axs[i].set_xticklabels(labels = features, rotation=45)"
]
},
{
Expand Down Expand Up @@ -481,6 +533,13 @@
"test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)\n",
"print(test_auc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -499,7 +558,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
"version": "3.7.13"
},
"toc": {
"base_numbering": 1,
Expand Down
6 changes: 4 additions & 2 deletions forest_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@
"metadata": {},
"outputs": [],
"source": [
"# This is a generic pipeline but actually no categorical features are available for this dataset\n",
"\n",
"unused_feat = []\n",
"\n",
"features = [ col for col in train.columns if col not in unused_feat+[target]] \n",
Expand Down Expand Up @@ -237,7 +239,7 @@
"metadata": {},
"outputs": [],
"source": [
"max_epochs = 50 if not os.getenv(\"CI\", False) else 2"
"max_epochs = 100 if not os.getenv(\"CI\", False) else 2"
]
},
{
Expand Down Expand Up @@ -513,7 +515,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
"version": "3.7.13"
},
"toc": {
"base_numbering": 1,
Expand Down
2 changes: 1 addition & 1 deletion multi_task_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
"version": "3.7.13"
},
"toc": {
"base_numbering": 1,
Expand Down
15 changes: 9 additions & 6 deletions pretraining_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@
" mask_type='entmax', # \"sparsemax\",\n",
" n_shared_decoder=1, # nb shared glu for decoding\n",
" n_indep_decoder=1, # nb independent glu for decoding\n",
"# grouped_features=[[0, 1]], # you can group features together here\n",
" verbose=5,\n",
")"
]
},
Expand All @@ -207,7 +209,7 @@
"metadata": {},
"outputs": [],
"source": [
"max_epochs = 1000 if not os.getenv(\"CI\", False) else 2"
"max_epochs = 100 if not os.getenv(\"CI\", False) else 2 # 1000"
]
},
{
Expand All @@ -225,7 +227,7 @@
" batch_size=2048, virtual_batch_size=128,\n",
" num_workers=0,\n",
" drop_last=False,\n",
" pretraining_ratio=0.8,\n",
" pretraining_ratio=0.5,\n",
") "
]
},
Expand Down Expand Up @@ -296,11 +298,12 @@
"outputs": [],
"source": [
"clf = TabNetClassifier(optimizer_fn=torch.optim.Adam,\n",
" optimizer_params=dict(lr=2e-2),\n",
" optimizer_params=dict(lr=2e-3),\n",
" scheduler_params={\"step_size\":10, # how to use learning rate scheduler\n",
" \"gamma\":0.9},\n",
" scheduler_fn=torch.optim.lr_scheduler.StepLR,\n",
" mask_type='sparsemax' # This will be overwritten if using pretrain model\n",
" mask_type='sparsemax', # This will be overwritten if using pretrain model\n",
" verbose=5,\n",
" )"
]
},
Expand All @@ -322,7 +325,7 @@
" num_workers=0,\n",
" weights=1,\n",
" drop_last=False,\n",
" from_unsupervised=loaded_pretrain\n",
" from_unsupervised=loaded_pretrain,\n",
" \n",
") "
]
Expand Down Expand Up @@ -504,7 +507,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.7.13"
},
"toc": {
"base_numbering": 1,
Expand Down
Loading

0 comments on commit bcae5f4

Please sign in to comment.