Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] new functional for creating data splits in graph #5418

Merged
merged 21 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
cbd793b
new functional for creating data splits in graph
gvbazhenov Mar 2, 2023
f8e3460
Merge branch 'dmlc:master' into structural-shifts
gvbazhenov Mar 3, 2023
cdd4a17
minor fix in data split implementation
gvbazhenov Mar 3, 2023
a1e8c55
Merge branch 'structural-shifts' of https://github.com/gvbazhenov/dgl…
gvbazhenov Mar 3, 2023
845274d
Merge branch 'dmlc:master' into structural-shifts
gvbazhenov Mar 6, 2023
2aab8e5
apply suggestions from code review
gvbazhenov Mar 7, 2023
09af927
Merge branch 'dmlc:master' into structural-shifts
gvbazhenov Mar 7, 2023
910d757
refactoring + unit tests
gvbazhenov Mar 7, 2023
37e933f
Merge branch 'master' into structural-shifts
mufeili Mar 8, 2023
126323e
fix test file name
gvbazhenov Mar 8, 2023
7044854
Merge branch 'structural-shifts' of https://github.com/gvbazhenov/dgl…
gvbazhenov Mar 8, 2023
6dfe6de
Merge branch 'dmlc:master' into structural-shifts
gvbazhenov Mar 8, 2023
932cca0
move imports to the top
gvbazhenov Mar 8, 2023
2fa397e
Revert "fix test file name"
gvbazhenov Mar 8, 2023
c7ddd2b
Merge branch 'master' into structural-shifts
gvbazhenov Mar 8, 2023
00c842b
remove nccl submodule
gvbazhenov Mar 8, 2023
2f9677c
Merge branch 'master' into structural-shifts
frozenbugs Mar 9, 2023
d60cf9a
address linter issues
gvbazhenov Mar 9, 2023
0a622a9
Merge branch 'structural-shifts' of https://github.com/gvbazhenov/dgl…
gvbazhenov Mar 9, 2023
6d98537
Merge branch 'dmlc:master' into structural-shifts
gvbazhenov Mar 9, 2023
57fd348
Merge branch 'structural-shifts' of https://github.com/gvbazhenov/dgl…
gvbazhenov Mar 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api/python/dgl.data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,6 @@ Utilities
utils.save_info
utils.load_info
utils.add_nodepred_split
utils.mask_nodes_by_property
utils.add_node_property_split
utils.Subset
197 changes: 197 additions & 0 deletions python/dgl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
"save_tensors",
"load_tensors",
"add_nodepred_split",
"add_node_property_split",
"mask_nodes_by_property",
]


Expand Down Expand Up @@ -482,3 +484,198 @@ def add_nodepred_split(dataset, ratio, ntype=None):
g.nodes[ntype].data["train_mask"] = train_mask
g.nodes[ntype].data["val_mask"] = val_mask
g.nodes[ntype].data["test_mask"] = test_mask


def mask_nodes_by_property(property_values, part_ratios, random_seed=None):
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
"""Provide the split masks for training, ID and OOD validation, ID and OOD
testing according to the node property values.
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved

It sorts the nodes in the ascending order of their property values, splits
them into 5 non-intersecting parts, and creates 5 associated node mask arrays:
- 3 for the ID nodes: ``'in_train_mask'``, ``'in_valid_mask'``, ``'in_test_mask'``,
- and 2 for the OOD nodes: ``'out_valid_mask'``, ``'out_test_mask'``.

As described in `Evaluating Robustness and Uncertainty of Graph Models Under
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
Structural Distributional Shifts <https://arxiv.org/abs/2302.13875v1>`__,
this approach allows to create data splits with distributional shifts.

Parameters
----------
property_values : numpy ndarray
The node property (float) values to split the dataset by.
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
The length of array must be equal to the number of nodes in graph.
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
part_ratios : list
A list of 5 ratios for training, ID validation, ID test,
OOD validation, OOD testing parts. The values in list must sum to one.
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
random_seed : int, optional
Random seed to fix for the initial permutation of nodes. It is
used to create a random order for the nodes that have the same
property values. (default: None)
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved

Returns
----------
split_masks : dict
A python dict storing the mask names as keys and the corresponding
node mask arrays as values.
"""

num_nodes = len(property_values)
part_sizes = np.round(num_nodes * np.array(part_ratios)).astype(int)
part_sizes[-1] -= np.sum(part_sizes) - num_nodes

generator = np.random.RandomState(random_seed)
permutation = generator.permutation(num_nodes)

node_indices = np.arange(num_nodes)[permutation]
property_values = property_values[permutation]
in_distribution_size = np.sum(part_sizes[:3])

