Skip to content

Commit

Permalink
fix C0103
Browse files Browse the repository at this point in the history
  • Loading branch information
rudongyu committed Jan 30, 2023
1 parent 4901d46 commit 45f756b
Showing 1 changed file with 25 additions and 23 deletions.
48 changes: 25 additions & 23 deletions python/dgl/nn/pytorch/gt/lap_enc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ class LapPosEncoder(nn.Module):
`GraphGPS: General Powerful Scalable Graph Transformers
<https://arxiv.org/abs/2205.12454>`__
This module is a learned laplacian positional encoding module using Transformer or DeepSet.
This module is a learned laplacian positional encoding module using
Transformer or DeepSet.
Parameters
----------
Expand All @@ -28,7 +29,8 @@ class LapPosEncoder(nn.Module):
If True, apply batch normalization on raw LaplacianPE.
Default : False.
num_post_layer : int, optional
If num_post_layer > 0, apply an MLP of ``num_post_layer`` layers after pooling.
If num_post_layer > 0, apply an MLP of ``num_post_layer`` layers after
pooling.
Default : 0.
Example
Expand All @@ -40,13 +42,13 @@ class LapPosEncoder(nn.Module):
>>> transform = LaplacianPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True)
>>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))
>>> g = transform(g)
>>> EigVals, EigVecs = g.ndata['eigval'], g.ndata['eigvec']
>>> eig_vals, eig_vecs = g.ndata['eigval'], g.ndata['eigvec']
>>> TransformerLPE = LapPosEncoder(model_type="Transformer", num_layer=3, k=5,
lpe_dim=16, n_head=4)
>>> PosEnc = TransformerLPE(EigVals, EigVecs)
>>> pos_enc = TransformerLPE(eig_vals, eig_vecs)
>>> DeepSetLPE = LapPosEncoder(model_type="DeepSet", num_layer=3, k=5,
lpe_dim=16, num_post_layer=2)
>>> PosEnc = DeepSetLPE(EigVals, EigVecs)
>>> pos_enc = DeepSetLPE(eig_vals, eig_vecs)
"""

def __init__(
Expand Down Expand Up @@ -85,8 +87,8 @@ def __init__(
self.pe_encoder = nn.Sequential(*layers)
else:
raise ValueError(
f"model_type '{model_type}' is not allowed, must be 'Transformer'"
"or 'DeepSet'."
f"model_type '{model_type}' is not allowed, must be "
"'Transformer' or 'DeepSet'."
)

if batch_norm:
Expand All @@ -111,14 +113,14 @@ def __init__(
else:
self.post_mlp = None

def forward(self, EigVals, EigVecs):
def forward(self, eig_vals, eig_vecs):
r"""
Parameters
----------
EigVals : Tensor
eig_vals : Tensor
Laplacian Eigenvalues of shape :math:`(N, k)`, k different eigenvalues repeat N times,
can be obtained by using `LaplacianPE`.
EigVecs : Tensor
eig_vecs : Tensor
Laplacian Eigenvectors of shape :math:`(N, k)`, can be obtained by using `LaplacianPE`.
Returns
Expand All @@ -128,31 +130,31 @@ def forward(self, EigVals, EigVecs):
where :math:`N` is the number of nodes in the input graph,
:math:`d` is :attr:`lpe_dim`.
"""
PosEnc = th.cat(
(EigVecs.unsqueeze(2), EigVals.unsqueeze(2)), dim=2
pos_enc = th.cat(
(eig_vecs.unsqueeze(2), eig_vals.unsqueeze(2)), dim=2
).float()
empty_mask = th.isnan(PosEnc)
empty_mask = th.isnan(pos_enc)

PosEnc[empty_mask] = 0
pos_enc[empty_mask] = 0
if self.raw_norm:
PosEnc = self.raw_norm(PosEnc)
PosEnc = self.linear(PosEnc)
pos_enc = self.raw_norm(pos_enc)
pos_enc = self.linear(pos_enc)

if self.model_type == "Transformer":
PosEnc = self.pe_encoder(
src=PosEnc, src_key_padding_mask=empty_mask[:, :, 1]
pos_enc = self.pe_encoder(
src=pos_enc, src_key_padding_mask=empty_mask[:, :, 1]
)
else:
PosEnc = self.pe_encoder(PosEnc)
pos_enc = self.pe_encoder(pos_enc)

# Remove masked sequences
PosEnc[empty_mask[:, :, 1]] = 0
pos_enc[empty_mask[:, :, 1]] = 0

# Sum pooling
PosEnc = th.sum(PosEnc, 1, keepdim=False)
pos_enc = th.sum(pos_enc, 1, keepdim=False)

# MLP post pooling
if self.post_mlp:
PosEnc = self.post_mlp(PosEnc)
pos_enc = self.post_mlp(pos_enc)

return PosEnc
return pos_enc

0 comments on commit 45f756b

Please sign in to comment.