Skip to content

Commit

Permalink
ok it runs!
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jun 3, 2024
1 parent 30ca458 commit 6b73d28
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 43 deletions.
38 changes: 36 additions & 2 deletions src/levanter/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from levanter.data import Dataset
from levanter.data.dataset import ShardableDataset
from levanter.data.sampler import ItemSampler
from levanter.mesh import local_devices_mapping, process_mesh_mapping
from levanter.shapes import NamedShapeSpec, ShapeSpec, to_raw_shape
from levanter.utils.background_iterable import BackgroundIterable
Expand Down Expand Up @@ -158,9 +159,11 @@ class ShardedBatchLoader(BatchLoader[Ex]):
:param max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread
"""

# TODO: clean this up

def __init__(
self,
local_dataset: ShardableDataset[Ex],
local_dataset: ShardableDataset[Ex] | ItemSampler[Ex],
mesh: Mesh,
Batch: hax.Axis,
axis_resources: Optional[ResourceMapping] = None,
Expand Down Expand Up @@ -188,10 +191,21 @@ def __init__(
self.local_devices_map = local_devices_map
self.per_device_batch_size = self.batch_size // self.mesh.devices.shape[0] // self.mesh.devices.shape[1]

self.item_dataset = local_dataset.shard(process_data_pos, num_data_process_groups)
if isinstance(local_dataset, ShardableDataset):
self.item_dataset = local_dataset.shard(process_data_pos, num_data_process_groups)
self.is_dataset = True
else:
self.item_dataset = local_dataset # type: ignore
self.is_dataset = False

super().__init__(max_capacity, axis_resources)

def _produce_batches(self) -> Iterator[PyTree]:

if isinstance(self.item_dataset, ItemSampler):
yield from self._produce_batches_from_sampler()
return

one_item_generator = non_caching_cycle(self.item_dataset)
batched = _batched(one_item_generator, self.local_batch_size)

Expand All @@ -215,6 +229,26 @@ def batch_callback(global_begin, _):

yield batch

def _produce_batches_from_sampler(self) -> Iterator[PyTree]:
def batch_callback(step, global_begin, global_end):
key = jax.random.PRNGKey(step)
elems = []
for i in range(global_begin, global_end):
elems.append(self.item_dataset.sample(step * self.batch_size + i, key=key))

return elems

exemplar = self.item_dataset.sample(0, key=jax.random.PRNGKey(0)) # type: ignore

step = 0
while True:
batch = self._construct_global_array_for_tree(
item_exemplar=exemplar,
get_batch_items=lambda global_begin, global_end: batch_callback(step, global_begin, global_end),
)

yield batch

@property
def batch_size(self) -> int:
"""Returns the 'global' batch size: the effective number of examples in a batch across all devices/hosts"""
Expand Down
46 changes: 46 additions & 0 deletions src/levanter/data/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import abc
from typing import Generic, TypeVar

from jax.random import PRNGKey


# from levanter.data import ShardCache

T = TypeVar("T")


class ItemSampler(Generic[T], abc.ABC):
"""
Samples batches of data from a dataset.
"""

# TODO: getstate/setstate

@abc.abstractmethod
def sample(self, index: int, *, key: PRNGKey) -> T:
"""
Samples a batch of data from the dataset.
Args:
index: The index of the item to sample. This can be any nonnegative integer.
key: The random key if you need additional randomness
Returns:
The sampled data.
"""
raise NotImplementedError


class RowSampler(ItemSampler[T]):
"""
Samples rows from a shard cache randomly.
"""

def __init__(self, cache):
self.cache = cache

def sample(self, index, *, key: PRNGKey) -> T:
max_index = self.cache.final_row_count()
index = index % max_index

return self.cache.get_row(index)
37 changes: 23 additions & 14 deletions src/levanter/data/shard_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ class ChunkMetadata:
# these are running totals into row groups
row_offsets: Optional[np.ndarray] = dataclasses.field(metadata=NP_INT64_CODEC, default=None)
field_offsets: Optional[Dict[str, list[int]]] = dataclasses.field(metadata=DICT_NP_INT64_CODEC, default=None)
version: str = "2"

@property
def num_record_batches(self):
Expand Down Expand Up @@ -2243,7 +2242,7 @@ def _migrate_cache_to_add_offsets(cache_dir: str):
try:
ledger = _load_cache_ledger(cache_dir)
# if we already have ledger metadata, see if we have offsets
if ledger.chunks[0].field_offsets is not None:
if len(ledger.field_offsets):
logger.info("Offsets already present in ledger, skipping migration")
return

Expand All @@ -2256,26 +2255,36 @@ def _migrate_cache_to_add_offsets(cache_dir: str):
# now we have to do the same for the shard metadata, which are just lists of chunk metadatas
# these are named <shard_name>.json. We find them because they're not chunks nad not hte ledger
all_chunks = _migrate_shard_metadatas(cache_dir)
_migrate_ledger(all_chunks, cache_dir, ledger)

if ledger is not None:
_migrate_ledger(all_chunks, cache_dir, ledger)


def _migrate_ledger(all_chunks, cache_dir, ledger):
found = set()
if ledger is not None:
for chunk in ledger.chunks:
if chunk.name not in all_chunks:
raise ValueError(f"Chunk {chunk.name} in ledger but not found in cache")
global_row_offsets = [0]
global_field_offsets = {name: [0] for name in ledger.chunks[0].field_counts.keys()}
for chunk in ledger.chunks:
if chunk.name not in all_chunks:
raise ValueError(f"Chunk {chunk.name} in ledger but not found in cache")

chunk.field_offsets = all_chunks[chunk.name].field_offsets
chunk.row_offsets = all_chunks[chunk.name].row_offsets
global_row_offsets.append(global_row_offsets[-1] + chunk.num_rows)
for field, count in chunk.field_counts.items():
global_field_offsets[field].append(global_field_offsets[field][-1] + count)

found.add(chunk.name)

chunk.field_offsets = all_chunks[chunk.name].field_offsets
chunk.row_offsets = all_chunks[chunk.name].row_offsets
found.add(chunk.name)
ledger.row_offsets = np.asarray(global_row_offsets, dtype=np.int64)
ledger.field_offsets = {k: np.asarray(v, dtype=np.int64) for k, v in global_field_offsets.items()}

missing_chunks = set(all_chunks.keys()) - found
missing_chunks = set(all_chunks.keys()) - found

if len(missing_chunks) > 0:
raise ValueError(f"Found chunks in cache but not in ledger: {missing_chunks}")
if len(missing_chunks) > 0:
raise ValueError(f"Found chunks in cache but not in ledger: {missing_chunks}")

_serialize_json_and_commit(os.path.join(cache_dir, LEDGER_FILE_NAME), ledger)
_serialize_json_and_commit(os.path.join(cache_dir, LEDGER_FILE_NAME), ledger)


def _migrate_shard_metadatas(cache_dir):
Expand Down
70 changes: 47 additions & 23 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,26 @@
from dataclasses import dataclass
from functools import cached_property
from itertools import chain
from typing import Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union

import braceexpand
import datasets
import equinox as eqx
import fsspec
import jax
import jax.numpy as jnp
import numpy as np
import pyarrow as pa
import regex
from draccus import field
from jax.random import PRNGKey
from jaxtyping import PRNGKeyArray

import haliax as hax
from haliax import Axis

from levanter.data.mixture import MixtureDataset, StopStrategy
from levanter.data.sampler import ItemSampler

# intercept the logging nonsense here
from levanter.logging import silence_transformer_nag # noqa
Expand Down Expand Up @@ -58,7 +61,6 @@
logger = logging.getLogger("levanter.data.text")

# TASKS:
# TODO: consider adding indexing a la Map-style datasets
# TODO: support seeking/serialization/restore in the dataset

LEDGER_FILE = "ledger.json"
Expand Down Expand Up @@ -121,37 +123,44 @@ def _create_lm_example(tokens, key):
yield example


class TokenSeqSampler:
def __init__(self, doc_cache: ShardCache, seq_len, seed: int, field_name: str = "input_ids"):
class CausalLmSampler(ItemSampler[LmExample]):
def __init__(self, token_sampler, QPos: Axis, field_name: str = "input_ids"):
self.token_sampler = token_sampler
self.QPos = QPos

def sample(self, index, *, key: PRNGKey) -> LmExample:
tokens = self.token_sampler.sample(index, key=key)
return LmExample.causal(tokens=hax.named(tokens, self.QPos), ignore_id=DEFAULT_IGNORE_INDEX)


class TokenSeqSampler(ItemSampler[np.ndarray]):
def __init__(self, doc_cache: ShardCache, seq_len, field_name: str = "input_ids"):
self.doc_cache = doc_cache
self.seq_len = seq_len
self.seed = seed
self.field_name = field_name

self.num_tokens_in_dataset = self.doc_cache.final_field_count(field_name)

def sample(self, step: int) -> np.ndarray:
# mix the seed with the step
rng = np.random.default_rng(self.seed + step)
out: list = []
while len(out) < self.seq_len:
remaining = self.seq_len - len(out)
idx = int(rng.integers(0, self.num_tokens_in_dataset - remaining, size=1)[0])
out.extend(self.doc_cache.get_field_slice(self.field_name, idx, idx + remaining))
def sample(self, index, *, key: PRNGKey) -> np.ndarray:
max_index = self.num_tokens_in_dataset - self.seq_len
index = index % max_index

return np.array(out)
return self.doc_cache.get_field_slice(self.field_name, index, index + self.seq_len)


class RowSampler:
def __init__(self, doc_cache: ShardCache, seed: int):
self.doc_cache = doc_cache
self.seed = seed
self.num_rows = self.doc_cache.final_row_count()
T = TypeVar("T")


class MixtureSampler(ItemSampler[T]):
def __init__(self, samplers: List[ItemSampler[T]], weights: List[float], key: PRNGKey):
self.samplers = samplers
self.weights = jnp.array(weights, dtype=jnp.float32)
self.key = key

def sample(self, step: int):
rng = np.random.default_rng(self.seed + step)
idx = rng.integers(0, self.num_rows, size=1)[0]
return self.doc_cache.get_row(idx)
def sample(self, index, *, key: PRNGKey) -> T:
key, subkey = jax.random.split(key)
i = jax.random.choice(subkey, len(self.samplers), shape=(), p=self.weights)
return self.samplers[i].sample(index, key=key)


class TokenSeqDataset(ShardableDataset[np.ndarray]):
Expand Down Expand Up @@ -605,6 +614,10 @@ def train_set(
) -> ShardableDataset[np.ndarray]:
pass

@abc.abstractmethod
def train_sampler(self, seq_len, monitors: Union[bool, List[MetricsMonitor]] = True) -> ItemSampler[np.ndarray]:
pass

@abc.abstractmethod
def validation_sets(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
Expand Down Expand Up @@ -637,6 +650,10 @@ def train_set(
raise ValueError("No training set!")
return ds

def train_sampler(self, seq_len, monitors: Union[bool, List[MetricsMonitor]] = True) -> ItemSampler[np.ndarray]:
cache = self.build_or_load_cache("train", monitors=monitors)
return TokenSeqSampler(cache.chunk_cache, seq_len) # type: ignore

def validation_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
) -> Optional[TokenSeqDataset]:
Expand Down Expand Up @@ -766,6 +783,13 @@ def __post_init__(self):
f" {self.train_weights.keys()}"
)

def train_sampler(self, seq_len, monitors: Union[bool, List[MetricsMonitor]] = True) -> ItemSampler[np.ndarray]:
doc_caches = self.build_caches("train", monitors=monitors)
token_datasets = {name: TokenSeqSampler(cache.chunk_cache, seq_len) for name, cache in doc_caches.items()}
return MixtureSampler(
list(token_datasets.values()), list(self.train_weights.values()), key=jax.random.key(self.seed)
)

def train_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
) -> ShardableDataset[np.ndarray]:
Expand Down
6 changes: 2 additions & 4 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import levanter
from levanter import callbacks
from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback
from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig
from levanter.data.text import CausalLmDataset, CausalLmSampler, LMDatasetConfig, LMMixtureDatasetConfig
from levanter.models.gpt2 import Gpt2Config
from levanter.models.lm_model import LmConfig
from levanter.optim import AdamConfig, OptimizerConfig
Expand Down Expand Up @@ -103,9 +103,7 @@ def main(config: TrainLmConfig):
KeyPos = config.model.KeyPos

tagged_eval_datasets = config.data.tagged_eval_sets(Pos.size)
train_dataset = CausalLmDataset(
config.data.train_set(Pos.size), Pos, KeyPos, ignore_index=config.data.ignore_token_id
)
train_dataset = CausalLmSampler(config.data.train_sampler(Pos.size), Pos)

# to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to
# For most things, we just insist you specify the config right, but tokenizers often have strange numbers of
Expand Down

0 comments on commit 6b73d28

Please sign in to comment.