From fbe27bc7f3591a6403fef3b2b6c805114e9215f4 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 5 Sep 2024 11:10:11 -0700 Subject: [PATCH] Completely rework dataset/cache system: instant resume, perfect shuffle, stable mixtures and more (#716) Introduces a massive rework of Levanter's cache system to support instant resume, perfect shuffle, stable mixtures and such. The basic idea is to use TensorStore to store all of our data as a kind of janky column store (implemented in JaggedArrayStore) and pytrees of such (implemented in TreeStore). TensorStore provides efficient storage and access to very large arrays. We still support streaming from an in progress cache via a new AsyncDataset class. I've successfully tests this on the pile and, modulo the usual issues with the llama tokenizer on long documents/books, it behaves well. Closes #626 #311 #119 #34 --- .dockerignore | 1 + .github/workflows/run_entry_tests.yaml | 2 +- .github/workflows/run_ray_tests.yaml | 4 +- .github/workflows/run_tests.yaml | 2 +- config/data/redpajama_1b_source.yaml | 1 - config/data/redpajama_1t_source.yaml | 1 - config/data/rpv1_llama.yaml | 1 - config/gpt2_nano_mixture.yaml | 2 +- config/llama2_small_fast_mix.yaml | 163 ++ docs/Fine-Tuning.md | 2 +- docs/Training-On-Your-Data.md | 32 +- docs/design/Data-Loader-Design.md | 334 ++-- examples/alpaca-lora/alpaca_lora.py | 2 +- examples/alpaca/alpaca.py | 82 +- examples/gsm8k-lora/gsm8k_lora.py | 69 +- pyproject.toml | 3 +- scripts/repair_cache.py | 60 - src/levanter/callbacks.py | 5 +- src/levanter/checkpoint.py | 5 +- src/levanter/compat/hf_checkpoints.py | 2 +- src/levanter/data/__init__.py | 36 +- src/levanter/data/_preprocessor.py | 20 +- src/levanter/data/_process_interleave.py | 338 ---- src/levanter/data/_prp.py | 63 + src/levanter/data/_queue.py | 41 +- src/levanter/data/audio.py | 194 ++- src/levanter/data/dataset.py | 380 +++- src/levanter/data/loader.py | 452 +++-- src/levanter/data/metrics_monitor.py | 25 +- src/levanter/data/mixture.py | 224 ++- src/levanter/data/permutation.py | 135 ++ src/levanter/data/shard_cache.py | 1521 ----------------- ...arded_dataset.py => sharded_datasource.py} | 82 +- src/levanter/data/text.py | 552 +++--- src/levanter/doremi.py | 28 +- src/levanter/eval.py | 101 +- src/levanter/logging.py | 4 +- src/levanter/main/cache_dataset.py | 4 +- src/levanter/main/eval_lm.py | 6 +- src/levanter/main/lora_lm.py | 14 +- src/levanter/main/train_asr.py | 13 +- src/levanter/main/train_lm.py | 32 +- src/levanter/main/viz_logprobs.py | 8 +- src/levanter/store/__init__.py | 6 + src/levanter/store/cache.py | 1321 ++++++++++++++ src/levanter/store/jagged_array.py | 508 ++++++ src/levanter/store/stress_test_new_cache.py | 149 ++ src/levanter/store/tree_store.py | 237 +++ src/levanter/tracker/wandb.py | 10 +- src/levanter/trainer.py | 35 +- src/levanter/utils/background_iterable.py | 85 +- src/levanter/utils/fsspec_utils.py | 6 + src/levanter/utils/index.py | 46 + src/levanter/utils/jax_utils.py | 4 +- src/levanter/utils/py_utils.py | 7 - src/levanter/utils/ray_utils.py | 4 +- src/levanter/utils/thread_utils.py | 28 + tests/__init__.py | 0 tests/test_audio.py | 52 +- tests/test_background_iterable.py | 70 +- tests/test_checkpoint.py | 7 - tests/test_data_mixture.py | 126 -- tests/test_doremi.py | 85 +- tests/test_in_progress_sequence.py | 124 -- tests/test_jagged_array.py | 305 ++++ tests/test_llama.py | 3 - tests/test_lora.py | 5 +- tests/test_mixture.py | 155 ++ tests/test_new_cache.py | 921 ++++++++++ ...eplicated_loader.py => test_new_loader.py} | 143 +- tests/test_newdataset.py | 142 ++ tests/test_prp.py | 87 + tests/test_shard_cache.py | 383 ----- tests/test_sharded_dataset.py | 4 +- tests/test_sharded_loader.py | 299 ---- tests/test_shuffle_dataset.py | 30 - tests/test_text.py | 30 +- tests/test_tokenized_document_cache.py | 216 --- tests/test_tree_store.py | 435 +++++ tests/test_utils.py | 12 +- tests/tiny_test_corpus.py | 20 +- 81 files changed, 6691 insertions(+), 4455 deletions(-) create mode 100644 config/llama2_small_fast_mix.yaml delete mode 100644 scripts/repair_cache.py delete mode 100644 src/levanter/data/_process_interleave.py create mode 100644 src/levanter/data/_prp.py create mode 100644 src/levanter/data/permutation.py delete mode 100644 src/levanter/data/shard_cache.py rename src/levanter/data/{sharded_dataset.py => sharded_datasource.py} (89%) create mode 100644 src/levanter/store/__init__.py create mode 100644 src/levanter/store/cache.py create mode 100644 src/levanter/store/jagged_array.py create mode 100644 src/levanter/store/stress_test_new_cache.py create mode 100644 src/levanter/store/tree_store.py create mode 100644 src/levanter/utils/index.py create mode 100644 src/levanter/utils/thread_utils.py create mode 100644 tests/__init__.py delete mode 100644 tests/test_data_mixture.py delete mode 100644 tests/test_in_progress_sequence.py create mode 100644 tests/test_jagged_array.py create mode 100644 tests/test_mixture.py create mode 100644 tests/test_new_cache.py rename tests/{test_replicated_loader.py => test_new_loader.py} (62%) create mode 100644 tests/test_newdataset.py create mode 100644 tests/test_prp.py delete mode 100644 tests/test_shard_cache.py delete mode 100644 tests/test_sharded_loader.py delete mode 100644 tests/test_shuffle_dataset.py delete mode 100644 tests/test_tokenized_document_cache.py create mode 100644 tests/test_tree_store.py diff --git a/.dockerignore b/.dockerignore index 17fbbcfe1..45dfa95e6 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,6 +2,7 @@ scratch cache +new-cache wandb checkpoints diff --git a/.github/workflows/run_entry_tests.yaml b/.github/workflows/run_entry_tests.yaml index 9ab96773e..ab08013ee 100644 --- a/.github/workflows/run_entry_tests.yaml +++ b/.github/workflows/run_entry_tests.yaml @@ -21,7 +21,7 @@ jobs: run: | python -m pip install --upgrade pip pip install flake8 pytest - pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" + pip install .[test] "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" pip install soundfile librosa - name: Run entry tests with pytest run: | diff --git a/.github/workflows/run_ray_tests.yaml b/.github/workflows/run_ray_tests.yaml index c82611793..42139e576 100644 --- a/.github/workflows/run_ray_tests.yaml +++ b/.github/workflows/run_ray_tests.yaml @@ -21,8 +21,8 @@ jobs: run: | python -m pip install --upgrade pip pip install flake8 pytest - pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" + pip install .[test] "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" pip install soundfile librosa - name: Run ray tests with pytest run: | - XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=$(pwd)/tests:$(pwd)/src:$(pwd):. pytest tests -m ray + PYTHONPATH=$(pwd)/tests:$(pwd)/src:$(pwd):. pytest tests -m ray diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index a6d4f7cab..6e9ed7024 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -21,7 +21,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" + pip install .[test] "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" pip install -r ./tests/requirements.txt - name: Test with pytest run: | diff --git a/config/data/redpajama_1b_source.yaml b/config/data/redpajama_1b_source.yaml index 1a873ed9a..aaa817399 100644 --- a/config/data/redpajama_1b_source.yaml +++ b/config/data/redpajama_1b_source.yaml @@ -3,4 +3,3 @@ cache_dir: gs://levanter-data/tokenized/redpajama-sample/ tokenizer: EleutherAI/gpt-neox-20b splits: - train -rows_per_chunk: 32768 diff --git a/config/data/redpajama_1t_source.yaml b/config/data/redpajama_1t_source.yaml index f70f7c192..4a4b29474 100644 --- a/config/data/redpajama_1t_source.yaml +++ b/config/data/redpajama_1t_source.yaml @@ -3,4 +3,3 @@ cache_dir: gs://levanter-data/tokenized/redpajama/ tokenizer: EleutherAI/gpt-neox-20b splits: - train -rows_per_chunk: 4096 diff --git a/config/data/rpv1_llama.yaml b/config/data/rpv1_llama.yaml index 92a46b50c..75a7b7ff2 100644 --- a/config/data/rpv1_llama.yaml +++ b/config/data/rpv1_llama.yaml @@ -1,5 +1,4 @@ cache_dir: gs://levanter-data/tokenized/redpajama_v1_llama_mixture -rows_per_chunk: 4096 tokenizer: "meta-llama/Llama-2-7b-hf" configs: arxiv: diff --git a/config/gpt2_nano_mixture.yaml b/config/gpt2_nano_mixture.yaml index 673187312..2939b9e5e 100644 --- a/config/gpt2_nano_mixture.yaml +++ b/config/gpt2_nano_mixture.yaml @@ -7,7 +7,7 @@ data: id: dlwh/wikitext_103_detokenized train_weights: wikitext: 1.0 - w2: 0 + w2: 1.0 model: type: gpt2 hidden_dim: 32 diff --git a/config/llama2_small_fast_mix.yaml b/config/llama2_small_fast_mix.yaml new file mode 100644 index 000000000..aabd17fae --- /dev/null +++ b/config/llama2_small_fast_mix.yaml @@ -0,0 +1,163 @@ +data: + tokenizer: "meta-llama/Llama-2-7b-hf" + cache_dir: "gs://levanter-data/new-tokenized/pile_mix/" + shuffle: + era_length: 10000 + configs: + arxiv: + train_urls: + - gs://levanter-data/pile-domains/arxiv/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/arxiv/val.jsonl.zst + books2: + train_urls: + - gs://levanter-data/pile-domains/books2/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/books2/val.jsonl.zst + books3: + train_urls: + - gs://levanter-data/pile-domains/books3/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/books3/val.jsonl.zst + dm_math: + train_urls: + - gs://levanter-data/pile-domains/dm_math/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/dm_math/val.jsonl.zst + enron: + train_urls: + - gs://levanter-data/pile-domains/enron/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/enron/val.jsonl.zst + europarl: + train_urls: + - gs://levanter-data/pile-domains/europarl/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/europarl/val.jsonl.zst + free_law: + train_urls: + - gs://levanter-data/pile-domains/freelaw/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/freelaw/val.jsonl.zst + github: + train_urls: + - gs://levanter-data/pile-domains/github/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/github/val.jsonl.zst + hackernews: + train_urls: + - gs://levanter-data/pile-domains/hackernews/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/hackernews/val.jsonl.zst + nih: + train_urls: + - gs://levanter-data/pile-domains/nih/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/nih/val.jsonl.zst + opensubtitles: + train_urls: + - gs://levanter-data/pile-domains/opensubtitles/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/opensubtitles/val.jsonl.zst + owt2: + train_urls: + - gs://levanter-data/pile-domains/owt2/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/owt2/val.jsonl.zst + pg_19: + train_urls: + - gs://levanter-data/pile-domains/pg_19/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/pg_19/val.jsonl.zst + philpapers: + train_urls: + - gs://levanter-data/pile-domains/philpapers/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/philpapers/val.jsonl.zst + pile_cc: + train_urls: + - gs://levanter-data/pile-domains/pile_cc/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/pile_cc/val.jsonl.zst + pubmed_abs: + train_urls: + - gs://levanter-data/pile-domains/pubmed_abs/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/pubmed_abs/val.jsonl.zst + pubmed_central: + train_urls: + - gs://levanter-data/pile-domains/pubmed_central/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/pubmed_central/val.jsonl.zst + stack_exchange: + train_urls: + - gs://levanter-data/pile-domains/stack_exchange/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/stack_exchange/val.jsonl.zst + ubuntu_irc: + train_urls: + - gs://levanter-data/pile-domains/ubuntu_irc/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/ubuntu_irc/val.jsonl.zst + uspto: + train_urls: + - gs://levanter-data/pile-domains/uspto/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/uspto/val.jsonl.zst + wiki_en: + train_urls: + - gs://levanter-data/pile-domains/wiki_en/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/wiki_en/val.jsonl.zst + youtube_subtitles: + train_urls: + - gs://levanter-data/pile-domains/youtube_subtitles/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/youtube_subtitles/val.jsonl.zst + train_weights: + # these weights come from the paper https://arxiv.org/pdf/2101.00027.pdf + pile_cc: 0.1811 + pubmed_central: 0.1440 + books3: 0.1207 + owt2: 0.1001 + arxiv: 0.0896 + github: 0.0759 + free_law: 0.0612 + stack_exchange: 0.0513 + uspto: 0.0365 + pubmed_abs: 0.0307 + pg_19: 0.0217 + opensubtitles: 0.0155 + wiki_en: 0.0153 + dm_math: 0.0124 + ubuntu_irc: 0.0088 + books2: 0.0075 + europarl: 0.0073 + hackernews: 0.0062 + youtube_subtitles: 0.0060 + philpapers: 0.0038 + nih: 0.0030 + enron: 0.0014 +model: + type: llama + hidden_dim: 768 + intermediate_dim: 2048 + num_heads: 6 + num_kv_heads: 6 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true +trainer: + tracker: + project: "levanter" + tags: [ "pile", "llama", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + + train_batch_size: 256 + num_train_steps: 20000 +optimizer: + learning_rate: 1E-3 + weight_decay: 0.1 + warmup: 0.01 diff --git a/docs/Fine-Tuning.md b/docs/Fine-Tuning.md index 0b7cecd58..ebe0377ea 100644 --- a/docs/Fine-Tuning.md +++ b/docs/Fine-Tuning.md @@ -406,7 +406,7 @@ def _get_data_source(path_or_id): if fsspec_utils.exists(path_or_id): return JsonDataset([path_or_id]) else: - return levanter.data.dataset_from_hf(path_or_id, split="train") + return levanter.data.datasource_from_hf(path_or_id, split="train") ``` Preprocessing in Levanter typically happens in two phases: diff --git a/docs/Training-On-Your-Data.md b/docs/Training-On-Your-Data.md index 381ab9dc3..3fac85a07 100644 --- a/docs/Training-On-Your-Data.md +++ b/docs/Training-On-Your-Data.md @@ -106,10 +106,6 @@ data: ### Mixture of Sources -!!! warning - - This feature is experimental and may change in the future. - If you have multiple sources of data (e.g., multiple domains, or distinct subsets of data), you can use the `data` section of your training configuration to specify them: ```yaml @@ -145,14 +141,14 @@ validation data. ## Data Preprocessing Levanter supports both online and offline preprocessing. Online preprocessing is done on-the-fly -during training. With online preprocessing, you don't need to think about preprocessing your data. +during training. With online preprocessing, you don't need to think about preprocessing your data +except to make sure it's in the right format and where you'd like to store the cached preprocessing +results. Our data loading pipeline will automatically break and concatenate documents into chunks equal to the model's `seq_len` parameter. It will also automatically add special tokens to the end of documents. -We don't yet handle sequence-to-sequence tasks, but we plan to. - ### Online Preprocessing We have a sophisticated caching mechanism using [Ray](https://docs.ray.io/en/latest/) @@ -160,8 +156,7 @@ that builds a cache of preprocessed data on the fly. Online caching happens tran in the background, using the mostly-idle CPU-cores of the machine(s) you are training on. The cache that is built is fully reproducible, and can be used for future training runs. -Training will start as soon as each training machine has its first shard of data cached -and once the validation data is cached. +Training will start as soon as the system has the data it needs. ### Offline Preprocessing @@ -190,19 +185,28 @@ python -m levanter.main.cache_dataset \ ### Direct Cache Construction As a final option, you can directly construct a cache of preprocessed data without using Ray. This is useful if you -have custom preprocessing logic or Ray isn't working for you for some reason. To do so, you can use [levanter.data.SerialCacheWriter][] +have custom preprocessing logic or Ray isn't working for you for some reason. To do so, you can use [levanter.store.SerialCacheWriter][] to write batches directly. Here's an example: ```python -from levanter.data import SerialCacheWriter +import numpy as np + +from levanter.store import SerialCacheWriter + +exemplar = { + "input_ids": np.zeros((0), dtype=np.int32), + "attention_mask": np.zeros((0), dtype=np.int32), + "labels": np.zeros((0), dtype=np.int32), +} -with SerialCacheWriter(cache_dir, rows_per_chunk=1024) as writer: +with SerialCacheWriter(cache_dir, exemplar) as writer: for batch in process_batches(): + # batch should be a list of dicts, each with keys "input_ids", "attention_mask", and "labels" writer.write_batch(batch) ``` -`batch` can be a `list[dict]`, `dict[list]`, or `pyarrow.RecordBatch`. To work with `train_lm`, it should have an -`input_ids` key that is a list of `int`s. +In this case, `batch` should be a list of dicts, each with keys `"input_ids"`, `"attention_mask"`, and `"labels"`. +To work with `train_lm`, it should have an `input_ids` key that is a list of `int`s. To use a cache like this, you can use the `passthrough` tokenizer: diff --git a/docs/design/Data-Loader-Design.md b/docs/design/Data-Loader-Design.md index e9386a762..b000e93c0 100644 --- a/docs/design/Data-Loader-Design.md +++ b/docs/design/Data-Loader-Design.md @@ -1,254 +1,174 @@ # Data Loader Design -## Design as of 2023-04-18 -### Goals +## Context -We want to support the following: -1) Deterministic batches, even for a changing number of readers (or writers). That is, for any cluster size -during training, we want the same batches to be generated in the same order. -2) Sharded reading and writing. We want to be able to read and write from multiple shards in parallel. -3) Simultaneous reading and writing of shards. We want to be able to start training while we are still building the cache. -4) Fast resumption without losing too much progress. This applies to both *writing* and *reading* the cache. That is, when we - resume a training run, we want to finish producing the cache and also jump to the right place in the cache for reads. -5) (eventually) shuffling/random access -6) We want to be able to build the cache offline too. -7) We want to support batches that are composed of fragments of documents. In particular, we take a moving window of tokens - from documents. This implies that the mapping from "documents" to "batches" is not 1:1, or easy to compute. +Levanter, like any LM training framework, needs to read (usually text) data to feed it to the model. This +process involves reading lots of raw text, tokenizing it, and splitting it up into model-sized chunks. +Unlike many other ML workloads, the mapping from raw data to model-sized chunks is not 1:1, but in general +many-to-many. This is because we typically take a moving window of tokens from a list of documents. -We want to support the following use cases: -1) We have a larger training dataset, and we want to draw samples from it more or less independently on a large number of machines. - We don't really care about "epochs"/"passes", but we do want to be able to handle resumes and be deterministic. Ideally, each - machine only reads from the chunks that it needs to read from. -2) We have a smaller validation dataset, and we want to do a single pass over it. We don't care about resuming, and it's ok if -we have to read the whole dataset on each machine. -3) (Eventually) Like (1) but we want to jump around the dataset. We still care about resuming and determinism, but don't care about epochs. +Levanter is designed to be completely deterministic, meaning that if you run the same code on the same data on +the same hardware, you should get the same results. This is important for debugging and for reproducibility. +In order to guarantee determinism, our data loading pipeline must be deterministic as well. +Moreover, to the extent possible, we want deterministic batch order even if the number of machines changes. -We focus on (1) and (2) for now. +Data is usually stored in compressed shards, each morally equivalent to an iterator over a list of documents. +In particular, we don't usually have random access. This implies that we need to produce a cache of processed +documents that does allow for random access. Random access is important for resuming training quickly, +as well as for shuffling. +Early on in Levanter's development, we made the decision to support "quick start" training, where we can start +training while we are still building the cache. This is helpful when iterating on the data pipeline +and removes a step from the training process. This implies that we need to support simultaneous reading and writing +of the cache. -## Some terminology +Levanter also wants to support dynamic mixtures of data, where we reweight different datasets on the fly. To do so, +we need separate caches for each dataset. -* **Shard**: A shard is a list of *raw* documents that not been tokenized/preprocessed. -* **Chunk**: A chunk is a list of *processed* documents that have been tokenized/preprocessed. -* **Reader**: A reader is a process that reads from the cache. Typically there is one reader per machine. -* **Writer**: A writer is a process that writes to the cache. Typically there is one writer per machine. -* **Global ordering**: The global ordering is the ordering of chunks in the cache. This is the order in which - documents are read by readers. The global ordering is defined with respect to an "idealized" number of readers R*. (See below.) -* **Processor** or **Tokenizer**: A function that takes a raw document and returns a processed document. -* **Example** is a single datum that is fed into the model. Examples are typically composed of fragments of documents. - For example, we might take a moving window of tokens from the concatenation of a list of preprocessed documents. - - -We say there are K input shards, W writers, R readers. We assume K >= W (though typically K is not too large), and W ≈ R. -We produce N chunks. We also define an idealized number of readers R*, which defines the global ordering over the data. -Typically R* should be the maximum number of readers we expect to actually use. - - -## Cache structure -We define a shard cache as a list of "chunks", where each chunk is a parquet file (plus metadata) with an equal -number of documents (except for the last chunks for each shard.) -Each chunk is a list of processed documents. Chunks are ordered round robin from the input shards, so that the c'th global chunk is the -c%K'th chunk of the c/K'th shard, so long as all shards have at least c/K chunks. (After that, we remove shards that -have been exhausted and continue round robin.) -We keep the following metadata: -* For each shard, we keep a list of chunks written so far and whether or not we are done processing that shard. -* For each chunk, we keep the number of documents, token counts/length of various fields, and the number of bytes. - (This metadata can be used for seeking.) -* For the cache overall, we keep the global ordering of chunks, the number of chunks, and the number of documents. - -### Chunk format - -A Chunk is an Apache Parquet file with schema dependent on the task. For example, for language modeling, we might have -just a sequence of input_ids per document. We use Apache Parquet because it's compact and doesn't require us to know -much about the datatypes we're using. - -Chunks also have metadata stored in a separate json file. This metadata includes the total number of documents in the -chunk, as well as token counts/lengths of various fields. This metadata is used for seeking. - -## Cache construction - -We use Ray to handle the construction of the cache. There are 4 types of processes/actors that we create using Ray: - -* A ChunkCacheBroker actor, whose job is to dispense chunks to readers while the cache is being built. It is also - responsible for keeping track of the global ordering of chunks. -* A ChunkCacheBuilder actor, which is responsible for building the cache. It forks off processes for processing - input shards. It acts as a callback for these processes, and accepts chunk metadata from them. -* Shard writer processes, one per input shard. The function _produce_cache_for_shard is the entry point for these processes. - This function is responsible for reading from the input shard and forking off processes to process chunks of documents. -* Chunk processor processes, which are responsible for processing documents and creating chunks. _produce_chunk is the - entry point for these processes. - -Readers are managed by the model training processes, which read by sending requests to the broker via the Ray. They -are not themselves Ray actors/processes. +In practice, even for the relatively "small" examples one has in LM training (compared to vision, for example), +we also want to do sharded loading. -## Reproducible Sharded Reading for Training +## Goals -We want to be able to read from the cache in a way that is deterministic and reproducible, even if the number of readers -changes. We also want readers to only read from the chunks that they need to read from. -We pretend the list of data is infinite by cycling. We do track epochs when reading this way. +To summarize: -NB Our goal is a deterministic ordering over examples, and not merely chunks or even documents. +* **Deterministic batches**: For any cluster size during training, we want the same batches to be + generated in the same order. +* **Instant Resume**: We want training to be able to resume training quickly, without losing too much progress. +* **Quick Start**: Unless it is logically impossible (e.g. for shuffling), we want to be able to start training + while we are still building the cache. +* **Random Access**: We want to be able to jump around the dataset, for shuffling and for resuming. -Given a list of chunks and the idealized number of readers R*, we define the global ordering over chunks as follows: -First define R* iterators over chunks, with `chunk_iterators[r]` being defined as `loop(all_chunks)[r::R*]`. +## Cache Design -Next, define a function `mk_examples(chunk_iterator)` that takes a list of iterators over chunks and returns -a list of examples. Define `chunk_examples[r] = mk_examples(chunk_iterators[r])`. -This function depends on our sequence length, etc. Then the ordering over examples is: +### Terminology -`chunk_examples[0][0], chunk_examples[1][0], ..., chunk_examples[R*-1][0], ..., chunk_examples[0][1], chunk_examples[1][1], ..., chunk_examples[R*-1][1], ...` -that is, `example[i] == chunk_examples[i % R*][i // R*]` - -If we have $R*$ readers, then each `reader_iterator[r][j] == chunk_examples[r][j] == example[j * R* + r]`. -Moreover, if either R or R* is a multiple of the other, then we still get a nice property where -each reader reads from a strided slice of the chunk_iterators: - -(Boring math) -* If we have R readers, then `reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*]` -* If we have `R == n * R*`, then `reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*] - == chunk_examples[r % R*][(j * n * R* + r) // R*] == chunk_examples[r % R*][j * n + r // R*],` so each reader reads from -a strided slice (specifically `islice(..., r//R*, None, n)`) -* If we have `R* == n * R`, then `reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*] -== chunk_examples[R * (j % n) + r][(j * R + r) // R*]` and so each reader reads from n different chunk_exampless. -so we round-robin over a slice of the chunk_examples. - -For other cases (R and R* don't divide each other), there's no simple relationship between the reader and chunk iterators -and you end up reading from everywhere, but that's ok. - -# Single-Pass Reading for Evaluation -When we want to do a single pass over the data, we don't cycle and we don't shuffle. We just read the data in order. Boring -and simple. - - -## Resuming - -We need to think about resuming in two cases: resuming writes and resuming reads. - -## Resuming Writes - -Resuming writes is relatively easy, since we can just keep track of the number of chunks written for each shard and the -number of documents written for each chunk. Then you just skip to the appropriate document and start writing. - -### Resuming Reads - -We want to understand how to seek to the b'th batch. +* **Document**: A document is a single datum that is fed into the model. Documents are typically tokenized and + preprocessed. For LM training, a document is just a string, but for other tasks, it might be more complex. For example, + there might be a "prompt" and a "response". +* **Shard**: A shard is a list of *raw* documents that not been tokenized/preprocessed. +* **Cache**: A cache is a list of *processed* documents that have been tokenized/preprocessed. These are stored as +a group of [TensorStore](https://google.github.io/tensorstore/) arrays structured to behave like a column store. These +arrays are stored as Zarray arrays, typically compressed. +* **Reader**: A reader is a process that reads from the cache. Typically, there is one reader per machine. +* **Writer**: A writer is a process that writes to the cache. Typically, there is one writer per *cache*. +* **Global ordering**: Each document in a cache is assigned a global index. This index is deterministic, but +a bit hard to compute a priori. +* **Processor** or **Tokenizer**: A function that takes a batch of raw documents and returns a batch of processed documents. +* **Example**: A single datum that is fed into the model. Examples are typically composed of fragments of documents. + For example, we might take a moving window of tokens from the concatenation of a list of preprocessed documents. +* **Ledger**: A ledger is a list of metadata about the cache. This includes the number of documents in each shard +as well as some information to make it less likely that you accidentally reuse a cache. -There are two cases of resuming we need to think about: -1) The "easy" case where 1 example == 1 (preprocessed) document. -2) The "hard" case where the mapping from examples to documents is not 1:1, but there is some easily computable relationship. +### Cache Structure -In the first case, each reader `r` reads `documents[r::R]`. The `b`th batch -is `documents[b * batch_size:(b+1) * batch_size]`. Assuming `batch_size % R == 0`, then for the b'th batch, reader r -needs to read `documents[b * batch_size + r: (b+1) * batch_size + r: R] == docs(chunk_iterator[r])[b * batch_size // R:(b+1) * batch_size // R]`. -If we know how many documents are in each chunk, then we can seek to the right place in the chunk. +A cache is a [PyTree](https://jax.readthedocs.io/en/latest/pytrees.html) of [levanter.store.JaggedArray][]s, each +representing a different field of the processed documents. Each JaggedArray is a group of either two or three arrays: -The second case is broadly similar. In particular, we consider the case where we take moving windows of concatenated documents. -If our metadata includes token counts, then we can skip chunks until we pass `batch_size * tokens_per_example // R` tokens. +* **Data**: The actual data, stored as a Zarray array. All the "tokens" for a given field for all documents are stored in a single flat array. +* **Offsets**: The offsets into the data array for each document. This is a 1-D array of length N+1, where N is the number of documents. +* **Shape** (optional): The shape of the data for each document. This is only present for fields that are not 1-D. +For tokenized documents, a cache looks like this: -## Shuffling +``` +cache +├── train +│ ├── input_ids +│ │ ├── data +│ │ │ ├── c +│ │ │ │ └── 0 +│ │ │ └── zarr.json +│ │ └── offsets +│ │ ├── c +│ │ │ └── 0 +│ │ └── zarr.json +│ ├── shard_ledger.json +``` -### A brief digression -Why do we want to shuffle during training? Shuffling reduces variance in the gradients. If we have batches -where every example is from the same document/domain, then the gradients for those batches will be correlated. +(Typically there's a lot more files in the `c` directories, but I've omitted them for brevity.) -That said, in our setting where we use moving windows from documents, if we round-robin from chunks (which are produced -from different documents), and R* is roughly equal to the batch size, then we will read from a different chunk for every -example in a batch, which reduces correlation within a batch. +The stuff under `input_ids/data` is the actual data, and the stuff under `input_ids/offsets` is the offsets. -However, we still have (undesirable) correlation between batches: if we -read from chunks consecutively and our documents are long, then many examples in the next batch will be from the -same document as an example in the previous batch. Ideally this wouldn't happen. I'm not convinced that it matters -that much. +In code, this is modeled in [levanter.store.TreeStore][]. -Proper shuffling is incompatible with streaming at a fundamental level. Our choices are something like: +### Cache Construction -* Randomly shuffle before preprocessing. Makes life a bit less pleasant for people with a new dataset. Can't be changed after preprocessing. Doesn't solve the problem of correlated batches. -* Reservoir sampling. Makes resumes hard, but is easy to implement. -* "Epochal" reservoir sampling, where we periodically "flush" the reservoir. Resumes are easier because you can start from the latest "epoch" -* No shuffling in the first pass, but shuffle in subsequent passes. -* Shuffle within a range of chunks that grows as the run progresses. +We use Ray to handle the construction of the cache. There are 4 types of processes/actors that we create using Ray: -My hunch is that we can skip this for now, and revisit if we find that it's a problem. +- `_TreeStoreCacheBuilder`: This actor is responsible for building the cache. It forks off actors for reading + shards and processing documents. It acts as a callback for these processes. +- `_OrderedCacheWriter`: This actor is responsible for writing to the cache. It is responsible for writing the + processed documents to the cache in the correct order. +- `WorkQueueDispatcherActor`: This actor is responsible for reading batches of documents from a group of shards. It dispatches + documents to a group of processors, which are responsible for processing the documents. +- `_BatchProcessorQueue`: This actor is responsible for managing the queue of batches of documents to be processed. It + actually calls the processors to process the documents and then forwards the results to the writer. +The basic flow is that the builder forks off a bunch of `WorkQueueDispatcherActor`s, which read from the shards and +dispatch the documents to the processors. The processors process the documents and send the results to the writer, +which writes them to the cache. -## Current Status as of 2022-10-10 +The writer is responsible for writing the documents to the cache in the correct order. In particular, fix a batch +size B. The writer writes the documents in batches of size B, round-robin from the shards. Once a shard is exhausted, +it is removed from the list of shards. -The current data loader (in levanter/data/text.py and levanter/data/sharded.py) works as follows: +The writer maintains a "ledger" of the cache, which has the number of documents processed in each shard, as well as +whether or not the shard is done. This ledger is used for resuming cache construction. -### TokenizedDocumentCache -* We build a TokenizedDocumentCache, which creates a (user-specified) number of shards (default 128 for training). Documents are tokenized (via an HF tokenizer) and written to the cache in batches of 1000 (by default), with each batch being written to the *smallest* shard. -* The underlying format is a Parquet file, for text data this means a sequence of input_ids stored in a batched columnar layout and compressed -* When you iterate through the TokenizedDocumentCache, it reads the shards in a round-robin fashion, and yields batches of documents, as they were written. -* It can optionally "flatten" a batch of documents into a single doc (which are delimited by eos), which is what we do with TokenSeqDataset. +## Datasets and the Data Loader +Along with the cache, we introduce interfaces and classes for working with the cache. The main classes are: -### TokenSeqDataset -* At load time, a TokenizedDocumentCache is typically wrapped in an TokenSeqDataset, which just wraps the -TokenizedDocumentCache and sets a max_seq_len. This is the fundamental data structure that is used by the data loader. -* The TokenSeqDataset iterates through the TokenizedDocumentCache one batch at a time. The docs are (implicitly) -concatenated together. If a concatenated doc is longer than max_seq_len, then it is split into chunks of max_seq_len. Any left over at the end of a batch is currently thrown out, matching Mistral's behavior. +- [levanter.data.AsyncDataset][]: This is the base class for all datasets. The main method it exposes is + `get_batch(indices: Sequence[int])` which (asynchronously) returns a batch of documents for the given indices. +- [levanter.data.DataLoader][]: This is a class that wraps a dataset and provides an iterator over the dataset. It prefetches + the parts of batches that each machine needs. It has an iterator and supports "seeking" to a particular batch. +- [levanter.store.TreeCache][]: This is an AsyncDataest that wraps a cache and exposes a `get_batch` method that returns + a batch of documents for the given indices. +- [levanter.data.TokenSeqDataset][]: This is an async dataset that does the chunking of documents into examples. It + takes a cache and a `max_seq_len` and returns examples of length `max_seq_len`. +- [levanter.data.PermutationDataset][]: This is a dataset that permutes the indices of another dataset. It is used for shuffling. +- [levanter.data.EraShufflingDataset][]: This is a dataset that emulates the behavior of a shuffle buffer, while + still support random access. It is used for shuffling while still building the cache. +- [levanter.data.MixtureDataset][]: This is a dataset that mixes together multiple datasets with different weights. -### ShardedTokenSeqDataset +### [levanter.data.PermutationDataset][] -* Recall that model computation is done by creating a 2-D grid of TPUs, with the first axis being "data" and the other being "model". All devices on the same row process the same slice of a batch. Typically a row does not span multiple nodes, but usually a node will have multiple rows. -* We can conceptually group the rows into "row groups" such that either a row group is just 1 row, or it spans all rows that are on the same node. -* The job of the ShardedTokenSeqDataset is to shard the TokenSeqDataset into a number of shards and loads the data so that each row gets its own data. Each row group of the 2-d mesh is assigned a shard (i.e. a set of cache files) that it loads from exclusively. -* For each batch, a node reads however many examples it needs to fill its row group. We then create a GlobalDeviceArray which orchestrates the shards together. +The PermutationDataset is a dataset that permutes the indices of another dataset. It is backed by a pseudo-random +permutation (PRP). PRPs give you random access to a permutation with O(1) time and memory. -### Misc notes / problems -* There's no randomness anywhere. -* If documents are very long, this means we're reading from the same doc repeatedly for a batch, which is not ideal. -* We get bitwise determinism so long as the grid configuration doesn't change. -* Because we write to the smallest shard, with a large enough dataset, we should have roughly the same number of tokens in each shard, but it won't be exact. -* Because of the above (and because I didn't know how to do it) we don't have a way for one process to signal that it's done. So we just loop the dataset forever. This isn't ideal for evaluation, if nothing else. -* We haven't implemented seeking in the DataLoader, so resumes are expensive. This is not super hard in principle, but it's not implemented. -* Mentioning again that we drop the last batch of a shard if it's not full. This is not ideal. We should pad it and return masks. +### [levanter.data.EraShufflingDataset][] -## Resumable, Streaming Dataset with good randomness +The EraShufflingDataset is a dataset that emulates the behavior of a shuffle buffer, while still supporting random access. +It works by defining an "era" length, which is the number of samples that are shuffled together. After an era is exhausted, +the dataset shuffles the next era. -Goal: a streaming dataset that is: -1. disk-seek efficient (meaning we don't jump to a random position in a random shard for every sample) -2. reasonably random, including robust to long documents. -3. resumable (meaning it's relatively cheap to resume if a run crashes) -4. doesn't eat too much disk -5. fully replicable with same configuration (meaning that if you run the same code on the same data, you get the same results) -6. (stretch) fully replicable with different configurations (meaning that if you run the same code on the same data, you get the same results, even if you change the number of nodes) -It's easy to get (1) with streaming, and (2) by jumping to random offsets for every sample. Shuffle buffers get you (1) -and (2) together, but only if documents aren't too long. (3) comes easily if you do streaming OR random jumping -constantly, but is a bit painful with a shuffle buffer. You can get (1), (2) and (3) if you are willing to lay out the -entire shuffled dataset on disk for every epoch. But that's not ideal. +### [levanter.data.MixtureDataset][] +We implement "stable mixtures" where the number of samples from each domain for each batch is fixed. This acts +as a kind of variance reduction, while also enabling random access and sampling without replacement. -For (1)-(4), we take a middle path: we stream into a shuffle buffer, but we jump to random offsets after every K samples. Moreover, -we serialize the shuffle buffer, the positions of the datasets, and the random seed to disk when we checkpoint, so that we can resume easily. +Note: I believe it's impossible to sample without replacement and have random access with sampled batches. +This is because for each item `i`, you sample a domain `d_i`, but you can't know which indices in the domain have +been sampled. With replacement is easy so long as you know how big each domain is ahead of time, which means +you can't do streaming. -(5) is easy if you serialize the relevant state, or can just restart your iterators deterministically. -(6) is hard to do in a sharded way. It's easy to "scale down" by emulating a larger number of nodes with a smaller -number of nodes, but it's hard to "scale up". To do this, we can think of each row as having its own stream of data, -perhaps sliced out of a larger stream? TODO for version 3 +## Performance -### Tasks +### Reading from the Cache -#### TokenSeqDataset -* [] supports seek that jumps to a random offset in a random shard in the TokenizedDocumentCache -* [] can report current position for serialization -* [] can resume from a position +TensorStore can sustain high throughput but has pretty terrible latency (when hitting GCS). +The latency can be on the order of a second. We mitigate this by prefetching the data in the DataLoader. -#### JumpingDataset -* [] has a policy for jumping around in an TokenSeqDataset -* [] has a random key and a position within the policy -* [] can report key and current position for serialization +With prefetching we can sustain about a million tokens per second per host, wihch is sufficient. +In particular, when training a GPT-2 small model on a v3-32, loading is able to keep up with training. +However, 3/4 of evaluation time is spent blocked on loading data, so we could potentially speed up evaluation. +(However it's still twice as fast as with the old cache and data loader.) -#### ShuffleBufferDataset -* [] has a shuffle buffer of size ≤N -* [] has a random key -* [] can report key and shuffle buffer for serialization +### Writing to the Cache -#### Misc -* [] dataset hierarchy that exposes the interfaces we need (tree_leaves probably for serialization?) -* [] serialize the dataset on all nodes. This logic might need to be a bit different than for models, since the models all use GDAs and only write plain old arrays once. -* [] make sure we can resume from a checkpoint with bitwise determinism as before +Writes are also slow, but we also batch up the writes, typically writing 8K documents at a time. diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index 9488809ba..000b5a715 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -124,7 +124,7 @@ def loraize_hf_model(model): # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = trainer.data_loader(train_dataset, trainer.TrainBatch) loader = non_caching_cycle(loader) if int(state.step) != 0: diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 6578bc46c..a2201de76 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -15,8 +15,7 @@ import levanter from levanter.compat.hf_checkpoints import HFCheckpointConverter, save_hf_checkpoint_callback -from levanter.data import Dataset -from levanter.data.sharded_dataset import JsonDataset, JsonlDataset, WrappedHFDataset +from levanter.data import PermutationDataset from levanter.models.lm_model import LmExample, LmHeadModel, compute_next_token_loss from levanter.optim import OptimizerConfig from levanter.trainer import Trainer, TrainerConfig @@ -100,53 +99,21 @@ class TrainArgs: hf_save_steps: int = 1000 # How often to save the HuggingFace checkpoint. -# Encoder/Decoder dataset for Alpaca. -# We basically do string interpolation of the (instruction, input, output) triples with the prompt, -# and mask out the prompt and padding. -class SupervisedDataset(Dataset[LmExample]): - def __init__(self, preproc_dataset, tokenizer, mask_inputs): - self.preproc_dataset = preproc_dataset - self.tokenizer = tokenizer - self.mask_inputs = mask_inputs - - def __iter__(self): - for ex in self.preproc_dataset: - # annoyingly, pad expects things to be batched so we have to prepend a batch axis - ex = self.tokenizer.pad( - {k: np.expand_dims(v, 0) for k, v in ex.items()}, return_tensors="np", padding="max_length" - ) - ex = {k: v[0] for k, v in ex.items()} - input_ids = hax.named(ex["input_ids"], "position") - - # mask out padding and anything before the start of the target - Pos = input_ids.resolve_axis("position") - if self.mask_inputs: - loss_mask = hax.arange(Pos) >= ex["source_lens"] - - # don't predict the padding - targets = hax.roll(input_ids, -1, Pos) - loss_mask = loss_mask & (targets != self.tokenizer.pad_token_id) - else: - loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32) - - yield LmExample(input_ids, loss_mask) - - def _get_data_source(path_or_id): """The original alpaca.py used a json file, but it's since been moved to the HF dataset hub. You can use any dataset that's compatible with the structure of the alpaca dataset.""" if fsspec_utils.exists(path_or_id): # we're a bit generous here b/c we support compression if ".jsonl" in path_or_id: - return JsonlDataset([path_or_id]) + return levanter.data.datasource_from_jsonl([path_or_id]) elif ".json" in path_or_id: - return JsonDataset([path_or_id]) + return levanter.data.datasource_from_json([path_or_id]) else: raise ValueError( f"We only support HF Datasets or a data file with .json or .jsonl extensions, not {path_or_id}!" ) else: - return WrappedHFDataset(path_or_id, split="train") + return levanter.data.datasource_from_hf(path_or_id, split="train") def mk_dataset(config: TrainArgs, tokenizer: transformers.PreTrainedTokenizerBase): @@ -175,12 +142,37 @@ def format_example(ex): "source_lens": source_lens, } - dataset = dataset.map_batches(preprocess, batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer)) - dataset = dataset.build_or_load_cache(config.data_cache_dir, await_finished=True) - - dataset = SupervisedDataset(dataset, tokenizer, mask_inputs=config.mask_inputs) + dataset = dataset.map_batches(preprocess, batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer)) # type: ignore + dataset = dataset.build_or_load_cache(config.data_cache_dir, await_finished=True) # type: ignore + + def _prepare_example(ex: dict) -> LmExample: + """ + Prepare an example for training. This function converts the (cached) batch encoding into an LmExample. + + It goes through the following steps: + + 1. Pad the batch to the maximum length. + 2. Mask out the input and prompt if requested. + 3. Create an LmExample with the input_ids as the input and the next token as the target. + """ + # annoyingly, pad expects things to be batched so we have to prepend a batch axis + ex = tokenizer.pad({k: np.expand_dims(v, 0) for k, v in ex.items()}, return_tensors="np", padding="max_length") + ex = {k: v[0] for k, v in ex.items()} + input_ids = hax.named(ex["input_ids"], "position") + # mask out padding and anything before the start of the target + Pos = input_ids.resolve_axis("position") + if config.mask_inputs: + loss_mask = hax.arange(Pos) >= ex["source_lens"] + + # don't predict the padding + targets = hax.roll(input_ids, -1, Pos) + loss_mask = loss_mask & (targets != tokenizer.pad_token_id) + else: + loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32) + lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) + return lm_ex - return dataset + return dataset.map(_prepare_example) def get_prompts(prompt_path) -> dict: @@ -208,7 +200,7 @@ def train(config: TrainArgs): ) # Randomness in JAX is tightly controlled. We pass around a key that is used to generate random numbers. - training_key = jrandom.PRNGKey(config.trainer.seed) + training_key, data_key = jrandom.split(jrandom.PRNGKey(config.trainer.seed), 2) # This is largely the same as in Alpaca. Only change is we use the fast tokenizer. tokenizer = transformers.AutoTokenizer.from_pretrained( @@ -224,6 +216,7 @@ def train(config: TrainArgs): converter = converter.replaced(tokenizer=tokenizer) train_dataset = mk_dataset(config, tokenizer) + train_dataset = PermutationDataset(train_dataset, data_key) optimizer = config.optimizer.build(config.trainer.num_train_steps) @@ -243,11 +236,12 @@ def train(config: TrainArgs): # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = trainer.data_loader(train_dataset, trainer.TrainBatch) loader = non_caching_cycle(loader) state = trainer.initial_state(training_key, model=model) + # TODO: remove this. we don't need it now if int(state.step) != 0: logger.info(f"Resuming training from step {state.step}") for i in range(state.step): diff --git a/examples/gsm8k-lora/gsm8k_lora.py b/examples/gsm8k-lora/gsm8k_lora.py index b7ac3945c..97a9c06ab 100644 --- a/examples/gsm8k-lora/gsm8k_lora.py +++ b/examples/gsm8k-lora/gsm8k_lora.py @@ -14,9 +14,8 @@ import haliax as hax import levanter -from levanter.data import Dataset -from levanter.data.dataset import ShuffleDataset -from levanter.data.sharded_dataset import WrappedHFDataset +from levanter.data import PermutationDataset +from levanter.data.sharded_datasource import WrappedHFDataSource from levanter.lora import ( LoraConfig, lora_trainable_params_filter, @@ -67,37 +66,8 @@ class TrainArgs: merged_hf_upload: Optional[str] = None -class SupervisedDataset(Dataset[LmExample]): - def __init__(self, preproc_dataset, tokenizer, mask_inputs): - self.preproc_dataset = preproc_dataset - self.tokenizer = tokenizer - self.mask_inputs = mask_inputs - - def __iter__(self): - for ex in self.preproc_dataset: - # annoyingly, pad expects things to be batched so we have to prepend a batch axis - ex = self.tokenizer.pad( - {k: np.expand_dims(v, 0) for k, v in ex.items()}, return_tensors="np", padding="max_length" - ) - ex = {k: v[0] for k, v in ex.items()} - input_ids = hax.named(ex["input_ids"], "position") - - # mask out padding and anything before the start of the target - Pos = input_ids.resolve_axis("position") - if self.mask_inputs: - loss_mask = hax.arange(Pos) >= ex["source_lens"] - - # don't predict the padding - targets = hax.roll(input_ids, -1, Pos) - loss_mask = loss_mask & (targets != self.tokenizer.pad_token_id) - else: - loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32) - - yield LmExample.causal(input_ids, loss_mask=loss_mask) - - def mk_dataset(config: TrainArgs, tokenizer: transformers.PreTrainedTokenizerBase): - dataset = WrappedHFDataset("gsm8k", split="train", name="main") + dataset = WrappedHFDataSource("gsm8k", split="train", name="main") def preprocess(batch): def format_example(ex): @@ -125,9 +95,34 @@ def format_output(ex): dataset = dataset.map_batches(preprocess, batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer)) # type: ignore dataset = dataset.build_or_load_cache(config.data_cache_dir, await_finished=True) # type: ignore - dataset = SupervisedDataset(dataset, tokenizer, mask_inputs=config.mask_inputs) # type: ignore + def _prepare_example(ex: dict) -> LmExample: + """ + Prepare an example for training. This function converts the (cached) batch encoding into an LmExample. + + It goes through the following steps: + + 1. Pad the batch to the maximum length. + 2. Mask out the input and prompt if requested. + 3. Create an LmExample with the input_ids as the input and the next token as the target. + """ + # annoyingly, pad expects things to be batched so we have to prepend a batch axis + ex = tokenizer.pad({k: np.expand_dims(v, 0) for k, v in ex.items()}, return_tensors="np", padding="max_length") + ex = {k: v[0] for k, v in ex.items()} + input_ids = hax.named(ex["input_ids"], "position") + # mask out padding and anything before the start of the target + Pos = input_ids.resolve_axis("position") + if config.mask_inputs: + loss_mask = hax.arange(Pos) >= ex["source_lens"] + + # don't predict the padding + targets = hax.roll(input_ids, -1, Pos) + loss_mask = loss_mask & (targets != tokenizer.pad_token_id) + else: + loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32) + lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) + return lm_ex - return dataset + return dataset.map(_prepare_example) def train(config: TrainArgs): @@ -151,7 +146,7 @@ def train(config: TrainArgs): data_key = jrandom.PRNGKey(config.data_seed) train_dataset = mk_dataset(config, tokenizer) - train_dataset = ShuffleDataset(train_dataset, data_key, buffer_size=1000 * 1000) + train_dataset = PermutationDataset(train_dataset, data_key) optimizer = config.optimizer.build(config.trainer.num_train_steps) @@ -196,7 +191,7 @@ def loraize_hf_model(model): # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = trainer.data_loader(train_dataset, trainer.TrainBatch) loader = non_caching_cycle(loader) if int(state.step) != 0: diff --git a/pyproject.toml b/pyproject.toml index c94ec5a6a..8712d16a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ ] dependencies = [ "haliax>=1.4.dev307", - "equinox>=0.11.4", + "equinox==0.11.4", "jaxtyping>=0.2.20", "tokenizers>=0.15.2", "transformers>=4.41.2", @@ -49,6 +49,7 @@ dependencies = [ "rich~=13.0", "filelock~=3.13", # "ai2-olmo", + "async-lru~=2.0", ] [project.urls] diff --git a/scripts/repair_cache.py b/scripts/repair_cache.py deleted file mode 100644 index 49f9e36c6..000000000 --- a/scripts/repair_cache.py +++ /dev/null @@ -1,60 +0,0 @@ -import os -import os.path -from dataclasses import dataclass -from typing import List - -import fsspec -import pyarrow -from tqdm import tqdm - -import levanter -from levanter.data.shard_cache import LEDGER_FILE_NAME, CacheLedger, ChunkMetadata, _serialize_json_and_commit - - -@dataclass -class RepairCacheArgs: - cache_path: str - - -@levanter.config.main() -def main(args: RepairCacheArgs): - """Repairs a broken cache by recreating the ledger""" - for split in ["train", "validation"]: - # find train files in the dir, which can be in cloud - fs = fsspec.get_fs_token_paths(args.cache_path)[0] - paths = os.path.join(args.cache_path, split, "*.parquet") - files = fs.glob(paths) - - # We're basically redoing this, but without the old ledger: - chunks: List[ChunkMetadata] = [] - - pbar = tqdm(files) - total_input_ids = 0 - for file in pbar: - file = f"gs://{file}" - table = pyarrow.parquet.read_metadata(file) - - input_ids = 0 - for g in range(table.num_row_groups): - input_ids += table.row_group(g).column(0).statistics.num_values - - file = file.replace(os.path.join(args.cache_path, split), "").lstrip("/") - - chunks.append( - ChunkMetadata( - name=file.replace(".parquet", ""), - num_rows=table.num_rows, - field_counts={"input_ids": input_ids}, - ) - ) - - total_input_ids += input_ids - - pbar.set_postfix(num_rows=table.num_rows, total_input_ids=total_input_ids) - - ledger = CacheLedger(chunks=chunks) - _serialize_json_and_commit(os.path.join(args.cache_path, split, LEDGER_FILE_NAME), ledger) - - -if __name__ == "__main__": - main() diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 406a7b39a..21aaf5faa 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -8,13 +8,14 @@ import threading import time import warnings -from typing import Callable, Iterable, Optional +from typing import Callable, Optional import humanfriendly import jax from tqdm import tqdm import levanter.tracker +from levanter.data import DataLoader from levanter.logging import save_xla_dumps_to_wandb from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.wandb import WandbConfig @@ -69,7 +70,7 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n def compute_validation_loss( loss_fn: Callable, # [[M, ...], jax.numpy.ndarray], - dataset: Iterable, + dataset: DataLoader, max_batches: Optional[int] = None, name: Optional[str] = None, ): diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 7802a7f07..b102198d7 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -5,7 +5,6 @@ import os import pathlib import queue -import sys import threading import time import urllib.parse @@ -110,7 +109,7 @@ def __init__( self._manager = GlobalAsyncCheckpointManager(timeout_secs=60 * 30) if jax.process_index() == 0: - self._async_checkpoint_remover_queue: queue.Queue[str] = queue.Queue() + self._async_checkpoint_remover_queue: queue.Queue[str] = queue.Queue(maxsize=-1) self._async_checkpoint_remover_thread = threading.Thread( target=self._async_checkpoint_remover, daemon=True ) @@ -224,7 +223,7 @@ def wait_until_finished(self): def _rm_checkpoint(self, checkpoint): if jax.process_index() == 0: - print(f"Removing checkpoint {checkpoint}", file=sys.stderr, flush=True) + logger.info(f"Removing checkpoint {checkpoint}") self._async_checkpoint_remover_queue.put(checkpoint) def _do_rm_checkpoint(self, checkpoint): diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index 5727f4360..ce267041c 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -975,7 +975,7 @@ def select_if_missing(missing_leaf, new_value): else: return None - return jax.tree_map(select_if_missing, dtype_structs, new_model, is_leaf=lambda x: x is None) + return jax.tree.map(select_if_missing, dtype_structs, new_model, is_leaf=lambda x: x is None) new_buffers = _init_buffers() diff --git a/src/levanter/data/__init__.py b/src/levanter/data/__init__.py index 534ec6dbf..85d99f8ab 100644 --- a/src/levanter/data/__init__.py +++ b/src/levanter/data/__init__.py @@ -1,22 +1,24 @@ -from levanter.data.dataset import Dataset, ShardableDataset, ShuffleDataset -from levanter.data.loader import BatchLoader, ReplicatedBatchLoader, ShardedBatchLoader -from levanter.data.shard_cache import SerialCacheWriter, ShardCache, build_or_load_cache -from levanter.data.sharded_dataset import ShardedDataset, dataset_from_hf, dataset_from_jsonl -from levanter.data.utils import batched +from ._preprocessor import BatchProcessor +from .dataset import AsyncDataset, ListAsyncDataset, MappedAsyncDataset, SyncDataset +from .loader import DataLoader +from .mixture import MixtureDataset, StopStrategy +from .permutation import EraShufflingDataset, PermutationDataset +from .sharded_datasource import ShardedDataSource, datasource_from_hf, datasource_from_json, datasource_from_jsonl +from .utils import batched __all__ = [ "batched", - "Dataset", - "ShardableDataset", - "ShuffleDataset", - "BatchLoader", - "ReplicatedBatchLoader", - "ShardedBatchLoader", - "build_or_load_cache", - "ShardCache", - "ShardedDataset", - "SerialCacheWriter", - "dataset_from_hf", - "dataset_from_jsonl", + "ShardedDataSource", + "datasource_from_hf", + "datasource_from_jsonl", + "datasource_from_json", + "BatchProcessor", + "AsyncDataset", + "MappedAsyncDataset", + "SyncDataset", + "ListAsyncDataset", + "DataLoader", + "MixtureDataset", + "StopStrategy", ] diff --git a/src/levanter/data/_preprocessor.py b/src/levanter/data/_preprocessor.py index 08e287c54..9ee1e2dc2 100644 --- a/src/levanter/data/_preprocessor.py +++ b/src/levanter/data/_preprocessor.py @@ -17,7 +17,7 @@ """ -class BatchProcessor(Generic[T_contra], ABC): +class BatchProcessor(Generic[T_contra, U], ABC): """ A BatchProcessor is the main interface for preprocessing data. It takes a batch of data and returns a batch of processed data. It can be used to tokenize data, convert it to a RecordBatch, or do any other kind of preprocessing. @@ -25,7 +25,7 @@ class BatchProcessor(Generic[T_contra], ABC): """ @abstractmethod - def __call__(self, batch: Sequence[T_contra]) -> BatchResult: + def __call__(self, batch: Sequence[T_contra]) -> Sequence[U] | U: # U can be batched "structure of arrays" form """ Process a batch of data. You should return either a RecordBatch, a sequence of dicts (one per output example), or a dict of sequences (one per output field). @@ -34,6 +34,14 @@ def __call__(self, batch: Sequence[T_contra]) -> BatchResult: """ raise NotImplementedError + @property + @abstractmethod + def output_exemplar(self): + """ + An exemplar of what this processor returns. This is used to determine the output schema of a dataset. + """ + raise NotImplementedError + @property def resources(self) -> Dict[str, float]: """Any resources that this processor needs to run. Ray uses this to schedule tasks.""" @@ -113,7 +121,7 @@ def _construct_composite_batch_processor(dataset): """ def rec(dataset): - from levanter.data.sharded_dataset import _TransformedDataset + from levanter.data.sharded_datasource import _TransformedDataset if isinstance(dataset, _TransformedDataset): source, transforms, batch_transform = rec(dataset.source) @@ -165,6 +173,10 @@ def num_gpus(self): def resources(self): return self._resources + @property + def output_exemplar(self): + return self.transforms[-1].output_exemplar + def __call__(self, batch): # batch is initially a list of elements, but after a BatchMapTransform # it can be a recordbatch, dict of lists, or list of dicts @@ -196,7 +208,7 @@ def __call__(self, batch): return batch -def dict_from_record_batch(b): +def dict_from_record_batch(b) -> dict: # we follow the convention from hf batchencoding where homogeneous-lengthed arrays are turned into nd arrays # while heterogeneous lists are left as lists of arrays diff --git a/src/levanter/data/_process_interleave.py b/src/levanter/data/_process_interleave.py deleted file mode 100644 index b4586d130..000000000 --- a/src/levanter/data/_process_interleave.py +++ /dev/null @@ -1,338 +0,0 @@ -import asyncio -import heapq -from typing import Generic, Optional, Sequence, TypeVar - -import ray - - -G = TypeVar("G") -T = TypeVar("T") - - -# this is what we want: -# shards.permute().group(G).flatmap_interleaved(f, num_workers) # produces an iterator over T - - -# TODO: can we work with this? - -# def flatmap_interleaved(f, iterable, *, num_workers, ray_remote_args=None): -# """Apply f to each element of iterable, returning an interleaved list of results. -# -# Args: -# f: A function to apply to each element of iterable. Should return an iterator -# iterable: An iterable of elements to apply f to. -# num_workers: The number of workers to use. -# -# Returns: -# iterator over the results of applying f to each element of iterable, interleaving the results -# """ -# iterable = list(enumerate(iterable)) -# # group the elements by worker -# grouped = [iterable[i::num_workers] for i in range(num_workers)] -# -# sink = RoundRobinSink.remote(range(len(iterable))) -# -# results = [_compute_round_robin.options(**(ray_remote_args or {})).remote(f, group, sink) for group in grouped] -# ray.get(results) -# -# return sink._buffer.drain() -# -# -# @ray.remote -# def _compute_round_robin(f, groups, sink): -# serials = [0] * len(groups) -# emitters = [(group_id, f(group)) for group_id, group in groups] -# done_emitters = set() -# -# while len(done_emitters) < len(groups): -# for idx in range(len(groups)): -# group_id, emitter = emitters[idx] -# if group_id in done_emitters: -# continue -# item = next(emitter, None) -# if item is None: -# done_emitters.add(group_id) -# emitters[idx] = (group_id, None) -# del emitter -# sink.group_total_known(group_id, serials[group_id]) -# else: -# sink.append_to_group(group_id, serials[group_id], item) -# serials[group_id] += 1 - - -@ray.remote -class RoundRobinSink: - def __init__(self, groups): - self._buffer = GroupRoundRobinBuffer(groups) - - def append_to_group(self, group, item_serial, item): - self._buffer.append_to_group(group, item_serial, item) - - def group_total_known(self, group, total): - self._buffer.group_total_known(group, total) - - -class GroupRoundRobinBuffer(Generic[G, T]): - """ - A buffer that holds items from multiple groups and returns them in a round-robin fashion. - The groups need not have the same number of items. If a group is exhausted, it is removed from the rotation. - """ - - def __init__(self, groups: Sequence[G]): - self.groups = list(groups) - self._current_group = 0 - self.buffers: dict[G, list[tuple[int, T]]] = {group: [] for group in groups} - self._remaining_groups = set(groups) - self._totals_written: dict[G, int] = {group: 0 for group in groups} - self._totals_expected: dict[G, Optional[int]] = {group: None for group in groups} - - def append_to_group(self, group: G, item_serial: int, item: T): - if group not in self.groups: - raise ValueError(f"Group {group} not in {self.groups}") - - if group not in self._remaining_groups: - raise ValueError(f"Group {group} already finished") - - heapq.heappush(self.buffers[group], (item_serial, item)) - - def group_total_known(self, group: G, total: int): - if group not in self.groups: - raise ValueError(f"Group {group} not in {self.groups}") - - if group not in self._remaining_groups: - raise ValueError(f"Group {group} already finished: {total} vs {self._totals_expected[group]}") - - self._totals_expected[group] = total - - if self._totals_written[group] == total: - assert len(self.buffers[group]) == 0 - self._remaining_groups.remove(group) - - def is_finished(self): - return len(self._remaining_groups) == 0 - - def pop(self) -> Optional[T]: - group = self._next_group_to_read_from() - if group is None: - return None - - if len(self.buffers[group]) == 0: - return None - - cur_serial, item = self.buffers[group][0] - - if cur_serial != self._totals_written[group]: - return None - - heapq.heappop(self.buffers[group]) - - self._totals_written[group] += 1 - - if self._totals_written[group] == self._totals_expected[group]: - assert len(self.buffers[group]) == 0 - assert group in self._remaining_groups - self._remaining_groups.remove(group) - - self._current_group = (self._current_group + 1) % len(self.groups) - - return item - - def drain(self) -> list[T]: - items = [] - while True: - item = self.pop() - if item is None: - break - items.append(item) - - return items - - def _next_group_to_read_from(self): - if len(self._remaining_groups) == 0: - return None - - while True: - group = self.groups[self._current_group] - if group not in self._remaining_groups: - assert self._totals_written[group] == self._totals_expected[group] - assert len(self.buffers[group]) == 0 - self._current_group = (self._current_group + 1) % len(self.groups) - else: - break - return group - - -_SENTINEL = object() - - -class _BoxedError: - def __init__(self, exc): - self.exc = exc - - def __repr__(self): - return f"BoxedError({self.exc})" - - def __str__(self): - return f"BoxedError({self.exc})" - - def __eq__(self, other): - return isinstance(other, _BoxedError) and self.exc == other.exc - - def __hash__(self): - return hash(self.exc) - - -def _is_internal_item(item): - return item is _SENTINEL or isinstance(item, _BoxedError) - - -class InProgressSequence(Generic[T]): - def __init__(self): - self._buffer: list = [] - self._total_added = 0 - self._promises: dict[int, asyncio.Future] = {} - self._finished_length: Optional[int] = None - self._finished_promise = asyncio.Future() - - def append(self, item: T): - if self._finished_length is not None and len(self._buffer) >= self._finished_length: - raise IndexError("Index out of range") - self._buffer.append(item) - self._total_added += 1 - self._fulfill_promise(len(self._buffer) - 1) - - def to_list(self): - if not self.is_finished(): - raise ValueError("Not finished") - return list(self._buffer) - - def set_item(self, idx: int, item: T): - # self._buffer.append(item) - # return self._fulfill_promises() - - if idx < 0: - raise IndexError("Negative indices not supported") - - if self._finished_length is not None and idx >= self._finished_length: - raise IndexError("Index out of range") - - if idx >= len(self._buffer): - self._buffer.extend([_SENTINEL] * (idx - len(self._buffer) + 1)) - - if self._buffer[idx] is _SENTINEL: - self._total_added += 1 - - self._buffer[idx] = item - self._fulfill_promise(idx) - - def item_exception(self, idx: int, exc: Exception): - if idx < 0: - raise IndexError("Negative indices not supported") - - if self._finished_length is not None and idx >= self._finished_length: - raise IndexError("Index out of range") - - promise = self._promises.pop(idx, None) - if promise is not None: - promise.set_exception(exc) - - if idx >= len(self._buffer): - self._buffer.extend([_SENTINEL] * (idx - len(self._buffer) + 1)) - - self._buffer[idx] = _BoxedError(exc) - - self.set_exception(exc) - - def set_finished_length(self, length): - if self._finished_length is not None: - raise ValueError("Finished length already set") - self._finished_length = length - return self._flush_promises() - - def set_exception(self, exc: Exception): - if not self._finished_promise.done(): - self._finished_promise.set_exception(exc) - for promise in self._promises.values(): - promise.set_exception(exc) - - self._promises.clear() - - def is_finished(self): - return self._finished_length is not None and len(self._buffer) == self._finished_length - - @property - def finished_promise(self): - return self._finished_promise - - def final_length(self): - return self._finished_length - - def current_length(self): - return len(self._buffer) - - def get_promise(self, idx): - if idx < 0: - raise IndexError("Negative indices not supported") - - if self._finished_length is not None and idx >= self._finished_length: - raise IndexError("Index out of range") - - if self._finished_promise.done() and self._finished_promise.exception(): - return self._finished_promise - - if idx < len(self._buffer): - promise = asyncio.Future() - result = self._buffer[idx] - if isinstance(result, _BoxedError): - promise.set_exception(result.exc) - return promise - elif result is not _SENTINEL: - promise.set_result(result) - return promise - - if idx in self._promises: - return self._promises[idx] - - promise = asyncio.Future() - self._promises[idx] = promise - return promise - - def finalize(self): - if self._finished_length is None: - self._finished_length = len(self._buffer) - self._flush_promises() - - assert ( - self._total_added == self._finished_length - ), f"Finalize called with {self._total_added} != {self._finished_length}" - - async def get(self, idx): - if idx < len(self._buffer): - result = self._buffer[idx] - if isinstance(result, _BoxedError): - raise result.exc - elif result is not _SENTINEL: - return result - - return await self.get_promise(idx) - - def _fulfill_promise(self, idx): - promise = self._promises.pop(idx, None) - if promise is not None: - promise.set_result(self._buffer[idx]) - - if self._total_added == self._finished_length: - self._finished_promise.set_result(None) - - def _flush_promises(self): - assert self._finished_length is not None - - if self._total_added == self._finished_length: - self._finished_promise.set_result(None) - - for idx, promise in self._promises.items(): - if idx < self._finished_length: - if self._buffer[idx] is not _SENTINEL: - promise.set_result(self._buffer[idx]) - else: - promise.set_exception(IndexError("Index out of range")) diff --git a/src/levanter/data/_prp.py b/src/levanter/data/_prp.py new file mode 100644 index 000000000..65a86e66f --- /dev/null +++ b/src/levanter/data/_prp.py @@ -0,0 +1,63 @@ +import typing + +import jax.lax +import jax.numpy as jnp +import jax.random as jrandom +import numpy as np + + +# TODO: do we make this a pytree +class Permutation: + # Pseudo-Random Permutation Code + """A stateless pseudo-random permutation. + + This class generates a pseudo-random permutation of a given length. The permutation is generated using a PRNG + with a fixed key. The permutation is generated by finding a random `a` and `b` such that `gcd(a, length) != 1` and + then computing the permutation as `p(x) = (a * x + b) % length`. + + This is not a very good PRP, but it is probably good enough for our purposes. + """ + # TODO: is it actually good enough for our purposes? + + def __init__(self, length, prng_key): + self.length = length + self.prng_key = prng_key + a_key, b_key = jrandom.split(prng_key) + self._a = jrandom.randint(a_key, (), 1, length) + self._b = jrandom.randint(b_key, (), 0, length) + + cond = lambda a_and_key: jnp.all(jnp.gcd(a_and_key[0], length) != 1) + + def loop_body(a_and_key): + a, key = a_and_key + this_key, key = jrandom.split(key) + a = jrandom.randint(this_key, (), 1, length) + return a, key + + self._a, key = jax.lax.while_loop(cond, loop_body, (self._a, a_key)) + + self._a = int(self._a) + self._b = int(self._b) + + @typing.overload + def __call__(self, indices: int) -> int: + ... + + @typing.overload + def __call__(self, indices: jnp.ndarray) -> jnp.ndarray: + ... + + def __call__(self, indices): + if isinstance(indices, jnp.ndarray): + # TODO: use error_if? + # import equinox as eqx + if jnp.any(indices < 0) or jnp.any(indices >= self.length): + raise IndexError(f"index {indices} is out of bounds for length {self.length}") + elif isinstance(indices, np.ndarray): + if np.any(indices < 0) or np.any(indices >= self.length): + raise IndexError(f"index {indices} is out of bounds for length {self.length}") + else: + if indices < 0 or indices >= self.length: + raise IndexError(f"index {indices} is out of bounds for length {self.length}") + + return (self._a * indices + self._b) % self.length diff --git a/src/levanter/data/_queue.py b/src/levanter/data/_queue.py index b29327a83..fd8f84860 100644 --- a/src/levanter/data/_queue.py +++ b/src/levanter/data/_queue.py @@ -8,18 +8,18 @@ from queue import PriorityQueue from typing import List, Optional, Protocol, Sequence, TypeVar -import pyarrow as pa import ray from ray.actor import ActorHandle from levanter.utils.ray_utils import RefBox -from ._preprocessor import BatchProcessor, as_record_batch +from ._preprocessor import BatchProcessor logger = pylogging.getLogger(__name__) T = TypeVar("T") +U = TypeVar("U") LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" @@ -65,28 +65,20 @@ def __le__(self, other: "PriorityWorkItem"): return self.priority <= other.priority -def _mk_queue_aware_process_task(processor: BatchProcessor[T], queue: ActorHandle): +def _mk_queue_aware_process_task(processor: BatchProcessor[T, U], queue: ActorHandle): @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) - def process_task(desc, batch: List[T]) -> pa.RecordBatch: - pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) + def process_task(desc, batch: List[T]): + # pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) logger.debug(f"Processing batch {desc}") queue.task_running.remote() - # timer_thread = WaitTimeReportingThread( - # lambda t: logger.info(f"Waiting for {desc} to be processed for {t} seconds"), interval=30 - # ) - # timer_thread.start() try: result = processor(batch) - del batch - result = as_record_batch(result) logger.debug(f"Finished processing batch {desc}") return result except Exception as e: logger.exception(f"Error while processing batch {desc}") raise e finally: - # timer_thread.shutdown() - # timer_thread.join() pass return process_task @@ -120,7 +112,7 @@ class _BatchProcessorQueue: # (Generic[T]): ray doesn't like generics def batch_size(self): return self.processor.batch_size - def __init__(self, batch_processor: BatchProcessor[T]): + def __init__(self, batch_processor: BatchProcessor[T, U]): self.pqueue = PriorityQueue() self.processor = batch_processor self._next_task_id = 0 @@ -145,7 +137,10 @@ def _maybe_start_task(self): self.ready = False item = self.pqueue.get() batch = item.batch - item.task_future.set_result(self._task_processor.remote(item.desc, batch)) + try: + item.task_future.set_result(self._task_processor.remote(item.desc, batch)) + except Exception as e: + item.task_future.set_exception(e) def task_running(self): self.ready = True @@ -153,7 +148,7 @@ def task_running(self): @ray.remote(num_cpus=0.5, scheduling_strategy="SPREAD") -class PriorityProcessorActor: +class WorkQueueDispatcherActor: def __init__(self, max_in_flight: Optional[int] = 200): pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self._queue: list[PriorityWorkItem] = [] # heapq @@ -162,9 +157,17 @@ def __init__(self, max_in_flight: Optional[int] = 200): self._current_item: Optional[PriorityWorkItem] = None self._max_in_flight = max_in_flight + self._max_priority: Optional[float] = None self._processing_thread = threading.Thread(target=self._loop, daemon=True) self._processing_thread.start() + def set_max_dispatch_priority(self, max_priority: Optional[float]): + """ + When the sink is full, we will not dispatch items with a priority higher than this. + """ + with self._queue_lock: + self._max_priority = max_priority + def assign_work(self, group: PriorityWorkTaskGroupSpec): items = group.build().items() with self._queue_lock: @@ -196,7 +199,7 @@ def shutdown(self): if self._processing_thread.is_alive(): self._processing_thread.join() - def _loop(self: "PriorityProcessorActor"): + def _loop(self: "WorkQueueDispatcherActor"): should_sleep = False backpressure_queue: list[ray.ObjectRef] = [] @@ -220,6 +223,10 @@ def drain_backpressure_to(count): should_sleep = False item = heapq.heappop(self._queue) + if self._max_priority is not None and item.priority > self._max_priority: + logger.debug(f"Item {item.name} has priority {item.priority} which is too high. Rescheduling.") + heapq.heappush(self._queue, item) + continue self._current_item = item try: diff --git a/src/levanter/data/audio.py b/src/levanter/data/audio.py index 9a1f98d93..d04479a24 100644 --- a/src/levanter/data/audio.py +++ b/src/levanter/data/audio.py @@ -19,17 +19,18 @@ from haliax import Axis from levanter.compat.hf_checkpoints import load_processor, load_tokenizer -from levanter.data._preprocessor import BatchProcessor, dict_from_record_batch -from levanter.data.dataset import ShardableDataset +from levanter.data import AsyncDataset +from levanter.data._preprocessor import BatchProcessor +from levanter.data.dataset import MappedAsyncDataset from levanter.data.metrics_monitor import LoggerMetricsMonitor, LoggingMetricsMonitor, MetricsMonitor -from levanter.data.shard_cache import DEFAULT_ROWS_PER_CHUNK, ShardCache, build_or_load_cache -from levanter.data.sharded_dataset import AudioTextUrlDataset, ShardedDataset, WrappedHFDataset +from levanter.data.sharded_datasource import AudioTextUrlDataSource, ShardedDataSource, WrappedHFDataSource from levanter.data.text import BatchTokenizer # intercept the logging nonsense here from levanter.logging import silence_transformer_nag from levanter.models.asr_model import AudioTextExample -from levanter.utils.jax_utils import use_cpu_device +from levanter.store.cache import TreeCache, build_or_load_cache +from levanter.utils.jax_utils import local_cpu_mesh silence_transformer_nag() # noqa @@ -44,15 +45,6 @@ logger = logging.getLogger("levanter.data.audio") -AudioTextStorageBatch = TypedDict( - "AudioTextStorageBatch", - { - "input_features": np.ndarray, - "input_ids": np.ndarray, - "attention_mask": np.ndarray, - "audio_shape": Sequence[Tuple[int, int]], - }, -) AudioTextDict = TypedDict( "AudioTextDict", { @@ -62,8 +54,14 @@ }, ) +AudioTextDict_exemplar = { + "input_features": np.zeros((1, 1), dtype=np.float32), + "input_ids": np.zeros((0,), dtype=np.int32), + "attention_mask": np.zeros((0,), dtype=np.int32), +} -class BatchAudioProcessor(BatchProcessor[Tuple[np.ndarray, int, str]]): + +class BatchAudioProcessor(BatchProcessor[Tuple[np.ndarray, int, str], AudioTextDict]): """ A batch processor that converts raw audio into the expected inputs of a model. """ @@ -81,7 +79,7 @@ def __init__( padding=True, ): self.feature_extractor: SequenceFeatureExtractor = processor.feature_extractor - self.bt: PreTrainedTokenizerBase = BatchTokenizer( + self.bt = BatchTokenizer( tokenizer, enforce_bos=enforce_bos, enforce_eos=enforce_eos, @@ -95,7 +93,7 @@ def __init__( self.override_resources = override_resources self._batch_size = batch_size - def __call__(self, batch: Sequence[Tuple[np.ndarray, int, str]]) -> AudioTextStorageBatch: + def __call__(self, batch: Sequence[Tuple[np.ndarray, int, str]]) -> Sequence[AudioTextDict]: """ Process a batch of data. """ @@ -106,15 +104,28 @@ def __call__(self, batch: Sequence[Tuple[np.ndarray, int, str]]) -> AudioTextSto uniq_sampling_rates: set[int] = set(sampling_rates) assert len(uniq_sampling_rates) == 1, "Sampling rates should be standardized" audio_features: BatchFeature = self.feature_extractor(audio_batch, sampling_rate=uniq_sampling_rates.pop()) - text_features: BatchEncoding = self.bt(text_batch) - combined_features = audio_features | text_features - combined_features["input_ids"] = np.array(combined_features["input_ids"]) - combined_features["attention_mask"] = np.array(combined_features["attention_mask"]) - a_features = np.array(combined_features["input_features"]) - a_shape = a_features.shape - combined_features["audio_shape"] = [a_shape[1:]] * a_shape[0] - combined_features["input_features"] = a_features.reshape(a_shape[0], -1) - return combined_features + audio_features["input_features"] = np.array(audio_features["input_features"]) + text_features: list[dict] = self.bt(text_batch) + text_features = [ + {k: np.array(tf[k], dtype=np.int32) for k in ["input_ids", "attention_mask"]} for tf in text_features + ] + + # debatch and return + out = [] + for i, text in enumerate(text_features): + out.append( + { + "input_features": audio_features["input_features"][i], + "input_ids": text["input_ids"], + "attention_mask": text["attention_mask"], + } + ) + + return out # type: ignore + + @property + def output_exemplar(self): + return AudioTextDict_exemplar @property def num_cpus(self) -> int: @@ -146,10 +157,10 @@ class AudioDatasetSourceConfig: train_urls: List[str] = () # type: ignore validation_urls: List[str] = () # type:ignore - def get_shard_source(self, split) -> Optional[ShardedDataset[Tuple[np.ndarray, int, str]]]: + def get_shard_source(self, split) -> Optional[ShardedDataSource[Tuple[np.ndarray, int, str]]]: if self.id is not None: try: - ds = WrappedHFDataset(self.id, split=split, name=self.name, streaming=self.stream) + ds = WrappedHFDataSource(self.id, split=split, name=self.name, streaming=self.stream) except ValueError as e: # if the message starts with Bad split, then just return None if str(e).startswith("Bad split"): @@ -164,7 +175,7 @@ def get_shard_source(self, split) -> Optional[ShardedDataset[Tuple[np.ndarray, i def decode(x): text = x[self.text_key] audio_pointer = x[self.audio_key] - audio = AudioTextUrlDataset.resolve_audio_pointer(audio_pointer, self.sampling_rate) + audio = AudioTextUrlDataSource.resolve_audio_pointer(audio_pointer, self.sampling_rate) return (audio["array"], audio["sampling_rate"], text) return ds.map(decode) @@ -172,7 +183,7 @@ def decode(x): split_urls = self.urls_for_split(split) if len(split_urls) == 0: return None - return AudioTextUrlDataset(split_urls, self.text_key, self.audio_key, sampling_rate=self.sampling_rate) + return AudioTextUrlDataSource(split_urls, self.text_key, self.audio_key, sampling_rate=self.sampling_rate) def doc_iterator(self, split: str) -> Iterator[Tuple[np.ndarray, int, str]]: if self.id is not None: @@ -182,7 +193,7 @@ def doc_iterator(self, split: str) -> Iterator[Tuple[np.ndarray, int, str]]: else: urls = self.urls_for_split(split) - yield from AudioTextUrlDataset(urls, self.text_key, self.audio_key, sampling_rate=self.sampling_rate) + yield from AudioTextUrlDataSource(urls, self.text_key, self.audio_key, sampling_rate=self.sampling_rate) def urls_for_split(self, split): if split == "train": @@ -211,7 +222,6 @@ class AudioTaskConfig(abc.ABC): train_split: str = "train" validation_split: str = "validation" cache_dir: str = "cache/" - rows_per_chunk: int = DEFAULT_ROWS_PER_CHUNK # number of rows to process and cache per chunk enforce_bos: bool = True # whether to append bos even if the tokenizer doesn't enforce_eos: bool = True # whether to append eos even if the tokenizer doesn't max_length: int = 448 @@ -232,64 +242,72 @@ def the_tokenizer(self) -> PreTrainedTokenizerBase: return load_tokenizer(self.tokenizer) @cached_property - def the_feature_extractor(self) -> PreTrainedTokenizerBase: + def the_feature_extractor(self) -> SequenceFeatureExtractor: return self.the_processor.feature_extractor @abc.abstractmethod def train_set( self, batch_size: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> ShardableDataset[np.ndarray]: + ) -> AsyncDataset[np.ndarray]: pass @abc.abstractmethod def validation_sets( - self, batch_size: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, ShardableDataset[np.ndarray]]: + self, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Mapping[str, AsyncDataset[np.ndarray]]: pass -class ProcessedAudioCache(ShardableDataset[AudioTextStorageBatch]): +class ProcessedAudioCache(AsyncDataset[AudioTextDict]): """ Represents a cache of data with both pre-processed audio and tokenized text, which is a directory of parquet files with a ledger file. """ - def __init__(self, chunk_cache: ShardCache): - # Separates Batching For Processing from Batching For Training - self.chunk_cache = chunk_cache.with_batch_size(1) + def __init__(self, cache: TreeCache[AudioTextDict]): + self.cache = cache + + async def async_len(self) -> int: + return await self.cache.async_len() + + async def final_length_is_known(self) -> bool: + return await self.cache.final_length_is_known() + + def is_finite(self) -> bool: + return self.cache.is_finite() - def __iter__(self): - for batch in self._chunks(): - unarrow = dict_from_record_batch(batch) - # Flatten Singleton Batch Dimension - singleton_dict = {key: unarrow[key].squeeze() for key in unarrow} - singleton_dict["input_features"] = singleton_dict["input_features"].reshape(singleton_dict["audio_shape"]) - del singleton_dict["audio_shape"] - yield singleton_dict + async def current_len(self) -> Optional[int]: + return await self.cache.current_len() - def _chunks(self): - return self.chunk_cache.iter_batches_from_chunks() + async def get_batch(self, indices: Sequence[int]) -> Sequence[AudioTextDict]: + return await self.cache.get_batch(indices) + + # def _convert_to_example(self, storage: AudioTextStorageBatch) -> AudioTextDict: + # storage["input_features"] = storage["input_features"].reshape(storage["audio_shape"]) + # del storage["audio_shape"] + # return storage @staticmethod def build_or_load( cache_dir: str, - source: ShardedDataset[Tuple[np.ndarray, int, str]], + source: ShardedDataSource[Tuple[np.ndarray, int, str]], processor: ProcessorMixin, tokenizer: PreTrainedTokenizerBase, enforce_bos=True, enforce_eos=True, batch_size=128, - rows_per_chunk=DEFAULT_ROWS_PER_CHUNK, monitors=None, await_finished=True, override_resources=None, + max_length=448, ) -> "ProcessedAudioCache": - bp: BatchProcessor[Tuple[np.ndarray, int, str]] = BatchAudioProcessor( + bp = BatchAudioProcessor( processor, tokenizer, enforce_bos=enforce_bos, enforce_eos=enforce_eos, batch_size=batch_size, override_resources=override_resources, + max_length=max_length, ) monitors = monitors or [] cache = build_or_load_cache( @@ -297,8 +315,6 @@ def build_or_load( source, bp, await_finished=await_finished, - batch_size=batch_size, - rows_per_chunk=rows_per_chunk, monitors=monitors, ) if cache.is_finished: @@ -311,9 +327,9 @@ def build_or_load( return ProcessedAudioCache(cache) @staticmethod - def load(cache_dir, batch_size: int = 128): + def load(cache_dir): """ - Load a TokenizedDocumentCache from a directory. If the ledger file is not present, this will raise a + Load a ProcessedAudioCache from a directory. If the ledger file is not present, this will raise a FileNotFoundError. NOTE: ATM this attempts to migrate old caches to the new format, but this will be removed in the future. @@ -323,7 +339,7 @@ def load(cache_dir, batch_size: int = 128): """ try: - cache = ShardCache.load(cache_dir, batch_size=batch_size) + cache = TreeCache.load(cache_dir, AudioTextDict_exemplar) return ProcessedAudioCache(cache) except FileNotFoundError: raise FileNotFoundError(f"{cache_dir} is not a complete cache") @@ -331,15 +347,6 @@ def load(cache_dir, batch_size: int = 128): logger.exception("error loading cache") raise - def shard(self, shard_index, num_shards): - if num_shards <= shard_index: - raise ValueError(f"Shard index {shard_index} is out of range") - - if num_shards == 1: - return self - - return ProcessedAudioCache(self.chunk_cache.shard(shard_index, num_shards)) - @dataclass class AudioIODatasetConfig(AudioDatasetSourceConfig, AudioTaskConfig): @@ -351,16 +358,12 @@ def train_set(self, batch_size: int, monitors: Union[bool, List[MetricsMonitor]] raise ValueError("No training set!") return ds - def validation_set( - self, batch_size: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Optional[ProcessedAudioCache]: - return self.build_or_load_cache(self.validation_split, batch_size=batch_size, monitors=monitors) + def validation_set(self, monitors: Union[bool, List[MetricsMonitor]] = True) -> Optional[ProcessedAudioCache]: + return self.build_or_load_cache(self.validation_split, monitors=monitors) - def validation_sets( - self, batch_size: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, ProcessedAudioCache]: + def validation_sets(self, monitors: Union[bool, List[MetricsMonitor]] = True) -> Mapping[str, ProcessedAudioCache]: if self._has_validation_set: - validation_set = self.validation_set(batch_size, monitors) + validation_set = self.validation_set(monitors) if validation_set is not None: return {"": validation_set} return {} @@ -393,7 +396,7 @@ def build_or_load_cache( name = logger_name or os.path.basename(self.cache_dir) try: - return ProcessedAudioCache.load(split_cache_dir, batch_size=batch_size) + return ProcessedAudioCache.load(split_cache_dir) except FileNotFoundError: pass @@ -420,16 +423,16 @@ def build_or_load_cache( enforce_bos=self.enforce_bos, enforce_eos=self.enforce_eos, batch_size=batch_size, - rows_per_chunk=self.rows_per_chunk, monitors=monitors, await_finished=(split == "validation"), + max_length=self.max_length, ) -class AudioTextDataset(ShardableDataset[AudioTextExample]): +class AudioTextDataset(MappedAsyncDataset[AudioTextDict, AudioTextExample]): def __init__( self, - dataset: ShardableDataset[AudioTextStorageBatch], + dataset: AsyncDataset[AudioTextDict], TextPos: Axis, AudioPos: hax.AxisSelector, KPos: Axis, @@ -443,28 +446,23 @@ def __init__( self.key = key self.ignore_id = ignore_index - def shard(self, shard_id: int, num_shards: int) -> "AudioTextDataset": - return AudioTextDataset( - self.dataset.shard(shard_id, num_shards), - self.TextPos, - self.AudioPos, - self.KPos, - self.key, - self.ignore_id, - ) - - def __iter__(self) -> Iterator[AudioTextExample]: sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) - with use_cpu_device(): - - @functools.partial(eqx.filter_jit, out_shardings=sharding) - def _convert_example(inputs: AudioTextDict) -> "AudioTextExample": + @functools.partial(eqx.filter_jit, out_shardings=sharding) + def _convert_example(inputs: AudioTextDict) -> "AudioTextExample": + with local_cpu_mesh(): tokens = hax.named(inputs["input_ids"], self.TextPos) audio_features = hax.named(inputs["input_features"], self.AudioPos) - return AudioTextExample.init(audio_features, tokens, ignore_id=self.ignore_id) - for example in self.dataset: - converted_example = _convert_example(example) - yield converted_example + super().__init__(self.dataset, _convert_example) + + # def __iter__(self) -> Iterator[AudioTextExample]: + # + # + # with use_cpu_device(): + # + # + # for example in self.dataset: + # converted_example = _convert_example(example) + # yield converted_example diff --git a/src/levanter/data/dataset.py b/src/levanter/data/dataset.py index 14c8979b3..def0c158a 100644 --- a/src/levanter/data/dataset.py +++ b/src/levanter/data/dataset.py @@ -1,66 +1,356 @@ -from abc import ABC, abstractmethod -from typing import Iterable, Iterator, List, TypeVar +import abc +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Generic, Optional, Sequence, TypeVar -import jax.random as jrandom +import jax.random +import numpy as np from jax.random import PRNGKey +from levanter.utils import thread_utils -T = TypeVar("T", covariant=True) +logger = logging.getLogger(__name__) -class Dataset(Iterable[T], ABC): - @abstractmethod - def __iter__(self) -> Iterator[T]: + +T_co = TypeVar("T_co", covariant=True) +T = TypeVar("T") +U = TypeVar("U") + + +_executor = ThreadPoolExecutor(max_workers=10) + + +class DatasetBase(abc.ABC, Generic[T_co]): + """ + Base class for sync and async datasets. This class is not meant to be used directly. + """ + + @abc.abstractmethod + def as_async_dataset(self) -> "AsyncDataset[T_co]": + raise NotImplementedError("...") + + @abc.abstractmethod + def as_sync_dataset(self) -> "SyncDataset[T_co]": + raise NotImplementedError("...") + + +class AsyncDataset(DatasetBase[T_co]): + """ + An asynchronous dataset that can be used with async/await syntax. In Levanter, we use AsyncDataset for two purposes: + * To represent datasets that are inherently asynchronous (e.g. reading from disk, network, etc.). + * To represent datasets that are still being constructed. + + The core methods in this class are: + * `async_len`: Returns the final length of the dataset. + * `get_batch`: Returns a batch of items from the dataset. + * `current_len`: Returns the current length of the dataset. This may be None if no current length is known. + """ + + @abc.abstractmethod + async def async_len(self) -> int: + raise NotImplementedError + + @abc.abstractmethod + async def final_length_is_known(self) -> bool: + """Returns whether the final length of the dataset is known. + If this returns False, the current_len of the dataset may change in the future.""" + raise NotImplementedError + + @abc.abstractmethod + def is_finite(self) -> bool: + """ + Returns whether the dataset will have a known length in the future (e.g. if it's being constructed). + If this returns False, the length of the dataset is infinite or unknowable. + """ raise NotImplementedError + @abc.abstractmethod + async def current_len(self) -> Optional[int]: + """ + Returns the current length of the dataset that won't require (expensive) waiting. -class ShardableDataset(Dataset[T], ABC): - @abstractmethod - def shard(self, shard_id: int, num_shards: int) -> "ShardableDataset[T]": + If the current length is not known, returns None. This might block temporarily for a short time to get the + current length. + """ raise NotImplementedError - @abstractmethod - def __iter__(self) -> Iterator[T]: + async def getitem_async(self, index: int) -> T_co: + """ + Returns the item at the given index. Typically implemented as a wrapper around `get_batch`. + + In general, it is better to call (and override) `get_batch` instead of this method. + """ + return (await self.get_batch([index]))[0] + + @abc.abstractmethod + async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: raise NotImplementedError + async def wait_until_len_at_least(self, length: int) -> int: + """ + Returns the length of the dataset once it is at least `length` or if the dataset has a known (finished) length. + + The default implementation is a naive busy-wait loop. You should override this method for more efficient + implementations. + """ + return await naive_busy_wait_until_len_at_least(self, length) + + def as_sync_dataset(self): + return SyncifiedDataset(self) + + def as_async_dataset(self) -> "AsyncDataset[T_co]": + return self + + def map(self, fn: Callable[[T_co], U], *extra_args, **extra_kwargs) -> "MappedAsyncDataset[T_co, U]": + return MappedAsyncDataset(self, fn, *extra_args, **extra_kwargs) + + def shuffle(self, key: PRNGKey): + import levanter.data.permutation as permutation + + return permutation.PermutationDataset(self, key) + + def era_shuffle(self, era_length: int, key: PRNGKey): + import levanter.data.permutation as permutation + + return permutation.EraShufflingDataset(self, era_length, key=key) + + +async def naive_busy_wait_until_len_at_least(dataset: AsyncDataset[T_co], length: int) -> int: + """ + Runs a busy-wait loop until the dataset has at least `length` items or the final length is known. + + Returns the current length of the dataset when either the dataset has at least `length` items or the final length is + known. + + You should probably implement this in a more efficient way. This is just a naive implementation. + """ + while not await dataset.final_length_is_known(): + current_len = await dataset.current_len() + if current_len is None: + raise ValueError("Dataset has unknown length") + if current_len <= length: + await asyncio.sleep(0.1) + else: + return current_len + + return await dataset.async_len() + + +class SyncDataset(DatasetBase[T_co]): + """ + A synchronous dataset that can be used with regular Python syntax. In Levanter, we mainly do not use this class. + You can use this class if it's easier, then convert it to an AsyncDataset using `as_async_dataset`. This + is not as efficient as using an AsyncDataset directly, but it can be useful for testing or for simpler code. + """ + + @abc.abstractmethod + def __len__(self) -> int: + """ + Returns the final length of the data store. + May raise if the length is not known. + """ + + @abc.abstractmethod + def has_len(self) -> bool: + """ + Whether the data store currently has a known length. If this returns False, then the length of the data store + may change in the future. + """ + pass + + @abc.abstractmethod + def current_len(self) -> Optional[int]: + """ + Returns the current length of the data store. If the length is infinite or not known, returns None. + """ + pass + + def __getitem__(self, index: int) -> T_co: + return self.get_batch([index])[0] + + @abc.abstractmethod + def get_batch(self, indices: Sequence[int] | np.ndarray) -> Sequence[T_co]: + pass + + def as_async_dataset(self) -> "AsyncDataset[T_co]": + return AsyncifiedDataset(self) + + def as_sync_dataset(self) -> "SyncDataset[T_co]": + return self + -class InMemoryDataset(ShardableDataset[T]): - def __init__(self, items: List[T]): - self.items = items +class SyncifiedDataset(SyncDataset[T_co]): + def __init__(self, dataset: AsyncDataset[T_co]): + self.dataset = dataset + + def _run_coroutine(self, coro): + return thread_utils.blocking_wait(coro) + + def __len__(self) -> int: + return self._run_coroutine(self.dataset.async_len()) + + def has_len(self) -> bool: + return self.dataset.is_finite() + + def current_len(self) -> Optional[int]: + return self._run_coroutine(self.dataset.current_len()) - def __iter__(self) -> Iterator[T]: - return iter(self.items) + def get_batch(self, indices: Sequence[int] | np.ndarray) -> Sequence[T_co]: + return self._run_coroutine(self.dataset.get_batch(indices)) - def shard(self, shard_id: int, num_shards: int) -> "InMemoryDataset[T]": - return InMemoryDataset(self.items[shard_id::num_shards]) + def __getitem__(self, index: int) -> T_co: + return self._run_coroutine(self.dataset.getitem_async(index)) -class ShuffleDataset(ShardableDataset[T]): - def __init__(self, dataset: Dataset[T], key: PRNGKey, buffer_size: int): +class AsyncifiedDataset(AsyncDataset[T_co]): + def __init__(self, dataset: SyncDataset[T_co]): self.dataset = dataset - self.buffer_size = buffer_size - self.key = key - - def shard(self, shard_id: int, num_shards: int) -> "ShuffleDataset": - key = jrandom.fold_in(self.key, shard_id) - return ShuffleDataset(self.dataset.shard(shard_id, num_shards), key, self.buffer_size) # type: ignore - - def __iter__(self) -> Iterator[T]: - inner = iter(self.dataset) - buffer: List[T] = [] - current_key = self.key - - for item in inner: - if len(buffer) == self.buffer_size: - current_key, subkey = jrandom.split(current_key) - i = jrandom.randint(subkey, (), 0, len(buffer)) - yield buffer[i] - buffer[i] = item + + async def async_len(self) -> int: + return len(self.dataset) + + async def final_length_is_known(self) -> bool: + return self.dataset.has_len() + + def is_finite(self) -> bool: + return self.dataset.has_len() + + async def current_len(self) -> Optional[int]: + return self.dataset.current_len() + + async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: + return self.dataset.get_batch(indices) + + async def getitem_async(self, index: int) -> T_co: + return self.dataset[index] + + def __repr__(self): + return f"WrappedAsyncDataset({repr(self.dataset)})" + + def __str__(self): + return f"WrappedAsyncDataset({str(self.dataset)})" + + +class ListAsyncDataset(AsyncDataset[T]): + """ + A simple dataset that wraps a list. Mostly for testing. + """ + + def __init__(self, data: list[T], is_complete: bool = False): + self.data = data + self.is_complete = is_complete + if not is_complete: + self.complete_promise: Optional[asyncio.Future[None]] = asyncio.Future() + self.length_updated: Optional[asyncio.Condition] = asyncio.Condition() + else: + self.complete_promise = None + self.length_updated = None + + async def async_len(self) -> int: + # this is the final length + if not self.is_complete: + assert self.complete_promise is not None + await self.complete_promise + return len(self.data) + + async def final_length_is_known(self) -> bool: + return self.is_complete + + def is_finite(self) -> bool: + return True + + async def current_len(self) -> Optional[int]: + return len(self.data) + + async def get_batch(self, indices: Sequence[int]) -> Sequence[T]: + await self.wait_until_len_at_least(max(indices) + 1) + return [self.data[i] for i in indices] + + def append(self, item: T): + if self.is_complete: + raise ValueError("Cannot append to a finalized dataset") + self.data.append(item) + asyncio.create_task(self.notify_length_update()) + + def finalize(self): + self.is_complete = True + if self.complete_promise is not None: + self.complete_promise.set_result(None) + if not asyncio.get_event_loop().is_running(): + _executor.submit(lambda: asyncio.run(self.notify_length_update())) else: - buffer.append(item) + asyncio.create_task(self.notify_length_update()) + + async def notify_length_update(self): + async with self.length_updated: + self.length_updated.notify_all() + + async def wait_until_len_at_least(self, length: int) -> int: + if self.is_complete: + return len(self.data) + + assert self.length_updated is not None + + async with self.length_updated: + while len(self.data) < length and not self.is_complete: + await self.length_updated.wait() + + return len(self.data) + + +class MappedAsyncDataset(AsyncDataset[U], Generic[T, U]): + """ + A dataset that applies a function to each item in the dataset. + You can pass extra arguments to the function using `*extra_args` and `**extra_kwargs`. + If a kwarg called `key` is passed, it will be treated as a PRNGKey and folded in with the index of the item + for each call to the function. + """ + + def __init__( + self, + dataset: AsyncDataset[T], + fn: Callable[[T], U] | Callable[[T, Optional[PRNGKey]], U], + *extra_args, + **extra_kwargs, + ): + self.dataset = dataset + self.fn = fn + self._extra_args = extra_args + self._extra_kwargs = extra_kwargs + + async def async_len(self) -> int: + return await self.dataset.async_len() + + async def final_length_is_known(self) -> bool: + return await self.dataset.final_length_is_known() + + def is_finite(self) -> bool: + return self.dataset.is_finite() + + async def current_len(self) -> Optional[int]: + return await self.dataset.current_len() + + def _maybe_fold_in_key(self, key, index): + if key is not None: + key = jax.random.fold_in(key, index) + return key + + async def get_batch(self, indices: Sequence[int]) -> Sequence[U]: + items = await self.dataset.get_batch(indices) + return [self._call_fn(i, item) for i, item in zip(indices, items)] + + async def getitem_async(self, index: int) -> U: + return self._call_fn(index, await self.dataset.getitem_async(index)) + + async def wait_until_len_at_least(self, length: int) -> int: + return await self.dataset.wait_until_len_at_least(length) - while len(buffer) > 0: - current_key, subkey = jrandom.split(current_key) - i = jrandom.randint(subkey, (), 0, len(buffer)) - yield buffer[i] - del buffer[i] + def _call_fn(self, index, item): + if "key" in self._extra_kwargs: + key = self._maybe_fold_in_key(self._extra_kwargs["key"], index) + kwargs = {**self._extra_kwargs, "key": key} + else: + kwargs = self._extra_kwargs + return self.fn(item, *self._extra_args, **kwargs) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index b6e7f673f..ab97e0827 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -1,229 +1,257 @@ -import abc import functools import logging +import time from collections import defaultdict -from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union +from typing import Iterable, Iterator, Optional, Tuple, TypeVar import jax -import jax.numpy as jnp -import jax.tree_util as jtu +from jax import Array +from jax import numpy as jnp +from jax import tree_util as jtu from jax.experimental import multihost_utils from jax.sharding import Mesh, PartitionSpec -from jaxtyping import Array, PyTree +from jaxtyping import PyTree import haliax as hax -from haliax import NamedArray +from haliax import is_named_array +from haliax._src.util import index_where from haliax.partitioning import ResourceMapping -from haliax.util import is_named_array -from levanter.data import Dataset -from levanter.data.dataset import ShardableDataset -from levanter.mesh import local_devices_mapping, process_mesh_mapping +from levanter.data.dataset import AsyncDataset +from levanter.data.utils import batched from levanter.shapes import NamedShapeSpec, ShapeSpec, to_raw_shape from levanter.utils.background_iterable import BackgroundIterable -from levanter.utils.py_utils import non_caching_cycle +from levanter.utils.thread_utils import blocking_wait Ex = TypeVar("Ex") +_TensorSliceIndex = tuple[slice, ...] logger = logging.getLogger(__name__) -# TODO: write tests to verify this works when data spans multiple processes -_TensorSliceIndex = Tuple[slice, ...] +class DataLoader(Iterable[Ex]): + def __init__( + self, + Batch: hax.Axis, + data: AsyncDataset[Ex], + max_buffered_batches: Optional[int], + mesh: Mesh, + axis_resources: Optional[ResourceMapping], + # this is set heuristically for the typical tokenseqdataset we use. Should probably tune + prefetch_size: int = 32, + ): + """ + TODO: document this -class BatchLoader(Iterable[Ex], abc.ABC): - Batch: hax.Axis - mesh: Mesh - axis_resources: Optional[ResourceMapping] + Args: + Batch (hax.Axis): The batch axis + data (AsyncDataset[Ex]): The dataset to load from + max_buffered_batches (Optional[int]): The maximum number of batches to buffer. If None, the buffer is unbounded. + If <0, the buffer is disabled and single threaded operation is used. + axis_resources (Optional[ResourceMapping]): axis mapping + prefetch_size (int): The number of batches to prefetch at once + mesh (Mesh): The mesh to use - def __init__(self, max_capacity: Optional[int], axis_resources: Optional[ResourceMapping]): - """ - :param max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread - :param axis_resources: """ - self.max_capacity = max_capacity + self.max_buffered_batches = max_buffered_batches + self.prefetch_size = prefetch_size self.axis_resources = axis_resources + self.data_store = data + self.mesh = mesh + self.Batch = Batch - def __iter__(self) -> Iterator[Ex]: - ax_resources = self.axis_resources - if ax_resources is None: - ax_resources = hax.partitioning.current_thread_local_mapping() - - def produce_batches(): - with hax.axis_mapping(ax_resources): - for batch in self._produce_batches(): - yield batch - - if self.max_capacity is not None and self.max_capacity < 0: - yield from produce_batches() - else: - bg_iter = BackgroundIterable(produce_batches, max_capacity=self.max_capacity) - yield from bg_iter - - @abc.abstractmethod - def _produce_batches(self) -> Iterator[Ex]: - raise NotImplementedError + def _exemplar_shape(): + return blocking_wait(self.data_store.getitem_async(0)) + + self._ex_leaves, self._ex_structure = jax.tree_flatten(_exemplar_shape(), is_leaf=is_named_array) + + local_device_indices, local_indices = self._compute_local_device_indices() + + self._local_device_indices: dict[jax.Device, range] = local_device_indices + # this is just the flat indices + self._local_indices: list[int] = local_indices + + def _compute_local_device_indices(self): + sharding: jax.sharding.Sharding = hax.partitioning.sharding_for_axis( + self.Batch.name, self.axis_resources, self.mesh + ) + # this is a map from devices to the slice of the array that they contain (in the global array) + local_indices_map = sharding.addressable_devices_indices_map((self.batch_size,)) + # we just want all the indices + local_device_indices: dict[jax.Device, range] = { + device1: range(*idx[0].indices(self.batch_size)) + for device1, idx in local_indices_map.items() + if idx is not None + } + local_indices: list[int] = [] + for device, indices in local_device_indices.items(): + local_indices.extend(indices) + return local_device_indices, local_indices @property - def batch_size(self) -> int: + def batch_size(self): return self.Batch.size - def _construct_global_array_for_tree(self, item_exemplar: PyTree, get_batch_items: Callable[[int, int], PyTree]): - # ok this is a bit messy: we want to create a batch of items from our dataset, only loading - # the relevant data for each process. - # In general an item is represented as a PyTree, whose leaves are (named or unnamed) arrays. - # To make a batch we just want to add a leading dimension to each leaf array by stacking. - # That is, we have (conceptually) a List[PyTree[Array]] and we want to produce a PyTree[List[Array]] - # The difference is that we want to do this in a way that only loads the relevant data for each process - # So it's more that we have a LocalBatch[PyTree[Array]] and we want to produce a PyTree[GlobalBatch[Array]] - # because more than one device can get the same data, we need to make sure we only load it once since we're - # streaming. This is the cache - stacked_local_batch: Dict[Tuple[int, int], List[Array | hax.NamedArray]] = {} - - def get_local_batch(begin: int, end: int) -> List[Array]: - key = (begin, end) - if key in stacked_local_batch: - return stacked_local_batch[key] - - individual_datums = get_batch_items(begin, end) - - device_batch = _stack_tree(self.Batch.name, individual_datums) - batch_leaves = jtu.tree_leaves(device_batch) - - stacked_local_batch[key] = batch_leaves - - return batch_leaves - - def get_local_data_for_leaf(indices: _TensorSliceIndex, leaf_index: int) -> Array: - batch_slice = indices[0] - begin, end, _ = batch_slice.indices(self.Batch.size) - local_batch = get_local_batch(begin, end) - leaf = local_batch[leaf_index] - other_indices = indices[1:] - if all(idx == slice(None) for idx in other_indices): - return leaf - else: - return leaf[(..., *indices[1:])] - - def make_global_array_for_leaf(leaf_index, item_leaf_shape: Union[ShapeSpec, NamedShapeSpec]): - raw_array = jax.make_array_from_callback( - to_raw_shape(item_leaf_shape), - jax.sharding.NamedSharding(self.mesh, self._pspec_for(item_leaf_shape)), - lambda indices: get_local_data_for_leaf(indices, leaf_index), - ) - if isinstance(item_leaf_shape, NamedShapeSpec): - return NamedArray(raw_array, item_leaf_shape.shape) - else: - return raw_array - - item_leaves, item_shape = jtu.tree_flatten(item_exemplar, is_leaf=is_named_array) - - gda_leaves = [ - make_global_array_for_leaf(leaf_index, _batchified_shape(self.Batch, item_leaf)) - for leaf_index, item_leaf in enumerate(item_leaves) - ] - - gda_tree = jtu.tree_unflatten(item_shape, gda_leaves) - - return gda_tree - - def _pspec_for(self, shape_spec: Union[ShapeSpec, NamedShapeSpec]) -> PartitionSpec: + def __iter__(self): + return self.iter_from_step(None) + + def iter_from_step(self, start_from_batch: Optional[int] = None): + return DataLoaderIterator(self, start_from_batch=start_from_batch) + + +class DataLoaderIterator(Iterator[Ex]): + def __init__(self, data_loader: DataLoader, start_from_batch: Optional[int] = None): + self.dl = data_loader + self._start_from_batch = start_from_batch + self.mapping = self.dl.axis_resources + if self.mapping is None: + self.mapping = hax.partitioning.current_thread_local_mapping() + + # TODO: bring back non-prefetching version + buffered_batches = self.dl.max_buffered_batches + self._batches = iter(BackgroundIterable(self._produce_batches, max_capacity=buffered_batches)) + + def __next__(self): + time_start = time.time() + out = next(self._batches) + time_end = time.time() + if (time_end - time_start) > 0.5: + logger.info(f"Prefetch wasn't fast enough: {time_end - time_start:.3f}") + return out + + async def _produce_batches(self): + batch_number = self._start_from_batch or 0 + total_ex_loaded = 0 + done = False + while not done: + next_batch_numbers = [] + for i in range(self.dl.prefetch_size): + if self.dl.data_store.is_finite(): + next_end = (batch_number + 1) * self.dl.batch_size + available_len = await self.dl.data_store.wait_until_len_at_least(next_end) + if available_len < next_end: + done = True + break + + next_batch_numbers.append(batch_number) + batch_number += 1 + + async for batch in self._retrieve_batches(next_batch_numbers): + yield batch + + total_ex_loaded += self.dl.batch_size * len(next_batch_numbers) + + async def _retrieve_batches(self, batch_numbers: list[int]): + with hax.axis_mapping(self.mapping), self.dl.mesh: + indices_for_this_batch_of_batches: list[int] = [] + for bn in batch_numbers: + indices_this_batch = range(bn * self.dl.batch_size, (bn + 1) * self.dl.batch_size, 1) + indices_this_batch_this_process = [indices_this_batch[i] for i in self.dl._local_indices] + indices_for_this_batch_of_batches.extend(indices_this_batch_this_process) + + time_start = time.time() + individual_datums = await self.dl.data_store.get_batch(indices_for_this_batch_of_batches) + time_end = time.time() + logger.debug(f"Time to get {len(batch_numbers)} batches: {time_end - time_start:.3f}") + time_start = time.time() + # reshape to be per batch + individual_datums = list(batched(individual_datums, len(self.dl._local_indices))) + + # below we're gonna get the indices relative to this batch (i.e. 0 to batch_size) + index_to_datum = [ + {index: datum for index, datum in zip(self.dl._local_indices, individual_data_batch)} + for individual_data_batch in individual_datums + ] + + def get_local_batch(bn: int, begin: int, end: int) -> list: + # TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example + # which will require support from the datastore (i.e. tensorstore) + device_batch = _stack_tree(self.dl.Batch.name, [index_to_datum[bn][i] for i in range(begin, end)]) + batch_leaves = hax.tree_util.tree_leaves(device_batch) + return batch_leaves + + def get_local_data_for_leaf(bn, indices: _TensorSliceIndex, leaf_index: int) -> Array: + batch_slice = indices[0] + begin, end, stride = batch_slice.indices(self.dl.batch_size) + if stride != 1: + raise ValueError("Stride must be 1") + + leaf_data = (get_local_batch(bn, begin, end))[leaf_index] + + if isinstance(leaf_data, hax.NamedArray): + # select out the batch axis + batch_index = index_where(lambda ax: ax.name == self.dl.Batch.name, leaf_data.axes) + new_indices = list(indices) + new_indices[batch_index] = slice(None) + return leaf_data.array[tuple(new_indices)] + + else: + other_indices = indices[1:] + if all(idx == slice(None) for idx in other_indices): + return leaf_data + else: + # TODO: this doesn't work with named axes + return leaf_data[(..., *other_indices)] + + for batch_offset, bn in enumerate(batch_numbers): + + def make_global_array_for_leaf(leaf_index, item_leaf_shape: ShapeSpec | NamedShapeSpec): + def get_data(indices): + return get_local_data_for_leaf(batch_offset, indices, leaf_index) + + raw_array = jax.make_array_from_callback( + to_raw_shape(item_leaf_shape), + jax.sharding.NamedSharding(self.dl.mesh, self._pspec_for(item_leaf_shape)), + get_data, + ) + if isinstance(item_leaf_shape, NamedShapeSpec): + return hax.NamedArray(raw_array, item_leaf_shape.shape) + else: + return raw_array + + gda_leaves = [ + make_global_array_for_leaf(leaf_index, _batchified_shape(self.dl.Batch, item_leaf)) + for leaf_index, item_leaf in enumerate(self.dl._ex_leaves) + ] + + gda_tree = jax.tree.unflatten(self.dl._ex_structure, gda_leaves) + yield gda_tree + + def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec: if isinstance(shape_spec, ShapeSpec): # type: ignore - batch_name = hax.partitioning.physical_axis_name(self.Batch, self.axis_resources) + batch_name = hax.partitioning.physical_axis_name(self.dl.Batch, self.dl.axis_resources) return PartitionSpec(batch_name, *((None,) * (len(shape_spec.shape) - 1))) else: - return hax.partitioning.pspec_for_axis(shape_spec.shape, self.axis_resources) # type: ignore - - -class ShardedBatchLoader(BatchLoader[Ex]): - """ - ShardedBatchLoader wraps a "local dataset" (a dataset that is shardable and can be iterated over) to produce - distributed/sharded jax.Arrays representing batches of data. Each array that has a global shape - but only has the data for some of the chunks of the array (namely, the ones on the local devices). - Thus, each process loads the data for its devices. - - **NOTE: ShardedBatchLoader loops forever since it's hard to do coordination.** - - The details are a bit complex: We have a device mesh of shape (data, model). We want each row of the device mesh to - get batch_size//num_rows examples. Usually, a process will be responsible for one or more entire rows, meaning - that it wil load data that is distinct from every other process. However, if num_cols > num_devices_per_process, - then some processes will need to load the same data. We use the process_mesh_position to determine which data to - load, by determining which row(s) of the device mesh the process is responsible for. - - :arg local_dataset: a dataset that is shardable and can be iterated over - :arg mesh: the device mesh - :arg Batch: the batch size - :param max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread - """ - - def __init__( - self, - local_dataset: ShardableDataset[Ex], - mesh: Mesh, - Batch: hax.Axis, - axis_resources: Optional[ResourceMapping] = None, - max_capacity: Optional[int] = 64, - *, - override_process_data_pos: Optional[int] = None, # for testing - override_process_data_groups: Optional[int] = None, # for testing - ): - self.mesh = mesh - self.Batch = Batch - - process_mesh_map = process_mesh_mapping(self.mesh) - local_devices_map = local_devices_mapping(self.mesh) - process_data_pos = override_process_data_pos or process_mesh_map[jax.process_index()] - num_data_process_groups = override_process_data_groups or max(process_mesh_map.values()) + 1 - - if not override_process_data_groups: - assert num_data_process_groups <= jax.process_count() - - self.process_data_pos = process_data_pos - self.num_data_process_groups = num_data_process_groups - assert self.Batch.size % num_data_process_groups == 0 + return hax.partitioning.pspec_for_axis(shape_spec.shape, self.dl.axis_resources) # type: ignore - self.process_mesh_map = process_mesh_map - self.local_devices_map = local_devices_map - self.per_device_batch_size = self.batch_size // self.mesh.devices.shape[0] // self.mesh.devices.shape[1] - self.item_dataset = local_dataset.shard(process_data_pos, num_data_process_groups) - super().__init__(max_capacity, axis_resources) +def _abstractify(x): + def _abstractify_array(x): + if isinstance(x, jax.numpy.ndarray): + return ShapeSpec(x.shape, x.dtype) + elif isinstance(x, hax.NamedArray): + return NamedShapeSpec(x.axes, x.dtype) - def _produce_batches(self) -> Iterator[PyTree]: - one_item_generator = non_caching_cycle(self.item_dataset) - batched = _batched(one_item_generator, self.local_batch_size) + return x - def batch_callback(global_begin, _): - # global_begin is uid for DP/FSDP - # DP_id * per_device_bs = global_begin - device_pos = global_begin // self.per_device_batch_size + return hax.tree_util.tree_map(_abstractify_array, x) - begin = self.local_devices_map[device_pos] * self.per_device_batch_size - end = begin + self.per_device_batch_size - return local_batch[begin:end] - - while True: - local_batch: List[PyTree] = next(batched) - - batch = self._construct_global_array_for_tree( - item_exemplar=local_batch[0], - get_batch_items=batch_callback, - ) - - yield batch +def _batchified_shape(Batch, leaf: hax.NamedArray | Array) -> ShapeSpec | NamedShapeSpec: + if is_named_array(leaf): + return NamedShapeSpec((Batch,) + leaf.axes, leaf.dtype) + else: + return ShapeSpec((Batch.size,) + leaf.shape, leaf.dtype) - @property - def batch_size(self) -> int: - """Returns the 'global' batch size: the effective number of examples in a batch across all devices/hosts""" - return self.Batch.size - @property - def local_batch_size(self) -> int: - """Returns the 'local' batch size: the number of examples in a batch on this host""" - return self.batch_size // self.num_data_process_groups +def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec: + if isinstance(shape_spec, ShapeSpec): # type: ignore + batch_name = hax.partitioning.physical_axis_name(self.Batch, self.axis_resources) + return PartitionSpec(batch_name, *((None,) * (len(shape_spec.shape) - 1))) + else: + return hax.partitioning.pspec_for_axis(shape_spec.shape, self.axis_resources) # type: ignore @functools.partial(jax.jit, static_argnums=(0,)) @@ -234,50 +262,7 @@ def _stack_leaves_unchecked(*leaves): else: return jnp.stack(leaves) - return jax.tree_map(_stack_leaves_unchecked, *individual_datums, is_leaf=is_named_array) - - -class ReplicatedBatchLoader(BatchLoader[Ex]): - """A batch loader that creates batches without sharded data loading. All examples are loaded on all machines and then - sharded. This is useful if you have a small dataset and want to make a single pass over it. - - Note: this class discards the final batch if it is smaller than the batch size. - - :arg item_dataset: a dataset that is shardable and can be iterated over - :arg mesh: the device mesh - :arg Batch: the batch size - :arg axis_resources: the resources for the batch axis - :param max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread - """ - - def __init__( - self, - item_dataset: Dataset[Ex], - mesh: Mesh, - Batch: hax.Axis, - axis_resources: Optional[ResourceMapping] = None, - max_capacity: Optional[int] = 64, - ): - assert item_dataset is not None - self.item_dataset = item_dataset - self.mesh = mesh - self.Batch = Batch - - super().__init__(max_capacity, axis_resources) - - def _produce_batches(self): - for batch in _batched(self.item_dataset, self.Batch.size): - sharded = self._construct_global_array_for_tree( - item_exemplar=batch[0], get_batch_items=lambda begin, end: batch[begin:end] - ) - yield sharded - - -def _batchified_shape(Batch, leaf: Union[NamedArray, Array]): - if isinstance(leaf, NamedArray): - return NamedShapeSpec((Batch,) + leaf.axes, leaf.dtype) - else: - return ShapeSpec((Batch.size,) + leaf.shape, leaf.dtype) + return jax.tree.map(_stack_leaves_unchecked, *individual_datums, is_leaf=is_named_array) def check_sharded_consistency(tree: PyTree, check_disjoint_indices_are_different: bool = False): @@ -340,12 +325,3 @@ def _to_tuple(index: Tuple[slice, ...]) -> Tuple[Tuple[int, int], ...]: for leaf in jtu.tree_leaves(tree): check_array(leaf) - - -def _batched(item_iter, size): - batch = [] - for item in item_iter: - batch.append(item) - if len(batch) == size: - yield batch - batch = [] diff --git a/src/levanter/data/metrics_monitor.py b/src/levanter/data/metrics_monitor.py index 264229cdc..4e4619ffb 100644 --- a/src/levanter/data/metrics_monitor.py +++ b/src/levanter/data/metrics_monitor.py @@ -25,7 +25,6 @@ @dataclass class InProgressCacheMetrics: rows_finished: int = 0 - chunks_finished: int = 0 shards_finished: int = 0 field_counts: Dict[str, int] = dataclasses.field(default_factory=dict) is_finished: bool = False @@ -63,7 +62,6 @@ def _init_progress(self, metrics): columns = [ BarColumn(), TaskProgressColumn(), - TextColumn("| {task.fields[chunks_finished]} chunks", justify="center"), TextColumn("| {task.fields[rows_finished]} docs", justify="center"), ] @@ -103,7 +101,6 @@ def __call__(self, metrics: InProgressCacheMetrics): to_log: Dict[str, Any] = {} to_log[f"{self.prefix}/shards"] = metrics.shards_finished - to_log[f"{self.prefix}/chunks"] = metrics.chunks_finished to_log[f"{self.prefix}/rows"] = metrics.rows_finished for field, count in metrics.field_counts.items(): @@ -117,7 +114,6 @@ def __call__(self, metrics: InProgressCacheMetrics): # assert self.last_time is not None # elapsed = time.time() - self.last_time # to_log[f"{self.prefix}/shards_per_s"] = (metrics.shards_finished - self.last_metrics.shards_finished) / elapsed - # to_log[f"{self.prefix}/chunks_per_s"] = (metrics.chunks_finished - self.last_metrics.chunks_finished) / elapsed # to_log[f"{self.prefix}/rows_per_s"] = (metrics.rows_finished - self.last_metrics.rows_finished) / elapsed # # for field, count in metrics.field_counts.items(): @@ -132,19 +128,28 @@ def __call__(self, metrics: InProgressCacheMetrics): class LoggerMetricsMonitor(MetricsMonitor): # TODO: I'd like to get the trainer pbar migrated to rich and just use rich everywhere, but until then, # we have separate logging - def __init__(self, logger: Optional[Union[pylogging.Logger, str]] = None, level=pylogging.INFO): + def __init__( + self, + logger: Optional[Union[pylogging.Logger, str]] = None, + level=pylogging.INFO, + log_interval: float | int = 30.0, + ): if isinstance(logger, str): logger = pylogging.getLogger(logger) self.logger = logger or pylogging.getLogger(__name__) self.level = level + self.log_interval = log_interval + self._last_log_time = time.time() def __call__(self, metrics: InProgressCacheMetrics): if jax.process_index() == 0: - self.logger.log( - self.level, - f" done: Shards: {metrics.shards_finished} | Chunks: {metrics.chunks_finished} | Docs:" - f" {metrics.rows_finished}", - ) + if time.time() - self._last_log_time > self.log_interval: + self._last_log_time = time.time() + + self.logger.log( + self.level, + f" done: Shards: {metrics.shards_finished} | Docs: {metrics.rows_finished}", + ) if metrics.is_finished: self.logger.info("Cache creation finished") diff --git a/src/levanter/data/mixture.py b/src/levanter/data/mixture.py index ba7ae674b..eb1bdfaaf 100644 --- a/src/levanter/data/mixture.py +++ b/src/levanter/data/mixture.py @@ -1,12 +1,18 @@ -from typing import Dict, Iterator, Mapping, TypeVar +import asyncio +import warnings +from typing import Mapping, Optional, Sequence, TypeVar -import jax.random +import jax import numpy as np +from async_lru import alru_cache +from jax.random import PRNGKey from jaxtyping import PRNGKeyArray from haliax.util import StringHolderEnum -from levanter.data import ShardableDataset +from levanter.data import AsyncDataset +from levanter.utils.index import Index +from levanter.utils.thread_utils import future_from_value T = TypeVar("T") @@ -18,15 +24,19 @@ class StopStrategy(metaclass=StringHolderEnum): RESTART_STRATEGY = "restart" -class MixtureDataset(ShardableDataset[T]): +class MixtureDataset(AsyncDataset[T]): """ MixtureDataset supports loading data from multiple datasets. It takes a list of datasets and yields from them according to the weights. + Creating a random-access MixtureDataset is challenging because we need to keep track of the current index of each + dataset. So solve this, we instead use "block-deterministic" mixtures, where the number of samples from each dataset + in each block is always identical (and we shuffle the order of the dataset ids in each block). + Args: datasets: A dict of datasets, where the key is the name of the dataset and the value is the dataset itself weights: weights for each dataset - stop_strategy: strategy for stopping the iteration, by default RESTART_STRATEGY + stop_strategy: strategy for stopping the iteration, by default RESTART_STRATEGY. (Currently only RESTART_STRATEGY is supported) - FIRST_STOP_STRATEGY: stop when one dataset has been exhausted - ALL_STOP_STRATEGY: stop when all datasets have been exhausted - RESTART_STRATEGY: restart the dataset when it has been exhausted @@ -35,57 +45,187 @@ class MixtureDataset(ShardableDataset[T]): def __init__( self, - datasets: Mapping[str, ShardableDataset[T]], - weights: Dict[str, float], - key: int | PRNGKeyArray, + datasets: Mapping[str, AsyncDataset[T]], + weights: dict[str, float], + block_size: int, + *, + randomize_blocks: bool = True, + key: PRNGKeyArray | int, stop_strategy: str = StopStrategy.RESTART_STRATEGY, ): - self.datasets = datasets self.weights = MixtureDataset._normalize_weights(weights) + self.datasets = {name: dataset for name, dataset in datasets.items() if self.weights.get(name, 0) > 0} + self.dataset_index = Index(self.datasets.keys()) + self.block_size = block_size + # we pack index and ds id into a single 32 bit, so block size must be at most 2^16 + if block_size >= 2**16: + raise ValueError(f"Block size must be at most 2^16, got {block_size}") + + self.randomize_blocks = randomize_blocks + + if isinstance(key, int): + key = PRNGKey(key) + + self.key = key if stop_strategy not in StopStrategy: # type: ignore raise ValueError(f"Stop strategy {stop_strategy} is not supported.") - self.stop_strategy = stop_strategy + # for now, just support restart strategy + if stop_strategy != StopStrategy.RESTART_STRATEGY: + raise NotImplementedError("Only restart strategy is supported for now.") - if not isinstance(key, int): - key = jax.random.randint(key, (), 0, 2**20).item() + self.stop_strategy = stop_strategy - self.key = key + self._counts_per_block = self._compute_expected_counts_per_block(block_size) + # precompute a list of ids for each block + # the ids contain both the dataset index and the index within the dataset + self._unpermuted_ids = self._compute_unpermuted_ids(self._counts_per_block) + + def _compute_expected_counts_per_block(self, block_size): + _expected_values_per_block = np.zeros(len(self.datasets), dtype=np.int32) + for i, dsname in enumerate(self.dataset_index): + _expected_values_per_block[i] = self.weights[dsname] * block_size + + # handle remainder by adding to the largest dataset + largest_dataset = np.argmax(_expected_values_per_block) + _expected_values_per_block[largest_dataset] += block_size - _expected_values_per_block.sum() + + # check if any dataset has 0 samples (and nonzero weight) + for i, dsname in enumerate(self.dataset_index): + if _expected_values_per_block[i] == 0 and self.weights[dsname] > 0: + warnings.warn( + f"Dataset {dsname} has 0 samples in the block, but weight of {self.weights[dsname]}." + " Recommend increasing block size." + ) + + return _expected_values_per_block + + def _compute_unpermuted_ids(self, counts_per_block): + unpermuted_ids = np.zeros(int(counts_per_block.sum()), dtype=np.int64) + start = 0 + for i, dsname in enumerate(self.dataset_index): + count = counts_per_block[i] + unpermuted_ids[start : start + count] = (i << 16) + np.arange(count) + start += count + return unpermuted_ids @staticmethod - def _normalize_weights(weights: Dict[str, float]): + def _normalize_weights(weights: dict[str, float]): """Normalize the weights to sum to 1""" total = sum(weights.values()) if total == 0: raise ValueError(f"Datasets' weights cannot sum to 0, got {weights}") return {name: weight / total for name, weight in weights.items() if weight > 0} - def shard(self, shard_id: int, num_shards: int) -> "MixtureDataset": - """Return a MixtureDataset with the sharded datasets""" - sharded = {name: dset.shard(shard_id, num_shards) for name, dset in self.datasets.items()} - my_key = int(jax.random.randint(jax.random.PRNGKey(self.key), (num_shards,), 0, 2**20)[shard_id]) - return MixtureDataset(datasets=sharded, weights=self.weights, stop_strategy=self.stop_strategy, key=my_key) - - def __iter__(self) -> Iterator[np.ndarray]: - iterators = {name: iter(dataset) for name, dataset in self.datasets.items()} - current_weights = self._normalize_weights(self.weights) - rng = np.random.default_rng(self.key) - - while True: - dataset_name = rng.choice(list(current_weights.keys()), p=list(current_weights.values())) - try: - item = next(iterators[dataset_name]) - yield item - except StopIteration: - match self.stop_strategy: - case StopStrategy.RESTART_STRATEGY: - iterators[dataset_name] = iter(self.datasets[dataset_name]) - case StopStrategy.FIRST_STOP_STRATEGY: - break - case StopStrategy.ALL_STOP_STRATEGY: - del iterators[dataset_name] - del current_weights[dataset_name] - if len(current_weights) == 0: - break - current_weights = self._normalize_weights(current_weights) + async def async_len(self) -> int: + if self.stop_strategy == StopStrategy.RESTART_STRATEGY: + raise ValueError("Length is infinite for restart strategy") + + raise NotImplementedError("Length is not implemented for other strategies") + + async def final_length_is_known(self) -> bool: + if self.stop_strategy == StopStrategy.RESTART_STRATEGY: + return False + + raise NotImplementedError("Length is not known for other strategies") + + def is_finite(self) -> bool: + if self.stop_strategy == StopStrategy.RESTART_STRATEGY: + return False + + return True + + async def current_len(self) -> Optional[int]: + if self.stop_strategy == StopStrategy.RESTART_STRATEGY: + return None + + raise NotImplementedError("Length is not known for other strategies") + + @alru_cache + async def _get_block(self, index: int) -> Optional[np.ndarray]: + if not self.randomize_blocks: + return self._unpermuted_ids + + return np.array(_compute_block_assignment(self._unpermuted_ids, index, self.key)) + + def _index_into_dataset_for_id(self, id: int, block_id) -> tuple[int, int]: + dataset_id = id >> 16 + dataset_index = id & 0xFFFF + return dataset_id, dataset_index + block_id * self._counts_per_block[dataset_id] + + async def get_batch(self, indices: Sequence[int]) -> Sequence[T]: + block_ids = np.array([idx // self.block_size for idx in indices]) + blocks = [self._get_block(block_id) for block_id in block_ids] + blocks = await asyncio.gather(*blocks) + + # split the indices into batches for each dataset + batches_per_dataset: list[list[int]] = [[] for _ in range(len(self.datasets))] + indices_in_final_batch: list[list[int]] = [[] for _ in range(len(self.datasets))] + + assert len(indices) == len(blocks) == len(block_ids) + + for batch_index, (idx, block, block_id) in enumerate(zip(indices, blocks, block_ids)): + index_within_block = idx % self.block_size # which element of the block to get + id = block[index_within_block] # for this block, which dataset+base dataset offset + dataset_id, dataset_index = self._index_into_dataset_for_id(id, block_id) + batches_per_dataset[dataset_id].append(dataset_index) + indices_in_final_batch[dataset_id].append(batch_index) + + # get the batches from each dataset + batch_futures = [] + for dataset_id, indices_for_dataset in enumerate(batches_per_dataset): + if len(indices_for_dataset) == 0: + batch_futures.append(future_from_value([])) + else: + dataset = self._dataset_of_id(dataset_id) + indices_for_dataset = await self._remap_indices(dataset, indices_for_dataset) + batch_futures.append(dataset.get_batch(indices_for_dataset)) + + batches = await asyncio.gather(*batch_futures) + + # reassemble the final batch + final_batch = [None] * len(indices) + + for dataset_id, indices_into_batch in enumerate(indices_in_final_batch): + for i, idx in enumerate(indices_into_batch): + assert final_batch[idx] is None + assert len(final_batch) > idx + final_batch[idx] = batches[dataset_id][i] + + return final_batch # type: ignore + + async def getitem_async(self, index: int) -> T: + # simpler implementation because there's only one + block_id = index // self.block_size + index = index % self.block_size + permuted_ids = await self._get_block(block_id) + dataset_id, dataset_index = self._index_into_dataset_for_id(permuted_ids[index], block_id) + + dataset = self._dataset_of_id(dataset_id) + dataset_index = (await self._remap_indices(dataset, [dataset_index]))[0] + + return await dataset.getitem_async(dataset_index) + + async def _remap_indices(self, ds, indices_into_ds): + """ + Handles wrap around for datasets that have finite length + """ + if self.stop_strategy == StopStrategy.RESTART_STRATEGY: + if ds.is_finite(): + max_elem = max(indices_into_ds) + length_of_dataset = await ds.wait_until_len_at_least(max_elem + 1) + indices_into_ds = [idx % length_of_dataset for idx in indices_into_ds] + + return indices_into_ds + + raise NotImplementedError("Length is not known for other strategies") + + def _dataset_of_id(self, id): + return self.datasets[self.dataset_index[id]] + + +def _compute_block_assignment(base_ids, index, key): + rng = jax.random.fold_in(key, index) + permuted_ids = jax.random.permutation(rng, base_ids) + return permuted_ids diff --git a/src/levanter/data/permutation.py b/src/levanter/data/permutation.py new file mode 100644 index 000000000..a0f0566f4 --- /dev/null +++ b/src/levanter/data/permutation.py @@ -0,0 +1,135 @@ +import dataclasses +from typing import Optional, Sequence + +import jax.random +from async_lru import alru_cache + +from levanter.data import AsyncDataset +from levanter.data._prp import Permutation +from levanter.data.dataset import T_co + + +class PermutationDataset(AsyncDataset[T_co]): + """A permutation dataset that wraps another dataset and applies a permutation to the indices.""" + + # TODO: add epoch reshuffling + + def __init__(self, dataset: AsyncDataset[T_co], key: jax.random.PRNGKey): + self.dataset = dataset + self.key = key + self._permutation: Optional[Permutation] = None + + async def async_len(self) -> int: + return await self.dataset.async_len() + + async def final_length_is_known(self) -> bool: + return await self.dataset.final_length_is_known() + + def is_finite(self) -> bool: + return self.dataset.is_finite() + + async def current_len(self) -> Optional[int]: + return await self.dataset.current_len() + + async def getitem_async(self, index: int) -> T_co: + permutation = await self._get_permutation() + return await self.dataset.getitem_async(permutation(index)) + + async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: + permutation = await self._get_permutation() + return await self.dataset.get_batch([permutation(i) for i in indices]) + + async def _get_permutation(self): + if self._permutation is None: + self._permutation = Permutation(await self.dataset.async_len(), self.key) + return self._permutation + + +class EraShufflingDataset(AsyncDataset[T_co]): + """ + A dataset that shuffles the data in "eras" of fixed length. Era shuffling is somewhere in between a shuffle buffer + and a permutation. It's a "local" permutation where pi(i) \in [ (i//L) * L, (i//L + 1) * L ) for some era length L. + + The advantages of era shuffling are: + - It's stateless, so resumes are easy + - Like shuffle buffers, it's a decent compromise between full shuffling and no shuffling + - Like a shuffle buffer, it's streaming: we don't need to know the length of the data in advance + + The disadvantages are: + - It's not as good as full shuffling + - It distributes less well than a shuffle buffer does. It's more like a "local" shuffle buffer. + - You have to wait for an era to fill before you can start shuffling it. With prefetching, this is less of an issue. + + + # TODO: given the way tokenization works (where it runs way ahead of training), we can probably increase the era + length # over time. This would be a nice feature to have. + """ + + def __init__(self, dataset: AsyncDataset[T_co], era_length: int, *, key: jax.random.PRNGKey): + self.dataset = dataset + self.era_length = era_length + self.key = key + + @alru_cache(maxsize=4) # we're mostly going to be going sequentially + async def gen_era_permutation(era: int) -> Permutation: + # TODO: support epochs + # edge case: final era may be shorter than era_length + current_len = await self.dataset.wait_until_len_at_least((era + 1) * self.era_length) + era_length = min(self.era_length, current_len - era * self.era_length) + + mix_key = jax.random.fold_in(key, era) + return Permutation(era_length, mix_key) + + self.gen_era_permutation = gen_era_permutation + + async def _get_index(self, idx: int) -> int: + if idx < 0: + raise ValueError("Negative indices are not supported") + era = idx // self.era_length + permutation = await self.gen_era_permutation(era) + return permutation(idx - era * self.era_length) + era * self.era_length + + async def async_len(self) -> int: + return await self.dataset.async_len() + + async def final_length_is_known(self) -> bool: + return await self.dataset.final_length_is_known() + + def is_finite(self) -> bool: + return self.dataset.is_finite() + + async def current_len(self) -> Optional[int]: + # nb this is the no-wait length, which means we might be a bit behind the length of the inner dataset + inner_current_len = await self.dataset.current_len() + if inner_current_len is None: + return None + + # if we have the final length, and it's the inner_current_len, then we can return the final length + if await self.final_length_is_known() and inner_current_len == await self.async_len(): + return inner_current_len + + # otherwise, we need to wait for the era to fill + era = inner_current_len // self.era_length + return era * self.era_length + + async def getitem_async(self, index: int) -> T_co: + return await self.dataset.getitem_async(await self._get_index(index)) + + async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: + return await self.dataset.get_batch([await self._get_index(i) for i in indices]) + + def __repr__(self): + return f"EraShufflingDataset({repr(self.dataset)}, era_length={self.era_length})" + + def __str__(self): + return f"EraShufflingDataset({str(self.dataset)})" + + async def wait_until_len_at_least(self, length: int) -> int: + # wait until we hit the next era + next_era_end = (length // self.era_length + 1) * self.era_length + return await self.dataset.wait_until_len_at_least(next_era_end) + + +@dataclasses.dataclass +class EraConfig: + era_length: int diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py deleted file mode 100644 index 8956412b5..000000000 --- a/src/levanter/data/shard_cache.py +++ /dev/null @@ -1,1521 +0,0 @@ -# Dataset for preprocessing data, tokenizing, and caching to disk. -import asyncio -import dataclasses -import heapq -import logging as pylogging -import os -import threading -import time -from contextlib import AbstractContextManager -from dataclasses import dataclass -from typing import IO, Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, TypeVar - -import fsspec.core -import pyarrow as pa -import pyarrow.parquet as pq -import ray -from dataclasses_json import dataclass_json -from fsspec import AbstractFileSystem -from ray.actor import ActorHandle -from ray.exceptions import GetTimeoutError - -from ..utils.ray_utils import ( - ExceptionInfo, - RefBox, - SnitchRecipient, - current_actor_handle, - log_failures_to, - ser_exc_info, -) -from ._preprocessor import BatchProcessor, BatchResult, as_record_batch, dict_from_record_batch -from ._queue import ( - PriorityProcessorActor, - PriorityWorkItem, - PriorityWorkTaskGroup, - PriorityWorkTaskGroupSpec, - _BatchProcessorQueue, -) -from .dataset import ShardableDataset -from .metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor -from .sharded_dataset import ShardedDataset - - -G = TypeVar("G") -T = TypeVar("T") -T_co = TypeVar("T_co", covariant=True) - - -logger = pylogging.getLogger(__name__) - -DEFAULT_ROWS_PER_CHUNK = 8192 -DEFAULT_MAX_BYTES_PER_BATCH = 256 * 1024 * 1024 # 256 MB, this is pre-preprocessing python object size -DEFAULT_MAX_SHARDS_TO_READ_AT_ONCE = 32 -LEDGER_FILE_NAME = "cache_ledger.json" - -LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" -LEVEL_TO_LOG = pylogging.INFO - - -def build_or_load_cache( - cache_dir: str, - input_shards: ShardedDataset[T], - processor: BatchProcessor[T], - batch_size: int = 1, - rows_per_chunk: int = DEFAULT_ROWS_PER_CHUNK, - await_finished: bool = True, - monitors: Optional[Sequence["MetricsMonitor"]] = None, - cache_config: Optional[Dict[str, Any]] = None, -) -> "ShardCache": - """ - Produces a sharded cache of the dataset using Ray for distributed processing. The cache can be any path - on any file system understood by fsspec. - - This system is designed with tokenization and similar processes in mind, but it can potentially be used for any kind - of preprocessing that converts input batches to output batches. The main design goal is to make it easy to - parallelize preprocessing across multiple machines while maintaining reproducibility and fault tolerance. - Usually the machines in question are the ones doing the training, but they could be separate machines as well. - - See the [Dataloader Design Doc](https://github.com/stanford-crfm/levanter/blob/main/docs/design/Data-Loader-Design.md) - for a somewhat out of date overview of the design. - - Args: - cache_dir: The directory to write the cache to. This can be any path understood by fsspec. - input_shards: A ShardedDataset that will be used to read the input data. Conceptually, it's just a mapping - from shard names to iterators over the data in that shard. - processor: A BatchProcessor that will be used to process batches of data. This is the main place where - you can customize the preprocessing pipeline. - batch_size: When reading from the cache, how many examples to read at a time. - rows_per_chunk: The number of rows to write to each chunk. May be smaller at the end of a shard. - await_finished: If True, this function will block until the cache is finished. If False, it will return - immediately. - monitors: a list of MetricsMonitors to attach to the cache. These will be called periodically with - metrics about the cache build process. If None, will add a LoggerMetricsMonitor. - - Returns: - (ShardCache) A ShardCache object that can be used to read the cache. - - """ - # first see if we need to do anything - cache = ShardCache.build_or_load( - cache_dir=cache_dir, - shard_source=input_shards, - processor=processor, - batch_size=batch_size, - rows_per_chunk=rows_per_chunk, - cache_config=cache_config, - ) - - if cache.is_finished: - logger.info("Cache already finished. Skipping.") - return cache - - if monitors is None: - monitors = [LoggerMetricsMonitor()] - - for monitor in monitors: - cache.attach_metrics_monitor(monitor) - - while await_finished: - try: - cache.await_finished(4.0) - break - except TimeoutError: - pass - - return cache - - -@dataclass_json -@dataclass -class ChunkMetadata: - name: str - num_rows: int - field_counts: Dict[str, int] - - -@dataclass_json -@dataclass -class ShardMetadata: - chunks: List[ChunkMetadata] = dataclasses.field(default_factory=list) - is_finished: bool = False - - @property - def total_rows(self): - return sum(chunk.num_rows for chunk in self.chunks) - - @property - def total_chunks_produced(self): - return len(self.chunks) - - -@dataclass_json -@dataclass -class CacheLedger: - """Written at the end of the cache build process. Contains the global chunk order.""" - - chunks: List[ChunkMetadata] = dataclasses.field(default_factory=list) - metadata: Dict[str, Any] = dataclasses.field(default_factory=dict) - - -class SerialCacheWriter(AbstractContextManager): - """ - Writes ShardCache-compatible caches to disk. This is a serial version of ShardCacheWriter that doesn't use Ray. - Mostly for scripts and debugging. - - Examples: - >>> with SerialCacheWriter(cache_dir, rows_per_chunk=1024) as writer: - ... for batch in process_batches(): - ... writer.write_batch(batch) - """ - - def __init__( - self, - cache_dir: str, - rows_per_chunk: int = DEFAULT_ROWS_PER_CHUNK, - cache_config: Optional[Dict[str, Any]] = None, - ): - if rows_per_chunk <= 0: - raise ValueError("rows_per_chunk must be positive") - self.cache_dir = cache_dir - self.cache_config = cache_config - self._rows_per_chunk = rows_per_chunk - self._chunks: List[ChunkMetadata] = [] - self._current_chunk_writer: Optional[_ChunkWriter] = None - self._is_closed = False - - def __enter__(self) -> "SerialCacheWriter": - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - # if successful, write the ledger - if self._current_chunk_writer is not None: - self._current_chunk_writer.__exit__(exc_type, exc_val, exc_tb) - self._chunks.append(self._current_chunk_writer.get_metadata()) - self._current_chunk_writer = None - - if exc_type is None: - _serialize_json_and_commit( - os.path.join(self.cache_dir, LEDGER_FILE_NAME), CacheLedger(self._chunks, self.cache_config) - ) - logger.info(f"Cache ledger written to {self.cache_dir}") - self._is_closed = True - - def result(self, batch_size: int = 1) -> "ShardCache": - if not self._is_closed: - raise RuntimeError("Cannot get result until ShardCacheWriter is closed") - return ShardCache.load(self.cache_dir, batch_size=batch_size) - - def write_batch(self, batch: BatchResult): - rb = as_record_batch(batch) - - while rb.num_rows > 0: - if self._current_chunk_writer is None: - self._current_chunk_writer = _ChunkWriter( - self.cache_dir, f"chunk-{len(self._chunks)}", rb.schema - ).__enter__() - - slice = rb.slice(0, min(rb.num_rows, self._rows_per_chunk - self._current_chunk_writer.num_rows)) - self._current_chunk_writer.write_batch(slice) - rb = rb.slice(slice.num_rows) - - if self._current_chunk_writer.num_rows >= self._rows_per_chunk: - self._current_chunk_writer.__exit__(None, None, None) - self._chunks.append(self._current_chunk_writer.get_metadata()) - self._current_chunk_writer = None - - -class _ChunkWriter: - def __init__(self, cache_dir: str, chunk_name: str, schema: pa.Schema): - self.cache_dir = cache_dir - self.chunk_name = chunk_name - self.schema = schema - self.file: Optional[IO] = None - self.writer: Optional[pq.ParquetWriter] = None - self.num_rows = 0 - self.field_counts: Dict[str, int] = {} - - self.is_finished = False - - def __enter__(self): - self.file = fsspec.open(os.path.join(self.cache_dir, f"{self.chunk_name}.parquet"), "wb").__enter__() - self.writer = pq.ParquetWriter(self.file, self.schema, version="2.6", compression="ZSTD").__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.writer is not None: - self.writer.__exit__(exc_type, exc_val, exc_tb) - if self.file is not None: - self.file.__exit__(exc_type, exc_val, exc_tb) - - self.is_finished = True - - def write_batch(self, batch: pa.RecordBatch): - assert not self.is_finished - assert self.writer is not None - self.writer.write_batch(batch) - self.num_rows += batch.num_rows - - for i in range(batch.num_columns): - name = batch.field(i).name - value = batch.column(i) - if isinstance(value, pa.ListArray): - value = value.flatten() - self.field_counts[name] = self.field_counts.get(name, 0) + len(value) - elif isinstance(value, pa.ChunkedArray): - self.field_counts[name] = self.field_counts.get(name, 0) + value.length() - - def get_metadata(self) -> ChunkMetadata: - if not self.is_finished: - raise RuntimeError("Cannot get metadata for unfinished chunk") - return ChunkMetadata(self.chunk_name, self.num_rows, self.field_counts) - - -class _ShardMetadataWriter: - def __init__(self, metadata_path): - self.metadata_path = metadata_path - try: - with fsspec.open(self.metadata_path, "r") as file: - self.metadata = ShardMetadata.from_json(file.read()) # type: ignore - except FileNotFoundError: - self.metadata = ShardMetadata() - - @property - def is_finished(self): - return self.metadata.is_finished - - @property - def chunks(self): - return self.metadata.chunks - - @property - def num_chunks(self): - return len(self.metadata.chunks) - - def commit_chunk(self, chunk: ChunkMetadata): - assert not self.metadata.is_finished - self.metadata.chunks.append(chunk) - self._commit() - - def finish(self): - self.metadata.is_finished = True - self._commit() - - def _commit(self): - _serialize_json_and_commit(self.metadata_path, self.metadata) - - -# thinking through the design of the cache system - -# we decided to use Ray, which was maybe a mistake, but here we are. -# Ray doesn't like it when the number of actors gets too large, so we can't have one actor per shard. -# we have N nodes and K shards. We want to produce chunks of size C examples, from each shards. -# We define a global order over chunks [shard[0].chunk[0], shard[1].chunk[0], ... shard[K].chunk[0], shard[0].chunk[1], ...] -# with the obvious extension for if one shard has more chunks than another. -# We want to produce chunks in roughly this order, but we want to do it in parallel. -# We also want to be able to recover from failures, and we want to be able to resume a cache build. - -# at a high level, we have 3 steps: -# 1. read batches from the source -# 2. process batches, concatenating them into chunks -# 3. write chunks to disk - -# The difficulty is that we want parallelism and we want to control the order of chunks. -# reading batches requires CPU and network. This means we should limit the number to roughly the number of nodes, maybe times 2. -# We want to prioritize so that we read 1 chunks worth of batches from each shard before reading more from another shard. -# We also want to prioritize reading earlier shards before later shards (within a chunk generation round). -# Ray also seems to get upset about having too many processes, and we can't serialize the iterators over shards. - - -def _shard_reader_generator(shard_source: ShardedDataset[T], shard_name: str, start_row: int, batch_size: int): - shard_iter = shard_source.open_shard_at_row(shard_name, start_row) - batch = [] - for row in shard_iter: - batch.append(row) - - if len(batch) == batch_size: - yield batch - batch = [] - - if len(batch) > 0: - yield batch - - -@dataclass -class ShardGroupToBeProcessed(PriorityWorkTaskGroupSpec): - name: str - builder_ref: ray.actor.ActorHandle # _ChunkCacheBuilder - writer: ray.actor.ActorHandle # _GroupedShardWriter - shard_source: ShardedDataset - shard_names: Sequence[str] - priority_fn: Callable[[int, int], float] - processor_actor: ray.actor.ActorHandle # BatchProcessorQueue - batch_size: int - num_rows_per_chunk: int - group_id: int - - def build(self) -> "PriorityWorkTaskGroup": - return ShardGroupTaskGroup(self) - - -class ShardGroupTaskGroup(PriorityWorkTaskGroup): - def __init__(self, spec: ShardGroupToBeProcessed): - self.spec = spec - self.logger = pylogging.getLogger(f"shard_reader.{spec.group_id}.{spec.name}") - - try: - metadata: dict[str, ShardMetadata] = _initial_shard_metadatas( - self.spec.shard_source, self.spec.shard_names, self.spec.writer - ) - except Exception as e: - self.spec.builder_ref.other_failed.remote(ser_exc_info()) - raise e - - batch_size = min(self.spec.batch_size, self.spec.num_rows_per_chunk) - - self._items: list[PriorityWorkItem] = [] - - for shard_name in self.spec.shard_names: - shard_idx = self.spec.shard_source.shard_names.index(shard_name) - try: - shard_metadata = metadata[shard_name] - reader = _shard_reader_generator( - self.spec.shard_source, shard_name, shard_metadata.total_rows, batch_size - ) - - if shard_metadata.is_finished: - self.logger.info(f"Shard {shard_name} already finished. Skipping.") - - task_name = f"shard_reader.{self.spec.name}.{shard_name}" - - chunk_idx = len(shard_metadata.chunks) - item = ShardReaderItem(self, task_name, shard_name, shard_idx, chunk_idx, reader) - - heapq.heappush(self._items, item) - except Exception as e: - self.logger.exception(f"Error while initializing shard {shard_name}") - self.spec.writer[shard_name].shard_failed.remote(ser_exc_info()) - raise e - - @property - def name(self): - return self.spec.name - - def items(self) -> Sequence["PriorityWorkItem"]: - return self._items - - -# NB This class is stateful -@dataclass -class ShardReaderItem(PriorityWorkItem): - """ - Each time execute is called, this class reads one chunk's worth of batches from the shard - and dispatches them to the processor. - """ - - group: ShardGroupTaskGroup - name: str - shard_name: str - shard_idx: int - chunk_idx: int - reader: Iterator[list] - - @property - def priority(self): - return self.group.spec.priority_fn(self.shard_idx, self.chunk_idx) - - @property - def spec(self): - return self.group.spec - - def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: - exhausted_shard = False - writer = self.spec.writer - - chunk_batch_idx = 0 # the index of the batch within the chunk - chunk_filled = False # whether or not we've filled the chunk to max size - total_chunk_rows = 0 # the total number of rows in the chunk - batch_result_ref = None - - self.group.logger.debug(f"Reading one chunk of shard {self.shard_name}: {self.chunk_idx}") - - try: - while not chunk_filled: - batch = next(self.reader, None) - if batch is None: - exhausted_shard = True - break - - exhausted_shard = len(batch) < self.spec.batch_size - total_chunk_rows += len(batch) - - if batch: - priority = self.spec.priority_fn(self.shard_idx, self.chunk_idx) - # these times aren't exact because the times might be from different machines - # but they're just for logging - time_in = time.time() - batch_result_ref = ray.get( - self.spec.processor_actor.submit.remote( - priority=priority, - desc=f"{self.shard_name}.{self.chunk_idx}.{chunk_batch_idx}", - batch=RefBox(ray.put(batch)), - ) - ) - writer.chunk_batch_finished.remote( - self.shard_name, self.chunk_idx, chunk_batch_idx, RefBox(batch_result_ref), time_in - ) - chunk_batch_idx += 1 - del batch - - if total_chunk_rows >= self.spec.num_rows_per_chunk or exhausted_shard: - chunk_filled = True - - if chunk_batch_idx > 0: - writer.chunk_finished_reading.remote(self.shard_name, self.chunk_idx, chunk_batch_idx) - old_prio = self.priority - self.chunk_idx += 1 - assert self.priority > old_prio - - if exhausted_shard: - writer.shard_finished_reading.remote(self.shard_name, self.chunk_idx) - - self.group.logger.debug( - f"Finished reading one chunk of shard {self.shard_name}: {self.chunk_idx} {exhausted_shard}" - ) - - return exhausted_shard, batch_result_ref - except Exception as e: # noqa - self.group.logger.exception(f"Error while processing shard {self.shard_name}") - # fire and forget - writer.shard_failed.remote(self.shard_name, ser_exc_info()) - raise e - - -def _initial_shard_metadatas(shard_source, shard_names, shard_group_writer): - shard_metadatas: dict[str, ShardMetadata] = {} - _metadata_futures = [shard_group_writer.current_metadata.remote(name) for name in shard_names] - shard_metadatas_rs = ray.get(_metadata_futures) - for shard_name, shard_metadata in zip(shard_names, shard_metadatas_rs): - shard_metadatas[shard_name] = shard_metadata - return shard_metadatas - - -def _serialize_json_and_commit(path, obj): - # just to be paranoid, we write to a temp file and then rename it - # TODO: probably we could do better here - with fsspec.open(f"{path}.tmp", "w") as file: - file.write(obj.to_json()) - # now copy the old file to a backup - fs: AbstractFileSystem = fsspec.core.url_to_fs(path)[0] - fs.mkdirs(os.path.dirname(path), exist_ok=True) - if fs.exists(path): - fs.copy(path, f"{path}.bak") - fs.rename(f"{path}.tmp", path) - - -def _load_cache_ledger(cache_dir) -> CacheLedger: - try: - ledger_path = os.path.join(cache_dir, LEDGER_FILE_NAME) - logger.debug(f"Attempting to load cache ledger from {ledger_path}") - with fsspec.open(ledger_path) as file: - cache_ledger = CacheLedger.from_json(file.read()) # type: ignore - return cache_ledger - except FileNotFoundError: - raise FileNotFoundError(f"Cache ledger not found at {ledger_path}") - - -@dataclass -class _ShardStatus: - num_chunks_sent: int = 0 - current_buffer: list[ChunkMetadata] = dataclasses.field(default_factory=list) - expected_num_chunks: Optional[int] = None - - def pop_chunk_to_send(self) -> Optional[ChunkMetadata]: - if len(self.current_buffer) == 0: - return None - else: - self.num_chunks_sent += 1 - return self.current_buffer.pop(0) - - @property - def is_finished_and_buffer_empty(self): - return self.expected_num_chunks is not None and self.num_chunks_sent >= self.expected_num_chunks - - -# Ray does poorly with large numbers of actors (grumble grumble), so we can't have one actor per shard. -# This class wraps a map of shard names to _ShardWriterWorkers, and manages the lifecycle of the workers. -@ray.remote(num_cpus=0.0, scheduling_strategy="SPREAD") # type: ignore -class _GroupShardWriterWorker: - def __init__(self, parent_ref, cache_dir: str, shard_names: Sequence[str]): - with log_failures_to(parent_ref): - pylogging.basicConfig(level=LEVEL_TO_LOG, format=LOG_FORMAT) - self.cache_dir = cache_dir - self.shard_names = shard_names - self.shard_writers: dict[str, _ShardWriterWorker] = { - shard_name: _ShardWriterWorker(parent_ref, cache_dir, shard_name) for shard_name in shard_names - } - - def current_metadata(self, shard_name: str): - return self.shard_writers[shard_name].current_metadata() - - async def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: int, batch: RefBox, time_in): - # batch is a pa.RecordBatch ref box - try: - time_mid = time.time() - logger.debug( - f"Received in progress batch {batch_idx} of chunk {chunk_id} of shard {shard_name} in" - f" {time_mid - time_in}" - ) - # do a backoff loop until the batch is actually processed. log if it's been a while - timeout_interval = 20 - total_time_waited = 0 - - while True: - try: - # batch = await asyncio.wait_for(asyncio.shield(batch.ref), timeout_interval) - batch = await batch.ref - break - except asyncio.TimeoutError: - # to keep to round numbers, we log how much we asked for rather than how much we got - total_time_waited += timeout_interval - timeout_interval = min(2 * timeout_interval, 100) - logger.info( - f"Waiting for {shard_name}.{chunk_id}.{batch_idx} to be processed. " - f"Waited {total_time_waited} seconds." - ) - - if logger.isEnabledFor(pylogging.DEBUG): - logger.debug( - f"Received finished {shard_name}.{chunk_id}.{batch_idx} in {(time.time() - time_in):.2f} seconds." - ) - elif total_time_waited > 40: - logger.info( - f"Waited {total_time_waited} seconds for {shard_name}.{chunk_id}.{batch_idx} to be processed." - ) - return self.shard_writers[shard_name].chunk_batch_finished(chunk_id, batch_idx, batch) - except Exception as e: - print(f"Error while processing batch {batch_idx} of chunk {chunk_id} of shard {shard_name}", flush=True) - self.shard_writers[shard_name].chunk_failed(chunk_id, ser_exc_info()) - raise e - - def chunk_finished_reading(self, shard_name: str, chunk_id: int, expected_num_batches: int): - return self.shard_writers[shard_name].chunk_finished_reading(chunk_id, expected_num_batches) - - def chunk_failed(self, shard_name: str, chunk_id: int, error: ExceptionInfo): - return self.shard_writers[shard_name].chunk_failed(chunk_id, error) - - def shard_finished_reading(self, shard_name: str, expected_num_chunks: int): - return self.shard_writers[shard_name].shard_finished_reading(expected_num_chunks) - - def shard_failed(self, shard_name: str, error: ExceptionInfo): - return self.shard_writers[shard_name].shard_failed(error) - - -class _ShardWriterWorker: # type: ignore - """ - Actor that writes chunks to disk and updates the ShardMetadata. It reports to the ChunkCacheBroker - """ - - def __init__( - self, - parent_ref: ActorHandle, # ChunkCacheBuilder - cache_dir: str, - shard_name: str, - ): - pylogging.basicConfig(level=LEVEL_TO_LOG, format=LOG_FORMAT) - self.parent_ref = parent_ref - self.cache_dir = cache_dir - self.shard_name = shard_name - self.uncommited_chunks: list[tuple[int, ChunkMetadata]] = [] # heapq of (chunk index, chunk) - - self.metadata_writer = _ShardMetadataWriter(os.path.join(cache_dir, f"{shard_name}.json")) - self._expected_num_chunks: Optional[int] = None - - if self.metadata_writer.num_chunks > 0: - self.parent_ref.new_chunk.remote(shard_name, *self.metadata_writer.chunks) - - if self.metadata_writer.is_finished: - self._expected_num_chunks = self.metadata_writer.num_chunks - self.parent_ref.shard_finished.remote(self.shard_name, self._expected_num_chunks) - self.finished = True - else: - self.finished = False - - self.collator = _ChunkCollator(cache_dir, shard_name) - - def current_metadata(self): - return self.metadata_writer.metadata - - # forward some methods to the collator, handle any metadata that comes back - def chunk_batch_finished(self, chunk_id: int, batch_idx: int, batch: pa.RecordBatch): - metadata = self.collator.new_batch(chunk_id, batch_idx, batch) - if metadata is not None: - self._finished_chunk(chunk_id, metadata) - - return metadata - - def chunk_finished_reading(self, chunk_id: int, expected_num_batches: int): - metadata = self.collator.chunk_finished_reading(chunk_id, expected_num_batches) - if metadata is not None: - self._finished_chunk(chunk_id, metadata) - - return metadata - - def chunk_failed(self, chunk_id: int, error: ExceptionInfo): - self.collator.chunk_failed(chunk_id, error) - print(f"Error while processing chunk {chunk_id} of shard {self.shard_name}", flush=True) - self.parent_ref.shard_failed.remote(self.shard_name, error) - - def _finished_chunk(self, idx: int, chunk: ChunkMetadata): - if (idx < self.metadata_writer.num_chunks) or ( - self._expected_num_chunks is not None and idx >= self._expected_num_chunks - ): - logger.error(f"Received chunk {idx} for {self.shard_name} but it's already finished") - error = RuntimeError(f"Received chunk {idx} for {self.shard_name} but it's already finished") - self.parent_ref.shard_failed.remote(self.shard_name, ser_exc_info(error)) - raise error - - heapq.heappush(self.uncommited_chunks, (idx, chunk)) - self._attempt_to_commit_chunks() - - def shard_finished_reading(self, expected_num_chunks: int): - # TODO: add metadata that we're done reading to metrics - self._expected_num_chunks = expected_num_chunks - self._attempt_to_commit_chunks() - - def shard_failed(self, error: ExceptionInfo): - self.parent_ref.shard_failed.remote(self.shard_name, error) - - def _attempt_to_commit_chunks(self): - chunks_committed = [] - while len(self.uncommited_chunks) > 0 and self.uncommited_chunks[0][0] == self.metadata_writer.num_chunks: - _, chunk = heapq.heappop(self.uncommited_chunks) - chunk_number = self.metadata_writer.num_chunks - logger.debug(f"Committing chunk {chunk.name} of shard {self.shard_name}. It is chunk {chunk_number}") - self.metadata_writer.commit_chunk(chunk) - chunks_committed.append(chunk) - - if len(chunks_committed) > 0: - if self.finished: - raise RuntimeError("Tried to commit chunks after shard finished") - # TODO: this is called inside an async call so we need to not block, but we do need to sequence - # this to come before the shard_finished - self.parent_ref.new_chunk.remote(self.shard_name, *chunks_committed) - - if not self.finished and self.metadata_writer.num_chunks == self._expected_num_chunks: - self.metadata_writer.finish() - self.finished = True - self.parent_ref.shard_finished.remote(self.shard_name, self._expected_num_chunks) - - -class _ChunkCollator: - """ - This class is responsible for taking batches from the processor and writing them to disk in order. - It also handles the logic of when to commit chunks to disk. - - For each chunk (that is has data for and hasn't finished), it keeps a heapq of batches that have been - processed but not yet written to disk. When a new batch comes in, it checks if it's the next batch in the - chunk. If so, it writes it to disk and flushes any other batches that are ready to be written. - - A chunk isn't finished until it's received all the batches it's expecting and it knows how many batches - to expect. - - """ - - def __init__(self, cache_dir: str, shard_name: str): - self.cache_dir = cache_dir - self.shard_name = shard_name - self.chunk_writers: dict[int, _ChunkWriter] = {} # chunk index -> writer - self.batch_counts: dict[int, int] = {} # chunk index -> number of batches written - self.expected_totals: dict[int, int] = {} # chunk index -> expected num batches. - self.failed_chunks: dict[int, ExceptionInfo] = {} # chunk index -> error - self.chunk_partial_batches: dict[ - int, list[tuple[int, pa.RecordBatch]] - ] = {} # chunk index -> heapq of (batch index, batch) - - def new_batch(self, chunk_id, batch_idx, batch) -> Optional[ChunkMetadata]: - if chunk_id not in self.chunk_partial_batches: - self.chunk_partial_batches[chunk_id] = [] - self.batch_counts[chunk_id] = 0 - - heapq.heappush(self.chunk_partial_batches[chunk_id], (batch_idx, batch)) - - return self._attempt_to_write_chunk_fragments(chunk_id) - - def chunk_finished_reading(self, chunk_id, expected_num_batches) -> Optional[ChunkMetadata]: - self.expected_totals[chunk_id] = expected_num_batches - return self._attempt_to_write_chunk_fragments(chunk_id) - - def chunk_failed(self, chunk_id, error: ExceptionInfo): - self.failed_chunks[chunk_id] = error - if chunk_id in self.chunk_writers: - self.chunk_writers[chunk_id].__exit__(*error.restore()) - del self.chunk_writers[chunk_id] - - def _attempt_to_write_chunk_fragments(self, chunk_id) -> Optional[ChunkMetadata]: - if chunk_id in self.failed_chunks: - logger.error(f"Chunk {chunk_id} of shard {self.shard_name} already failed, not writing more") - raise self.failed_chunks[chunk_id].restore() - - if chunk_id in self.chunk_partial_batches: - chunk_batches = self.chunk_partial_batches[chunk_id] - - while len(chunk_batches) > 0: - batch_id, batch = chunk_batches[0] - if batch_id != self.batch_counts[chunk_id]: - break - - # we can write this batch - batch_id, batch = heapq.heappop(chunk_batches) - - if chunk_id not in self.chunk_writers: - assert batch_id == 0, f"Expected batch 0 but got {batch_id}" - chunk_name = os.path.join(self.shard_name, f"chunk-{chunk_id}") - writer = _ChunkWriter(self.cache_dir, chunk_name, batch.schema) - writer.__enter__() - self.chunk_writers[chunk_id] = writer - - self.chunk_writers[chunk_id].write_batch(batch) - self.batch_counts[chunk_id] += 1 - - if chunk_id not in self.batch_counts: - return None - - if chunk_id in self.expected_totals and self.batch_counts[chunk_id] == self.expected_totals[chunk_id]: - assert len(chunk_batches) == 0 - # we're done with this chunk - writer = self.chunk_writers[chunk_id] - writer.__exit__(None, None, None) - del self.chunk_writers[chunk_id] - del self.batch_counts[chunk_id] - del self.chunk_partial_batches[chunk_id] - return writer.get_metadata() - else: - return None - - -@ray.remote(num_cpus=0.5) # keep this small b/c it doesn't do a lot -class ChunkCacheBuilder(SnitchRecipient): - """ - Actor that manages the in-progress global ordering on chunks. ChunkCacheWriter's job is to hold the list of all - chunks as well as chunks from each shard while caching is running. - - This is a separate actor from the ChunkCacheBroker because - we need something that gets messages from shards in-order, and async methods make actors - lose that property. - """ - - def __init__( - self, - broker_ref, - cache_dir: str, - name: str, - source: ShardedDataset[T], - processor: BatchProcessor[T], - rows_per_chunk: int, - ): - with log_failures_to(broker_ref): - pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) - self.logger = pylogging.getLogger(f"{__name__}.{name}") - self.broker_ref = broker_ref - self.shard_status: Dict[str, _ShardStatus] = dict() - self._current_round_robin = [] - self.source = source - self._metrics = InProgressCacheMetrics() - - self_ref = current_actor_handle() - - if len(source.shard_names) == 0: - self.logger.warning("No shards to index?!?") - self._finish() - else: - self.logger.info(f"Starting cache build for {len(source.shard_names)} shards") - - self._shard_writers = [] - self._shard_readers = [] - self._processor_actors = [] - - for shard_name in source.shard_names: - self._current_round_robin.append(shard_name) - self.shard_status[shard_name] = _ShardStatus() - - num_shards = len(source.shard_names) - num_worker_groups = len(ray.nodes()) - num_shard_groups = max(min(num_worker_groups, num_shards), 1) - - # if we have a bunch of caches to build with one shard, we don't want them all - # assigned to the same node, so we use an offset based on the hash of the name (for stability) - # in an attempt to spread them out - group_offset = int(hash(name) % num_worker_groups) - - shard_groups: list[list[str]] = [[] for _ in range(num_shard_groups)] - for i, shard_name in enumerate(source.shard_names): - shard_groups[i % num_shard_groups].append(shard_name) - - def priority_fn(shard_idx, chunk_idx): - return chunk_idx * num_shards + shard_idx - - for group_id, shard_group in enumerate(shard_groups): - writer = _GroupShardWriterWorker.remote(self_ref, cache_dir, shard_group) # type: ignore - self._shard_writers.append(writer) - - # TODO: would probably be better if we didn't create one of these per shard group - processor_actor = _BatchProcessorQueue.remote(processor) # type: ignore - self._processor_actors.append(processor_actor) - - work_item = ShardGroupToBeProcessed( - name=name, - builder_ref=self_ref, - writer=writer, - shard_source=source, - shard_names=shard_group, - priority_fn=priority_fn, - processor_actor=processor_actor, - batch_size=processor.batch_size, - num_rows_per_chunk=rows_per_chunk, - group_id=group_id, - ) - - # we want global names so that different tasks can coordinate priorities - worker_to_assign = (group_id + group_offset) % num_worker_groups - priority_actor_name = f"priority_processor.{worker_to_assign}" - - reader_actor = PriorityProcessorActor.options( # type: ignore - name=priority_actor_name, get_if_exists=True - ).remote() - - reader_actor.assign_work.remote(work_item) - - self._shard_readers.append(reader_actor) - - def new_chunk(self, shard_name: str, *chunks: ChunkMetadata): - """Callback method for when a shard worker has produced a new chunk.""" - self.shard_status[shard_name].current_buffer.extend(chunks) - - # if we have buffered chunks, we need to check if we can send them to the broker - self._attempt_to_flush_buffers() - - self._metrics.chunks_finished += len(chunks) - # update metrics - for chunk in chunks: - self._metrics.rows_finished += chunk.num_rows - for field, count in chunk.field_counts.items(): - self._metrics.field_counts[field] = self._metrics.field_counts.get(field, 0) + count - - if len(chunks) > 0: - ray.get(self.broker_ref._new_metrics.remote(self._metrics)) - - def shard_finished(self, shard_name: str, expected_num_chunks: int): - """Callback method for when a shard worker has finished.""" - shard_status = self.shard_status[shard_name] - assert ( - shard_status.expected_num_chunks is None - ), f"Shard {shard_name} already finished: {shard_status.expected_num_chunks} {expected_num_chunks}" - shard_status.expected_num_chunks = expected_num_chunks - - # we might still have buffered chunks, so we need to check if we can append them - self._attempt_to_flush_buffers() - self._metrics.shards_finished += 1 - ray.get(self.broker_ref._new_metrics.remote(self._metrics)) - - # if there are no more active shards, we're done - if self._all_shards_done(): - assert len(self._current_round_robin) == 0 - self._finish() - - def _all_shards_done(self): - return all(status.is_finished_and_buffer_empty for status in self.shard_status.values()) - - def shard_failed(self, shard_name: str, error: ExceptionInfo): - """Callback method for when a shard worker has failed.""" - ray.get(self.broker_ref._writer_exception.remote(shard_name, error)) - - def other_failed(self, error: ExceptionInfo): - """Callback method for when a shard worker has failed.""" - ray.get(self.broker_ref._writer_exception.remote(None, error)) - - def _attempt_to_flush_buffers(self): - # this is the most complex logic in this class. - # The global order on chunks is defined as a roundrobin over shards, until one shard is done. - # After that, that shard is removed from the roundrobin and the process continues. - # Roundrobin order is determined by self.source.shard_names - - # We are happy to release chunks that form a prefix of the global order so that they can be read. - # To do that, we maintain the roundrobin order in self._current_round_robin - # and we maintain the current buffer for each shard in self.shard_status. - # When we get a new chunk, we append it to the buffer for that shard. - # When we get a finished message, we mark that shard as finished. - # In either case, we check if we can send any chunks from the front of the roundrobin. - # If we can, we send them to the broker - - # here "finished" means that the shard has sent all of its chunks and has told us that it's done. - - chunks_to_send = [] - - while len(self._current_round_robin) > 0: - name = self._current_round_robin[0] - status = self.shard_status[name] - if status.is_finished_and_buffer_empty: - # we're done with this shard, so we can remove it from the roundrobin - self._current_round_robin.pop(0) - logger.debug(f"Shard {name} is finished, removing from round robin") - continue - - # now let's see if we can send a chunk from this shard - next_chunk = status.pop_chunk_to_send() - if next_chunk is not None: - # we can send a chunk from this shard - self._current_round_robin.pop(0) - self._current_round_robin.append(name) - chunks_to_send.append(next_chunk) - continue - else: - # we can't send a chunk from this shard, so we can't send any additional chunks - if self.logger.level <= pylogging.DEBUG: - chunks_waiting = [ - f"{n2} ({len(s2.current_buffer)})" - for n2, s2 in self.shard_status.items() - if len(s2.current_buffer) > 0 - ] - msg = ( - f"Shard {name} has no chunks to send and is not known to be finished. We have this many queued" - f" chunks: {chunks_waiting}" - ) - self.logger.debug(msg) - break - - if len(chunks_to_send) > 0: - logger.debug(f"Sending {len(chunks_to_send)} chunks to broker") - ray.get(self.broker_ref._append_chunks.remote(*chunks_to_send)) - - def _finish(self): - self._metrics.is_finished = True - ray.get(self.broker_ref._new_metrics.remote(self._metrics)) - ray.get(self.broker_ref._finalize.remote()) - # self._shard_writers = [] - # self._shard_readers = [] - - -@ray.remote(num_cpus=0) -class ChunkCacheBroker(SnitchRecipient): - """Actor that manages the global order on chunks and vends chunk metadata to readers.""" - - chunks: List[ChunkMetadata] - _reader_promises: Dict[int, asyncio.Future[ChunkMetadata]] - _finished_promise: asyncio.Future[None] - - def __init__( - self, - cache_dir: str, - source: ShardedDataset[T], - processor: BatchProcessor[T], - rows_per_chunk: int, - cache_config: Optional[Dict[str, Any]], - ): - pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) - self.chunks = [] - self._reader_promises = {} - self._is_finished = False - self._source = source - self._processor = processor - self._cache_dir = cache_dir - self._rows_per_chunk = rows_per_chunk - self._finished_promise = asyncio.Future() - # used to subscribe to metrics updates - self._latest_metrics = InProgressCacheMetrics() - self._metrics_condition = asyncio.Condition() - self._cache_config = cache_config - path_for_name = os.path.join(*self._cache_dir.split("/")[-2:]) - name = f"broker::{path_for_name}" - self.logger = pylogging.getLogger(f"{name}") - - # initialize writer task - # first see if we need to do anything: check the ledger for is_finished - try: - cache_ledger = _load_cache_ledger(self._cache_dir) - self._append_chunks(*cache_ledger.chunks) - self._is_finished = True - self._finished_promise.set_result(None) - except FileNotFoundError: - self_ref = ray.runtime_context.get_runtime_context().current_actor - # only use the last two components of the name since it gets kind of long - name = f"builder::{path_for_name}" - self._builder_actor = ChunkCacheBuilder.remote( # type: ignore - self_ref, - self._cache_dir, - name, - self._source, - self._processor, - rows_per_chunk, - ) # type: ignore - - def is_finished(self): - return self._is_finished - - async def finished_sentinel(self): - await self._finished_promise - - async def updated_metrics(self) -> InProgressCacheMetrics: - if self._finished_promise.done(): - if self._finished_promise.exception() is not None: - raise self._finished_promise.exception() # type: ignore - else: - return self._latest_metrics - - async with self._metrics_condition: - await self._metrics_condition.wait() - return self._latest_metrics - - async def get_chunk(self, chunk_idx: int) -> Optional[ChunkMetadata]: - assert isinstance(self.chunks, list), self.chunks - if chunk_idx < len(self.chunks): - return self.chunks[chunk_idx] - elif self._is_finished: - return None - elif self._finished_promise.exception() is not None: - raise self._finished_promise.exception() # type: ignore - else: - if chunk_idx not in self._reader_promises: - self._reader_promises[chunk_idx] = asyncio.Future() - return await self._reader_promises[chunk_idx] - - async def final_chunk_count(self) -> Optional[int]: - if self._is_finished: - return len(self.chunks) - else: - return None - - def _append_chunks(self, *chunks: ChunkMetadata): - for chunk in chunks: - self.chunks.append(chunk) - chunk_idx = len(self.chunks) - 1 - self.logger.debug(f"Received chunk {chunk_idx}") - if chunk_idx in self._reader_promises: - self.logger.debug(f"Resolving promise for chunk {chunk_idx}") - self._reader_promises[chunk_idx].set_result(chunk) - del self._reader_promises[chunk_idx] - - def _new_metrics(self, metrics): - self._latest_metrics = metrics - self._do_notify() - - def _do_notify(self): - async def _do_notify_async(): - async with self._metrics_condition: - self._metrics_condition.notify_all() - - asyncio.create_task(_do_notify_async()) - - def _child_failed(self, child: ray.actor.ActorHandle, exception: ExceptionInfo): - try: - super()._child_failed(child, exception) - except Exception as e: - logger.exception("Error in child_failed") - self._writer_exception(None, ser_exc_info(e)) - - def _writer_exception(self, shard_name, exc_info: ExceptionInfo): - info = exc_info.restore() - - logger.exception(f"Writer task {shard_name} failed with exception", exc_info=info) - for future in self._reader_promises.values(): - future.set_exception(info[1]) - - self._reader_promises = {} - - self._finished_promise.set_exception(info[1]) - self._do_notify() - - def _finalize(self): - logger.info(f"Finalizing cache {self._cache_dir}...") - self._is_finished = True - for k, future in self._reader_promises.items(): - future.set_result(None) - - # write ledger - _serialize_json_and_commit( - os.path.join(self._cache_dir, LEDGER_FILE_NAME), CacheLedger(self.chunks, self._cache_config) - ) - - self._reader_promises = {} - # TODO: For some reason this crashes other actors with weird reference counting assertion errors. - # pretty sure it's a ray bug - # self._builder_actor = None - self._finished_promise.set_result(None) - - # notify metrics subscribers - self._do_notify() - - -def _get_broker_actor( - cache_dir, - input_shards, - processor, - cache_config=None, - rows_per_chunk=DEFAULT_ROWS_PER_CHUNK, -): - return ChunkCacheBroker.options( - name="lev_cache_manager::" + cache_dir.replace("/", "--"), get_if_exists=True, lifetime="detached" - ).remote( - # type: ignore - cache_dir=cache_dir, - source=input_shards, - processor=processor, - cache_config=cache_config, - rows_per_chunk=rows_per_chunk, - ) - - -class DictCacheDataset(ShardableDataset[dict]): - """ - A Dataset that yields HF BatchEncodings from a ShardCache. - This basically yields a dict-of-arrays, just the HF BatchEncoding class version of dict. - """ - - def __init__(self, cache: "ShardCache", return_batches: bool = False): - self.cache = cache - self.return_batches = return_batches - - def __iter__(self) -> Iterator[dict]: - for batch in self.cache: - encoding = dict_from_record_batch(batch) - - if self.return_batches: - yield encoding - else: - batch_size = 0 - for v in encoding.values(): - batch_size = len(v) - break - - for i in range(batch_size): - yield {k: v[i] for k, v in encoding.items()} - - def shard(self, shard_id: int, num_shards: int) -> "DictCacheDataset": - return DictCacheDataset(self.cache.shard(shard_id, num_shards)) - - @staticmethod - def load(cache_dir: str, return_batches: bool = False, batch_size: Optional[int] = None) -> "DictCacheDataset": - if batch_size is None: - batch_size = 1 - cache = ShardCache.load(cache_dir, batch_size=batch_size) - return DictCacheDataset(cache, return_batches=return_batches) - - -class ShardCache(Iterable[pa.RecordBatch]): - """A cache which is backed by a collection of chunks of preprocessed documents. These chunks - are produced by tokenizing/preprocessing a ShardedDataset. - - This is the main interface for building and reading from a shard cache. - - ShardCache has the following objectives: - - 1) Deterministic ordering over the data - 2) Sharded reading - 3) Sharded writing - 4) Simultaneous reading and writing of shards - 5) Fast resumption of writing - 6) Fast resumption of reading - - ShardCache achieves (1), (2), and (3) maintaining a reproducible global ordering over "chunks" created from shards. - The global ordering is defined by taking chunks round-robin from each shard. This allows us to read shards - in parallel and deterministically. - - ShardCache achieves (4) also via the chunking mechanism. As soon as all shards have written a chunk, the next - chunk can be read. This allows us to read and write in parallel. - - ShardCache achieves (5) by writing chunks to disk as soon as they are completed and serializing a state - of the chunks that have been written for each shard. This allows us to resume from the last chunk that was written. - - # TODO (6) isn't implemented just yet - - ShardCache achieves (6) by storing metadata about the chunks that have been written in a state. In addition - to the global ordering, the state also stores the number of documents in each chunk as well as the number - of tokens. - """ - - ledger: Optional[CacheLedger] - _broker: Optional[ActorHandle] - # We use a thread here instead of an actor because we want to ensure it's on the same process as the ShardCache - # object. - _monitor_thread: Optional[threading.Thread] - _metrics_monitors: List[MetricsMonitor] - - def __init__( - self, - cache_dir: str, - batch_size: int, - ledger: Optional[CacheLedger], - _broker: Optional[ActorHandle], - reader_offset: int = 0, - num_readers: int = 1, - ): - self.cache_dir = cache_dir - self.ledger = ledger - self._broker = _broker - self._batch_size = batch_size - - self._metrics_monitors = [] - self._monitor_thread = None - - self._num_readers = num_readers - self._reader_offset = reader_offset - name = os.path.join(*cache_dir.split("/")[-2:]) - self.logger = pylogging.getLogger(f"ShardCache.{name}") - - @staticmethod - def load(cache_dir: str, batch_size: int) -> "ShardCache": - """Loads a cache from disk. Raises FileNotFoundError if the cache doesn't exist""" - logger.info(f"Loading cache from {cache_dir}") - ledger = _load_cache_ledger(cache_dir) - return ShardCache(cache_dir, batch_size, ledger, None) - - @staticmethod - def build_or_load( - cache_dir: str, - shard_source: ShardedDataset[T], - processor: BatchProcessor[T], - batch_size: int, - rows_per_chunk: int, - cache_config: Optional[Dict[str, Any]] = None, - ): - try: - return ShardCache.load(cache_dir, batch_size) - except FileNotFoundError: - broker = _get_broker_actor( - cache_dir=cache_dir, - input_shards=shard_source, - processor=processor, - cache_config=cache_config, - rows_per_chunk=rows_per_chunk, - ) - return ShardCache(cache_dir=cache_dir, batch_size=batch_size, ledger=None, _broker=broker) - - def finished_sentinel(self): - """Returns a Ray-awaitable object that will be set when the cache is finished""" - if self._broker is None: - return ray.remote(num_cpus=0)(lambda: None).remote() - else: - return self._broker.finished_sentinel.remote() - - @property - def is_finished(self): - """Returns whether the cache is finished""" - if self._broker is None: - return True - else: - return ray.get(self._broker.is_finished.remote()) - - def read_chunk(self, chunk_idx: int) -> Iterator[pa.RecordBatch]: - """Reads a chunk from the cache""" - chunk = self.get_chunk(chunk_idx) - yield from self._read_chunk(chunk) - - def _map_index(self, index): - return index * self._num_readers + self._reader_offset - - def get_chunk(self, index: int, *, timeout: Optional[float] = None) -> ChunkMetadata: - """Returns the metadata for a given chunk index""" - mapped_index = self._map_index(index) - return self._get_chunk_unmapped(mapped_index, timeout=timeout) - - def _get_chunk_unmapped(self, mapped_index: int, *, timeout: Optional[float] = None) -> ChunkMetadata: - if self.ledger is not None: - return self.ledger.chunks[mapped_index] - else: - assert self._broker is not None - time_in = time.time() - next_time = time_in - # we want to also log if we're waiting for a long time, so we do this in a loop - while timeout is None or next_time - time_in < timeout: - current_timeout = 20.0 - if timeout is not None: - current_timeout = min(current_timeout, timeout - (next_time - time_in)) - try: - chunk = ray.get(self._broker.get_chunk.remote(mapped_index), timeout=current_timeout) - except GetTimeoutError: - self.logger.warning(f"Waiting for chunk {mapped_index} for {int(next_time - time_in)} seconds") - next_time = time.time() - current_timeout *= 2 - current_timeout = min(current_timeout, 100) - continue - except asyncio.exceptions.InvalidStateError: - self.logger.warning( - f"Invalid state waiting for chunk {mapped_index} for {int(next_time - time_in)} seconds" - ) - next_time = time.time() - current_timeout *= 2 - current_timeout = min(current_timeout, 100) - time.sleep(current_timeout) - continue - - if chunk is None: - raise IndexError(f"Chunk index out of bounds. (Mapped index {mapped_index})") - - return chunk - - if timeout is not None: - raise TimeoutError(f"Timeout while waiting for chunk {mapped_index}") - - async def get_chunk_async(self, index: int) -> ChunkMetadata: - """Returns the metadata for a given chunk index""" - mapped_index = self._map_index(index) - if self.ledger is not None: - return self.ledger.chunks[mapped_index] - else: - assert self._broker is not None - chunk = await self._broker.get_chunk.remote(mapped_index) - if chunk is None: - raise IndexError(f"Chunk index {index} out of bounds. (Mapped index {mapped_index})") - return chunk - - def final_chunk_count(self) -> Optional[int]: - """Returns the number of chunks in the cache, if known""" - if self.ledger is not None: - return len(self.ledger.chunks) - else: - assert self._broker is not None - return ray.get(self._broker.final_chunk_count.remote()) - - def iter_batches_from_chunks(self, loop: bool = False): - shard_offset = self._reader_offset - - if self.ledger is not None: - num_chunks = len(self.ledger.chunks) - - if num_chunks == 0: - return - - while True: - i = 0 - for i in range(shard_offset, num_chunks, self._num_readers): - chunk = self.ledger.chunks[i] - yield from self._read_chunk(chunk) - - if not loop: - break - - shard_offset = i % len(self.ledger.chunks) - else: - assert self._broker is not None - i = shard_offset - while True: - try: - self.logger.debug(f"Reading chunk {i}") - chunk = self._get_chunk_unmapped(i) - i += self._num_readers - yield from self._read_chunk(chunk) - except IndexError: - if loop: - num_chunks = ray.get(self._broker.final_chunk_count.remote()) - assert num_chunks is not None - - i = i % num_chunks - else: - break - except Exception as e: - self.logger.exception("Error while reading from shard cache.") - raise e - - def __iter__(self): - return self.iter_batches_from_chunks() - - def shard(self, offset, num_readers): - """ - Returns a shard of this shard cache. This method shards w.r.t the current shard cache, not the base shard cache. - - Args: - offset: - num_readers: - - Returns: - (ShardCache): A shard of this shard cache. - """ - if offset >= num_readers: - raise ValueError(f"Shard index {offset} is out of range") - - if num_readers == 1: - return self - - new_offset = self._reader_offset * num_readers + offset - new_num_readers = self._num_readers * num_readers - return ShardCache(self.cache_dir, self._batch_size, self.ledger, self._broker, new_offset, new_num_readers) - - def unshard(self): - """ - Gets the "base" shard cache that this shard cache is a shard of. - """ - return ShardCache(self.cache_dir, self._batch_size, self.ledger, self._broker, 0, 1) - - def with_batch_size(self, batch_size): - return ShardCache( - self.cache_dir, batch_size, self.ledger, self._broker, self._reader_offset, self._num_readers - ) - - def _read_chunk(self, chunk): - reader = _ChunkReader.from_metadata(self.cache_dir, chunk, self._batch_size) - for batch in reader: - yield batch - - def await_finished(self, timeout: Optional[float] = None): - return ray.get(self.finished_sentinel(), timeout=timeout) - - def attach_metrics_monitor(self, monitor: MetricsMonitor): - if self._broker is None: - # TODO: decide what to do about attaching if the cache is already finished - # maybe get the final metrics? - return - - self._metrics_monitors.append(monitor) - if self._monitor_thread is None: - self._monitor_thread = threading.Thread(target=self._monitor_metrics) - self._monitor_thread.start() - - def _monitor_metrics(self): - while True: - try: - metrics = ray.get(self._broker.updated_metrics.remote()) - for monitor in self._metrics_monitors: - monitor(metrics) - if metrics.is_finished: - break - except Exception as e: - self.logger.exception("Error while reading metrics from shard cache.") - raise e - - -class _ChunkReader: - """Reads batches of documents from a chunk""" - - metadata: ChunkMetadata - file: pq.ParquetFile - batch_size: int - - # TODO: seek by doc - # TODO: seek by token etc - - def __init__(self, metadata: ChunkMetadata, file: pq.ParquetFile, batch_size: int): - self.metadata = metadata - self.file = file - self.batch_size = batch_size - - def with_batch_size(self, batch_size): - return _ChunkReader(self.metadata, self.file, batch_size) - - @property - def num_docs(self): - return self.metadata.num_rows - - def field_count(self, field, default=None): - return self.metadata.field_counts.get(field, default) - - @property - def __len__(self): - return (self.num_docs + self.batch_size - 1) // self.batch_size - - def __iter__(self) -> Iterator[pa.RecordBatch]: - for record_batch in self.file.iter_batches(batch_size=self.batch_size): - yield record_batch - - @staticmethod - def from_metadata(cache_dir, metadata: ChunkMetadata, batch_size: int) -> "_ChunkReader": - file = pq.ParquetFile(fsspec.open(os.path.join(cache_dir, f"{metadata.name}.parquet"), "rb").open()) - return _ChunkReader(metadata, file, batch_size) diff --git a/src/levanter/data/sharded_dataset.py b/src/levanter/data/sharded_datasource.py similarity index 89% rename from src/levanter/data/sharded_dataset.py rename to src/levanter/data/sharded_datasource.py index e16f5fce7..38682616d 100644 --- a/src/levanter/data/sharded_dataset.py +++ b/src/levanter/data/sharded_datasource.py @@ -2,7 +2,20 @@ import json import os import warnings -from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, List, Optional, Sequence, Sized, Tuple, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Iterable, + Iterator, + List, + Optional, + Sequence, + Sized, + Tuple, + TypeVar, +) import datasets import fsspec @@ -10,6 +23,7 @@ from levanter.utils import fsspec_utils +from ..data import AsyncDataset from ._preprocessor import ( BatchResult, _BatchMapTransform, @@ -17,7 +31,6 @@ _DatasetTransform, _MapTransform, ) -from .dataset import Dataset, ShardableDataset from .utils import batched @@ -30,7 +43,7 @@ U = TypeVar("U") -class ShardedDataset(Dataset[T_co]): +class ShardedDataSource(Generic[T_co]): """ A ShardedDataset is the main interface for reading data. It's basically a mapping from shard names to iterators, with the extra feature that it exposes the ability to skip to a particular row in a shard. @@ -66,10 +79,9 @@ def build_or_load_cache( self, path: str, *, - rows_per_chunk: Optional[int] = None, await_finished: bool = True, monitors: Optional[Sequence["MetricsMonitor"]] = None, - ) -> ShardableDataset[dict]: + ) -> AsyncDataset[T]: """ Constructs a shard cache version of this dataset using Ray. @@ -79,36 +91,30 @@ def build_or_load_cache( * interruptible and resumable * streaming results (no need to wait for everything to finish) - *Note that build_cache does not in general preserve the order of the data.* - Note that this is an experimental API and is subject to change. Returns: - A new dataset that is backed by the cache. + A new AsyncDataset that is backed by the cache. """ - from levanter.data.shard_cache import DEFAULT_ROWS_PER_CHUNK, DictCacheDataset, build_or_load_cache - - if rows_per_chunk is None: - rows_per_chunk = DEFAULT_ROWS_PER_CHUNK source, processor = _construct_composite_batch_processor(self) + from ..store.cache import build_or_load_cache cache = build_or_load_cache( path, source, processor, - rows_per_chunk=rows_per_chunk, await_finished=await_finished, monitors=monitors, ) - return DictCacheDataset(cache) + return cache - def map(self, fn: Callable[[T_co], U]) -> "ShardedDataset[U]": - return _MappedShardedDataset(self, fn) + def map(self, fn: Callable[[T_co], U]) -> "ShardedDataSource[U]": + return _MappedShardedDataSource(self, fn) def map_batches( self, fn: Callable[[list[T_co]], BatchResult], batch_size, *, num_cpus=1, num_gpus=0, **resources - ) -> "ShardedDataset[dict]": + ) -> "ShardedDataSource[dict]": """ **Lazily** map a function over batches of data. This is useful for doing things like batching data for a model, or for batched preprocessing. @@ -125,21 +131,25 @@ def map_batches( Returns: A new ShardedDataset. """ - return _BatchMappedShardedDataset(self, fn, batch_size, num_cpus=num_cpus, num_gpus=num_gpus, **resources) + return _BatchMappedShardedDataSource(self, fn, batch_size, num_cpus=num_cpus, num_gpus=num_gpus, **resources) -def dataset_from_hf(id: str, *, split, **kwargs) -> ShardedDataset[dict]: +def datasource_from_hf(id: str, *, split, **kwargs) -> ShardedDataSource[dict]: """ Create a ShardedDataset from a HuggingFace dataset. Arguments are passed to load_dataset. """ - return WrappedHFDataset(id, split=split, **kwargs) + return WrappedHFDataSource(id, split=split, **kwargs) + + +def datasource_from_jsonl(urls_or_paths: Sequence[str]) -> ShardedDataSource[dict]: + return JsonlDataSource(urls_or_paths) -def dataset_from_jsonl(urls_or_paths: Sequence[str]) -> ShardedDataset[dict]: - return JsonlDataset(urls_or_paths) +def datasource_from_json(urls_or_paths: Sequence[str]) -> ShardedDataSource[dict]: + return JsonDataSource(urls_or_paths) -class WrappedHFDataset(ShardedDataset[dict]): +class WrappedHFDataSource(ShardedDataSource[dict]): """ This class is responsible for loading a dataset from HuggingFace Datasets and returning the shards. Only (some) IterableDatasets are actually sharded in any meaningful way, so we just return a single shard @@ -189,7 +199,7 @@ def _load_dataset(self): return datasets.load_dataset(self.id, split=self.split, streaming=self.streaming, **self.kwargs) -class TextUrlDataset(ShardedDataset[str]): +class TextUrlDataSource(ShardedDataSource[str]): """ Dataset for various text formats. """ @@ -232,7 +242,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: raise ValueError(f"Unknown format {format}") -class AudioTextUrlDataset(ShardedDataset[Tuple[np.ndarray, int, str]]): +class AudioTextUrlDataSource(ShardedDataSource[Tuple[np.ndarray, int, str]]): """ Dataset for various audio and text formats. """ @@ -267,6 +277,8 @@ def _load_audio_file(file_name, sampling_rate): audio = {"array": array, "sampling_rate": sr} elif "path" in audio_pointer: audio = _load_audio_file(audio_pointer["path"], sampling_rate) + else: + raise ValueError(f"Unsupported audio format {audio_pointer}") elif isinstance(audio_pointer, str): # This supports filename pointers to arbitrary audio types audio = _load_audio_file(audio_pointer, sampling_rate) @@ -287,14 +299,14 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[Tuple[np.ndar if i >= row: mat_json = json.loads(line) audio_pointer = mat_json[self.audio_key] - audio = AudioTextUrlDataset.resolve_audio_pointer(audio_pointer, self.sampling_rate) + audio = AudioTextUrlDataSource.resolve_audio_pointer(audio_pointer, self.sampling_rate) yield (audio["array"], audio["sampling_rate"], mat_json[self.text_key]) i += 1 case ".json": data = json.load(f) for doc in data[row:]: audio_pointer = doc[self.audio_key] - audio = AudioTextUrlDataset.resolve_audio_pointer(audio_pointer, self.sampling_rate) + audio = AudioTextUrlDataSource.resolve_audio_pointer(audio_pointer, self.sampling_rate) yield (audio["array"], audio["sampling_rate"], doc[self.text_key]) case _: raise ValueError(f"Unknown format {format}") @@ -348,7 +360,7 @@ def _sniff_format_for_dataset(url): return format_from_url -class JsonlDataset(ShardedDataset[dict]): +class JsonlDataSource(ShardedDataSource[dict]): def __init__(self, urls): self.urls = urls self._shard_name_to_url_mapping = _mk_shard_name_mapping(urls) @@ -369,7 +381,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: i += 1 -class TextDataset(ShardedDataset[dict]): +class TextDataSource(ShardedDataSource[dict]): def __init__(self, urls): self.urls = urls self._shard_name_to_url_mapping = _mk_shard_name_mapping(urls) @@ -388,7 +400,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: i += 1 -class JsonDataset(ShardedDataset[dict]): +class JsonDataSource(ShardedDataSource[dict]): def __init__(self, urls): self.urls = urls self._shard_name_to_url_mapping = _mk_shard_name_mapping(urls) @@ -440,12 +452,12 @@ def _mk_shard_name_mapping(urls): class _TransformedDataset: - source: ShardedDataset + source: ShardedDataSource _transform: _DatasetTransform -class _MappedShardedDataset(ShardedDataset[T], _TransformedDataset): - def __init__(self, source: ShardedDataset[T_co], fn: Callable[[T_co], T]): +class _MappedShardedDataSource(ShardedDataSource[T], _TransformedDataset): + def __init__(self, source: ShardedDataSource[T_co], fn: Callable[[T_co], T]): self.source = source self.fn = fn self._transform = _MapTransform(fn) @@ -458,10 +470,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[T]: return map(self.fn, self.source.open_shard_at_row(shard_name, row)) -class _BatchMappedShardedDataset(ShardedDataset[T], _TransformedDataset): +class _BatchMappedShardedDataSource(ShardedDataSource[T], _TransformedDataset): def __init__( self, - source: ShardedDataset[T_co], + source: ShardedDataSource[T_co], fn: Callable[[list[T_co]], Iterable[U]], batch_size, num_cpus=1, diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index c29e55e83..fc9ce8052 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -1,4 +1,5 @@ import abc +import asyncio import copy import dataclasses import functools @@ -7,7 +8,7 @@ from dataclasses import dataclass from functools import cached_property from itertools import chain -from typing import Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union import braceexpand import datasets @@ -15,20 +16,28 @@ import fsspec import jax import numpy as np -import pyarrow as pa import regex +import tensorstore as ts from draccus import field +from jax._src.random import PRNGKey from jaxtyping import PRNGKeyArray +from tokenizers import normalizers import haliax as hax from haliax import Axis +from levanter.data import AsyncDataset +from levanter.data.dataset import MappedAsyncDataset from levanter.data.mixture import MixtureDataset, StopStrategy +from levanter.data.permutation import EraConfig # intercept the logging nonsense here from levanter.logging import silence_transformer_nag # noqa from levanter.models.attention import AttentionMask from levanter.models.lm_model import LmExample +from levanter.store.cache import TreeCache +from levanter.store.jagged_array import JaggedArrayStore +from levanter.store.tree_store import TreeStore from levanter.utils.hf_utils import num_cpus_used_by_tokenizer @@ -36,18 +45,16 @@ from transformers import BatchEncoding, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast # noqa from levanter.compat.hf_checkpoints import load_tokenizer # noqa -from levanter.data._preprocessor import BatchProcessor, dict_from_record_batch # noqa -from levanter.data.dataset import ShardableDataset, ShuffleDataset # noqa +from levanter.data._preprocessor import BatchProcessor, U, dict_from_record_batch # noqa from levanter.data.metrics_monitor import LoggerMetricsMonitor, LoggingMetricsMonitor, MetricsMonitor # noqa -from levanter.data.shard_cache import DEFAULT_ROWS_PER_CHUNK # noqa -from levanter.data.shard_cache import CacheLedger # noqa -from levanter.data.shard_cache import LEDGER_FILE_NAME as NEW_LEDGER_FILE_NAME # noqa -from levanter.data.shard_cache import ChunkMetadata, ShardCache, build_or_load_cache # noqa -from levanter.data.sharded_dataset import ShardedDataset, TextUrlDataset, WrappedHFDataset # noqa +from levanter.data.sharded_datasource import ShardedDataSource, TextUrlDataSource, WrappedHFDataSource # noqa from levanter.shapes import NamedShapeSpec, ShapeSpec # noqa -from levanter.utils.jax_utils import use_cpu_device # noqa +from levanter.store.cache import build_or_load_cache # noqa +from levanter.utils.jax_utils import key_iterator, local_cpu_mesh, use_cpu_device # noqa +T_co = TypeVar("T_co", covariant=True) + logger = logging.getLogger("levanter.data.text") # TASKS: @@ -59,41 +66,110 @@ DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index -class CausalLmDataset(ShardableDataset[LmExample]): +class TokenSeqDataset(AsyncDataset[np.ndarray]): + """ + A dataset that yields sequences of tokens of fixed length from an underlying TreeCache. + + :param doc_cache: the TreeCache to read from + :param seq_len: The max length of sequences to emit + """ + + def __init__(self, doc_cache: TreeCache[dict], seq_len: int): + self.doc_cache = doc_cache + self.seq_len = seq_len + self._store: Optional[TreeStore] = None + self._cached_len: Optional[int] = None + + async def async_len(self) -> int: + await self.doc_cache.finished() + token_arrays = await self._await_token_cache() + return token_arrays.data_size // self.seq_len + + async def _await_token_cache(self) -> JaggedArrayStore: + if self._store is None: + self._store = await self.doc_cache.store_async() + return self._store.tree["input_ids"] + + async def final_length_is_known(self) -> bool: + return await self.doc_cache.final_length_is_known() + + def is_finite(self) -> bool: + return True + + async def current_len(self) -> Optional[int]: + store = await self._await_token_cache() + return store.data_size // self.seq_len + + async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: + token_arrays = await self._await_token_cache() + # logger.info(f"Time to get token cache: {time.time() - time_in}") + len = await self.wait_until_len_at_least(max(indices) + 1) + if len is not None and len < max(indices) + 1: + raise ValueError("Requested indices beyond the end of the dataset") + offsets = np.array(indices) * self.seq_len + with ts.Batch(): + out = [] + for offset in offsets: + out.append(token_arrays.data[offset : offset + self.seq_len].read()) + + out = await asyncio.gather(*out) + + return out + + def get_batch_sync(self, indices: Sequence[int]) -> Sequence[T_co]: + token_arrays = self.doc_cache.store.tree["input_ids"] + # logger.info(f"Time to get token cache: {time.time() - time_in}") + # len = await self.wait_until_len_at_least(max(indices) + 1) + # if len is not None and len < max(indices) + 1: + # raise ValueError("Requested indices beyond the end of the dataset") + offsets = np.array(indices) * self.seq_len + with ts.Batch(): + out = [] + for offset in offsets: + out.append(token_arrays.data[offset : offset + self.seq_len].read()) + # logger.info(f"Time to read token cache: {time.time() - time_in}") + + out = [x.result() for x in out] + # logger.info(f"Time to wait for token cache: {time.time() - time_in}") + return out + + async def wait_until_len_at_least(self, length: int) -> int: + # length is brutally slow to compute, so we cache it + if self._cached_len is not None and self._cached_len >= length: + return self._cached_len + + # TODO: would be better to listen for cache updates + length = await super().wait_until_len_at_least(length) + self._cached_len = length + return length + + +class CausalLmDataset(MappedAsyncDataset[np.ndarray, LmExample]): def __init__( self, - dataset: ShardableDataset[np.ndarray], + dataset: AsyncDataset[np.ndarray], QPos: Axis, KPos: Axis, fcm_prob: float = 0.0, - key: Optional[PRNGKeyArray] = None, + key: Optional[PRNGKey] = None, ignore_index: Optional[int] = None, ): self.dataset = dataset self.QPos = QPos self.KPos = KPos self.fcm_prob = fcm_prob - self.key = key self.ignore_id = ignore_index + self.key = key if self.fcm_prob > 0.0 and self.key is None: raise ValueError("must provide key if fcm_prob > 0.0") - def shard(self, shard_id: int, num_shards: int) -> "CausalLmDataset": - return CausalLmDataset( - self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, self.fcm_prob, self.key, self.ignore_id - ) - - def __iter__(self) -> Iterator[LmExample]: - key = self.key sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) - with use_cpu_device(): - - @functools.partial(eqx.filter_jit, out_shardings=sharding) - def _create_lm_example(tokens, key): + @functools.partial(eqx.filter_jit, out_shardings=sharding) + def _create_lm_example(tokens, key): + with local_cpu_mesh(): tokens = hax.named(tokens, self.QPos) - example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) if self.fcm_prob > 0: @@ -109,216 +185,10 @@ def _create_lm_example(tokens, key): return example - for tokens in self.dataset: - example = _create_lm_example(tokens, key) - yield example + super().__init__(self.dataset, _create_lm_example, key=key) - -class TokenSeqDataset(ShardableDataset[np.ndarray]): - """ - A dataset that yields sequences of tokens of fixed length from a TokenizedDocumentCache. - - :param doc_cache: the TokenizedDocumentCache to draw from - :param seq_len: The max length of sequences to emit - """ - - def __init__(self, doc_cache, seq_len: int, stride: Optional[int] = None): - self.doc_cache = doc_cache - self.seq_len = seq_len - self.stride = stride - - def shard(self, shard_id: int, num_shards: int) -> "TokenSeqDataset": - """ - Split the dataset into num_processes shards. - """ - return TokenSeqDataset(self.doc_cache.shard(shard_id, num_shards), self.seq_len, self.stride) - - def __iter__(self) -> Iterator[np.ndarray]: - extra_tokens = None # BatchEncoding of the last tokens from the previous doc - for doc in self.doc_cache: - # TODO: we could be cleverer here, and avoid these expensive copies etc - # should run some benchmarks to see if it's worth it - if extra_tokens is not None: - doc = _stack_batch_encodings(extra_tokens, doc) - extra_tokens = None - - for encoded_slice in concatenate_and_group_texts(doc, self.seq_len, self.stride, drop_remainder=False): - if len(encoded_slice["input_ids"]) < self.seq_len: - assert extra_tokens is None - extra_tokens = encoded_slice - else: - extra_tokens = None - ids = encoded_slice["input_ids"] - yield ids - - @staticmethod - def load(seq_len: int, cache_dir: str, stride: Optional[int] = None) -> "TokenSeqDataset": - # Maybe force the cache to be built ahead of time? - doc_cache = TokenizedDocumentCache.load(cache_dir, True) - return TokenSeqDataset(doc_cache, seq_len, stride) - - -class BatchEncodingDataset(ShardableDataset[BatchEncoding]): - """ - A Dataset that yields HF BatchEncodings from a ShardCache. - This basically yields a dict-of-arrays, just the HF BatchEncoding class version of dict. - """ - - def __init__(self, cache: ShardCache, return_batches: bool = False): - self.cache = cache - self.return_batches = return_batches - - def __iter__(self) -> Iterator[BatchEncoding]: - for batch in self.cache: - encoding = _batch_encoding_from_record_batch(batch, flatten_docs=False) - if self.return_batches: - yield encoding - else: - batch_size = 0 - for v in encoding.values(): - batch_size = len(v) - break - - for i in range(batch_size): - # this doesn't work for reconstituted batches, so we have to do this - # I have no idea why this is the case - # yield encoding[i] - yield BatchEncoding({k: v[i] for k, v in encoding.items()}) - - def shard(self, shard_id: int, num_shards: int) -> "BatchEncodingDataset": - return BatchEncodingDataset(self.cache.shard(shard_id, num_shards)) - - @staticmethod - def load(cache_dir: str, return_batches: bool = False, batch_size: Optional[int] = None) -> "BatchEncodingDataset": - if batch_size is None: - batch_size = 1 - cache = ShardCache.load(cache_dir, batch_size=batch_size) - return BatchEncodingDataset(cache, return_batches=return_batches) - - -class TokenizedDocumentCache(ShardableDataset[BatchEncoding]): - """ - Represents a tokenized document cache, which is a directory of parquet files with a ledger file. - - The difference between this class and the TokenSeqDataset is that this class yields entire documents, - while the TokenSeqDataset yields tokens sequences of fixed length from concatenated documents. - """ - - def __init__(self, chunk_cache: ShardCache, flatten_docs): - self.chunk_cache = chunk_cache - self.flatten_docs = flatten_docs - - def __iter__(self): - """Reads the cache files produced by cache_and_group and yields tokenized sequences. - If flatten is false, this returns the docs as they were presented to the caching process. If flatten is True, - then the documents returned are actually concatenated documents, where the number is the number of documents - presented as a batch to the caching process.""" - for batch in self._chunks(): - yield _batch_encoding_from_record_batch(batch, self.flatten_docs) - - def _chunks(self): - return self.chunk_cache.iter_batches_from_chunks() - - @staticmethod - def build_or_load( - cache_dir, - source: ShardedDataset[str], - tokenizer: PreTrainedTokenizerBase, - *, - flatten_docs=True, - enforce_bos=True, - enforce_eos=True, - batch_size=128, - rows_per_chunk=DEFAULT_ROWS_PER_CHUNK, - monitors=None, - await_finished=True, - override_resources=None, - ) -> "TokenizedDocumentCache": - bt = BatchTokenizer( - tokenizer, enforce_bos=enforce_bos, enforce_eos=enforce_eos, override_resources=override_resources - ) - monitors = monitors or [] - cache = build_or_load_cache( - cache_dir, - source, - bt, - await_finished=await_finished, - batch_size=batch_size, - rows_per_chunk=rows_per_chunk, - monitors=monitors, - cache_config={ - "tokenizer": tokenizer.name_or_path, - "vocab_size": tokenizer.vocab_size, - }, - ) - - if cache.is_finished: - logger.info(f"Cache {cache_dir} is complete.") - else: - logger.info( - f"Cache {cache_dir} is incomplete. This will block until at least one chunk per process is complete." - ) - - if cache.ledger and "tokenizer" in cache.ledger.metadata: - cached_tokenizer = cache.ledger.metadata["tokenizer"] - cached_vocab_size = cache.ledger.metadata["vocab_size"] - if cached_tokenizer != tokenizer.name_or_path: - raise ValueError( - f"Cache {cache_dir} was built with tokenizer {cached_tokenizer}, but current tokenizer is" - f" {tokenizer.name_or_path}." - ) - if cached_vocab_size != tokenizer.vocab_size: - raise ValueError( - f"Cache {cache_dir} was built with vocab size {cached_vocab_size}, but current vocab size is" - f" {tokenizer.vocab_size}." - ) - - return TokenizedDocumentCache(cache, flatten_docs=flatten_docs) - - @staticmethod - def load(cache_dir, batch_size: int = 128, flatten_docs=True): - """ - Load a TokenizedDocumentCache from a directory. If the ledger file is not present, this will raise a - FileNotFoundError. - - NOTE: ATM this attempts to migrate old caches to the new format, but this will be removed in the future. - - :param cache_dir: - :param flatten_docs: If true, then multiple documents from a single batch (when the cache was built) will be - concatenated into a single document. Often one is concatenating documents anyway, so this is a useful option. - :return: - """ - - try: - cache = ShardCache.load(cache_dir, batch_size=batch_size) - return TokenizedDocumentCache(cache, flatten_docs=flatten_docs) - except FileNotFoundError: - raise FileNotFoundError(f"{cache_dir} is not a complete cache") - except Exception: - logger.exception("error loading cache") - raise - - def shard(self, shard_index, num_shards): - if num_shards <= shard_index: - raise ValueError(f"Shard index {shard_index} is out of range") - - if num_shards == 1: - return self - - return TokenizedDocumentCache(self.chunk_cache.shard(shard_index, num_shards), self.flatten_docs) - - -def _batch_encoding_from_record_batch(b: pa.RecordBatch, flatten_docs: bool): - if flatten_docs: - # insert a newaxis to the beginning so that it appears to be bs=1 - return BatchEncoding( - { - b.field(i).name: b.column(i).values.to_numpy(zero_copy_only=False)[np.newaxis, :] - for i in range(b.num_columns) - }, - ) - else: - return BatchEncoding(dict_from_record_batch(b)) + async def async_len(self) -> int: + return await self.dataset.async_len() def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase): @@ -328,10 +198,12 @@ def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase): os.environ["TOKENIZERS_PARALLELISM"] = "true" +LONG_STRING_WORKAROUND = 10_000 + ws = regex.compile(r"\s") -class BatchTokenizer(BatchProcessor[str]): +class BatchTokenizer(BatchProcessor[str, dict]): """ A batch processor that tokenizes a batch of strings using a tokenizer. By default, this will append eos to the end of the string, even if the tokenizer doesn't. @@ -345,6 +217,7 @@ def __init__( *, batch_size=128, override_resources=None, + _workaround_len=LONG_STRING_WORKAROUND, return_attention_mask=False, padding=False, max_length=None, @@ -380,20 +253,64 @@ def __init__( self._need_to_add_eos = should_append_eos self._need_to_add_bos = should_append_bos + self._workaround_len = _workaround_len - def __call__(self, batch: Sequence[str]) -> BatchEncoding: + def __call__(self, batch: Sequence[str]) -> list[dict]: if self._need_to_add_bos: batch = [self.tokenizer.bos_token + " " + d for d in batch] if self._need_to_add_eos: batch = [d + " " + self.tokenizer.eos_token for d in batch] + if self._needs_long_sequence_workaround: + batch, needs_merge = self._break_for_long_sequences(batch) + else: + needs_merge = [] + if self.padding is not False: encoding = self.tokenizer(batch, return_attention_mask=self.return_attention_mask, verbose=False, padding=self.padding, max_length=self.max_length, truncation=True) # type: ignore else: encoding = self.tokenizer(batch, return_attention_mask=self.return_attention_mask, verbose=False) # type: ignore - return encoding + if needs_merge: + new_encoding = self._merge_split_encodings(batch, encoding, needs_merge) + encoding = BatchEncoding(new_encoding) + + # debatch the encoding + unbatched = [dict(zip(encoding, t)) for t in zip(*[encoding[k] for k in encoding])] + + return unbatched + + def _break_for_long_sequences(self, batch): + orig_lengths = [len(d) for d in batch] + # break any strings that are longer than LONG_STRING_WORKAROUND characters into smaller chunks + orig_batch = batch + batch = [] + needs_merge = [] + for i, d in enumerate(orig_batch): + needs_merge.append(False) + orig_len = orig_lengths[i] + while len(d) > self._workaround_len: + # we'd rather break strings at whitespace, so find the first whitespace + match = ws.search(d, self._workaround_len) + # this is vanishingly unlikely, but if we can't find a whitespace, just break it at the limit + if match is None: + split = len(d) + else: + split = match.start() + + batch.append(d[:split]) + needs_merge.append(True) + + d = d[split:] + orig_len -= split + + batch.append(d) + return batch, needs_merge + + @property + def output_exemplar(self) -> dict: + return dict(**self.tokenizer("hi there", return_attention_mask=self.return_attention_mask, verbose=False)) @property def name_or_path(self): @@ -403,6 +320,59 @@ def name_or_path(self): def vocab_size(self): return self.tokenizer.vocab_size + @staticmethod + def _merge_split_encodings(batch, encoding, needs_merge): + # merge the encodings back together + # we might need to merge multiple encodings together + # needs merge marks the first n-1 encodings that need to be merged for each document + new_encoding = {} + for k, v in encoding.items(): + if len(v) == 0: + continue + if isinstance(v[0], np.ndarray): + assert len(v) == len(batch) + v_out = [] + vs_to_merge = [] + for i in range(len(batch)): + if not needs_merge[i]: + v_out.append(np.concatenate(vs_to_merge)) + vs_to_merge = [] + vs_to_merge.append(v[i]) + + if len(vs_to_merge) > 0: + v_out.append(np.concatenate(vs_to_merge)) + + new_encoding[k] = v_out + elif isinstance(v[0], list): + v_out = [] + vs_to_merge = [] + for i in range(len(batch)): + if not needs_merge[i]: + if len(vs_to_merge) > 0: + v_out.append(list(chain(*vs_to_merge))) + vs_to_merge = [] + vs_to_merge.append(v[i]) + + if len(vs_to_merge) > 0: + v_out.append(list(chain(*vs_to_merge))) + new_encoding[k] = v_out + else: + raise ValueError(f"Unknown type {type(v[0])}") + return new_encoding + + # TODO remove this when it's resolved https://github.com/huggingface/tokenizers/issues/1495 + @cached_property + def _needs_long_sequence_workaround(self): + if isinstance(self.tokenizer, PreTrainedTokenizerFast): + normalizer = self.tokenizer.backend_tokenizer.normalizer + if normalizer is None: + return False + # if there's a "Replace" normalizer, then we need to do the workaround + # inexplicably there's no way to see inside a Sequence so we also have to assume it needs it + return isinstance(normalizer, (normalizers.Replace, normalizers.Sequence)) + else: + return False + @property def num_cpus(self) -> int: if self.override_resources is not None: @@ -511,10 +481,10 @@ class LMDatasetSourceConfig: train_urls: List[str] = () # type: ignore validation_urls: List[str] = () # type:ignore - def get_shard_source(self, split) -> Optional[ShardedDataset[str]]: + def get_shard_source(self, split) -> Optional[ShardedDataSource[str]]: if self.id is not None: try: - ds = WrappedHFDataset(self.id, split=split, name=self.name, streaming=self.stream) + ds = WrappedHFDataSource(self.id, split=split, name=self.name, streaming=self.stream) except ValueError as e: # if the message starts with Bad split, then just return None if str(e).startswith("Bad split"): @@ -531,7 +501,7 @@ def get_shard_source(self, split) -> Optional[ShardedDataset[str]]: split_urls = self.urls_for_split(split) if len(split_urls) == 0: return None - return TextUrlDataset(split_urls, self.text_key) + return TextUrlDataSource(split_urls, self.text_key) def doc_iterator(self, split: str): if self.id is not None: @@ -542,7 +512,7 @@ def doc_iterator(self, split: str): else: urls = self.urls_for_split(split) - yield from TextUrlDataset(urls, self.text_key) + yield from TextUrlDataSource(urls, self.text_key) def urls_for_split(self, split): if split == "train": @@ -575,11 +545,13 @@ class LMTaskConfig(abc.ABC): # config related to caching cache_dir: str = "cache/" - rows_per_chunk: int = DEFAULT_ROWS_PER_CHUNK # number of rows to process and cache per chunk + tokenizer_batch_size: int = 32 enforce_eos: bool = True # whether to append eos even if the tokenizer doesn't ignore_token_id: Optional[int] = None - shuffle_buffer_size: Optional[int] = None + shuffle: bool | EraConfig = False + """whether to shuffle the dataset. True means shuffle the whole dataset, False means don't shuffle. + If you want to shuffle in eras, provide an EraConfig (which asks for an era_length)""" @cached_property def the_tokenizer(self) -> PreTrainedTokenizerBase: @@ -591,13 +563,13 @@ def the_tokenizer(self) -> PreTrainedTokenizerBase: @abc.abstractmethod def train_set( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] - ) -> ShardableDataset[np.ndarray]: + ) -> AsyncDataset[np.ndarray]: pass @abc.abstractmethod def validation_sets( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, ShardableDataset[np.ndarray]]: + ) -> Mapping[str, AsyncDataset[np.ndarray]]: pass @property @@ -607,7 +579,7 @@ def sources(self) -> dict[str, LMDatasetSourceConfig]: def tagged_eval_sets( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> list[Tuple[ShardableDataset[np.ndarray], List[str]]]: + ) -> list[Tuple[AsyncDataset[np.ndarray], List[str]]]: tags = {name: (config.tags or []) + [name] for name, config in self.sources.items()} eval_sets = self.validation_sets(seq_len, monitors) @@ -620,17 +592,17 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): def train_set( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None - ) -> ShardableDataset[np.ndarray]: + ) -> AsyncDataset[np.ndarray]: ds = self.token_seq_dataset("train", seq_len, monitors) if ds is None: raise ValueError("No training set!") - if self.shuffle_buffer_size is not None: - if key is None: - key = jax.random.PRNGKey(0) - return ShuffleDataset(ds, key, self.shuffle_buffer_size) + if self.shuffle is True: + ds = ds.shuffle(key) + elif isinstance(self.shuffle, EraConfig): + ds = ds.era_shuffle(self.shuffle.era_length, key=key) - return ds + return ds # type: ignore def validation_set( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True @@ -639,7 +611,7 @@ def validation_set( def validation_sets( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, ShardableDataset[np.ndarray]]: + ) -> Mapping[str, AsyncDataset[np.ndarray]]: validation_set = self.validation_set(seq_len, monitors) if validation_set is not None: return {"": validation_set} @@ -675,12 +647,12 @@ def token_seq_dataset( def build_or_load_cache( self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None - ) -> Optional[TokenizedDocumentCache]: + ) -> Optional[TreeCache[BatchEncoding]]: split_cache_dir = os.path.join(self.cache_dir, split) name = logger_name or os.path.basename(self.cache_dir) try: - return TokenizedDocumentCache.load(split_cache_dir, flatten_docs=True) + return TreeCache.load(split_cache_dir, exemplar={"input_ids": np.zeros(0, dtype=np.int32)}) except FileNotFoundError: pass @@ -699,16 +671,20 @@ def build_or_load_cache( elif monitors is False: monitors = [] - return TokenizedDocumentCache.build_or_load( + bt = BatchTokenizer( + self.the_tokenizer, enforce_bos=True, enforce_eos=self.enforce_eos, batch_size=self.tokenizer_batch_size + ) + + return build_or_load_cache( split_cache_dir, source, - self.the_tokenizer, - enforce_eos=self.enforce_eos, - flatten_docs=True, - rows_per_chunk=self.rows_per_chunk, + bt, + await_finished=False, monitors=monitors, - # TODO: it would be better if we could just prioritize validation higher (we typically want it after the first grad step) - await_finished=(split == "validation"), + cache_config={ + "tokenizer": self.the_tokenizer.name_or_path, + "vocab_size": self.the_tokenizer.vocab_size, + }, ) @@ -749,6 +725,8 @@ class LMMixtureDatasetConfig(LMTaskConfig): train_weights: Dict[str, float] = field(default_factory=dict) """ weights for each dataset source. They will be normalized to sum to 1. """ stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY) + mixture_block_size: int = 2048 + """ block size for the mixture dataset.""" def __post_init__(self): if len(self.configs) == 0: @@ -762,40 +740,59 @@ def __post_init__(self): def train_set( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] - ) -> ShardableDataset[np.ndarray]: + ) -> AsyncDataset[np.ndarray]: doc_caches = self.build_caches("train", monitors=monitors) - token_datasets = {name: TokenSeqDataset(cache, seq_len, stride=None) for name, cache in doc_caches.items()} + token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} + if key is None: key = jax.random.PRNGKey(0) mix_key, shuffle_key = jax.random.split(key) + # We shuffle the components and not the overall mixture because this lets us preserve + # the "stable batch" property of the mixture dataset. + def shuffle_ds(ds, key): + if self.shuffle is True: + ds = ds.shuffle(key) + elif isinstance(self.shuffle, EraConfig): + ds = ds.era_shuffle(self.shuffle.era_length, key=key) + + return ds + + if self.shuffle: + out_token_datasets = {} + key_iter = key_iterator(shuffle_key) + for name, ds in token_datasets.items(): + out_token_datasets[name] = shuffle_ds(ds, next(key_iter)) + token_datasets = out_token_datasets + mixture = MixtureDataset( - datasets=token_datasets, weights=self.train_weights, stop_strategy=self.stop_strategy, key=mix_key + datasets=token_datasets, + weights=self.train_weights, + stop_strategy=self.stop_strategy, + key=mix_key, + block_size=2048, ) - if self.shuffle_buffer_size is not None: - return ShuffleDataset(mixture, shuffle_key, self.shuffle_buffer_size) - return mixture def training_sets( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, ShardableDataset[np.ndarray]]: + ) -> Mapping[str, TokenSeqDataset]: doc_caches = self.build_caches("train", monitors=monitors) - token_datasets = {name: TokenSeqDataset(cache, seq_len, stride=None) for name, cache in doc_caches.items()} + token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} return token_datasets def validation_sets( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, ShardableDataset[np.ndarray]]: + ) -> Mapping[str, AsyncDataset[np.ndarray]]: doc_caches = self.build_caches("validation", monitors=monitors) - token_datasets = {name: TokenSeqDataset(cache, seq_len, stride=None) for name, cache in doc_caches.items()} + token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} return token_datasets def build_caches( self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Dict[str, TokenizedDocumentCache]: + ) -> Dict[str, TreeCache[dict]]: # this is a bit gross, but we want to forward all "Task" config fields to the LMDatasetConfig for building. # We do this by just grabbing all the fields from the LMDatasetConfig and forwarding them to the # LMDatasetConfig.build_or_load_cache method. We exclude the cache_dir field. @@ -822,6 +819,13 @@ def build_caches( logger.warning(f"Skipping {name} for split {split} because no source was provided") else: caches[name] = cache + + # in practice it works best if we block on validation caches + if split == "validation": + logger.info("Waiting for validation caches to finish building...") + for cache in caches.values(): + cache.await_finished() + return caches @property diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 63495d709..9d048b24f 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -1,6 +1,6 @@ import dataclasses import logging -from typing import Iterator, Optional, Tuple, TypeVar +from typing import Optional, Tuple, TypeVar import equinox as eqx import jax.numpy as jnp @@ -14,7 +14,7 @@ import levanter.tracker from levanter.callbacks import eval_loss_loop from levanter.checkpoint import load_checkpoint_or_initialize -from levanter.data import ShardableDataset +from levanter.data import AsyncDataset, MappedAsyncDataset from levanter.data.mixture import MixtureDataset from levanter.tracker import capture_time from levanter.trainer import M, StepInfo, Trainer, TrainerConfig, TrainerState @@ -56,10 +56,10 @@ def estimate_mixture_weights( loss_fn: ComputeLossFunction[M, T], initial_proxy: M, ref: M, - data_sources: dict[str, ShardableDataset[T]], + data_sources: dict[str, AsyncDataset[T]], sampling_weights: Optional[dict[str, float]] = None, *, - validation_sets: Optional[dict[str, ShardableDataset[T]]] = None, + validation_sets: Optional[dict[str, AsyncDataset[T]]] = None, trainer_config: TrainerConfig = DEFAULT_DOREMI_TRAINER_CONFIG, optimizer: optax.GradientTransformation = optax.adamw(1e-3), domain_weight_step_size: float = 1.0, @@ -107,7 +107,7 @@ def eval_loss(model, *batch, **batch_kwargs): loss = eval_loss_loop( eval_loss, ref, - trainer.replicated_loader(dataset, trainer.EvalBatch), + trainer.data_loader(dataset, trainer.EvalBatch), name=f"ref {domain}", max_batches=trainer_config.max_eval_batches, ) @@ -201,7 +201,7 @@ def doremi_step(state: DoremiState, ref, batch, domains): average_alpha=initial_alpha, ) del initial_proxy - train_loader = iter(trainer.sharded_loader(tagged_mixture, trainer.TrainBatch)) + train_loader = iter(trainer.data_loader(tagged_mixture, trainer.TrainBatch)) if state.step > 0: # step is after the batch, so we need to seek to step @@ -263,7 +263,7 @@ def _prepare_ref_model(ref, trainer): def domain_tagged_mixture( - data_sources: dict[str, ShardableDataset[T]], + data_sources: dict[str, AsyncDataset[T]], weights: dict[str, float], domain_to_index: dict[str, int], *, @@ -278,13 +278,13 @@ def domain_tagged_mixture( for domain, domain_index in domain_to_index.items() } - return MixtureDataset(tagged_datasets, weights, key=key) + return MixtureDataset(tagged_datasets, weights, key=key, block_size=2048) -class DomainTaggedDataset(ShardableDataset[Tuple[T, hax.NamedArray]]): # named array is a scalar int +class DomainTaggedDataset(MappedAsyncDataset[T, Tuple[T, hax.NamedArray]]): # named array is a scalar int def __init__( self, - dataset: ShardableDataset[T], + dataset: AsyncDataset[T], domain_index: int | hax.NamedArray, ): self.dataset = dataset @@ -294,9 +294,7 @@ def __init__( else: self.domain_index = domain_index - def shard(self, shard_id: int, num_shards: int) -> "DomainTaggedDataset[T]": - return DomainTaggedDataset(self.dataset.shard(shard_id, num_shards), self.domain_index) + def _transform(item): + return item, self.domain_index - def __iter__(self) -> Iterator[Tuple[T, hax.NamedArray]]: - for item in self.dataset: - yield item, self.domain_index + super().__init__(dataset, _transform) diff --git a/src/levanter/eval.py b/src/levanter/eval.py index 6a016f1f9..48fcb426c 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -1,7 +1,9 @@ +import asyncio import dataclasses import logging import warnings -from typing import Callable, Optional, Sequence, TypeVar +from collections import defaultdict +from typing import Callable, Mapping, Optional, Sequence, TypeVar import jax.numpy as jnp import jmp @@ -13,7 +15,7 @@ from haliax.partitioning import ResourceMapping import levanter.tracker -from levanter.data import Dataset, ReplicatedBatchLoader +from levanter.data import AsyncDataset, DataLoader from levanter.logging import LoadingTimeTrackerIterator from levanter.models.lm_model import LmExample, LmHeadModel, compute_next_token_loss from levanter.trainer import StepInfo @@ -37,15 +39,20 @@ class EvalResult: total_eval_loading_time: float -class DomainTaggedDataset(Dataset[tuple[T, hax.NamedArray]]): +# This class doesn't try to be async or work with incomplete datasets, because it's eval + + +class DomainTaggedDataset(AsyncDataset[tuple[T, hax.NamedArray]]): """Holds multiple datasets, each with its own domain tag. Also indexes the tags to enable easier aggregation.""" + tag_index: Mapping[str, int] + @property def tags(self): return self.tag_to_index.keys() def __init__( - self, datasets: Sequence[tuple[Dataset[T], Sequence[str]]], max_examples_per_dataset: Optional[int] = None + self, datasets: Sequence[tuple[AsyncDataset[T], Sequence[str]]], max_examples_per_dataset: Optional[int] = None ): self.datasets = [] tag_index: dict[str, int] = {} @@ -62,20 +69,78 @@ def __init__( self.tag_to_index = tag_index self.Tag = hax.Axis("tag", len(self.tag_to_index)) self.max_examples_per_dataset = max_examples_per_dataset + self._tag_arrays = self._compute_tag_arrays() + self._offsets: Optional[np.ndarray] = None + self._max_examples_per_dataset = max_examples_per_dataset + + async def _get_offsets(self) -> np.ndarray: + if self._offsets is None: + lengths = await asyncio.gather(*[dataset.async_len() for dataset, _ in self.datasets]) + if self._max_examples_per_dataset is not None: + lengths = [min(length, self._max_examples_per_dataset) for length in lengths] + self._offsets = np.cumsum([0] + lengths) - def __iter__(self): + return self._offsets # type: ignore + + def _compute_tag_arrays(self): + tag_arrays = [] for dataset, tags in self.datasets: indexed = [self.tag_to_index[tag] for tag in tags] tags = np.zeros(self.Tag.size, dtype=np.int32) tags[indexed] = 1 tags = hax.named(tags, self.Tag) - count = 0 - for example in dataset: - if self.max_examples_per_dataset is not None and count >= self.max_examples_per_dataset: - break - count += 1 - yield example, tags + tag_arrays.append(tags) + return tag_arrays + + async def async_len(self) -> int: + return int((await self._get_offsets())[-1]) + + async def getitem_async(self, item: int) -> tuple[T, hax.NamedArray]: + offsets = await self._get_offsets() + dataset_index = np.searchsorted(offsets, item, side="right") - 1 + offset = offsets[dataset_index] + dataset, tags = self.datasets[dataset_index] + return await dataset.getitem_async(int(item - offset)), self._tag_arrays[dataset_index] + + async def get_batch(self, indices: Sequence[int]) -> Sequence[tuple[T, hax.NamedArray]]: + # Chatgpt wrote this. pretty sure it's correct + offsets = await self._get_offsets() + original_order = np.argsort(indices) + sorted_indices = np.array(indices)[original_order] + dataset_indices = np.searchsorted(offsets, sorted_indices, side="right") - 1 + + # Group indices by the dataset they belong to + grouped_indices = defaultdict(list) + for idx, dataset_index in zip(sorted_indices, dataset_indices): + grouped_indices[dataset_index].append(idx - offsets[dataset_index]) + + # Retrieve the batch for each group + batch_futures: list = [] + for dataset_index, dataset_indices in grouped_indices.items(): + dataset, tags = self.datasets[dataset_index] + dataset_batch = dataset.get_batch(dataset_indices) + batch_futures.append(dataset_batch) + + batch_groups = await asyncio.gather(*batch_futures) + batch = [] + for dataset_index, dataset_batch in zip(grouped_indices.keys(), batch_groups): + batch.extend([(item, self._tag_arrays[dataset_index]) for item in dataset_batch]) + + # Reorder the batch to match the original order of indices + batch = [batch[i] for i in np.argsort(original_order)] + + return batch + + async def final_length_is_known(self) -> bool: + return all(await asyncio.gather(*[dataset.final_length_is_known() for dataset, _ in self.datasets])) + + def is_finite(self) -> bool: + return all(dataset.is_finite() for dataset, _ in self.datasets) + + async def current_len(self) -> Optional[int]: + # We currently require all datasets to be finished before we do anything with this dataset, so... + return await self.async_len() def _join_prefix(prefix: str, tag: str) -> str: @@ -86,7 +151,7 @@ def _join_prefix(prefix: str, tag: str) -> str: def cb_tagged_lm_evaluate( EvalBatch: hax.Axis, - tagged_eval_sets: Sequence[tuple[Dataset[LmExample], Sequence[str]]], + tagged_eval_sets: Sequence[tuple[AsyncDataset[LmExample], Sequence[str]]], device_mesh: Optional[Mesh] = None, axis_mapping: ResourceMapping = None, max_examples_per_dataset: Optional[int] = None, @@ -168,7 +233,7 @@ class TaggedEvaluator: def __init__( self, EvalBatch: hax.Axis, - tagged_eval_sets, + tagged_eval_sets: Sequence[tuple[AsyncDataset, Sequence[str]]], device_mesh=None, axis_mapping=None, max_examples_per_dataset=None, @@ -176,8 +241,12 @@ def __init__( ): self.EvalBatch = EvalBatch self.dataset = DomainTaggedDataset(tagged_eval_sets, max_examples_per_dataset) - self.loader = ReplicatedBatchLoader( - self.dataset, mesh=device_mesh, axis_resources=axis_mapping, Batch=EvalBatch + self.loader = DataLoader( + EvalBatch, + self.dataset.as_async_dataset(), + max_buffered_batches=100, + mesh=device_mesh, + axis_resources=axis_mapping, ) self.mp = mp @@ -229,9 +298,11 @@ def evaluate(self, m: LmHeadModel): state = hax.shard(state) iterator = LoadingTimeTrackerIterator(self.loader) + n = 0 for batch, tags in tqdm.tqdm(iterator, "eval"): state = self.accum_for_batch(m, state, batch, tags) + n += 1 total_loss, losses_per_tag = state diff --git a/src/levanter/logging.py b/src/levanter/logging.py index 66e9e9581..1d1e1bd7f 100644 --- a/src/levanter/logging.py +++ b/src/levanter/logging.py @@ -75,11 +75,13 @@ def __init__(self, items: Iterable[T]): start = time.perf_counter() self.items = iter(items) self.total_time += time.perf_counter() - start + self.this_load_time = 0.0 def __next__(self) -> T: start = time.perf_counter() item = next(self.items) - self.total_time += time.perf_counter() - start + self.this_load_time = time.perf_counter() - start + self.total_time += self.this_load_time return item diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 5063c69e2..2483e9214 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -4,10 +4,10 @@ import levanter from levanter.data.metrics_monitor import LoggingMetricsMonitor, RichMetricsMonitor -from levanter.data.shard_cache import build_or_load_cache from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.distributed import RayConfig from levanter.logging import init_logging +from levanter.store.cache import build_or_load_cache from levanter.tracker import NoopConfig, TrackerConfig @@ -46,10 +46,8 @@ def main(args: RayCachedLMDatasetConfig): cache_dir=split_cache_dir, input_shards=source, processor=batch_tokenizer, - rows_per_chunk=args.rows_per_chunk, await_finished=False, monitors=monitors, - batch_size=128, ) cache.await_finished() diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index df41750ab..116a08f18 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -16,7 +16,7 @@ from levanter import callbacks from levanter.checkpoint import load_checkpoint from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef -from levanter.data import ReplicatedBatchLoader +from levanter.data import DataLoader from levanter.data.text import CausalLmDataset, LMDatasetConfig from levanter.models.gpt2 import Gpt2Config from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel, compute_next_token_loss @@ -57,7 +57,9 @@ def main(config: EvalLmConfig): raw_dataset = CausalLmDataset(validation_set, Pos, KeyPos) # type: ignore - eval_loader = ReplicatedBatchLoader(raw_dataset, config.trainer.device_mesh, Batch) + eval_loader = DataLoader( + Batch, raw_dataset, None, config.trainer.device_mesh, config.trainer.parameter_axis_mapping + ) compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 9d7018c7e..9eee109fe 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -22,7 +22,6 @@ from levanter.optim import AdamConfig, OptimizerConfig from levanter.trainer import Trainer, TrainerConfig from levanter.utils.jax_utils import parameter_count -from levanter.utils.py_utils import non_caching_cycle logger = logging.getLogger(__name__) @@ -87,7 +86,7 @@ def main(config: LoraLmConfig): logger.warning("No evaluation datasets provided.") train_dataset = CausalLmDataset(config.data.train_set(Pos.size, key=data_key), Pos, KeyPos) - train_loader = trainer.sharded_loader(train_dataset, Batch) + train_loader = trainer.data_loader(train_dataset, Batch) # load the underlying hf model logger.info(f"Loading pretrained model from {converter.reference_checkpoint}") @@ -150,16 +149,7 @@ def loraize_hf_model(model): every=config.hf_save_steps, ) - # data loader. may need to seek to the right place if we're resuming - iter_data = non_caching_cycle(train_loader) - - if int(state.step) > 0: - # step is after the batch, so we need to seek to step - # TODO: implement iter_data.seek(resume_step +1) - import tqdm - - for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): - next(iter_data) + iter_data = train_loader.iter_from_step(state.step) ## OK, actually run training! trainer.train(state, iter_data) diff --git a/src/levanter/main/train_asr.py b/src/levanter/main/train_asr.py index 2d0651198..72e6d5adb 100644 --- a/src/levanter/main/train_asr.py +++ b/src/levanter/main/train_asr.py @@ -113,7 +113,7 @@ def compute_loss( Pos = config.model.Pos KeyPos = config.model.KeyPos - eval_datasets = config.data.validation_sets(config.batch_size) + eval_datasets = config.data.validation_sets() train_dataset = AudioTextDataset( config.data.train_set(config.batch_size), Pos, @@ -189,16 +189,7 @@ def compute_log_probs(model, example): logprobs = hax.roll(logprobs, 1, Pos) return logprobs.rearrange((EvalBatch, Pos)).array - # data loader. may need to seek to the right place if we're resuming - train_loader = iter(trainer.sharded_loader(train_dataset, Batch)) - - if int(state.step) > 0: - # step is after the batch, so we need to seek to step - # TODO: implement iter_data.seek(resume_step +1) - import tqdm - - for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): - next(train_loader) + train_loader = trainer.data_loader(train_dataset, Batch).iter_from_step(state.step) ## OK, actually run training! trainer.train(state, train_loader) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index e76f6bc5d..8e905b064 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -114,7 +114,8 @@ def main(config: TrainLmConfig): Pos = config.model.Pos KeyPos = config.model.KeyPos - tagged_eval_datasets = config.data.tagged_eval_sets(Pos.size) + # TODO: fix this + tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size) train_dataset = CausalLmDataset( config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, ignore_index=config.data.ignore_token_id ) @@ -161,13 +162,14 @@ def main(config: TrainLmConfig): if len(tagged_eval_datasets) == 0: logger.warning("No evaluation datasets provided.") else: + max_eval_examples_per_ds = config.trainer.max_eval_batches + if max_eval_examples_per_ds is not None: + max_eval_examples_per_ds *= config.trainer.eval_batch_size + causal_datasets = [ (CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags) for ds, tags in tagged_eval_datasets ] - max_eval_examples_per_ds = config.trainer.max_eval_batches - if max_eval_examples_per_ds is not None: - max_eval_examples_per_ds *= config.trainer.eval_batch_size cb = levanter.eval.cb_tagged_lm_evaluate( EvalBatch, @@ -205,23 +207,11 @@ def compute_log_probs(model, example): logprobs = hax.roll(logprobs, 1, Pos) return logprobs.rearrange((EvalBatch, Pos)).array - # engine.add_hook( - # callbacks.compute_and_visualize_log_probs( - # eval_loader, tokenizer, compute_log_probs, os.path.join(config.trainer.run_dir, "log_probs") - # ), - # every=config.trainer.steps_per_eval, - # ) - # - # data loader. may need to seek to the right place if we're resuming - train_loader = iter(trainer.sharded_loader(train_dataset, Batch)) - - if int(state.step) > 0 and seek_dataloader: - # step is after the batch, so we need to seek to step - # TODO: implement iter_data.seek(resume_step +1) - import tqdm - - for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): - next(train_loader) + train_loader = trainer.data_loader(train_dataset, Batch) + if seek_dataloader: + train_loader = train_loader.iter_from_step(state.step) + else: + train_loader = iter(train_loader) ## OK, actually run training! trainer.train(state, train_loader) diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index a95783c18..b00ba61d5 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -11,7 +11,7 @@ import levanter from levanter.checkpoint import load_checkpoint -from levanter.data import ReplicatedBatchLoader +from levanter.data import DataLoader from levanter.data.text import CausalLmDataset, LMDatasetConfig from levanter.models.gpt2 import Gpt2Config from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel, compute_next_token_loss @@ -44,10 +44,12 @@ def main(config: VizGpt2Config): Pos = config.model.Pos KeyPos = config.model.KeyPos - eval_loader = ReplicatedBatchLoader( + eval_loader = DataLoader( + EvalBatch, CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos), # type: ignore + 32, config.trainer.device_mesh, - EvalBatch, + config.trainer.compute_axis_mapping, ) # some axes we use outside the model proper diff --git a/src/levanter/store/__init__.py b/src/levanter/store/__init__.py new file mode 100644 index 000000000..d0f4ad96a --- /dev/null +++ b/src/levanter/store/__init__.py @@ -0,0 +1,6 @@ +from .cache import SerialCacheWriter, TreeCache, build_or_load_cache +from .jagged_array import JaggedArrayStore +from .tree_store import TreeStore + + +__all__ = ["TreeCache", "build_or_load_cache", "SerialCacheWriter", "JaggedArrayStore", "TreeStore"] diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py new file mode 100644 index 000000000..85b612f91 --- /dev/null +++ b/src/levanter/store/cache.py @@ -0,0 +1,1321 @@ +import asyncio +import concurrent +import dataclasses +import heapq +import logging as pylogging +import os +import threading +import time +from asyncio import InvalidStateError +from concurrent.futures import Future as threading_Future +from contextlib import AbstractContextManager +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, Sequence, TypeVar, Union + +import fsspec.core +import pyarrow as pa +import ray +from dataclasses_json import dataclass_json +from fsspec import AbstractFileSystem +from ray.actor import ActorHandle + +from levanter.data.dataset import AsyncDataset + +from ..data._preprocessor import BatchProcessor, BatchResult, dict_from_record_batch +from ..data._queue import ( + PriorityWorkItem, + PriorityWorkTaskGroup, + PriorityWorkTaskGroupSpec, + WorkQueueDispatcherActor, + _BatchProcessorQueue, +) +from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor +from ..data.sharded_datasource import ShardedDataSource +from ..utils.ray_utils import ( + ExceptionInfo, + RefBox, + SnitchRecipient, + current_actor_handle, + log_failures_to, + ser_exc_info, +) +from .tree_store import TreeStore + + +T = TypeVar("T") +U = TypeVar("U") +T_co = TypeVar("T_co", covariant=True) + +logger = pylogging.getLogger(__name__) + +LEDGER_FILE_NAME = "shard_ledger.json" + +DEFAULT_LOG_LEVEL = pylogging.INFO +LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + +# TODO: should probably do this in terms of bytes +MIN_ITEMS_TO_WRITE = 8192 +MAX_TIME_BETWEEN_WRITES = 100.0 + + +def build_or_load_cache( + cache_dir: str, + input_shards: ShardedDataSource[T], + processor: BatchProcessor[T, U], + await_finished: bool = True, + monitors: Optional[Sequence["MetricsMonitor"]] = None, + cache_config: Optional[Dict[str, Any]] = None, + items_per_write: int = MIN_ITEMS_TO_WRITE, +) -> "TreeCache[U]": + """ + Produces a sharded cache of the dataset using Ray for distributed processing. The cache can be any path + on any file system understood by fsspec. + + This system is designed with tokenization and similar processes in mind, but it can potentially be used for any kind + of preprocessing that converts input batches to output batches. The main design goal is to make it easy to + parallelize preprocessing across multiple machines while maintaining reproducibility and fault tolerance. + Usually the machines in question are the ones doing the training, but they could be separate machines as well. + + See the [Dataloader Design Doc](https://github.com/stanford-crfm/levanter/blob/main/docs/design/Data-Loader-Design.md) + for a somewhat out of date overview of the design. + + Args: + cache_dir: The directory to write the cache to. This can be any path understood by fsspec. + input_shards: A ShardedDataset that will be used to read the input data. Conceptually, it's just a mapping + from shard names to iterators over the data in that shard. + processor: A BatchProcessor that will be used to process batches of data. This is the main place where + you can customize the preprocessing pipeline. + await_finished: If True, this function will block until the cache is finished. If False, it will return + immediately. + monitors: a list of MetricsMonitors to attach to the cache. These will be called periodically with + metrics about the cache build process. If None, will add a LoggerMetricsMonitor. + + cache_config: A dictionary of configuration options for the cache. This is passed to the cache writer. + + items_per_write: The number of items to write to the cache at a time. This is a performance tuning parameter, + and you probably don't need to change it. We mostly use it for testing. + + Returns: + (TreeCache) A TreeCache object that can be used to read the cache. + + """ + # first see if we need to do anything + cache = TreeCache.build_or_load( + cache_dir=cache_dir, + shard_source=input_shards, + processor=processor, + cache_config=cache_config, + items_per_write=items_per_write, + ) + + if cache.is_finished: + logger.info("Cache already finished. Skipping.") + return cache + + if monitors is None: + monitors = [LoggerMetricsMonitor()] + + for monitor in monitors: + cache.attach_metrics_monitor(monitor) + + while await_finished: + try: + cache.await_finished(4.0) + break + except TimeoutError: + pass + + return cache + + +@dataclass_json +@dataclass +class CacheLedger: + # NB: unlike the old cache, the mere existence of a ledger doesn't mean the cache is finished + total_num_rows: int + shard_rows: Dict[str, int] + is_finished: bool = False + finished_shards: List[str] = dataclasses.field(default_factory=list) + field_counts: Dict[str, int] = dataclasses.field(default_factory=dict) + metadata: Dict[str, Any] = dataclasses.field(default_factory=dict) + + +@dataclass +class ShardStatus: + shard_name: str + num_rows_committed: int + is_finished: bool + + +class SerialCacheWriter(AbstractContextManager): + """ + Writes TreeCache-compatible caches to disk. This is a serial version of TreeCacheWriter that doesn't use Ray. + Mostly for scripts and debugging. + + Examples: + >>> with SerialCacheWriter(cache_dir, exemplar) as writer: + ... for batch in process_batches(): + ... writer.write_batch(batch) + """ + + def __init__( + self, + cache_dir: str, + exemplar: T, + cache_config: Optional[Dict[str, Any]] = None, + ): + self.cache_dir = cache_dir + self.cache_config = cache_config + self._exemplar = exemplar + self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="w") # type: ignore + self._is_closed = False + + def __enter__(self) -> "SerialCacheWriter": + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # if successful, write the ledger + # TODO: store field counts in the ledger + ledger = CacheLedger( + total_num_rows=len(self._tree_store), + is_finished=True, + shard_rows={"": len(self._tree_store)}, + finished_shards=[""], + field_counts={}, + ) + + if exc_type is None: + _serialize_json_and_commit(os.path.join(self.cache_dir, LEDGER_FILE_NAME), ledger) + logger.info(f"Cache ledger written to {self.cache_dir}") + self._is_closed = True + + def result(self) -> "TreeCache": + if not self._is_closed: + raise RuntimeError("Cannot get result until TreeCacheWriter is closed") + return TreeCache.load(self.cache_dir, self._exemplar) + + def write_batch(self, batch: BatchResult): + if isinstance(batch, pa.RecordBatch): + raise NotImplementedError("Only non-RecordBatch batches are supported for now") + + batch = _canonicalize_batch(batch) # type: ignore + + self._tree_store.extend(batch) + + +def _load_or_initialize_ledger(path): + try: + with fsspec.open(path, "r") as file: + return CacheLedger.from_json(file.read()) + except FileNotFoundError: + return CacheLedger(0, {}) + + +@ray.remote(num_cpus=0.5) # type: ignore +class _OrderedCacheWriter: + """ + This cache writer receives examples from some number of shards (generally out of order) and writes them to the store + in a defined round-robin order. It also keeps track of the metadata for each shard. + + Once a shard finishes sending batches, it notifies this writer, which then updates the metadata and writes it to disk. + """ + + def __init__( + self, + parent, + name, + exemplar, + batch_size, + cache_dir: str, + shards: Sequence[str], + min_items_to_write=MIN_ITEMS_TO_WRITE, + ): + pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) + with log_failures_to(parent): + self._parent = parent + self.cache_dir = cache_dir + self.shards = shards + self.batch_size = batch_size + self._min_items_to_write = min_items_to_write + self._failed = False + self._logger = pylogging.getLogger(name) + + # these are batches that we've received but haven't ordered them for writing yet + self._batch_queue = GroupRoundRobinBuffer(shards) # type: ignore + self._total_queue_length = 0 + self._was_overwhelmed = False # whether the queue has gotten too big + # writes are very slow (~2s) so we want to batch them up + self._ordered_but_unwritten_items: list = [] + self._batches_in_next_write_by_shard: dict[str, int] = {shard: 0 for shard in shards} + # we also want to write every so often + self._last_write_time = time.time() + + self._ledger = _load_or_initialize_ledger(os.path.join(cache_dir, LEDGER_FILE_NAME)) + self._expected_num_rows: dict[str, Optional[int]] = {shard: None for shard in shards} + + self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="a") + # careful: trim the store to the total number of rows in the cache that we've committed to + self._tree_store.trim_to_size(self._ledger.total_num_rows) + # we also have to tell the queue how many rows for each shard we've already written + for shard, num_rows in self._ledger.shard_rows.items(): + if num_rows > 0: + self._logger.info(f"Already written {num_rows} rows for shard {shard}") + + # careful: this is in terms of batch size + # Have to round up to the nearest batch size + self._batch_queue.fast_forward(shard, div_round_up(num_rows, self.batch_size)) + if shard in self._ledger.finished_shards: + self._expected_num_rows[shard] = num_rows + self._batch_queue.group_total_known(shard, div_round_up(num_rows, self.batch_size)) + + # double check that we're not finished by committing the ledger + self._attempt_to_write_batches() + + def batch_finished(self, shard_name: str, shard_batch_idx: int, batch_result_box): + with log_failures_to(self._parent): + if self._failed: + self._logger.warning("Received batch after failure. Ignoring.") + return + + if isinstance(batch_result_box, RefBox): + batch_result = ray.get(batch_result_box.ref) + else: + batch_result = batch_result_box + + # we need to keep track of the order of the batches so that we can write them out in order + self._total_queue_length += len(batch_result) + self._batch_queue.append_to_group(shard_name, shard_batch_idx, batch_result) + self._attempt_to_write_batches() + next_missing_item = self._batch_queue.next_missing_item_index() + + overwhelmed = self.is_overwhelmed() + if overwhelmed: + if not self._was_overwhelmed: + self._logger.warning(f"Writer queue is getting long ({self._total_queue_length}).") + self._parent.signal_backpressure.remote(next_missing_item) + elif self._was_overwhelmed: + self._logger.info(f"Writer queue is no longer overwhelmed ({self._total_queue_length}).") + self._parent.signal_backpressure.remote(None) + + self._was_overwhelmed = overwhelmed + + def shard_failed(self, shard_name: str, batch_id: int, exc_info: ExceptionInfo): + with log_failures_to(self._parent): + self._failed = True + logger.error(f"Shard {shard_name} failed at batch {batch_id}", exc_info=exc_info.restore()) + self._parent.shard_failed.remote(shard_name, exc_info) + + def shard_finished_reading(self, shard_name: str, expected_num_rows: int): + with log_failures_to(self._parent): + # careful: this is in terms of batch size + self._batch_queue.group_total_known(shard_name, div_round_up(expected_num_rows, self.batch_size)) + self._expected_num_rows[shard_name] = expected_num_rows + logger.debug( + f"Attempting to write batches because {shard_name} finished reading with {expected_num_rows} batches." + ) + self._attempt_to_write_batches() + + def get_shard_status(self, shard_name: str): + with log_failures_to(self._parent): + rows = self._ledger.shard_rows.get(shard_name, 0) + is_finished = shard_name in self._ledger.finished_shards + return ShardStatus(shard_name, rows, is_finished) + + def get_ledger(self): + return self._ledger + + def _attempt_to_write_batches(self): + if self._ledger.is_finished: + raise RuntimeError("Trying to write batches after cache is finished") + + if self._failed: + logger.warning("Not writing batches because of failure.") + return + + self._dequeue_ready_batches() + updated_shards = self._write_available_batches() + + logger.debug(f"Updated shards: {updated_shards}") + + need_to_commit = len(updated_shards) > 0 + total_rows = self._ledger.total_num_rows + sum(updated_shards.values()) + + for shard, num_rows in updated_shards.items(): + self._ledger.shard_rows[shard] = self._ledger.shard_rows.get(shard, 0) + num_rows + + futures_to_await_shards, need_to_commit_for_shards = self._check_for_finished_shards() + + need_to_commit = need_to_commit or need_to_commit_for_shards + + futures_to_await = [] + if need_to_commit: + self._ledger.total_num_rows = total_rows + _serialize_json_and_commit(os.path.join(self.cache_dir, LEDGER_FILE_NAME), self._ledger) + + futures_to_await.append(self._parent._updated_ledger.remote(self._ledger)) + + if self._ledger.is_finished: + f = self._parent._finalize.remote() + futures_to_await.append(f) + + ray.wait(futures_to_await + futures_to_await_shards) + + def _dequeue_ready_batches(self): + for shard, batch in self._batch_queue.drain(): + logger.debug(f"Writing batch for {shard}") + batch = _canonicalize_batch(batch) + self._total_queue_length -= len(batch) + self._ordered_but_unwritten_items.extend(batch) + self._batches_in_next_write_by_shard[shard] = self._batches_in_next_write_by_shard.get(shard, 0) + len( + batch + ) + + def _write_available_batches(self): + if len(self._ordered_but_unwritten_items) == 0: + return {} + + any_shard_finished_reading = any(num_rows is not None for num_rows in self._expected_num_rows.values()) + + if ( + len(self._ordered_but_unwritten_items) >= self._min_items_to_write + or (time.time() - self._last_write_time > MAX_TIME_BETWEEN_WRITES) + or any_shard_finished_reading + ): + time_in = time.time() + self._tree_store.extend(self._ordered_but_unwritten_items) + time_out = time.time() + logger.debug(f"Wrote {len(self._ordered_but_unwritten_items)} rows in {time_out - time_in:.2f} seconds") + self._ordered_but_unwritten_items = [] + + written_by_shard = self._batches_in_next_write_by_shard + self._batches_in_next_write_by_shard = {} + self._last_write_time = time.time() + return written_by_shard + else: + return {} + + def _check_for_finished_shards(self): + futures_to_await_shards = [] + need_to_commit_for_shards = False + for shard, expected_rows in self._expected_num_rows.items(): + if expected_rows is None: + continue + + current_rows = self._ledger.shard_rows.get(shard, 0) + if current_rows == expected_rows: + if shard not in self._ledger.finished_shards: + logger.info(f"Shard {shard} finished.") + self._ledger.finished_shards.append(shard) + futures_to_await_shards.append(self._parent.shard_finished.remote(shard)) + need_to_commit_for_shards = True + elif current_rows > expected_rows: + raise ValueError(f"Shard {shard} has more rows than expected: {current_rows} > {expected_rows}") + + if len(self._ledger.finished_shards) == len(self.shards) and set(self._ledger.finished_shards) == set( + self.shards + ): + self._ledger.is_finished = True + need_to_commit_for_shards = True + return futures_to_await_shards, need_to_commit_for_shards + + def is_overwhelmed(self) -> bool: + max_queue_size = self._min_items_to_write * 3 + return self._total_queue_length > max_queue_size + + +def _to_list_of_dicts(batch: dict) -> List[dict]: + """ + Convert a batch of dictionaries to a list of dictionaries, suitable for writing to a cache. + """ + keys = list(batch.keys()) + values = list(batch.values()) + num_rows = len(values[0]) + return [{key: values[i][j] for i, key in enumerate(keys)} for j in range(num_rows)] + + +def _canonicalize_batch(batch: Union[dict, List[dict]]) -> List[dict]: + if isinstance(batch, pa.RecordBatch): + batch = dict_from_record_batch(batch) + + if isinstance(batch, dict): + return _to_list_of_dicts(batch) + else: + return batch + + +# thinking through the design of the cache system + +# we decided to use Ray, which was maybe a mistake, but here we are. +# Ray doesn't like it when the number of actors gets too large, so we can't have one actor per shard. +# we have N nodes and K shards. + +# at a high level, we have 3 steps: +# 1. read batches from the shard source +# 2. process batches +# 3. write batches to the cache for that shard + +# The difficulty is that we want parallelism, and we want to control the order of the written data. +# Reading batches requires CPU and network. +# ==> This means we should limit the number of shard groups to roughly the number of nodes, maybe times 2. +# We ideally want to read from shards roughly evenly (at least within a group of shards) + + +def _shard_reader_generator(shard_source: ShardedDataSource[T], shard_name: str, start_row: int, batch_size: int): + shard_iter = shard_source.open_shard_at_row(shard_name, start_row) + batch = [] + for row in shard_iter: + batch.append(row) + + if len(batch) == batch_size: + yield batch + batch = [] + + if len(batch) > 0: + yield batch + + +@dataclass +class ShardGroupToBeProcessed(PriorityWorkTaskGroupSpec): + name: str + builder_ref: ray.actor.ActorHandle # _TreeStoreCacheBuilder + writer: ray.actor.ActorHandle # _GroupedShardWriter + shard_source: ShardedDataSource + shard_names: Sequence[str] + priority_fn: Callable[[int, int], float] + processor_actor: ray.actor.ActorHandle # BatchProcessorQueue + batch_size: int + group_id: int + + def build(self) -> "PriorityWorkTaskGroup": + return ShardGroupTaskGroup(self) + + +class ShardGroupTaskGroup(PriorityWorkTaskGroup): + def __init__(self, spec: ShardGroupToBeProcessed): + self.spec: ShardGroupToBeProcessed = spec + self.logger = pylogging.getLogger(f"shard_reader.{spec.group_id}.{spec.name}") + + current_shard_status: dict[str, ShardStatus] = {} + for shard_name in self.spec.shard_names: + try: + current_shard_status[shard_name] = ray.get(self.spec.writer.get_shard_status.remote(shard_name)) + except Exception as e: + self.spec.builder_ref.shard_failed.remote(shard_name, ser_exc_info()) + raise e + + batch_size = self.spec.batch_size + + self._items: list[PriorityWorkItem] = [] + + for shard_name in self.spec.shard_names: + try: + status = current_shard_status[shard_name] + if status.is_finished: + self.logger.info(f"Shard {shard_name} already finished. Skipping.") + continue + + reader = _shard_reader_generator( + self.spec.shard_source, shard_name, status.num_rows_committed, batch_size + ) + + task_name = f"shard_reader.{self.spec.name}.{shard_name}" + + batch_idx = status.num_rows_committed // batch_size + + shard_idx = self.spec.shard_source.shard_names.index(shard_name) + item = ShardReaderItem( + self, + task_name, + shard_name, + shard_idx, + batch_idx=batch_idx, + reader=reader, + current_row=status.num_rows_committed, + ) + + heapq.heappush(self._items, item) + except Exception as e: + self.logger.exception(f"Error while initializing shard {shard_name}") + self.spec.writer[shard_name].shard_failed.remote(ser_exc_info()) + raise e + + @property + def name(self): + return self.spec.name + + def items(self) -> Sequence["PriorityWorkItem"]: + return self._items + + +# NB This class is stateful +@dataclass +class ShardReaderItem(PriorityWorkItem): + """ + Each time execute is called, this class reads a batch of examples from the shard + and dispatches them to the processor. + """ + + group: ShardGroupTaskGroup + name: str + shard_name: str + shard_idx: int + batch_idx: int + reader: Iterator[list] + current_row: int = 0 + + @property + def priority(self): + return self.group.spec.priority_fn(self.shard_idx, self.batch_idx) + + @property + def spec(self): + return self.group.spec + + def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: + writer = self.spec.writer + write_finished_ref = None + + self.group.logger.debug(f"Reading one batch of shard {self.shard_name}: {self.batch_idx}") + + try: + batch = next(self.reader, None) + exhausted_shard = batch is None or (len(batch) < self.spec.batch_size) + + if batch: + priority = self.spec.priority_fn(self.shard_idx, self.batch_idx) + try: + batch_result_ref = ray.get( + self.spec.processor_actor.submit.remote( + priority=priority, + desc=f"{self.shard_name}.{self.batch_idx}", + batch=RefBox(ray.put(batch)), + ) + ) + logger.debug(f"Got batch result: {batch_result_ref}") + write_finished_ref = writer.batch_finished.remote( + self.shard_name, self.batch_idx, RefBox(batch_result_ref) + ) + self.batch_idx += 1 + self.current_row += len(batch) + except Exception as e: + self.group.logger.exception(f"Error while processing batch {self.batch_idx}") + # fire and forget + writer.shard_failed.remote(self.shard_name, self.batch_idx, ser_exc_info()) + raise e + + if exhausted_shard: + logger.info(f"Shard {self.shard_name} exhausted. Expecting {self.current_row} rows.") + writer.shard_finished_reading.remote(self.shard_name, self.current_row) + + self.group.logger.debug(f"Finished reading one batch of shard {self.shard_name}: {self.batch_idx}") + + return exhausted_shard, write_finished_ref + except Exception as e: # noqa + self.group.logger.exception(f"Error while processing shard {self.shard_name}") + # fire and forget + writer.shard_failed.remote(self.shard_name, self.batch_idx, ser_exc_info()) + raise e + + +def _serialize_json_and_commit(path, obj): + # just to be paranoid, we write to a temp file and then rename it + # TODO: probably we could do better here + with fsspec.open(f"{path}.tmp", "w") as file: + file.write(obj.to_json()) + # now copy the old file to a backup + fs: AbstractFileSystem = fsspec.core.url_to_fs(path)[0] + fs.mkdirs(os.path.dirname(path), exist_ok=True) + if fs.exists(path): + fs.copy(path, f"{path}.bak") + fs.rename(f"{path}.tmp", path) + + +def _load_cache_ledger(cache_dir) -> CacheLedger: + try: + ledger_path = os.path.join(cache_dir, LEDGER_FILE_NAME) + logger.debug(f"Attempting to load cache ledger from {ledger_path}") + with fsspec.open(ledger_path) as file: + cache_ledger = CacheLedger.from_json(file.read()) # type: ignore + return cache_ledger + except FileNotFoundError: + raise FileNotFoundError(f"Cache ledger not found at {ledger_path}") + + +@ray.remote(num_cpus=0.1) # keep this small b/c it doesn't do a lot +class _TreeStoreCacheBuilder(SnitchRecipient): + """ + Actor that coordinates the building of a cache. It spins up a bunch of workers to read from each shard + and write to the cache. + + """ + + def __init__( + self, + cache_dir: str, + name: str, + source: ShardedDataSource[T], + processor: BatchProcessor[T, U], + cache_config: Dict[str, Any], + min_items_to_write: int, + ): + pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) + self.logger = pylogging.getLogger(f"{__name__}.{name}") + self.source = source + self._cache_dir = cache_dir + # self._metrics = InProgressCacheMetrics() + self._updated_ledger_condition = asyncio.Condition() + self._ledger = CacheLedger(0, {}) + self.shards_in_progress: set[str] = set() + exemplar = processor.output_exemplar + + self._finished_promise: asyncio.Future[None] = asyncio.Future() + # used to subscribe to metrics updates + self._cache_config = cache_config + path_for_name = os.path.join(*self._cache_dir.split("/")[-2:]) + name = f"broker::{path_for_name}" + self.logger = pylogging.getLogger(f"{name}") + self._cache_writer: Optional[ActorHandle] = _OrderedCacheWriter.remote( # type: ignore + current_actor_handle(), + f"writer::{path_for_name}", + exemplar, + processor.batch_size, + cache_dir, + source.shard_names, + min_items_to_write, + ) + + try: + cache_ledger = _load_cache_ledger(self._cache_dir) + self._ledger = cache_ledger + except FileNotFoundError: + pass + + if self._ledger.is_finished: + self._finished_promise.set_result(None) + self._start_workers(cache_dir, name, processor, source) + + def _start_workers(self, cache_dir, name, processor, source): + if len(source.shard_names) == 0: + self.logger.warning("No shards to index?!?") + self._finalize() + else: + self.logger.debug(f"Starting cache build for {source.shard_names}") + self.logger.info(f"Starting cache build for {len(source.shard_names)} shards") + + self_ref = current_actor_handle() + + self._shard_writers = [] + self._shard_readers = [] + self._processor_actors = [] + + for shard_name in source.shard_names: + self.shards_in_progress.add(shard_name) + + num_shards = len(source.shard_names) + num_worker_groups = len(ray.nodes()) + num_shard_groups = max(min(num_worker_groups, num_shards), 1) + + # if we have a bunch of caches to build with one shard, we don't want them all + # assigned to the same node, so we use an offset based on the hash of the name (for stability) + # in an attempt to spread them out + group_offset = int(hash(name) % num_worker_groups) + + shard_groups: list[list[str]] = [[] for _ in range(num_shard_groups)] + for i, shard_name in enumerate(source.shard_names): + shard_groups[i % num_shard_groups].append(shard_name) + + def priority_fn(shard_idx, batch_idx): + return batch_idx * num_shards + shard_idx + + for group_id, shard_group in enumerate(shard_groups): + # TODO: would probably be better if we didn't create one of these per shard group + processor_actor = _BatchProcessorQueue.remote(processor) # type: ignore + self._processor_actors.append(processor_actor) + + assert self._cache_writer is not None + + work_item = ShardGroupToBeProcessed( + name=name, + builder_ref=self_ref, + writer=self._cache_writer, + shard_source=source, + shard_names=shard_group, + priority_fn=priority_fn, + processor_actor=processor_actor, + batch_size=processor.batch_size, + group_id=group_id, + ) + + # we want global names so that different tasks can coordinate priorities + worker_to_assign = (group_id + group_offset) % num_worker_groups + priority_actor_name = f"priority_processor.{worker_to_assign}" + + reader_actor = WorkQueueDispatcherActor.options( # type: ignore + name=priority_actor_name, get_if_exists=True + ).remote() + + reader_actor.assign_work.remote(work_item) + self._shard_readers.append(reader_actor) + + def shard_finished(self, shard_name: str): + """Callback method for when a shard worker has finished.""" + self.shards_in_progress.remove(shard_name) + + def shard_failed(self, shard_name: str, error: ExceptionInfo): + """Callback method for when a shard worker has failed.""" + self._writer_exception(shard_name, error) + + def _updated_ledger(self, ledger: CacheLedger): + self._ledger = ledger + self._do_notify() + + def other_failed(self, error: ExceptionInfo): + """Callback method for when a shard worker has failed.""" + self._writer_exception(None, error) + + def _child_failed(self, child: ray.actor.ActorHandle, exception: ExceptionInfo): + self.logger.error(f"Child {child} failed with exception", exc_info=exception.restore()) + self._writer_exception(None, exception) + + def is_finished(self): + return self._ledger.is_finished + + async def finished_sentinel(self): + await self._finished_promise + + async def updated_ledger(self) -> CacheLedger: + if self._finished_promise.done(): + if self._finished_promise.exception() is not None: + raise self._finished_promise.exception() # type: ignore + else: + return self._ledger + + async with self._updated_ledger_condition: + await self._updated_ledger_condition.wait() + return self._ledger + + def _writer_exception(self, shard_name, exc_info: ExceptionInfo): + info = exc_info.restore() + + logger.exception(f"Writer task {shard_name} failed with exception", exc_info=info) + + try: + self._finished_promise.set_exception(info[1]) + except InvalidStateError: + pass + except concurrent.futures.InvalidStateError: + pass + self._do_notify() + + def _do_notify(self): + async def _do_notify_async(): + async with self._updated_ledger_condition: + self._updated_ledger_condition.notify_all() + + asyncio.create_task(_do_notify_async()) + + def current_ledger(self): + return self._ledger + + def _finalize(self): + logger.info(f"Finalizing cache {self._cache_dir}...") + + self._ledger.is_finished = True + self._finished_promise.set_result(None) + + # notify metrics subscribers + self._do_notify() + self._cache_writer = None + + def signal_backpressure(self, next_item_desired: Optional[int]): + # get the priority of the item we want + if next_item_desired is not None: + self.logger.debug(f"Signaling backpressure for {next_item_desired}") + # our priority function above is basically (batch_index, shard_index). We just ask we don't get more + # than one round of batches ahead + max_priority = (next_item_desired + 1) * len(self.source.shard_names) + + for reader in self._shard_readers: + reader.set_max_dispatch_priority.remote(max_priority) + else: + self.logger.debug("Signaling no backpressure") + for reader in self._shard_readers: + reader.set_max_dispatch_priority.remote(None) + + +def _get_builder_actor(cache_dir, input_shards, processor, cache_config=None, items_per_write=MIN_ITEMS_TO_WRITE): + name = f"lev_cache_manager::{cache_dir}" + path_for_name = os.path.join(*os.path.split(cache_dir)[-2:]) + name_for_display = f"builder::{path_for_name}" + + return _TreeStoreCacheBuilder.options(name=name, get_if_exists=True).remote( # type: ignore + name=name_for_display, + cache_dir=cache_dir, + source=input_shards, + processor=processor, + cache_config=cache_config, + min_items_to_write=items_per_write, + ) + + +class TreeCache(AsyncDataset[T_co]): + ledger: Optional[CacheLedger] + _broker: Optional[ActorHandle] + # monitor_thread waits for new metrics and also periodically reloads the cache + _monitor_thread: Optional[threading.Thread] + _metrics_monitors: List[MetricsMonitor] + + def __init__( + self, + cache_dir: str, + exemplar: T_co, + ledger: Optional[CacheLedger], + _broker, # handle of _TreeStoreCacheBuilder + ): + self.cache_dir = cache_dir + self.ledger = ledger + self._was_already_finished = ledger is not None and ledger.is_finished + self._broker = _broker + self._exemplar = exemplar + + self._metrics_monitors = [] + name = os.path.join(*cache_dir.split("/")[-2:]) + self.logger = pylogging.getLogger(f"TreeCache.{name}") + self._store_future: threading_Future[TreeStore] = threading_Future() + self._stop = False + + if self._broker is not None: + self._monitor_thread = threading.Thread(target=self._monitor_metrics, daemon=True) + self._monitor_thread.start() + else: + self._attempt_to_load_store() + assert self._store_future.done() + + @property + def store(self) -> TreeStore[T_co]: + return self._store_future.result() + + async def store_async(self) -> TreeStore[T_co]: + if self._broker is not None: + return await asyncio.wrap_future(self._store_future) + else: + return self.store + + async def async_len(self) -> int: + if self._broker is not None: + self.await_finished() + + return len(await self.store_async()) + + def __len__(self): + self.await_finished() + + return len(self.store) + + async def final_length_is_known(self) -> bool: + if self._broker is not None: + return await self._broker.is_finished.remote() + + return True + + def is_finite(self) -> bool: + return True + + async def current_len(self) -> int: + if not self._store_future.done(): + return 0 + + return len(await self.store_async()) + + async def get_batch(self, indices: Sequence[int] | slice): + # this is tricky: we want to wait until either the cache is finished or we have the max index + if isinstance(indices, slice): + start, step, stop = await self._get_start_stops_async(indices) + await self._wait_for_len(max(stop, start)) + indices = range(start, stop, step) + + max_index = max(indices) + await self._wait_for_len(max_index + 1) + + return await self.store.get_batch(indices) + + async def _wait_for_len(self, needed_len): + if self._broker is not None: + while needed_len > await self.current_len(): + new_ledger = await self._broker.updated_ledger.remote() + + if needed_len <= new_ledger.total_num_rows: + break + + if new_ledger.is_finished: + if needed_len >= new_ledger.rows_finished: + raise IndexError( + f"Index {needed_len} out of bounds for cache of size {new_ledger.total_num_rows}" + ) + break + else: + if needed_len > len(self.store): + raise IndexError(f"Index {needed_len} out of bounds for cache of size {len(self.store)}") + + def _wait_for_len_sync(self, needed_len, timeout: Optional[float] = None): + time_in = time.time() + t_max = time_in + (timeout or 1e6) + if self._broker is not None: + while needed_len > len(self.store): + cur_time = time.time() + if cur_time > t_max: + raise TimeoutError(f"Timed out waiting for cache to reach {needed_len}") + try: + new_ledger = ray.get(self._broker.updated_ledger.remote(), timeout=max(t_max - cur_time, 10)) + except TimeoutError: + continue + + if needed_len <= new_ledger.total_num_rows: + break + + if new_ledger.is_finished: + if needed_len >= new_ledger.rows_finished: + raise IndexError( + f"Index {needed_len} out of bounds for cache of size {new_ledger.total_num_rows}" + ) + break + else: + if needed_len > len(self.store): + raise IndexError(f"Index {needed_len} out of bounds for cache of size {len(self.store)}") + + @staticmethod + def load(cache_dir: str, exemplar: T) -> "TreeCache": + """Loads a cache from disk or an object store. Raises FileNotFoundError if the cache doesn't exist""" + logger.info(f"Loading cache from {cache_dir}") + ledger = _load_cache_ledger(cache_dir) + if not ledger.is_finished: + raise FileNotFoundError(f"Cache at {cache_dir} is not finished. Use build_or_load to build it.") + return TreeCache(cache_dir, exemplar, ledger, None) + + @staticmethod + def build_or_load( + cache_dir: str, + shard_source: ShardedDataSource[T], + processor: BatchProcessor[T, U], + cache_config: Optional[Dict[str, Any]] = None, + items_per_write: int = MIN_ITEMS_TO_WRITE, + ) -> "TreeCache[U]": + try: + return TreeCache.load(cache_dir, processor.output_exemplar) + except FileNotFoundError: + broker = _get_builder_actor( + cache_dir=cache_dir, + input_shards=shard_source, + processor=processor, + cache_config=cache_config, + items_per_write=items_per_write, + ) + return TreeCache(cache_dir=cache_dir, exemplar=processor.output_exemplar, ledger=None, _broker=broker) + + def finished_sentinel(self): + """Returns a Ray-awaitable object that will be set when the cache is finished""" + if self._broker is None: + return ray.remote(num_cpus=0)(lambda: None).remote() + else: + return self._broker.finished_sentinel.remote() + + @property + def is_finished(self): + if self._broker is None: + return True + else: + return ray.get(self._broker.is_finished.remote()) + + def __getitem__(self, item): + if isinstance(item, slice): + start, step, stop = self._get_start_stops(item) + # TODO: wait for store to be set + return self.store[start:stop:step] + else: + if item < 0: + item += len(self) + if item < 0 or item >= len(self): + raise IndexError(f"Index {item} out of bounds for cache of size {len(self)}") + return self.store[item] + + def get_batch_sync(self, indices_or_slice, *, timeout: Optional[float] = None): + store = self.store + if isinstance(indices_or_slice, slice): + start, step, stop = self._get_start_stops(indices_or_slice) + indices_or_slice = range(start, stop, step) + + max_index = max(indices_or_slice) + + self._wait_for_len_sync(max_index + 1, timeout=timeout) + + return store.get_batch_sync(indices_or_slice) + + def _get_start_stops(self, slice): + start = slice.start or 0 + if slice.stop is None: + stop = len(self) + elif slice.stop < 0: + stop = len(self) + slice.stop + else: + stop = slice.stop + if start < 0: + start = len(self) + slice.start + step = slice.step or 1 + return start, step, stop + + async def _get_start_stops_async(self, slice): + start = slice.start or 0 + if slice.stop is None: + stop = await self.async_len() + elif slice.stop < 0: + stop = (await self.async_len()) + slice.stop + else: + stop = slice.stop + if start < 0: + start = (await self.async_len()) + slice.start + + step = slice.step or 1 + return start, step, stop + + def await_finished(self, timeout: Optional[float] = None): + x = ray.get(self.finished_sentinel(), timeout=timeout) + self._attempt_to_load_store() + return x + + async def finished(self): + x = await self.finished_sentinel() + # TODO: make an async version of this + self._attempt_to_load_store() + return x + + def _attempt_to_load_store(self): + if self._store_future.done(): + return + + try: + store = TreeStore.open(self._exemplar, self.cache_dir, mode="r") + except FileNotFoundError: + logger.error(f"Cache at {self.cache_dir} not found.") + assert self._broker is not None + ledger = ray.get(self._broker.current_ledger.remote()) + metrics = _ledger_to_metrics(ledger) + if metrics.rows_finished == 0 and metrics.is_finished: + # this means we built an empty cache. go with it + store = TreeStore.open(self._exemplar, f"memory://{self.cache_dir}", mode="a") + else: + raise + try: + self._store_future.set_result(store) + except concurrent.futures.InvalidStateError: + pass + + def attach_metrics_monitor(self, monitor: MetricsMonitor): + if self._broker is None: + logger.warning("Cannot attach metrics monitor to finished cache.") + # TODO: decide what to do about attaching if the cache is already finished + # maybe get the final metrics? + return + + self._metrics_monitors.append(monitor) + + def _monitor_metrics(self): + while not self._stop: + try: + try: + ledger = ray.get(self._broker.updated_ledger.remote(), timeout=10.0) + metrics = _ledger_to_metrics(ledger) + for monitor in self._metrics_monitors: + monitor(metrics) + if metrics.is_finished: + break + except TimeoutError: + pass + except Exception as e: + if str(e).startswith("Failed to submit task to actor"): + logger.warning("Cache builder actor is gone. Stopping monitoring.") + break + try: + self._attempt_to_load_store() + except FileNotFoundError: + pass + except Exception as e: + if str(e).startswith("Failed to submit task to actor"): + logger.warning("Cache builder actor is gone. Stopping monitoring.") + break + else: + self.logger.exception("Error while reading metrics from shard cache.") + raise e + + +def _ledger_to_metrics(ledger: CacheLedger) -> InProgressCacheMetrics: + return InProgressCacheMetrics( + rows_finished=ledger.total_num_rows, + is_finished=ledger.is_finished, + # shard_rows=ledger.shard_rows, + # finished_shards=ledger.finished_shards, + field_counts=ledger.field_counts, + ) + + +class GroupRoundRobinBuffer(Generic[T]): + """ + A buffer that holds items from multiple groups and returns them in a round-robin fashion. + The groups need not have the same number of items. If a group is exhausted, it is removed from the rotation. + """ + + def __init__(self, groups: Sequence[str]): + self.groups = groups + self._current_group = 0 + self.buffers: dict[str, list[tuple[int, T]]] = {group: [] for group in groups} + self._remaining_groups = set(groups) + self._totals_written: dict[str, int] = {group: 0 for group in groups} + self._totals_expected: dict[str, Optional[int]] = {group: None for group in groups} + + def __len__(self): + return sum(len(buffer) for buffer in self.buffers.values()) + + def append_to_group(self, group: str, item_serial: int, item: T): + if group not in self.groups: + raise ValueError(f"Group {group} not in {self.groups}") + + if group not in self._remaining_groups: + raise ValueError(f"Group {group} already finished") + + logger.debug(f"Appending item {item_serial} to {group}") + + heapq.heappush(self.buffers[group], (item_serial, item)) + + def group_total_known(self, group: str, total: int): + if group not in self.groups: + raise ValueError(f"Group {group} not in {self.groups}") + + if group not in self._remaining_groups: + raise ValueError(f"Group {group} already finished: {total} vs {self._totals_expected[group]}") + + self._totals_expected[group] = total + + if self._totals_written[group] == total: + assert len(self.buffers[group]) == 0 + self._remaining_groups.remove(group) + elif self._totals_written[group] > total: + raise ValueError(f"Group {group} has written more than expected: {self._totals_written[group]} > {total}") + + def is_finished(self): + return len(self._remaining_groups) == 0 + + def pop(self) -> Optional[tuple[str, T]]: + group = self._next_group_to_read_from() + if group is None: + return None + + if len(self.buffers[group]) == 0: + return None + + cur_serial, item = self.buffers[group][0] + + # logger.debug( + # f"group: {group}, cur_serial: {cur_serial}, totals_written: {self._totals_written[group]}," + # f" totals_expected: {self._totals_expected.get(group)}" + # ) + + if cur_serial > self._totals_written[group]: + return None + elif cur_serial < self._totals_written[group]: + raise ValueError(f"Duplicate serial {cur_serial} for group {group}") + + heapq.heappop(self.buffers[group]) + logger.debug(f"Read item {cur_serial} from {group}") + + self._totals_written[group] += 1 + + if self._totals_written[group] == self._totals_expected[group]: + assert len(self.buffers[group]) == 0 + assert group in self._remaining_groups + self._remaining_groups.remove(group) + + self._current_group = (self._current_group + 1) % len(self.groups) + + return group, item + + def drain(self) -> Iterator[tuple[str, T]]: + while True: + item = self.pop() + if item is None: + break + yield item + + def _next_group_to_read_from(self): + """ + Returns the next group to read from. This is always the group with the least that is not finished. + """ + if len(self._remaining_groups) == 0: + return None + + # careful: this is only correct if self._current_group is correct. whenever we fast forward, we have to + # recompute it + while True: + group = self.groups[self._current_group] + if group not in self._remaining_groups: + assert self._totals_written[group] == self._totals_expected[group] + assert len(self.buffers[group]) == 0 + self._current_group = (self._current_group + 1) % len(self.groups) + else: + break + return group + + def fast_forward(self, group, num_rows): + """ + Fast forwards the buffer for a group to a certain number of rows. This sets the "next" item to be the + num_rows-th item. + """ + if group not in self.groups: + raise ValueError(f"Group {group} not in {self.groups}") + + if self._totals_written[group] != 0: + raise ValueError(f"Group {group} already written to: {self._totals_written[group]}") + + self._totals_written[group] = num_rows + + self._fix_current_group() + + def _fix_current_group(self): + # This is always the minimum total written group that is not finished + self._current_group = 0 + min_total = None + + for i, group in enumerate(self.groups): + if group not in self._remaining_groups: + continue + total = self._totals_written[group] + if min_total is None or total < min_total: + min_total = total + self._current_group = i + + def next_missing_item_index(self): + """ + Returns the index of the next item that is not in the buffer + (i.e. what's stopping us from yielding the next item). + """ + if len(self._remaining_groups) == 0: + return None + + group = self.groups[self._current_group] + if group not in self._remaining_groups: + self._fix_current_group() + return self.next_missing_item_index() + + if len(self.buffers[group]) == 0: + return self._totals_written[group] + + cur_serial, _ = self.buffers[group][0] + + if cur_serial > self._totals_written[group]: + return self._totals_written[group] + elif cur_serial < self._totals_written[group]: + raise ValueError(f"Duplicate serial {cur_serial} for group {group}") + + return None + + +def div_round_up(x, y): + return (x + y - 1) // y diff --git a/src/levanter/store/jagged_array.py b/src/levanter/store/jagged_array.py new file mode 100644 index 000000000..8b3a26a54 --- /dev/null +++ b/src/levanter/store/jagged_array.py @@ -0,0 +1,508 @@ +import asyncio +import os +from dataclasses import dataclass +from typing import Optional, Sequence + +import fsspec.core +import jax +import jax.experimental.array_serialization.serialization as ser +import jax.numpy as jnp +import numpy as np +import tensorstore as ts + +from levanter.utils import fsspec_utils +from levanter.utils.thread_utils import future_from_value + + +# zarr suggests 1MB chunk size (in bytes, but whatever) +# at 4 bytes this is 256k elements +DEFAULT_CHUNK_SIZE = 256 * 1024 +DEFAULT_WRITE_CHUNK_SIZE = DEFAULT_CHUNK_SIZE * 512 + + +@dataclass +class JaggedArrayStore: + """ + A jagged array is a collection of arrays of varying lengths. + We represent this as a single array with an accompanying array of offsets. + + Note that JAX doesn't really support jagged arrays, so we have to be careful about how we use them. + Typically, we just use these for data loading. + + PERFORMANCE: accessing an individual row (or a single small slice of the underlying data) is very slow. + Where ever possible, use get_batch to get multiple rows at once for as large a batch as possible. + High latency, but high throughput. + """ + + offsets: ts.TensorStore # offsets of the start of each array, except that index[0] is the number of arrays + data: ts.TensorStore + shapes: Optional[ts.TensorStore] # (len(offsets), len(data.shape)-1) + item_rank: int = 1 + + @staticmethod + async def open_async(path: Optional[str], *, mode="a", item_rank=1, dtype) -> "JaggedArrayStore": + offset_path = _extend_path(path, "offsets") + offsets = _ts_open_async(offset_path, jnp.int64, [1], mode=mode) + + data_path = _extend_path(path, "data") + data = _ts_open_async(data_path, dtype, [0], mode=mode) + + if item_rank > 1: + shape_path = _extend_path(path, "shapes") + shapes = _ts_open_async(shape_path, jnp.int64, [0, item_rank - 1], mode=mode) + else: + shapes = None + + return JaggedArrayStore(await offsets, await data, await shapes if shapes is not None else None, item_rank) + + @staticmethod + def open(path: Optional[str], *, mode="a", item_rank=1, dtype) -> "JaggedArrayStore": + offset_path = _extend_path(path, "offsets") + offsets = _ts_open_sync(offset_path, jnp.int64, [1], mode=mode) + + data_path = _extend_path(path, "data") + data = _ts_open_sync(data_path, dtype, [0], mode=mode) + + if item_rank > 1: + shape_path = _extend_path(path, "shapes") + shapes = _ts_open_sync(shape_path, jnp.int64, [0, item_rank - 1], mode=mode) + else: + shapes = None + + return JaggedArrayStore(offsets, data, shapes, item_rank) + + @property + def num_rows(self): + return int(self.offsets[0].read().result()) + + async def num_rows_async(self): + return int(await self.offsets[0].read()) + + @property + def data_size(self): + return int(self.offsets[self.num_rows].read().result()) + + async def append_async(self, data: jax.Array): + await self.extend_async([data]) + + def append(self, data: jax.Array): + self.extend([data]) + + async def trim_to_size_async(self, size: int): + """ + Trims so we have exactly `size` rows in the jagged array. + """ + if size >= len(self): + return + + current_data_size = self.data_size + current_num_rows = await self.num_rows_async() + + offsets_fut = self.offsets[size + 1 : current_num_rows + 1].write(0) + + if size == 0: + new_max = 0 + else: + new_max = int(await self.offsets[size].read()) + + f1 = self.offsets[0].write(size) + + # Trim the shapes + if self.shapes is not None: + shape_fut = self.shapes[size:current_num_rows].write( + np.zeros(self.shapes.shape[1:], dtype=self.shapes.dtype.name) + ) + else: + shape_fut = None + + data_fut = self.data[new_max:current_data_size].write(np.zeros((), dtype=self.data.dtype.name)) + await f1 + + await shape_fut if shape_fut is not None else None + await data_fut + await offsets_fut + + def trim_to_size(self, size: int): + if size >= self.num_rows: + return + + old_len = len(self) + old_data_size = self.data_size + + if self.shapes is not None: + shape_fut = self.shapes[size:old_len].write(np.zeros(self.shapes.shape[1:], dtype=self.shapes.dtype.name)) + else: + shape_fut = None + + f1 = self.offsets[0].write(size) + + if size == 0: + new_max = 0 + else: + new_max = int(self.offsets[size].read().result()) + data_fut = self.data[new_max:old_data_size].write(np.zeros((), dtype=self.data.dtype.name)) + + f1.result() + offsets_fut = self.offsets[size + 1 : old_data_size + 1].write(0) + + data_fut.result() + offsets_fut.result() + + if shape_fut is not None: + shape_fut.result() + + async def extend_async(self, arrays: Sequence[jax.Array]): + data, new_offsets, shapes = self._prepare_batch(arrays) + + num_rows = await self.num_rows_async() + num_added = len(arrays) + current_data_size = self.data_size + + # Write to resized arrays concurrently, adjusting offsets explicitly + write_tasks = [ + self.data[current_data_size : current_data_size + len(data)].write(data), + self.offsets[num_rows + 1 : num_rows + num_added + 1].write(new_offsets), + ] + if self.shapes is not None: + write_tasks.append(self.shapes[num_rows : num_rows + num_added].write(shapes)) + + await asyncio.gather(*write_tasks) + + # Update num_rows + int(self.offsets[self.num_rows].read().result()) + await self.offsets[0].write(num_rows + len(arrays)) + # print("done") + + def extend(self, arrays: Sequence[jax.Array]): + data, new_offsets, shapes = self._prepare_batch(arrays) + + num_rows = self.num_rows + num_added = len(arrays) + current_data_size = self.data_size + + write_tasks = [ + self.data[current_data_size : current_data_size + len(data)].write(data), + self.offsets[num_rows + 1 : num_rows + num_added + 1].write(new_offsets), + ] + if self.shapes is not None: + write_tasks.append(self.shapes[num_rows : num_rows + num_added].write(shapes)) + + for task in write_tasks: + task.result() + + # Update num_rows. We want to make sure this comes after the other data is committed to avoid a race + self.offsets[0].write(num_rows + len(arrays)).result() + + def _prepare_batch(self, arrays): + if self.shapes is not None: + for data in arrays: + if data.ndim != self.item_rank: + raise ValueError(f"Expected data to have rank {self.item_rank}, got {data.ndim}") + shapes = np.array([data.shape[:-1] for data in arrays], dtype=np.int64) + else: + for data in arrays: + if data.ndim > 1: + raise ValueError(f"Expected data to have rank 1, got {data.ndim}") + shapes = None + new_offsets = np.array([data.size for data in arrays], dtype=np.int64) + new_offsets = np.cumsum(new_offsets) + self.data_size + data = np.concatenate([data.reshape(-1) for data in arrays]) + return data, new_offsets, shapes + + async def reload_async(self) -> "JaggedArrayStore": + """ + Calls `resolve` on the underlying tensorstore objects, updating size information + + @return: new JaggedArrayStore with resolved tensorstores + """ + offsets = ts.open(_unshaped_spec(self.offsets)) + data = ts.open(_unshaped_spec(self.data)) + shapes = future_from_value(None) if self.shapes is None else ts.open(_unshaped_spec(self.shapes.spec())) + + offsets, data, shapes = await asyncio.gather(offsets, data, shapes) + + return JaggedArrayStore(offsets, data, shapes, self.item_rank) + + def reload(self) -> "JaggedArrayStore": + offsets = ts.open(_unshaped_spec(self.offsets)) + data = ts.open(_unshaped_spec(self.data)) + shapes = None if self.shapes is None else ts.open(_unshaped_spec(self.shapes.spec())).result() + + offsets = offsets.result() + data = data.result() + + return JaggedArrayStore(offsets, data, shapes, self.item_rank) + + def __len__(self): + return self.num_rows + + async def get_item_async(self, item): + if isinstance(item, slice): + raise NotImplementedError("Slicing not supported") + len_self = await self.num_rows_async() + start, stop, step = item.indices(len_self) + if step != 1: + raise ValueError("JaggedArrayStore doesn't support slicing with step != 1") + shapes = None if self.shapes is None else self.shapes[start:stop] + # NB: JaggedArray not JaggedArrayStore + # TODO: use a transformed TS? + data_start, data_stop, offsets = await self._bounds_for_rows_async(start, stop) + new_offsets = offsets - offsets[0] + return JaggedArray(new_offsets, await self.data[data_start:data_stop].read(), shapes) + else: + try: + start, stop, _ = await self._bounds_for_rows_async(item, item + 1) + data = await self.data[start:stop].read() + + if self.shapes is not None: + shapes = np.array(self.shapes[item]) + data = data.reshape(*shapes, -1) + return data + except ValueError as e: + # ts raises a value error for an index out of bounds OUT_OF_RANGE + if "OUT_OF_RANGE" in str(e): + raise IndexError(f"JaggedArrayStore index out of range: {item}") from e + else: + raise e + + async def get_batch(self, indices: Sequence[int]) -> Sequence[jax.Array]: + # get indices + with ts.Batch(): + all_indices_futs = [self._bounds_for_rows_async(indices[i], indices[i] + 1) for i in range(len(indices))] + + # shapes, if applicable + if self.shapes is not None: + with ts.Batch(): + shapes_futs = [self.shapes[i].read() for i in indices] + + all_indices = [(start, stop) for start, stop, _ in await asyncio.gather(*all_indices_futs)] + + # get data + with ts.Batch(): + data_futs = [self.data[start:stop].read() for start, stop in all_indices] + + data = await asyncio.gather(*data_futs) + + if self.shapes is not None: + shapes = await asyncio.gather(*shapes_futs) + + data = [d.reshape(*s, -1) for d, s in zip(data, shapes)] + + return data + + def get_batch_sync(self, indices: Sequence[int]) -> Sequence[jax.Array]: + all_indices = self._bounds_for_rows_batch(indices) + + with ts.Batch(): + # shapes, if applicable + if self.shapes is not None: + shapes_futs = [self.shapes[i].read() for i in indices] + + data_futs = [self.data[start:stop].read() for start, stop in all_indices] + + data = [d.result() for d in data_futs] + + if self.shapes is not None: + shapes = [s.result() for s in shapes_futs] # noqa + data = [d.reshape(*s, -1) for d, s in zip(data, shapes)] + + return data + + def __getitem__(self, item): + if isinstance(item, slice): + # raise NotImplementedError("Slicing not supported") + # # TODO: do we need to avoid reading len(self)? + # start, stop, step = item.indices(len(self)) + # if step != 1: + # raise ValueError("JaggedArrayStore doesn't support slicing with step != 1") + # shapes = None if self.shapes is None else self.shapes[start:stop] + # # NB: JaggedArray not JaggedArrayStore + # # TODO: use a transformed TS? + # data_start, data_stop, offsets = self._bounds_for_rows(start, stop) + # new_offsets = offsets - offsets[0] + # return JaggedArray(new_offsets, self.data[data_start:data_stop].read().result(), shapes) + start, stop, step = item.indices(len(self)) + # for now, just read the data into a list + + return [self[i] for i in range(start, stop, step)] + else: + try: + start, stop, _ = self._bounds_for_rows(item, item + 1) + data = self.data[start:stop].read().result() + + if self.shapes is not None: + shapes = np.array(self.shapes[item]) + data = data.reshape(*shapes, -1) + return data + except ValueError as e: + # ts raises a value error for an index out of bounds OUT_OF_RANGE + if "OUT_OF_RANGE" in str(e): + raise IndexError(f"JaggedArrayStore index out of range: {item}") from e + else: + raise e + + def _bounds_for_rows(self, start, stop): + num_rows = self.num_rows + if start >= num_rows or stop > num_rows: + raise IndexError("Index out of bounds") + start, stop, step = slice(start, stop).indices(num_rows) + offsets = self.offsets[start : stop + 1].read().result() + data_start, data_stop = offsets[0], offsets[-1] + if start == 0: + # The first offset is the number of rows + data_start = 0 + offsets[0] = 0 + + return data_start, data_stop, offsets + + def _bounds_for_rows_batch(self, indices): + num_rows = self.num_rows + offsets_futs: list = [] + + zero_pos = None + + with ts.Batch(): + for index in indices: + if index >= num_rows or index < 0: + raise IndexError("Index out of bounds") + offsets = self.offsets[index : index + 2].read() + offsets_futs.append(offsets) + + if index == 0: + zero_pos = len(offsets_futs) - 1 + + offsets = [fut.result() for fut in offsets_futs] + offsets = [(offset[0], offset[-1]) for offset in offsets] + + if zero_pos is not None: + offsets[zero_pos] = [0, offsets[zero_pos][1]] + + return offsets + + async def _bounds_for_rows_async(self, start, stop): + offsets = await self.offsets[start : stop + 1].read() + data_start, data_stop = offsets[0], offsets[-1] + if start == 0: + # The first offset is the number of rows + data_start = 0 + offsets[0] = 0 + + return data_start, data_stop, offsets + + +def _unshaped_spec(store: ts.TensorStore) -> ts.Spec: + spec = store.spec(retain_context=True) + return spec + + +def _ts_open_sync(path: Optional[str], dtype: jnp.dtype, shape, *, mode): + spec = _get_spec(path, shape) + mode = _mode_to_open_mode(mode) + + # Basically, we want to load the existing shape metadata if it exists + if not mode.get("delete_existing", False): + try: + return ts.open(spec, **mode).result() + except FileNotFoundError: + pass + except ValueError: + pass + + # TODO: groups? + # TODO: set chunk sizes + try: + return ts.open( + spec, + dtype=jnp.dtype(dtype).name, + shape=[2**54, *shape[1:]], + # chunk_layout=ts.ChunkLayout( + # read_chunk_shape=[DEFAULT_CHUNK_SIZE, *shape[1:]], + # write_chunk_shape=[DEFAULT_WRITE_CHUNK_SIZE, *shape[1:]] + # ), + # compression={"codec": "zstd", "compression_level": 5}, + **mode, + ).result() + except ValueError as e: + if "NOT_FOUND" in str(e): + raise FileNotFoundError(f"File not found: {path}") from e + else: + raise e + + +async def _ts_open_async(path: Optional[str], dtype: jnp.dtype, shape, *, mode): + spec = _get_spec(path, shape) + mode = _mode_to_open_mode(mode) + + # Basically, we want to load the existing shape metadata if it exists + if not mode.get("delete_existing", False): + try: + return await ts.open(spec, **mode) + except FileNotFoundError: + pass + except ValueError: + pass + + # TODO: groups? + # TODO: set chunk sizes + return await ts.open( + spec, + dtype=jnp.dtype(dtype).name, + shape=[2**54, *shape[1:]], + # chunk_layout=ts.ChunkLayout( + # read_chunk_shape=[DEFAULT_CHUNK_SIZE, *shape[1:]], + # write_chunk_shape=[DEFAULT_WRITE_CHUNK_SIZE, *shape[1:]] + # ), + # compression={"codec": "zstd", "compression_level": 5}, + **mode, + ) + + +def _get_spec(path, shape): + if path is None: + import uuid + + random_name = str(uuid.uuid4()) + spec = ts.Spec({"driver": "zarr", "kvstore": f"memory://{random_name}"}) + else: + # make path absolute if it's not already + protocol, _ = fsspec.core.split_protocol(path) + if protocol is None: + path = os.path.abspath(path) + spec = ser.get_tensorstore_spec(path, ocdbt=False) + store = spec.get("kvstore") + spec = {"driver": "zarr3", "kvstore": store} + fsspec_utils.mkdirs(os.path.dirname(path)) + spec["metadata"] = { + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": [DEFAULT_WRITE_CHUNK_SIZE, *shape[1:]]}, + }, + "codecs": [ + { + "name": "sharding_indexed", + "configuration": { + "chunk_shape": [DEFAULT_CHUNK_SIZE, *shape[1:]], + "codecs": [{"name": "blosc", "configuration": {"clevel": 5}}], + }, + } + ], + } + return spec + + +def _mode_to_open_mode(mode: str): + if mode == "r": + return {"open_mode": ts.OpenMode(open=True)} + elif mode == "w": + return {"open_mode": ts.OpenMode(create=True, delete_existing=True)} + elif mode == "a": + return {"open_mode": ts.OpenMode(create=True, open=True, delete_existing=False)} + else: + raise ValueError(f"Invalid mode: {mode}") + + +def _extend_path(path: Optional[str], extra: str): + if path == "memory" or path is None: + return path + else: + return os.path.join(path, extra) diff --git a/src/levanter/store/stress_test_new_cache.py b/src/levanter/store/stress_test_new_cache.py new file mode 100644 index 000000000..c583ede56 --- /dev/null +++ b/src/levanter/store/stress_test_new_cache.py @@ -0,0 +1,149 @@ +# Reads an old-style ShardCache and compares to +import asyncio +import logging +import os + +import jax.random +import numpy as np +import tensorstore as ts + +from levanter.data import PermutationDataset +from levanter.data.text import TokenSeqDataset +from levanter.store.cache import LEDGER_FILE_NAME, CacheLedger, TreeCache, _serialize_json_and_commit +from levanter.store.tree_store import TreeStore +from levanter.tracker import capture_time +from levanter.utils import fsspec_utils + + +logging.basicConfig(level=logging.INFO) + + +SEQ_LEN = 1024 +BS = 8 +BATCHES = 1000 + +# want to test reading from: +# 1) old cache sequentially +# 2) new cache sequentially +# 3) new cache randomly + + +def bench_new_cache_serial(exemplar, new_cache_path): + jagged_array = TreeStore.open(exemplar, new_cache_path).tree["input_ids"] + len_cache = jagged_array.data_size + new_cache = jagged_array.data + num_batches = len_cache // SEQ_LEN + for b in range(BATCHES): + elems = [] + with ts.Batch(): + for j in range(BS): + idx = b * BS + j + idx = idx % num_batches + arr1 = new_cache[idx * SEQ_LEN : (idx + 1) * SEQ_LEN].read() + elems.append(arr1) + + for elem in elems: + elem.result() + + +def bench_new_cache_random(exemplar, new_cache_path): + jagged_array = TreeStore.open(exemplar, new_cache_path).tree["input_ids"] + len_cache = jagged_array.data_size + new_cache = jagged_array.data + num_batches = len_cache // SEQ_LEN + for b in range(BATCHES): + elems = [] + with ts.Batch(): + for j in range(BS): + idx = np.random.randint(0, num_batches) + arr1 = new_cache[idx * SEQ_LEN : (idx + 1) * SEQ_LEN].read() + elems.append(arr1) + + for elem in elems: + elem.result() + + +async def bench_new_cache_serial_tokenseq(exemplar, new_cache_path): + ensure_cache(new_cache_path) + cache = TreeCache.load(new_cache_path, exemplar) + + ds = TokenSeqDataset(cache, SEQ_LEN) + + num_batches = await ds.async_len() + + for b in range(BATCHES): + indices = [] + for j in range(BS): + idx = b * BS + j + idx = idx % num_batches + indices.append(idx) + elems = await ds.get_batch(indices) + del elems + + +async def bench_new_cache_permutation_random(exemplar, new_cache_path): + ensure_cache(new_cache_path) + cache = TreeCache.load(new_cache_path, exemplar) + + ds = TokenSeqDataset(cache, SEQ_LEN) + ds = PermutationDataset(ds, jax.random.PRNGKey(0)) + + num_batches = await ds.async_len() + + for b in range(BATCHES): + indices = [] + for j in range(BS): + idx = b * BS + j + idx = idx % num_batches + indices.append(idx) + elems = await ds.get_batch(indices) + del elems + + +def ensure_cache(new_cache_path): + if not fsspec_utils.exists(os.path.join(new_cache_path, LEDGER_FILE_NAME)): + ledger = CacheLedger(100000, {}, True) + _serialize_json_and_commit(os.path.join(new_cache_path, LEDGER_FILE_NAME), ledger) + + +if __name__ == "__main__": + import sys + + if not len(sys.argv) == 3: + print("Usage: convert_to_new_cache.py old_cache_path new_cache_path") + sys.exit(1) + + for split in ["validation", "train"]: + print(f"Split: {split}", flush=True) + in_path = os.path.join(sys.argv[1], split) + out_path = os.path.join(sys.argv[2], split) + # convert_to_new_cache(in_path, out_path) + # with capture_time() as time_fn: + # bench_old_cache(in_path) + # tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn() + # print(f"Old Cache: {time_fn()} ({tokens_per_second} tps)", flush=True) + + exemplar = {"input_ids": np.zeros((SEQ_LEN,), dtype=np.int32)} + + with capture_time() as time_fn: + bench_new_cache_serial(exemplar, out_path) + tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn() + print(f"New Cache Serial: {time_fn()} ({tokens_per_second} tps)", flush=True) + + with capture_time() as time_fn: + asyncio.run(bench_new_cache_serial_tokenseq(exemplar, out_path)) + tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn() + + print(f"New Cache Serial TokenSeq: {time_fn()} ({tokens_per_second} tps)", flush=True) + + with capture_time() as time_fn: + bench_new_cache_random(exemplar, out_path) + tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn() + + print(f"New Cache Random: {time_fn()} ({tokens_per_second} tps)", flush=True) + + with capture_time() as time_fn: + asyncio.run(bench_new_cache_permutation_random(exemplar, out_path)) + tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn() + + print(f"New Cache Permutation: {time_fn()} ({tokens_per_second} tps)", flush=True) diff --git a/src/levanter/store/tree_store.py b/src/levanter/store/tree_store.py new file mode 100644 index 000000000..0b1e93bff --- /dev/null +++ b/src/levanter/store/tree_store.py @@ -0,0 +1,237 @@ +import asyncio +import os +from typing import Generic, List, TypeVar + +import jax +import jax.numpy as jnp +import jax.tree_util as jtu +import numpy as np +from jaxtyping import PyTree + +from haliax.jax_utils import is_jax_array_like + +from .jagged_array import JaggedArrayStore + + +T = TypeVar("T", bound=PyTree) + + +# TODO at some point if we turn this into a real library, it would be nice to store the schema +# TODO: some data is probably best not stored as a jagged array, but as a flat array? +# TODO: also sometimes we might want a rowstore actually + + +def heuristic_is_leaf(x): + if isinstance(x, list): + return jnp.isscalar(x[0]) + else: + return False + + +def heuristic_is_leaf_batched(x): + if isinstance(x, list): + return jnp.isscalar(x[0]) or is_jax_array_like(x[0]) + else: + return False + + +class TreeStore(Generic[T]): + """ + A TreeStoreBuilder stores batched data as a tree of ragged arrays. + """ + + path: str + mode: str + tree: PyTree[JaggedArrayStore] + + def __init__(self, tree, path: str, mode: str): + self.path = path + self.mode = mode + self.tree = tree + + @staticmethod + def open(exemplar: T, path: str, *, mode="a") -> "TreeStore": + """ + Open a TreeStoreBuilder from a file. + """ + tree = _construct_builder_tree(exemplar, path, mode) + return TreeStore(tree, path, mode) + + def append(self, ex: T): + return self.extend([ex]) + + def extend(self, batch: List[T]): + """ + Append a batch of data to the store. + """ + # TODO: I do wish zarr supported async + jtu.tree_map( + lambda writer, *xs: writer.extend([np.asarray(x) for x in xs]), + self.tree, + *batch, + is_leaf=heuristic_is_leaf, + ) + + def extend_with_batch(self, batch: T): + """ + Append a batch of data (as a pytree with batched leaves) to the store. + + This method works only when the "leaves" are lists of numpy arrays or scalars. + For instance, HF's BatchEncoding is a dict of lists of numpy arrays. + """ + jtu.tree_map( + lambda writer, xs: writer.extend([np.asarray(x) for x in xs]), + self.tree, + batch, + is_leaf=heuristic_is_leaf_batched, + ) + + async def extend_with_batch_async(self, batch: T): + """ + Append a batch of data (as a pytree with batched leaves) to the store. + + This method works only when the "leaves" are lists of numpy arrays or scalars. + For instance, HF's BatchEncoding is a dict of lists of numpy arrays. + """ + futures = jtu.tree_map( + lambda writer, xs: writer.extend_async([np.asarray(x) for x in xs]), + self.tree, + batch, + is_leaf=heuristic_is_leaf_batched, + ) + + await asyncio.gather(*jax.tree_leaves(futures)) + + def trim_to_size(self, size: int): + """ + Trim the store to a given size. + """ + # TODO These all return ts Futures so in theory we could await them all at once + jtu.tree_map(lambda writer: writer.trim_to_size(size), self.tree, is_leaf=heuristic_is_leaf) + + async def trim_to_size_async(self, size: int): + """ + Trim the store to a given size. + """ + futures = jtu.tree_map(lambda writer: writer.trim_to_size_async(size), self.tree, is_leaf=heuristic_is_leaf) + leaves, structure = jax.tree_flatten(futures) + + await asyncio.gather(*leaves) + + def reload(self) -> "TreeStore": + """ + Close the builder and return a TreeStore. + """ + tree = jtu.tree_map(lambda builder: builder.reload(), self.tree, is_leaf=heuristic_is_leaf) + return TreeStore(tree, self.path, self.mode) + + def __len__(self): + if self.tree is None: + return 0 + else: + return len(jax.tree.leaves(self.tree)[0]) + + async def get_batch(self, indices) -> List[T]: + grouped = jtu.tree_map(lambda reader: reader.get_batch(indices), self.tree, is_leaf=heuristic_is_leaf) + + leaves, structure = jtu.tree_flatten(grouped, is_leaf=heuristic_is_leaf) + + awaited_leaves = await asyncio.gather(*leaves) + return [jtu.tree_unflatten(structure, [leaf[i] for leaf in awaited_leaves]) for i in range(len(indices))] + + def __getitem__(self, item): + if self.tree is None: + raise IndexError("No data in store") + elif isinstance(item, slice): + # debatch + leaves, structure = jax.tree.flatten(self.tree, is_leaf=heuristic_is_leaf) + # batched_items = jtu.tree_map(lambda reader: reader[item], self.tree, is_leaf=heuristic_is_leaf) + batched_item_leaves = [leaf[item] for leaf in leaves] + num_items = len(leaves[0]) + return [jtu.tree_unflatten(structure, [leaf[i] for leaf in batched_item_leaves]) for i in range(num_items)] + else: + return jtu.tree_map(lambda reader: reader[item], self.tree, is_leaf=heuristic_is_leaf) + + def __iter__(self): + if self.tree is None: + return + else: + for i in range(len(self)): + yield self[i] + + def get_batch_sync(self, indices) -> List[T]: + # TODO: would be better to batch these up + grouped = jtu.tree_map(lambda reader: reader.get_batch_sync(indices), self.tree, is_leaf=heuristic_is_leaf) + + out = [jtu.tree_map(lambda _, leaf: leaf[i], self.tree, grouped) for i in range(len(indices))] + + return out + + +def _construct_builder_tree(exemplar, path, mode): + def open_builder(tree_path, item): + item = np.asarray(item) + rank = item.ndim + render_tree_path = "/".join(_render_path_elem(x) for x in tree_path) + return JaggedArrayStore.open(os.path.join(path, render_tree_path), mode=mode, item_rank=rank, dtype=item.dtype) + + return jtu.tree_map_with_path(open_builder, exemplar, is_leaf=heuristic_is_leaf) + + +def _render_path_elem(x): + match x: + case jtu.DictKey(key): + return f"{key}" + case jtu.GetAttrKey(key): + return f"{key}" + case jtu.SequenceKey(i): + return f"{i}" + case jtu.FlattenedIndexKey(i): + return f"{i}" + case _: + return str(x) + + +# class TokenSeqDataset: +# """ +# A dataset of sequences of tokens of fixed length, materialized from a collection of JaggedArrayStores, +# which have typically much longer sequences. This class takes consecutive sequences of tokens from the builders +# and slices/concats them to form the dataset. +# """ +# +# def __init__( +# self, token_arrays: Sequence[JaggedArrayStore], token_counts: Sequence[int], seq_len: int, pad_token: int +# ): +# self.token_arrays = token_arrays +# +# def _round_to_nearest_multiple(x, y): +# return x + y - x % y +# +# token_counts_padded = np.array([_round_to_nearest_multiple(x, seq_len) for x in token_counts]) +# seq_counts = token_counts_padded // seq_len +# self.seq_counts_cumsum = np.concatenate([np.asarray([0]), np.cumsum(seq_counts)]) +# +# self.seq_len = seq_len +# self.pad_token = pad_token +# +# def __len__(self): +# return self.seq_counts_cumsum[-1] +# +# def __getitem__(self, seq_id): +# return asyncio.run(self.get_item_async(seq_id)) +# +# async def get_item_async(self, seq_id): +# # TODO: accept slices and such? +# shard_id = np.searchsorted(self.seq_counts_cumsum, seq_id, side="right") - 1 +# shard_start = self.seq_counts_cumsum[shard_id] +# shard_end = self.seq_counts_cumsum[shard_id + 1] +# shard_seq_id = seq_id - shard_start +# +# shard_seq_start = shard_seq_id * self.seq_len +# shard_seq_end = min((shard_seq_id + 1) * self.seq_len, self.token_arrays[shard_id].data_size) +# +# shard_seq = await self.token_arrays[shard_id].data[shard_seq_start:shard_seq_end].read() +# pad_len = self.seq_len - (shard_seq_end - shard_seq_start) +# padded_seq = np.concatenate([shard_seq, np.full(pad_len, self.pad_token, dtype=shard_seq.dtype)]) +# +# return padded_seq diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index 1b0254261..1e95c0d3a 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -155,7 +155,7 @@ def init(self, run_id: Optional[str]) -> WandbTracker: if jax.process_count() > 1: # we need to share wandb run information across all hosts, because we use it for checkpoint paths and things metadata_to_share = dict( - entity=r.entity, + # entity=r.entity, project=r.project, name=r.name, tags=r.tags, @@ -166,10 +166,10 @@ def init(self, run_id: Optional[str]) -> WandbTracker: metadata_to_share, is_source=jax.process_index() == 0 ) - if jax.process_index() != 0: - assert r.mode == "disabled" - for k, v in metadata_to_share.items(): - setattr(r, k, v) + # if jax.process_index() != 0: + # assert r.mode == "disabled", f"Only the primary worker should be using wandb. Got {r.mode}" + # for k, v in metadata_to_share.items(): + # setattr(r, k, v) logger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index ef870382b..69c932cd9 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -48,7 +48,7 @@ from levanter import tracker from levanter.checkpoint import CheckpointerConfig, load_checkpoint_or_initialize from levanter.config import JsonAtom -from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader +from levanter.data import AsyncDataset, DataLoader from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import microbatched from levanter.tracker import TrackerConfig, capture_time @@ -433,7 +433,7 @@ def _add_default_hooks(self): def add_eval_hook(self, eval_dataset, name: Optional[str] = None): from levanter import callbacks - eval_loader = self.replicated_loader(eval_dataset, self.EvalBatch) + eval_loader = self.data_loader(eval_dataset, self.EvalBatch) if eval_loader and (self.config.max_eval_batches is None or self.config.max_eval_batches > 0): @@ -450,31 +450,24 @@ def eval_loss(model, *batch, **batch_kwargs): every=self.config.steps_per_eval, ) - def replicated_loader(self, dataset: Dataset[X], batch_axis: Axis) -> ReplicatedBatchLoader[X]: - """Creates a replicated batch loader for the given dataset. Generally you should use this - if you either be able to make a single pass over the dataset. + def data_loader(self, dataset: AsyncDataset[X], batch_axis: Axis) -> DataLoader[X]: + """Creates a data loader for the given dataset and batch axis. Args: - dataset (Dataset): the dataset to load + dataset (AsyncDataset): the dataset to load batch_axis (Axis): the batch axis Returns: - ReplicatedBatchLoader: the batch loader + DataLoader: the data loader """ - return ReplicatedBatchLoader(dataset, self.device_mesh, batch_axis, self.compute_axis_mapping) - - def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> ShardedBatchLoader[X]: - """Creates a sharded batch loader for the given dataset. Generally you should use this - for training and you don't care about epoch boundaries. - - Args: - dataset (Dataset): the dataset to load - batch_axis (Axis): the batch axis - - Returns: - ShardedBatchLoader: the batch loader - """ - return ShardedBatchLoader(dataset, self.device_mesh, batch_axis, self.compute_axis_mapping) + return DataLoader( + batch_axis, + dataset, + max_buffered_batches=128, + mesh=self.device_mesh, + axis_resources=self.compute_axis_mapping, + prefetch_size=32, + ) @cached_property def _jit_train_step_fn(self): diff --git a/src/levanter/utils/background_iterable.py b/src/levanter/utils/background_iterable.py index 6bb200873..84c5a7789 100644 --- a/src/levanter/utils/background_iterable.py +++ b/src/levanter/utils/background_iterable.py @@ -1,7 +1,8 @@ +import asyncio import queue import sys import threading -from typing import Callable, Iterable, Iterator, Optional, TypeVar +from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, TypeVar, Union import tblib @@ -18,27 +19,41 @@ class BackgroundIterable(Iterable[Ex]): like running XLA kernels... """ - def __init__(self, producer_fn: Callable[[], Iterator[Ex]], max_capacity: Optional[int] = None): + def __init__( + self, + producer_fn: Callable[[], Union[Iterator[Ex], AsyncIterator[Ex]]], + max_capacity: Optional[int] = None, + ): self.max_capacity = max_capacity - self._stop_event = threading.Event() self._producer_fn = producer_fn def __iter__(self): - if self._stop_event.is_set(): - raise RuntimeError("Cannot iterate over a stopped BackgroundIterable") + return BackgroundIterator(self._producer_fn, self.max_capacity) + - q = queue.Queue(self.max_capacity) - thread = threading.Thread(target=self._fill_queue_with_batches, args=(q,)) - thread.daemon = True - thread.start() +class BackgroundIterator(Iterator[Ex]): + def __init__(self, producer_fn: Callable[[], Union[Iterator[Ex], AsyncIterator[Ex]]], max_capacity: Optional[int]): + self.max_capacity = max_capacity + self._producer_fn = producer_fn + self._stop_event = threading.Event() + self.q: queue.Queue = queue.Queue(self.max_capacity or 0) + self.thread = threading.Thread(target=self._fill_queue_with_batches) + self.thread.daemon = True + self.thread.start() + + def __iter__(self): + return self + def __next__(self): while not self._stop_event.is_set(): - batch = q.get() + batch = self.q.get() if batch is _SENTINEL: - break + raise StopIteration elif isinstance(batch, _ExceptionWrapper): batch.reraise() - yield batch + return batch + + raise StopIteration def __del__(self): self.stop() @@ -46,13 +61,44 @@ def __del__(self): def stop(self): self._stop_event.set() - def _fill_queue_with_batches(self, q): + def _fill_queue_with_batches(self): + try: + iterator = self._producer_fn() + if isinstance(iterator, Iterator): + self._produce_batches_sync(iterator) + else: + asyncio.run(self._produce_batches_async(iterator)) + except Exception: + self.q.put(_ExceptionWrapper(sys.exc_info())) + + def _produce_batches_sync(self, iterator): + try: + for batch in iterator: + while not self._stop_event.is_set(): + try: + self.q.put(batch, block=True, timeout=1) + break + except queue.Full: + pass + + if self._stop_event.is_set(): + break + + while not self._stop_event.is_set(): + try: + self.q.put(_SENTINEL, block=True, timeout=1) + break + except queue.Full: + pass + except Exception: + self.q.put(_ExceptionWrapper(sys.exc_info())) + + async def _produce_batches_async(self, iterator): try: - for batch in self._producer_fn(): - # we don't want to block forever because then we can't stop the thread + async for batch in iterator: while not self._stop_event.is_set(): try: - q.put(batch, block=True, timeout=1) + self.q.put(batch, block=True, timeout=1) break except queue.Full: pass @@ -62,13 +108,12 @@ def _fill_queue_with_batches(self, q): while not self._stop_event.is_set(): try: - q.put(_SENTINEL, block=True, timeout=1) + self.q.put(_SENTINEL, block=True, timeout=1) break except queue.Full: - # don't hold up the thread if we can't put the sentinel pass - except Exception: # flake8: noqa - q.put(_ExceptionWrapper(sys.exc_info())) + except Exception: + self.q.put(_ExceptionWrapper(sys.exc_info())) class _Sentinel: diff --git a/src/levanter/utils/fsspec_utils.py b/src/levanter/utils/fsspec_utils.py index c6adeb3e4..896ea8450 100644 --- a/src/levanter/utils/fsspec_utils.py +++ b/src/levanter/utils/fsspec_utils.py @@ -5,3 +5,9 @@ def exists(url, **kwargs) -> bool: """Check if a file exists on a remote filesystem.""" fs, path = fsspec.core.url_to_fs(url, **kwargs) return fs.exists(path) + + +def mkdirs(path): + """Create a directory and any necessary parent directories.""" + fs, path = fsspec.core.url_to_fs(path) + fs.makedirs(path, exist_ok=True) diff --git a/src/levanter/utils/index.py b/src/levanter/utils/index.py new file mode 100644 index 000000000..3e94ab9fb --- /dev/null +++ b/src/levanter/utils/index.py @@ -0,0 +1,46 @@ +from typing import Generic, Iterable, Iterator, TypeVar + + +T = TypeVar("T") + + +class Index(Generic[T]): + """ + Index is a bidirectional mapping from (incremental) integers to objects. + + Needs to be fast, so it exposes the underlying data structures. + """ + + def __init__(self, objs: Iterable[T] = ()): + self._index_to_obj: list[T] = [] + self._obj_to_index: dict[T, int] = {} + for obj in objs: + self.append(obj) + + def __len__(self): + return len(self._index_to_obj) + + def __getitem__(self, index: int) -> T: + return self._index_to_obj[index] + + def __setitem__(self, index: int, obj: T): + self._index_to_obj[index] = obj + self._obj_to_index[obj] = index + + def append(self, obj: T) -> int: + index = len(self) + self._index_to_obj.append(obj) + self._obj_to_index[obj] = index + return index + + def get_index(self, obj: T) -> int: + return self._obj_to_index[obj] + + def get_obj(self, index: int) -> T: + return self._index_to_obj[index] + + def __contains__(self, obj: T) -> bool: + return obj in self._obj_to_index + + def __iter__(self) -> Iterator[T]: + return iter(self._index_to_obj) diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index d159d7948..1d7205365 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -41,7 +41,9 @@ def use_cpu_device(): def local_cpu_mesh(): """Temporarily sets the default device to CPU""" cpu = jax.local_devices(backend="cpu")[0] - mesh = jax.sharding.Mesh(np.array([cpu]).reshape(1, 1), ("data", "model")) + mesh = jax.sharding.Mesh( + np.array([cpu]).reshape(1, 1, 1), (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL) + ) with use_cpu_device(), mesh: yield mesh diff --git a/src/levanter/utils/py_utils.py b/src/levanter/utils/py_utils.py index 5262aa75d..a796dd6af 100644 --- a/src/levanter/utils/py_utils.py +++ b/src/levanter/utils/py_utils.py @@ -1,4 +1,3 @@ -import asyncio import os import sys from dataclasses import dataclass @@ -182,9 +181,3 @@ def actual_sizeof(obj): need_to_see.extend(obj) objects = need_to_see return size - - -def future_from_value(value): - future = asyncio.Future() - future.set_result(value) - return future diff --git a/src/levanter/utils/ray_utils.py b/src/levanter/utils/ray_utils.py index 255968815..8a299720e 100644 --- a/src/levanter/utils/ray_utils.py +++ b/src/levanter/utils/ray_utils.py @@ -85,9 +85,11 @@ def _child_failed(self, child: ray.actor.ActorHandle, exception: ExceptionInfo): @contextlib.contextmanager -def log_failures_to(parent): +def log_failures_to(parent, suppress=False): # parent is actorref of SnitchRecipient try: yield except Exception as e: parent._child_failed.remote(current_actor_handle(), ser_exc_info(e)) + if not suppress: + raise e diff --git a/src/levanter/utils/thread_utils.py b/src/levanter/utils/thread_utils.py new file mode 100644 index 000000000..9c6e2ef36 --- /dev/null +++ b/src/levanter/utils/thread_utils.py @@ -0,0 +1,28 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor + + +# Create a ThreadPoolExecutor +_executor = ThreadPoolExecutor(max_workers=10) + + +def blocking_wait(coro): + """ + This will only work if there are fewer than 10 levels of nested coroutines... + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None and loop.is_running(): + future = _executor.submit(lambda: asyncio.run(coro)) + return future.result() + else: + return asyncio.run(coro) + + +def future_from_value(value): + future = asyncio.Future() + future.set_result(value) + return future diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_audio.py b/tests/test_audio.py index c9ae0d494..8d3015431 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -1,7 +1,11 @@ +import tempfile + +import pytest from datasets import load_dataset -from transformers import AutoProcessor +from transformers import AutoProcessor, AutoTokenizer from levanter.data.audio import AudioDatasetSourceConfig, AudioIODatasetConfig, BatchAudioProcessor +from levanter.store.cache import SerialCacheWriter from test_utils import skip_if_hf_model_not_accessible, skip_if_no_soundlibs @@ -9,8 +13,9 @@ @skip_if_hf_model_not_accessible("openai/whisper-tiny") def test_whisper_batch_processor(): processor = AutoProcessor.from_pretrained("openai/whisper-tiny") + tokenizer = AutoTokenizer.from_pretrained("openai/whisper-tiny") ds = load_dataset("WillHeld/test_librispeech_parquet", split="validation").select_columns(["audio", "text"]) - batch_processor = BatchAudioProcessor(processor) + batch_processor = BatchAudioProcessor(processor, tokenizer) inputs = [(audio["array"], audio["sampling_rate"], text) for audio, text in zip(ds[:16]["audio"], ds[:16]["text"])] batch_processor(inputs) @@ -37,12 +42,41 @@ def test_hf_audio_loading_source(): @skip_if_no_soundlibs @skip_if_hf_model_not_accessible("openai/whisper-tiny") -def test_hf_audio_ray_pipeline(): +@pytest.mark.asyncio +async def test_hf_audio_ray_pipeline(): + # Use the Real Librispeech Valudation. Testing one doesn't support streaming. + with tempfile.TemporaryDirectory() as tmpdir: + ac = AudioIODatasetConfig( + cache_dir=str(tmpdir), id="WillHeld/test_librispeech_parquet", text_key="text", max_length=1024 + ) + validation = ac.validation_set() + for i in range(10): + t = (await validation.get_batch([i]))[0] + assert t["input_features"].shape == (80, 3000), t["input_features"].shape + assert t["input_ids"].shape == (1024,), t["input_ids"].shape + assert t["attention_mask"].shape == (1024,), t["attention_mask"].shape + + +@skip_if_no_soundlibs +@skip_if_hf_model_not_accessible("openai/whisper-tiny") +def test_hf_audio_serial_cache(): # Use the Real Librispeech Valudation. Testing one doesn't support streaming. ac = AudioIODatasetConfig(id="WillHeld/test_librispeech_parquet", text_key="text") - audio_iterator = iter(ac.validation_set(batch_size=10)) - for i in range(10): - t = next(audio_iterator) - assert t["input_features"].shape == (80, 3000), t["input_features"].shape - assert t["input_ids"].shape == (1024,), t["input_ids"].shape - assert t["attention_mask"].shape == (1024,), t["attention_mask"].shape + + processor = AutoProcessor.from_pretrained("openai/whisper-tiny") + tokenizer = AutoTokenizer.from_pretrained("openai/whisper-tiny") + batch_processor = BatchAudioProcessor(processor, tokenizer, max_length=1024) + + with tempfile.TemporaryDirectory() as tmpdir: + with SerialCacheWriter(tmpdir, batch_processor.output_exemplar) as writer: + for i, ex in enumerate(ac.get_shard_source("validation")): + writer.write_batch(batch_processor([ex])) + if i > 10: + break + + cache = writer.result() + + for ex in cache.get_batch_sync(list(range(10))): + assert ex["input_features"].shape == (80, 3000), ex["input_features"].shape + assert ex["input_ids"].shape == (1024,), ex["input_ids"].shape + assert ex["attention_mask"].shape == (1024,), ex["attention_mask"].shape diff --git a/tests/test_background_iterable.py b/tests/test_background_iterable.py index ad768288c..0da8d6ea6 100644 --- a/tests/test_background_iterable.py +++ b/tests/test_background_iterable.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from levanter.utils.background_iterable import BackgroundIterable @@ -55,10 +57,76 @@ def ongoing_process(): for _ in range(5): next(iter1) - background_iterable.stop() + iter1.stop() # Try to get another item from the iterator (should raise StopIteration) # there's a bit of a race so we give it 2 tries, which is enough for the test with pytest.raises(StopIteration): next(iter1) next(iter1) + + +@pytest.mark.asyncio +async def test_async_reentrancy(): + async def async_producer(): + for i in range(1, 101): + yield i + await asyncio.sleep(0.01) + + background_iterable = BackgroundIterable(async_producer, max_capacity=10) + + iter1 = iter(background_iterable) + iter2 = iter(background_iterable) + + data1 = [item for item in iter1] + data2 = [item for item in iter2] + + assert data1 == data2 + assert data1 == list(range(1, 101)) + + +@pytest.mark.asyncio +async def test_async_empty_iteration(): + async def async_producer(): + if False: + yield + + background_iterable = BackgroundIterable(async_producer, max_capacity=10) + + data = list(background_iterable) + + assert data == [] + + +@pytest.mark.asyncio +async def test_async_exception_handling(): + async def async_producer_with_exception(): + raise ValueError("Something went wrong!") + yield 0 # have to make sure it's an async coroutine + + background_iterable = BackgroundIterable(async_producer_with_exception, max_capacity=10) + + with pytest.raises(ValueError): + for _ in background_iterable: + pass + + +@pytest.mark.asyncio +async def test_async_stop_event(): + async def ongoing_async_process(): + while True: + for item in range(1, 101): + yield item + + background_iterable = BackgroundIterable(ongoing_async_process, max_capacity=10) + + iter1 = iter(background_iterable) + + for _ in range(5): + next(iter1) + + iter1.stop() + + with pytest.raises(StopIteration): + await next(iter1) + await next(iter1) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 306bec9cd..f5ce0f774 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -1,7 +1,6 @@ import dataclasses import datetime import pathlib -import sys import tempfile from datetime import timedelta @@ -281,12 +280,6 @@ def init_fn(key): jax.tree_util.tree_leaves(arrays_only(loaded2)), ) - print(jax.tree_util.tree_leaves(loaded), file=sys.stderr) - print("M1", file=sys.stderr) - print(jax.tree_util.tree_leaves(model1), file=sys.stderr) - print("M0", file=sys.stderr) - print(jax.tree_util.tree_leaves(model0), file=sys.stderr) - assert_trees_all_equal( jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed))), jax.tree_util.tree_leaves(arrays_only(eqx.filter(model0, is_checkpointed))), diff --git a/tests/test_data_mixture.py b/tests/test_data_mixture.py deleted file mode 100644 index 2410a7d5f..000000000 --- a/tests/test_data_mixture.py +++ /dev/null @@ -1,126 +0,0 @@ -import tempfile - -import tiny_test_corpus -from levanter.data import Dataset -from levanter.data.mixture import MixtureDataset, StopStrategy -from levanter.data.text import TokenSeqDataset - - -class ListDataset(Dataset[list]): - def __init__(self, data: list): - self.data = data - - def __iter__(self): - return iter(self.data) - - -def test_stop_strategies(): - seq_len = 10 - - num_docs_1, num_docs_2 = 10, 20 - with tempfile.TemporaryDirectory() as tmpdir: - # source_1 = SingleShardDocumentSource(docs_1) - data_config, _ = tiny_test_corpus.construct_small_data_cache( - f"{tmpdir}/cache_1", num_shards=1, chunk_size=num_docs_1, doc_len=seq_len - ) - - data_config, _ = tiny_test_corpus.construct_small_data_cache( - f"{tmpdir}/cache_2", num_shards=1, chunk_size=num_docs_2, doc_len=seq_len - ) - - ds1 = TokenSeqDataset.load(seq_len, f"{tmpdir}/cache_1/cache/train") - ds2 = TokenSeqDataset.load(seq_len, f"{tmpdir}/cache_2/cache/train") - - # set reuseable config - datasets = {"1": ds1, "2": ds2} - # test mixture with all weights on one dataset - mixture_1_only = MixtureDataset( - datasets=datasets, - weights={"1": 1.0, "2": 0.0}, - stop_strategy=StopStrategy.FIRST_STOP_STRATEGY, - key=0, - ) - counter = 0 - for batch in mixture_1_only: - assert batch.shape == (seq_len,) - counter += 1 - assert counter == 10 - - # compare mixture with different strategies - mixture_balanced_first = MixtureDataset( - datasets=datasets, - weights={"1": 0.5, "2": 0.5}, - stop_strategy=StopStrategy.FIRST_STOP_STRATEGY, - key=0, - ) - counter_first = sum([1 for _ in mixture_balanced_first]) - - mixture_balanced_all = MixtureDataset( - datasets=datasets, - weights={"1": 0.5, "2": 0.5}, - stop_strategy=StopStrategy.ALL_STOP_STRATEGY, - key=0, - ) - counter_all = sum([1 for _ in mixture_balanced_all]) - assert counter_first < counter_all - - # test normalized weights - mixture_normalized = MixtureDataset( - datasets=datasets, - weights={"1": 2.0, "2": 2.0}, - stop_strategy=StopStrategy.FIRST_STOP_STRATEGY, - key=0, - ) - assert mixture_normalized.weights["1"] == mixture_normalized.weights["2"] == 0.5 - - -def test_restart_strategy_gets_the_right_average(): - - num_docs_1, num_docs_2 = 10, 20 - ds1 = ListDataset([1 for _ in range(num_docs_1)]) - ds2 = ListDataset([2 for _ in range(num_docs_2)]) - - datasets = {"1": ds1, "2": ds2} - mixture_balanced_restart = MixtureDataset( - datasets=datasets, # type: ignore - weights={"1": 0.6, "2": 0.4}, - stop_strategy=StopStrategy.RESTART_STRATEGY, - key=0, - ) - - # ensure we get the right long run average - NUM_SAMPLES = 2300 - - # variance of a bernoulli distribution is p(1-p) ≈ 0.24 - # to get a 95% confidence interval of 0.02, we need ~2300 samples - - # we expect to get roughly 60% 1s and 40% 2s - num_ones = 0 - for i, ex in enumerate(mixture_balanced_restart): - if ex == 1: - num_ones += 1 - if i >= NUM_SAMPLES: - break - - assert 0.58 < num_ones / NUM_SAMPLES < 0.62 - - # now just to verify, stop_first won't give us the same average - - num_total = 0 - num_ones = 0 - - mixture_balanced_first = MixtureDataset( - datasets=datasets, # type: ignore - weights={"1": 0.6, "2": 0.4}, - stop_strategy=StopStrategy.FIRST_STOP_STRATEGY, - key=0, - ) - - for i, ex in enumerate(mixture_balanced_first): - if ex == 1: - num_ones += 1 - num_total += 1 - - assert num_total < 30 - assert num_ones == num_docs_1 - assert num_ones / num_total < 0.55 or num_ones / num_total > 0.65 diff --git a/tests/test_doremi.py b/tests/test_doremi.py index 8f10139b0..8600c9c8b 100644 --- a/tests/test_doremi.py +++ b/tests/test_doremi.py @@ -1,6 +1,10 @@ +import functools +from typing import Optional, Sequence + import equinox import jax import jax.random +import numpy as np import optax import pytest @@ -8,7 +12,7 @@ import levanter.tracker from levanter.callbacks import eval_loss_loop -from levanter.data.dataset import ShardableDataset +from levanter.data import AsyncDataset from levanter.data.mixture import MixtureDataset from levanter.trainer import Trainer, TrainerConfig from levanter.utils.jax_utils import key_iterator @@ -23,7 +27,7 @@ class Example(equinox.Module): Block = hax.Axis("Block", 1024) -class LogitDataset(ShardableDataset[Example]): +class LogitDataset(AsyncDataset[Example]): def __init__(self, W, noise, x_mask, x_bias, *, key): self.W = W self.noise = noise @@ -31,18 +35,65 @@ def __init__(self, W, noise, x_mask, x_bias, *, key): self.x_bias = x_bias self.key = key + @equinox.filter_jit + def _make_example(x_block, y_block, offset): + return Example(x=x_block[Block, offset], y=y_block[Block, offset]) + + self._make_example = _make_example + + @functools.lru_cache + @equinox.filter_jit + def _gen_block_data(block_id): + key = jax.random.fold_in(self.key, block_id) + x_block = hax.random.normal(key, (Block, self.W.axes[0])) * self.x_mask + self.x_bias + noise = hax.random.normal(key, (Block,)) * self.noise + y_block = (hax.nn.sigmoid(hax.dot(x_block, self.W, axis=self.W.axes[0]) + noise) > 0.5).astype(float) + return x_block, y_block + + self._gen_block_data = _gen_block_data + def __iter__(self): key_iter = key_iterator(self.key) Dim = self.W.axes[0] while True: - x_block = hax.random.normal(next(key_iter), (Block, Dim)) * self.x_mask + self.x_bias - noise = hax.random.normal(next(key_iter), (Block,)) * self.noise + kk = next(key_iter) + this_key_iter = key_iterator(kk) + x_block = hax.random.normal(next(this_key_iter), (Block, Dim)) * self.x_mask + self.x_bias + noise = hax.random.normal(next(this_key_iter), (Block,)) * self.noise y_block = (hax.nn.sigmoid(hax.dot(x_block, self.W, axis=Dim) + noise) > 0.5).astype(float) for i in range(Block.size): - yield Example(x=x_block[Block, i], y=y_block[Block, i]) + yield self._make_example(x_block, y_block, i) + + async def async_len(self) -> int: + raise ValueError("Infinitely long dataset") + + async def final_length_is_known(self) -> bool: + return False + + def is_finite(self) -> bool: + return False - def shard(self, shard_id: int, num_shards: int): - return LogitDataset(self.W, self.noise, self.x_mask, self.x_bias, key=jax.random.fold_in(self.key, shard_id)) + async def current_len(self) -> Optional[int]: + return None + + async def get_batch(self, indices: Sequence[int]) -> Sequence[Example]: + blocks = set(i // Block.size for i in indices) + + block_data = {} + for block_id in blocks: + x_block, y_block = self._gen_block_data(block_id) + block_data[block_id] = (x_block, y_block) + + result: list[Example] = [] + indices = np.array(indices, dtype=int) + + for index in indices: + block_id = index // Block.size + block_offset = index % Block.size + x_block, y_block = block_data[block_id] + result.append(self._make_example(x_block, y_block, block_offset)) + + return result @pytest.mark.slow @@ -78,7 +129,7 @@ def compute_loss_fn(model, example, reduction=hax.mean, reduction_axis=None, key return hax.nn.binary_cross_entropy_loss(y_pred, example.y, reduction=reduction, reduction_axis=reduction_axis) tiny_trainer_config = TrainerConfig( - num_train_steps=600, + num_train_steps=300, train_batch_size=Batch.size, tracker=(), id="kmaklfmaf", @@ -89,11 +140,11 @@ def compute_loss_fn(model, example, reduction=hax.mean, reduction_axis=None, key trainer = Trainer(tiny_trainer_config, optimizer, compute_loss_fn) - def fit_to_dataset(dataset): + def fit_to_dataset(dataset: AsyncDataset): initial_model = init_model() with trainer: state = trainer.initial_state(next(keys), model=initial_model) - loader = trainer.replicated_loader(dataset, Batch) + loader = trainer.data_loader(dataset, Batch) loader = non_caching_cycle(loader) loss = 0.0 @@ -125,19 +176,13 @@ def init_model(): datasets = {"d1": ds1, "d2": ds2, "d3": ds3} ref_model, ref_loss = fit_to_dataset( - MixtureDataset(datasets, weights={k: 1 / 3.0 for k in datasets.keys()}, key=next(keys)) + MixtureDataset(datasets, weights={k: 1 / 3.0 for k in datasets.keys()}, key=next(keys), block_size=2048) ) # let's see the loss on each dataset - l1_ref = eval_loss_loop( - compute_loss_fn, ref_model, trainer.replicated_loader(ds1, Batch), max_batches=10, name="d1" - ) - l2_ref = eval_loss_loop( - compute_loss_fn, ref_model, trainer.replicated_loader(ds2, Batch), max_batches=10, name="d2" - ) - l3_ref = eval_loss_loop( - compute_loss_fn, ref_model, trainer.replicated_loader(ds3, Batch), max_batches=10, name="d3" - ) + l1_ref = eval_loss_loop(compute_loss_fn, ref_model, trainer.data_loader(ds1, Batch), max_batches=10, name="d1") + l2_ref = eval_loss_loop(compute_loss_fn, ref_model, trainer.data_loader(ds2, Batch), max_batches=10, name="d2") + l3_ref = eval_loss_loop(compute_loss_fn, ref_model, trainer.data_loader(ds3, Batch), max_batches=10, name="d3") assert l3_ref < l1_ref < l2_ref diff --git a/tests/test_in_progress_sequence.py b/tests/test_in_progress_sequence.py deleted file mode 100644 index 1b5b6711b..000000000 --- a/tests/test_in_progress_sequence.py +++ /dev/null @@ -1,124 +0,0 @@ -import pytest - -from levanter.data._process_interleave import InProgressSequence - - -@pytest.mark.asyncio -async def test_append(): - seq = InProgressSequence[int]() - seq.append(1) - assert seq.current_length() == 1 - assert await seq.get(0) == 1 - - -@pytest.mark.asyncio -async def test_set_item(): - seq = InProgressSequence[int]() - seq.set_item(2, 10) - assert seq.current_length() == 3 - assert await seq.get(2) == 10 - - -@pytest.mark.asyncio -async def test_set_item_out_of_range(): - seq = InProgressSequence[int]() - with pytest.raises(IndexError): - seq.set_item(-1, 10) - - -@pytest.mark.asyncio -async def test_item_exception(): - seq = InProgressSequence[int]() - seq.set_item(0, 5) - seq.item_exception(0, ValueError("Test Exception")) - with pytest.raises(ValueError, match="Test Exception"): - await seq.get(0) - - -@pytest.mark.asyncio -async def test_set_finished_length(): - seq = InProgressSequence[int]() - seq.append(1) - seq.append(2) - seq.set_finished_length(2) - assert seq.is_finished() - assert seq.to_list() == [1, 2] - - -@pytest.mark.asyncio -async def test_set_finished_length_first(): - seq = InProgressSequence[int]() - seq.set_finished_length(2) - seq.append(1) - seq.append(2) - assert seq.is_finished() - assert seq.to_list() == [1, 2] - - -@pytest.mark.asyncio -async def test_finalize(): - seq = InProgressSequence[int]() - seq.append(1) - seq.append(2) - seq.finalize() - assert seq.is_finished() - assert seq.to_list() == [1, 2] - - -@pytest.mark.asyncio -async def test_exception_handling(): - seq = InProgressSequence[int]() - seq.set_exception(ValueError("Test Exception")) - with pytest.raises(ValueError, match="Test Exception"): - await seq.finished_promise - - -@pytest.mark.asyncio -async def test_get_promise_immediate(): - seq = InProgressSequence[int]() - seq.append(1) - promise = seq.get_promise(0) - assert await promise == 1 - - -@pytest.mark.asyncio -async def test_get_promise_deferred(): - seq = InProgressSequence[int]() - promise = seq.get_promise(0) - seq.append(2) - assert await promise == 2 - - -@pytest.mark.asyncio -async def test_get_promise_out_of_range(): - seq = InProgressSequence[int]() - seq.set_finished_length(2) - with pytest.raises(IndexError): - seq.get_promise(3) - - -@pytest.mark.asyncio -async def test_get_promise_with_future_exception(): - seq = InProgressSequence[int]() - promise = seq.get_promise(0) - promise2 = seq.get_promise(0) - seq.item_exception(0, ValueError("Test Exception")) - - with pytest.raises(ValueError, match="Test Exception"): - await promise - - with pytest.raises(ValueError, match="Test Exception"): - await promise2 - - -@pytest.mark.asyncio -async def test_get_promise_with_past_exception(): - seq = InProgressSequence[int]() - seq.item_exception(0, ValueError("Test Exception")) - promise = seq.get_promise(0) - promise2 = seq.get_promise(0) - with pytest.raises(ValueError, match="Test Exception"): - await promise - - with pytest.raises(ValueError, match="Test Exception"): - await promise2 diff --git a/tests/test_jagged_array.py b/tests/test_jagged_array.py new file mode 100644 index 000000000..24ed24b08 --- /dev/null +++ b/tests/test_jagged_array.py @@ -0,0 +1,305 @@ +import math +import tempfile + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from levanter.store.jagged_array import JaggedArrayStore + + +class TestJaggedArrayStore: + def test_append_and_get(self): + with tempfile.TemporaryDirectory() as tmpdir: + builder = JaggedArrayStore.open(tmpdir, item_rank=2, dtype=jnp.float32) + + data1 = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + data2 = jnp.array([[5.0]]) + + builder.append(data1) + builder.append(data2) + + assert len(builder) == 2 + + result1 = builder[0] + assert jnp.all(result1 == data1) + + result2 = builder[1] + assert jnp.all(result2 == data2) + + # result_slice = builder[0:2] + # assert isinstance(result_slice, JaggedArray) + + def test_extend_with_multiple(self): + with tempfile.TemporaryDirectory() as tmpdir: + builder = JaggedArrayStore.open(tmpdir, item_rank=2, dtype=jnp.float32) + + data1 = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + data2 = jnp.array([[5.0]]) + + builder.extend([data1, data2]) + + assert len(builder) == 2 + + result1 = builder[0] + assert jnp.all(result1 == data1) + + result2 = builder[1] + assert jnp.all(result2 == data2) + + def test_append_error(self): + with tempfile.TemporaryDirectory() as tmpdir: + builder = JaggedArrayStore.open(tmpdir, item_rank=1, dtype=jnp.float32) + with pytest.raises(ValueError): + builder.append(jnp.array([[1.0, 2.0]])) + + def test_append_single_rank(self): + with tempfile.TemporaryDirectory() as tmpdir: + builder = JaggedArrayStore.open(tmpdir, item_rank=1, dtype=jnp.float32) + + data = jnp.array([1.0, 2.0, 3.0]) + builder.append(data) + + assert len(builder) == 1 + + result = builder[0] + assert jnp.all(result == data) + + def test_append_multi_rank(self): + with tempfile.TemporaryDirectory() as tmpdir: + builder = JaggedArrayStore.open(tmpdir, item_rank=2, dtype=jnp.float32) + + data1 = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + data2 = jnp.array([[5.0, 6.0], [7.0, 8.0]]) + + builder.append(data1) + builder.append(data2) + + assert len(builder) == 2 + + result1 = builder[0] + assert jnp.all(result1 == data1) + + result2 = builder[1] + assert jnp.all(result2 == data2) + + def test_getitem_out_of_bounds(self): + with tempfile.TemporaryDirectory() as tmpdir: + builder = JaggedArrayStore.open(tmpdir, item_rank=2, dtype=jnp.float32) + + data = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + builder.append(data) + + with pytest.raises(IndexError): + builder[2] + + def test_step_slicing(self): + with tempfile.TemporaryDirectory() as tmpdir: + builder = JaggedArrayStore.open(tmpdir, item_rank=2, dtype=jnp.float32) + + data = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + builder.append(data) + + # with pytest.raises(ValueError): + # builder[::2] + + +async def create_builder_with_data(directory, num_sequences: int, sequence_length: int | tuple[int, ...]): + if isinstance(sequence_length, int): + sequence_length = (sequence_length,) + + """Helper function to create a JaggedArrayStore with specific data.""" + seed = jax.random.PRNGKey(num_sequences * math.prod(sequence_length)) + + builder = await JaggedArrayStore.open_async(directory, item_rank=len(sequence_length), dtype=jnp.int64) + for i in range(num_sequences): + key, seed = jax.random.split(seed) + data = jax.random.randint(key, sequence_length, 0, 100) + await builder.append_async(data) + + return builder + + +def create_builder_with_data_sync( + directory, num_sequences: int, sequence_length: int | tuple[int, ...] +) -> JaggedArrayStore: + if isinstance(sequence_length, int): + sequence_length = (sequence_length,) + + """Helper function to create a JaggedArrayStore with specific data.""" + seed = jax.random.PRNGKey(num_sequences * math.prod(sequence_length)) + + builder = JaggedArrayStore.open(directory, item_rank=len(sequence_length), dtype=jnp.int64) + for i in range(num_sequences): + key, seed = jax.random.split(seed) + data = jax.random.randint(key, sequence_length, 0, 100) + builder.append(data) + + return builder + + +@pytest.mark.asyncio +async def test_trim_to_size_async(): + tmpdir = tempfile.TemporaryDirectory().name + builder = await create_builder_with_data(tmpdir, num_sequences=10, sequence_length=1000) + + # Initial size + initial_size = len(builder) + assert initial_size == 10 + + expected_data = list([builder[i] for i in range(10)]) + + # Trim to smaller size + await builder.trim_to_size_async(5) + new_size = len(builder) + assert new_size == 5 + + # Verify the data integrity + trimmed_data = await builder.data[0:5000].read() + assert jnp.all(trimmed_data == jnp.concatenate(expected_data[:5])) + + # Trim to zero size + await builder.trim_to_size_async(0) + new_size = len(builder) + assert new_size == 0 + + # Verify the data integrity + trimmed_data = await builder.data[0:5000].read() + assert jnp.all(trimmed_data == 0) + + +@pytest.mark.asyncio +async def test_trim_to_size_larger_than_current(): + tmpdir = tempfile.TemporaryDirectory().name + builder = await create_builder_with_data(tmpdir, num_sequences=10, sequence_length=1000) + expected_data = list([builder[i] for i in range(10)]) + + # Initial size + initial_size = len(builder) + assert initial_size == 10 + + # Trim to a larger size than current (should not change) + await builder.trim_to_size_async(15) + new_size = len(builder) + assert new_size == 10 + + # Verify the data integrity + trimmed_data = await builder.data[0:10000].read() + assert np.array_equal(trimmed_data, jnp.concatenate(expected_data[:10])) + + +@pytest.mark.asyncio +async def test_trim_to_size_with_shapes_async(): + tmpdir = tempfile.TemporaryDirectory().name + builder = await create_builder_with_data(tmpdir, num_sequences=10, sequence_length=(10, 100)) + expected_shapes = list(await builder.shapes[0:10].read()) + + # Trim to smaller size + await builder.trim_to_size_async(5) + new_size = len(builder) + assert new_size == 5 + + # Verify the shapes integrity + trimmed_shapes = await builder.shapes[0:5].read() + assert np.array_equal(trimmed_shapes, jnp.stack(expected_shapes[:5])) + + +def test_trim_to_size(): + tmpdir = tempfile.TemporaryDirectory().name + builder = create_builder_with_data_sync(tmpdir, num_sequences=10, sequence_length=1000) + + # Initial size + initial_size = len(builder) + assert initial_size == 10 + + expected_data = list([builder[i] for i in range(10)]) + + # Trim to smaller size + builder.trim_to_size(5) + new_size = len(builder) + assert new_size == 5 + + # Verify the data integrity + trimmed_data = builder.data[0:5000].read().result() + assert jnp.all(trimmed_data == jnp.concatenate(expected_data[:5])) + + # Trim to zero size + builder.trim_to_size(0) + new_size = len(builder) + assert new_size == 0 + + # Verify the data integrity + trimmed_data = builder.data[0:10000].read().result() + assert jnp.all(trimmed_data == 0) + + +@pytest.mark.asyncio +async def test_get_batch_single_item(): + tmpdir = tempfile.TemporaryDirectory().name + builder = await create_builder_with_data(tmpdir, num_sequences=10, sequence_length=1000) + + # Retrieve a single item using get_batch + batch = await builder.get_batch([3]) + result = batch[0] + + expected_data = await builder.get_item_async(3) + + assert np.array_equal(result, expected_data) + + +@pytest.mark.asyncio +async def test_get_batch_multiple_items(): + tmpdir = tempfile.TemporaryDirectory().name + builder = await create_builder_with_data(tmpdir, num_sequences=10, sequence_length=1000) + + # Retrieve multiple items using get_batch + indices = [1, 4, 7] + batch = await builder.get_batch(indices) + + for idx, result in zip(indices, batch): + expected_data = await builder.get_item_async(idx) + assert np.array_equal(result, expected_data) + + +@pytest.mark.asyncio +async def test_get_batch_out_of_order(): + tmpdir = tempfile.TemporaryDirectory().name + builder = await create_builder_with_data(tmpdir, num_sequences=10, sequence_length=1000) + + # Retrieve items out of order using get_batch + indices = [7, 2, 5] + batch = await builder.get_batch(indices) + + for idx, result in zip(indices, batch): + expected_data = await builder.get_item_async(idx) + assert np.array_equal(result, expected_data) + + +@pytest.mark.asyncio +async def test_get_batch_with_shapes(): + tmpdir = tempfile.TemporaryDirectory().name + builder = await create_builder_with_data(tmpdir, num_sequences=10, sequence_length=(10, 100)) + + # Retrieve multiple items using get_batch + indices = [0, 3, 6] + batch = await builder.get_batch(indices) + + for idx, result in zip(indices, batch): + expected_data = await builder.get_item_async(idx) + assert np.array_equal(result, expected_data) + + +@pytest.mark.asyncio +async def test_get_batch_empty(): + tmpdir = tempfile.TemporaryDirectory().name + builder = await create_builder_with_data(tmpdir, num_sequences=10, sequence_length=1000) + + # Retrieve an empty batch + batch = await builder.get_batch([]) + + assert batch == [] + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/test_llama.py b/tests/test_llama.py index 3fc6a551e..4277150fe 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -51,7 +51,6 @@ def test_llama_flops(): llama_config = LlamaConfig.from_hf_config(hf_config) n_params = 6.738415616e9 ratio = llama_config.flops_per_token(hf_config.vocab_size) / (2 * n_params) - print(ratio) assert ratio > 1.1, f"ratio {ratio} < 1.1" assert ratio < 1.2, f"ratio {ratio} > 1.2" @@ -386,6 +385,4 @@ def test_state_dict_consistency(scan_layers, num_kv_heads): model = LlamaLMHeadModel.init(Vocab=Vocab, config=config, key=random.PRNGKey(0)) hf_config = config.to_hf_config(Vocab.size) hf_model = LlamaForCausalLM(hf_config) - print(hf_model.state_dict().keys()) - print(model.to_state_dict().keys()) assert set(hf_model.state_dict().keys()) == set(model.to_state_dict().keys()) diff --git a/tests/test_lora.py b/tests/test_lora.py index f9268d350..f7d852531 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -113,6 +113,7 @@ def test_lora_peft_integration(): hf_dict = get_peft_model_state_dict(model) converter = Gpt2Config().hf_checkpoint_converter() + lev_model = converter.load_pretrained(converter.default_config.model_type, "stanford-crfm/expanse-gpt2-small-x777") lora_lev_model = loraize(lev_model, LoraConfig(r=8, target_modules=["c_attn"]), key=jax.random.PRNGKey(0)) @@ -168,8 +169,8 @@ def replace_dot_general(x): return PreciseDotGeneralOp() return x - merged = jax.tree_map(replace_dot_general, merged, is_leaf=lambda x: isinstance(x, DefaultDotGeneralOp)) - loraized = jax.tree_map(replace_dot_general, loraized, is_leaf=lambda x: isinstance(x, DefaultDotGeneralOp)) + merged = jax.tree.map(replace_dot_general, merged, is_leaf=lambda x: isinstance(x, DefaultDotGeneralOp)) + loraized = jax.tree.map(replace_dot_general, loraized, is_leaf=lambda x: isinstance(x, DefaultDotGeneralOp)) input = hax.random.normal(k0, (In,)) # light tolerances for TPU diff --git a/tests/test_mixture.py b/tests/test_mixture.py new file mode 100644 index 000000000..e8821e24f --- /dev/null +++ b/tests/test_mixture.py @@ -0,0 +1,155 @@ +import jax +import numpy as np +import pytest + +from levanter.data import ListAsyncDataset, MixtureDataset +from levanter.data.mixture import StopStrategy + + +def datasets(): + ds1 = ListAsyncDataset([1, 2, 3, 4, 5]) + ds2 = ListAsyncDataset([10, 20, 30, 40, 50]) + ds3 = ListAsyncDataset([100, 200, 300, 400, 500]) + ds1.finalize() + ds2.finalize() + ds3.finalize() + return {"ds1": ds1, "ds2": ds2, "ds3": ds3} + + +def weights(): + return {"ds1": 0.5, "ds2": 0.3, "ds3": 0.2} + + +def block_size(): + return 10 + + +def key(): + return jax.random.PRNGKey(42) + + +@pytest.mark.asyncio +async def test_mixture_dataset_getitem(): + mixture_ds = MixtureDataset(datasets(), weights(), 10, key=key, randomize_blocks=False) + + item = await mixture_ds.getitem_async(0) + assert item in [1, 10, 100], f"Unexpected item: {item}" + + +@pytest.mark.asyncio +async def test_mixture_dataset_get_batch(): + mixture_ds = MixtureDataset(datasets(), weights(), 10, key=key(), randomize_blocks=False) + + batch = await mixture_ds.get_batch([0, 1, 2]) + assert len(batch) == 3 + assert all(item in [1, 2, 3, 10, 20, 30, 100, 200, 300] for item in batch) + + +@pytest.mark.asyncio +async def test_mixture_dataset_block_assignments(): + mixture_ds = MixtureDataset(datasets(), weights(), 10, key=key()) + + block_assignment = await mixture_ds._get_block(0) + assert block_assignment is not None + assert len(block_assignment) == 10 + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_mixture_dataset_stop_strategy_first(): + mixture_ds = MixtureDataset(datasets(), weights(), 10, key=key, stop_strategy=StopStrategy.FIRST_STOP_STRATEGY) + + with pytest.raises(NotImplementedError): + await mixture_ds.async_len() + + +@pytest.mark.asyncio +async def test_mixture_dataset_stop_strategy_restart(): + mixture_ds = MixtureDataset( + datasets(), weights(), block_size=10, key=key(), stop_strategy=StopStrategy.RESTART_STRATEGY + ) + + with pytest.raises(ValueError): + await mixture_ds.async_len() + + +@pytest.mark.asyncio +async def test_mixture_dataset_normalized_weights(): + weights = {"ds1": 0, "ds2": 0.5, "ds3": 0.5} + mixture_ds = MixtureDataset(datasets(), weights, block_size=10, key=key(), randomize_blocks=False) + + batch = await mixture_ds.get_batch([0, 1, 2]) + assert len(batch) == 3 + assert all(item in [10, 20, 30, 100, 200, 300] for item in batch) + + +@pytest.mark.asyncio +async def test_mixture_dataset_unpermuted_ids(): + mixture_ds = MixtureDataset(datasets(), weights(), block_size=10, key=key()) + + unpermuted_ids = mixture_ds._compute_unpermuted_ids(mixture_ds._counts_per_block) + assert len(unpermuted_ids) == 10 + assert unpermuted_ids[0] >> 32 in range(3) # Ensure the dataset ID is valid + + +@pytest.mark.asyncio +async def test_mixture_dataset_remap_indices(): + dses = datasets() + mixture_ds = MixtureDataset(dses, weights(), block_size=10, key=key()) + + remapped_indices = await mixture_ds._remap_indices(dses["ds1"], [0, 1, 2]) + assert len(remapped_indices) == 3 + assert remapped_indices == [0, 1, 2] + + # check wrap around + len_ds1 = await dses["ds1"].async_len() + remapped_indices = await mixture_ds._remap_indices(dses["ds1"], [len_ds1 - 1, len_ds1, len_ds1 + 1]) + assert len(remapped_indices) == 3 + + assert remapped_indices == [len_ds1 - 1, 0, 1] + + +@pytest.mark.asyncio +async def test_mixture_dataset_respects_weights(): + w = weights() + mixture_ds = MixtureDataset(datasets(), w, block_size(), key=key()) + + # Check that the dataset respects the weights + num_samples = 1000 + samples = await mixture_ds.get_batch(list(range(num_samples))) + + counts = {"ds1": 0, "ds2": 0, "ds3": 0} + for sample in samples: + if sample < 10: + counts["ds1"] += 1 + elif sample < 100: + counts["ds2"] += 1 + else: + counts["ds3"] += 1 + + for dataset, count in counts.items(): + assert abs(count / num_samples - w[dataset]) < 0.1, f"Dataset {dataset} has unexpected weight" + + +@pytest.mark.asyncio +async def test_mixture_dataset_randomizes_blocks(): + mixture_ds = MixtureDataset(datasets(), weights(), block_size=10, key=key()) + + block_assignment_1 = await mixture_ds._get_block(0) + block_assignment_2 = await mixture_ds._get_block(0) + + assert np.all(block_assignment_1 == block_assignment_2), "Block assignments should be randomized" + + block_assignment_3 = await mixture_ds._get_block(1) + assert not np.all(block_assignment_1 == block_assignment_3), "Block assignments should be randomized" + + +@pytest.mark.asyncio +async def test_mixture_dataset_samples_all_elements(): + mixture_ds = MixtureDataset(datasets(), weights(), block_size=10, key=key()) + + num_samples = 1000 + samples = await mixture_ds.get_batch(list(range(num_samples))) + + assert len(samples) == num_samples + assert set(samples) == {1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100, 200, 300, 400, 500} diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py new file mode 100644 index 000000000..3302674de --- /dev/null +++ b/tests/test_new_cache.py @@ -0,0 +1,921 @@ +import asyncio +import logging +import tempfile +from typing import Iterator, Sequence +from unittest.mock import MagicMock + +import numpy as np +import pytest +import ray +from ray.exceptions import RayTaskError + +from levanter.data import BatchProcessor, ShardedDataSource, batched +from levanter.data.sharded_datasource import TextUrlDataSource +from levanter.store.cache import ( + SerialCacheWriter, + TreeStore, + _get_builder_actor, + _OrderedCacheWriter, + build_or_load_cache, +) +from levanter.utils.py_utils import logical_cpu_core_count +from levanter.utils.ray_utils import ExceptionInfo, SnitchRecipient, ser_exc_info + + +class TestProcessor(BatchProcessor[Sequence[int], dict[str, np.ndarray]]): + def __init__(self, batch_size: int = 8): + self._batch_size = batch_size + + def __call__(self, batch: Sequence[Sequence[int]]) -> Sequence[dict[str, np.ndarray]]: + # return pa.RecordBatch.from_arrays([pa.array(batch)], ["test"]) + return [{"test": np.asarray(x)} for x in batch] + + @property + def output_exemplar(self): + return {"test": np.array([0], dtype=np.int64)} + + @property + def batch_size(self) -> int: + return self._batch_size + + @property + def num_cpus(self) -> int: + return 1 + + +def simple_process(processor, source): + result = [] + for shard_name in source.shard_names: + for batch in source.open_shard(shard_name): + result.append(processor([batch])[0]) + + return result + + +def process_interleave(processor, source): + batch_size = processor.batch_size + shard_iterators = { + shard_name: batched(iter(source.open_shard(shard_name)), batch_size) for shard_name in source.shard_names + } + finished = 0 + + while finished < len(shard_iterators): + for shard_name, shard_iter in shard_iterators.items(): + if shard_iter is None: + continue + try: + batch = next(shard_iter) + yield from processor(batch) + except StopIteration: + shard_iterators[shard_name] = None + finished += 1 + + +def setup_module(module): + ray.init( + "local", num_cpus=max(2 * logical_cpu_core_count(), 8), ignore_reinit_error=True + ) # 2x cpu count is faster on my m1 + + +def teardown_module(module): + ray.shutdown() + + +class SimpleProcessor(BatchProcessor[Sequence[int], dict[str, np.ndarray]]): + def __init__(self, batch_size: int = 8): + self._batch_size = batch_size + + def __call__(self, batch: Sequence[Sequence[int]]) -> Sequence[dict[str, Sequence[int]]]: + return [{"data": x} for x in batch] + + @property + def batch_size(self) -> int: + return self._batch_size + + @property + def num_cpus(self) -> int: + return 1 + + @property + def output_exemplar(self) -> dict[str, np.ndarray]: + return {"data": np.array([0], dtype=np.int64)} + + +class SimpleShardSource(ShardedDataSource[list[int]]): + def __init__(self, num_shards: int = 4): + self._num_shards = num_shards + + @property + def shard_names(self) -> Sequence[str]: + return [f"shard_{i}" for i in range(self._num_shards)] + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: + # parse the shard name to get the shard number + shard_num = int(shard_name.split("_")[1]) + return ([shard_num * 10 + i] * 10 for i in range(row, 10)) + + +def test_serial_cache_writer(): + with tempfile.TemporaryDirectory() as tmpdir1: + source = SimpleShardSource(num_shards=4) + processor = SimpleProcessor() + + exemplar = {"data": np.array([0], dtype=np.int64)} + + with SerialCacheWriter(tmpdir1, exemplar) as writer: + for shard_name in source.shard_names: + for ex in batched(source.open_shard(shard_name), processor.batch_size): + writer.write_batch(processor(ex)) + + _ = writer.result() + data_path = writer._tree_store.path + + builder = TreeStore.open(exemplar, data_path, mode="r") + + assert len(builder) == 40 + + for i, x in enumerate(builder): + np.testing.assert_array_equal(x["data"], np.asarray([i % 10 + i // 10 * 10] * 10)) + + +def crappy_du(path): + import os + + total = 0 + for root, dirs, files in os.walk(path): + for f in files: + total += os.path.getsize(os.path.join(root, f)) + return total + + +@ray.remote +class PretendParent(SnitchRecipient): + def __init__(self): + self.logger = logging.getLogger("SnitchRecipient") + self.failure_received = asyncio.Event() + self.exception_info = None + self._finished_shards = set() + self._finished = False + self._ledger = None + self._desired_next_item = None + + def _child_failed(self, child: ray.actor.ActorHandle, exception: ExceptionInfo): + try: + self.logger.error(f"Child {child} failed with exception {exception}") + self.exception_info = exception + self.failure_received.set() + except Exception as e: + self.logger.error(f"Error in _child_failed: {e}") + + def shard_failed(self, shard_name, exc_info): + self.exception_info = exc_info + self.failure_received.set() + + async def wait_for_failure(self): + await self.failure_received.wait() + return self.exception_info + + def shard_finished(self, shard_name): + self._finished_shards.add(shard_name) + + def get_finished_shards(self): + return self._finished_shards + + def _updated_ledger(self, ledger): + if ledger.is_finished: + self._finished = True + + self._ledger = ledger + + def _finalize(self): + self._finished = True + + def is_finished(self): + return self._finished + + def signal_backpressure(self, desired_next_item: float): + self._desired_next_item = desired_next_item + + def desired_next_item(self): + return self._desired_next_item + + +@pytest.mark.asyncio +async def test_batch_finished(): + parent = PretendParent.remote() + exemplar = np.array([1, 2, 3]) + with tempfile.TemporaryDirectory() as cache_dir: + shards = ["shard1", "shard2", "shard3"] + writer = _OrderedCacheWriter.remote( + parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=1 + ) + + try: + shard_idx = "shard1" + shard_batch_idx = 0 + batch_result = [np.array([1, 2, 3])] + + await writer.batch_finished.remote(shard_idx, shard_batch_idx, batch_result) + shard_status = await writer.get_shard_status.remote("shard1") + assert shard_status.num_rows_committed == 1 + finally: + ray.kill(parent) + ray.kill(writer) + + +@pytest.mark.asyncio +async def test_shard_finished_reading(): + parent = PretendParent.remote() + exemplar = MagicMock() + with tempfile.TemporaryDirectory() as cache_dir: + shards = ["shard1", "shard2", "shard3"] + writer = _OrderedCacheWriter.remote(parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards) + + try: + shard_name = "shard1" + expected_batches = 5 + + await writer.shard_finished_reading.remote(shard_name, expected_batches) + shard_status = await writer.get_shard_status.remote(shard_name) + assert shard_status.is_finished is False + finally: + ray.kill(parent) + ray.kill(writer) + + +@pytest.mark.asyncio +async def test_get_shard_status(): + parent = PretendParent.remote() + exemplar = np.array([1, 2, 3]) + with tempfile.TemporaryDirectory() as cache_dir: + shards = ["shard1", "shard2", "shard3"] + writer = _OrderedCacheWriter.remote(parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards) + + try: + shard_name = "shard1" + shard_status = await writer.get_shard_status.remote(shard_name) + + assert shard_status.shard_name == shard_name + assert shard_status.num_rows_committed == 0 + assert not shard_status.is_finished + finally: + ray.kill(parent) + ray.kill(writer) + + +@pytest.mark.asyncio +async def test_shard_failed(): + parent = PretendParent.remote() + exemplar = MagicMock() + with tempfile.TemporaryDirectory() as cache_dir: + shards = ["shard1", "shard2", "shard3"] + writer = _OrderedCacheWriter.remote(parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards) + + try: + shard_name = "shard1" + batch_id = 0 + try: + raise Exception("Test Exception") + except: # noqa + exc_info = ser_exc_info() + + await writer.shard_failed.remote(shard_name, batch_id, exc_info) + exception_received = await parent.wait_for_failure.remote() + assert str(exception_received.ex) == str(exc_info.ex) + finally: + ray.kill(parent) + ray.kill(writer) + + +DEFAULT_BATCH_SIZE = 128 + + +@pytest.mark.asyncio +async def test_attempt_to_write_batches(): + parent = PretendParent.remote() + exemplar = np.array([1, 2, 3]) + with tempfile.TemporaryDirectory() as cache_dir: + shards = ["shard1", "shard2", "shard3"] + writer = _OrderedCacheWriter.remote( + parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=2 + ) + + try: + shard1_batch = [np.asarray([1, 2, 3])] + shard2_batch = [np.asarray([4, 5, 6, 7])] + + await writer.batch_finished.remote("shard1", 0, shard1_batch) + await writer.batch_finished.remote("shard2", 0, shard2_batch) + + ledger = await writer.get_ledger.remote() + assert ledger.is_finished is False + assert ledger.total_num_rows == 2 # Assuming each batch has 1 row for simplicity + + store = TreeStore.open(exemplar, cache_dir, mode="r") + assert len(store) == 2 + np.testing.assert_array_equal(store[0], shard1_batch[0]) + np.testing.assert_array_equal(store[1], shard2_batch[0]) + finally: + ray.kill(parent) + ray.kill(writer) + + +@pytest.mark.asyncio +async def test_finalize_cache(): + parent = PretendParent.remote() + exemplar = np.array([1, 2, 3]) + with tempfile.TemporaryDirectory() as cache_dir: + shards = ["shard1", "shard2", "shard3"] + writer = _OrderedCacheWriter.remote(parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards) + + try: + shard1_batch = [np.array([1, 2, 3])] + shard2_batch = [np.array([4, 5, 6, 7])] + + await writer.batch_finished.remote("shard1", 0, shard1_batch) + await writer.shard_finished_reading.remote("shard1", 1) + await writer.shard_finished_reading.remote("shard2", 1) + await writer.batch_finished.remote("shard2", 0, shard2_batch) + + ledger = await writer.get_ledger.remote() + assert ledger.is_finished is False + assert ledger.total_num_rows == 2 # Assuming each batch has 1 row for simplicity + + await writer.shard_finished_reading.remote("shard3", 0) + finished_shards = await parent.get_finished_shards.remote() + assert len(finished_shards) == 3 + + ledger = await writer.get_ledger.remote() + assert ledger.is_finished is True + assert ledger.total_num_rows == 2 + assert await parent.is_finished.remote() is True + finally: + ray.kill(parent) + ray.kill(writer) + + +@pytest.mark.asyncio +async def test_error_handling(): + parent = PretendParent.remote() + exemplar = np.array([1, 2, 3]) + with tempfile.TemporaryDirectory() as cache_dir: + shards = ["shard1", "shard2", "shard3"] + writer = _OrderedCacheWriter.remote(parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards) + + try: + with pytest.raises(TypeError): + await writer.batch_finished.remote("shard1", 0, None) + + exception_received = await parent.wait_for_failure.remote() + assert exception_received is not None + finally: + ray.kill(parent) + ray.kill(writer) + + +@pytest.mark.asyncio +async def test_out_of_order_batches_same_shard(): + parent = PretendParent.remote() + exemplar = np.array([1, 2, 3]) + with tempfile.TemporaryDirectory() as cache_dir: + shards = ["shard1"] + writer = _OrderedCacheWriter.remote( + parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=2 + ) + + try: + # Sending batch 1 before batch 0 for shard1 + shard1_batch0 = [np.array([1, 2, 3])] + shard1_batch1 = [np.array([4, 5, 6])] + + await writer.batch_finished.remote("shard1", 1, shard1_batch1) + await writer.batch_finished.remote("shard1", 0, shard1_batch0) + + store = TreeStore.open(exemplar, cache_dir, mode="r") + assert len(store) == 2 + np.testing.assert_array_equal(store[0], shard1_batch0[0]) + np.testing.assert_array_equal(store[1], shard1_batch1[0]) + finally: + ray.kill(parent) + ray.kill(writer) + + +@pytest.mark.asyncio +async def test_out_of_order_batches_different_shards(): + parent = PretendParent.remote() + exemplar = np.array([1, 2, 3]) + with tempfile.TemporaryDirectory() as cache_dir: + shards = ["shard1", "shard2"] + writer = _OrderedCacheWriter.remote( + parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=3 + ) + + try: + # Sending batches out of order across different shards + shard1_batch0 = [np.array([1, 2, 3])] + shard2_batch0 = [np.array([4, 5, 6])] + shard1_batch1 = [np.array([7, 8, 9])] + + await writer.batch_finished.remote("shard1", 1, shard1_batch1) + await writer.batch_finished.remote("shard2", 0, shard2_batch0) + await writer.batch_finished.remote("shard1", 0, shard1_batch0) + + store = TreeStore.open(exemplar, cache_dir, mode="r") + assert len(store) == 3 + np.testing.assert_array_equal(store[0], shard1_batch0[0]) + np.testing.assert_array_equal(store[1], shard2_batch0[0]) + np.testing.assert_array_equal(store[2], shard1_batch1[0]) + finally: + ray.kill(parent) + ray.kill(writer) + + +@pytest.mark.asyncio +async def test_batches_different_orders_all_shards(): + parent = PretendParent.remote() + exemplar = np.array([1, 2, 3]) + with tempfile.TemporaryDirectory() as cache_dir: + shards = ["shard1", "shard2", "shard3"] + writer = _OrderedCacheWriter.remote( + parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=2 + ) + + try: + # Sending batches in different orders across all shards + shard1_batch0 = [np.array([1, 2, 3])] + shard1_batch1 = [np.array([4, 5, 6])] + shard2_batch0 = [np.array([7, 8, 9])] + shard3_batch0 = [np.array([10, 11, 12])] + + await writer.batch_finished.remote("shard2", 0, shard2_batch0) + await writer.batch_finished.remote("shard3", 0, shard3_batch0) + await writer.batch_finished.remote("shard1", 1, shard1_batch1) + await writer.batch_finished.remote("shard1", 0, shard1_batch0) + + store = TreeStore.open(exemplar, cache_dir, mode="r") + assert len(store) == 4 + np.testing.assert_array_equal(store[0], shard1_batch0[0]) + np.testing.assert_array_equal(store[1], shard2_batch0[0]) + np.testing.assert_array_equal(store[2], shard3_batch0[0]) + np.testing.assert_array_equal(store[3], shard1_batch1[0]) + finally: + ray.kill(parent) + ray.kill(writer) + + +@pytest.mark.asyncio +async def test_intermixed_batches_same_and_different_shards(): + parent = PretendParent.remote() + exemplar = np.array([1, 2, 3]) + with tempfile.TemporaryDirectory() as cache_dir: + shards = ["shard1", "shard2", "shard3"] + writer = _OrderedCacheWriter.remote( + parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=1 + ) + + try: + # Sending intermixed batches from the same and different shards + shard1_batch0 = [np.array([1, 2, 3])] + shard2_batch0 = [np.array([4, 5, 6])] + shard1_batch1 = [np.array([7, 8, 9])] + shard3_batch0 = [np.array([10, 11, 12])] + shard2_batch1 = [np.array([13, 14, 15])] + + await writer.batch_finished.remote("shard2", 0, shard2_batch0) + await writer.batch_finished.remote("shard3", 0, shard3_batch0) + await writer.batch_finished.remote("shard1", 1, shard1_batch1) + await writer.batch_finished.remote("shard2", 1, shard2_batch1) + await writer.batch_finished.remote("shard1", 0, shard1_batch0) + + store = TreeStore.open(exemplar, cache_dir, mode="r") + assert len(store) == 5 + np.testing.assert_array_equal(store[0], shard1_batch0[0]) + np.testing.assert_array_equal(store[1], shard2_batch0[0]) + np.testing.assert_array_equal(store[2], shard3_batch0[0]) + np.testing.assert_array_equal(store[3], shard1_batch1[0]) + np.testing.assert_array_equal(store[4], shard2_batch1[0]) + finally: + ray.kill(parent) + ray.kill(writer) + + +@pytest.mark.asyncio +async def test_duplicate_batches_same_shard(): + parent = PretendParent.remote() + exemplar = np.array([1, 2, 3]) + with tempfile.TemporaryDirectory() as cache_dir: + shards = ["shard1"] + writer = _OrderedCacheWriter.remote(parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards) + + try: + # Sending duplicate batches for the same shard + shard1_batch0 = [np.array([1, 2, 3])] + + await writer.batch_finished.remote("shard1", 0, shard1_batch0) + with pytest.raises(RayTaskError): + await writer.batch_finished.remote("shard1", 0, shard1_batch0) # Duplicate + finally: + ray.kill(parent) + ray.kill(writer) + + +@pytest.mark.asyncio +async def test_mixed_order_batches_multiple_shards(): + parent = PretendParent.remote() + exemplar = np.array([1, 2, 3]) + with tempfile.TemporaryDirectory() as cache_dir: + shards = ["shard1", "shard2", "shard3"] + writer = _OrderedCacheWriter.remote( + parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=1 + ) + + try: + # Sending batches in mixed order for multiple shards + shard1_batch0 = [np.array([1, 2, 3])] + shard2_batch0 = [np.array([4, 5, 6])] + shard1_batch1 = [np.array([7, 8, 9])] + shard2_batch1 = [np.array([10, 11, 12])] + shard3_batch0 = [np.array([13, 14, 15])] + shard3_batch1 = [np.array([16, 17, 18])] + + await writer.batch_finished.remote("shard3", 0, shard3_batch0) + await writer.batch_finished.remote("shard1", 1, shard1_batch1) + await writer.batch_finished.remote("shard2", 0, shard2_batch0) + await writer.batch_finished.remote("shard2", 1, shard2_batch1) + await writer.batch_finished.remote("shard1", 0, shard1_batch0) + await writer.batch_finished.remote("shard3", 1, shard3_batch1) + + store = TreeStore.open(exemplar, cache_dir, mode="r") + assert len(store) == 6 + np.testing.assert_array_equal(store[0], shard1_batch0[0]) + np.testing.assert_array_equal(store[1], shard2_batch0[0]) + np.testing.assert_array_equal(store[2], shard3_batch0[0]) + np.testing.assert_array_equal(store[3], shard1_batch1[0]) + np.testing.assert_array_equal(store[4], shard2_batch1[0]) + np.testing.assert_array_equal(store[5], shard3_batch1[0]) + finally: + ray.kill(parent) + ray.kill(writer) + + +@pytest.mark.ray +def test_full_end_to_end_cache_simple(): + td = tempfile.TemporaryDirectory() + with td as tmpdir: + ray_ds = build_or_load_cache( + tmpdir, + SimpleShardSource(num_shards=1), + TestProcessor(), + await_finished=True, + ) + + simple_processed = simple_process(TestProcessor(), SimpleShardSource()) + + all_data = ray_ds[:] + + check_datasets_equal(all_data, simple_processed) + + +@pytest.mark.ray +def test_cache_remembers_its_cached(): + directory = tempfile.TemporaryDirectory() + with directory as tmpdir: + ds1 = build_or_load_cache(tmpdir, SimpleShardSource(), TestProcessor()) + + class ThrowingProcessor(BatchProcessor[Sequence[int], dict[str, np.ndarray]]): + def __call__(self, batch: Sequence[Sequence[int]]): + raise RuntimeError("This should not be called") + + @property + def output_exemplar(self) -> dict[str, np.ndarray]: + return {"test": np.array([0], dtype=np.int64)} + + @property + def batch_size(self) -> int: + return 8 + + @property + def num_cpus(self) -> int: + return 1 + + # testing this doesn't throw + ds2 = build_or_load_cache(tmpdir, SimpleShardSource(), ThrowingProcessor(), await_finished=True) + + check_datasets_equal(ds1, ds2) + + +def check_datasets_equal(ds1, ds2): + for r1, r2 in zip(ds1, ds2): + assert r1.keys() == r2.keys() + for key in r1.keys(): + np.testing.assert_array_equal(r1[key], r2[key]) + + +class _CustomException(Exception): + pass + + +@pytest.mark.ray +@pytest.mark.skip("This test segfaults in CI. I think a ray bug") +def test_cache_recover_from_crash(): + class CrashingShardSource(ShardedDataSource[list[int]]): + def __init__(self, crash_point: int): + self.crash_point = crash_point + + @property + def shard_names(self) -> Sequence[str]: + return [f"shard_{i}" for i in range(4)] + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: + # parse the shard name to get the shard number + shard_num = int(shard_name.split("_")[1]) + for i in range(10): + if shard_num * 10 + i == self.crash_point: + raise _CustomException(f"Crashing at {shard_num} {i} {self.crash_point}") + if i >= row: + yield [shard_num * 10 + i] * 10 + + with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as tmpdir2: + source = CrashingShardSource(4) + with pytest.raises(_CustomException): + build_or_load_cache(tmpdir, source, TestProcessor()) + + # kill the broker actor so that we can test recovery + ray.kill( + _get_builder_actor(tmpdir, source, TestProcessor()), + no_restart=True, + ) + + source = CrashingShardSource(5) + with pytest.raises(_CustomException): + build_or_load_cache(tmpdir, source, TestProcessor()) + + ray.kill( + _get_builder_actor(tmpdir, source, TestProcessor()), + no_restart=True, + ) + + # testing this doesn't throw + source = CrashingShardSource(1000) + reader1 = build_or_load_cache(tmpdir, source, TestProcessor(), await_finished=True) + + # compare to the original with no crash + reader2 = build_or_load_cache(tmpdir2, SimpleShardSource(), TestProcessor(), await_finished=True) + + assert len(list(reader1)) == 40 + check_datasets_equal(reader1, reader2) + + +@pytest.mark.ray +def test_no_hang_if_empty_shard_source(): + class EmptyShardSource(ShardedDataSource[list[int]]): + @property + def shard_names(self) -> Sequence[str]: + return [] + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: + raise RuntimeError("This should not be called") + + with tempfile.TemporaryDirectory() as tmpdir: + reader = build_or_load_cache(tmpdir, EmptyShardSource(), TestProcessor()) + assert list(reader) == [] + + +@pytest.mark.ray +def test_chunk_ordering_is_correct_with_slow_shards(): + class SlowShardSource(ShardedDataSource[list[int]]): + @property + def shard_names(self) -> Sequence[str]: + return ["shard_0", "shard_1"] + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: + max_count = 40 if shard_name == "shard_1" else 20 + shard_id = int(shard_name.split("_")[1]) + for i in range(0, max_count): + yield [i * 10 + shard_id] * 10 + + with tempfile.TemporaryDirectory() as tmpdir: + cache = build_or_load_cache( + tmpdir, + SlowShardSource(), + TestProcessor(1), + await_finished=False, + ) + + # now block until the cache is done + cache.await_finished(timeout=10) + + expected = process_interleave(TestProcessor(1), SlowShardSource()) + + check_datasets_equal(list(cache[:]), expected) + + +@pytest.mark.asyncio +@pytest.mark.ray +async def test_can_get_elems_before_finished(): + @ray.remote(num_cpus=0) + class Blocker: + def __init__(self): + self.future = asyncio.Future() + + async def block(self): + await self.future + + def unblock(self): + self.future.set_result(None) + + blocker_to_wait_on_test = Blocker.remote() + + class SlowShardSource(ShardedDataSource[list[int]]): + @property + def shard_names(self) -> Sequence[str]: + return ["shard_0"] + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: + for i in range(10): + yield [i] * 10 + ray.get(blocker_to_wait_on_test.block.remote()) + for i in range(10, 20): + yield [i] * 10 + + with tempfile.TemporaryDirectory() as tmpdir: + cache = build_or_load_cache( + tmpdir, SlowShardSource(), TestProcessor(5), await_finished=False, items_per_write=5 + ) + + # read the first 10 elements + # ensure the first 10 elements are [{"test": np.array([i] * 10)} for i in range(10)] + first_10 = list(await cache.get_batch(range(0, 10))) + + for i, x in enumerate(first_10): + np.testing.assert_array_equal(x["test"], np.array([i] * 10)) + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(cache.get_batch(range(10, 20)), timeout=0.1) + + # then unblock: + ray.get(blocker_to_wait_on_test.unblock.remote()) + + # now ensure we can get the next 10 elements, which will be + # [{"test": np.array([i] * 10)} for i in range(10, 20)] + batch = await asyncio.wait_for(cache.get_batch(range(10, 20)), timeout=10) + + for i, x in enumerate(batch): + np.testing.assert_array_equal(x["test"], np.array([i + 10] * 10)) + + ray.get(blocker_to_wait_on_test.block.remote()) + + # now wait until the cache is finished. mostly so that the tempdir cleanup works + cache.await_finished(timeout=10) + + +@pytest.mark.skip("This test segfaults in CI. I think a ray bug") +@pytest.mark.ray +def test_shard_cache_crashes_if_processor_throws(): + class ThrowingProcessor(BatchProcessor[Sequence[int], dict[str, np.ndarray]]): + def __call__(self, batch: Sequence[Sequence[int]]): + raise RuntimeError("exc") + + @property + def output_exemplar(self) -> dict: + return {"test": np.array([0], dtype=np.int64)} + + @property + def batch_size(self) -> int: + return 8 + + @property + def num_cpus(self) -> int: + return 1 + + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(RuntimeError): + build_or_load_cache(tmpdir, SimpleShardSource(), ThrowingProcessor(), await_finished=True) + + +@pytest.mark.ray +@pytest.mark.skip("This test segfaults in CI. I think a ray bug") +def test_shard_cache_fails_with_multiple_shards_with_the_same_name(): + with tempfile.TemporaryDirectory() as tmpdir: + with open(f"{tmpdir}/data.txt", "w") as f: + f.write("") + + with pytest.raises(ValueError): + TextUrlDataSource( + [f"{tmpdir}/data.txt", f"{tmpdir}/data.txt"], + ) + + with open(f"{tmpdir}/data.txt.1", "w") as f: + f.write("") + + dataset = TextUrlDataSource( + [f"{tmpdir}/data.txt", f"{tmpdir}/data.txt.1"], + ) + + build_or_load_cache(tmpdir, dataset, TestProcessor(), await_finished=True) + + +@pytest.mark.skip("This test segfaults in CI. I think a ray bug") +@pytest.mark.ray +@pytest.mark.asyncio +async def test_shard_cache_fails_gracefully_with_unknown_file_type_async(): + with tempfile.TemporaryDirectory() as tmpdir: + with open(f"{tmpdir}/data.not_a_real_extension", "w") as f: + f.write("") + + dataset = TextUrlDataSource( + [f"{tmpdir}/data.not_a_real_extension"], + ) + + with pytest.raises(ValueError): + build_or_load_cache(tmpdir, dataset, TestProcessor(), await_finished=True) + + # now make sure it works in non-blocking mode + + cache = build_or_load_cache(tmpdir, dataset, TestProcessor(), await_finished=False) + + with pytest.raises(ValueError): + await cache.get_batch([0]) + + with pytest.raises(ValueError): + cache.await_finished(timeout=10) + + del cache + + +@pytest.mark.skip("This test segfaults in CI. I think a ray bug") +@pytest.mark.ray +def test_shard_cache_fails_gracefully_with_unknown_file_type(): + with tempfile.TemporaryDirectory() as tmpdir: + with open(f"{tmpdir}/data.not_a_real_extension", "w") as f: + f.write("") + + dataset = TextUrlDataSource( + [f"{tmpdir}/data.not_a_real_extension"], + ) + + with pytest.raises(ValueError): + build_or_load_cache(tmpdir, dataset, TestProcessor(), await_finished=True) + + # now make sure it works in non-blocking mode + + cache = build_or_load_cache(tmpdir, dataset, TestProcessor(), await_finished=False) + + with pytest.raises(ValueError): + cache.get_batch_sync([0]) + + with pytest.raises(ValueError): + cache.await_finished(timeout=10) + + del cache + + +@pytest.mark.ray +@pytest.mark.asyncio +async def test_backpressure_mechanism(): + parent = PretendParent.remote() + exemplar = np.array([1, 2, 3]) + with tempfile.TemporaryDirectory() as cache_dir: + shards = ["shard1", "shard2", "shard3"] + writer = _OrderedCacheWriter.remote( + parent, "test", exemplar, DEFAULT_BATCH_SIZE, cache_dir, shards, min_items_to_write=1 + ) + + # Simulate batches being processed + shard1_batch = [np.array([1, 2, 3])] + shard2_batch = [np.array([4, 5, 6])] + shard3_batch = [np.array([7, 8, 9])] + + # await writer.batch_finished.remote("shard1", 0, shard1_batch) + await writer.batch_finished.remote("shard2", 0, shard2_batch) + await writer.batch_finished.remote("shard3", 0, shard3_batch) + await writer.batch_finished.remote("shard1", 1, shard3_batch) + await writer.batch_finished.remote("shard1", 2, shard3_batch) + await writer.batch_finished.remote("shard1", 3, shard3_batch) + + # Check if backpressure is signaled + is_overwhelmed = await writer.is_overwhelmed.remote() + assert is_overwhelmed is True + + for i in range(4): + if (await parent.desired_next_item.remote()) == 0: + break + + await asyncio.sleep(0.1 * (i + 1) * (i + 1)) + else: + assert False, "Backpressure wasn't sent" + + await writer.batch_finished.remote("shard1", 0, shard1_batch) + + # Reduce the queue size to relieve backpressure + # Check if backpressure is relieved + is_overwhelmed = await writer.is_overwhelmed.remote() + assert is_overwhelmed is False + + for i in range(4): + if (await parent.desired_next_item.remote()) is None: + break + + await asyncio.sleep(0.1 * (i + 1) * (i + 1)) + else: + assert False, "Backpressure wasn't relieved" diff --git a/tests/test_replicated_loader.py b/tests/test_new_loader.py similarity index 62% rename from tests/test_replicated_loader.py rename to tests/test_new_loader.py index 431a1c0bb..e6f9a3dd7 100644 --- a/tests/test_replicated_loader.py +++ b/tests/test_new_loader.py @@ -1,5 +1,5 @@ -import itertools -from typing import Sequence +import asyncio +from typing import Optional, Sequence import jax import numpy as np @@ -9,26 +9,16 @@ from haliax import Axis from haliax.partitioning import ResourceAxis -import levanter.data -from levanter.data.loader import ReplicatedBatchLoader, check_sharded_consistency -from test_utils import skip_if_not_enough_devices +from levanter.data.dataset import AsyncDataset, ListAsyncDataset +from levanter.data.loader import DataLoader, check_sharded_consistency +from .test_utils import skip_if_not_enough_devices -def _small_dataset(seq_len=128, num_sequences=200) -> levanter.data.ShardableDataset[Sequence[int]]: - class SequenceDataset(levanter.data.ShardableDataset[np.ndarray]): - def __init__(self, sequences: Sequence[np.ndarray]): - self.sequences = sequences - def shard(self, shard_idx: int, num_shards: int) -> levanter.data.ShardableDataset[np.ndarray]: - return SequenceDataset(self.sequences[shard_idx::num_shards]) - - def __iter__(self): - yield from self.sequences - - # sequences = [list(range(i * 1000, i * 1000 + seq_len)) for i in range(num_sequences)] +def _small_dataset(seq_len=128, num_sequences=200) -> AsyncDataset[Sequence[int]]: sequences = [np.arange(seq_len) + 1000 * i for i in range(num_sequences)] - return SequenceDataset(sequences) + return ListAsyncDataset(sequences, is_complete=True) @skip_if_not_enough_devices(2) @@ -45,9 +35,9 @@ def test_local_batched_data_loading_model_axis_2(): seq_len = 128 cache = _small_dataset(seq_len) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(cache, mesh, Batch) + loader = DataLoader(Batch, cache, max_buffered_batches=10, mesh=mesh, axis_resources=None) - batches = list(itertools.islice(loader, 10)) + batches = list(loader) for batch in batches: check_sharded_consistency(batch, check_disjoint_indices_are_different=True) @@ -65,36 +55,46 @@ def test_local_batched_data_loading_model_axis_1(): seq_len = 128 cache = _small_dataset(seq_len) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(cache, mesh, Batch) + loader = DataLoader(Batch, cache, max_buffered_batches=10, mesh=mesh, axis_resources=None) - batches = list(itertools.islice(loader, 10)) + batches = list(loader) for batch in batches: check_sharded_consistency(batch, check_disjoint_indices_are_different=True) -class StructuredDataset(levanter.data.ShardableDataset): - def __init__(self, seq_len, begin, end, stride): +class StructuredDataset(AsyncDataset): + def __init__(self, seq_len): self.seq_len = seq_len - self.begin = begin - self.end = end - self.stride = stride + self.begin = 0 + self.end = 256 + self.stride = 1 + + async def async_len(self) -> int: + return (self.end - self.begin) // self.stride - def __getitem__(self, item): + async def getitem_async(self, index: int) -> dict: + index = self.begin + index * self.stride return { - "input_ids": np.arange(self.seq_len, dtype=np.int32) + item * 1000, - "labels": np.arange(self.seq_len, dtype=np.int32) + item * 1000, + "input_ids": np.arange(self.seq_len, dtype=np.int32) + index * 1000, + "labels": np.arange(self.seq_len, dtype=np.int32) + index * 1000, "extra": { - "input_ids": np.arange(self.seq_len, dtype=np.int32) + item * 1000, - "mask": np.arange(self.seq_len * 2, dtype=np.int32).reshape(-1, 2) + item * 1000, + "input_ids": np.arange(self.seq_len, dtype=np.int32) + index * 1000, + "mask": np.arange(self.seq_len * 2, dtype=np.int32).reshape(-1, 2) + index * 1000, }, } - def __iter__(self): - for i in range(self.begin, self.end, self.stride): - yield self[i] + async def final_length_is_known(self) -> bool: + return True + + def is_finite(self) -> bool: + return True - def shard(self, shard_id: int, num_shards: int): - return StructuredDataset(self.seq_len, self.begin + shard_id, self.end, self.stride * num_shards) + async def current_len(self) -> Optional[int]: + return await self.async_len() + + async def get_batch(self, indices: Sequence[int]): + out = await asyncio.gather(*(self.getitem_async(i) for i in indices)) + return out def test_structured_batches_model_axis_1(): @@ -107,11 +107,11 @@ def test_structured_batches_model_axis_1(): ) with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): seq_len = 128 - dataset = StructuredDataset(seq_len, 0, 256, 1) + dataset = StructuredDataset(seq_len) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = DataLoader(Batch, dataset, max_buffered_batches=10, mesh=mesh, axis_resources=None) - batches = list(itertools.islice(loader, 10)) + batches = list(loader) for batch in batches: check_sharded_consistency(batch, check_disjoint_indices_are_different=True) @@ -127,16 +127,16 @@ def test_structured_batches_model_axis_2(): ) with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): seq_len = 128 - dataset = StructuredDataset(seq_len, 0, 256, 1) + dataset = StructuredDataset(seq_len) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = DataLoader(Batch, dataset, max_buffered_batches=10, mesh=mesh, axis_resources=None) - batches = list(itertools.islice(loader, 10)) + batches = list(loader) for batch in batches: check_sharded_consistency(batch, check_disjoint_indices_are_different=True) -class StructuredDatasetWithNames(levanter.data.ShardableDataset): +class StructuredDatasetWithNames(AsyncDataset): def __init__(self, Height: Axis, Width: Axis, begin, end, stride): self.Height = Height self.Width = Width @@ -144,6 +144,33 @@ def __init__(self, Height: Axis, Width: Axis, begin, end, stride): self.end = end self.stride = stride + async def final_length_is_known(self) -> bool: + return True + + def is_finite(self) -> bool: + return True + + async def current_len(self) -> Optional[int]: + return True + + async def get_batch(self, indices: Sequence[int]): + out = await asyncio.gather(*(self.getitem_async(i) for i in indices)) + return out + + async def async_len(self) -> int: + return (self.end - self.begin) // self.stride + + async def getitem_async(self, index: int) -> dict: + index = self.begin + index * self.stride + return { + "input_ids": self._gen_image(index), + "labels": self._gen_image(index), + "extra": { + "input_ids": self._gen_image(index), + "mask": haliax.arange(self.Height) + index * 1000, + }, + } + def _gen_image(self, index): image = ( np.arange(self.Height.size * self.Width.size, dtype=np.int32).reshape(self.Height.size, self.Width.size) @@ -152,25 +179,10 @@ def _gen_image(self, index): return haliax.named(image, (self.Height, self.Width)) - def __getitem__(self, item): - return { - "input_ids": self._gen_image(item), - "labels": self._gen_image(item), - "extra": { - "input_ids": self._gen_image(item), - "mask": haliax.arange(self.Height) + item * 1000, - }, - } - def __iter__(self): for i in range(self.begin, self.end, self.stride): yield self[i] - def shard(self, shard_id: int, num_shards: int): - return StructuredDatasetWithNames( - self.Height, self.Width, self.begin + shard_id, self.end, self.stride * num_shards - ) - def test_structured_batches_model_axis_1_with_names(): devices = jax.devices() @@ -183,11 +195,11 @@ def test_structured_batches_model_axis_1_with_names(): with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): Height = Axis("Height", 16) Width = Axis("Width", 16) - dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) + dataset = StructuredDatasetWithNames(Height, Width, 0, len(devices) * 10, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = DataLoader(Batch, dataset, max_buffered_batches=10, mesh=mesh, axis_resources=None) - batches = list(itertools.islice(loader, 10)) + batches = list(loader) for batch in batches: check_sharded_consistency(batch, check_disjoint_indices_are_different=True) @@ -208,9 +220,9 @@ def test_structured_batches_model_axis_2_with_names(): Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = DataLoader(Batch, dataset, max_buffered_batches=10, mesh=mesh, axis_resources=None) - batches = list(itertools.islice(loader, 10)) + batches = list(loader) for batch in batches: check_sharded_consistency(batch, check_disjoint_indices_are_different=True) @@ -230,8 +242,7 @@ def test_structured_batches_model_axis_2_subsharded(): with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA, Height.name: ResourceAxis.MODEL}): dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = DataLoader(Batch, dataset, max_buffered_batches=10, mesh=mesh, axis_resources=None) - batches = list(itertools.islice(loader, 10)) - for batch in batches: + for batch in iter(loader): check_sharded_consistency(batch, check_disjoint_indices_are_different=True) diff --git a/tests/test_newdataset.py b/tests/test_newdataset.py new file mode 100644 index 000000000..030095b41 --- /dev/null +++ b/tests/test_newdataset.py @@ -0,0 +1,142 @@ +import asyncio + +import jax.random +import pytest + +from levanter.data import EraShufflingDataset, PermutationDataset +from levanter.data.dataset import ListAsyncDataset + + +@pytest.mark.asyncio +async def test_length_of_sequence_dataset_is_accurate(): + data = [1, 2, 3] + dataset = ListAsyncDataset(data) + assert (await dataset.current_len()) == 3 + assert not (await dataset.final_length_is_known()) + dataset.finalize() + assert (await dataset.current_len()) == 3 + assert await dataset.final_length_is_known() + assert (await dataset.async_len()) == 3 + + +@pytest.mark.asyncio +async def test_list_dataset_get_item_returns_correct_item(): + data = ["a", "b", "c"] + dataset = ListAsyncDataset(data) + assert await dataset.getitem_async(1) == "b" + + +@pytest.mark.asyncio +async def test_list_async_dataset_appends_and_finalizes_correctly(): + dataset = ListAsyncDataset([]) + dataset.append("a") + dataset.finalize() + assert await dataset.async_len() == 1 + assert await dataset.get_batch([0]) == ["a"] + + +@pytest.mark.asyncio +async def test_permutation_dataset_is_at_least_sometimes_permuted(): + for seed in range(10): + data = [1, 2, 3, 4] + dataset = ListAsyncDataset(data, is_complete=True) + permuted_dataset = PermutationDataset(dataset, jax.random.PRNGKey(seed)) + if await permuted_dataset.get_batch([0, 1, 2, 3]) != [1, 2, 3, 4]: + return + + pytest.fail("PermutationDataset did not permute the data") + + +@pytest.mark.asyncio +async def test_era_shuffling_dataset_returns_correct_length(): + data = list(range(100)) + dataset = ListAsyncDataset(data, is_complete=False) + era_length = 10 + key = jax.random.PRNGKey(0) + shuffling_dataset = EraShufflingDataset(dataset, era_length, key=key) + assert await shuffling_dataset.current_len() == 100 + assert not await shuffling_dataset.final_length_is_known() + + dataset.append(1) + assert await shuffling_dataset.current_len() == 100 + + +@pytest.mark.asyncio +async def test_era_shuffling_dataset_get_batch_returns_shuffled_batch(): + data = list(range(20)) + dataset = ListAsyncDataset(data) + dataset.finalize() + era_length = 5 + key = jax.random.PRNGKey(0) + shuffling_dataset = EraShufflingDataset(dataset, era_length, key=key) + batch_indices = [0, 1, 2, 3, 4] + batch = await shuffling_dataset.get_batch(batch_indices) + assert set(batch) == set([0, 1, 2, 3, 4]) # Ensures all elements are from the first era but does not assume order + assert batch != [0, 1, 2, 3, 4] # Ensures the batch is shuffled + + +@pytest.mark.asyncio +async def test_era_shuffling_can_grow(): + data = list(range(5)) + dataset = ListAsyncDataset(data) + era_length = 5 + key = jax.random.PRNGKey(0) + shuffling_dataset = EraShufflingDataset(dataset, era_length, key=key) + batch_indices = [0, 1, 2, 3, 4] + batch = await shuffling_dataset.get_batch(batch_indices) + assert set(batch) == set([0, 1, 2, 3, 4]) + + for i in range(5): + dataset.append(i + 5) + + assert await shuffling_dataset.current_len() == 10 + assert not await shuffling_dataset.final_length_is_known() + batch = await shuffling_dataset.get_batch(list(range(10))) + + assert set(batch) == set(range(10)) + assert set(batch[0:5]) == set([0, 1, 2, 3, 4]) + assert set(batch[5:10]) == set([5, 6, 7, 8, 9]) + + # now make sure that we can await data and it does get fulfilled + # this should timeout if we try to await it + coro = dataset.get_batch([11]) + try: + await asyncio.wait_for(coro, timeout=0.1) + pytest.fail("Should have timed out") + except asyncio.TimeoutError: + pass + + async def append_data(): + await asyncio.sleep(0.1) + for i in range(10, 15): + dataset.append(i) + + coro = dataset.getitem_async(11) + + _, r = await asyncio.gather(append_data(), coro) + assert r in range(10, 15) + + coro2 = shuffling_dataset.wait_until_len_at_least(20) + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(coro2, timeout=0.1) + + assert await shuffling_dataset.current_len() == 15 + + coro2 = shuffling_dataset.wait_until_len_at_least(20) + dataset.append(15) + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(coro2, timeout=0.1) + + assert await shuffling_dataset.current_len() == 15 + + coro2 = shuffling_dataset.wait_until_len_at_least(20) + dataset.finalize() + await asyncio.wait_for(coro2, timeout=0.1) + + assert await dataset.async_len() == 16 + assert await shuffling_dataset.current_len() == 16 + + coro = shuffling_dataset.get_batch(list(range(16))) + + batch = await coro + assert set(batch) == set(range(16)) diff --git a/tests/test_prp.py b/tests/test_prp.py new file mode 100644 index 000000000..6c549eabf --- /dev/null +++ b/tests/test_prp.py @@ -0,0 +1,87 @@ +import jax.numpy as jnp +import jax.random as jrandom +import pytest + +from levanter.data._prp import Permutation + + +def test_permutation_creates_valid_instance(): + length = 100 + prng_key = jrandom.PRNGKey(0) + permutation = Permutation(length, prng_key) + assert permutation.length == length + assert permutation._a > 0 and permutation._a < length + assert permutation._b >= 0 and permutation._b < length + + +def test_permutation_with_single_index_returns_correct_value(): + length = 10 + prng_key = jrandom.PRNGKey(0) + permutation = Permutation(length, prng_key) + index = 5 + result = permutation(index) + assert isinstance(result, int) + assert result != index # In most cases, result should not equal the input for a permutation + + +def test_permutation_with_array_returns_correct_values(): + length = 10 + prng_key = jrandom.PRNGKey(0) + permutation = Permutation(length, prng_key) + indices = jnp.arange(length) + results = permutation(indices) + assert isinstance(results, jnp.ndarray) + assert len(results) == length + assert jnp.sum(results == indices) <= 2 + + +def test_permutation_is_bijective_over_full_range(): + length = 10 + prng_key = jrandom.PRNGKey(0) + permutation = Permutation(length, prng_key) + indices = jnp.arange(length) + permuted = permutation(indices) + # Check if all elements are unique, which is a necessary condition for a bijective function + assert len(jnp.unique(permuted)) == length + + +def test_permutation_handles_edge_case_length_one(): + length = 1 + prng_key = jrandom.PRNGKey(0) + permutation = Permutation(length, prng_key) + result = permutation(0) + assert result == 0 # With length 1, the only valid output is the input it + + +def test_permutation_rejects_invalid_indices(): + length = 10 + prng_key = jrandom.PRNGKey(0) + permutation = Permutation(length, prng_key) + with pytest.raises(IndexError): + permutation(-1) # Test negative index + with pytest.raises(IndexError): + permutation(length) # Test index equal to length + + +def test_permutation_is_deterministic(): + length = 4 + prng_key = jrandom.PRNGKey(0) + permutation = Permutation(length, prng_key) + indices = jnp.arange(length) + results = permutation(indices) + prng_key = jrandom.PRNGKey(0) + permutation = Permutation(length, prng_key) + results2 = permutation(indices) + assert jnp.all(results == results2) + + +def test_permutation_is_deterministic1(): + length = 4 + prng_key = jrandom.PRNGKey(1) + permutation = Permutation(length, prng_key) + indices = jnp.arange(length) + results = permutation(indices) + prng_key = jrandom.PRNGKey(1) + permutation = Permutation(length, prng_key) + results2 = permutation(indices) + assert jnp.all(results == results2) diff --git a/tests/test_shard_cache.py b/tests/test_shard_cache.py deleted file mode 100644 index 7500307db..000000000 --- a/tests/test_shard_cache.py +++ /dev/null @@ -1,383 +0,0 @@ -import asyncio -import tempfile -from typing import Iterator, List, Sequence - -import pyarrow as pa -import pytest -import ray - -from levanter.data._preprocessor import BatchProcessor -from levanter.data.shard_cache import ChunkMetadata, SerialCacheWriter, _get_broker_actor, build_or_load_cache -from levanter.data.sharded_dataset import ShardedDataset, TextUrlDataset -from levanter.utils.py_utils import logical_cpu_core_count -from test_utils import skip_in_ci - - -def setup_module(module): - ray.init("local", num_cpus=max(2 * logical_cpu_core_count(), 8)) # 2x cpu count is faster on my m1 - - -def teardown_module(module): - ray.shutdown() - - -# tests to write: -# - test idempotency of writes - - -class TestProcessor(BatchProcessor[Sequence[int]]): - def __init__(self, batch_size: int = 8): - self._batch_size = batch_size - - def __call__(self, batch: Sequence[Sequence[int]]) -> pa.RecordBatch: - return pa.RecordBatch.from_arrays([pa.array(batch)], ["test"]) - - @property - def batch_size(self) -> int: - return self._batch_size - - @property - def num_cpus(self) -> int: - return 1 - - -class SimpleShardSource(ShardedDataset[List[int]]): - def __init__(self, num_shards: int = 4): - self._num_shards = num_shards - - @property - def shard_names(self) -> Sequence[str]: - return [f"shard_{i}" for i in range(self._num_shards)] - - def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[List[int]]: - # parse the shard name to get the shard number - shard_num = int(shard_name.split("_")[1]) - return ([shard_num * 10 + i] * 10 for i in range(row, 10)) - - -def simple_process(processor, source): - result = [] - for shard_name in source.shard_names: - for batch in source.open_shard(shard_name): - result.append(processor([batch])) - - return result - - -@pytest.mark.ray -@pytest.mark.parametrize("shards_to_read_at_once", [1, 2, 4]) -def test_cache_simple(shards_to_read_at_once): - td = tempfile.TemporaryDirectory() - with td as tmpdir: - ray_ds = build_or_load_cache( - tmpdir, - SimpleShardSource(), - TestProcessor(), - await_finished=True, - # shards_to_read_at_once=shards_to_read_at_once, - ) - - simple_processed = simple_process(TestProcessor(), SimpleShardSource()) - - assert list(ray_ds) == list(simple_processed) - - -@pytest.mark.ray -def test_cache_remembers_its_cached(): - directory = tempfile.TemporaryDirectory() - with directory as tmpdir: - ds1 = build_or_load_cache(tmpdir, SimpleShardSource(), TestProcessor()) - - class ThrowingProcessor(BatchProcessor[Sequence[int]]): - def __call__(self, batch: Sequence[Sequence[int]]) -> pa.RecordBatch: - raise RuntimeError("This should not be called") - - @property - def batch_size(self) -> int: - return 8 - - @property - def num_cpus(self) -> int: - return 1 - - # testing this doesn't throw - ds2 = build_or_load_cache(tmpdir, SimpleShardSource(), ThrowingProcessor(), await_finished=True) - - assert list(ds1) == list(ds2) - # ensure we delete tmpdir, since something is holding onto it - - -class _CustomException(Exception): - pass - - -@pytest.mark.ray -@skip_in_ci -def test_cache_recover_from_crash(): - class CrashingShardSource(ShardedDataset[List[int]]): - def __init__(self, crash_point: int): - self.crash_point = crash_point - - @property - def shard_names(self) -> Sequence[str]: - return [f"shard_{i}" for i in range(4)] - - def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[List[int]]: - # parse the shard name to get the shard number - shard_num = int(shard_name.split("_")[1]) - for i in range(10): - if shard_num * 10 + i == self.crash_point: - raise _CustomException(f"Crashing at {shard_num} {i} {self.crash_point}") - if i >= row: - yield [shard_num * 10 + i] * 10 - - with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as tmpdir2: - source = CrashingShardSource(4) - with pytest.raises(_CustomException): - build_or_load_cache(tmpdir, source, TestProcessor()) - - # kill the broker actor so that we can test recovery - ray.kill(_get_broker_actor(tmpdir, source, TestProcessor()), no_restart=True) - - source = CrashingShardSource(5) - with pytest.raises(_CustomException): - build_or_load_cache(tmpdir, source, TestProcessor()) - - ray.kill(_get_broker_actor(tmpdir, source, TestProcessor()), no_restart=True) - - # testing this doesn't throw - source = CrashingShardSource(1000) - reader1 = build_or_load_cache(tmpdir, source, TestProcessor(), batch_size=1, await_finished=True) - - # compare to the original with no crash - reader2 = build_or_load_cache(tmpdir2, SimpleShardSource(), TestProcessor(), batch_size=1, await_finished=True) - - assert list(reader1) == list(reader2) - assert len(list(reader1)) == 40 - - -@pytest.mark.ray -def test_no_hang_if_empty_shard_source(): - class EmptyShardSource(ShardedDataset[List[int]]): - @property - def shard_names(self) -> Sequence[str]: - return [] - - def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[List[int]]: - raise RuntimeError("This should not be called") - - with tempfile.TemporaryDirectory() as tmpdir: - reader = build_or_load_cache(tmpdir, EmptyShardSource(), TestProcessor(), batch_size=1) - assert list(reader) == [] - - -@skip_in_ci -@pytest.mark.ray -def test_chunk_ordering_is_correct_with_slow_shards(): - class SlowShardSource(ShardedDataset[List[int]]): - @property - def shard_names(self) -> Sequence[str]: - return ["shard_0", "shard_1"] - - def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[List[int]]: - max_count = 40 if shard_name == "shard_1" else 20 - for i in range(0, max_count): - yield [i] * 10 - - with tempfile.TemporaryDirectory() as tmpdir: - cache = build_or_load_cache( - tmpdir, - SlowShardSource(), - TestProcessor(1), - batch_size=1, - rows_per_chunk=10, - await_finished=False, - ) - - # now block until the cache is done - cache.await_finished(timeout=10) - - # now check that the chunks are in the right order - # TODO: this is a bit gross - chunks: List[ChunkMetadata] = ray.get([cache._broker.get_chunk.remote(i) for i in range(6)]) - assert chunks[0].name == "shard_0/chunk-0" - assert chunks[1].name == "shard_1/chunk-0" - assert chunks[2].name == "shard_0/chunk-1" - assert chunks[3].name == "shard_1/chunk-1" - assert chunks[4].name == "shard_1/chunk-2" - assert chunks[5].name == "shard_1/chunk-3" - - # make sure there's not a 7th chunk - chunk = ray.get(cache._broker.get_chunk.remote(6), timeout=0.5) - assert chunk is None - - -@skip_in_ci -@pytest.mark.ray -def test_can_get_chunk_before_finished(): - @ray.remote(num_cpus=0) - class Blocker: - def __init__(self): - self.future = asyncio.Future() - - async def block(self): - await self.future - - def unblock(self): - self.future.set_result(None) - - blocker_to_wait_on_test = Blocker.remote() - - class SlowShardSource(ShardedDataset[List[int]]): - @property - def shard_names(self) -> Sequence[str]: - return ["shard_0"] - - def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[List[int]]: - for i in range(10): - yield [i] * 10 - ray.get(blocker_to_wait_on_test.block.remote()) - for i in range(10, 20): - yield [i] * 10 - - with tempfile.TemporaryDirectory() as tmpdir: - cache = build_or_load_cache( - tmpdir, SlowShardSource(), TestProcessor(5), batch_size=1, rows_per_chunk=10, await_finished=False - ) - - def back_to_py(batch: pa.RecordBatch): - return list(batch["test"].values.to_numpy()) - - chunk = [back_to_py(batch) for batch in cache.read_chunk(0)] - - assert [list(x) for x in chunk] == [[i] * 10 for i in range(10)] - - with pytest.raises(TimeoutError): - cache.get_chunk(1, timeout=0.1) - - ray.get(blocker_to_wait_on_test.unblock.remote()) - - chunk = [back_to_py(batch) for batch in cache.read_chunk(1)] - - assert [list(x) for x in chunk] == [[i] * 10 for i in range(10, 20)] - - ray.get(blocker_to_wait_on_test.block.remote()) - - # now wait until the cache is finished. mostly so that the tempdir cleanup works - cache.await_finished(timeout=10) - - -@skip_in_ci -@pytest.mark.ray -def test_shard_cache_crashes_if_processor_throws(): - class ThrowingProcessor(BatchProcessor[Sequence[int]]): - def __call__(self, batch: Sequence[Sequence[int]]) -> pa.RecordBatch: - raise RuntimeError("exc") - - @property - def batch_size(self) -> int: - return 8 - - @property - def num_cpus(self) -> int: - return 1 - - with tempfile.TemporaryDirectory() as tmpdir: - with pytest.raises(RuntimeError): - build_or_load_cache(tmpdir, SimpleShardSource(), ThrowingProcessor(), await_finished=True) - - -@skip_in_ci -@pytest.mark.ray -def test_map_batches_and_map_shard_cache(): - td = tempfile.TemporaryDirectory() - with td as tmpdir: - ray_ds = ( - SimpleShardSource() - .map(lambda list: list * 2) - .map_batches(TestProcessor(), 8) - .map(lambda d: {"q": d["test"]}) - .build_or_load_cache(tmpdir, await_finished=True) - ) - - def composite_fn(list): - assert len(list) == 1 - return {"q": list[0] * 2} - - simple_processed = simple_process(composite_fn, SimpleShardSource()) - - # we internally change all the int lists in the ray_ds to np arrays, so we need to convert them back to lists - ray_entries = [] - for entry in ray_ds: - assert entry.keys() == {"q"} - ray_entries.append({"q": entry["q"].tolist()}) - - assert ray_entries == list(simple_processed) - - -@pytest.mark.ray -def test_serial_cache_writer(): - with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2: - source = SimpleShardSource(num_shards=4) - processor = TestProcessor() - - with SerialCacheWriter(tmpdir1, rows_per_chunk=8) as writer: - for shard_name in source.shard_names: - for batch in source.open_shard(shard_name): - writer.write_batch(processor([batch])) - - serial = writer.result(batch_size=1) - ray_ds = build_or_load_cache(tmpdir2, source, processor, await_finished=True) - - def freeze_batch(batch): - # make it hashable - return tuple(batch["test"].values.to_numpy()) - - assert set(freeze_batch(batch) for batch in serial) == set(freeze_batch(batch) for batch in ray_ds) - - -@skip_in_ci -@pytest.mark.ray -def test_shard_cache_fails_with_multiple_shards_with_the_same_name(): - with tempfile.TemporaryDirectory() as tmpdir: - with open(f"{tmpdir}/data.txt", "w") as f: - f.write("") - - with pytest.raises(ValueError): - TextUrlDataset( - [f"{tmpdir}/data.txt", f"{tmpdir}/data.txt"], - ) - - with open(f"{tmpdir}/data.txt.1", "w") as f: - f.write("") - - dataset = TextUrlDataset( - [f"{tmpdir}/data.txt", f"{tmpdir}/data.txt.1"], - ) - - build_or_load_cache(tmpdir, dataset, TestProcessor(), await_finished=True) - - -@skip_in_ci -@pytest.mark.ray -def test_shard_cache_fails_gracefully_with_unknown_file_type(): - with tempfile.TemporaryDirectory() as tmpdir: - with open(f"{tmpdir}/data.not_a_real_extension", "w") as f: - f.write("") - - dataset = TextUrlDataset( - [f"{tmpdir}/data.not_a_real_extension"], - ) - - with pytest.raises(ValueError): - build_or_load_cache(tmpdir, dataset, TestProcessor(), await_finished=True) - - # now make sure it works in non-blocking mode - - cache = build_or_load_cache(tmpdir, dataset, TestProcessor(), await_finished=False) - - with pytest.raises(ValueError): - cache.get_chunk(0, timeout=5) - - with pytest.raises(ValueError): - cache.await_finished(timeout=10) diff --git a/tests/test_sharded_dataset.py b/tests/test_sharded_dataset.py index 5183c55a4..b3c8bcc8d 100644 --- a/tests/test_sharded_dataset.py +++ b/tests/test_sharded_dataset.py @@ -1,6 +1,6 @@ import tempfile -from levanter.data.sharded_dataset import AudioTextUrlDataset, _sniff_format_for_dataset +from levanter.data.sharded_datasource import AudioTextUrlDataSource, _sniff_format_for_dataset from test_utils import skip_if_no_soundlibs @@ -26,4 +26,4 @@ def test_sniff_format_for_json(): @skip_if_no_soundlibs def test_resolve_audio_pointer(): - AudioTextUrlDataset.resolve_audio_pointer("https://ccrma.stanford.edu/~jos/mp3/trumpet.mp3", 16_000) + AudioTextUrlDataSource.resolve_audio_pointer("https://ccrma.stanford.edu/~jos/mp3/trumpet.mp3", 16_000) diff --git a/tests/test_sharded_loader.py b/tests/test_sharded_loader.py deleted file mode 100644 index ec46fb6a6..000000000 --- a/tests/test_sharded_loader.py +++ /dev/null @@ -1,299 +0,0 @@ -import itertools -from typing import Sequence - -import jax -import jax.numpy as jnp -import numpy as np -from jax.sharding import Mesh - -import haliax as hax -from haliax import Axis -from haliax.partitioning import ResourceAxis - -import levanter.data -from levanter.data.loader import ShardedBatchLoader, check_sharded_consistency -from test_utils import skip_if_not_enough_devices - - -NUM_SHARDS_TINY = 16 - - -def _small_dataset(seq_len=128, num_sequences=200) -> levanter.data.ShardableDataset[Sequence[int]]: - class SequenceDataset(levanter.data.ShardableDataset[np.ndarray]): - def __init__(self, sequences: Sequence[np.ndarray]): - self.sequences = sequences - - def shard(self, shard_idx: int, num_shards: int) -> levanter.data.ShardableDataset[np.ndarray]: - return SequenceDataset(self.sequences[shard_idx::num_shards]) - - def __iter__(self): - yield from self.sequences - - # sequences = [list(range(i * 1000, i * 1000 + seq_len)) for i in range(num_sequences)] - sequences = [np.arange(seq_len) + 1000 * i for i in range(num_sequences)] - - return SequenceDataset(sequences) - - -@skip_if_not_enough_devices(2) -def test_sharded_data_loading_model_axis_2(): - devices = jax.devices() - model_axis_size = 2 - - mesh = Mesh( - np.array(devices).reshape(1, -1, model_axis_size), - (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL), - ) - with mesh, hax.axis_mapping({"batch": (ResourceAxis.REPLICA, ResourceAxis.DATA)}): - seq_len = 128 - cache = _small_dataset(seq_len) - Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(cache, mesh, Batch) - - batches = list(itertools.islice(loader, 10)) - for batch in batches: - check_sharded_consistency(batch, check_disjoint_indices_are_different=True) - - -def test_sharded_data_loading_model_axis_1(): - devices = jax.devices() - model_axis_size = 1 - - mesh = Mesh( - np.array(devices).reshape(1, -1, model_axis_size), - (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL), - ) - with mesh, hax.axis_mapping({"batch": (ResourceAxis.REPLICA, ResourceAxis.DATA)}): - seq_len = 128 - cache = _small_dataset(seq_len) - Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(cache, mesh, Batch) - - batches = list(itertools.islice(loader, 10)) - for batch in batches: - check_sharded_consistency(batch, check_disjoint_indices_are_different=True) - - -class StructuredDataset(levanter.data.ShardableDataset): - def __init__(self, seq_len, begin, end, stride): - self.seq_len = seq_len - self.begin = begin - self.end = end - self.stride = stride - - def __getitem__(self, item): - return { - "input_ids": np.arange(self.seq_len, dtype=np.int32) + item * 1000, - "labels": np.arange(self.seq_len, dtype=np.int32) + item * 1000, - "extra": { - "input_ids": np.arange(self.seq_len, dtype=np.int32) + item * 1000, - "mask": np.arange(self.seq_len * 2, dtype=np.int32).reshape(-1, 2) + item * 1000, - }, - } - - def __iter__(self): - for i in range(self.begin, self.end, self.stride): - yield self[i] - - def shard(self, shard_id: int, num_shards: int): - return StructuredDataset(self.seq_len, self.begin + shard_id, self.end, self.stride * num_shards) - - -def test_structured_batches_model_axis_1(): - devices = jax.devices() - model_axis_size = 1 - - mesh = Mesh( - np.array(devices).reshape(1, -1, model_axis_size), - (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL), - ) - with mesh, hax.axis_mapping({"batch": (ResourceAxis.REPLICA, ResourceAxis.DATA)}): - seq_len = 128 - dataset = StructuredDataset(seq_len, 0, 256, 1) - Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) - - batches = list(itertools.islice(loader, 10)) - for batch in batches: - check_sharded_consistency(batch, check_disjoint_indices_are_different=True) - - -class ScalarDataset(levanter.data.ShardableDataset[hax.NamedArray]): - def __init__(self, begin, end, stride): - self.begin = begin - self.end = end - self.stride = stride - - def __getitem__(self, item): - return hax.named(jnp.array(item), ()) - - def __iter__(self): - for i in range(self.begin, self.end, self.stride): - yield self[i] - - def shard(self, shard_id: int, num_shards: int): - return ScalarDataset(self.begin + shard_id, self.end, self.stride * num_shards) - - -def test_can_batch_named_scalars(): - - devices = jax.devices() - model_axis_size = 1 - - mesh = Mesh( - np.array(devices).reshape(1, -1, model_axis_size), - (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL), - ) - with mesh, hax.axis_mapping({"batch": (ResourceAxis.REPLICA, ResourceAxis.DATA)}): - dataset = ScalarDataset(0, 256, 1) - Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) - - batches = list(itertools.islice(loader, 10)) - for batch in batches: - check_sharded_consistency(batch, check_disjoint_indices_are_different=True) - - -@skip_if_not_enough_devices(2) -def test_structured_batches_model_axis_2(): - devices = jax.devices() - model_axis_size = 2 - - mesh = Mesh( - np.array(devices).reshape(1, -1, model_axis_size), - (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL), - ) - with mesh, hax.axis_mapping({"batch": (ResourceAxis.REPLICA, ResourceAxis.DATA)}): - seq_len = 128 - dataset = StructuredDataset(seq_len, 0, 256, 1) - Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) - - batches = list(itertools.islice(loader, 10)) - for batch in batches: - check_sharded_consistency(batch, check_disjoint_indices_are_different=True) - - -class StructuredDatasetWithNames(levanter.data.ShardableDataset): - def __init__(self, Height: Axis, Width: Axis, begin, end, stride): - self.Height = Height - self.Width = Width - self.begin = begin - self.end = end - self.stride = stride - - def _gen_image(self, index): - image = ( - np.arange(self.Height.size * self.Width.size, dtype=np.int32).reshape(self.Height.size, self.Width.size) - + index * 1000 - ) - - return hax.named(image, (self.Height, self.Width)) - - def __getitem__(self, item): - return { - "input_ids": self._gen_image(item), - "labels": self._gen_image(item), - "extra": { - "input_ids": self._gen_image(item), - "mask": hax.arange(self.Height) + item * 1000, - }, - "id": hax.named(jnp.array(item), ()), - } - - def __iter__(self): - for i in range(self.begin, self.end, self.stride): - yield self[i] - - def shard(self, shard_id: int, num_shards: int): - return StructuredDatasetWithNames( - self.Height, self.Width, self.begin + shard_id, self.end, self.stride * num_shards - ) - - -def test_structured_batches_model_axis_1_with_names(): - devices = jax.devices() - model_axis_size = 1 - - mesh = Mesh( - np.array(devices).reshape(1, -1, model_axis_size), - (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL), - ) - with mesh, hax.axis_mapping({"batch": (ResourceAxis.REPLICA, ResourceAxis.DATA)}): - Height = Axis("Height", 16) - Width = Axis("Width", 16) - dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) - Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) - - batches = list(itertools.islice(loader, 10)) - for batch in batches: - check_sharded_consistency(batch, check_disjoint_indices_are_different=True) - - -@skip_if_not_enough_devices(2) -def test_structured_batches_model_axis_2_with_names(): - devices = jax.devices() - model_axis_size = 2 - - mesh = Mesh( - np.array(devices).reshape(1, -1, model_axis_size), - (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL), - ) - with mesh, hax.axis_mapping({"batch": (ResourceAxis.REPLICA, ResourceAxis.DATA)}): - Height = Axis("Height", 16) - Width = Axis("Width", 16) - dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) - Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) - - batches = list(itertools.islice(loader, 10)) - for batch in batches: - check_sharded_consistency(batch, check_disjoint_indices_are_different=True) - - -@skip_if_not_enough_devices(4) -def test_structured_batches_model_axis_2_subsharded(): - """This tests data loading if individual datums are sharded too""" - devices = jax.devices() - model_axis_size = 2 - - mesh = Mesh( - np.array(devices).reshape(1, -1, model_axis_size), - (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL), - ) - Height = Axis("Height", 16) - Width = Axis("Width", 16) - with mesh, hax.axis_mapping({"batch": (ResourceAxis.REPLICA, ResourceAxis.DATA), Height.name: ResourceAxis.MODEL}): - dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) - Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) - - batches = list(itertools.islice(loader, 10)) - for batch in batches: - check_sharded_consistency(batch, check_disjoint_indices_are_different=True) - - -def test_sharded_loader_doesnt_throw_away_data(): - devices = jax.devices() - model_axis_size = 1 - - mesh = Mesh( - np.array(devices).reshape(1, -1, model_axis_size), - (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL), - ) - with mesh, hax.axis_mapping({"batch": (ResourceAxis.REPLICA, ResourceAxis.DATA)}): - dataset = ScalarDataset(0, 256, 1) - Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) - - batches = list(itertools.islice(loader, 10)) - dataset_examples = list(itertools.islice(dataset, 10 * Batch.size)) - - def unbatch_example(example): - return example.unbind("batch") - - loader_examples = [ex for b in batches for ex in unbatch_example(b)] - - for ex_d, ex_l in zip(dataset_examples, loader_examples): - assert jnp.all(ex_d.array == ex_l.array) diff --git a/tests/test_shuffle_dataset.py b/tests/test_shuffle_dataset.py deleted file mode 100644 index 226986d14..000000000 --- a/tests/test_shuffle_dataset.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Iterator - -from jax.random import PRNGKey - -from levanter.data import Dataset, ShuffleDataset - - -class RangeDataset(Dataset[int]): - def __init__(self, start: int, end: int): - self.start = start - self.end = end - - def __iter__(self) -> Iterator[int]: - yield from range(self.start, self.end) - - -def test_shuffle_dataset(): - dataset = RangeDataset(0, 100) - assert list(dataset) == list(range(100)) - - key = PRNGKey(0) - shuffle_dataset = ShuffleDataset(dataset, key, 10) - - assert set(shuffle_dataset) == set(range(100)) - - assert list(shuffle_dataset) != list(range(100)) - - key2 = PRNGKey(2) - shuffle_dataset2 = ShuffleDataset(dataset, key2, 10) - assert list(shuffle_dataset2) != list(shuffle_dataset) diff --git a/tests/test_text.py b/tests/test_text.py index a9d407b44..a2645c1f9 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -1,12 +1,14 @@ import tempfile import jax.numpy as jnp +from transformers import AutoTokenizer import haliax as hax -from levanter.data.text import LMDatasetConfig +from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.models.lm_model import LmExample from levanter.models.loss import next_token_loss +from tests.test_utils import skip_if_hf_model_not_accessible def test_dont_blow_up_without_validation_set(): @@ -39,3 +41,29 @@ def test_lm_example_handles_ignore_id(): no_ignore_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_no_ignore.loss_mask) assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size + + +def test_merge_split_encodings(): + tokenizer = AutoTokenizer.from_pretrained("gpt2") + # make this very short for testing + + lorem = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.""" + + short_batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=len(lorem) // 3) + # force this + short_batch_tokenizer._needs_long_sequence_workaround = True + + batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=50000) + batch = [lorem] + + short_out = short_batch_tokenizer(batch) + reg_out = batch_tokenizer(batch) + + assert short_out == reg_out + + +@skip_if_hf_model_not_accessible("meta-llama/Llama-2-7b-hf") +def test_llama_tokenizer_needs_long_sequence_workaround(): + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + batch_tokenizer = BatchTokenizer(tokenizer) + assert batch_tokenizer._needs_long_sequence_workaround diff --git a/tests/test_tokenized_document_cache.py b/tests/test_tokenized_document_cache.py deleted file mode 100644 index d3b452937..000000000 --- a/tests/test_tokenized_document_cache.py +++ /dev/null @@ -1,216 +0,0 @@ -import tempfile -from typing import List, Sequence, TypeVar - -import pytest -import ray -from transformers import AutoTokenizer, BatchEncoding - -from levanter.data.shard_cache import build_or_load_cache -from levanter.data.sharded_dataset import ShardedDataset, TextUrlDataset -from levanter.data.text import TokenizedDocumentCache -from levanter.utils.py_utils import logical_cpu_core_count -from test_utils import IdentityProcessor, ShardsDataset, SingleShardDocumentSource, skip_in_ci - - -tokenizer = AutoTokenizer.from_pretrained("gpt2") - -T = TypeVar("T") - - -def setup_module(module): - ray_designated_cores = max(1, logical_cpu_core_count()) - ray.init("local", num_cpus=ray_designated_cores) - - -def teardown_module(module): - ray.shutdown() - - -@pytest.mark.ray -def test_index_empty_file(): - with tempfile.TemporaryDirectory() as tmpdir: - empty_dataset = [""] - source = SingleShardDocumentSource(empty_dataset) - cache = TokenizedDocumentCache.build_or_load( - f"{tmpdir}/cache", - source, - tokenizer, - flatten_docs=True, - enforce_bos=False, - enforce_eos=False, - override_resources={"num_cpus": 1}, - ) - - for chunk in cache: - assert chunk["input_ids"].size == 0 - - -@pytest.mark.ray -def test_index_no_files(): - with tempfile.TemporaryDirectory() as tmpdir: - empty_dataset = [] - source = SingleShardDocumentSource(empty_dataset) - cache = TokenizedDocumentCache.build_or_load( - f"{tmpdir}/cache", - source, - tokenizer, - flatten_docs=True, - enforce_eos=False, - override_resources={"num_cpus": 1}, - ) - - for chunk in cache: - pytest.fail("Should not have any chunks") - - -@skip_in_ci -@pytest.mark.ray -def test_doc_cache_reproduces_data_one_batch_per_shard(): - def doc_i(i: int): - return BatchEncoding(data=dict(input_ids=[list(range(10 * i, 10 * (i + 1)))])) - - num_docs = 10 - docs = [doc_i(j) for j in range(num_docs)] - - class OneDocPerShardSource(ShardedDataset[T]): - def __init__(self, docs: List[T]): - self.docs = docs - - @property - def shard_names(self) -> Sequence[str]: - return [str(i) for i in range(len(self.docs))] - - def open_shard_at_row(self, shard_name: str, row: int): - if row != 0: - raise ValueError(f"Expected row 0, got {row}") - - return [self.docs[int(shard_name)]] - - source = OneDocPerShardSource(docs) - - with tempfile.TemporaryDirectory() as tmpdir: - build_or_load_cache(f"{tmpdir}/cache", source, IdentityProcessor(), await_finished=True) - cache = TokenizedDocumentCache.load(f"{tmpdir}/cache", flatten_docs=False) - - result = list(cache) - - assert len(result) == num_docs - # sort the docs by input_ids b/c the order is not guaranteed - for i in range(len(result)): - as_listed = BatchEncoding(data={k: [vv.tolist() for vv in v] for k, v in result[i].items()}) - assert as_listed == docs[i] - - -@skip_in_ci -@pytest.mark.ray -@pytest.mark.parametrize("batch_size", list([1, 2, 3, 8])) -def test_doc_cache_reproduces_data_multi_docs_per_batch_sharded(batch_size): - def batch_docs(doc_ids): - return BatchEncoding(data=dict(input_ids=[list(range(10 * i, 10 * (i + 1))) for i in doc_ids])) - - num_docs = 10 - batches = [batch_docs([j, j + 1]) for j in range(0, num_docs, batch_size)] - - source = ShardsDataset([[b] for b in batches]) - with tempfile.TemporaryDirectory() as tmpdir: - build_or_load_cache(f"{tmpdir}/cache", source, IdentityProcessor()) - cache = TokenizedDocumentCache.load(f"{tmpdir}/cache", flatten_docs=True) - - result = list(cache) - - assert len(result) == len(batches) - - def list_in_list(a, b): - """checks if a is a contiguous sublist of b""" - n = len(a) - return any((list(a) == list(b[i : i + n])) for i in range(len(b) - n + 1)) - - # all we can really assert is that every doc from docs is in the result as a sublist - for i in range(len(batches)): - doc_tokens = batches[i]["input_ids"][0] - found = False - for j in range(len(result)): - # check if the doc is in this result doc - found = list_in_list(doc_tokens, result[j]["input_ids"][0]) - if found: - break - assert found - - -@skip_in_ci -@pytest.mark.ray -def test_doc_cache_sharding(): - def doc_i(i: int): - return BatchEncoding(data=dict(input_ids=[list(range(10 * i, 10 * (i + 1)))])) - - num_docs = 25 - num_shards = 12 - docs = [doc_i(j) for j in range(num_docs)] - # group into num_shards groups - doc_shards = [docs[i : i + num_docs // num_shards] for i in range(0, num_docs, num_docs // num_shards)] - - with tempfile.TemporaryDirectory() as tmpdir: - source = ShardsDataset(doc_shards) - build_or_load_cache(f"{tmpdir}/cache", source, IdentityProcessor()) - - # must evenly divide num_shards - num_shards_rebuild = [1, 2, 3, 4, 6, 12] - - for open_shards in num_shards_rebuild: - cache = TokenizedDocumentCache.load(f"{tmpdir}/cache", flatten_docs=False) - reconstructed = [] - - for shard_idx in range(0, open_shards): - # now we shard the cache - c = cache.shard(shard_idx, open_shards) - reconstructed.extend([d for b in c for d in _unbatch_encoding(b)]) - - assert len(reconstructed) == num_docs - - # sort the docs by input_ids b/c the order is not guaranteed - reconstructed.sort(key=lambda x: x["input_ids"][0][0]) # extra [0] for batchiness - for i in range(len(reconstructed)): - as_listed = BatchEncoding(data={k: [vv.tolist() for vv in v] for k, v in reconstructed[i].items()}) - assert as_listed == docs[i] - - -def _unbatch_encoding(enc: BatchEncoding): - docs = [] - for i in range(len(enc["input_ids"])): - docs.append(BatchEncoding(data={k: [v[i]] for k, v in enc.items()})) - return docs - - -@pytest.mark.ray -def test_cache_fails_with_different_tokenizer(): - with tempfile.TemporaryDirectory() as tmpdir: - with open(f"{tmpdir}/data.txt", "w") as f: - f.write("") - - dataset = TextUrlDataset( - [f"{tmpdir}/data.txt"], - ) - - tokenizer_a = AutoTokenizer.from_pretrained("microsoft/phi-2") - tokenizer_b = AutoTokenizer.from_pretrained("google/flan-t5-small") - - TokenizedDocumentCache.build_or_load( - tmpdir, - dataset, - tokenizer=tokenizer_a, - ) - - # Loading with the original tokenizer should be fine. - TokenizedDocumentCache.build_or_load( - tmpdir, - dataset, - tokenizer=tokenizer_a, - ) - - # Loading with a different tokenizer should error out. - with pytest.raises(ValueError): - TokenizedDocumentCache.build_or_load( - tmpdir, - dataset, - tokenizer=tokenizer_b, - ) diff --git a/tests/test_tree_store.py b/tests/test_tree_store.py new file mode 100644 index 000000000..e25ef7928 --- /dev/null +++ b/tests/test_tree_store.py @@ -0,0 +1,435 @@ +import tempfile +from typing import Iterator, List, Sequence + +import numpy as np +import pytest +import tensorstore as ts + +from levanter.data import BatchProcessor, ShardedDataSource +from levanter.data.utils import batched +from levanter.store.tree_store import TreeStore + + +class SimpleProcessor(BatchProcessor[Sequence[int], dict[str, np.ndarray]]): + def __init__(self, batch_size: int = 8): + self._batch_size = batch_size + + def __call__(self, batch: Sequence[Sequence[int]]) -> Sequence[dict[str, Sequence[int]]]: + return [{"data": x} for x in batch] + + @property + def output_exemplar(self) -> dict[str, Sequence[int]]: + return {"data": np.array([0], dtype=np.int64)} + + @property + def batch_size(self) -> int: + return self._batch_size + + @property + def num_cpus(self) -> int: + return 1 + + +class SimpleShardSource(ShardedDataSource[List[int]]): + def __init__(self, num_shards: int = 4): + self._num_shards = num_shards + + @property + def shard_names(self) -> Sequence[str]: + return [f"shard_{i}" for i in range(self._num_shards)] + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[List[int]]: + # parse the shard name to get the shard number + shard_num = int(shard_name.split("_")[1]) + return ([shard_num * 10 + i] * 10 for i in range(row, 10)) + + +def test_tree_builder_with_processor(): + with tempfile.TemporaryDirectory() as tempdir: + exemplar = {"data": np.array([0], dtype=np.int64)} + + builder = TreeStore.open(exemplar, tempdir, mode="w") + processor = SimpleProcessor() + source = SimpleShardSource() + + for batch in batched(source, processor.batch_size): + processed = processor(batch) + builder.extend(processed) + + assert len(builder) == 40 + + for i, x in enumerate(builder): + assert len(x) == 1 + + np.testing.assert_array_equal(x["data"], np.asarray([i % 10 + i // 10 * 10] * 10)) + + assert i == 39 + + # now test random access + for i in range(40): + x = builder[i] + assert len(x) == 1 + np.testing.assert_array_equal(x["data"], np.asarray([i % 10 + i // 10 * 10] * 10)) + + # double check columnar access + assert builder.tree["data"].data_size == 10 * 40 + + +def test_append_batch(): + with tempfile.TemporaryDirectory() as tmpdir: + exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)} + builder = TreeStore.open(exemplar, tmpdir) + + batch1 = [ + {"a": np.array([1.0, 2.0]), "b": np.array([3.0, 4.0])}, + {"a": np.array([5.0, 6.0]), "b": np.array([7.0, 8.0])}, + ] + builder.extend(batch1) + + assert len(builder) == 2 + + result1 = builder[0] + assert np.all(result1["a"] == np.array([1.0, 2.0])) + assert np.all(result1["b"] == np.array([3.0, 4.0])) + + result2 = builder[1] + assert np.all(result2["a"] == np.array([5.0, 6.0])) + assert np.all(result2["b"] == np.array([7.0, 8.0])) + + +def test_append_batch_different_shapes(): + with tempfile.TemporaryDirectory() as tmpdir: + + def _f32(x): + return np.asarray(x, dtype=np.float32) + + exemplar = {"a": _f32([0]), "b": _f32([0])} + builder = TreeStore.open(exemplar, tmpdir) + batch1 = [ + {"a": _f32([1.0, 2.0]), "b": _f32([3.0, 4.0])}, + {"a": _f32([5.0, 6.0]), "b": _f32([7.0, 8.0])}, + ] + builder.extend(batch1) + + batch2 = [ + {"a": _f32([9.0]), "b": _f32([10.0])}, + {"a": _f32([11.0, 12.0, 13.0]), "b": _f32([14.0, 15.0, 16.0])}, + ] + builder.extend(batch2) + + assert len(builder) == 4 + + result3 = builder[2] + assert np.all(result3["a"] == np.array([9.0])) + assert np.all(result3["b"] == np.array([10.0])) + + result4 = builder[3] + assert np.all(result4["a"] == np.array([11.0, 12.0, 13.0])) + assert np.all(result4["b"] == np.array([14.0, 15.0, 16.0])) + + +def test_extend_batch_different_shapes(): + with tempfile.TemporaryDirectory() as tmpdir: + exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)} + builder = TreeStore.open(exemplar, tmpdir) + + batch1 = {"a": [np.array([1.0, 2.0]), np.array([5.0, 6.0])], "b": [np.array([3.0, 4.0]), np.array([7.0, 8.0])]} + builder.extend_with_batch(batch1) + + batch2 = { + "a": [np.array([9.0]), np.array([11.0, 12.0, 13.0])], + "b": [np.array([10.0]), np.array([14.0, 15.0, 16.0])], + } + builder.extend_with_batch(batch2) + + assert len(builder) == 4 + + result3 = builder[2] + assert np.all(result3["a"] == np.array([9.0])) + assert np.all(result3["b"] == np.array([10.0])) + + result4 = builder[3] + assert np.all(result4["a"] == np.array([11.0, 12.0, 13.0])) + assert np.all(result4["b"] == np.array([14.0, 15.0, 16.0])) + + +def test_len(): + with tempfile.TemporaryDirectory() as tmpdir: + exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)} + builder = TreeStore.open(exemplar, tmpdir) + + assert len(builder) == 0 + + batch = [ + {"a": np.array([1.0, 2.0]), "b": np.array([3.0, 4.0])}, + {"a": np.array([5.0, 6.0]), "b": np.array([7.0, 8.0])}, + ] + builder.extend(batch) + + assert len(builder) == 2 + + +def test_getitem(): + with tempfile.TemporaryDirectory() as tmpdir: + exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)} + builder = TreeStore.open(exemplar, tmpdir) + + batch = [ + {"a": np.array([1.0, 2.0]), "b": np.array([3.0, 4.0])}, + {"a": np.array([5.0, 6.0]), "b": np.array([7.0, 8.0])}, + ] + builder.extend(batch) + + result = builder[0] + assert np.all(result["a"] == np.array([1.0, 2.0])) + assert np.all(result["b"] == np.array([3.0, 4.0])) + + result = builder[1] + assert np.all(result["a"] == np.array([5.0, 6.0])) + assert np.all(result["b"] == np.array([7.0, 8.0])) + + # test slice + # result = builder[0:2] + # assert isinstance(result["a"], JaggedArray) + # assert isinstance(result["b"], JaggedArray) + + +def test_getitem_out_of_bounds(): + with tempfile.TemporaryDirectory() as tmpdir: + exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)} + builder = TreeStore.open(exemplar, tmpdir) + + batch = [ + {"a": np.array([1.0, 2.0]), "b": np.array([3.0, 4.0])}, + {"a": np.array([5.0, 6.0]), "b": np.array([7.0, 8.0])}, + ] + builder.extend(batch) + + with pytest.raises(IndexError): + builder[2] + + +def test_iter(): + with tempfile.TemporaryDirectory() as tmpdir: + exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)} + builder = TreeStore.open(exemplar, tmpdir) + + batch = [ + {"a": np.array([1.0, 2.0]), "b": np.array([3.0, 4.0])}, + {"a": np.array([5.0, 6.0]), "b": np.array([7.0, 8.0])}, + ] + builder.extend(batch) + + for i, result in enumerate(builder): + if i == 0: + assert np.all(result["a"] == np.array([1.0, 2.0])) + assert np.all(result["b"] == np.array([3.0, 4.0])) + elif i == 1: + assert np.all(result["a"] == np.array([5.0, 6.0])) + assert np.all(result["b"] == np.array([7.0, 8.0])) + else: + pytest.fail("Unexpected index") + + +def test_reading_from_written(): + with tempfile.TemporaryDirectory() as tmpdir: + exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)} + builder = TreeStore.open(exemplar, tmpdir, mode="w") + + batch = [ + {"a": np.array([1.0, 2.0]), "b": np.array([3.0, 4.0])}, + {"a": np.array([5.0, 6.0]), "b": np.array([7.0, 8.0])}, + ] + builder.extend(batch) + + del builder + + builder2 = TreeStore.open(exemplar, tmpdir, mode="r") + + for i, result in enumerate(builder2): + if i == 0: + assert np.all(result["a"] == np.array([1.0, 2.0])) + assert np.all(result["b"] == np.array([3.0, 4.0])) + elif i == 1: + assert np.all(result["a"] == np.array([5.0, 6.0])) + assert np.all(result["b"] == np.array([7.0, 8.0])) + else: + pytest.fail("Unexpected index") + + +def test_resolve_changed_cache_size(): + with tempfile.TemporaryDirectory() as tmpdir: + exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)} + builder = TreeStore.open(exemplar, tmpdir, mode="w") + follower = TreeStore.open(exemplar, tmpdir, mode="r") + + batch = [ + {"a": np.array([1.0, 2.0]), "b": np.array([3.0, 4.0])}, + {"a": np.array([5.0, 6.0]), "b": np.array([7.0, 8.0])}, + ] + builder.extend(batch) + + follower = follower.reload() + follower2 = TreeStore.open(exemplar, tmpdir, mode="r") + + assert len(follower2) == 2 + assert len(follower) == 2 + + builder.extend(batch) + follower = follower.reload() + + assert len(follower) == 4 + + +# this test mostly exists to help me remember the API + + +def test_simple_resize_bounds(): + with tempfile.TemporaryDirectory() as tmpdir: + store1 = ts.open( + { + "driver": "zarr", + "kvstore": { + "driver": "file", + "path": tmpdir, + }, + }, + create=True, + dtype=ts.int32, + shape=[1000, 2000, 3000], + chunk_layout=ts.ChunkLayout(inner_order=[2, 1, 0]), + ).result() + + store2 = ts.open( + { + "driver": "zarr", + "kvstore": { + "driver": "file", + "path": tmpdir, + }, + }, + dtype=ts.int32, + ).result() + + assert store2.shape == (1000, 2000, 3000) + assert store2.chunk_layout.inner_order == (2, 1, 0) + + store1 = store1.resize(exclusive_max=[2000, 3000, 4000]).result() + + assert store1.shape == (2000, 3000, 4000) + + # store2 = store2[ts.d[0].mark_bounds_implicit[True]].resolve().result() + spec = store2.spec(retain_context=True, minimal_spec=True) + # spec.update(transform={}) + store2 = ts.open(spec).result() + + # store2 = store2.resolve(fix_resizable_bounds=False).result() + + assert store2.shape == (2000, 3000, 4000) # nope? + + +@pytest.mark.asyncio +async def test_get_batch_single_item(): + with tempfile.TemporaryDirectory() as tmpdir: + exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)} + builder = TreeStore.open(exemplar, tmpdir) + + batch1 = [ + {"a": np.array([1.0, 2.0]), "b": np.array([3.0, 4.0])}, + {"a": np.array([5.0, 6.0]), "b": np.array([7.0, 8.0])}, + ] + builder.extend(batch1) + + # Retrieve a single item using get_batch + batch = await builder.get_batch([0]) + result = batch[0] + + expected_data = builder[0] + assert np.array_equal(result["a"], expected_data["a"]) + assert np.array_equal(result["b"], expected_data["b"]) + + +@pytest.mark.asyncio +async def test_get_batch_multiple_items(): + with tempfile.TemporaryDirectory() as tmpdir: + exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)} + builder = TreeStore.open(exemplar, tmpdir) + + batch1 = [ + {"a": np.array([1.0, 2.0]), "b": np.array([3.0, 4.0])}, + {"a": np.array([5.0, 6.0]), "b": np.array([7.0, 8.0])}, + {"a": np.array([9.0, 10.0]), "b": np.array([11.0, 12.0])}, + ] + builder.extend(batch1) + + # Retrieve multiple items using get_batch + indices = [0, 2] + batch = await builder.get_batch(indices) + + for idx, result in zip(indices, batch): + expected_data = builder[idx] + assert np.array_equal(result["a"], expected_data["a"]) + assert np.array_equal(result["b"], expected_data["b"]) + + +@pytest.mark.asyncio +async def test_get_batch_out_of_order(): + with tempfile.TemporaryDirectory() as tmpdir: + exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)} + builder = TreeStore.open(exemplar, tmpdir) + + batch1 = [ + {"a": np.array([1.0, 2.0]), "b": np.array([3.0, 4.0])}, + {"a": np.array([5.0, 6.0]), "b": np.array([7.0, 8.0])}, + {"a": np.array([9.0, 10.0]), "b": np.array([11.0, 12.0])}, + ] + builder.extend(batch1) + + # Retrieve items out of order using get_batch + indices = [2, 0, 1] + batch = await builder.get_batch(indices) + + for idx, result in zip(indices, batch): + expected_data = builder[idx] + assert np.array_equal(result["a"], expected_data["a"]) + assert np.array_equal(result["b"], expected_data["b"]) + + +@pytest.mark.asyncio +async def test_get_batch_with_shapes(): + with tempfile.TemporaryDirectory() as tmpdir: + exemplar = {"a": np.array([[0]], dtype=np.float64), "b": np.array([[0]], dtype=np.float64)} + builder = TreeStore.open(exemplar, tmpdir) + + batch1 = [ + {"a": np.array([[1.0, 2.0], [3.0, 4.0]]), "b": np.array([[5.0, 6.0], [7.0, 8.0]])}, + {"a": np.array([[9.0, 10.0], [11.0, 12.0]]), "b": np.array([[13.0, 14.0], [15.0, 16.0]])}, + ] + builder.extend(batch1) + + # Retrieve multiple items using get_batch + indices = [0, 1] + batch = await builder.get_batch(indices) + + for idx, result in zip(indices, batch): + expected_data = builder[idx] + assert np.array_equal(result["a"], expected_data["a"]) + assert np.array_equal(result["b"], expected_data["b"]) + + +@pytest.mark.asyncio +async def test_get_batch_empty(): + with tempfile.TemporaryDirectory() as tmpdir: + exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)} + builder = TreeStore.open(exemplar, tmpdir) + + batch1 = [ + {"a": np.array([1.0, 2.0]), "b": np.array([3.0, 4.0])}, + {"a": np.array([5.0, 6.0]), "b": np.array([7.0, 8.0])}, + ] + builder.extend(batch1) + + # Retrieve an empty batch + batch = await builder.get_batch([]) + + assert batch == [] diff --git a/tests/test_utils.py b/tests/test_utils.py index 53042826c..1bf03b624 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,7 +17,7 @@ from levanter.checkpoint import _get_fs_and_plain_path from levanter.data._preprocessor import BatchProcessor -from levanter.data.sharded_dataset import ShardedDataset +from levanter.data.sharded_datasource import ShardedDataSource from levanter.data.text import _stack_batch_encodings from levanter.models.attention import AttentionMask @@ -193,17 +193,21 @@ def decorator(fn): return pytest.mark.skipif("CI" in os.environ, reason="skipped in CI")(fn_or_msg) -class IdentityProcessor(BatchProcessor[BatchEncoding]): +class IdentityProcessor(BatchProcessor[BatchEncoding, BatchEncoding]): def __call__(self, batch: Sequence[BatchEncoding]) -> BatchEncoding: stacked = reduce(_stack_batch_encodings, batch) return stacked + @property + def output_exemplar(self): + return BatchEncoding({}) + @property def num_cpus(self) -> int: return 0 -class ShardsDataset(ShardedDataset[T]): +class ShardsDataSource(ShardedDataSource[T]): def __init__(self, docs: List[List[T]]): self.docs = docs @@ -215,7 +219,7 @@ def open_shard_at_row(self, shard_name: str, row: int): return self.docs[int(shard_name)][row:] -class SingleShardDocumentSource(ShardedDataset[T]): +class SingleShardDocumentSource(ShardedDataSource[T]): def __init__(self, docs: List[T]): self.docs = docs diff --git a/tests/tiny_test_corpus.py b/tests/tiny_test_corpus.py index 5cd0e8a70..91597c137 100644 --- a/tests/tiny_test_corpus.py +++ b/tests/tiny_test_corpus.py @@ -2,10 +2,11 @@ import os import numpy +import numpy as np from levanter.data.audio import AudioIODatasetConfig -from levanter.data.shard_cache import ShardCache from levanter.data.text import LMDatasetConfig +from levanter.store.cache import TreeCache def _write_tiny_corpus(path): @@ -43,17 +44,24 @@ def tiny_asr_corpus_config(path): def construct_small_data_cache( path, num_shards=8, chunk_size=512, doc_len=128, vocab_size=1024 -) -> tuple[LMDatasetConfig, dict[str, ShardCache]]: - from levanter.data.shard_cache import SerialCacheWriter +) -> tuple[LMDatasetConfig, dict[str, TreeCache]]: + from levanter.store.cache import SerialCacheWriter rng = numpy.random.default_rng(0) - caches = {} + caches: dict[str, TreeCache] = {} + + exemplar = {"input_ids": numpy.zeros((doc_len,), dtype=numpy.int32)} for split in ["train", "validation"]: - with SerialCacheWriter(f"{path}/cache/{split}", chunk_size) as writer: + with SerialCacheWriter(f"{path}/cache/{split}", exemplar) as writer: for shard in range(num_shards): - writer.write_batch({"input_ids": rng.integers(0, vocab_size, size=(chunk_size, doc_len))}) + writer.write_batch( + [ + {"input_ids": rng.integers(0, vocab_size, size=(doc_len,), dtype=np.int32)} + for _ in range(chunk_size) + ] + ) caches[split] = writer.result() config = LMDatasetConfig(