Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add bits-per-byte calculation to levanter #729

Merged
merged 11 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 112 additions & 18 deletions src/levanter/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict
from typing import Callable, Mapping, Optional, Sequence, TypeVar

import equinox as eqx
import jax.numpy as jnp
import jmp
import numpy as np
Expand All @@ -19,7 +20,8 @@
from levanter.logging import LoadingTimeTrackerIterator
from levanter.models.lm_model import LmExample, LmHeadModel, compute_next_token_loss
from levanter.trainer import StepInfo
from levanter.utils.stat_utils import RunningMean
from levanter.utils.hf_utils import HfTokenizer, byte_length_of_token
from levanter.utils.stat_utils import Arrayish, RunningMean
from levanter.utils.tree_utils import inference_mode


Expand All @@ -37,6 +39,10 @@ class EvalResult:
tag_macro_losses: dict[str, float] # per tag average-per-token loss
tag_micro_losses: dict[str, float] # per tag total loss, for "parent" tags
total_eval_loading_time: float
micro_bpb: Optional[float] = None
macro_bpb: Optional[float] = None
tag_macro_bpb: Optional[dict[str, float]] = None
tag_micro_bpb: Optional[dict[str, float]] = None
dlwh marked this conversation as resolved.
Show resolved Hide resolved


