Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NN] nn modules & examples update #890

Merged
merged 34 commits into from
Nov 3, 2019
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
c4b43b6
upd
yzh119 Sep 30, 2019
9c9ec79
damn it
yzh119 Sep 30, 2019
83fc4df
fuck
yzh119 Sep 30, 2019
bdeb396
fuck pylint
yzh119 Sep 30, 2019
c2b9ec6
fudge
yzh119 Sep 30, 2019
748d8c1
Merge remote-tracking branch 'upstream/master' into shit
yzh119 Sep 30, 2019
f13e7b4
remove some comments about MXNet
yzh119 Sep 30, 2019
392b9f5
Merge remote-tracking branch 'upstream/master' into shit
yzh119 Oct 17, 2019
bd2e64c
upd
yzh119 Oct 17, 2019
abf90a0
Merge branch 'master' into shit
yzh119 Oct 21, 2019
837fb7e
upd
yzh119 Oct 23, 2019
0f5ee2f
Merge branch 'shit' of https://github.com/yzh119/dgl into shit
yzh119 Oct 23, 2019
71614a6
damn it
yzh119 Oct 23, 2019
b4b3ccd
damn it
yzh119 Oct 23, 2019
15c20cb
fuck
yzh119 Oct 29, 2019
64011b4
Merge remote-tracking branch 'upstream/master' into shit
yzh119 Oct 29, 2019
f967eb8
fuck
yzh119 Oct 29, 2019
4514743
upd
yzh119 Oct 29, 2019
e060817
Merge branch 'master' into shit
yzh119 Oct 30, 2019
e7dc899
upd
yzh119 Oct 31, 2019
171a84b
pylint bastard
yzh119 Oct 31, 2019
647f9e2
Merge branch 'master' into shit
jermainewang Nov 1, 2019
959a5f5
upd
yzh119 Nov 2, 2019
bef6276
Merge branch 'shit' of https://github.com/yzh119/dgl into shit
yzh119 Nov 2, 2019
222a7fb
upd
yzh119 Nov 2, 2019
31e7689
upd
yzh119 Nov 2, 2019
705e5d4
upd
yzh119 Nov 2, 2019
d803dfd
upd
yzh119 Nov 2, 2019
c6026e2
upd
yzh119 Nov 3, 2019
4894125
upd
yzh119 Nov 3, 2019
0d3e639
upd
yzh119 Nov 3, 2019
1b77e92
Merge branch 'shit' of https://github.com/yzh119/dgl into shit
yzh119 Nov 3, 2019
2a4d8f8
upd
yzh119 Nov 3, 2019
5da1ea6
Merge branch 'master' into shit
jermainewang Nov 3, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions docs/source/api/python/nn.mxnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,106 @@ TAGConv
:members: forward
:show-inheritance:

GATConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.mxnet.conv.GATConv
:members: forward
:show-inheritance:

EdgeConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.mxnet.conv.EdgeConv
:members: forward
:show-inheritance:

SAGEConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.mxnet.conv.SAGEConv
:members: forward
:show-inheritance:

SGConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.mxnet.conv.SGConv
:members: forward
:show-inheritance:

APPNPConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.mxnet.conv.APPNPConv
:members: forward
:show-inheritance:

GINConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.mxnet.conv.GINConv
:members: forward
:show-inheritance:

GatedGraphConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.mxnet.conv.GatedGraphConv
:members: forward
:show-inheritance:

GMMConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.mxnet.conv.GMMConv
:members: forward
:show-inheritance:

ChebConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.mxnet.conv.ChebConv
:members: forward
:show-inheritance:

AGNNConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.mxnet.conv.AGNNConv
:members: forward
:show-inheritance:

NNConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.mxnet.conv.NNConv
:members: forward
:show-inheritance

Dense Conv Layers
----------------------------------------

DenseGraphConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.mxnet.conv.DenseGraphConv
:members: forward
:show-inheritance:

DenseSAGEConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.mxnet.conv.DenseSAGEConv
:members: forward
:show-inheritance

DenseChebConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.pytorch.conv.DenseChebConv
:members: forward
:show-inheritance:

