Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into small_fast_ci
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Sep 13, 2024
2 parents e0eda64 + 2645efb commit 7126de8
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 26 deletions.
3 changes: 1 addition & 2 deletions config/llama2_small_fast_mix.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
data:
tokenizer: "meta-llama/Llama-2-7b-hf"
cache_dir: "gs://levanter-data/new-tokenized/pile_mix/"
shuffle:
era_length: 10000
shuffle: 10000
configs:
arxiv:
train_urls:
Expand Down
17 changes: 9 additions & 8 deletions src/levanter/data/permutation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import dataclasses
from typing import Optional, Sequence

import jax.random
Expand Down Expand Up @@ -29,7 +28,11 @@ def is_finite(self) -> bool:
return self.dataset.is_finite()

async def current_len(self) -> Optional[int]:
return await self.dataset.current_len()
if await self.final_length_is_known():
return await self.async_len()
# In general, we can't know the current length until we know the entire length
return None
# return await self.dataset.current_len()

async def getitem_async(self, index: int) -> T_co:
permutation = await self._get_permutation()
Expand All @@ -41,9 +44,12 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:

async def _get_permutation(self):
if self._permutation is None:
self._permutation = Permutation(await self.dataset.async_len(), self.key)
self._permutation = Permutation(await self.async_len(), self.key)
return self._permutation

async def wait_until_len_at_least(self, length: int) -> int:
return await self.async_len()


class EraShufflingDataset(AsyncDataset[T_co]):
"""
Expand Down Expand Up @@ -128,8 +134,3 @@ async def wait_until_len_at_least(self, length: int) -> int:
# wait until we hit the next era
next_era_end = (length // self.era_length + 1) * self.era_length
return await self.dataset.wait_until_len_at_least(next_era_end)


@dataclasses.dataclass
class EraConfig:
era_length: int
14 changes: 6 additions & 8 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from levanter.data import AsyncDataset
from levanter.data.dataset import MappedAsyncDataset
from levanter.data.mixture import MixtureDataset, StopStrategy
from levanter.data.permutation import EraConfig

# intercept the logging nonsense here
from levanter.logging import silence_transformer_nag # noqa
Expand Down Expand Up @@ -113,7 +112,6 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
out.append(token_arrays.data[offset : offset + self.seq_len].read())

out = await asyncio.gather(*out)

return out

def get_batch_sync(self, indices: Sequence[int]) -> Sequence[T_co]:
Expand Down Expand Up @@ -549,9 +547,9 @@ class LMTaskConfig(abc.ABC):
enforce_eos: bool = True # whether to append eos even if the tokenizer doesn't

ignore_token_id: Optional[int] = None
shuffle: bool | EraConfig = False
shuffle: bool | int = False
"""whether to shuffle the dataset. True means shuffle the whole dataset, False means don't shuffle.
If you want to shuffle in eras, provide an EraConfig (which asks for an era_length)"""
If you want to shuffle in eras, set this to the era length"""

@cached_property
def the_tokenizer(self) -> PreTrainedTokenizerBase:
Expand Down Expand Up @@ -599,8 +597,8 @@ def train_set(

if self.shuffle is True:
ds = ds.shuffle(key)
elif isinstance(self.shuffle, EraConfig):
ds = ds.era_shuffle(self.shuffle.era_length, key=key)
elif isinstance(self.shuffle, int) and self.shuffle > 0:
ds = ds.era_shuffle(self.shuffle, key=key)

return ds # type: ignore

Expand Down Expand Up @@ -754,8 +752,8 @@ def train_set(
def shuffle_ds(ds, key):
if self.shuffle is True:
ds = ds.shuffle(key)
elif isinstance(self.shuffle, EraConfig):
ds = ds.era_shuffle(self.shuffle.era_length, key=key)
elif isinstance(self.shuffle, int):
ds = ds.era_shuffle(self.shuffle, key=key)

return ds

Expand Down
15 changes: 7 additions & 8 deletions src/levanter/store/stress_test_new_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,13 @@ def ensure_cache(new_cache_path):
if __name__ == "__main__":
import sys

if not len(sys.argv) == 3:
print("Usage: convert_to_new_cache.py old_cache_path new_cache_path")
if not len(sys.argv) == 2:
print("Usage: convert_to_new_cache.py new_cache_path")
sys.exit(1)

for split in ["validation", "train"]:
print(f"Split: {split}", flush=True)
in_path = os.path.join(sys.argv[1], split)
out_path = os.path.join(sys.argv[2], split)
cache_path = os.path.join(sys.argv[1], split)
# convert_to_new_cache(in_path, out_path)
# with capture_time() as time_fn:
# bench_old_cache(in_path)
Expand All @@ -126,24 +125,24 @@ def ensure_cache(new_cache_path):
exemplar = {"input_ids": np.zeros((SEQ_LEN,), dtype=np.int32)}

with capture_time() as time_fn:
bench_new_cache_serial(exemplar, out_path)
bench_new_cache_serial(exemplar, cache_path)
tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn()
print(f"New Cache Serial: {time_fn()} ({tokens_per_second} tps)", flush=True)

with capture_time() as time_fn:
asyncio.run(bench_new_cache_serial_tokenseq(exemplar, out_path))
asyncio.run(bench_new_cache_serial_tokenseq(exemplar, cache_path))
tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn()

print(f"New Cache Serial TokenSeq: {time_fn()} ({tokens_per_second} tps)", flush=True)

with capture_time() as time_fn:
bench_new_cache_random(exemplar, out_path)
bench_new_cache_random(exemplar, cache_path)
tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn()

print(f"New Cache Random: {time_fn()} ({tokens_per_second} tps)", flush=True)

with capture_time() as time_fn:
asyncio.run(bench_new_cache_permutation_random(exemplar, out_path))
asyncio.run(bench_new_cache_permutation_random(exemplar, cache_path))
tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn()

print(f"New Cache Permutation: {time_fn()} ({tokens_per_second} tps)", flush=True)

0 comments on commit 7126de8

Please sign in to comment.