-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Utils] Edge and LINKX homophily measure #5382
Conversation
To trigger regression tests:
|
python/dgl/homophily.py
Outdated
) | ||
return graph.ndata["node_value"].mean().item() | ||
return F.as_scalar(F.mean(graph.ndata["same_class_deg"], dim=0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to point out that the implementation is awkward due to the constraints of our current APIs: (1) Need to use framework-agnostic backend, (2) don't support integer-type aggregation, etc.
Ideally, it should be as simple as:
u, v = graph.edges()
graph.edata['same_class'] = (y[u.long()] == y[v.long()]).float()
graph.update_all(...)
return graph.ndata["same_class_deg"].mean()
__all__ = ["node_homophily", "edge_homophily", "linkx_homophily"] | ||
|
||
|
||
def get_long_edges(graph): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sound good.
nit: Maybe rename to get_edges_long, more natural.
python/dgl/homophily.py
Outdated
---------- | ||
graph : DGLGraph | ||
The graph. | ||
y : Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.Tensor
and others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
for k in range(num_classes): | ||
# Get the nodes that belong to class k. | ||
class_mask = y == k |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: class_mask = (y == k)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I initially did what you suggested, and then the lint check failed.
dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now" | ||
) | ||
@parametrize_idtype | ||
def test_linkx_homophily(idtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any conner case you need to handle?
e.g. there was a max(0, xxxx)
Should we check the 0 cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the current cases are sufficient.
def get_long_edges(graph): | ||
"""Internal function for getting the edges of a graph as long tensors.""" | ||
src, dst = graph.edges() | ||
return src.long(), dst.long() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there are only two lines, consider just embed them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine either way. Maybe you two can start a fight. :) @frozenbugs
graph.edata["same_class"] = (y[src] == y[dst]).float() | ||
graph.update_all( | ||
fn.copy_e("same_class", "m"), fn.sum("m", "same_class_deg") | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, now I'm pushing this further. Will using sparse API makes the code more readable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How so? You convert the graph to a sparse matrix and call AX
. I don't think there are significant differences.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with graph.local_scope():
# Handle the case where graph is of dtype int32.
src, dst = get_long_edges(graph)
# Compute y_v = y_u for all edges.
graph.edata["same_class"] = (y[src] == y[dst]).float()
graph.update_all(
fn.copy_e("same_class", "m"), fn.mean("m", "same_class_deg")
)
return graph.ndata["same_class_deg"].mean(dim=0).item()
v.s.
A = graph.adj
same_class = (y[A.row] == y[A.col]).float()
same_class_avg = dglsp.val_like(A, same_class).smean(dim=1)
return same_class_avg.mean(dim=0).item()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
v.s. in the new message passing API style
src, dst = get_long_edges(graph)
same_class = (y[src] == y[dst]).float()
same_class_avg = dgl.mpops.copy_e_mean(g, same_class)
return same_class_avg.mean(dim=0).item()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still, it's quite subtle. I'm fine either way. The question is more about when do we encourage the use of message passing APIs versus sparse APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My opinion is to go with the math formulation: If the model is described in node-wise/edge-wise computation then message passing is the way to goal; otherwise, use sparse. In this case, the definition is in node/edge so message passing is more suitable. You can see that although the sparse APIs are shorter, it doesn't align well with the definition, e.g., the use of val_like
and smean
is not straightforward.
dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now" | ||
) | ||
@parametrize_idtype | ||
def test_linkx_homophily(idtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Do you want to remove the [DoNotMerge] tag?
* Update * lint * lint * r prefix * CI * lint * skip TF * Update * edge homophily * linkx homophily * format * skip TF * fix test * update * lint * lint * review * lint * update * lint * update * CI --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
Description
Checklist
Please feel free to remove inapplicable items for your PR.
Changes