Skip to content

Commit

Permalink
fix: remove deepcopy from shared blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp authored and Optimox committed Feb 3, 2020
1 parent 5ce5aca commit 123932a
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions pytorch_tabnet/tab_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from torch.nn import Linear, BatchNorm1d, ReLU
import numpy as np
from pytorch_tabnet import sparsemax
from copy import deepcopy


def initialize_non_glu(module, input_dim, output_dim):
Expand Down Expand Up @@ -263,7 +262,7 @@ def __init__(self, input_dim, output_dim, shared_blocks, n_glu,
Float value between 0 and 1 which will be used for momentum in batch norm
"""

self.shared = deepcopy(shared_blocks)
self.shared = shared_blocks
if self.shared is not None:
for l in self.shared.glu_layers:
l.bn = GBN(2*output_dim, virtual_batch_size=virtual_batch_size,
Expand Down

0 comments on commit 123932a

Please sign in to comment.