Skip to content

Commit

Permalink
[Utils] Edge and LINKX homophily measure (dmlc#5382)
Browse files Browse the repository at this point in the history
* Update

* lint

* lint

* r prefix

* CI

* lint

* skip TF

* Update

* edge homophily

* linkx homophily

* format

* skip TF

* fix test

* update

* lint

* lint

* review

* lint

* update

* lint

* update

* CI

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
  • Loading branch information
2 people authored and DominikaJedynak committed Mar 12, 2024
1 parent ccf692d commit 6cdbd37
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 18 deletions.
2 changes: 2 additions & 0 deletions docs/source/api/python/dgl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ Utilities for measuring homophily of a graph
.. autosummary::
:toctree: ../../generated/

edge_homophily
node_homophily
linkx_homophily

Utilities
-----------------------------------------------
Expand Down
164 changes: 148 additions & 16 deletions python/dgl/homophily.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
"""Utils for tacking graph homophily and heterophily"""
from . import backend as F, function as fn
# pylint: disable=W0611
from . import function as fn

__all__ = ["node_homophily"]
try:
import torch
except ImportError:
HAS_TORCH = False
else:
HAS_TORCH = True

__all__ = ["node_homophily", "edge_homophily", "linkx_homophily"]


def check_pytorch():
"""Check if PyTorch is the backend."""
if HAS_TORCH is False:
raise ModuleNotFoundError(
"This function requires PyTorch to be the backend."
)


def get_long_edges(graph):
"""Internal function for getting the edges of a graph as long tensors."""
src, dst = graph.edges()
return src.long(), dst.long()


def node_homophily(graph, y):
r"""Homophily measure from `Geom-GCN: Geometric Graph Convolutional Networks
<https://arxiv.org/abs/2002.05287>`__
r"""Homophily measure from `Geom-GCN: Geometric Graph Convolutional
Networks <https://arxiv.org/abs/2002.05287>`__
We follow the practice of a later paper `Large Scale Learning on
Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods
Expand All @@ -15,8 +37,8 @@ def node_homophily(graph, y):
Mathematically it is defined as follows:
.. math::
\frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \frac{ | \{ (u,v) : u
\in \mathcal{N}(v) \wedge y_v = y_u \} | } { |\mathcal{N}(v)| }
\frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \frac{ | \{u
\in \mathcal{N}(v): y_v = y_u \} | } { |\mathcal{N}(v)| },
where :math:`\mathcal{V}` is the set of nodes, :math:`\mathcal{N}(v)` is
the predecessors of node :math:`v`, and :math:`y_v` is the class of node
Expand All @@ -25,14 +47,14 @@ def node_homophily(graph, y):
Parameters
----------
graph : DGLGraph
The graph
y : Tensor
The node labels, which is a tensor of shape (|V|)
The graph.
y : torch.Tensor
The node labels, which is a tensor of shape (|V|).
Returns
-------
float
The node homophily value
The node homophily value.
Examples
--------
Expand All @@ -44,14 +66,124 @@ def node_homophily(graph, y):
>>> dgl.node_homophily(graph, y)
0.6000000238418579
"""
check_pytorch()
with graph.local_scope():
# Handle the case where graph is of dtype int32.
src, dst = get_long_edges(graph)
# Compute y_v = y_u for all edges.
graph.edata["same_class"] = (y[src] == y[dst]).float()
graph.update_all(
fn.copy_e("same_class", "m"), fn.mean("m", "same_class_deg")
)
return graph.ndata["same_class_deg"].mean(dim=0).item()


def edge_homophily(graph, y):
r"""Homophily measure from `Beyond Homophily in Graph Neural Networks:
Current Limitations and Effective Designs
<https://arxiv.org/abs/2006.11468>`__
Mathematically it is defined as follows:
.. math::
\frac{| \{ (u,v) : (u,v) \in \mathcal{E} \wedge y_u = y_v \} | }
{|\mathcal{E}|},
where :math:`\mathcal{E}` is the set of edges, and :math:`y_u` is the class
of node :math:`u`.
Parameters
----------
graph : DGLGraph
The graph.
y : torch.Tensor
The node labels, which is a tensor of shape (|V|).
Returns
-------
float
The edge homophily ratio value.
Examples
--------
>>> import dgl
>>> import torch
>>> graph = dgl.graph(([1, 2, 0, 4], [0, 1, 2, 3]))
>>> y = torch.tensor([0, 0, 0, 0, 1])
>>> dgl.edge_homophily(graph, y)
0.75
"""
check_pytorch()
with graph.local_scope():
# Handle the case where graph is of dtype int32.
src, dst = get_long_edges(graph)
# Compute y_v = y_u for all edges.
edge_indicator = (y[src] == y[dst]).float()
return edge_indicator.mean(dim=0).item()


def linkx_homophily(graph, y):
r"""Homophily measure from `Large Scale Learning on Non-Homophilous Graphs:
New Benchmarks and Strong Simple Methods
<https://arxiv.org/abs/2110.14446>`__
Mathematically it is defined as follows:
.. math::
\frac{1}{C-1} \sum_{k=1}^{C} \max \left(0, \frac{\sum_{v\in C_k}|\{u\in
\mathcal{N}(v): y_v = y_u \}|}{\sum_{v\in C_k}|\mathcal{N}(v)|} -
\frac{|\mathcal{C}_k|}{|\mathcal{V}|} \right),
where :math:`C` is the number of node classes, :math:`C_k` is the set of
nodes that belong to class k, :math:`\mathcal{N}(v)` are the predecessors
of node :math:`v`, :math:`y_v` is the class of node :math:`v`, and
:math:`\mathcal{V}` is the set of nodes.
Parameters
----------
graph : DGLGraph
The graph.
y : torch.Tensor
The node labels, which is a tensor of shape (|V|).
Returns
-------
float
The homophily value.
Examples
--------
>>> import dgl
>>> import torch
>>> graph = dgl.graph(([0, 1, 2, 3], [1, 2, 0, 4]))
>>> y = torch.tensor([0, 0, 0, 0, 1])
>>> dgl.linkx_homophily(graph, y)
0.19999998807907104
"""
check_pytorch()
with graph.local_scope():
src, dst = graph.edges()
# Compute |{u\in N(v): y_v = y_u}| for each node v.
# Handle the case where graph is of dtype int32.
src = F.astype(src, F.int64)
dst = F.astype(dst, F.int64)
src, dst = get_long_edges(graph)
# Compute y_v = y_u for all edges.
graph.edata["same_class"] = F.astype(y[src] == y[dst], F.float32)
graph.edata["same_class"] = (y[src] == y[dst]).float()
graph.update_all(
fn.copy_e("same_class", "m"), fn.mean("m", "node_value")
fn.copy_e("same_class", "m"), fn.sum("m", "same_class_deg")
)
return graph.ndata["node_value"].mean().item()

deg = graph.in_degrees().float()
num_nodes = graph.num_nodes()
num_classes = y.max(dim=0).values.item() + 1

value = 0
for k in range(num_classes):
# Get the nodes that belong to class k.
class_mask = y == k
same_class_deg_k = graph.ndata["same_class_deg"][class_mask].sum()
deg_k = deg[class_mask].sum()
num_nodes_k = class_mask.sum()
value += max(0, same_class_deg_k / deg_k - num_nodes_k / num_nodes)

return value.item() / (num_classes - 1)
35 changes: 33 additions & 2 deletions tests/python/common/test_homophily.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import unittest

import backend as F
Expand All @@ -6,7 +7,9 @@
from test_utils import parametrize_idtype


@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="Skip TF")
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@parametrize_idtype
def test_node_homophily(idtype):
# IfChangeThenChange: python/dgl/homophily.py
Expand All @@ -16,4 +19,32 @@ def test_node_homophily(idtype):
([1, 2, 0, 4], [0, 1, 2, 3]), idtype=idtype, device=device
)
y = F.tensor([0, 0, 0, 0, 1])
assert dgl.node_homophily(graph, y) == 0.6000000238418579
assert math.isclose(dgl.node_homophily(graph, y), 0.6000000238418579)


@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@parametrize_idtype
def test_edge_homophily(idtype):
# IfChangeThenChange: python/dgl/homophily.py
# Update the docstring example.
device = F.ctx()
graph = dgl.graph(
([1, 2, 0, 4], [0, 1, 2, 3]), idtype=idtype, device=device
)
y = F.tensor([0, 0, 0, 0, 1])
assert math.isclose(dgl.edge_homophily(graph, y), 0.75)


@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@parametrize_idtype
def test_linkx_homophily(idtype):
# IfChangeThenChange: python/dgl/homophily.py
# Update the docstring example.
device = F.ctx()
graph = dgl.graph(([0, 1, 2, 3], [1, 2, 0, 4]), device=device)
y = F.tensor([0, 0, 0, 0, 1])
assert math.isclose(dgl.linkx_homophily(graph, y), 0.19999998807907104)

0 comments on commit 6cdbd37

Please sign in to comment.