Skip to content

Commit

Permalink
Completely rework dataset/cache system: instant resume, perfect shuff…
Browse files Browse the repository at this point in the history
…le, stable mixtures and more (#716)

Introduces a massive rework of Levanter's cache system to support instant resume, perfect shuffle, stable mixtures and such.

The basic idea is to use TensorStore to store all of our data as a kind of janky column store (implemented in JaggedArrayStore) and pytrees of such (implemented in TreeStore).

TensorStore provides efficient storage and access to very large arrays. We still support streaming from an in progress cache via a new AsyncDataset class.

I've successfully tests this on the pile and, modulo the usual issues with the llama tokenizer on long documents/books, it behaves well.

Closes #626 #311 #119 #34
  • Loading branch information
dlwh committed Sep 5, 2024
1 parent ea4ea25 commit fbe27bc
Show file tree
Hide file tree
Showing 81 changed files with 6,691 additions and 4,455 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

scratch
cache
new-cache
wandb
checkpoints

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_entry_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
pip install .[test] "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
pip install soundfile librosa
- name: Run entry tests with pytest
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/run_ray_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
pip install .[test] "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
pip install soundfile librosa
- name: Run ray tests with pytest
run: |
XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=$(pwd)/tests:$(pwd)/src:$(pwd):. pytest tests -m ray
PYTHONPATH=$(pwd)/tests:$(pwd)/src:$(pwd):. pytest tests -m ray
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
pip install .[test] "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
pip install -r ./tests/requirements.txt
- name: Test with pytest
run: |
Expand Down
1 change: 0 additions & 1 deletion config/data/redpajama_1b_source.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@ cache_dir: gs://levanter-data/tokenized/redpajama-sample/
tokenizer: EleutherAI/gpt-neox-20b
splits:
- train
rows_per_chunk: 32768
1 change: 0 additions & 1 deletion config/data/redpajama_1t_source.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@ cache_dir: gs://levanter-data/tokenized/redpajama/
tokenizer: EleutherAI/gpt-neox-20b
splits:
- train
rows_per_chunk: 4096
1 change: 0 additions & 1 deletion config/data/rpv1_llama.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
cache_dir: gs://levanter-data/tokenized/redpajama_v1_llama_mixture
rows_per_chunk: 4096
tokenizer: "meta-llama/Llama-2-7b-hf"
configs:
arxiv:
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_nano_mixture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ data:
id: dlwh/wikitext_103_detokenized
train_weights:
wikitext: 1.0
w2: 0
w2: 1.0
model:
type: gpt2
hidden_dim: 32
Expand Down
163 changes: 163 additions & 0 deletions config/llama2_small_fast_mix.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
data:
tokenizer: "meta-llama/Llama-2-7b-hf"
cache_dir: "gs://levanter-data/new-tokenized/pile_mix/"
shuffle:
era_length: 10000
configs:
arxiv:
train_urls:
- gs://levanter-data/pile-domains/arxiv/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/arxiv/val.jsonl.zst
books2:
train_urls:
- gs://levanter-data/pile-domains/books2/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/books2/val.jsonl.zst
books3:
train_urls:
- gs://levanter-data/pile-domains/books3/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/books3/val.jsonl.zst
dm_math:
train_urls:
- gs://levanter-data/pile-domains/dm_math/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/dm_math/val.jsonl.zst
enron:
train_urls:
- gs://levanter-data/pile-domains/enron/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/enron/val.jsonl.zst
europarl:
train_urls:
- gs://levanter-data/pile-domains/europarl/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/europarl/val.jsonl.zst
free_law:
train_urls:
- gs://levanter-data/pile-domains/freelaw/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/freelaw/val.jsonl.zst
github:
train_urls:
- gs://levanter-data/pile-domains/github/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/github/val.jsonl.zst
hackernews:
train_urls:
- gs://levanter-data/pile-domains/hackernews/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/hackernews/val.jsonl.zst
nih:
train_urls:
- gs://levanter-data/pile-domains/nih/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/nih/val.jsonl.zst
opensubtitles:
train_urls:
- gs://levanter-data/pile-domains/opensubtitles/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/opensubtitles/val.jsonl.zst
owt2:
train_urls:
- gs://levanter-data/pile-domains/owt2/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/owt2/val.jsonl.zst
pg_19:
train_urls:
- gs://levanter-data/pile-domains/pg_19/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/pg_19/val.jsonl.zst
philpapers:
train_urls:
- gs://levanter-data/pile-domains/philpapers/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/philpapers/val.jsonl.zst
pile_cc:
train_urls:
- gs://levanter-data/pile-domains/pile_cc/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/pile_cc/val.jsonl.zst
pubmed_abs:
train_urls:
- gs://levanter-data/pile-domains/pubmed_abs/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/pubmed_abs/val.jsonl.zst
pubmed_central:
train_urls:
- gs://levanter-data/pile-domains/pubmed_central/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/pubmed_central/val.jsonl.zst
stack_exchange:
train_urls:
- gs://levanter-data/pile-domains/stack_exchange/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/stack_exchange/val.jsonl.zst
ubuntu_irc:
train_urls:
- gs://levanter-data/pile-domains/ubuntu_irc/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/ubuntu_irc/val.jsonl.zst
uspto:
train_urls:
- gs://levanter-data/pile-domains/uspto/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/uspto/val.jsonl.zst
wiki_en:
train_urls:
- gs://levanter-data/pile-domains/wiki_en/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/wiki_en/val.jsonl.zst
youtube_subtitles:
train_urls:
- gs://levanter-data/pile-domains/youtube_subtitles/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/youtube_subtitles/val.jsonl.zst
train_weights:
# these weights come from the paper https://arxiv.org/pdf/2101.00027.pdf
pile_cc: 0.1811
pubmed_central: 0.1440
books3: 0.1207
owt2: 0.1001
arxiv: 0.0896
github: 0.0759
free_law: 0.0612
stack_exchange: 0.0513
uspto: 0.0365
pubmed_abs: 0.0307
pg_19: 0.0217
opensubtitles: 0.0155
wiki_en: 0.0153
dm_math: 0.0124
ubuntu_irc: 0.0088
books2: 0.0075
europarl: 0.0073
hackernews: 0.0062
youtube_subtitles: 0.0060
philpapers: 0.0038
nih: 0.0030
enron: 0.0014
model:
type: llama
hidden_dim: 768
intermediate_dim: 2048
num_heads: 6
num_kv_heads: 6
num_layers: 12
seq_len: 1024
gradient_checkpointing: true
trainer:
tracker:
project: "levanter"
tags: [ "pile", "llama", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1

train_batch_size: 256
num_train_steps: 20000
optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
2 changes: 1 addition & 1 deletion docs/Fine-Tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def _get_data_source(path_or_id):
if fsspec_utils.exists(path_or_id):
return JsonDataset([path_or_id])
else:
return levanter.data.dataset_from_hf(path_or_id, split="train")
return levanter.data.datasource_from_hf(path_or_id, split="train")
```
Preprocessing in Levanter typically happens in two phases:
Expand Down
32 changes: 18 additions & 14 deletions docs/Training-On-Your-Data.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ data:

### Mixture of Sources

!!! warning

This feature is experimental and may change in the future.

If you have multiple sources of data (e.g., multiple domains, or distinct subsets of data), you can use the `data` section of your training configuration to specify them:

```yaml
Expand Down Expand Up @@ -145,23 +141,22 @@ validation data.
## Data Preprocessing

Levanter supports both online and offline preprocessing. Online preprocessing is done on-the-fly
during training. With online preprocessing, you don't need to think about preprocessing your data.
during training. With online preprocessing, you don't need to think about preprocessing your data
except to make sure it's in the right format and where you'd like to store the cached preprocessing
results.

Our data loading pipeline will automatically break and concatenate documents into chunks equal
to the model's `seq_len` parameter. It will also automatically add special tokens to the
end of documents.

We don't yet handle sequence-to-sequence tasks, but we plan to.

### Online Preprocessing

We have a sophisticated caching mechanism using [Ray](https://docs.ray.io/en/latest/)
that builds a cache of preprocessed data on the fly. Online caching happens transparently
in the background, using the mostly-idle CPU-cores of the machine(s) you are training on.

The cache that is built is fully reproducible, and can be used for future training runs.
Training will start as soon as each training machine has its first shard of data cached
and once the validation data is cached.
Training will start as soon as the system has the data it needs.

### Offline Preprocessing

Expand Down Expand Up @@ -190,19 +185,28 @@ python -m levanter.main.cache_dataset \
### Direct Cache Construction

As a final option, you can directly construct a cache of preprocessed data without using Ray. This is useful if you
have custom preprocessing logic or Ray isn't working for you for some reason. To do so, you can use [levanter.data.SerialCacheWriter][]
have custom preprocessing logic or Ray isn't working for you for some reason. To do so, you can use [levanter.store.SerialCacheWriter][]
to write batches directly. Here's an example:

```python
from levanter.data import SerialCacheWriter
import numpy as np
from levanter.store import SerialCacheWriter
exemplar = {
"input_ids": np.zeros((0), dtype=np.int32),
"attention_mask": np.zeros((0), dtype=np.int32),
"labels": np.zeros((0), dtype=np.int32),
}
with SerialCacheWriter(cache_dir, rows_per_chunk=1024) as writer:
with SerialCacheWriter(cache_dir, exemplar) as writer:
for batch in process_batches():
# batch should be a list of dicts, each with keys "input_ids", "attention_mask", and "labels"
writer.write_batch(batch)
```

`batch` can be a `list[dict]`, `dict[list]`, or `pyarrow.RecordBatch`. To work with `train_lm`, it should have an
`input_ids` key that is a list of `int`s.
In this case, `batch` should be a list of dicts, each with keys `"input_ids"`, `"attention_mask"`, and `"labels"`.
To work with `train_lm`, it should have an `input_ids` key that is a list of `int`s.

To use a cache like this, you can use the `passthrough` tokenizer:

Expand Down
Loading

0 comments on commit fbe27bc

Please sign in to comment.