Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Utils] Edge and LINKX homophily measure #5382

Merged
merged 29 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api/python/dgl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ Utilities for measuring homophily of a graph
.. autosummary::
:toctree: ../../generated/

edge_homophily
node_homophily
linkx_homophily

Utilities
-----------------------------------------------
Expand Down
160 changes: 150 additions & 10 deletions python/dgl/homophily.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
"""Utils for tacking graph homophily and heterophily"""
from . import backend as F, function as fn
from .convert import graph as create_graph

__all__ = ["node_homophily"]
try:
import torch
except ImportError:
pass

__all__ = ["node_homophily", "edge_homophily", "linkx_homophily"]


def get_long_edges(graph):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should internal helper function start with _?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine either way. For this file, it should be clear that only the functions included in __all__ are external.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sound good.

nit: Maybe rename to get_edges_long, more natural.

src, dst = graph.edges()
src = F.astype(src, F.int64)
dst = F.astype(dst, F.int64)
mufeili marked this conversation as resolved.
Show resolved Hide resolved
return src, dst


def node_homophily(graph, y):
r"""Homophily measure from `Geom-GCN: Geometric Graph Convolutional Networks
<https://arxiv.org/abs/2002.05287>`__
r"""Homophily measure from `Geom-GCN: Geometric Graph Convolutional
Networks <https://arxiv.org/abs/2002.05287>`__

We follow the practice of a later paper `Large Scale Learning on
Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods
Expand All @@ -15,8 +28,8 @@ def node_homophily(graph, y):
Mathematically it is defined as follows:

.. math::
\frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \frac{ | \{ (u,v) : u
\in \mathcal{N}(v) \wedge y_v = y_u \} | } { |\mathcal{N}(v)| }
\frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \frac{ | \{u
\in \mathcal{N}(v): y_v = y_u \} | } { |\mathcal{N}(v)| }
mufeili marked this conversation as resolved.
Show resolved Hide resolved

where :math:`\mathcal{V}` is the set of nodes, :math:`\mathcal{N}(v)` is
the predecessors of node :math:`v`, and :math:`y_v` is the class of node
Expand Down Expand Up @@ -45,13 +58,140 @@ def node_homophily(graph, y):
0.6000000238418579
"""
with graph.local_scope():
src, dst = graph.edges()
# Handle the case where graph is of dtype int32.
src = F.astype(src, F.int64)
dst = F.astype(dst, F.int64)
src, dst = get_long_edges(graph)
# Compute y_v = y_u for all edges.
graph.edata["same_class"] = F.astype(y[src] == y[dst], F.float32)
graph.update_all(
fn.copy_e("same_class", "m"), fn.mean("m", "node_value")
fn.copy_e("same_class", "m"), fn.mean("m", "same_class_deg")
)
return graph.ndata["node_value"].mean().item()
return F.as_scalar(F.mean(graph.ndata["same_class_deg"], dim=0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to point out that the implementation is awkward due to the constraints of our current APIs: (1) Need to use framework-agnostic backend, (2) don't support integer-type aggregation, etc.

Ideally, it should be as simple as:

u, v = graph.edges()
graph.edata['same_class'] = (y[u.long()] == y[v.long()]).float()
graph.update_all(...)
return graph.ndata["same_class_deg"].mean()



def edge_homophily(graph, y):
r"""Homophily measure from `Beyond Homophily in Graph Neural Networks:
Current Limitations and Effective Designs
<https://arxiv.org/abs/2006.11468>`__

Mathematically it is defined as follows:

.. math::
\frac{| \{ (u,v) : (u,v) \in \mathcal{E} \wedge y_u = y_v \} | }
{|\mathcal{E}|}

where :math:`\mathcal{E}` is the set of edges, and :math:`y_u` is the class
of node :math:`u`.

Parameters
----------
graph : DGLGraph
The graph
y : Tensor
The node labels, which is a tensor of shape (|V|)

Returns
-------
float
The edge homophily ratio value
mufeili marked this conversation as resolved.
Show resolved Hide resolved

Examples
--------
>>> import dgl
>>> import torch

>>> graph = dgl.graph(([1, 2, 0, 4], [0, 1, 2, 3]))
>>> y = torch.tensor([0, 0, 0, 0, 1])
>>> dgl.edge_homophily(graph, y)
0.75
"""
with graph.local_scope():
# Handle the case where graph is of dtype int32.
src, dst = get_long_edges(graph)
# Compute y_v = y_u for all edges.
edge_indicator = F.astype(y[src] == y[dst], F.float32)
return F.as_scalar(F.mean(edge_indicator, dim=0))


def linkx_homophily(graph, y):
r"""Homophily measure from `Large Scale Learning on Non-Homophilous Graphs:
New Benchmarks and Strong Simple Methods
<https://arxiv.org/abs/2110.14446>`__

Mathematically it is defined as follows:

.. math::
\frac{1}{C-1} \sum_{k=1}^{C} \max \left(0, \frac{\sum_{v\in C_k}|\{u\in
\mathcal{N}(v): y_v = y_u \}|}{\sum_{v\in C_k}|\mathcal{N}(v)|} -
\frac{|\mathcal{C}_k|}{|\mathcal{V}|} \right)

where :math:`C` is the number of node classes, :math:`C_k` is the set of
nodes that belong to class k, :math:`\mathcal{N}(v)` are the predecessors
of node :math:`v`, :math:`y_v` is the class of node :math:`v`, and
:math:`\mathcal{V}` is the set of nodes.

