Skip to content

Commit

Permalink
Fix broken test
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jan 1, 2024
1 parent f9b2fc8 commit a516d3f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 22 deletions.
18 changes: 9 additions & 9 deletions language_interpolation/lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,11 @@ def test_dataloader(self) -> DataLoader:
)


class TransformerMixin :

class TransformerMixin:
def collate_fn(self, batch) -> tuple[Tensor, Tensor, list[int]]:
# TODO: this does not make sense to me
# The max size includes the output
print('something seems wrong here. Fix')

max_size = max(self._max_size, batch[0][0].shape[0])
this_size = random.randint(1, max_size - 1)
final_features = torch.stack([sample[0][:this_size] for sample in batch])
Expand All @@ -270,7 +269,7 @@ def collate_fn(self, batch) -> tuple[Tensor, Tensor, list[int]]:
return self.normalize(final_features), final_targets, final_indexes


class MambaMixin :
class MambaMixin:
def collate_fn(self, batch) -> tuple[Tensor, Tensor, list[int]]:
# The targets are just the features shifted by 1
# The max size includes the output
Expand All @@ -290,7 +289,6 @@ def collate_fn(self, batch) -> tuple[Tensor, Tensor, list[int]]:
return final_features, final_targets, final_indexes



class SequenceDataModule(pl.LightningDataModule):
def __init__(
self,
Expand Down Expand Up @@ -328,9 +326,9 @@ def __init__(
"""
super().__init__()
self._characters_per_feature = characters_per_feature

self._max_features = max_features

self._targets = targets
self._batch_size = batch_size
self._num_workers = num_workers
Expand Down Expand Up @@ -455,8 +453,10 @@ def test_dataloader(self) -> DataLoader:
collate_fn=self.collate_fn,
)

class TransformerDataModule(TransformerMixin, SequenceDataModule) :

class TransformerDataModule(TransformerMixin, SequenceDataModule):
pass


class MambaDataModule(MambaMixin, SequenceDataModule):
pass
pass
14 changes: 1 addition & 13 deletions language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import time
from lion_pytorch import Lion
from high_order_layers_torch.sparse_optimizers import SparseLion
from language_interpolation.state_space_network import Mamba, ModelArgs
from language_interpolation.state_space_network import Mamba

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
Expand Down Expand Up @@ -863,18 +863,6 @@ def select_network(cfg: DictConfig, device: str = None):
widths, output_sizes = tail_focus.compute_sizes(cfg.net.features)
logger.info(f"TailFocusNetwork widths {widths} output_sizes {output_sizes}")
elif cfg.net.model_type == "mamba":
model_args = ModelArgs(
d_model=cfg.net.d_model,
n_layer=cfg.net.n_layer,
vocab_size=cfg.net.vocab_size,
d_state=cfg.net.d_state,
expand=cfg.net.expand,
dt_rank=cfg.net.dt_rank,
d_conv=cfg.net.d_conv,
pad_vocab_size_multiple=cfg.net.pad_vocab_size_multiple,
conv_bias=cfg.net.conv_bias,
bias=cfg.net.bias,
)
model = Mamba(
d_model=cfg.net.d_model,
n_layer=cfg.net.n_layer,
Expand Down

0 comments on commit a516d3f

Please sign in to comment.