Skip to content

Commit

Permalink
feat: pretraining matches paper
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed May 27, 2021
1 parent 8c3b795 commit 5adb804
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 18 deletions.
26 changes: 21 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,18 @@ A complete example can be found within the notebook `pretraining_example.ipynb`.

/!\ : current implementation is trying to reconstruct the original inputs, but Batch Normalization applies a random transformation that can't be deduced by a single line, making the reconstruction harder. Lowering the `batch_size` might make the pretraining easier.

# Easy saving and loading

It's really easy to save and re-load a trained model, this makes TabNet production ready.
```
# save tabnet model
saving_path_name = "./tabnet_model_test_1"
saved_filepath = clf.save_model(saving_path_name)
# define new model with basic parameters and load state dict weights
loaded_clf = TabNetClassifier()
loaded_clf.load_model(saved_filepath)
```

# Useful links

Expand Down Expand Up @@ -251,10 +263,6 @@ A complete example can be found within the notebook `pretraining_example.ipynb`.

Name of the model used for saving in disk, you can customize this to easily retrieve and reuse your trained models.

- `saving_path` : str (default = './')

Path defining where to save models.

- `verbose` : int (default=1)

Verbosity for notebooks plots, set to 1 to see every epoch, 0 to get None.
Expand All @@ -263,7 +271,15 @@ A complete example can be found within the notebook `pretraining_example.ipynb`.
'cpu' for cpu training, 'gpu' for gpu training, 'auto' to automatically detect gpu.

- `mask_type: str` (default='sparsemax')
Either "sparsemax" or "entmax" : this is the masking function to use for selecting features
Either "sparsemax" or "entmax" : this is the masking function to use for selecting features.

- `n_shared_decoder` : int (default=1)

Number of shared GLU block in decoder, this is only useful for `TabNetPretrainer`.

- `n_indep_decoder` : int (default=1)

Number of independent GLU block in decoder, this is only useful for `TabNetPretrainer`.

## Fit parameters

Expand Down
7 changes: 5 additions & 2 deletions pretraining_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@
" cat_emb_dim=3,\n",
" optimizer_fn=torch.optim.Adam,\n",
" optimizer_params=dict(lr=2e-2),\n",
" mask_type='entmax' # \"sparsemax\"\n",
" mask_type='entmax', # \"sparsemax\",\n",
" n_shared_decoder=1, # nb shared glu for decoding\n",
" n_indep_decoder=1, # nb independent glu for decoding\n",
")"
]
},
Expand Down Expand Up @@ -214,6 +216,7 @@
" num_workers=0,\n",
" drop_last=False,\n",
" pretraining_ratio=0.8,\n",
"\n",
") "
]
},
Expand Down Expand Up @@ -492,7 +495,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
"version": "3.7.6"
},
"toc": {
"base_numbering": 1,
Expand Down
2 changes: 2 additions & 0 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class TabModel(BaseEstimator):
input_dim: int = None
output_dim: int = None
device_name: str = "auto"
n_shared_decoder: int = 1
n_indep_decoder: int = 1

def __post_init__(self):
self.batch_size = 1024
Expand Down
25 changes: 14 additions & 11 deletions pytorch_tabnet/tab_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def __init__(
input_dim,
n_d=8,
n_steps=3,
n_independent=2,
n_shared=2,
n_independent=1,
n_shared=1,
virtual_batch_size=128,
momentum=0.02,
):
Expand All @@ -228,9 +228,9 @@ def __init__(
gamma : float
Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0)
n_independent : int
Number of independent GLU layer in each GLU block (default 2)
Number of independent GLU layer in each GLU block (default 1)
n_shared : int
Number of independent GLU layer in each GLU block (default 2)
Number of independent GLU layer in each GLU block (default 1)
virtual_batch_size : int
Batch size for Ghost Batch Normalization
momentum : float
Expand All @@ -245,7 +245,6 @@ def __init__(
self.virtual_batch_size = virtual_batch_size

self.feat_transformers = torch.nn.ModuleList()
self.reconstruction_layers = torch.nn.ModuleList()

if self.n_shared > 0:
shared_feat_transform = torch.nn.ModuleList()
Expand All @@ -268,16 +267,16 @@ def __init__(
momentum=momentum,
)
self.feat_transformers.append(transformer)
reconstruction_layer = Linear(n_d, self.input_dim, bias=False)
initialize_non_glu(reconstruction_layer, n_d, self.input_dim)
self.reconstruction_layers.append(reconstruction_layer)

self.reconstruction_layer = Linear(n_d, self.input_dim, bias=False)
initialize_non_glu(self.reconstruction_layer, n_d, self.input_dim)

def forward(self, steps_output):
res = 0
for step_nb, step_output in enumerate(steps_output):
x = self.feat_transformers[step_nb](step_output)