-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NN] Support scalar edge weight for GraphConv, SAGEConv and GINConv #2557
Conversation
If a weight tensor on each edge is provided, the weighted graph convolution is defined as: | ||
|
||
.. math:: | ||
h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{e_{ji}}{c_{ij}}h_j^{(l)}W^{(l)}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this formulation correct? The definition of graph convolution is D^{-1/2} A D^{-1/2}
, where D[i, i] = sum(A[i, :])
, for weighted graphs.
I will extend the formulation above to bipartite graphs (blocks) with D_{out}^{-1/2} A D_{in}^{-1/2}
, where D_out[i, i] = sum(A[i, :])
and D_in[j, j] = sum(A[:, j])
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reasons that I implement it in this way are:
- For scalar edge weight, the correct normalization asks for a reversed block to compute the
g.out_degrees
counterpart. - For multi-dimensional edge weight, there's no broadly accepted definition yet.
Therefore with the current form, users can insert an edge weight for simple edge weight computation (like masking edges), and pre-normalize the edge weight to customize the computation if necessary.
The unittest has passed. Could you please take a second look? @BarclayII |
reversed_g = reverse(graph) | ||
reversed_g.edata['_edge_w'] = edge_weight | ||
reversed_g.update_all(fn.copy_edge('_edge_w', 'm'), fn.sum('m', 'out_weight')) | ||
degs = reversed_g.dstdata['out_weight'] + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit unsure about whether "1" will be small enough as we have no restriction over the range of edge weights. Probably we can make this an argument in __init__
and set 1 to be the default value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works as the I
in the form of \tilde{A} = A+I
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- For the current implementation of
GraphConv
, it did not add "1" after degree computation. - This module can be in principle quite flexible and not necessarily restricted to the case of GCN.
@@ -205,6 +349,9 @@ def forward(self, graph, feat, weight=None): | |||
:math:`(N_{out}, D_{in_{dst}})`. | |||
weight : torch.Tensor, optional | |||
Optional external weight tensor. | |||
edge_weight : torch.Tensor, optional | |||
Optional tensor on the edge. If given, the convolution will weight | |||
with regard to the edge feature. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you only want to support scalar edge weights, mention that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We support all broadcastable dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In principle this can be broadcastable multidimensional edge weights like below.
import dgl
import dgl.function as fn
import torch
g = dgl.graph(([0, 1], [1, 2]))
g.ndata['h'] = torch.randn(3, 2)
g.edata['w'] = torch.randn(2, 3, 1)
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h2'))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is exactly what I have in code
@@ -20,6 +20,16 @@ class GINConv(nn.Module): | |||
\mathrm{aggregate}\left(\left\{h_j^{l}, j\in\mathcal{N}(i) | |||
\right\}\right)\right) | |||
|
|||
If a weight tensor on each edge is provided, the weighted graph convolution is defined as: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"scalar weight"?
@@ -98,6 +108,9 @@ def forward(self, graph, feat): | |||
:math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`. | |||
If ``apply_func`` is not None, :math:`D_{in}` should | |||
fit the input dimensionality requirement of ``apply_func``. | |||
edge_weight : torch.Tensor, optional | |||
Optional tensor on the edge. If given, the convolution will weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"scalar weight"?
@@ -98,6 +108,9 @@ def forward(self, graph, feat): | |||
:math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`. | |||
If ``apply_func`` is not None, :math:`D_{in}` should | |||
fit the input dimensionality requirement of ``apply_func``. | |||
edge_weight : torch.Tensor, optional | |||
Optional tensor on the edge. If given, the convolution will weight | |||
with regard to the edge feature. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"edge feature" -> "message"?
@@ -164,6 +173,9 @@ def forward(self, graph, feat): | |||
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. | |||
If a pair of torch.Tensor is given, the pair must contain two tensors of shape | |||
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. | |||
edge_weight : torch.Tensor, optional | |||
Optional tensor on the edge. If given, the convolution will weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scalar weight?
@@ -164,6 +173,9 @@ def forward(self, graph, feat): | |||
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. | |||
If a pair of torch.Tensor is given, the pair must contain two tensors of shape | |||
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. | |||
edge_weight : torch.Tensor, optional | |||
Optional tensor on the edge. If given, the convolution will weight | |||
with regard to the edge feature. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"edge feature" -> "message"
@yzh119 Do we now recommend using |
|
out : str | ||
The output message field. | ||
|
||
Examples | ||
-------- | ||
>>> import dgl | ||
>>> message_func = dgl.function.src_mul_edge('h', 'h', 'm') | ||
>>> message_func = dgl.function.src_mul_edge('h', 'e', 'm') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"src_mul_edge" -> "u_mul_e"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait... shouldn't we at least keep the docstring consistent with the function, even it is deprecated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't realize that:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
This PR is for #1281
I might need help on adding proper unit tests.
Checklist
Please feel free to remove inapplicable items for your PR.
or have been fixed to be compatible with this change