From 92cf819131774654b74b7cf36ae6135cd9d4a0bc Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 27 May 2024 20:36:14 -0700 Subject: [PATCH] simulate data mixture --- analysis/count_data_mixture.py | 16 +++++++ analysis/simulate_data_mixture.py | 73 +++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) create mode 100644 analysis/count_data_mixture.py create mode 100644 analysis/simulate_data_mixture.py diff --git a/analysis/count_data_mixture.py b/analysis/count_data_mixture.py new file mode 100644 index 000000000..ad051d914 --- /dev/null +++ b/analysis/count_data_mixture.py @@ -0,0 +1,16 @@ +from collections import Counter + + +counter = Counter() +log_file = "log.txt" +with open(log_file, "r") as f: + for line in f: + counter.update(line.strip().split()) + +# normalize the counts +total = sum(counter.values()) +for key in counter: + counter[key] /= total + +for key, value in counter.items(): + print(f"{key}: {value:.3f}") \ No newline at end of file diff --git a/analysis/simulate_data_mixture.py b/analysis/simulate_data_mixture.py new file mode 100644 index 000000000..6e4f19427 --- /dev/null +++ b/analysis/simulate_data_mixture.py @@ -0,0 +1,73 @@ +import yaml +import numpy +from pathlib import Path + +from levanter.data.mixture import MixtureDataset, StopStrategy +from levanter.data.shard_cache import ShardCache +from levanter.data.text import TokenSeqDataset, LMDatasetConfig + + +DATA_CONFIG = "config/data/dolma_olmo_paloma.yaml" +CACHE_DIR = "scratch/cache" + + + +def construct_small_data_cache( + path, num_shards=8, chunk_size=512, doc_len=128, vocab_size=1024 +) -> tuple[LMDatasetConfig, dict[str, ShardCache]]: + from levanter.data.shard_cache import SerialCacheWriter + + rng = numpy.random.default_rng(0) + + caches = {} + + for split in ["train", "validation"]: + with SerialCacheWriter(f"{path}/cache/{split}", chunk_size) as writer: + for shard in range(num_shards): + writer.write_batch({"input_ids": rng.integers(0, vocab_size, size=(chunk_size, doc_len))}) + caches[split] = writer.result() + + config = LMDatasetConfig( + train_urls=[f"file://{path}/train/docs.jsonl"], + validation_urls=[f"file://{path}/validation/docs.jsonl"], + cache_dir=f"{path}/cache", + vocab_size=vocab_size, + tokenizer="passthrough", + ) + + return config, caches + + +def simulate_olmo(): + seq_len = 10 + num_docs = 1000 + # load data config + with open(DATA_CONFIG, "r") as f: + data_config = yaml.safe_load(f) + weights_config = data_config["train_weights"] + + # prepare data cache + datasets = {} + Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) + for data_name in weights_config.keys(): + data_name = data_name.replace(" ", "_") + construct_small_data_cache( + f"{CACHE_DIR}/{data_name}", num_shards=1, chunk_size=num_docs, doc_len=seq_len + ) + ds = TokenSeqDataset.load(seq_len, f"{CACHE_DIR}/{data_name}/cache/train") + datasets[data_name] = ds + + # compare mixture with different strategies + dataset = MixtureDataset( + datasets=datasets, + weights=weights_config, + stop_strategy=StopStrategy.FIRST_STOP_STRATEGY, + ) + for idx, content in enumerate(dataset): + # print(f"idx: {idx}, content: {content}") + if idx > 10000: + break + + +if __name__ == "__main__": + simulate_olmo()