Skip to content

Commit

Permalink
Merge branch 'master' into 0.5.x
Browse files Browse the repository at this point in the history
  • Loading branch information
BarclayII committed Sep 13, 2020
2 parents 1998335 + 76d66fd commit 76de9ae
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/pytorch/rgcn/experimental/partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def load_ogb(dataset, global_norm):
if ntype == category:
category_id = i

g = dgl.to_homo(hg)
g = dgl.to_homogeneous(hg, edata=['norm'])
if global_norm:
u, v, eid = g.all_edges(form='all')
_, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True)
Expand Down
22 changes: 18 additions & 4 deletions python/dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .. import backend as F
from .. import utils
from ..convert import heterograph
from ..distributed.dist_graph import DistGraph

# pylint: disable=unused-argument
def assign_block_eids(block, frontier):
Expand Down Expand Up @@ -244,6 +245,7 @@ def sample_blocks(self, g, seed_nodes, exclude_eids=None):
assign_block_eids(block, frontier)

seed_nodes = {ntype: block.srcnodes[ntype].data[NID] for ntype in block.srctypes}

# Pre-generate CSR format so that it can be used in training directly
block.create_formats_()
blocks.insert(0, block)
Expand Down Expand Up @@ -309,6 +311,7 @@ class NodeCollator(Collator):
"""
def __init__(self, g, nids, block_sampler):
self.g = g
self._is_distributed = isinstance(g, DistGraph)
if not isinstance(nids, Mapping):
assert len(g.ntypes) == 1, \
"nids should be a dict of node type and ids for graph with multiple node types"
Expand Down Expand Up @@ -352,6 +355,15 @@ def collate(self, items):
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)

# TODO(BarclayII) Because DistGraph doesn't have idtype and device implemented,
# this function does not work. I'm again skipping this step as a workaround.
# We need to fix this.
if not self._is_distributed:
if isinstance(items, dict):
items = utils.prepare_tensor_dict(self.g, items, 'items')
else:
items = utils.prepare_tensor(self.g, items, 'items')
blocks = self.block_sampler.sample_blocks(self.g, items)
output_nodes = blocks[-1].dstdata[NID]
input_nodes = blocks[0].srcdata[NID]
Expand Down Expand Up @@ -559,10 +571,11 @@ def dataset(self):

def _collate(self, items):
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)
items = {k: F.zerocopy_from_numpy(np.asarray(v)) for k, v in items.items()}
items = utils.prepare_tensor_dict(self.g_sampling, items, 'items')
else:
items = F.zerocopy_from_numpy(np.asarray(items))
items = utils.prepare_tensor(self.g_sampling, items, 'items')

pair_graph = self.g.edge_subgraph(items)
seed_nodes = pair_graph.ndata[NID]
Expand All @@ -582,10 +595,11 @@ def _collate(self, items):

def _collate_with_negative_sampling(self, items):
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)
items = {k: F.zerocopy_from_numpy(np.asarray(v)) for k, v in items.items()}
items = utils.prepare_tensor_dict(self.g_sampling, items, 'items')
else:
items = F.zerocopy_from_numpy(np.asarray(items))
items = utils.prepare_tensor(self.g_sampling, items, 'items')

pair_graph = self.g.edge_subgraph(items, preserve_nodes=True)
induced_edges = pair_graph.edata[EID]
Expand Down

0 comments on commit 76de9ae

Please sign in to comment.