Skip to content

Commit

Permalink
Add mamba specific generator
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Dec 30, 2023
1 parent 990b88f commit dd1b550
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions language_interpolation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,63 @@ def justify_sample(sample):

return results

def generate_mamba_text(
model: nn.Module,
characters_per_feature: int,
max_characters: int,
text_list: List[str],
output_size: int,
topk: int = 1,
):
"""
TODO: This can be done much more efficiently because right now I'm re-computing
the output from the 0th input every time, while those may be identical (I think)
and so can be re-used. I think I have quadratic generation here when it should
be linear.
:param characters_per_feature: typically 1, the number of characters that make up
a feature.
:param max_characters: The maximum number of characters to use, acts like a moving
window. Set to 0 if all characters should be used.
:param text_list: List of prompts
:param output_size: The number of characters to generate
:param topk: weighted random selection of the topk next characters
:returns: the continuation of the prompts, the original text + the next output_size
characters
"""
model.eval()


results = []
for text_raw in text_list:
text_in = text_raw
for i in range(output_size):
encoding, text_used = encode_input_from_text(
text_in=text_in, features=max_characters
)
encoding = (
encoding
.to(model._device)
.reshape(1, -1, characters_per_feature)
)
model.eval()

output = model(encoding)
values, indices, ascii = decode_output_to_text(
encoding=output[0,-1,:], topk=topk
)

# pick the next character weighted by probabilities of each character
# prevents the same response for every query.
values = values.nan_to_num(nan=1.0)
actual = random.choices(ascii, values.tolist())
text_in = text_in + actual[0]

results.append(text_in.replace("\n", " "))

return results



class TextGenerationSampler(Callback):
def __init__(self, cfg):
Expand All @@ -180,6 +237,16 @@ def on_train_epoch_end(self, trainer, pl_module, outputs=None):
output_size=self._cfg.num_predict,
topk=topk,
)
elif self._cfg.net.model_type in ["mamba"]:
predictions = generate_mamba_text(
pl_module,
characters_per_feature=self._cfg.data.characters_per_feature,
max_characters=self._cfg.data.characters_per_feature
* self._cfg.data.max_features,
text_list=self._cfg.prompts,
output_size=self._cfg.num_predict,
topk=topk,
)
else:
predictions = generate_text(
pl_module,
Expand Down

0 comments on commit dd1b550

Please sign in to comment.