Skip to content

Commit

Permalink
feat: TabNetMultiTaskClassifier
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox authored and eduardocarvp committed Sep 15, 2020
1 parent 13c2d7a commit 5764a43
Show file tree
Hide file tree
Showing 9 changed files with 899 additions and 31 deletions.
17 changes: 17 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,20 @@ jobs:
shell: bash -leo pipefail
command: |
make test-nb-regression
test-nb-multi-task:
executor: python-executor
steps:
- checkout
# Download and cache dependencies
- restore_cache:
keys:
- v1-dependencies-{{ checksum "poetry.lock" }}
- install_poetry
- run:
name: run test-nb-multi-task
shell: bash -leo pipefail
command: |
make test-nb-multi-task
workflows:
version: 2
Expand All @@ -171,6 +185,9 @@ workflows:
- test-nb-forest:
requires:
- install
- test-nb-multi-task:
requires:
- install
- lint-code:
requires:
- install
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ test-nb-multi-regression: ## run multi regression example tests using notebooks
$(MAKE) _run_notebook NB_FILE="./multi_regression_example.ipynb"
.PHONY: test-obfuscator

test-nb-multi-task: ## run multi task classification example tests using notebooks
$(MAKE) _run_notebook NB_FILE="./multi_task_example.ipynb"
.PHONY: test-obfuscator

help: ## Display help
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
.PHONY: help
Expand Down
29 changes: 26 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). T

![PyPI - Downloads](https://img.shields.io/pypi/dm/pytorch-tabnet)

Any questions ? Want to contribute ? To talk with us ? You can join us on [Slack](https://join.slack.com/t/mltooling/shared_invite/zt-e4y14rbp-IReRxmjoNy27XQFQgh~4ZQ)
Any questions ? Want to contribute ? To talk with us ? You can join us on [Slack](https://join.slack.com/t/mltooling/shared_invite/zt-fxaj0qk7-SWy2_~EWyhj4x9SD6gbRvg)

# Installation

Expand All @@ -37,6 +37,11 @@ If you wan to use it locally within a docker container:

- `make notebook` inside the same terminal. You can then follow the link to a jupyter notebook with tabnet installed.

# What problems does pytorch-tabnet handles?

- TabNetClassifier : binary classification and multi-class classification problems
- TabNetRegressor : simple and multi-task regression problems
- TabNetMultiTaskClassifier: multi-task multi-classification problems

# How to use it?

Expand All @@ -50,7 +55,23 @@ clf.fit(X_train, Y_train, X_valid, y_valid)
preds = clf.predict(X_test)
```

You can also get comfortable with how the code works by playing with the **notebooks tutorials** for adult census income dataset and forest cover type dataset.
or for TabNetMultiTaskClassifier :

```
from pytorch_tabnet.multitask import TabNetMultiTaskClassifier
clf = TabNetMultiTaskClassifier()
clf.fit(X_train, Y_train, X_valid, y_valid)
preds = clf.predict(X_test)
```

# Useful links

- explanatory video : https://youtu.be/ysBaZO8YmX8
- binary classification examples : https://github.com/dreamquark-ai/tabnet/blob/develop/census_example.ipynb
- multi-class classification examples : https://github.com/dreamquark-ai/tabnet/blob/develop/forest_example.ipynb
- regression examples : https://github.com/dreamquark-ai/tabnet/blob/develop/regression_example.ipynb
- multi-task regression examples : https://github.com/dreamquark-ai/tabnet/blob/develop/multi_regression_example.ipynb
- multi-task multi-class classification examples : https://www.kaggle.com/optimo/tabnetmultitaskclassifier

## Model parameters

Expand Down Expand Up @@ -176,9 +197,11 @@ You can also get comfortable with how the code works by playing with the **noteb
1 : automated sampling with inverse class occurences
dict : keys are classes, values are weights for each class

- loss_fn : torch.loss
- loss_fn : torch.loss or list of torch.loss

Loss function for training (default to mse for regression and cross entropy for classification)
When using TabNetMultiTaskClassifier you can set a list of same length as number of tasks,
each task will be assigned its own loss function

- batch_size : int (default=1024)

Expand Down
Loading

0 comments on commit 5764a43

Please sign in to comment.