Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torch core instead of torchdata modules. #7609

Merged
merged 4 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
)

# pylint: disable=wrong-import-position
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
from torch.utils.data import functional_datapipe, IterDataPipe

from .internal_utils import (
get_nonproperty_attributes,
Expand Down
9 changes: 4 additions & 5 deletions python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from concurrent.futures import ThreadPoolExecutor

import torch
import torch.utils.data
import torch.utils.data as torch_data
import torchdata.dataloader2.graph as dp_utils
import torchdata.datapipes as dp

from .base import CopyTo, get_host_to_device_uva_stream
from .feature_fetcher import FeatureFetcher, FeatureFetcherStartMarker
Expand Down Expand Up @@ -70,7 +69,7 @@ def _set_worker_id(worked_id):
torch.ops.graphbolt.set_worker_id(worked_id)


class MultiprocessingWrapper(dp.iter.IterDataPipe):
class MultiprocessingWrapper(torch_data.IterDataPipe):
"""Wraps a datapipe with multiprocessing.

Parameters
Expand All @@ -88,7 +87,7 @@ class MultiprocessingWrapper(dp.iter.IterDataPipe):

def __init__(self, datapipe, num_workers=0, persistent_workers=True):
self.datapipe = datapipe
self.dataloader = torch.utils.data.DataLoader(
self.dataloader = torch_data.DataLoader(
datapipe,
batch_size=None,
num_workers=num_workers,
Expand All @@ -100,7 +99,7 @@ def __iter__(self):
yield from self.dataloader


class DataLoader(torch.utils.data.DataLoader):
class DataLoader(torch_data.DataLoader):
"""Multiprocessing DataLoader.

Iterates over the data pipeline with everything before feature fetching
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper
from torch.utils.data.datapipes.iter import Mapper

from ..base import ORIGINAL_EDGE_ID
from ..internal import compact_csc_format, unique_and_compact_csc_formats
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/graphbolt/minibatch_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch.utils.data import functional_datapipe

from torchdata.datapipes.iter import Mapper
from torch.utils.data.datapipes.iter import Mapper

from .minibatch import MiniBatch

Expand Down
2 changes: 1 addition & 1 deletion tests/python/pytorch/graphbolt/test_feature_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import dgl.graphbolt as gb
import pytest
import torch
from torchdata.datapipes.iter import Mapper
from torch.utils.data.datapipes.iter import Mapper

from . import gb_test_utils

Expand Down
Loading