Skip to content

Commit

Permalink
feat: update readme and notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
queraq authored and eduardocarvp committed Oct 9, 2020
1 parent cc57d62 commit 9cb38d2
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 54 deletions.
50 changes: 44 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ TabNet is now scikit-compatible, training a TabNetClassifier or TabNetRegressor
from pytorch_tabnet.tab_model import TabNetClassifier, TabNetRegressor
clf = TabNetClassifier() #TabNetRegressor()
clf.fit(X_train, Y_train, X_valid, y_valid)
clf.fit(
X_train, Y_train,
eval_set=[(X_valid, y_valid)]
)
preds = clf.predict(X_test)
```

Expand All @@ -60,10 +63,37 @@ or for TabNetMultiTaskClassifier :
```
from pytorch_tabnet.multitask import TabNetMultiTaskClassifier
clf = TabNetMultiTaskClassifier()
clf.fit(X_train, Y_train, X_valid, y_valid)
clf.fit(
X_train, Y_train,
eval_set=[(X_valid, y_valid)]
)
preds = clf.predict(X_test)
```

### Custom early_stopping_metrics

```
from pytorch_tabnet.metrics import Metric
from sklearn.metrics import roc_auc_score
class Gini(Metric):
def __init__(self):
self._name = "gini"
self._maximize = True
def __call__(self, y_true, y_score):
auc = roc_auc_score(y_true, y_score[:, 1])
return max(2*auc - 1, 0.)
clf = TabNetClassifier()
clf.fit(
X_train, Y_train,
eval_set=[(X_valid, y_valid)],
eval_metric=[Gini]
)
```

# Useful links

- explanatory video : https://youtu.be/ysBaZO8YmX8
Expand Down Expand Up @@ -175,13 +205,18 @@ preds = clf.predict(X_test)

Training targets

- X_valid : np.array
- eval_set: list of tuple

Validation features for early stopping
List of eval tuple set (X, y).
The last one is used for early stopping

- y_valid : np.array for early stopping
- eval_name: list of str
List of eval set names.

- eval_metric : list of str
List of evaluation metrics.
The last metric is used for early stopping.

Validation targets
- max_epochs : int (default = 200)

Maximum number of epochs for trainng.
Expand Down Expand Up @@ -218,3 +253,6 @@ preds = clf.predict(X_test)
- drop_last : bool (default=False)

Whether to drop last batch if not complete during training

- callbacks : list of callback function
List of custom callbacks
36 changes: 24 additions & 12 deletions census_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,14 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" X_valid=X_valid, y_valid=y_valid,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" eval_metric=['auc'],\n",
" max_epochs=max_epochs , patience=20,\n",
" batch_size=1024, virtual_batch_size=128,\n",
" num_workers=0,\n",
Expand All @@ -217,8 +217,7 @@
"outputs": [],
"source": [
"# plot losses\n",
"plt.plot(clf.history['train']['loss'])\n",
"plt.plot(clf.history['valid']['loss'])"
"plt.plot(clf.history['loss'])"
]
},
{
Expand All @@ -228,8 +227,8 @@
"outputs": [],
"source": [
"# plot auc\n",
"plt.plot([-x for x in clf.history['train']['metric']])\n",
"plt.plot([-x for x in clf.history['valid']['metric']])"
"plt.plot(clf.history['train_auc'])\n",
"plt.plot(clf.history['valid_auc'])"
]
},
{
Expand All @@ -239,7 +238,7 @@
"outputs": [],
"source": [
"# plot learning rates\n",
"plt.plot([x for x in clf.history['train']['lr']])"
"plt.plot(clf.history['lr'])"
]
},
{
Expand Down Expand Up @@ -421,9 +420,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": ".shap",
"language": "python",
"name": "python3"
"name": ".shap"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -435,7 +434,20 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.6.8"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
Expand Down
33 changes: 23 additions & 10 deletions forest_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@
"metadata": {},
"outputs": [],
"source": [
"max_epochs = 1000 if not os.getenv(\"CI\", False) else 2"
"max_epochs = 5 if not os.getenv(\"CI\", False) else 2"
]
},
{
Expand All @@ -250,7 +250,8 @@
"source": [
"clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" X_valid=X_valid, y_valid=y_valid,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" max_epochs=max_epochs, patience=100,\n",
" batch_size=16384, virtual_batch_size=256\n",
") "
Expand All @@ -263,8 +264,7 @@
"outputs": [],
"source": [
"# plot losses\n",
"plt.plot(clf.history['train']['loss'])\n",
"plt.plot(clf.history['valid']['loss'])"
"plt.plot(clf.history['loss'])"
]
},
{
Expand All @@ -273,9 +273,9 @@
"metadata": {},
"outputs": [],
"source": [
"# plot accuracies\n",
"plt.plot([-x for x in clf.history['train']['metric']])\n",
"plt.plot([-x for x in clf.history['valid']['metric']])"
"# plot accuracy\n",
"plt.plot(clf.history['train_accuracy'])\n",
"plt.plot(clf.history['valid_accuracy'])"
]
},
{
Expand Down Expand Up @@ -495,9 +495,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": ".shap",
"language": "python",
"name": "python3"
"name": ".shap"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -509,7 +509,20 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
"version": "3.6.8"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
Expand Down
42 changes: 28 additions & 14 deletions multi_regression_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
"\n",
"import os\n",
"import wget\n",
"from pathlib import Path"
"from pathlib import Path\n",
"\n",
"\n",
"%load_ext autoreload\n",
"\n",
"%autoreload 2"
]
},
{
Expand Down Expand Up @@ -188,20 +193,21 @@
"metadata": {},
"outputs": [],
"source": [
"max_epochs = 1000 if not os.getenv(\"CI\", False) else 2"
"max_epochs = 10 if not os.getenv(\"CI\", False) else 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
"scrolled": false
},
"outputs": [],
"source": [
"clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" X_valid=X_valid, y_valid=y_valid,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" max_epochs=max_epochs,\n",
" patience=50,\n",
" batch_size=1024, virtual_batch_size=128,\n",
Expand All @@ -216,17 +222,12 @@
"metadata": {},
"outputs": [],
"source": [
"# Deprecated : best model is automatically loaded at end of fit\n",
"# clf.load_best_model()\n",
"\n",
"preds = clf.predict(X_test)\n",
"\n",
"y_true = y_test\n",
"\n",
"test_score = mean_squared_error(y_pred=preds, y_true=y_true)\n",
"test_mse = mean_squared_error(y_pred=preds, y_true=y_test)\n",
"\n",
"print(f\"BEST VALID SCORE FOR {dataset_name} : {clf.best_cost}\")\n",
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_score}\")"
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_mse}\")"
]
},
{
Expand Down Expand Up @@ -296,9 +297,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": ".shap",
"language": "python",
"name": "python3"
"name": ".shap"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -310,7 +311,20 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
"version": "3.6.8"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 9cb38d2

Please sign in to comment.