From ef2a4cce525bc29cbff7575116233bc22ee6b769 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 14 Sep 2024 11:02:39 -0700 Subject: [PATCH 01/11] wip --- src/levanter/eval.py | 131 +++++++++++++++++++++++++++---- src/levanter/main/train_lm.py | 1 + src/levanter/models/lm_model.py | 1 + src/levanter/utils/hf_utils.py | 5 ++ src/levanter/utils/stat_utils.py | 2 +- 5 files changed, 123 insertions(+), 17 deletions(-) diff --git a/src/levanter/eval.py b/src/levanter/eval.py index 2aa9b7ff3..c03646f20 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -5,6 +5,7 @@ from collections import defaultdict from typing import Callable, Mapping, Optional, Sequence, TypeVar +import equinox as eqx import jax.numpy as jnp import jmp import numpy as np @@ -19,7 +20,8 @@ from levanter.logging import LoadingTimeTrackerIterator from levanter.models.lm_model import LmExample, LmHeadModel, compute_next_token_loss from levanter.trainer import StepInfo -from levanter.utils.stat_utils import RunningMean +from levanter.utils.hf_utils import HfTokenizer +from levanter.utils.stat_utils import Arrayish, RunningMean from levanter.utils.tree_utils import inference_mode @@ -37,6 +39,11 @@ class EvalResult: tag_macro_losses: dict[str, float] # per tag average-per-token loss tag_micro_losses: dict[str, float] # per tag total loss, for "parent" tags total_eval_loading_time: float + micro_bpb: Optional[float] = None + macro_bpb: Optional[float] = None + tag_macro_bpb: Optional[dict[str, float]] = None + tag_micro_bpb: Optional[dict[str, float]] = None + # This class doesn't try to be async or work with incomplete datasets, because it's eval @@ -152,6 +159,7 @@ def _join_prefix(prefix: str, tag: str) -> str: def cb_tagged_lm_evaluate( EvalBatch: hax.Axis, tagged_eval_sets: Sequence[tuple[AsyncDataset[LmExample], Sequence[str]]], + tokenizer: Optional[HfTokenizer] = None, device_mesh: Optional[Mesh] = None, axis_mapping: ResourceMapping = None, max_examples_per_dataset: Optional[int] = None, @@ -173,12 +181,15 @@ def cb_tagged_lm_evaluate( Args: EvalBatch: The axis for the evaluation batch (mostly for the batch size) tagged_eval_sets: A list of datasets, each with its own domain tag + tokenizer: The tokenizer to use for bits-per-byte evaluation (optional) device_mesh: The mesh to use for evaluation axis_mapping: The axis mapping to use for evaluation + max_examples_per_dataset: The maximum number of examples to use from each dataset + prefix: The prefix to use for logging the losses """ evaluator = TaggedEvaluator( - EvalBatch, tagged_eval_sets, device_mesh, axis_mapping, max_examples_per_dataset, mp=mp + EvalBatch, tagged_eval_sets, tokenizer, device_mesh, axis_mapping, max_examples_per_dataset, mp=mp ) def eval_callback(step: StepInfo): @@ -213,6 +224,14 @@ def eval_callback(step: StepInfo): log_dict[_join_prefix(prefix, tag) + "/micro_loss"] = loss logger.info(f"{tag} micro loss: {loss:.3f}") + if tokenizer is not None: + log_dict[_join_prefix(prefix, "bpb")] = result.micro_bpb + log_dict[_join_prefix(prefix, "macro_bpb")] = result.macro_bpb + for tag, bpb in result.tag_micro_bpb.items(): + log_dict[_join_prefix(prefix, tag) + "/bpb"] = bpb + for tag, bpb in result.tag_macro_bpb.items(): + log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb + levanter.tracker.log_metrics(log_dict, step=step.step) return result @@ -234,6 +253,7 @@ def __init__( self, EvalBatch: hax.Axis, tagged_eval_sets: Sequence[tuple[AsyncDataset, Sequence[str]]], + tokenizer: Optional[HfTokenizer] = None, device_mesh=None, axis_mapping=None, max_examples_per_dataset=None, @@ -249,6 +269,8 @@ def __init__( axis_resources=axis_mapping, ) self.mp = mp + self.tokenizer = tokenizer + self.bytes_per_token = self._calculate_bytes_per_token_type(tokenizer) # tags are arranged hierarchically with "/" as separator. We want to log the average loss for each tag. hierarchy: dict[str, list[int]] = {} @@ -265,28 +287,45 @@ def __init__( @hax.named_jit(out_axis_resources=axis_mapping) def accum_for_batch( - m: LmHeadModel, state: tuple[RunningMean, RunningMean], batch: LmExample, tags: hax.NamedArray + m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample, tags: hax.NamedArray ): m = inference_mode(m, True) if self.mp is not None: m = self.mp.cast_to_compute(m) + with hax.axis_mapping(axis_mapping): - total_mean, mean_per_tag = state losses = compute_next_token_loss(m, batch, reduction=None, reduction_axis=()) - mask = batch.loss_mask # [Batch, Token] + mask = batch.loss_mask # [Batch, Pos] this_tokens = hax.sum(mask) this_loss = hax.einsum("->", losses, mask) # to scalar this_tokens_per_tag = hax.einsum("-> tag", mask, tags) this_loss_per_tag = hax.einsum("-> tag", mask, losses, tags) # [Tag] - mean = total_mean.add(this_loss / this_tokens, this_tokens) + if self.bytes_per_token is not None: + next_tokens = hax.roll(batch.tokens, -1, m.Pos) + bytes_per_pos = self.bytes_per_token.take("vocab", next_tokens) # [Batch, Pos] + bytes_per_tag = hax.einsum("-> tag", bytes_per_pos, mask, tags) # [Tag] + # log loss -> bits is log2(e) * loss + bits_per_tag = this_loss_per_tag * jnp.log2(jnp.e) + # this max is to avoid 0 bytes, which happens with special tokens + bpb_per_tag = bits_per_tag / hax.maximum(bytes_per_tag, 1) + bpb = this_loss / hax.maximum(hax.sum(bytes_per_pos), 1) * jnp.log2(jnp.e) + + mean = state.loss_per_token.add(this_loss / this_tokens, this_tokens) # careful: this_tokens_per_tag can be 0 if there are no tokens for that tag safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0) - mean_per_tag = mean_per_tag.add(safe_mean, this_tokens_per_tag) + mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag) - return mean, mean_per_tag + state = dataclasses.replace(state, loss_per_token=mean, loss_per_tag=mean_per_tag) + + if self.bytes_per_token is not None: + bpb_mean = state.bpb.add(bpb, this_tokens) + bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag) + state = dataclasses.replace(state, bpb=bpb_mean, bpb_per_tag=bpb_per_tag_mean) + + return state self.accum_for_batch = accum_for_batch @@ -294,7 +333,8 @@ def evaluate(self, m: LmHeadModel): total_loss = jnp.zeros(()) mean_losses_per_tag = hax.zeros(self.dataset.Tag, dtype=np.float32) - state = (RunningMean.zeros_like(total_loss), RunningMean.zeros_like(mean_losses_per_tag)) + state = _EvalRunningMeans.zeros_like(total_loss, mean_losses_per_tag) + del total_loss, mean_losses_per_tag state = hax.shard(state) iterator = LoadingTimeTrackerIterator(self.loader) @@ -304,19 +344,31 @@ def evaluate(self, m: LmHeadModel): state = self.accum_for_batch(m, state, batch, tags) n += 1 - total_loss, losses_per_tag = state - - micro_avg_loss = total_loss.mean.item() - tag_avg_loss = losses_per_tag.mean + micro_avg_loss = state.loss_per_token.mean.item() + tag_avg_loss = state.loss_per_tag.mean # TODO: why do i have to jit this macro_avg_loss = hax.named_jit(lambda x: hax.mean(x).array)(tag_avg_loss).item() + if self.bytes_per_token is not None: + micro_bpb = state.bpb.mean.item() + tag_avg_bpb = state.bpb_per_tag.mean + macro_avg_bpb = hax.named_jit(lambda x: hax.mean(x).array)(tag_avg_bpb).item() + else: + micro_bpb = None + tag_avg_bpb = None + macro_avg_bpb = None + tag_macro_loss = {} tag_micro_loss = {} + tag_macro_bpb = {} + tag_micro_bpb = {} + + mean_loss_per_tag_cpu = np.array(state.loss_per_tag.mean.array) + total_tokens_per_tag_cpu = np.array(state.loss_per_tag.mean.array) - mean_loss_per_tag_cpu = np.array(losses_per_tag.mean.array) # type: ignore - total_tokens_per_tag_cpu = np.array(losses_per_tag.total.array) # type: ignore + mean_bits_per_tag_cpu = np.array(state.bpb_per_tag.mean.array) + total_bytes_per_tag_cpu = np.array(state.bpb_per_tag.mean.array) # add in the hierarchy for parent, children in self.hierarchy.items(): @@ -333,8 +385,55 @@ def evaluate(self, m: LmHeadModel): # (average doesn't support where directly so we just 0 out the weights) tag_micro_loss[parent] = np.average(mean_loss_per_tag_cpu, weights=total_tokens_per_tag_cpu * mask) + if self.bytes_per_token is not None: + tag_macro_bpb[parent] = np.mean(mean_bits_per_tag_cpu, where=mask) + tag_micro_bpb[parent] = np.average(mean_bits_per_tag_cpu, weights=total_bytes_per_tag_cpu * mask) + for tag, index in self.dataset.tag_to_index.items(): tag_micro_loss[tag] = mean_loss_per_tag_cpu[index] # no macro loss for the leaf tags - return EvalResult(micro_avg_loss, macro_avg_loss, tag_macro_loss, tag_micro_loss, iterator.total_time) + if self.bytes_per_token is not None: + tag_micro_bpb[tag] = mean_bits_per_tag_cpu[index] + # tag_macro_bpb[tag] = None + + return EvalResult( + micro_avg_loss, macro_avg_loss, tag_macro_loss, tag_micro_loss, iterator.total_time, + micro_bpb, macro_avg_bpb, tag_macro_bpb, tag_micro_bpb + ) + + def _calculate_bytes_per_token_type(self, tokenizer: HfTokenizer) -> Optional[hax.NamedArray]: + if tokenizer is None: + return None + else: + # calculate the number of bytes in each token + Vocab = hax.Axis("vocab", len(tokenizer.get_vocab())) + bytes = np.ndarray((Vocab.size,), dtype=np.int32) + tok = tokenizer + + for i in range(Vocab.size): + if i in tok.all_special_ids: + # NB: special tokens don't have bytes, but they contribute to perplexity/bits + bytes[i] = 0 + continue + token_str = tok.convert_tokens_to_string([tok.convert_ids_to_tokens(i)]) + bytes[i] = len(token_str.encode("utf-8")) + + return hax.named(jnp.array(bytes), Vocab) + + + + +class _EvalRunningMeans(eqx.Module): + loss_per_token: RunningMean + loss_per_tag: RunningMean + bpb: RunningMean + bpb_per_tag: RunningMean + + @staticmethod + def zeros_like(total: Arrayish, per_tag: Arrayish) -> "_EvalRunningMeans": + z = RunningMean.zeros_like(total) + per_tag = RunningMean.zeros_like(per_tag) + return _EvalRunningMeans(z, per_tag, z, per_tag) + + diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 8e905b064..6c96f8b62 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -174,6 +174,7 @@ def main(config: TrainLmConfig): cb = levanter.eval.cb_tagged_lm_evaluate( EvalBatch, causal_datasets, + tokenizer, trainer.device_mesh, compute_axis_mapping, max_eval_examples_per_ds, diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 468f6a4a4..dd9a11017 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -22,6 +22,7 @@ class LmExample(eqx.Module): loss_mask: hax.NamedArray attn_mask: AttentionMask | NamedArray = AttentionMask.causal() + @staticmethod def causal( tokens: hax.NamedArray, *, loss_mask: Optional[hax.NamedArray] = None, ignore_id: Optional[int] = None diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index e5a576236..09ba199b0 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -1,4 +1,7 @@ import os +from typing import TypeAlias + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from levanter.logging import silence_transformer_nag from levanter.utils.py_utils import logical_cpu_core_count @@ -8,6 +11,8 @@ _HF_TOKENIZER_OFF_VALUES = {"off", "false", "f", "no", "n", "0"} +HfTokenizer: TypeAlias = PreTrainedTokenizerFast | PreTrainedTokenizer + def num_cpus_used_by_tokenizer(tokenizer) -> int: if getattr(tokenizer, "is_fast", False): diff --git a/src/levanter/utils/stat_utils.py b/src/levanter/utils/stat_utils.py index e51918d2f..6111be42e 100644 --- a/src/levanter/utils/stat_utils.py +++ b/src/levanter/utils/stat_utils.py @@ -7,7 +7,7 @@ import haliax as hax -Arrayish: typing.TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray | float +Arrayish: typing.TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray class RunningMean(eqx.Module): From a205b7df350b63ab97544aefee76ed7efdc32cb2 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 14 Sep 2024 11:40:50 -0700 Subject: [PATCH 02/11] wip --- src/levanter/eval.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/levanter/eval.py b/src/levanter/eval.py index c03646f20..7bb2b04d7 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -244,6 +244,8 @@ class TaggedEvaluator: Evaluates multiple tagged datasets using a given evaluation function. Scores for each tag are aggregated and logged separately, as well as getting an overall score. + TaggedEvaluator computes both log-perplexity and bits-per-byte for each tag, if a tokenizer is provided. + Tags are arranged hierarchically with "/" as separator, and we log both a micro and macro average loss for each tag. @@ -303,24 +305,24 @@ def accum_for_batch( this_tokens_per_tag = hax.einsum("-> tag", mask, tags) this_loss_per_tag = hax.einsum("-> tag", mask, losses, tags) # [Tag] + mean = state.loss_per_token.add(this_loss / this_tokens, this_tokens) + # careful: this_tokens_per_tag can be 0 if there are no tokens for that tag + safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0) + mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag) + + state = dataclasses.replace(state, loss_per_token=mean, loss_per_tag=mean_per_tag) + if self.bytes_per_token is not None: next_tokens = hax.roll(batch.tokens, -1, m.Pos) bytes_per_pos = self.bytes_per_token.take("vocab", next_tokens) # [Batch, Pos] - bytes_per_tag = hax.einsum("-> tag", bytes_per_pos, mask, tags) # [Tag] + bytes_per_pos = hax.einsum("... -> ...", bytes_per_pos, mask) # [Batch, Pos] + bytes_per_tag = hax.einsum("-> tag", bytes_per_pos, tags) # [Tag] # log loss -> bits is log2(e) * loss bits_per_tag = this_loss_per_tag * jnp.log2(jnp.e) # this max is to avoid 0 bytes, which happens with special tokens bpb_per_tag = bits_per_tag / hax.maximum(bytes_per_tag, 1) bpb = this_loss / hax.maximum(hax.sum(bytes_per_pos), 1) * jnp.log2(jnp.e) - mean = state.loss_per_token.add(this_loss / this_tokens, this_tokens) - # careful: this_tokens_per_tag can be 0 if there are no tokens for that tag - safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0) - mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag) - - state = dataclasses.replace(state, loss_per_token=mean, loss_per_tag=mean_per_tag) - - if self.bytes_per_token is not None: bpb_mean = state.bpb.add(bpb, this_tokens) bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag) state = dataclasses.replace(state, bpb=bpb_mean, bpb_per_tag=bpb_per_tag_mean) From 230c4212afc7048406dde1e37afdce85dbfaa19a Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 14 Sep 2024 22:25:51 -0700 Subject: [PATCH 03/11] it feels correct? --- src/levanter/eval.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/levanter/eval.py b/src/levanter/eval.py index 7bb2b04d7..4dda31f93 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -45,7 +45,6 @@ class EvalResult: tag_micro_bpb: Optional[dict[str, float]] = None - # This class doesn't try to be async or work with incomplete datasets, because it's eval @@ -288,9 +287,7 @@ def __init__( self.hierarchy = hierarchy @hax.named_jit(out_axis_resources=axis_mapping) - def accum_for_batch( - m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample, tags: hax.NamedArray - ): + def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample, tags: hax.NamedArray): m = inference_mode(m, True) if self.mp is not None: @@ -315,13 +312,12 @@ def accum_for_batch( if self.bytes_per_token is not None: next_tokens = hax.roll(batch.tokens, -1, m.Pos) bytes_per_pos = self.bytes_per_token.take("vocab", next_tokens) # [Batch, Pos] - bytes_per_pos = hax.einsum("... -> ...", bytes_per_pos, mask) # [Batch, Pos] + bytes_per_pos = bytes_per_pos * mask # [Batch, Pos] bytes_per_tag = hax.einsum("-> tag", bytes_per_pos, tags) # [Tag] + total_bytes = hax.sum(bytes_per_tag) # log loss -> bits is log2(e) * loss - bits_per_tag = this_loss_per_tag * jnp.log2(jnp.e) - # this max is to avoid 0 bytes, which happens with special tokens - bpb_per_tag = bits_per_tag / hax.maximum(bytes_per_tag, 1) - bpb = this_loss / hax.maximum(hax.sum(bytes_per_pos), 1) * jnp.log2(jnp.e) + bpb_per_tag = this_loss_per_tag / hax.maximum(bytes_per_tag, 1) * jnp.log2(jnp.e) + bpb = this_loss / hax.maximum(total_bytes, 1) * jnp.log2(jnp.e) bpb_mean = state.bpb.add(bpb, this_tokens) bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag) @@ -400,8 +396,15 @@ def evaluate(self, m: LmHeadModel): # tag_macro_bpb[tag] = None return EvalResult( - micro_avg_loss, macro_avg_loss, tag_macro_loss, tag_micro_loss, iterator.total_time, - micro_bpb, macro_avg_bpb, tag_macro_bpb, tag_micro_bpb + micro_avg_loss, + macro_avg_loss, + tag_macro_loss, + tag_micro_loss, + iterator.total_time, + micro_bpb, + macro_avg_bpb, + tag_macro_bpb, + tag_micro_bpb, ) def _calculate_bytes_per_token_type(self, tokenizer: HfTokenizer) -> Optional[hax.NamedArray]: @@ -424,8 +427,6 @@ def _calculate_bytes_per_token_type(self, tokenizer: HfTokenizer) -> Optional[ha return hax.named(jnp.array(bytes), Vocab) - - class _EvalRunningMeans(eqx.Module): loss_per_token: RunningMean loss_per_tag: RunningMean @@ -437,5 +438,3 @@ def zeros_like(total: Arrayish, per_tag: Arrayish) -> "_EvalRunningMeans": z = RunningMean.zeros_like(total) per_tag = RunningMean.zeros_like(per_tag) return _EvalRunningMeans(z, per_tag, z, per_tag) - - From 133f1a5a5b95c55718d5de2ca2b603d62481d650 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 15 Sep 2024 00:07:37 -0700 Subject: [PATCH 04/11] fix bpb for llama tokenizer (which has weird space handling) --- src/levanter/eval.py | 10 ++-------- src/levanter/utils/hf_utils.py | 21 ++++++++++++++++++++- tests/test_hf_utils.py | 26 ++++++++++++++++++++++++++ tests/test_utils.py | 6 ++---- 4 files changed, 50 insertions(+), 13 deletions(-) diff --git a/src/levanter/eval.py b/src/levanter/eval.py index 4dda31f93..b5e80c272 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -20,7 +20,7 @@ from levanter.logging import LoadingTimeTrackerIterator from levanter.models.lm_model import LmExample, LmHeadModel, compute_next_token_loss from levanter.trainer import StepInfo -from levanter.utils.hf_utils import HfTokenizer +from levanter.utils.hf_utils import HfTokenizer, byte_length_of_token from levanter.utils.stat_utils import Arrayish, RunningMean from levanter.utils.tree_utils import inference_mode @@ -414,15 +414,9 @@ def _calculate_bytes_per_token_type(self, tokenizer: HfTokenizer) -> Optional[ha # calculate the number of bytes in each token Vocab = hax.Axis("vocab", len(tokenizer.get_vocab())) bytes = np.ndarray((Vocab.size,), dtype=np.int32) - tok = tokenizer for i in range(Vocab.size): - if i in tok.all_special_ids: - # NB: special tokens don't have bytes, but they contribute to perplexity/bits - bytes[i] = 0 - continue - token_str = tok.convert_tokens_to_string([tok.convert_ids_to_tokens(i)]) - bytes[i] = len(token_str.encode("utf-8")) + bytes[i] = byte_length_of_token(tokenizer, i) return hax.named(jnp.array(bytes), Vocab) diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index 09ba199b0..1b800bf37 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -11,7 +11,7 @@ _HF_TOKENIZER_OFF_VALUES = {"off", "false", "f", "no", "n", "0"} -HfTokenizer: TypeAlias = PreTrainedTokenizerFast | PreTrainedTokenizer +HfTokenizer: TypeAlias = PreTrainedTokenizerFast | PreTrainedTokenizer def num_cpus_used_by_tokenizer(tokenizer) -> int: @@ -25,3 +25,22 @@ def num_cpus_used_by_tokenizer(tokenizer) -> int: return min(max(1, logical_cpu_core_count() - 2), 12) else: return 1 + + +def byte_length_of_token(tokenizer, idx: int) -> int: + # this is a pain because we want the prefix spaces, but we don't want extra noise for bytes + # e.g. in llama + # >>> t.convert_ids_to_tokens(q[2]) + # '▁this' + # >>> t.convert_ids_to_tokens(25) + # '<0x16>' + # We want the _ but not the <0x16>, which should instead be a single byte \x16 + # decode strips the prefix spaces, but does correctly handle the <0x16> case + # we can avoid prefix space issues by prepending another token before decoding, then stripping + if idx in tokenizer.all_special_ids: + # NB: special tokens don't have bytes, but they contribute to perplexity/bits + return 0 + extra_token = tokenizer(".", add_special_tokens=False)["input_ids"][0] + excess_bytes = len(".".encode("utf-8")) + decoded = tokenizer.decode([extra_token, idx]).encode("utf-8") + return len(decoded) - excess_bytes diff --git a/tests/test_hf_utils.py b/tests/test_hf_utils.py index e6a6158e2..a7a378bd1 100644 --- a/tests/test_hf_utils.py +++ b/tests/test_hf_utils.py @@ -3,6 +3,8 @@ from fsspec import AbstractFileSystem from levanter.compat.hf_checkpoints import load_tokenizer +from levanter.utils.hf_utils import byte_length_of_token +from test_utils import skip_if_hf_model_not_accessible def test_load_tokenizer_in_memory_fs(): @@ -22,3 +24,27 @@ def test_load_tokenizer_in_memory_fs(): ) tokenizer = load_tokenizer("memory://foo/") assert len(tokenizer) == 5027 + + +@skip_if_hf_model_not_accessible("meta-llama/Llama-2-7b-hf") +def test_byte_length_of_token(): + tok = load_tokenizer("meta-llama/Llama-2-7b-hf") + ids = tok("this is hello a test", add_special_tokens=False)["input_ids"] + assert byte_length_of_token(tok, ids[2]) == len(" hello".encode("utf-8")) + assert byte_length_of_token(tok, 25) == 1 + # llama prepends a space to the token. ideally it wouldn't b/c it technically throws off our bpb calculations + # but it's a small difference + assert byte_length_of_token(tok, ids[0]) == len(" this".encode("utf-8")) + + bos = tok.bos_token_id + assert byte_length_of_token(tok, bos) == 0 + + +@skip_if_hf_model_not_accessible("gpt2") +def test_byte_length_of_token_gpt2(): + tok = load_tokenizer("gpt2") + ids = tok("this is hello a test", add_special_tokens=False)["input_ids"] + assert byte_length_of_token(tok, ids[2]) == len(" hello".encode("utf-8")) + + eos = tok.eos_token_id + assert byte_length_of_token(tok, eos) == 0 diff --git a/tests/test_utils.py b/tests/test_utils.py index 1bf03b624..6206ec2ff 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,7 +11,7 @@ from equinox import nn as nn from equinox import static_field from jax._src.random import PRNGKey -from transformers import BatchEncoding +from transformers import AutoConfig, BatchEncoding import haliax as hax @@ -171,9 +171,7 @@ def try_load_path(path): def skip_if_hf_model_not_accessible(model_id: str): def try_load_hf(model_id): try: - from transformers import AutoModel - - AutoModel.from_pretrained(model_id) + AutoConfig.from_pretrained(model_id) except Exception: return False else: From 0827737a3afd675e9bc2e5b270b76cade772d856 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 15 Sep 2024 00:30:23 -0700 Subject: [PATCH 05/11] pre-commit --- src/levanter/models/lm_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index dd9a11017..468f6a4a4 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -22,7 +22,6 @@ class LmExample(eqx.Module): loss_mask: hax.NamedArray attn_mask: AttentionMask | NamedArray = AttentionMask.causal() - @staticmethod def causal( tokens: hax.NamedArray, *, loss_mask: Optional[hax.NamedArray] = None, ignore_id: Optional[int] = None From 3f60b823d47adbe1b898cd89662a0247ede1ab5d Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 15 Sep 2024 22:58:17 -0700 Subject: [PATCH 06/11] fix train_lm entry --- tests/tiny_test_corpus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tiny_test_corpus.py b/tests/tiny_test_corpus.py index 91597c137..fb09f362a 100644 --- a/tests/tiny_test_corpus.py +++ b/tests/tiny_test_corpus.py @@ -69,7 +69,7 @@ def construct_small_data_cache( validation_urls=[f"file://{path}/validation/docs.jsonl"], cache_dir=f"{path}/cache", vocab_size=vocab_size, - tokenizer="passthrough", + tokenizer="gpt2", ) return config, caches From 6838884c5980e9f092868d511bdb0e41e57bb683 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 16 Sep 2024 10:52:20 -0700 Subject: [PATCH 07/11] will this make the tests happier? --- src/levanter/utils/thread_utils.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/levanter/utils/thread_utils.py b/src/levanter/utils/thread_utils.py index 0b4abcdaf..4ac4e1f1e 100644 --- a/src/levanter/utils/thread_utils.py +++ b/src/levanter/utils/thread_utils.py @@ -34,24 +34,41 @@ class AsyncIteratorWrapper(Iterator): def __init__(self, async_iter): self.async_iter = async_iter self.loop = asyncio.new_event_loop() - self.executor = ThreadPoolExecutor(max_workers=1) self.thread = threading.Thread(target=self._run_loop, daemon=True) self.thread.start() + self._exhausted = False # Flag to indicate if the iterator is exhausted def _run_loop(self): asyncio.set_event_loop(self.loop) self.loop.run_forever() def _run_async_task(self, coro): - return asyncio.run_coroutine_threadsafe(coro, self.loop).result() + if not self.loop.is_running() or not self.thread.is_alive(): + raise StopIteration # Loop is not running or thread has been joined + try: + future = asyncio.run_coroutine_threadsafe(coro, self.loop) + return future.result() + except (RuntimeError, asyncio.CancelledError): + raise StopIteration # Either the loop was closed or the coroutine was cancelled def __iter__(self): return self def __next__(self): + if self._exhausted: + raise StopIteration try: return self._run_async_task(self.async_iter.__anext__()) except StopAsyncIteration: - self.loop.call_soon_threadsafe(self.loop.stop) - self.thread.join() + self._exhausted = True # Mark the iterator as exhausted + if self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + self.thread.join() # Ensure the thread is safely joined raise StopIteration + + def close(self): + """Close the event loop and thread gracefully.""" + if self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + self.thread.join() # Join the thread to ensure the loop is fully stopped + self.loop.close() # Explicitly close the loop to avoid dangling tasks From 14975d36957cd62e9e78b14b488686e9c3f4022e Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 16 Sep 2024 13:38:35 -0700 Subject: [PATCH 08/11] remove chatgpt comments --- src/levanter/utils/thread_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/levanter/utils/thread_utils.py b/src/levanter/utils/thread_utils.py index 4ac4e1f1e..fad60ad31 100644 --- a/src/levanter/utils/thread_utils.py +++ b/src/levanter/utils/thread_utils.py @@ -44,12 +44,12 @@ def _run_loop(self): def _run_async_task(self, coro): if not self.loop.is_running() or not self.thread.is_alive(): - raise StopIteration # Loop is not running or thread has been joined + raise StopIteration try: future = asyncio.run_coroutine_threadsafe(coro, self.loop) return future.result() except (RuntimeError, asyncio.CancelledError): - raise StopIteration # Either the loop was closed or the coroutine was cancelled + raise StopIteration def __iter__(self): return self @@ -63,12 +63,12 @@ def __next__(self): self._exhausted = True # Mark the iterator as exhausted if self.loop.is_running(): self.loop.call_soon_threadsafe(self.loop.stop) - self.thread.join() # Ensure the thread is safely joined + self.thread.join() raise StopIteration def close(self): """Close the event loop and thread gracefully.""" if self.loop.is_running(): self.loop.call_soon_threadsafe(self.loop.stop) - self.thread.join() # Join the thread to ensure the loop is fully stopped - self.loop.close() # Explicitly close the loop to avoid dangling tasks + self.thread.join() + self.loop.close() From 52c26f22c52f23487886a304ef5ecec4cc84c615 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 17 Sep 2024 16:06:26 -0700 Subject: [PATCH 09/11] pr comments --- src/levanter/eval.py | 30 ++++++++++++------------ src/levanter/utils/hf_utils.py | 16 +++++++++---- tests/test_hf_utils.py | 42 ++++++++++++++++++++++++++++++++-- 3 files changed, 66 insertions(+), 22 deletions(-) diff --git a/src/levanter/eval.py b/src/levanter/eval.py index b5e80c272..c03e2b519 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -299,10 +299,11 @@ def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample, this_tokens = hax.sum(mask) this_loss = hax.einsum("->", losses, mask) # to scalar + # all the *_per_tag variables are [Tag] this_tokens_per_tag = hax.einsum("-> tag", mask, tags) this_loss_per_tag = hax.einsum("-> tag", mask, losses, tags) # [Tag] - mean = state.loss_per_token.add(this_loss / this_tokens, this_tokens) + mean = state.token_avg_loss.add(this_loss / this_tokens, this_tokens) # careful: this_tokens_per_tag can be 0 if there are no tokens for that tag safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0) mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag) @@ -310,11 +311,12 @@ def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample, state = dataclasses.replace(state, loss_per_token=mean, loss_per_tag=mean_per_tag) if self.bytes_per_token is not None: - next_tokens = hax.roll(batch.tokens, -1, m.Pos) + next_tokens = hax.roll(batch.tokens, -1, m.Pos) # [Batch, Pos], rolled by 1 for next token task bytes_per_pos = self.bytes_per_token.take("vocab", next_tokens) # [Batch, Pos] bytes_per_pos = bytes_per_pos * mask # [Batch, Pos] bytes_per_tag = hax.einsum("-> tag", bytes_per_pos, tags) # [Tag] total_bytes = hax.sum(bytes_per_tag) + # log loss -> bits is log2(e) * loss bpb_per_tag = this_loss_per_tag / hax.maximum(bytes_per_tag, 1) * jnp.log2(jnp.e) bpb = this_loss / hax.maximum(total_bytes, 1) * jnp.log2(jnp.e) @@ -342,7 +344,7 @@ def evaluate(self, m: LmHeadModel): state = self.accum_for_batch(m, state, batch, tags) n += 1 - micro_avg_loss = state.loss_per_token.mean.item() + micro_avg_loss = state.token_avg_loss.mean.item() tag_avg_loss = state.loss_per_tag.mean # TODO: why do i have to jit this @@ -354,13 +356,12 @@ def evaluate(self, m: LmHeadModel): macro_avg_bpb = hax.named_jit(lambda x: hax.mean(x).array)(tag_avg_bpb).item() else: micro_bpb = None - tag_avg_bpb = None macro_avg_bpb = None - tag_macro_loss = {} - tag_micro_loss = {} - tag_macro_bpb = {} - tag_micro_bpb = {} + tag_macro_loss: dict[str, float] = {} + tag_micro_loss: dict[str, float] = {} + tag_macro_bpb: dict[str, float] = {} + tag_micro_bpb: dict[str, float] = {} mean_loss_per_tag_cpu = np.array(state.loss_per_tag.mean.array) total_tokens_per_tag_cpu = np.array(state.loss_per_tag.mean.array) @@ -388,12 +389,11 @@ def evaluate(self, m: LmHeadModel): tag_micro_bpb[parent] = np.average(mean_bits_per_tag_cpu, weights=total_bytes_per_tag_cpu * mask) for tag, index in self.dataset.tag_to_index.items(): - tag_micro_loss[tag] = mean_loss_per_tag_cpu[index] + tag_micro_loss[tag] = float(mean_loss_per_tag_cpu[index]) # no macro loss for the leaf tags if self.bytes_per_token is not None: - tag_micro_bpb[tag] = mean_bits_per_tag_cpu[index] - # tag_macro_bpb[tag] = None + tag_micro_bpb[tag] = float(mean_bits_per_tag_cpu[index]) return EvalResult( micro_avg_loss, @@ -422,10 +422,10 @@ def _calculate_bytes_per_token_type(self, tokenizer: HfTokenizer) -> Optional[ha class _EvalRunningMeans(eqx.Module): - loss_per_token: RunningMean - loss_per_tag: RunningMean - bpb: RunningMean - bpb_per_tag: RunningMean + token_avg_loss: RunningMean # average loss averaged over all tokens + loss_per_tag: RunningMean # average loss per tag + bpb: RunningMean # bits per byte averaged over all tokens + bpb_per_tag: RunningMean # bits per byte per tag @staticmethod def zeros_like(total: Arrayish, per_tag: Arrayish) -> "_EvalRunningMeans": diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index 1b800bf37..922de4830 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -1,4 +1,5 @@ import os +import re from typing import TypeAlias from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -34,13 +35,18 @@ def byte_length_of_token(tokenizer, idx: int) -> int: # '▁this' # >>> t.convert_ids_to_tokens(25) # '<0x16>' - # We want the _ but not the <0x16>, which should instead be a single byte \x16 + # We want the _ (as a single byte, not the 3 it's encoded as) but not the <0x16>, which should instead be a single byte \x16 # decode strips the prefix spaces, but does correctly handle the <0x16> case # we can avoid prefix space issues by prepending another token before decoding, then stripping + repr = tokenizer.convert_ids_to_tokens(idx) if idx in tokenizer.all_special_ids: # NB: special tokens don't have bytes, but they contribute to perplexity/bits return 0 - extra_token = tokenizer(".", add_special_tokens=False)["input_ids"][0] - excess_bytes = len(".".encode("utf-8")) - decoded = tokenizer.decode([extra_token, idx]).encode("utf-8") - return len(decoded) - excess_bytes + # handle bytes specially. This is a bit of a hack, but there's no other way + elif m := re.match(r"<0x([0-9A-Fa-f]+)>", repr): + return len(bytes.fromhex(m.group(1))) + else: + extra_token = tokenizer(".", add_special_tokens=False)["input_ids"][0] + excess_bytes = len(".".encode("utf-8")) + decoded = tokenizer.decode([extra_token, idx]).encode("utf-8") + return len(decoded) - excess_bytes diff --git a/tests/test_hf_utils.py b/tests/test_hf_utils.py index a7a378bd1..fadb515c7 100644 --- a/tests/test_hf_utils.py +++ b/tests/test_hf_utils.py @@ -26,19 +26,57 @@ def test_load_tokenizer_in_memory_fs(): assert len(tokenizer) == 5027 -@skip_if_hf_model_not_accessible("meta-llama/Llama-2-7b-hf") def test_byte_length_of_token(): tok = load_tokenizer("meta-llama/Llama-2-7b-hf") ids = tok("this is hello a test", add_special_tokens=False)["input_ids"] assert byte_length_of_token(tok, ids[2]) == len(" hello".encode("utf-8")) assert byte_length_of_token(tok, 25) == 1 - # llama prepends a space to the token. ideally it wouldn't b/c it technically throws off our bpb calculations + # llama prepends a space to the string. ideally it wouldn't b/c it technically throws off our bpb calculations # but it's a small difference assert byte_length_of_token(tok, ids[0]) == len(" this".encode("utf-8")) bos = tok.bos_token_id assert byte_length_of_token(tok, bos) == 0 + # 632: "▁▁▁▁▁▁▁▁▁▁▁▁" which is just 12 spaces + # assert byte_length_of_token(tok, 632) == len(" ".encode("utf-8")) + # 8535: "ными" + # assert byte_length_of_token(tok, 8535) == len("ными".encode("utf-8")) + + checks = { + 632: " " * 12, + 8535: "ными", + 25: " ", + } + + for token_id, expected_length in checks.items(): + assert byte_length_of_token(tok, token_id) == len(expected_length.encode("utf-8")) + + # now just test all tokens and print the ones that aren't expected + # the ones less than 259 are bytes or special tokens + for i in range(3, 259): + byte_length = byte_length_of_token(tok, i) + assert byte_length == 1, f"Token {i} has length {byte_length} but expected 1" + + for i in range(259, tok.vocab_size): + byte_length = byte_length_of_token(tok, i) + expected_length = len(tok.convert_ids_to_tokens(i).replace("▁", " ").encode("utf-8")) + assert byte_length == expected_length, f"Token {i} has length {byte_length} but expected {expected_length}" + + +@skip_if_hf_model_not_accessible("meta-llama/Llama-2-7b-hf") +def test_byte_length_of_token_multi(): + tok = load_tokenizer("meta-llama/Llama-2-7b-hf") + multi_checks = [ + "👍你好", + ] + + for expr in multi_checks: + # stupid llama adds a prefix space + token_ids = tok.encode(expr, add_special_tokens=False)[1:] + total_length = sum(byte_length_of_token(tok, token_id) for token_id in token_ids) + assert total_length == len(expr.encode("utf-8")) + @skip_if_hf_model_not_accessible("gpt2") def test_byte_length_of_token_gpt2(): From 33cdb925e2641c41077e17273ea82dafe03e925b Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 17 Sep 2024 16:19:25 -0700 Subject: [PATCH 10/11] grr --- src/levanter/eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/eval.py b/src/levanter/eval.py index c03e2b519..555dd1466 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -308,7 +308,7 @@ def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample, safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0) mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag) - state = dataclasses.replace(state, loss_per_token=mean, loss_per_tag=mean_per_tag) + state = dataclasses.replace(state, token_avg_loss=mean, loss_per_tag=mean_per_tag) if self.bytes_per_token is not None: next_tokens = hax.roll(batch.tokens, -1, m.Pos) # [Batch, Pos], rolled by 1 for next token task From d6d21c17352c0c2f31a5f613225e35035ec6b23a Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 17 Sep 2024 21:19:39 -0700 Subject: [PATCH 11/11] oops --- tests/test_hf_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_hf_utils.py b/tests/test_hf_utils.py index fadb515c7..c3c322cf0 100644 --- a/tests/test_hf_utils.py +++ b/tests/test_hf_utils.py @@ -26,6 +26,7 @@ def test_load_tokenizer_in_memory_fs(): assert len(tokenizer) == 5027 +@skip_if_hf_model_not_accessible("meta-llama/Llama-2-7b-hf") def test_byte_length_of_token(): tok = load_tokenizer("meta-llama/Llama-2-7b-hf") ids = tok("this is hello a test", add_special_tokens=False)["input_ids"]