Skip to content

Commit

Permalink
Merge branch 'master' into spatial_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rudongyu committed Jun 28, 2023
2 parents d6ffe2b + 2c03fe9 commit 8b30a96
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 35 deletions.
63 changes: 34 additions & 29 deletions graphbolt/src/csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,37 +141,42 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
torch::Tensor num_picked_neighbors_per_node =
torch::zeros({num_nodes + 1}, indptr_.options());

torch::parallel_for(0, num_nodes, 32, [&](size_t b, size_t e) {
for (size_t i = b; i < e; ++i) {
const auto nid = nodes[i].item<int64_t>();
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of the graph's "
"node IDs.");
const auto offset = indptr_[nid].item<int64_t>();
const auto num_neighbors = indptr_[nid + 1].item<int64_t>() - offset;
AT_DISPATCH_INTEGRAL_TYPES(
indptr_.scalar_type(), "parallel_for", ([&] {
torch::parallel_for(0, num_nodes, 32, [&](scalar_t b, scalar_t e) {
const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>();
for (scalar_t i = b; i < e; ++i) {
const auto nid = nodes[i].item<int64_t>();
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of the "
"graph's node IDs.");
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;

if (num_neighbors == 0) {
// Initialization is performed here because all tensors will be
// concatenated in the master thread, and having an undefined tensor
// during concatenation can result in a crash.
picked_neighbors_per_node[i] = torch::tensor({}, indptr_.options());
continue;
}
if (num_neighbors == 0) {
// Initialization is performed here because all tensors will be
// concatenated in the master thread, and having an undefined
// tensor during concatenation can result in a crash.
picked_neighbors_per_node[i] =
torch::tensor({}, indptr_.options());
continue;
}

if (consider_etype) {
picked_neighbors_per_node[i] = PickByEtype(
offset, num_neighbors, fanouts, replace, indptr_.options(),
type_per_edge_.value(), probs_or_mask);
} else {
picked_neighbors_per_node[i] = Pick(
offset, num_neighbors, fanouts[0], replace, indptr_.options(),
probs_or_mask);
}
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_per_node[i].size(0);
}
}); // End of the thread.
if (consider_etype) {
picked_neighbors_per_node[i] = PickByEtype(
offset, num_neighbors, fanouts, replace, indptr_.options(),
type_per_edge_.value(), probs_or_mask);
} else {
picked_neighbors_per_node[i] = Pick(
offset, num_neighbors, fanouts[0], replace, indptr_.options(),
probs_or_mask);
}
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_per_node[i].size(0);
}
}); // End of the thread.
}));

torch::Tensor subgraph_indptr =
torch::cumsum(num_picked_neighbors_per_node, 0);
Expand Down
4 changes: 2 additions & 2 deletions python/dgl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def add_nodepred_split(dataset, ratio, ntype=None):
def mask_nodes_by_property(property_values, part_ratios, random_seed=None):
"""Provide the split masks for a node split with distributional shift based on a given
node property, as proposed in `Evaluating Robustness and Uncertainty of Graph Models
Under Structural Distributional Shifts <https://arxiv.org/abs/2302.13875v1>`__
Under Structural Distributional Shifts <https://arxiv.org/abs/2302.13875>`__
It considers the in-distribution (ID) and out-of-distribution (OOD) subsets of nodes.
The ID subset includes training, validation and testing parts, while the OOD subset
Expand Down Expand Up @@ -569,7 +569,7 @@ def add_node_property_split(
):
"""Create a node split with distributional shift based on a given node property,
as proposed in `Evaluating Robustness and Uncertainty of Graph Models Under
Structural Distributional Shifts <https://arxiv.org/abs/2302.13875v1>`__
Structural Distributional Shifts <https://arxiv.org/abs/2302.13875>`__
It splits the nodes of each graph in the given dataset into 5 non-intersecting
parts based on their structural properties. This can be used for transductive node
Expand Down
5 changes: 5 additions & 0 deletions tests/backend/backend_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def abs(a):
pass


def seed(a):
"""Set seed to for random generator"""
pass


###############################################################################
# Tensor functions used *only* on index tensor
# ----------------
Expand Down
4 changes: 4 additions & 0 deletions tests/backend/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,7 @@ def dot(a, b):

def abs(a):
return nd.abs(a)


def seed(a):
return mx.random.seed(a)
4 changes: 4 additions & 0 deletions tests/backend/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,7 @@ def dot(a, b):

def abs(a):
return a.abs()


def seed(a):
return th.manual_seed(a)
4 changes: 4 additions & 0 deletions tests/backend/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,7 @@ def dot(a, b):

def abs(a):
return tf.abs(a)


def seed(a):
return tf.random.set_seed(a)
3 changes: 3 additions & 0 deletions tests/python/common/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1789,6 +1789,7 @@ def test_as_nodepred2():
assert len(ds.train_idx) == int(ds[0].num_nodes("Personen") * 0.1)


@unittest.skip("ogb is not available")
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
Expand Down Expand Up @@ -1841,6 +1842,7 @@ def test_as_linkpred():
assert 4000 < ds.test_edges[1][0].shape[0] <= 4224


@unittest.skip("ogb is not available")
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
Expand Down Expand Up @@ -2074,6 +2076,7 @@ def test_as_graphpred_reprocess():
assert len(ds.train_idx) == int(len(ds) * 0.1)


@unittest.skip("ogb is not available")
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
Expand Down
9 changes: 9 additions & 0 deletions tests/python/pytorch/nn/test_nn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import pickle
import random
from copy import deepcopy

import backend as F
Expand All @@ -8,6 +9,7 @@
import dgl.function as fn
import dgl.nn.pytorch as nn
import networkx as nx
import numpy as np # For setting seed for scipy
import pytest
import scipy as sp
import torch
Expand All @@ -24,6 +26,13 @@
random_graph,
)

# Set seeds to make tests fully reproducible.
SEED = 12345 # random.randint(1, 99999)
random.seed(SEED) # For networkx
np.random.seed(SEED) # For scipy
dgl.seed(SEED)
F.seed(SEED)

tmp_buffer = io.BytesIO()


Expand Down
2 changes: 1 addition & 1 deletion tutorials/large/L1_large_node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#
# OGB already prepared the data as DGL graph.
#

exit(0)
import os

os.environ["DGLBACKEND"] = "pytorch"
Expand Down
2 changes: 1 addition & 1 deletion tutorials/large/L2_large_link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
# This tutorial loads the dataset from the ``ogb`` package as in the
# :doc:`previous tutorial <L1_large_node_classification>`.
#

exit(0)
import os

os.environ["DGLBACKEND"] = "pytorch"
Expand Down
2 changes: 1 addition & 1 deletion tutorials/large/L4_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
works <L1_large_node_classification>`.
"""

exit(0)
import os

os.environ["DGLBACKEND"] = "pytorch"
Expand Down
2 changes: 1 addition & 1 deletion tutorials/multi/2_node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# Classification <../large/L1_large_node_classification>`
# tutorial.
#

exit(0)
import os

os.environ["DGLBACKEND"] = "pytorch"
Expand Down

0 comments on commit 8b30a96

Please sign in to comment.