Skip to content

Commit

Permalink
fix: allow smaller different nshared and nindependent
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox authored and eduardocarvp committed Feb 28, 2020
1 parent ae49f69 commit 4b365a7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 45 deletions.
49 changes: 9 additions & 40 deletions pytorch_tabnet/tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from pytorch_tabnet.utils import (PredictDataset,
create_dataloaders,
create_explain_matrix)
from sklearn.base import BaseEstimator
from torch.utils.data import DataLoader
from datetime import datetime


class TabModel(object):
class TabModel(BaseEstimator):
def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[], cat_emb_dim=1,
n_independent=2, n_shared=2, epsilon=1e-15, momentum=0.02,
lambda_sparse=1e-3, seed=0,
Expand Down Expand Up @@ -48,13 +49,10 @@ def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[],
self.device_name = device_name
self.saving_path = saving_path
self.model_name = model_name

self.scheduler_params = scheduler_params
self.scheduler_fn = scheduler_fn

self.opt_params = {}
self.opt_params['lr'] = self.lr

self.seed = seed
torch.manual_seed(self.seed)
# Defining device
Expand Down Expand Up @@ -139,7 +137,7 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None,
self.network.post_embed_dim)

self.optimizer = self.optimizer_fn(self.network.parameters(),
**self.opt_params)
lr=self.lr)

if self.scheduler_fn:
self.scheduler = self.scheduler_fn(self.optimizer, **self.scheduler_params)
Expand Down Expand Up @@ -283,8 +281,11 @@ def predict_batch(self, data, targets):
"""
raise NotImplementedError('users must define predict_batch to use this base class')

def load_best_model(self):
self.network = torch.load(self.saving_path+f"{self.model_name}.pt")
def load_best_model(self, saving_path=None, model_name=None):
if saving_path is None:
saving_path = self.saving_path
model_name = self.model_name
self.network = torch.load(saving_path+f"{self.model_name}.pt")

@abstractmethod
def predict(self, X):
Expand Down Expand Up @@ -351,22 +352,6 @@ def explain(self, X):

class TabNetClassifier(TabModel):

def __repr__(self):
repr_ = f"""TabNetClassifier(n_d={self.n_d}, n_a={self.n_a}, n_steps={self.n_steps},
lr={self.lr}, seed={self.seed},
gamma={self.gamma}, n_independent={self.n_independent}, n_shared={self.n_shared},
cat_idxs={self.cat_idxs},
cat_dims={self.cat_dims},
cat_emb_dim={self.cat_emb_dim},
lambda_sparse={self.lambda_sparse}, momentum={self.momentum},
clip_value={self.clip_value},
verbose={self.verbose}, device_name="{self.device_name}",
model_name="{self.model_name}", epsilon={self.epsilon},
optimizer_fn={str(self.optimizer_fn)},
scheduler_params={self.scheduler_params},
scheduler_fn={self.scheduler_fn}, saving_path="{self.saving_path}")"""
return repr_

def infer_output_dim(self, y_train, y_valid):
"""
Infer output_dim from targets
Expand Down Expand Up @@ -684,22 +669,6 @@ def predict_proba(self, X):

class TabNetRegressor(TabModel):

def __repr__(self):
repr_ = f"""TabNetRegressor(n_d={self.n_d}, n_a={self.n_a}, n_steps={self.n_steps},
lr={self.lr}, seed={self.seed},
gamma={self.gamma}, n_independent={self.n_independent}, n_shared={self.n_shared},
cat_idxs={self.cat_idxs},
cat_dims={self.cat_dims},
cat_emb_dim={self.cat_emb_dim},
lambda_sparse={self.lambda_sparse}, momentum={self.momentum},
clip_value={self.clip_value},
verbose={self.verbose}, device_name="{self.device_name}",
model_name="{self.model_name}",
optimizer_fn={str(self.optimizer_fn)},
scheduler_params={self.scheduler_params}, scheduler_fn={self.scheduler_fn},
epsilon={self.epsilon}, saving_path="{self.saving_path}")"""
return repr_

def construct_loaders(self, X_train, y_train, X_valid, y_valid, weights, batch_size):
"""
Returns
Expand Down
13 changes: 8 additions & 5 deletions pytorch_tabnet/tab_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self, input_dim, output_dim,
shared_feat_transform = None

self.initial_splitter = FeatTransformer(self.input_dim, n_d+n_a, shared_feat_transform,
n_glu=self.n_independent,
n_glu_independent=self.n_independent,
virtual_batch_size=self.virtual_batch_size,
momentum=momentum)

Expand All @@ -110,7 +110,7 @@ def __init__(self, input_dim, output_dim,

for step in range(n_steps):
transformer = FeatTransformer(self.input_dim, n_d+n_a, shared_feat_transform,
n_glu=self.n_independent,
n_glu_independent=self.n_independent,
virtual_batch_size=self.virtual_batch_size,
momentum=momentum)
attention = AttentiveTransformer(n_a, self.input_dim,
Expand Down Expand Up @@ -296,7 +296,7 @@ def forward(self, priors, processed_feat):


class FeatTransformer(torch.nn.Module):
def __init__(self, input_dim, output_dim, shared_layers, n_glu,
def __init__(self, input_dim, output_dim, shared_layers, n_glu_independent,
virtual_batch_size=128, momentum=0.02):
super(FeatTransformer, self).__init__()
"""
Expand All @@ -308,14 +308,15 @@ def __init__(self, input_dim, output_dim, shared_layers, n_glu,
Input size
- output_dim : int
Outpu_size
- n_glu_independant
- shared_blocks : torch.nn.ModuleList
The shared block that should be common to every step
- momentum : float
Float value between 0 and 1 which will be used for momentum in batch norm
"""

params = {
'n_glu': n_glu,
'n_glu': n_glu_independent,
'virtual_batch_size': virtual_batch_size,
'momentum': momentum
}
Expand All @@ -329,7 +330,9 @@ def __init__(self, input_dim, output_dim, shared_layers, n_glu,
self.shared = GLU_Block(input_dim, output_dim,
first=True,
shared_layers=shared_layers,
**params)
n_glu=len(shared_layers),
virtual_batch_size=virtual_batch_size,
momentum=momentum)
self.specifics = GLU_Block(output_dim, output_dim,
**params)

Expand Down

0 comments on commit 4b365a7

Please sign in to comment.