Skip to content

Commit

Permalink
oopsy
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed May 29, 2024
1 parent 11b1906 commit 4e1c2a5
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ class LMMixtureDatasetConfig(LMTaskConfig):
train_weights: Dict[str, float] = field(default_factory=dict)
""" weights for each dataset source. They will be normalized to sum to 1. """
stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY)
seed: int = field(default=0)

def __post_init__(self):
if len(self.configs) == 0:
Expand All @@ -737,7 +738,9 @@ def train_set(
) -> ShardableDataset[np.ndarray]:
doc_caches = self.build_caches("train", monitors=monitors)
token_datasets = {name: TokenSeqDataset(cache, seq_len, stride=None) for name, cache in doc_caches.items()}
return MixtureDataset(datasets=token_datasets, weights=self.train_weights, stop_strategy=self.stop_strategy)
return MixtureDataset(
datasets=token_datasets, weights=self.train_weights, stop_strategy=self.stop_strategy, key=self.seed
)

def training_sets(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
Expand Down

0 comments on commit 4e1c2a5

Please sign in to comment.