node_indices_ordered = node_indices[np.argsort(property_values)]
node_indices_ordered[: in_distribution_size] = generator.permutation(
node_indices_ordered[: in_distribution_size]
)

sections = np.cumsum(part_sizes)
node_split = np.split(node_indices_ordered, sections)[:-1]
mask_names = [
"in_train_mask",
"in_valid_mask",
"in_test_mask",
"out_valid_mask",
"out_test_mask",
]
split_masks = {}

for mask_name, node_indices in zip(mask_names, node_split):
split_mask = idx2mask(node_indices, num_nodes)
split_masks[mask_name] = generate_mask_tensor(split_mask)

return split_masks


def add_node_property_split(
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
dataset, part_ratios, property_name, ascending=True, random_seed=None
):
"""Create a data split with a distributional shift based on some node property.
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved

It splits each graph in the given dataset into training, ID and OOD validation, ID and OOD
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
testing parts for transductive node prediction task with structural distributional shifts.
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
As a result, it creates 5 associated node mask arrays for each graph:
- 3 for the ID nodes: ``'in_train_mask'``, ``'in_valid_mask'``, ``'in_test_mask'``,
- and 2 for the OOD nodes: ``'out_valid_mask'``, ``'out_test_mask'``.

Following `Evaluating Robustness and Uncertainty of Graph Models Under Structural
Distributional Shifts <https://arxiv.org/abs/2302.13875v1>`__, this function implements
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
3 particular strategies for inducing distributional shifts in graph — based on
**popularity**, **locality** or **density**.

Parameters
----------
dataset : DGLDataset or list of :class:`DGLGraph`
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
The dataset to induce structural distributional shift in.
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
part_ratios : list
A list of 5 ratio values for training, ID validation, ID test,
OOD validation and OOD test parts. The values must sum to 1.0.
property_name : str
The node property name to split the dataset by. Must be
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
``'popularity'``, ``'locality'`` or ``'density'``.
ascending : bool, optional
Whether to sort nodes in the ascending order of a particular
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
node property, so that more shifted OOD nodes have greater values
of the computed property (default: True)
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
random_seed : int, optional
Random seed to fix for the initial permutation of nodes. It is
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
used to create a random order for the nodes that have the same
property values. (default: None)
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved

Examples
--------
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
>>> dataset = dgl.data.AmazonCoBuyComputerDataset()
>>> print('in_valid_mask' in dataset[0].ndata)
False
>>> part_sizes = [0.3, 0.1, 0.1, 0.3, 0.2]
>>> property_name = 'popularity'
>>> dgl.data.utils.add_node_property_split(dataset, part_sizes, property_name)
>>> print('in_valid_mask' in dataset[0].ndata)
True
"""

assert property_name in [
"popularity",
"locality",
"density",
], "The name of property has to be 'popularity', 'locality', or 'density'"

assert (
len(part_ratios) == 5
), "The list of part ratios must contain 5 values"
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved

try:
import graph_tool as gt
from graph_tool import centrality, clustering, generation
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
except ImportError:
warnings.warn(
"graph-tool is required to compute the node property values: "
"https://graph-tool.skewed.de/"
)

user_direction = 1 if ascending else -1

for idx in range(len(dataset)):
graph_dgl = dataset[idx]
num_nodes = graph_dgl.num_nodes()

edge_list = F.stack(graph_dgl.edges(), dim=0)
if F.backend_name == "mxnet":
edge_list = edge_list.asnumpy().T
else:
edge_list = edge_list.numpy().T

graph_gt = gt.Graph()
graph_gt.add_vertex(num_nodes)
graph_gt.add_edge_list(edge_list)
graph_gt.set_directed(False)
generation.remove_parallel_edges(graph_gt)

if property_name == "popularity":
default_direction = -1
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
property_values = (
user_direction
* default_direction
* np.array(centrality.pagerank(graph_gt).get_array())
)

if property_name == "locality":
default_direction = -1
pagerank_values = np.array(
centrality.pagerank(graph_gt).get_array()
)

ohe_mask = np.zeros_like(pagerank_values)
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved
ohe_mask[np.argmax(pagerank_values)] = 1.0

property_gt = graph_gt.new_vertex_property("double")
property_gt.a = ohe_mask
gvbazhenov marked this conversation as resolved.
Show resolved Hide resolved

property_values = (
user_direction
* default_direction
* np.array(
centrality.pagerank(graph_gt, pers=property_gt).get_array()
)
)

if property_name == "density":
default_direction = -1
property_values = (
user_direction
* default_direction
* np.array(clustering.local_clustering(graph_gt).get_array())
)

node_masks = mask_nodes_by_property(
property_values, part_ratios, random_seed
)

for mask_name, node_mask in node_masks.items():
graph_dgl.ndata[mask_name] = node_mask