diff --git a/pytorch3d/implicitron/dataset/data_loader_map_provider.py b/pytorch3d/implicitron/dataset/data_loader_map_provider.py index 8c0841ccd..50a792183 100644 --- a/pytorch3d/implicitron/dataset/data_loader_map_provider.py +++ b/pytorch3d/implicitron/dataset/data_loader_map_provider.py @@ -12,7 +12,7 @@ from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from torch.utils.data import ( BatchSampler, - ChainDataset, + ConcatDataset, DataLoader, RandomSampler, Sampler, @@ -482,7 +482,7 @@ def _train_loader( num_batches=num_batches, ) return DataLoader( - ChainDataset([dataset, train_dataset]), + ConcatDataset([dataset, train_dataset]), batch_sampler=sampler, **data_loader_kwargs, )