Skip to content

Commit

Permalink
Fix linting with lintrunner -a
Browse files Browse the repository at this point in the history
  • Loading branch information
schmidt-ju committed Feb 14, 2023
1 parent 66165d4 commit 6d6a24b
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 63 deletions.
2 changes: 1 addition & 1 deletion python/dgl/nn/pytorch/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from .dgnconv import DGNConv
from .dotgatconv import DotGatConv
from .edgeconv import EdgeConv
from .egatconv import EGATConv
from .edgegatconv import EdgeGATConv
from .egatconv import EGATConv
from .egnnconv import EGNNConv
from .gatconv import GATConv
from .gatedgraphconv import GatedGraphConv
Expand Down
146 changes: 84 additions & 62 deletions python/dgl/nn/pytorch/conv/edgegatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from torch import nn

from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError
from ....utils import expand_as_pair
from ...functional import edge_softmax

# pylint: enable=W0235
class EdgeGATConv(nn.Module):
Expand Down Expand Up @@ -125,54 +125,63 @@ class EdgeGATConv(nn.Module):
(torch.Size([4, 3, 10]), torch.Size([5, 3, 1]))
"""

def __init__(self,
in_feats,
edge_feats,
out_feats,
num_heads,
feat_drop=0.,
attn_drop=0.,
negative_slope=0.2,
residual=True,
activation=None,
allow_zero_in_degree=False,
bias=True):
def __init__(
self,
in_feats,
edge_feats,
out_feats,
num_heads,
feat_drop=0.0,
attn_drop=0.0,
negative_slope=0.2,
residual=True,
activation=None,
allow_zero_in_degree=False,
bias=True,
):
super(EdgeGATConv, self).__init__()
self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False)
self._in_src_feats, out_feats * num_heads, bias=False
)
self.fc_dst = nn.Linear(
self._in_dst_feats, out_feats * num_heads, bias=False)
self._in_dst_feats, out_feats * num_heads, bias=False
)
else:
self.fc = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False)
self._in_src_feats, out_feats * num_heads, bias=False
)
self.attn_l = nn.Parameter(
th.FloatTensor(size=(1, num_heads, out_feats)))
th.FloatTensor(size=(1, num_heads, out_feats))
)
self.attn_r = nn.Parameter(
th.FloatTensor(size=(1, num_heads, out_feats)))
th.FloatTensor(size=(1, num_heads, out_feats))
)
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
if bias:
self.bias = nn.Parameter(
th.FloatTensor(size=(num_heads * out_feats,)))
th.FloatTensor(size=(num_heads * out_feats,))
)
else:
self.register_buffer('bias', None)
self.register_buffer("bias", None)
if residual:
self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=False)
self._in_dst_feats, num_heads * out_feats, bias=False
)
else:
self.register_buffer('res_fc', None)
self.register_buffer("res_fc", None)

self._edge_feats = edge_feats
self.fc_edge = nn.Linear(
edge_feats, out_feats * num_heads, bias=False)
self.fc_edge = nn.Linear(edge_feats, out_feats * num_heads, bias=False)
self.attn_edge = nn.Parameter(
th.FloatTensor(size=(1, num_heads, out_feats)))
th.FloatTensor(size=(1, num_heads, out_feats))
)

self.reset_parameters()
self.activation = activation
Expand All @@ -189,8 +198,8 @@ def reset_parameters(self):
The fc weights :math:`\mathbf{\Theta}` are and the
attention weights are using xavier initialization method.
"""
gain = nn.init.calculate_gain('relu')
if hasattr(self, 'fc'):
gain = nn.init.calculate_gain("relu")
if hasattr(self, "fc"):
nn.init.xavier_normal_(self.fc.weight, gain=gain)
else:
nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
Expand Down Expand Up @@ -261,46 +270,55 @@ def forward(self, graph, feat, edge_feat, get_attention=False):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)