Global Pooling Layers
----------------------------------------
Expand Down
97 changes: 9 additions & 88 deletions examples/mxnet/gat/gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,87 +7,8 @@
Pytorch implementation: https://github.com/Diego999/pyGAT
"""

import mxnet as mx
from mxnet import gluon
import mxnet.ndarray as nd
import mxnet.gluon.nn as nn
import dgl.function as fn
from dgl.nn.mxnet import edge_softmax


class GraphAttention(gluon.Block):
def __init__(self,
g,
in_dim,
out_dim,
num_heads,
feat_drop,
attn_drop,
alpha,
residual=False):
super(GraphAttention, self).__init__()
self.g = g
self.num_heads = num_heads
self.fc = nn.Dense(num_heads * out_dim, use_bias=False,
weight_initializer=mx.init.Xavier())
if feat_drop:
self.feat_drop = nn.Dropout(feat_drop)
else:
self.feat_drop = lambda x : x
if attn_drop:
self.attn_drop = nn.Dropout(attn_drop)
else:
self.attn_drop = lambda x : x
self.attn_l = self.params.get("left_att", grad_req="add",
shape=(1, num_heads, out_dim),
init=mx.init.Xavier())
self.attn_r = self.params.get("right_att", grad_req="add",
shape=(1, num_heads, out_dim),
init=mx.init.Xavier())
self.alpha = alpha
self.softmax = edge_softmax
self.residual = residual
if residual:
if in_dim != out_dim:
self.res_fc = nn.Dense(num_heads * out_dim, use_bias=False,
weight_initializer=mx.init.Xavier())
else:
self.res_fc = None

def forward(self, inputs):
# prepare
h = self.feat_drop(inputs) # NxD
ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD'
a1 = (ft * self.attn_l.data(ft.context)).sum(axis=-1).expand_dims(-1) # N x H x 1
a2 = (ft * self.attn_r.data(ft.context)).sum(axis=-1).expand_dims(-1) # N x H x 1
self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2})
# 1. compute edge attention
self.g.apply_edges(self.edge_attention)
# 2. compute softmax
self.edge_softmax()
# 3. compute the aggregated node features
self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'),
fn.sum('ft', 'ft'))
ret = self.g.ndata['ft']
# 4. residual
if self.residual:
if self.res_fc is not None:
resval = self.res_fc(h).reshape(
(h.shape[0], self.num_heads, -1)) # NxHxD'
else:
resval = nd.expand_dims(h, axis=1) # Nx1xD'
ret = resval + ret
return ret

def edge_attention(self, edges):
# an edge UDF to compute unnormalized attention values from src and dst
a = nd.LeakyReLU(edges.src['a1'] + edges.dst['a2'], slope=self.alpha)
return {'a' : a}

def edge_softmax(self):
attention = self.softmax(self.g, self.g.edata.pop('a'))
# Dropout attention scores and save them
self.g.edata['a_drop'] = self.attn_drop(attention)
from dgl.nn.mxnet.conv import GATConv


class GAT(nn.Block):
Expand All @@ -109,27 +30,27 @@ def __init__(self,
self.gat_layers = []
self.activation = activation
# input projection (no residual)
self.gat_layers.append(GraphAttention(
g, in_dim, num_hidden, heads[0],
self.gat_layers.append(GATConv(
in_dim, num_hidden, heads[0],
feat_drop, attn_drop, alpha, False))
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GraphAttention(
g, num_hidden * heads[l-1], num_hidden, heads[l],
self.gat_layers.append(GATConv(
num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, alpha, residual))
# output projection
self.gat_layers.append(GraphAttention(
g, num_hidden * heads[-2], num_classes, heads[-1],
self.gat_layers.append(GATConv(
num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, alpha, residual))
for i, layer in enumerate(self.gat_layers):
self.register_child(layer, "gat_layer_{}".format(i))

def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gat_layers[l](h).flatten()
h = self.gat_layers[l](self.g, h).flatten()
h = self.activation(h)
# output projection
logits = self.gat_layers[-1](h).mean(1)
logits = self.gat_layers[-1](self.g, h).mean(1)
return logits
19 changes: 18 additions & 1 deletion python/dgl/nn/mxnet/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,22 @@
from .graphconv import GraphConv
from .relgraphconv import RelGraphConv
from .tagconv import TAGConv
from .gatconv import GATConv
from .sageconv import SAGEConv
from .gatedgraphconv import GatedGraphConv
from .chebconv import ChebConv
from .agnnconv import AGNNConv
from .appnpconv import APPNPConv
from .densegraphconv import DenseGraphConv
from .densesageconv import DenseSAGEConv
from .densechebconv import DenseChebConv
from .edgeconv import EdgeConv
from .ginconv import GINConv
from .gmmconv import GMMConv
from .nnconv import NNConv
from .sgconv import SGConv

__all__ = ['GraphConv', 'TAGConv', 'RelGraphConv']
__all__ = ['GraphConv', 'TAGConv', 'RelGraphConv', 'GATConv',
'SAGEConv', 'GatedGraphConv', 'ChebConv', 'AGNNConv',
'APPNPConv', 'DenseGraphConv', 'DenseSAGEConv', 'DenseChebConv',
'EdgeConv', 'GINConv', 'GMMConv', 'NNConv', 'SGConv']
66 changes: 66 additions & 0 deletions python/dgl/nn/mxnet/conv/agnnconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""MXNet Module for Attention-based Graph Neural Network layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import mxnet as mx
from mxnet.gluon import nn

from .... import function as fn
from ..softmax import edge_softmax
from ..utils import normalize

class AGNNConv(nn.Block):
r"""Attention-based Graph Neural Network layer from paper `Attention-based
Graph Neural Network for Semi-Supervised Learning
<https://arxiv.org/abs/1803.03735>`__.

