From 938815041178c58648c135d0f9bff43c62853662 Mon Sep 17 00:00:00 2001 From: rudongyu Date: Wed, 21 Jun 2023 08:53:43 +0000 Subject: [PATCH] update doc --- python/dgl/nn/pytorch/gt/spatial_encoder.py | 38 ++++++++++----------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/python/dgl/nn/pytorch/gt/spatial_encoder.py b/python/dgl/nn/pytorch/gt/spatial_encoder.py index 33d4b0ff9a98..580f51814829 100644 --- a/python/dgl/nn/pytorch/gt/spatial_encoder.py +++ b/python/dgl/nn/pytorch/gt/spatial_encoder.py @@ -1,17 +1,14 @@ """Spatial Encoder""" -import math - import torch as th import torch.nn as nn import torch.nn.functional as F -from ....batch import unbatch - def gaussian(x, mean, std): - pi = 3.14159 - a = (2 * pi) ** 0.5 + """compute gaussian basis kernel function""" + const_pi = 3.14159 + a = (2 * const_pi) ** 0.5 return th.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) @@ -93,7 +90,7 @@ class SpatialEncoder3d(nn.Module): `One Transformer Can Understand Both 2D & 3D Molecular Data `__ - This module encodes pair-wise relation between atom pair :math:`(i,j)` in + This module encodes pair-wise relation between node pair :math:`(i,j)` in the 3D geometric space, according to the Gaussian Basis Kernel function: :math:`\psi _{(i,j)} ^k = \frac{1}{\sqrt{2\pi} \lvert \sigma^k \rvert} @@ -101,22 +98,24 @@ class SpatialEncoder3d(nn.Module): r_j \rvert \rvert + \beta_{(i,j)} - \mu^k}{\lvert \sigma^k \rvert} \right) ^2 \right)},k=1,...,K,` - where :math:`K` is the number of Gaussian Basis kernels. - :math:`r_i` is the Cartesian coordinate of atom :math:`i`. - :math:`\gamma_{(i,j)}, \beta_{(i,j)}` are learnable scaling factors of - the Gaussian Basis kernels. + where :math:`K` is the number of Gaussian Basis kernels. :math:`r_i` is the + Cartesian coordinate of node :math:`i`. + :math:`\gamma_{(i,j)}, \beta_{(i,j)}` are learnable scaling factors and + biases determined by node types. :math:`\mu^k, \sigma^k` are learnable + centers and standard deviations of the Gaussian Basis kernels. Parameters ---------- num_kernels : int Number of Gaussian Basis Kernels to be applied. Each Gaussian Basis - Kernel contains a learnable kernel center and a learnable scaling - factor. + Kernel contains a learnable kernel center and a learnable standard + deviation. num_heads : int, optional Number of attention heads if multi-head attention mechanism is applied. Default : 1. max_node_type : int, optional - Maximum number of node types. Default : 100. + Maximum number of node types. Each node type has a corresponding + learnable scaling factor and a bias. Default : 100. Examples -------- @@ -144,8 +143,8 @@ def __init__(self, num_kernels, num_heads=1, max_node_type=100): self.linear_layer_1 = nn.Linear(num_kernels, num_kernels) self.linear_layer_2 = nn.Linear(num_kernels, num_heads) # default mul/bias at position 0 (no node type given) - # src mul/bias at position 1 (pad) ~ max_node_type+1 - # tgt mul/bias at position max_node_type+2 (pad) ~ 2*max_node_type+2 + # src mul/bias at position [1 (pad) ~ max_node_type+1] + # tgt mul/bias at position [max_node_type+2 (pad) ~ 2*max_node_type+2] self.mul = nn.Embedding(2 * max_node_type + 3, 1, padding_idx=0) self.bias = nn.Embedding(2 * max_node_type + 3, 1, padding_idx=0) @@ -159,15 +158,16 @@ def forward(self, coord, node_type=None): Parameters ---------- coord : torch.Tensor - 3D coordinates of nodes in shape :math:`(B, N, 3)`, where :math:`B` is the batch size, :math:`N`: is the maximum number of nodes. + 3D coordinates of nodes in shape :math:`(B, N, 3)`, where :math:`B` + is the batch size, :math:`N`: is the maximum number of nodes. node_type : torch.Tensor, optional Node type ids of nodes. Default : None. * If specified, :attr:`node_type` should be a tensor in shape :math:`(B, N,)`. The scaling factors in gaussian kernels of each pair of nodes are determined by their node types. - * Otherwise, :attr:`node_type` will be set to zeros in the same - shape. + * Otherwise, :attr:`node_type` will be set to zeros of the same + shape by default. Returns -------