Skip to content

Commit

Permalink
Fix comment formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
schmidt-ju committed Mar 3, 2023
1 parent c079d1b commit 35d9ca1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 31 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ Take the survey [here](https://forms.gle/Ej3jHCocACmb49Gp8) and leave any feedba

1. [**How Attentive are Graph Attention Networks?**](https://arxiv.org/pdf/2105.14491.pdf), *Shaked Brody, Uri Alon, Eran Yahav*, [code](https://github.com/tech-srl/how_attentive_are_gats)

1. [**SCENE: Reasoning about Traffic Scenes using Heterogeneous Graph Neural Networks**](https://arxiv.org/pdf/2301.03512.pdf), *Thomas Monninger\*, Julian Schmidt\*, Jan Rupprecht, David Raba, Julian Jordan, Daniel Frank, Steffen Staab, Klaus Dietmayer*, [code](https://github.com/schmidt-ju/scene)
1. [**SCENE: Reasoning about Traffic Scenes using Heterogeneous Graph Neural Networks**](https://arxiv.org/pdf/2301.03512.pdf), *Thomas Monninger\*, Julian Schmidt\*, Jan Rupprecht, David Raba, Julian Jordan, Daniel Frank, Steffen Staab, Klaus Dietmayer*, [code](https://github.com/schmidt-ju/scene), \*co-first authors

</details>

Expand Down
68 changes: 38 additions & 30 deletions python/dgl/nn/pytorch/conv/edgegatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class EdgeGATConv(nn.Module):
<https://arxiv.org/pdf/2301.03512.pdf>`__
.. math::
\mathbf{v}_i^\prime = \mathbf{\Theta}_\mathrm{s} \cdot \mathbf{v}_i +
\sum\limits_{j \in \mathcal{N}(v_i)} \alpha_{j, i} \left( \mathbf{\Theta}_\mathrm{n}
\cdot \mathbf{v}_j + \mathbf{\Theta}_\mathrm{e} \cdot \mathbf{e}_{j,i} \right)
Expand All @@ -22,10 +23,13 @@ class EdgeGATConv(nn.Module):
for the transformation of features of the node to update (s=self),
neighboring nodes (n=neighbor) and edge features (e=edge).
Attention weights are obtained by
.. math::
\alpha_{j, i} = \mathrm{softmax}_i \Big( \mathrm{LeakyReLU} \big( \mathbf{a}^T
[ \mathbf{\Theta}_\mathrm{n} \cdot \mathbf{v}_i || \mathbf{\Theta}_\mathrm{n}
\cdot \mathbf{v}_j || \mathbf{\Theta}_\mathrm{e} \cdot \mathbf{e}_{j,i} ] \big) \Big)
with :math:`\mathbf{a}` corresponding to a learnable vector.
:math:`\mathrm{softmax_i}` stands for the normalization by all incoming edges of node :math:`i`.
Expand Down Expand Up @@ -86,24 +90,26 @@ class EdgeGATConv(nn.Module):
>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from from dgl.nn import EdgeGATConv
>>> from dgl.nn import EdgeGATConv
>>> # Case 1: Homogeneous graph
>>> # Case 1: Homogeneous graph.
>>> num_nodes, num_edges = 8, 30
>>> # generate a graph
>>> # Generate a graph.
>>> graph = dgl.rand_graph(num_nodes,num_edges)
>>> node_feats = th.rand((num_nodes, 20))
>>> edge_feats = th.rand((num_edges, 12))
>>> edge_gat = EdgeGATConv(in_feats=20,
... edge_feats=12,
... out_feats=15,
... num_heads=3)
>>> #forward pass
>>> edge_gat = EdgeGATConv(
... in_feats=20,
... edge_feats=12,
... out_feats=15,
... num_heads=3,
... )
>>> # Forward pass.
>>> new_node_feats = edge_gat(graph, node_feats, edge_feats)
>>> new_node_feats.shape
torch.Size([8, 3, 15]) torch.Size([30, 3, 10])
>>> # Case 2: Unidirectional bipartite graph
>>> # Case 2: Unidirectional bipartite graph.
>>> u = [0, 1, 0, 0, 1]
>>> v = [0, 1, 2, 3, 2]
>>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})
Expand All @@ -115,11 +121,13 @@ class EdgeGATConv(nn.Module):
>>> edge_feats = 15
>>> out_feats = 10
>>> num_heads = 3
>>> egat_model = EdgeGATConv(in_feats,
... edge_feats,
... out_feats,
... num_heads)
>>> #forward pass
>>> egat_model = EdgeGATConv(
... in_feats,
... edge_feats,
... out_feats,
... num_heads,
... )
>>> # Forward pass.
>>> new_node_feats, attention_weights = egat_model(g, nfeats, efeats, get_attention=True)
>>> new_node_feats.shape, attention_weights.shape
(torch.Size([4, 3, 10]), torch.Size([5, 3, 1]))
Expand Down Expand Up @@ -314,64 +322,64 @@ def forward(self, graph, feat, edge_feat, get_attention=False):
graph.number_of_dst_nodes(),
) + dst_prefix_shape[1:]

# linearly tranform the edge features
# Linearly tranform the edge features.
n_edges = edge_feat.shape[:-1]
feat_edge = self.fc_edge(edge_feat).view(
*n_edges, self._num_heads, self._out_feats
)

# add edge features to graph
# Add edge features to graph.
graph.edata["ft_edge"] = feat_edge

el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)

# calculate scalar for each edge
# Calculate scalar for each edge.
ee = (feat_edge * self.attn_edge).sum(dim=-1).unsqueeze(-1)
graph.edata["ee"] = ee

graph.srcdata.update({"ft": feat_src, "el": el})
graph.dstdata.update({"er": er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively
# Compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v("el", "er", "e_tmp"))

# e_tmp combines attention weights of source and destination node
# we also add the attention weight of the edge
# e_tmp combines attention weights of source and destination node.
# Add the attention weight of the edge.
graph.edata["e"] = graph.edata["e_tmp"] + graph.edata["ee"]

# create new edges features that combine the
# features of the source node and the edge features
# Create new edges features that combine the
# features of the source node and the edge features.
graph.apply_edges(fn.u_add_e("ft", "ft_edge", "ft_combined"))

e = self.leaky_relu(graph.edata.pop("e"))
# compute softmax
# Compute softmax.
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))

# for each edge, element-wise multiply the combined features with
# the attention coefficient
# For each edge, element-wise multiply the combined features with
# the attention coefficient.
graph.edata["m_combined"] = (
graph.edata["ft_combined"] * graph.edata["a"]
)

# first copy the edge features and then sum them up
# First copy the edge features and then sum them up.
graph.update_all(fn.copy_e("m_combined", "m"), fn.sum("m", "ft"))

rst = graph.dstdata["ft"]
# residual
# Residual.
if self.res_fc is not None:
# Use -1 rather than self._num_heads to handle broadcasting
# Use -1 rather than self._num_heads to handle broadcasting.
resval = self.res_fc(h_dst).view(
*dst_prefix_shape, -1, self._out_feats
)
rst = rst + resval
# bias
# Bias.
if self.bias is not None:
rst = rst + self.bias.view(
*((1,) * len(dst_prefix_shape)),
self._num_heads,
self._out_feats
)
# activation
# Activation.
if self.activation:
rst = self.activation(rst)

Expand Down

0 comments on commit 35d9ca1

Please sign in to comment.