diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 89fb4a192..5ced3d315 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -721,6 +721,7 @@ class LMMixtureDatasetConfig(LMTaskConfig): train_weights: Dict[str, float] = field(default_factory=dict) """ weights for each dataset source. They will be normalized to sum to 1. """ stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY) + seed: int = field(default=0) def __post_init__(self): if len(self.configs) == 0: @@ -737,7 +738,9 @@ def train_set( ) -> ShardableDataset[np.ndarray]: doc_caches = self.build_caches("train", monitors=monitors) token_datasets = {name: TokenSeqDataset(cache, seq_len, stride=None) for name, cache in doc_caches.items()} - return MixtureDataset(datasets=token_datasets, weights=self.train_weights, stop_strategy=self.stop_strategy) + return MixtureDataset( + datasets=token_datasets, weights=self.train_weights, stop_strategy=self.stop_strategy, key=self.seed + ) def training_sets( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True