# This class doesn't try to be async or work with incomplete datasets, because it's eval
Expand Down Expand Up @@ -152,6 +158,7 @@ def _join_prefix(prefix: str, tag: str) -> str:
def cb_tagged_lm_evaluate(
EvalBatch: hax.Axis,
tagged_eval_sets: Sequence[tuple[AsyncDataset[LmExample], Sequence[str]]],
tokenizer: Optional[HfTokenizer] = None,
device_mesh: Optional[Mesh] = None,
axis_mapping: ResourceMapping = None,
max_examples_per_dataset: Optional[int] = None,
Expand All @@ -173,12 +180,15 @@ def cb_tagged_lm_evaluate(
Args:
EvalBatch: The axis for the evaluation batch (mostly for the batch size)
tagged_eval_sets: A list of datasets, each with its own domain tag
tokenizer: The tokenizer to use for bits-per-byte evaluation (optional)
device_mesh: The mesh to use for evaluation
axis_mapping: The axis mapping to use for evaluation
max_examples_per_dataset: The maximum number of examples to use from each dataset
prefix: The prefix to use for logging the losses
"""

evaluator = TaggedEvaluator(
EvalBatch, tagged_eval_sets, device_mesh, axis_mapping, max_examples_per_dataset, mp=mp
EvalBatch, tagged_eval_sets, tokenizer, device_mesh, axis_mapping, max_examples_per_dataset, mp=mp
)

def eval_callback(step: StepInfo):
Expand Down Expand Up @@ -213,6 +223,14 @@ def eval_callback(step: StepInfo):
log_dict[_join_prefix(prefix, tag) + "/micro_loss"] = loss
logger.info(f"{tag} micro loss: {loss:.3f}")

if tokenizer is not None:
log_dict[_join_prefix(prefix, "bpb")] = result.micro_bpb
log_dict[_join_prefix(prefix, "macro_bpb")] = result.macro_bpb
for tag, bpb in result.tag_micro_bpb.items():
log_dict[_join_prefix(prefix, tag) + "/bpb"] = bpb
for tag, bpb in result.tag_macro_bpb.items():
log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb

levanter.tracker.log_metrics(log_dict, step=step.step)

return result
Expand All @@ -225,6 +243,8 @@ class TaggedEvaluator:
Evaluates multiple tagged datasets using a given evaluation function.
Scores for each tag are aggregated and logged separately, as well as getting an overall score.

TaggedEvaluator computes both log-perplexity and bits-per-byte for each tag, if a tokenizer is provided.

Tags are arranged hierarchically with "/" as separator, and we log both a micro and macro average loss
for each tag.

Expand All @@ -234,6 +254,7 @@ def __init__(
self,
EvalBatch: hax.Axis,
tagged_eval_sets: Sequence[tuple[AsyncDataset, Sequence[str]]],
tokenizer: Optional[HfTokenizer] = None,
device_mesh=None,
axis_mapping=None,
max_examples_per_dataset=None,
Expand All @@ -249,6 +270,8 @@ def __init__(
axis_resources=axis_mapping,
)
self.mp = mp
self.tokenizer = tokenizer
self.bytes_per_token = self._calculate_bytes_per_token_type(tokenizer)

# tags are arranged hierarchically with "/" as separator. We want to log the average loss for each tag.
hierarchy: dict[str, list[int]] = {}
Expand All @@ -264,37 +287,52 @@ def __init__(
self.hierarchy = hierarchy

@hax.named_jit(out_axis_resources=axis_mapping)
def accum_for_batch(
m: LmHeadModel, state: tuple[RunningMean, RunningMean], batch: LmExample, tags: hax.NamedArray
):
def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample, tags: hax.NamedArray):
m = inference_mode(m, True)

if self.mp is not None:
m = self.mp.cast_to_compute(m)

with hax.axis_mapping(axis_mapping):
total_mean, mean_per_tag = state
losses = compute_next_token_loss(m, batch, reduction=None, reduction_axis=())
mask = batch.loss_mask # [Batch, Token]
mask = batch.loss_mask # [Batch, Pos]
dlwh marked this conversation as resolved.
Show resolved Hide resolved
this_tokens = hax.sum(mask)
this_loss = hax.einsum("->", losses, mask) # to scalar

this_tokens_per_tag = hax.einsum("-> tag", mask, tags)
this_loss_per_tag = hax.einsum("-> tag", mask, losses, tags) # [Tag]

mean = total_mean.add(this_loss / this_tokens, this_tokens)
mean = state.loss_per_token.add(this_loss / this_tokens, this_tokens)
# careful: this_tokens_per_tag can be 0 if there are no tokens for that tag
safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0)
mean_per_tag = mean_per_tag.add(safe_mean, this_tokens_per_tag)
mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag)

state = dataclasses.replace(state, loss_per_token=mean, loss_per_tag=mean_per_tag)

if self.bytes_per_token is not None:
next_tokens = hax.roll(batch.tokens, -1, m.Pos)
bytes_per_pos = self.bytes_per_token.take("vocab", next_tokens) # [Batch, Pos]
bytes_per_pos = bytes_per_pos * mask # [Batch, Pos]
bytes_per_tag = hax.einsum("-> tag", bytes_per_pos, tags) # [Tag]
total_bytes = hax.sum(bytes_per_tag)
# log loss -> bits is log2(e) * loss
bpb_per_tag = this_loss_per_tag / hax.maximum(bytes_per_tag, 1) * jnp.log2(jnp.e)
bpb = this_loss / hax.maximum(total_bytes, 1) * jnp.log2(jnp.e)

return mean, mean_per_tag
bpb_mean = state.bpb.add(bpb, this_tokens)
bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag)
state = dataclasses.replace(state, bpb=bpb_mean, bpb_per_tag=bpb_per_tag_mean)

return state

self.accum_for_batch = accum_for_batch

def evaluate(self, m: LmHeadModel):
total_loss = jnp.zeros(())
mean_losses_per_tag = hax.zeros(self.dataset.Tag, dtype=np.float32)

state = (RunningMean.zeros_like(total_loss), RunningMean.zeros_like(mean_losses_per_tag))
state = _EvalRunningMeans.zeros_like(total_loss, mean_losses_per_tag)
del total_loss, mean_losses_per_tag
state = hax.shard(state)

iterator = LoadingTimeTrackerIterator(self.loader)
Expand All @@ -304,19 +342,31 @@ def evaluate(self, m: LmHeadModel):
state = self.accum_for_batch(m, state, batch, tags)
n += 1

total_loss, losses_per_tag = state

micro_avg_loss = total_loss.mean.item()
tag_avg_loss = losses_per_tag.mean
micro_avg_loss = state.loss_per_token.mean.item()
tag_avg_loss = state.loss_per_tag.mean

# TODO: why do i have to jit this
macro_avg_loss = hax.named_jit(lambda x: hax.mean(x).array)(tag_avg_loss).item()

if self.bytes_per_token is not None:
micro_bpb = state.bpb.mean.item()
tag_avg_bpb = state.bpb_per_tag.mean
macro_avg_bpb = hax.named_jit(lambda x: hax.mean(x).array)(tag_avg_bpb).item()
else:
micro_bpb = None
tag_avg_bpb = None
macro_avg_bpb = None

tag_macro_loss = {}
tag_micro_loss = {}
tag_macro_bpb = {}
tag_micro_bpb = {}
dlwh marked this conversation as resolved.
Show resolved Hide resolved

mean_loss_per_tag_cpu = np.array(state.loss_per_tag.mean.array)
total_tokens_per_tag_cpu = np.array(state.loss_per_tag.mean.array)

mean_loss_per_tag_cpu = np.array(losses_per_tag.mean.array) # type: ignore
total_tokens_per_tag_cpu = np.array(losses_per_tag.total.array) # type: ignore
mean_bits_per_tag_cpu = np.array(state.bpb_per_tag.mean.array)
total_bytes_per_tag_cpu = np.array(state.bpb_per_tag.mean.array)

# add in the hierarchy
for parent, children in self.hierarchy.items():
Expand All @@ -333,8 +383,52 @@ def evaluate(self, m: LmHeadModel):
# (average doesn't support where directly so we just 0 out the weights)
tag_micro_loss[parent] = np.average(mean_loss_per_tag_cpu, weights=total_tokens_per_tag_cpu * mask)

if self.bytes_per_token is not None:
tag_macro_bpb[parent] = np.mean(mean_bits_per_tag_cpu, where=mask)
tag_micro_bpb[parent] = np.average(mean_bits_per_tag_cpu, weights=total_bytes_per_tag_cpu * mask)

for tag, index in self.dataset.tag_to_index.items():
tag_micro_loss[tag] = mean_loss_per_tag_cpu[index]
# no macro loss for the leaf tags

return EvalResult(micro_avg_loss, macro_avg_loss, tag_macro_loss, tag_micro_loss, iterator.total_time)
if self.bytes_per_token is not None:
tag_micro_bpb[tag] = mean_bits_per_tag_cpu[index]
# tag_macro_bpb[tag] = None

return EvalResult(
micro_avg_loss,
macro_avg_loss,
tag_macro_loss,
tag_micro_loss,
iterator.total_time,
micro_bpb,
macro_avg_bpb,
tag_macro_bpb,
tag_micro_bpb,
)

def _calculate_bytes_per_token_type(self, tokenizer: HfTokenizer) -> Optional[hax.NamedArray]:
if tokenizer is None:
return None
else:
# calculate the number of bytes in each token
Vocab = hax.Axis("vocab", len(tokenizer.get_vocab()))
bytes = np.ndarray((Vocab.size,), dtype=np.int32)

for i in range(Vocab.size):
bytes[i] = byte_length_of_token(tokenizer, i)

return hax.named(jnp.array(bytes), Vocab)


class _EvalRunningMeans(eqx.Module):
loss_per_token: RunningMean
loss_per_tag: RunningMean
bpb: RunningMean
bpb_per_tag: RunningMean

@staticmethod
def zeros_like(total: Arrayish, per_tag: Arrayish) -> "_EvalRunningMeans":
z = RunningMean.zeros_like(total)
per_tag = RunningMean.zeros_like(per_tag)
return _EvalRunningMeans(z, per_tag, z, per_tag)
1 change: 1 addition & 0 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def main(config: TrainLmConfig):
cb = levanter.eval.cb_tagged_lm_evaluate(
EvalBatch,
causal_datasets,
tokenizer,
trainer.device_mesh,
compute_axis_mapping,
max_eval_examples_per_ds,
Expand Down
1 change: 1 addition & 0 deletions src/levanter/models/lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class LmExample(eqx.Module):
loss_mask: hax.NamedArray
attn_mask: AttentionMask | NamedArray = AttentionMask.causal()


@staticmethod
def causal(
tokens: hax.NamedArray, *, loss_mask: Optional[hax.NamedArray] = None, ignore_id: Optional[int] = None
Expand Down
24 changes: 24 additions & 0 deletions src/levanter/utils/hf_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
from typing import TypeAlias

from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from levanter.logging import silence_transformer_nag
from levanter.utils.py_utils import logical_cpu_core_count
Expand All @@ -8,6 +11,8 @@

_HF_TOKENIZER_OFF_VALUES = {"off", "false", "f", "no", "n", "0"}

HfTokenizer: TypeAlias = PreTrainedTokenizerFast | PreTrainedTokenizer


def num_cpus_used_by_tokenizer(tokenizer) -> int:
if getattr(tokenizer, "is_fast", False):
Expand All @@ -20,3 +25,22 @@ def num_cpus_used_by_tokenizer(tokenizer) -> int:
return min(max(1, logical_cpu_core_count() - 2), 12)
else:
return 1


def byte_length_of_token(tokenizer, idx: int) -> int:
# this is a pain because we want the prefix spaces, but we don't want extra noise for bytes

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this correspond to other implementations that need to compute bpb? Would be good to reference and comment on whether we're doing the same thing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lm-eval-harness is doing the more obvious thing where you just take the whole untokenized string and gets its length, but our eval pipeline starts from the tokenized and chunked sequences, so we have to back it out. I imagine it's not exactly the same value, but it's close enough and we can use lm-eval-harness to get "real" numbers for reporting if we're worried about it https://github.com/EleutherAI/lm-evaluation-harness/blob/fb963f0f0a5b28b69763590bb59676072cf43a01/lm_eval/tasks/french_bench/preprocess_wikitext.py#L39-L48

# e.g. in llama
# >>> t.convert_ids_to_tokens(q[2])
# '▁this'
# >>> t.convert_ids_to_tokens(25)
# '<0x16>'
# We want the _ but not the <0x16>, which should instead be a single byte \x16
# decode strips the prefix spaces, but does correctly handle the <0x16> case
# we can avoid prefix space issues by prepending another token before decoding, then stripping
if idx in tokenizer.all_special_ids:
# NB: special tokens don't have bytes, but they contribute to perplexity/bits
return 0
extra_token = tokenizer(".", add_special_tokens=False)["input_ids"][0]
excess_bytes = len(".".encode("utf-8"))
decoded = tokenizer.decode([extra_token, idx]).encode("utf-8")
return len(decoded) - excess_bytes
2 changes: 1 addition & 1 deletion src/levanter/utils/stat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import haliax as hax


Arrayish: typing.TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray | float
Arrayish: typing.TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray


class RunningMean(eqx.Module):
Expand Down
26 changes: 26 additions & 0 deletions tests/test_hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from fsspec import AbstractFileSystem

from levanter.compat.hf_checkpoints import load_tokenizer
from levanter.utils.hf_utils import byte_length_of_token
from test_utils import skip_if_hf_model_not_accessible


def test_load_tokenizer_in_memory_fs():
Expand All @@ -22,3 +24,27 @@ def test_load_tokenizer_in_memory_fs():
)
tokenizer = load_tokenizer("memory://foo/")
assert len(tokenizer) == 5027


@skip_if_hf_model_not_accessible("meta-llama/Llama-2-7b-hf")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add some tests with more funky Unicode characters and tokens that don't align on character boundaries?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call on the extra tests. caught an issue with the single byte tokens. i now test every token in the llama vocab

def test_byte_length_of_token():
tok = load_tokenizer("meta-llama/Llama-2-7b-hf")
ids = tok("this is hello a test", add_special_tokens=False)["input_ids"]
assert byte_length_of_token(tok, ids[2]) == len(" hello".encode("utf-8"))
assert byte_length_of_token(tok, 25) == 1
# llama prepends a space to the token. ideally it wouldn't b/c it technically throws off our bpb calculations
# but it's a small difference
assert byte_length_of_token(tok, ids[0]) == len(" this".encode("utf-8"))

bos = tok.bos_token_id
assert byte_length_of_token(tok, bos) == 0


@skip_if_hf_model_not_accessible("gpt2")
def test_byte_length_of_token_gpt2():
tok = load_tokenizer("gpt2")
ids = tok("this is hello a test", add_special_tokens=False)["input_ids"]
assert byte_length_of_token(tok, ids[2]) == len(" hello".encode("utf-8"))

eos = tok.eos_token_id
assert byte_length_of_token(tok, eos) == 0
6 changes: 2 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from equinox import nn as nn
from equinox import static_field
from jax._src.random import PRNGKey
from transformers import BatchEncoding
from transformers import AutoConfig, BatchEncoding

import haliax as hax

Expand Down Expand Up @@ -171,9 +171,7 @@ def try_load_path(path):
def skip_if_hf_model_not_accessible(model_id: str):
def try_load_hf(model_id):
try:
from transformers import AutoModel

AutoModel.from_pretrained(model_id)
AutoConfig.from_pretrained(model_id)
except Exception:
return False
else:
Expand Down
Loading