if isinstance(feat, tuple):
src_prefix_shape = feat[0].shape[:-1]
dst_prefix_shape = feat[1].shape[:-1]
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'):
if not hasattr(self, "fc_src"):
feat_src = self.fc(h_src).view(
*src_prefix_shape, self._num_heads, self._out_feats)
*src_prefix_shape, self._num_heads, self._out_feats
)
feat_dst = self.fc(h_dst).view(
*dst_prefix_shape, self._num_heads, self._out_feats)
*dst_prefix_shape, self._num_heads, self._out_feats
)
else:
feat_src = self.fc_src(h_src).view(
*src_prefix_shape, self._num_heads, self._out_feats)
*src_prefix_shape, self._num_heads, self._out_feats
)
feat_dst = self.fc_dst(h_dst).view(
*dst_prefix_shape, self._num_heads, self._out_feats)
*dst_prefix_shape, self._num_heads, self._out_feats
)
else:
src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view(
*src_prefix_shape, self._num_heads, self._out_feats)
*src_prefix_shape, self._num_heads, self._out_feats
)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_dst = h_dst[: graph.number_of_dst_nodes()]
dst_prefix_shape = (
graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
graph.number_of_dst_nodes(),
) + dst_prefix_shape[1:]

# linearly tranform the edge features
n_edges = edge_feat.shape[:-1]
feat_edge = self.fc_edge(edge_feat).view(
*n_edges, self._num_heads, self._out_feats)
*n_edges, self._num_heads, self._out_feats
)

# add edge features to graph
graph.edata["ft_edge"] = feat_edge
Expand All @@ -310,50 +328,54 @@ def forward(self, graph, feat, edge_feat, get_attention=False):

# calculate scalar for each edge
ee = (feat_edge * self.attn_edge).sum(dim=-1).unsqueeze(-1)
graph.edata['ee'] = ee
graph.edata["ee"] = ee

graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er})
graph.srcdata.update({"ft": feat_src, "el": el})
graph.dstdata.update({"er": er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively
graph.apply_edges(fn.u_add_v('el', 'er', 'e_tmp'))
graph.apply_edges(fn.u_add_v("el", "er", "e_tmp"))

# e_tmp combines attention weights of source and destination node
# we also add the attention weight of the edge
graph.edata['e'] = graph.edata['e_tmp'] + graph.edata['ee']
graph.edata["e"] = graph.edata["e_tmp"] + graph.edata["ee"]

# create new edges features that combine the
# features of the source node and the edge features
graph.apply_edges(fn.u_add_e('ft', 'ft_edge', 'ft_combined'))
graph.apply_edges(fn.u_add_e("ft", "ft_edge", "ft_combined"))

e = self.leaky_relu(graph.edata.pop('e'))
e = self.leaky_relu(graph.edata.pop("e"))
# compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))

# for each edge, element-wise multiply the combined features with
# the attention coefficient
graph.edata['m_combined'] = graph.edata['ft_combined'] * \
graph.edata['a']
graph.edata["m_combined"] = (
graph.edata["ft_combined"] * graph.edata["a"]
)

# first copy the edge features and then sum them up
graph.update_all(fn.copy_e('m_combined', 'm'),
fn.sum('m', 'ft'))
graph.update_all(fn.copy_e("m_combined", "m"), fn.sum("m", "ft"))

rst = graph.dstdata['ft']
rst = graph.dstdata["ft"]
# residual
if self.res_fc is not None:
# Use -1 rather than self._num_heads to handle broadcasting
resval = self.res_fc(h_dst).view(
*dst_prefix_shape, -1, self._out_feats)
*dst_prefix_shape, -1, self._out_feats
)
rst = rst + resval
# bias
if self.bias is not None:
rst = rst + self.bias.view(
*((1,) * len(dst_prefix_shape)), self._num_heads, self._out_feats)
*((1,) * len(dst_prefix_shape)),
self._num_heads,
self._out_feats
)
# activation
if self.activation:
rst = self.activation(rst)

if get_attention:
return rst, graph.edata['a']
return rst, graph.edata["a"]
else:
return rst

0 comments on commit 6d6a24b

Please sign in to comment.