Skip to content

Commit

Permalink
simulate data mixture
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan-Zhou committed May 28, 2024
1 parent 46fe0b7 commit 92cf819
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
16 changes: 16 additions & 0 deletions analysis/count_data_mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from collections import Counter


counter = Counter()
log_file = "log.txt"
with open(log_file, "r") as f:
for line in f:
counter.update(line.strip().split())

# normalize the counts
total = sum(counter.values())
for key in counter:
counter[key] /= total

for key, value in counter.items():
print(f"{key}: {value:.3f}")
73 changes: 73 additions & 0 deletions analysis/simulate_data_mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import yaml
import numpy
from pathlib import Path

from levanter.data.mixture import MixtureDataset, StopStrategy
from levanter.data.shard_cache import ShardCache
from levanter.data.text import TokenSeqDataset, LMDatasetConfig


DATA_CONFIG = "config/data/dolma_olmo_paloma.yaml"
CACHE_DIR = "scratch/cache"



def construct_small_data_cache(
path, num_shards=8, chunk_size=512, doc_len=128, vocab_size=1024
) -> tuple[LMDatasetConfig, dict[str, ShardCache]]:
from levanter.data.shard_cache import SerialCacheWriter

rng = numpy.random.default_rng(0)

caches = {}

for split in ["train", "validation"]:
with SerialCacheWriter(f"{path}/cache/{split}", chunk_size) as writer:
for shard in range(num_shards):
writer.write_batch({"input_ids": rng.integers(0, vocab_size, size=(chunk_size, doc_len))})
caches[split] = writer.result()

config = LMDatasetConfig(
train_urls=[f"file://{path}/train/docs.jsonl"],
validation_urls=[f"file://{path}/validation/docs.jsonl"],
cache_dir=f"{path}/cache",
vocab_size=vocab_size,
tokenizer="passthrough",
)

return config, caches


def simulate_olmo():
seq_len = 10
num_docs = 1000
# load data config
with open(DATA_CONFIG, "r") as f:
data_config = yaml.safe_load(f)
weights_config = data_config["train_weights"]

# prepare data cache
datasets = {}
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
for data_name in weights_config.keys():
data_name = data_name.replace(" ", "_")
construct_small_data_cache(
f"{CACHE_DIR}/{data_name}", num_shards=1, chunk_size=num_docs, doc_len=seq_len
)
ds = TokenSeqDataset.load(seq_len, f"{CACHE_DIR}/{data_name}/cache/train")
datasets[data_name] = ds

# compare mixture with different strategies
dataset = MixtureDataset(
datasets=datasets,
weights=weights_config,
stop_strategy=StopStrategy.FIRST_STOP_STRATEGY,
)
for idx, content in enumerate(dataset):
# print(f"idx: {idx}, content: {content}")
if idx > 10000:
break


if __name__ == "__main__":
simulate_olmo()

0 comments on commit 92cf819

Please sign in to comment.