Skip to content

Commit

Permalink
update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
rudongyu committed Jun 21, 2023
1 parent 66fa247 commit 9388150
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions python/dgl/nn/pytorch/gt/spatial_encoder.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
"""Spatial Encoder"""

import math

import torch as th
import torch.nn as nn
import torch.nn.functional as F

from ....batch import unbatch


def gaussian(x, mean, std):
pi = 3.14159
a = (2 * pi) ** 0.5
"""compute gaussian basis kernel function"""
const_pi = 3.14159
a = (2 * const_pi) ** 0.5
return th.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)


Expand Down Expand Up @@ -93,30 +90,32 @@ class SpatialEncoder3d(nn.Module):
`One Transformer Can Understand Both 2D & 3D Molecular Data
<https://arxiv.org/pdf/2210.01765.pdf>`__
This module encodes pair-wise relation between atom pair :math:`(i,j)` in
This module encodes pair-wise relation between node pair :math:`(i,j)` in
the 3D geometric space, according to the Gaussian Basis Kernel function:
:math:`\psi _{(i,j)} ^k = \frac{1}{\sqrt{2\pi} \lvert \sigma^k \rvert}
\exp{\left ( -\frac{1}{2} \left( \frac{\gamma_{(i,j)} \lvert \lvert r_i -
r_j \rvert \rvert + \beta_{(i,j)} - \mu^k}{\lvert \sigma^k \rvert} \right)
^2 \right)},k=1,...,K,`
where :math:`K` is the number of Gaussian Basis kernels.
:math:`r_i` is the Cartesian coordinate of atom :math:`i`.
:math:`\gamma_{(i,j)}, \beta_{(i,j)}` are learnable scaling factors of
the Gaussian Basis kernels.
where :math:`K` is the number of Gaussian Basis kernels. :math:`r_i` is the
Cartesian coordinate of node :math:`i`.
:math:`\gamma_{(i,j)}, \beta_{(i,j)}` are learnable scaling factors and
biases determined by node types. :math:`\mu^k, \sigma^k` are learnable
centers and standard deviations of the Gaussian Basis kernels.
Parameters
----------
num_kernels : int
Number of Gaussian Basis Kernels to be applied. Each Gaussian Basis
Kernel contains a learnable kernel center and a learnable scaling
factor.
Kernel contains a learnable kernel center and a learnable standard
deviation.
num_heads : int, optional
Number of attention heads if multi-head attention mechanism is applied.
Default : 1.
max_node_type : int, optional
Maximum number of node types. Default : 100.
Maximum number of node types. Each node type has a corresponding
learnable scaling factor and a bias. Default : 100.
Examples
--------
Expand Down Expand Up @@ -144,8 +143,8 @@ def __init__(self, num_kernels, num_heads=1, max_node_type=100):
self.linear_layer_1 = nn.Linear(num_kernels, num_kernels)
self.linear_layer_2 = nn.Linear(num_kernels, num_heads)
# default mul/bias at position 0 (no node type given)
# src mul/bias at position 1 (pad) ~ max_node_type+1
# tgt mul/bias at position max_node_type+2 (pad) ~ 2*max_node_type+2
# src mul/bias at position [1 (pad) ~ max_node_type+1]
# tgt mul/bias at position [max_node_type+2 (pad) ~ 2*max_node_type+2]
self.mul = nn.Embedding(2 * max_node_type + 3, 1, padding_idx=0)
self.bias = nn.Embedding(2 * max_node_type + 3, 1, padding_idx=0)

Expand All @@ -159,15 +158,16 @@ def forward(self, coord, node_type=None):
Parameters
----------
coord : torch.Tensor
3D coordinates of nodes in shape :math:`(B, N, 3)`, where :math:`B` is the batch size, :math:`N`: is the maximum number of nodes.
3D coordinates of nodes in shape :math:`(B, N, 3)`, where :math:`B`
is the batch size, :math:`N`: is the maximum number of nodes.
node_type : torch.Tensor, optional
Node type ids of nodes. Default : None.
* If specified, :attr:`node_type` should be a tensor in shape
:math:`(B, N,)`. The scaling factors in gaussian kernels of each
pair of nodes are determined by their node types.
* Otherwise, :attr:`node_type` will be set to zeros in the same
shape.
* Otherwise, :attr:`node_type` will be set to zeros of the same
shape by default.
Returns
-------
Expand Down

0 comments on commit 9388150

Please sign in to comment.