Skip to content

Commit

Permalink
Training with mamba seems to be working, need to handle generation pr…
Browse files Browse the repository at this point in the history
…operly now
  • Loading branch information
jloveric committed Dec 30, 2023
1 parent e87f3e6 commit 990b88f
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 15 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ python examples/high_order_interpolation.py data.type=sequence net=conv max_epoc
### mamba
Work in progress
```
python examples/high_order_interpolation.py data.type=sequence net=mamba
```
python examples/high_order_interpolation.py data.type=sequence net=mamba optimizer.lr=1e-4 data.max_features=16 batch_size=1024
```

### tail focus network
Using tail focus network you can handle much much longer sequences, however the accuracy needs to be much higher to not get garbage (random ascii characters that don't look like any language) for a given input
Expand Down
33 changes: 28 additions & 5 deletions examples/high_order_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from language_interpolation.networks import ASCIIPredictionNet
from language_interpolation.networks import ASCIIPredictionNet, MambaASCIIPredictionNet
import os
from omegaconf import DictConfig, OmegaConf
import hydra
Expand All @@ -18,7 +18,7 @@
TextGenerationSampler,
create_gutenberg_cache,
)
from language_interpolation.lightning_datamodule import TransformerDataModule
from language_interpolation.lightning_datamodule import TransformerDataModule, MambaDataModule

import logging
import traceback
Expand Down Expand Up @@ -52,7 +52,6 @@ def run_language_interpolation(cfg: DictConfig):
if cfg.net.model_type in [
"high_order_transformer",
"high_order_input_transformer",
"mamba"
]:
# dataset_generator is only one type so using the default
datamodule = TransformerDataModule(
Expand All @@ -73,7 +72,28 @@ def run_language_interpolation(cfg: DictConfig):
test_filenames=cfg.data.test.filenames,
max_size=cfg.data.max_size,
repeats=cfg.data.repeats,
as_index = True if cfg.net.model_type=="mamba" else False
as_index = False
)
elif cfg.net.model_type in ["mamba"] :
datamodule = MambaDataModule(
characters_per_feature=cfg.data.characters_per_feature,
max_features=cfg.data.max_features,
batch_size=cfg.batch_size,
targets=1,
num_workers=cfg.data.num_workers,
pre_process_workers=cfg.data.pre_process_workers,
gutenberg_ids_train=cfg.data.train.gutenberg_ids,
gutenberg_ids_val=cfg.data.val.gutenberg_ids,
gutenberg_ids_test=cfg.data.test.gutenberg_ids,
gutenberg_range_train=cfg.data.train.gutenberg_range,
gutenberg_range_val=cfg.data.val.gutenberg_range,
gutenberg_range_test=cfg.data.test.gutenberg_range,
train_filenames=cfg.data.train.filenames,
val_filenames=cfg.data.val.filenames,
test_filenames=cfg.data.test.filenames,
max_size=cfg.data.max_size,
repeats=cfg.data.repeats,
as_index = True
)
else:
datamodule = GutenbergDataModule(
Expand Down Expand Up @@ -112,7 +132,10 @@ def run_language_interpolation(cfg: DictConfig):
accumulate_grad_batches=cfg.accumulate_grad_batches,
)

model = ASCIIPredictionNet(cfg)
if cfg.net.model_type == "mamba" :
model = MambaASCIIPredictionNet(cfg)
else :
model = ASCIIPredictionNet(cfg)
trainer.fit(model, datamodule=datamodule)
logger.info("testing")

Expand Down
12 changes: 7 additions & 5 deletions language_interpolation/lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ class TransformerMixin :
def collate_fn(self, batch) -> tuple[Tensor, Tensor, list[int]]:
# TODO: this does not make sense to me
# The max size includes the output
print('something seems wrong here. Fix')
max_size = max(self._max_size, batch[0][0].shape[0])
this_size = random.randint(1, max_size - 1)
final_features = torch.stack([sample[0][:this_size] for sample in batch])
Expand All @@ -271,13 +272,12 @@ def collate_fn(self, batch) -> tuple[Tensor, Tensor, list[int]]:

class MambaMixin :
def collate_fn(self, batch) -> tuple[Tensor, Tensor, list[int]]:
# The targets are just the features shifted by 1
# The max size includes the output
max_size = max(self._max_size, batch[0][0].shape[0])
this_size = random.randint(1, max_size - 1)
final_features = torch.stack([sample[0][:this_size] for sample in batch])
final_features = torch.stack([sample[0][:-1] for sample in batch])

# grab the first letter of the next token
final_targets = torch.stack([sample[0][this_size][0] for sample in batch])
final_targets = torch.stack([sample[0][1:] for sample in batch])

final_indexes = [sample[1] for sample in batch]
if self._as_index is True:
Expand All @@ -287,7 +287,7 @@ def collate_fn(self, batch) -> tuple[Tensor, Tensor, list[int]]:
final_indexes,
)

return self.normalize(final_features), final_targets, final_indexes
return final_features, final_targets, final_indexes



Expand Down Expand Up @@ -328,7 +328,9 @@ def __init__(
"""
super().__init__()
self._characters_per_feature = characters_per_feature

self._max_features = max_features

self._targets = targets
self._batch_size = batch_size
self._num_workers = num_workers
Expand Down
25 changes: 24 additions & 1 deletion language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class ClassificationMixin:
def eval_step(self, batch: Tensor, name: str):
x, y, idx = batch
y_hat = self(x)
#print('y_hat.shape',y_hat.shape, 'y shape', y.shape)

loss = self.loss(y_hat, y.flatten())

diff = torch.argmax(y_hat, dim=1) - y.flatten()
Expand All @@ -131,6 +131,18 @@ def eval_step(self, batch: Tensor, name: str):
self.log(f"{name}_acc", accuracy, prog_bar=True)
return loss

class MambaClassificationMixin:
def eval_step(self, batch: Tensor, name: str):
x, y, idx = batch
y_hat = self(x)
loss = self.loss(y_hat.reshape(y.shape[0]*y.shape[1],-1), y.flatten())

diff = torch.argmax(y_hat, dim=2, keepdim=True) - y
accuracy = torch.where(diff == 0, 1, 0).sum() / torch.numel(diff)

self.log(f"{name}_loss", loss, prog_bar=True)
self.log(f"{name}_acc", accuracy, prog_bar=True)
return loss

class RegressionMixin:
def eval_step(self, batch: Tensor, name: str):
Expand Down Expand Up @@ -895,6 +907,17 @@ def __init__(self, cfg: DictConfig):
self.loss = torch.nn.CrossEntropyLoss()
self.accuracy = Accuracy(top_k=1, task="multiclass", num_classes=128)

class MambaASCIIPredictionNet(MambaClassificationMixin, PredictionNetMixin, LightningModule):
def __init__(self, cfg: DictConfig):
super().__init__()
self.save_hyperparameters(cfg)
self.cfg = cfg

self.model = select_network(cfg)

self.loss = torch.nn.CrossEntropyLoss()
self.accuracy = Accuracy(top_k=1, task="multiclass", num_classes=128)


class RegressionNet(RegressionMixin, PredictionNetMixin, LightningModule):
def __init__(self, cfg: DictConfig):
Expand Down
4 changes: 2 additions & 2 deletions language_interpolation/state_space_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ss

x = self.norm_f(x)
logits = self.lm_head(x)
reduced = logits[:,-1,:].reshape(logits.shape[0], logits.shape[2])
#reduced = logits[:,-1,:].reshape(logits.shape[0], logits.shape[2])

return reduced #logits
return logits #reduced #logits


class ResidualBlock(nn.Module):
Expand Down

0 comments on commit 990b88f

Please sign in to comment.