diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md new file mode 100644 index 00000000..97a2cf14 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.md @@ -0,0 +1,57 @@ +--- +name: "\U0001F41B Bug Report" +about: Submit a bug report to help us improve OpenHGNN + +--- + +## 🐛 Bug + + + +## To Reproduce + +Steps to reproduce the behavior: + +1. +1. +1. + + + +## Expected behavior + + + +## Environment + + - OpenHGNN Version (e.g., 1.0): + - Backend Library & Version (e.g., PyTorch 0.4.1, DGL 0.7.0): + - OS (e.g., Linux): + - Running command you used (e.g., python main.py -m GTN -d imdb4GTN -t node_classification -g 0 --use_best_config): + - Model configuration you used (e.g., details of the model configuration you used in [config.ini](../../openhgnn/config.ini)): + + - Python version: + - CUDA/cuDNN version (if applicable): + - GPU models and configuration (e.g. V100): + - Any other relevant information: + +## Additional context + + diff --git a/.github/ISSUE_TEMPLATE/feature-request.md b/.github/ISSUE_TEMPLATE/feature-request.md new file mode 100644 index 00000000..4f3b2aee --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.md @@ -0,0 +1,26 @@ +--- +name: "\U0001F680Feature Request" +about: Submit a proposal/request for a new OpenHGNN feature + +--- + +## 🚀 Feature + + +## Motivation + + + +## Alternatives + + + +## Pitch + + + +## Additional context + + diff --git a/.github/ISSUE_TEMPLATE/questions-help-support.md b/.github/ISSUE_TEMPLATE/questions-help-support.md new file mode 100644 index 00000000..5028d831 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/questions-help-support.md @@ -0,0 +1,9 @@ +--- +name: "❓Questions/Help/Support" +about: Do you need support? We have resources. + +--- + +## ❓ Questions and Help + + \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..b0d2d052 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,19 @@ +## Description + + +## Checklist +Please feel free to remove inapplicable items for your PR. +- [ ] The PR title starts with [$CATEGORY] (such as [NN], [Model], [Doc], [Feature]]) +- [ ] Changes are complete (i.e. I finished coding on this PR) +- [ ] All changes have test coverage +- [ ] Code is well-documented +- [ ] To the best of my knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change +- [ ] Related issue is referred in this PR +- [ ] If the PR is for a new model/paper, I've updated the example index [here](../README.md). + +## Changes + diff --git a/README.md b/README.md index ee6e4a2d..1a40ac4f 100644 --- a/README.md +++ b/README.md @@ -148,10 +148,12 @@ python main.py -m GTN -d imdb4GTN -t node_classification -g 0 --use_best_config | [HGSL](./openhgnn/output/HGSL)[AAAI 2021] | :heavy_check_mark: | | | | [HGNN-AC](./openhgnn/output/HGNN_AC)[WWW 2021] | :heavy_check_mark: | | | | [HeCo](./openhgnn/output/HeCo)[KDD 2021] | :heavy_check_mark: | | | -| [SimpleHGN](./openhgnn/output/SimpleHGN)[KDD 2021] | :heavy_check_mark: | | | +| [SimpleHGN](./openhgnn/output/HGT)[KDD 2021] | :heavy_check_mark: | | | | [HPN](./openhgnn/output/HPN)[TKDE 2021] | :heavy_check_mark: | :heavy_check_mark: | | | [RHGNN](./openhgnn/output/RHGNN)[arxiv] | :heavy_check_mark: | | | | [HDE](./openhgnn/output/HDE)[ICDM 2021] | | :heavy_check_mark: | | +| [HetSANN](./openhgnn/output/HGT)[AAAI 2020] | :heavy_check_mark: | | | +| [ieHGCN](./openhgnn/output/HGT)[TKDE 2021] | :heavy_check_mark: | | | ### 候选模型 diff --git a/README_EN.md b/README_EN.md index bab4def4..5390b122 100644 --- a/README_EN.md +++ b/README_EN.md @@ -156,6 +156,8 @@ The link will give some basic usage. | [HPN](../openhgnn/output/HPN)[TKDE 2021] | :heavy_check_mark: | :heavy_check_mark: | | | [RHGNN](../openhgnn/output/RHGNN)[arxiv] | :heavy_check_mark: | | | | [HDE](../openhgnn/output/HDE)[ICDM 2021] | | :heavy_check_mark: | | +| [HetSANN](./openhgnn/output/HGT)[AAAI 2020] | :heavy_check_mark: | | | +| [ieHGCN](./openhgnn/output/HGT)[TKDE 2021] | :heavy_check_mark: | | | ### Candidate models diff --git a/openhgnn/config.ini b/openhgnn/config.ini index ba070ca0..3bcd617b 100644 --- a/openhgnn/config.ini +++ b/openhgnn/config.ini @@ -477,6 +477,31 @@ patience = 100 slope = 0.2 residual = True +[ieHGCN] +in_dim = 64 +num_layers = 5 +hidden_dim = 64 +attn_dim = 32 +out_dim = 16 +patience = 100 +seed = 0 +lr = 0.01 +weight_decay = 5e-4 +max_epoch = 350 + +[HGAT] +in_dim = 64 +num_layers = 3 +hidden_dim = 64 +attn_dim = 32 +num_classes = 16 +negative_slope = 0.2 +patience = 100 +seed = 0 +lr = 0.01 +weight_decay = 5e-4 +max_epoch = 350 + [TransE] seed = 0 patience = 3 diff --git a/openhgnn/config.py b/openhgnn/config.py index 11fc663e..6912b6c7 100644 --- a/openhgnn/config.py +++ b/openhgnn/config.py @@ -557,7 +557,34 @@ def __init__(self, file_path, model, dataset, task, gpu): self.residual = conf.getboolean("HetSANN", "residual") self.mini_batch_flag = False self.hidden_dim = self.h_dim * self.num_heads - + self.mini_batch_flag = False + self.hidden_dim = self.h_dim * self.num_heads + elif self.model_name == 'ieHGCN': + self.weight_decay = conf.getfloat("ieHGCN", "weight_decay") + self.lr = conf.getfloat("ieHGCN", "lr") + self.max_epoch = conf.getint("ieHGCN", "max_epoch") + self.seed = conf.getint("ieHGCN", "seed") + self.attn_dim = conf.getint("ieHGCN", "attn_dim") + self.num_layers = conf.getint("ieHGCN","num_layers") + self.mini_batch_flag = False + self.hidden_dim = conf.getint("ieHGCN", "hidden_dim") + self.in_dim = conf.getint("ieHGCN", "in_dim") + self.out_dim = conf.getint("ieHGCN", "out_dim") + self.patience = conf.getint("ieHGCN", "patience") + elif self.model_name == 'HGAT': + self.weight_decay = conf.getfloat("HGAT", "weight_decay") + self.lr = conf.getfloat("HGAT", "lr") + self.max_epoch = conf.getint("HGAT", "max_epoch") + self.seed = conf.getint("HGAT", "seed") + self.attn_dim = conf.getint("HGAT", "attn_dim") + self.num_layers = conf.getint("HGAT","num_layers") + self.mini_batch_flag = False + self.hidden_dim = conf.getint("HGAT", "hidden_dim") + self.in_dim = conf.getint("HGAT", "in_dim") + self.num_classes = conf.getint("HGAT", "num_classes") + self.patience = conf.getint("HGAT", "patience") + self.negative_slope = conf.getfloat("HGAT", "negative_slope") + elif self.model_name == 'TransE': self.seed = conf.getint("TransE", "seed") self.patience = conf.getint("TransE", "patience") diff --git a/openhgnn/models/HGAT.py b/openhgnn/models/HGAT.py new file mode 100644 index 00000000..df58c0ac --- /dev/null +++ b/openhgnn/models/HGAT.py @@ -0,0 +1,162 @@ +import dgl +import torch +import torch.nn as nn +import dgl.function as Fn +import torch.nn.functional as F + +from dgl.ops import edge_softmax, segment_softmax +from dgl.nn import HeteroLinear, TypedLinear +from dgl.nn.pytorch.conv import GraphConv +from . import BaseModel, register_model +from ..utils import to_hetero_feat + +@register_model('HGAT') +class HGAT(BaseModel): + @classmethod + def build_model_from_args(cls, args, hg): + return cls(args.num_layers, + args.in_dim, + args.hidden_dim, + args.attn_dim, + args.num_classes, + hg.ntypes, + args.negative_slope) + + def __init__(self, num_layers, in_dim, hidden_dim, attn_dim, + num_classes, ntypes, negative_slope): + super(HGAT, self).__init__() + self.num_layers = num_layers + self.activation = F.elu + + + self.hgat_layers = nn.ModuleList() + self.hgat_layers.append( + TypeAttention(in_dim, + attn_dim, + ntypes, + negative_slope)) + self.hgat_layers.append( + NodeAttention(in_dim, + attn_dim, + hidden_dim, + negative_slope) + ) + for l in range(num_layers - 1): + self.hgat_layers.append( + TypeAttention(hidden_dim, + attn_dim, + ntypes, + negative_slope)) + self.hgat_layers.append( + NodeAttention(hidden_dim, + attn_dim, + hidden_dim, + negative_slope) + ) + + self.hgat_layers.append( + TypeAttention(hidden_dim, + attn_dim, + ntypes, + negative_slope)) + self.hgat_layers.append( + NodeAttention(hidden_dim, + attn_dim, + num_classes, + negative_slope) + ) + + + def forward(self, hg, h_dict): + with hg.local_scope(): + hg.ndata['h'] = h_dict + for l in range(self.num_layers): + attention = self.hgat_layers[2 * l](hg, hg.ndata['h']) + hg.edata['alpha'] = attention + g = dgl.to_homogeneous(hg, ndata = 'h', edata = ['alpha']) + h = self.hgat_layers[2 * l + 1](g, g.ndata['h'], g.ndata['_TYPE'], g.ndata['_TYPE'], presorted = True) + h_dict = to_hetero_feat(h, g.ndata['_TYPE'], hg.ntypes) + hg.ndata['h'] = h_dict + + return h_dict + +class TypeAttention(nn.Module): + def __init__(self, in_dim, ntypes, slope): + super(TypeAttention, self).__init__() + attn_vector = {} + for ntype in ntypes: + attn_vector[ntype] = in_dim + self.mu_l = HeteroLinear(attn_vector, in_dim) + self.mu_r = HeteroLinear(attn_vector, in_dim) + self.leakyrelu = nn.LeakyReLU(slope) + + def forward(self, hg, h_dict): + h_t = {} + attention = {} + with hg.local_scope(): + hg.ndata['h'] = h_dict + for srctype, etype, dsttype in hg.canonical_etypes: + rel_graph = hg[srctype, etype, dsttype] + if srctype not in h_dict: + continue + with rel_graph.local_scope(): + degs = rel_graph.out_degrees().float().clamp(min = 1) + norm = torch.pow(degs, -0.5) + feat_src = h_dict[srctype] + shp = norm.shape + (1,) * (feat_src.dim() - 1) + norm = torch.reshape(norm, shp) + feat_src = feat_src * norm + rel_graph.srcdata['h'] = feat_src + rel_graph.update_all(Fn.copy_src('h', 'm'), Fn.sum(msg='m', out='h')) + rst = rel_graph.dstdata['h'] + degs = rel_graph.in_degrees().float().clamp(min=1) + norm = torch.pow(degs, -0.5) + shp = norm.shape + (1,) * (feat_src.dim() - 1) + norm = torch.reshape(norm, shp) + rst = rst * norm + h_t[srctype] = rst + h_l = self.mu_l(h_dict)[dsttype] + h_r = self.mu_r(h_t)[srctype] + edge_attention = F.elu(h_l + h_r) + # edge_attention = F.elu(h_l + h_r).unsqueeze(0) + rel_graph.ndata['m'] = {dsttype: edge_attention, + srctype: torch.zeros((rel_graph.num_nodes(ntype = srctype),))} + # print(rel_graph.ndata) + reverse_graph = dgl.reverse(rel_graph) + reverse_graph.apply_edges(Fn.copy_src('m', 'alpha')) + + hg.edata['alpha'] = {(srctype, etype, dsttype): reverse_graph.edata['alpha']} + + # if dsttype not in attention.keys(): + # attention[dsttype] = edge_attention + # else: + # attention[dsttype] = torch.cat((attention[dsttype], edge_attention)) + attention = edge_softmax(hg, hg.edata['alpha']) + # for ntype in hg.dsttypes: + # attention[ntype] = F.softmax(attention[ntype], dim = 0) + + return attention + +class NodeAttention(nn.Module): + def __init__(self, in_dim, out_dim, slope): + super(NodeAttention, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.Mu_l = nn.Linear(in_dim, in_dim) + self.Mu_r = nn.Linear(in_dim, in_dim) + self.leakyrelu = nn.LeakyReLU(slope) + + def forward(self, g, x, ntype, etype, presorted = False): + with g.local_scope(): + src = g.edges()[0] + dst = g.edges()[1] + h_l = self.Mu_l(x)[src] + h_r = self.Mu_r(x)[dst] + edge_attention = self.leakyrelu((h_l + h_r) * g.edata['alpha']) + edge_attention = edge_softmax(g, edge_attention) + g.edata['alpha'] = edge_attention + g.srcdata['x'] = x + g.update_all(Fn.u_mul_e('x', 'alpha', 'm'), + Fn.sum('m', 'x')) + h = g.ndata['x'] + return h \ No newline at end of file diff --git a/openhgnn/models/HetSANN.py b/openhgnn/models/HetSANN.py new file mode 100644 index 00000000..e222351f --- /dev/null +++ b/openhgnn/models/HetSANN.py @@ -0,0 +1,313 @@ +import dgl +import torch +import torch.nn as nn +import dgl.function as Fn +import torch.nn.functional as F + +from dgl.ops import edge_softmax +from dgl.nn import TypedLinear +from ..utils import to_hetero_feat +from . import BaseModel, register_model + +@register_model('HetSANN') +class HetSANN(BaseModel): + @classmethod + def build_model_from_args(cls, args, hg): + return cls( + args.num_heads, + args.num_layers, + args.hidden_dim, + args.h_dim, + args.out_dim, + len(hg.etypes), + args.dropout, + args.slope, + args.residual, + ) + + def __init__(self, num_heads, num_layers, in_dim, hidden_dim, + num_classes, num_etypes, dropout, negative_slope, residual): + """ + This is a model HetSANN from `An Attention-Based Graph Neural Network for Heterogeneous Structural Learning + `__ + + It contains the following part: + + Apply a linear transformation: + + ..math:: + h^{(l+1, m)}_{\phi(j),i} = W^{(l+1, m)}_{\phi(j),\phi(i)} h^{(l)}_i (1) + + And return the new embeddings. + + You may refer to the paper HetSANN-Section 2.1-Type-aware Attention Layer-(1) + + Aggregation of Neighborhood: + + Computing the attention coefficient: + + ..math:: + o^{(l+1,m)}_e = \sigma(f^{(l+1,m)}_r(h^{(l+1, m)}_{\phi(j),j}, h^{(l+1, m)}_{\phi(j),i})) (2) + + ..math:: + f^{(l+1,m)}_r(e) = [h^{(l+1, m)^T}_{\phi(j),j}||h^{(l+1, m)^T}_{\phi(j),i}]a^{(l+1, m)}_r ] (3) + + ..math:: + \alpha^{(l+1,m)}_e = exp(o^{(l+1,m)}_e) / \sum_{k\in \varepsilon_j} exp(o^{(l+1,m)}_k) (4) + + Getting new embeddings with multi-head and residual + + ..math:: + h^{(l + 1, m)}_j = \sigma(\sum_{e = (i,j,r)\in \varepsilon_j} \alpha^{(l+1,m)}_e h^{(l+1, m)}_{\phi(j),i}) (5) + + Multi-heads: + + ..math:: + h^{(l+1)}_j = \parallel^M_{m = 1}h^{(l + 1, m)}_j (6) + + Residual: + + ..math:: + h^{(l+1)}_j = h^{(l)}_j + \parallel^M_{m = 1}h^{(l + 1, m)}_j (7) + + Parameters + ---------- + num_heads: int + the number of heads in the attention computing + num_layers: int + the number of layers we used in the computing + in_dim: int + the input dimension + hidden_dim: int + the hidden dimension + num_classes: int + the number of the output classes + num_etypes: int + the number of the edge types + dropout: float + the dropout rate + negative_slope: float + the negative slope used in the LeakyReLU + residual: boolean + if we need the residual operation + + """ + super(HetSANN, self).__init__() + self.num_heads = num_heads + self.num_layers = num_layers + # self.dropout = nn.Dropout(dropout) + self.residual = residual + self.activation = F.elu + + self.het_layers = nn.ModuleList() + + # input projection + self.het_layers.append( + HetSANNConv( + num_heads, + in_dim, + hidden_dim, + num_etypes, + dropout, + negative_slope, + False, + self.activation, + ) + ) + + # hidden layer + for i in range(1, num_layers - 1): + self.het_layers.append( + HetSANNConv( + num_heads, + hidden_dim * num_heads, + hidden_dim, + num_etypes, + dropout, + negative_slope, + residual, + self.activation + ) + ) + + # output projection + self.het_layers.append( + HetSANNConv( + 1, + hidden_dim * num_heads, + num_classes, + num_etypes, + dropout, + negative_slope, + residual, + None, + ) + ) + + def forward(self, hg, h_dict): + """ + The forward part of the HetSANN. + + Parameters + ---------- + hg : object + the dgl heterogeneous graph + h_dict: dict + the feature dict of different node types + + Returns + ------- + dict + The embeddings after the output projection. + """ + with hg.local_scope(): + # input layer and hidden layers + hg.ndata['h'] = h_dict + g = dgl.to_homogeneous(hg, ndata = 'h') + h = g.ndata['h'] + for i in range(self.num_layers - 1): + h = self.het_layers[i](g, h, g.ndata['_TYPE'], g.edata['_TYPE'], True) + + # output layer + h = self.het_layers[-1](g, h, g.ndata['_TYPE'], g.edata['_TYPE'], True) + + h_dict = to_hetero_feat(h, g.ndata['_TYPE'], hg.ntypes) + # g.ndata['h'] = h + # hg = dgl.to_heterogeneous(g, hg.ntypes, hg.etypes) + # h_dict = hg.ndata['h'] + + # for etype in hg.etypes: + # source = etype.split('-')[0] + # h[source] = self.W_out[etype](h_dict[source]) + # pre_h = dgl.to_homogeneous(hg, ndata = 'h').ndata['h'] + # hg.ndata['h'] = h + # g = dgl.to_homogeneous(hg, ndata = 'h') + # h = self.het_layers[-1](g, pre_h) + # hg = dgl.to_heterogeneous(g, hg.ntypes, hg.etypes) + # h_dict = hg.ndata['h'] + + return h_dict + +class HetSANNConv(nn.Module): + def __init__(self, num_heads, in_dim, hidden_dim, num_etypes, + dropout, negative_slope, residual, activation): + """ + The HetSANN convolution layer. + + Parameters + ---------- + num_heads: int + the number of heads in the attention computing + in_dim: int + the input dimension of the feature + hidden_dim: int + the hidden dimension + num_etypes: int + the number of the edge types + dropout: float + the dropout rate + negative_slope: float + the negative slope used in the LeakyReLU + residual: boolean + if we need the residual operation + activation: str + the activation function + """ + super(HetSANNConv, self).__init__() + self.num_heads = num_heads + self.in_dim = in_dim + self.hidden_dim = hidden_dim + + self.W = TypedLinear(in_dim, hidden_dim * num_heads, num_etypes) + # self.W_out = TypedLinear(hidden_dim * num_heads, num_classes, num_etypes) + + # self.W_hidden = nn.ModuleDict() + # self.W_out = nn.ModuleDict() + + # for etype in etypes: + # self.W_hidden[etype] = nn.Linear(in_dim, hidden_dim * num_heads) + + # for etype in etypes: + # self.W_out[etype] = nn.Linear(hidden_dim * num_heads, num_classes) + + self.a_l = TypedLinear(self.hidden_dim, self.hidden_dim, num_etypes) + self.a_r = TypedLinear(self.hidden_dim, self.hidden_dim, num_etypes) + + self.dropout = nn.Dropout(dropout) + self.leakyrelu = nn.LeakyReLU(negative_slope) + if residual: + self.residual = nn.Linear(in_dim, hidden_dim * num_heads) + else: + self.register_buffer("residual", None) + + self.activation = activation + + + def forward(self, g, x, ntype, etype, presorted = False): + """ + The forward part of the HetSANNConv. + + Parameters + ---------- + g : object + the dgl homogeneous graph + x: tensor + the original features of the graph + ntype: tensor + the node type of the graph + etype: tensor + the edge type of the graph + presorted: boolean + if the ntype and etype are preordered, default: ``False`` + + Returns + ------- + tensor + The embeddings after aggregation. + """ + # formula (1) + feat = self.W(x, ntype, presorted) + h = self.dropout(feat) + h = feat.view(-1, self.num_heads, self.hidden_dim) + + src = g.edges()[0] + dst = g.edges()[1] + + # formula (2) (3) (4) + h_l = self.a_l(h.view(-1, self.hidden_dim), ntype, presorted) \ + .view(-1, self.num_heads, self.hidden_dim).sum(dim = -1)[src] + + h_r = self.a_r(h.view(-1, self.hidden_dim), ntype, presorted) \ + .view(-1, self.num_heads, self.hidden_dim).sum(dim = -1)[dst] + + attention = self.leakyrelu(h_l + h_r) + attention = edge_softmax(g, attention) + + # formula (5) (6) + with g.local_scope(): + h = h.permute(0, 2, 1).contiguous() + g.edata['alpha'] = attention + g.srcdata['emb'] = h + g.update_all(Fn.u_mul_e('emb', 'alpha', 'm'), + Fn.sum('m', 'emb')) + h_output = g.ndata['emb'].view(-1, self.hidden_dim * self.num_heads) + + # h_prime = [] + # h = h.permute(1, 0, 2).contiguous() + # for i in range(self.num_heads): + # g.edata['alpha'] = attention[:, i] + # g.srcdata.update({'emb': h[i]}) + # g.update_all(Fn.u_mul_e('emb', 'alpha', 'm'), + # Fn.sum('m', 'emb')) + # h_prime.append(g.ndata['emb']) + # h_output = torch.cat(h_prime, dim=1) + + # formula (7) + if self.residual: + res = self.residual(x) + h_output += res + + if self.activation is not None: + h_output = self.activation(h_output) + + return h_output \ No newline at end of file diff --git a/openhgnn/models/SimpleHGN.py b/openhgnn/models/SimpleHGN.py index c0e4cd6c..ef04718c 100644 --- a/openhgnn/models/SimpleHGN.py +++ b/openhgnn/models/SimpleHGN.py @@ -172,36 +172,36 @@ def forward(self, hg, h_dict): return h_dict class SimpleHGNConv(nn.Module): + r""" + The SimpleHGN convolution layer. + + Parameters + ---------- + edge_dim: int + the edge dimension + num_etypes: int + the number of the edge type + in_dim: int + the input dimension + out_dim: int + the output dimension + num_heads: int + the number of heads + num_etypes: int + the number of edge type + feat_drop: float + the feature drop rate + negative_slope: float + the negative slope used in the LeakyReLU + residual: boolean + if we need the residual operation + activation: str + the activation function + beta: float + the hyperparameter used in edge residual + """ def __init__(self, edge_dim, in_dim, out_dim, num_heads, num_etypes, feat_drop=0.0, negative_slope=0.2, residual=True, activation=F.elu, beta=0.0): - """ - The SimpleHGN convolution layer. - - Parameters - ---------- - edge_dim: int - the edge dimension - num_etypes: int - the number of the edge type - in_dim: int - the input dimension - out_dim: int - the output dimension - num_heads: int - the number of heads - num_etypes: int - the number of edge type - feat_drop: float - the feature drop rate - negative_slope: float - the negative slope used in the LeakyReLU - residual: boolean - if we need the residual operation - activation: str - the activation function - beta: float - the hyperparameter used in edge residual - """ super(SimpleHGNConv, self).__init__() self.edge_dim = edge_dim self.in_dim = in_dim diff --git a/openhgnn/models/__init__.py b/openhgnn/models/__init__.py index 97f5332e..0419062c 100644 --- a/openhgnn/models/__init__.py +++ b/openhgnn/models/__init__.py @@ -89,6 +89,7 @@ def build_model_from_args(args, hg): 'GATNE-T': 'openhgnn.models.GATNE', 'HetSANN': 'openhgnn.models.HetSANN', 'HGAT': 'openhgnn.models.HGAT', + 'ieHGCN': 'openhgnn.models.ieHGCN', 'TransE': 'openhgnn.models.TransE', 'TransH': 'openhgnn.models.TransH', 'TransR': 'openhgnn.models.TransR', @@ -120,6 +121,9 @@ def build_model_from_args(args, hg): from .general_HGNN import general_HGNN from .HDE import HDE from .SimpleHGN import SimpleHGN +from .HetSANN import HetSANN +from .ieHGCN import ieHGCN +from .HGAT import HGAT from .GATNE import GATNE __all__ = [ diff --git a/openhgnn/models/ieHGCN.py b/openhgnn/models/ieHGCN.py new file mode 100644 index 00000000..43cbeb08 --- /dev/null +++ b/openhgnn/models/ieHGCN.py @@ -0,0 +1,276 @@ +import dgl +import torch +import torch.nn as nn +import torch.nn.functional as F +import dgl.nn.pytorch as dglnn +import numpy as np + +from . import BaseModel, register_model +from ..utils import to_hetero_feat + +import sys + + +@register_model('ieHGCN') +class ieHGCN(BaseModel): + r""" + Description + ----------- + ie-HGCN from paper `Interpretable and Efficient Heterogeneous Graph Convolutional Network + `__. + + `Source Code Link `_ + + Description + ----------- + The core part of ie-HGCN, the calculating flow of projection, object-level aggregation and type-level aggregation in + a specific type block. + + Projection + .. math:: + Y^{Self-\Omega }=H^{\Omega} \cdot W^{Self-\Omega} (1)-1 + + Y^{\Gamma - \Omega}=H^{\Gamma} \cdot W^{\Gamma - \Omega} , \Gamma \in N_{\Omega} (1)-2 + + Object-level Aggregation + .. math:: + Z^{ Self - \Omega } = Y^{ Self - \Omega}=H^{\Omega} \cdot W^{Self - \Omega} (2)-1 + + Z^{\Gamma - \Omega}=\hat{A}^{\Omega-\Gamma} \cdot Y^{\Gamma - \Omega} = \hat{A}^{\Omega-\Gamma} \cdot H^{\Gamma} \cdot W^{\Gamma - \Omega} (2)-2 + + Type-level Aggregation + .. math:: + Q^{\Omega}=Z^{Self-\Omega} \cdot W_q^{\Omega} (3)-1 + + K^{Self-\Omega}=Z^{Self -\Omega} \cdot W_{k}^{\Omega} (3)-2 + + K^{\Gamma - \Omega}=Z^{\Gamma - \Omega} \cdot W_{k}^{\Omega}, \quad \Gamma \in N_{\Omega} (3)-3 + + .. math:: + e^{Self-\Omega}={ELU} ([K^{ Self-\Omega} \| Q^{\Omega}] \cdot w_{a}^{\Omega}) (4)-1 + + e^{\Gamma - \Omega}={ELU} ([K^{\Gamma - \Omega} \| Q^{\Omega}] \cdot w_{a}^{\Omega}), \Gamma \in N_{\Omega} (4)-2 + + .. math:: + [a^{Self-\Omega}\|a^{1 - \Omega}\| \ldots . a^{\Gamma - \Omega}\|\ldots\| a^{|N_{\Omega}| - \Omega}]= + {softmax}([e^{Self - \Omega}\|e^{1 - \Omega}\| \ldots\|e^{\Gamma - \Omega}\| \ldots \| e^{|\N_{\Omega}| - \Omega}]) (5) + + .. math:: + H_{i,:}^{\Omega \prime}=\sigma(a_{i}^{Self-\Omega} \cdot Z_{i,:}^{Self-\Omega}+\sum_{\Gamma \in N_{\Omega}} a_{i}^{\Gamma - \Omega} \cdot Z_{i,:}^{\Gamma - \Omega}) (6) + + Parameters + ---------- + num_layers: int + the number of layers + in_dim: int + the input dimension + hidden_dim: int + the hidden dimension + out_dim: int + the output dimension + attn_dim: int + the dimension of attention vector + ntypes: list + the node type of a heterogeneous graph + etypes: list + the edge type of a heterogeneous graph + """ + @classmethod + def build_model_from_args(cls, args, hg:dgl.DGLGraph): + return cls(args.num_layers, + args.in_dim, + args.hidden_dim, + args.out_dim, + args.attn_dim, + hg.ntypes, + hg.etypes + ) + + def __init__(self, num_layers, in_dim, hidden_dim, out_dim, attn_dim, ntypes, etypes): + super(ieHGCN, self).__init__() + self.num_layers = num_layers + self.activation = F.elu + self.hgcn_layers = nn.ModuleList() + + self.hgcn_layers.append( + ieHGCNConv( + in_dim, + hidden_dim, + attn_dim, + ntypes, + etypes, + self.activation, + ) + ) + + for i in range(1, num_layers - 1): + self.hgcn_layers.append( + ieHGCNConv( + hidden_dim, + hidden_dim, + attn_dim, + ntypes, + etypes, + self.activation + ) + ) + + self.hgcn_layers.append( + ieHGCNConv( + hidden_dim, + out_dim, + attn_dim, + ntypes, + etypes, + None, + ) + ) + + def forward(self, hg, h_dict): + """ + The forward part of the ieHGCN. + + Parameters + ---------- + hg : object + the dgl heterogeneous graph + h_dict: dict + the feature dict of different node types + + Returns + ------- + dict + The embeddings after the output projection. + """ + with hg.local_scope(): + hg.ndata['h'] = h_dict + for l in range(self.num_layers): + h_dict = self.hgcn_layers[l](hg, h_dict) + + return h_dict + +class ieHGCNConv(nn.Module): + r""" + The ieHGCN convolution layer. + + Parameters + ---------- + in_size: int + the input dimension + out_size: int + the output dimension + attn_size: int + the dimension of attention vector + ntypes: list + the node type list of a heterogeneous graph + etypes: list + the feature drop rate + activation: str + the activation function + """ + def __init__(self, in_size, out_size, attn_size, ntypes, etypes, activation = F.elu): + super(ieHGCNConv, self).__init__() + node_size = {} + for ntype in ntypes: + node_size[ntype] = in_size + attn_vector = {} + for ntype in ntypes: + attn_vector[ntype] = attn_size + self.W_self = dglnn.HeteroLinear(node_size, out_size) + self.W_al = dglnn.HeteroLinear(attn_vector, 1) + self.W_ar = dglnn.HeteroLinear(attn_vector, 1) + + # self.conv = dglnn.HeteroGraphConv({ + # etype: dglnn.GraphConv(in_size, out_size, norm = 'right', weight = True, bias = True) + # for etype in etypes + # }) + self.in_size = in_size + self.out_size = out_size + self.attn_size = attn_size + mods = { + etype: dglnn.GraphConv(in_size, out_size, norm = 'right', + weight = True, bias = True, allow_zero_in_degree = True) + for etype in etypes + } + self.mods = nn.ModuleDict(mods) + + self.linear_q = nn.ModuleDict({ntype: nn.Linear(out_size, attn_size) for ntype in ntypes}) + self.linear_k = nn.ModuleDict({ntype: nn.Linear(out_size, attn_size) for ntype in ntypes}) + + self.activation = activation + + + def forward(self, hg, h_dict): + """ + The forward part of the ieHGCNConv. + + Parameters + ---------- + hg : object + the dgl heterogeneous graph + h_dict: dict + the feature dict of different node types + + Returns + ------- + dict + The embeddings after final aggregation. + """ + outputs = {ntype: [] for ntype in hg.dsttypes} + with hg.local_scope(): + hg.ndata['h'] = h_dict + # formulas (2)-1 + hg.ndata['z'] = self.W_self(hg.ndata['h']) + query = {} + key = {} + attn = {} + attention = {} + + # formulas (3)-1 and (3)-2 + for ntype in hg.dsttypes: + query[ntype] = self.linear_q[ntype](hg.ndata['z'][ntype]) + key[ntype] = self.linear_k[ntype](hg.ndata['z'][ntype]) + # formulas (4)-1 + h_l = self.W_al(key) + h_r = self.W_ar(query) + for ntype in hg.dsttypes: + attention[ntype] = F.elu(h_l[ntype] + h_r[ntype]) + attention[ntype] = attention[ntype].unsqueeze(0) + + for srctype, etype, dsttype in hg.canonical_etypes: + rel_graph = hg[srctype, etype, dsttype] + if srctype not in h_dict: + continue + # formulas (2)-2 + dstdata = self.mods[etype]( + rel_graph, + (h_dict[srctype], h_dict[dsttype]) + ) + outputs[dsttype].append(dstdata) + # formulas (3)-3 + attn[dsttype] = self.linear_k[dsttype](dstdata) + # formulas (4)-2 + h_attn = self.W_al(attn) + attn.clear() + edge_attention = F.elu(h_attn[dsttype] + h_r[dsttype]) + attention[dsttype] = torch.cat((attention[dsttype], edge_attention.unsqueeze(0))) + + # formulas (5) + for ntype in hg.dsttypes: + attention[ntype] = F.softmax(attention[ntype], dim = 0) + + # formulas (6) + rst = {ntype: 0 for ntype in hg.dsttypes} + for ntype, data in outputs.items(): + data = [hg.ndata['z'][ntype]] + data + if len(data) != 0: + for i in range(len(data)): + aggregation = torch.mul(data[i], attention[ntype][i]) + rst[ntype] = aggregation + rst[ntype] + + # h = self.conv(hg, hg.ndata['h'], aggregate = self.my_agg_func) + if self.activation is not None: + for ntype in rst.keys(): + rst[ntype] = self.activation(rst[ntype]) + + return rst \ No newline at end of file diff --git a/openhgnn/output/HGT/README.md b/openhgnn/output/HGT/README.md index 4e32c147..1fe7621d 100644 --- a/openhgnn/output/HGT/README.md +++ b/openhgnn/output/HGT/README.md @@ -1,44 +1,147 @@ -# HGT[WWW 2020] +# Attention Network -- paper: [Heterogeneous Graph Transformer](https://arxiv.org/abs/2003.01332) +| Model| Paper| +|:-----:|:-----:| +|HGT(WWW 2019)| [Heterogeneous Graph Transformer](https://arxiv.org/abs/2003.01332)| +|SimpleHGN(KDD 2021)|[Are we really making much progress? Revisiting, benchmarking,and refining heterogeneous graph neural networks](https://dl.acm.org/doi/pdf/10.1145/3447548.3467350)| +|HetSANN(AAAI 2020)|[An Attention-Based Graph Neural Network for Heterogeneous Structural Learning](https://arxiv.org/abs/1912.10832)| +|ieHGCN(TKDE 2021)|[Interpretable and Efficient Heterogeneous Graph Convolutional Network](https://arxiv.org/pdf/2005.13183.pdf)| -## Basic Idea +## Attention mechanism +This part, we will give the definition of attention methanism based on **GAT** and **Transformer**. -- The model designed node-type and edge-type dependent parameters to characterize the heterogeneous attention over each edge, empowering -HGT to maintain dedicated representations for different types of nodes and edges. -- At each layer, Compute a multi-head attention score for each edge $(s, e, t)$ in the graph: +- In [GAT](https://arxiv.org/abs/1710.10903), it defined the attentional mechanism. A shared linear transformation, parametrized by a weight matrix, $W\in\mathcal{R}^{F^{'}\times F}$, is applied to every node. Then use a shared attentional mechanism $a: \mathcal{R}^{F^{'}}\times \mathcal{R}^{F}\rightarrow \mathcal{R}$ to compute *attention coefficients*: $$ -Attention(s, e, t) = \text{Softmax}\left(||_{i\in[1,h]}ATT-head^i(s, e, t)\right) \\ -ATT-head^i(s, e, t) = \left(K^i(s)W^{ATT}_{\phi(e)}Q^i(t)^{\top}\right)\cdot -\frac{\mu_{(\tau(s),\phi(e),\tau(t)}}{\sqrt{d}} \\ -K^i(s) = \text{K-Linear}^i_{\tau(s)}(H^{(l-1)}[s]) \\ -Q^i(t) = \text{Q-Linear}^i_{\tau(t)}(H^{(l-1)}[t]) \\ +e_{ij} = a(Wh_i, Wh_j) $$ -- Compute the message to send on each edge $(s, e, t)$ : +- this indicate the importance of node $j$'s features to node $i$. $a$ is a single-layer feedforward neural network. Finally we can normalize them across all choices of $j$ using the softmax function: $$ -Message(s, e, t) = ||_{i\in[1, h]} MSG-head^i(s, e, t) \\ -MSG-head^i(s, e, t) = \text{M-Linear}^i_{\tau(s)}(H^{(l-1)}[s])W^{MSG}_{\phi(e)} \\ +\alpha_{ij} = softmax_j(e_{ij}) = \frac{\text{exp}(e_{ij})}{\sum_{k\in \mathcal{N}_i} \text{exp}(e_{ik})} $$ -- Send messages to target nodes $t$ and aggregate: + +- In [Transformer](https://arxiv.org/abs/1706.03762), an attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key. e.g. Scaled Dot-Product Attention: $$ -\tilde{H}^{(l)}[t] = \sum_{\forall s\in \mathcal{N}(t)}\left( Attention(s,e,t) -\cdot Message(s,e,t)\right) -$$ -- Compute new node features: -$$ -H^{(l)}[t]=\text{A-Linear}_{\tau(t)}(\sigma(\tilde(H)^{(l)}[t])) + H^{(l-1)}[t] +Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V $$ +## DGL API +This part, we will give DGL API we used. As DGL released 0.8.0 version, more API can support heterogeneous graph such as TypedLinear, HeteroLinear. So we will give some details of these APIs. + +### [TypedLinear](https://docs.dgl.ai/generated/dgl.nn.pytorch.TypedLinear.html) + +```python +class TypedLinear(in_size, out_size, num_types, regularizer=None, num_bases=None) +``` +Apply linear transformation according to types. + +Parameters: +- in_size(int): Input feature size. +- out_size(int): Output feature size. +- num_types(int): Number of types(node or edge). +- regularizer(str, optional): Which weight regularizer to use “basis” or “bdd”, default is **None**: + + - basis: basis-decomposition. + - bdd: block-diagonal-decomposition. + +- num_bases(int, optional): Number of bases. Needed when **regularizer** is specified. Typically smaller than **num_types**. Default: **None**. + +```python +forward(x, x_type, sorted_by_type=False) +``` +Parameters: +- x(tensor): The input tensor. +- x_type(tensor): 1D tensor storing the type of the element in **x**. +- sorted_by_type(boolean): Whether the inputs have been sorted by the types. Forward on pre-sorted inputs may be faster. + +So this API can be used when we use **to_homogeneous** to convert a heterogeneous graph to a homogeneous graph. + +### [HeteroLinear](https://docs.dgl.ai/generated/dgl.nn.pytorch.HeteroLinear.html) + +```python +class HeteroLinear(in_size, out_size, bias=True) +``` +Apply linear transformations on heterogeneous inputs. + +Parameters: +- in_size(dict[key, int]): Input feature size for heterogeneous inputs. A key can be a string or a tuple of strings. +- out_size(int): Output feature size. +- bias(boolean): If **True**, learns a bias term. + +```python +forward(feat) +``` +Parameters: +- feat(dict[key, tensor]): Heterogeneous input features. + +So this API can be used if we want to apply different linear transformations to different types. + +### [HeteroGraphConv](https://docs.dgl.ai/generated/dgl.nn.pytorch.HeteroGraphConv.html) + +```python +class HeteroGraphConv(mods, aggregate='sum') +``` +The heterograph convolution applies sub-modules on their associating relation graphs, which reads the features from source nodes and writes the updated ones to destination nodes. If multiple relations have the same destination node types, their results are aggregated by the specified method. If the relation graph has no edge, the corresponding module will not be called. + +Parameters: +- mods(dict[str, nn.Module]): Modules associated with every edge types. +- aggregate (str, callable, optional): Method for aggregating node features generated by different relations. Allowed string values are ‘sum’, ‘max’, ‘min’, ‘mean’, ‘stack’. User can also customize the aggregator by providing a callable instance. + +```python +forward(g, inputs, mod_args=None, mod_kwargs=None) +``` +Parameters: +- g(DGLHeteroGraph) – Graph data. +- inputs(dict[str, Tensor] or pair of dict[str, Tensor]) – Input node features. + +So this API can be used when we need to get relation subgraphs and apply nn.Module to each subgraph. + +## Typical model + +Based on HeteroGraphConv, we divide the attention model into two categories: Direct-Aggregation models and Dual-Aggregation models. +### Direct-Aggregation models + +| Model| Attention coefficient | +|:-----:|:-----:| +|HGT|$W_{Q_{\phi{(s)}}}h_s W^{ATT}_{\psi{(r)}}(W_{K_{\phi{(t)}}}h_t)^T$| +|SimpleHGN|$LeakyReLU(a^T[Wh_s \parallel Wh_t \parallel W_r r_{\psi()}])$| +|HetSANN|$LeakyReLU([W_{\phi(t),\phi(s)} h_s\parallel W_{\phi(t),\phi(s)} h_t]a_r)$| + +These models only have one aggregation process and do not distinguish between types of edges when aggregating, so they are not suitable for HeteroGraphConv. + + +### Dual-aggregation model + +#### ieHGCN + +| Model| Attention coefficient | +|:-----:|:-----:| +|ieHGCN|$ELU(a^T_{\phi(s)}[W_{Q_{\phi(s)}}h_s\parallel W_{K_{\phi(t)}}h_t])$| + +This model has two aggregation process and distinguish between types of edges when aggregating, so this is suitable for HeteroGraphConv. +## Implement Details + +### Direct-Aggregation models +- We first implement the convolution layer of the model SimpleHGN, and HetSANN. The convolutional layer of HGT we use is [hgtconv](https://docs.dgl.ai/generated/dgl.nn.pytorch.conv.HGTConv.html?highlight=hgtconv#dgl.nn.pytorch.conv.HGTConv). The **\_\_init\_\_** parameters can be different as the models need different parameters. +The parameters of the **forward** part are the same: `g` is the homogeneous graph, `h` is the features, `ntype` denotes the type of each node, `etype` denotes the type of each edge, `presorted` tells if the `ntype` or `etype` is presorted to use [TypedLinear](https://docs.dgl.ai/generated/dgl.nn.pytorch.TypedLinear.html) in **dgl.nn** conveniently. If we use [dgl.to_homogeneous](https://docs.dgl.ai/generated/dgl.to_homogeneous.html?highlight=to_homogeneous#dgl.to_homogeneous) to get the features, the features are presorted. + +- Then, we use the convolution layers to implement corresponding models. We need [dgl.to_homogeneous](https://docs.dgl.ai/generated/dgl.to_homogeneous.html?highlight=to_homogeneous#dgl.to_homogeneous) to get a homogeneous graph as when we use [edge_softmax](https://docs.dgl.ai/generated/dgl.nn.functional.edge_softmax.html?highlight=edge_softmax), we put all the edges together to calculate the attention coefficient instead of distinguishing the type of edges. +- After passing the convolution layers, we need to convert the output features to a feature dictionary in a heterogeneous graph. We designed a tool in **openhgnn.utils.utils.py** named **to_hetero_feat**. This is because we do not have a better solution to get a feature dictionay using **dgl**. We can only use [dgl.to_heterogeneous](https://docs.dgl.ai/generated/dgl.to_heterogeneous.html), but it has many additional operations to make the programs slowly. After we get the feature dictionary, the model is complete. + +### Dual-Aggregation model + +- We refer to the idea of the implementation of [dgl.nn.HeteroGraphConv](https://docs.dgl.ai/generated/dgl.nn.pytorch.HeteroGraphConv.html?highlight=heterographconv#dgl.nn.pytorch.HeteroGraphConv). We extract the relationship subgraph based on the edge type and complete the aggregation using the convoluntion layers. Then, to aggregate type-specific features across different relations we have to compute attention coefficients step by step. + ## How to run - Clone the Openhgnn-DGL ```bash # For node classification task + # You may select model HGT, SimpleHGN, HetSANN python main.py -m HGT -t node_classification -d imdb4MAGNN -g 0 --use_best_config ``` @@ -48,21 +151,81 @@ $$ #### Task: Node classification -Evaluation metric: accuracy - -| Dataset | HGBn-ACM | acm4GTN | imdb4MAGNN | dblp4MAGNN | -| -------- | ----- | ----- | ----- | ----- | -| Macro_f1 | 89.18 | 90.24 | 49.18 | 85.35 | -| Micro_f1 | 88.95 | 90.21 | 49.37 | 87.20 | - - - -## TrainerFlow: [node classification flow](../../trainerflow/#Node_classification_flow) +Evaluation metric: Micro/Macro-F1 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
HGBn-ACMacm4GTNimdb4MAGNNdblp4MAGNN
ModelMicro-F1Macro-F1Micro-F1Macro-F1Micro-F1Macro-F1Micro-F1Macro-F1
HGT88.9589.1890.2190.2449.3749.1887.2386.46
SimpleHGN92.2792.3689.2789.2852.2548.7887.7287.08
HetSANN88.488.792.2492.3152.8847.4489.5490.24
ie-HGCN91.7191.9992.4792.5655.0352.1888.3687.37
+ +## TrainerFlow: [node classification flow](../../trainerflow/node_classification.py) ## Hyper-parameter specific to the model -You can modify the parameters[HGT] in openhgnn/config.ini. - +You can modify the parameters [HGT], [SimpleHGN], [HetSANN], [ieHGCN] in openhgnn/config.ini. ## More #### Contirbutor