From a516d3fdfaea7a180fb3527104bdbe4b5d477dd0 Mon Sep 17 00:00:00 2001 From: John Date: Mon, 1 Jan 2024 13:22:01 -0800 Subject: [PATCH] Fix broken test --- language_interpolation/lightning_datamodule.py | 18 +++++++++--------- language_interpolation/networks.py | 14 +------------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/language_interpolation/lightning_datamodule.py b/language_interpolation/lightning_datamodule.py index 9e0039d..91a5804 100644 --- a/language_interpolation/lightning_datamodule.py +++ b/language_interpolation/lightning_datamodule.py @@ -246,12 +246,11 @@ def test_dataloader(self) -> DataLoader: ) -class TransformerMixin : - +class TransformerMixin: def collate_fn(self, batch) -> tuple[Tensor, Tensor, list[int]]: # TODO: this does not make sense to me # The max size includes the output - print('something seems wrong here. Fix') + max_size = max(self._max_size, batch[0][0].shape[0]) this_size = random.randint(1, max_size - 1) final_features = torch.stack([sample[0][:this_size] for sample in batch]) @@ -270,7 +269,7 @@ def collate_fn(self, batch) -> tuple[Tensor, Tensor, list[int]]: return self.normalize(final_features), final_targets, final_indexes -class MambaMixin : +class MambaMixin: def collate_fn(self, batch) -> tuple[Tensor, Tensor, list[int]]: # The targets are just the features shifted by 1 # The max size includes the output @@ -290,7 +289,6 @@ def collate_fn(self, batch) -> tuple[Tensor, Tensor, list[int]]: return final_features, final_targets, final_indexes - class SequenceDataModule(pl.LightningDataModule): def __init__( self, @@ -328,9 +326,9 @@ def __init__( """ super().__init__() self._characters_per_feature = characters_per_feature - + self._max_features = max_features - + self._targets = targets self._batch_size = batch_size self._num_workers = num_workers @@ -455,8 +453,10 @@ def test_dataloader(self) -> DataLoader: collate_fn=self.collate_fn, ) -class TransformerDataModule(TransformerMixin, SequenceDataModule) : + +class TransformerDataModule(TransformerMixin, SequenceDataModule): pass + class MambaDataModule(MambaMixin, SequenceDataModule): - pass \ No newline at end of file + pass diff --git a/language_interpolation/networks.py b/language_interpolation/networks.py index 8c895d6..c6e6c50 100644 --- a/language_interpolation/networks.py +++ b/language_interpolation/networks.py @@ -24,7 +24,7 @@ import time from lion_pytorch import Lion from high_order_layers_torch.sparse_optimizers import SparseLion -from language_interpolation.state_space_network import Mamba, ModelArgs +from language_interpolation.state_space_network import Mamba logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG) @@ -863,18 +863,6 @@ def select_network(cfg: DictConfig, device: str = None): widths, output_sizes = tail_focus.compute_sizes(cfg.net.features) logger.info(f"TailFocusNetwork widths {widths} output_sizes {output_sizes}") elif cfg.net.model_type == "mamba": - model_args = ModelArgs( - d_model=cfg.net.d_model, - n_layer=cfg.net.n_layer, - vocab_size=cfg.net.vocab_size, - d_state=cfg.net.d_state, - expand=cfg.net.expand, - dt_rank=cfg.net.dt_rank, - d_conv=cfg.net.d_conv, - pad_vocab_size_multiple=cfg.net.pad_vocab_size_multiple, - conv_bias=cfg.net.conv_bias, - bias=cfg.net.bias, - ) model = Mamba( d_model=cfg.net.d_model, n_layer=cfg.net.n_layer,