Parameters
----------
graph : DGLGraph
The graph
y : Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.Tensor

and others.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

The node labels, which is a tensor of shape (|V|)

Returns
-------
float
The homophily value

Examples
--------
>>> import dgl
>>> import torch

>>> graph = dgl.graph(([0, 1, 2, 3], [1, 2, 0, 4]))
>>> y = torch.tensor([0, 0, 0, 0, 1])
>>> dgl.linkx_homophily(graph, y)
0.19999998807907104
"""
with graph.local_scope():
# Compute |{u\in N(v): y_v = y_u}| for each node v.
# Handle the case where graph is of dtype int32.
src, dst = get_long_edges(graph)
# Compute y_v = y_u for all edges.
graph.edata["same_class"] = (y[src] == y[dst]).float()
graph.update_all(
fn.copy_e("same_class", "m"), fn.mean("m", "same_class_deg")
mufeili marked this conversation as resolved.
Show resolved Hide resolved
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, now I'm pushing this further. Will using sparse API makes the code more readable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How so? You convert the graph to a sparse matrix and call AX. I don't think there are significant differences.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with graph.local_scope():
    # Handle the case where graph is of dtype int32.
    src, dst = get_long_edges(graph)
    # Compute y_v = y_u for all edges.
    graph.edata["same_class"] = (y[src] == y[dst]).float()
    graph.update_all(
        fn.copy_e("same_class", "m"), fn.mean("m", "same_class_deg")
    )
    return graph.ndata["same_class_deg"].mean(dim=0).item()

v.s.

A = graph.adj
same_class = (y[A.row] == y[A.col]).float()
same_class_avg = dglsp.val_like(A, same_class).smean(dim=1)
return same_class_avg.mean(dim=0).item()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v.s. in the new message passing API style

src, dst = get_long_edges(graph)
same_class = (y[src] == y[dst]).float()
same_class_avg = dgl.mpops.copy_e_mean(g, same_class)
return same_class_avg.mean(dim=0).item()

Copy link
Member Author

@mufeili mufeili Mar 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still, it's quite subtle. I'm fine either way. The question is more about when do we encourage the use of message passing APIs versus sparse APIs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My opinion is to go with the math formulation: If the model is described in node-wise/edge-wise computation then message passing is the way to goal; otherwise, use sparse. In this case, the definition is in node/edge so message passing is more suitable. You can see that although the sparse APIs are shorter, it doesn't align well with the definition, e.g., the use of val_like and smean is not straightforward.


# Compute |N(v)| for each node v.
deg = graph.in_degrees().float()

# To compute \sum_{v\in C_k}|{u\in N(v): y_v = y_u}| for all k
# efficiently, construct a directed graph from nodes to their class.
num_classes = F.max(y, dim=0).item() + 1
src = graph.nodes().to(dtype=y.dtype)
dst = y + graph.num_nodes()
class_graph = create_graph((src, dst))
# Add placeholder values for the class nodes.
class_placeholder = torch.zeros(
(num_classes), dtype=deg.dtype, device=class_graph.device
)
class_graph.ndata["same_class_deg"] = torch.cat(
[graph.ndata["same_class_deg"], class_placeholder], dim=0
)
class_graph.update_all(
fn.copy_u("same_class_deg", "m"), fn.sum("m", "class_deg_aggr")
)

# Similarly, compute \sum_{v\in C_k}|N(v)| for all k in parallel.
class_graph.ndata["deg"] = torch.cat([deg, class_placeholder], dim=0)
class_graph.update_all(fn.copy_u("deg", "m"), fn.sum("m", "deg_aggr"))
mufeili marked this conversation as resolved.
Show resolved Hide resolved

# Compute class_deg_aggr / deg_aggr for all classes.
num_nodes = graph.num_nodes()
class_deg_aggr = class_graph.ndata["class_deg_aggr"][num_nodes:]
deg_aggr = torch.clamp(class_graph.ndata["deg_aggr"][num_nodes:], min=1)
fraction = (
class_deg_aggr / deg_aggr - torch.bincount(y).float() / num_nodes
)
fraction = torch.clamp(fraction, min=0)

return fraction.sum().item() / (num_classes - 1)
26 changes: 26 additions & 0 deletions tests/python/common/test_homophily.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,29 @@ def test_node_homophily(idtype):
)
y = F.tensor([0, 0, 0, 0, 1])
assert dgl.node_homophily(graph, y) == 0.6000000238418579


@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="Skip TF")
@parametrize_idtype
def test_edge_homophily(idtype):
# IfChangeThenChange: python/dgl/homophily.py
# Update the docstring example.
device = F.ctx()
graph = dgl.graph(
([1, 2, 0, 4], [0, 1, 2, 3]), idtype=idtype, device=device
)
y = F.tensor([0, 0, 0, 0, 1])
assert dgl.edge_homophily(graph, y) == 0.75


@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@parametrize_idtype
def test_linkx_homophily(idtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any conner case you need to handle?
e.g. there was a max(0, xxxx)
Should we check the 0 cases?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current cases are sufficient.

# IfChangeThenChange: python/dgl/homophily.py
# Update the docstring example.
device = F.ctx()
graph = dgl.graph(([0, 1, 2, 3], [1, 2, 0, 4]), device=device)
y = F.tensor([0, 0, 0, 0, 1])
assert dgl.linkx_homophily(graph, y) == 0.19999998807907104
mufeili marked this conversation as resolved.
Show resolved Hide resolved