.. math::
H^{l+1} = P H^{l}

where :math:`P` is computed as:

.. math::
P_{ij} = \mathrm{softmax}_i ( \beta \cdot \cos(h_i^l, h_j^l))

Parameters
----------
init_beta : float, optional
The :math:`\beta` in the formula.
learn_beta : bool, optional
If True, :math:`\beta` will be learnable parameter.
"""
def __init__(self,
init_beta=1.,
learn_beta=True):
super(AGNNConv, self).__init__()
with self.name_scope():
self.beta = self.params.get('beta',
shape=(1,),
grad_req='write' if learn_beta else 'null',
init=mx.init.Constant(init_beta))

def forward(self, graph, feat):
r"""Compute AGNN Layer.

Parameters
----------
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature of shape :math:`(N, *)` :math:`N` is the
number of nodes, and :math:`*` could be of any shape.

Returns
-------
mxnet.NDArray
The output feature of shape :math:`(N, *)` where :math:`*`
should be the same as input shape.
"""
graph = graph.local_var()
graph.ndata['h'] = feat
graph.ndata['norm_h'] = normalize(feat, p=2, axis=-1)
# compute cosine distance
graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos'))
cos = graph.edata.pop('cos')
e = self.beta.data(feat.context) * cos
graph.edata['p'] = edge_softmax(graph, e)
graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h'))
return graph.ndata.pop('h')
75 changes: 75 additions & 0 deletions python/dgl/nn/mxnet/conv/appnpconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""MXNet Module for APPNPConv"""
# pylint: disable= no-member, arguments-differ, invalid-name
import mxnet as mx
from mxnet import nd
from mxnet.gluon import nn

from .... import function as fn

class APPNPConv(nn.Block):
r"""Approximate Personalized Propagation of Neural Predictions
layer from paper `Predict then Propagate: Graph Neural Networks
meet Personalized PageRank <https://arxiv.org/pdf/1810.05997.pdf>`__.

.. math::
H^{0} & = X

H^{t+1} & = (1-\alpha)\left(\hat{D}^{-1/2}
\hat{A} \hat{D}^{-1/2} H^{t} + \alpha H^{0}\right)

Parameters
----------
k : int
Number of iterations :math:`K`.
alpha : float
The teleport probability :math:`\alpha`.
edge_drop : float, optional
Dropout rate on edges that controls the
messages received by each node. Default: ``0``.
"""
def __init__(self,
k,
alpha,
edge_drop=0.):
super(APPNPConv, self).__init__()
self._k = k
self._alpha = alpha
with self.name_scope():
self.edge_drop = nn.Dropout(edge_drop)

def forward(self, graph, feat):
r"""Compute APPNP layer.

Parameters
----------
graph : DGLGraph
The graph.
feat : mx.NDArray
The input feature of shape :math:`(N, *)` :math:`N` is the
number of nodes, and :math:`*` could be of any shape.

Returns
-------
mx.NDArray
The output feature of shape :math:`(N, *)` where :math:`*`
should be the same as input shape.
"""
graph = graph.local_var()
norm = mx.nd.power(mx.nd.clip(
graph.in_degrees().astype(feat.dtype), a_min=1, a_max=float("inf")), -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp).as_in_context(feat.context)
feat_0 = feat
for _ in range(self._k):
# normalization by src node
feat = feat * norm
graph.ndata['h'] = feat
graph.edata['w'] = self.edge_drop(
nd.ones((graph.number_of_edges(), 1), ctx=feat.context))
graph.update_all(fn.u_mul_e('h', 'w', 'm'),
fn.sum('m', 'h'))
feat = graph.ndata.pop('h')
# normalization by dst node
feat = feat * norm
feat = (1 - self._alpha) * feat + self._alpha * feat_0
return feat
Loading