-
Notifications
You must be signed in to change notification settings - Fork 141
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [Model] Update Attention Network * [Model] Update Attention Network * [Template] Create pr and issue template * [Model] Create model ieHGCN * [Model] Update Model ieHGCN * [Model] Implement model HGAT * [Model] Implement HGAT * [Model]Update Attention Network * [Model]Update init * [Model]Update init * [Model]Update init * [Model]Update init * [Docs]Update docs * [Docs]Update docs * [Model]Update model * [Model]Update model * [Model]Update init * [Docs]Update docs * [Model]Update init * [Docs]Update docs * [Model]Update init * [Model]Update init * [Model]Update init Co-authored-by: dddg617 <996179900@qq.com> Co-authored-by: dddg617 <75086617+dddg617@users.noreply.github.com>
- Loading branch information
1 parent
ad77ad6
commit 89fd328
Showing
14 changed files
with
1,148 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
--- | ||
name: "\U0001F41B Bug Report" | ||
about: Submit a bug report to help us improve OpenHGNN | ||
|
||
--- | ||
|
||
## 🐛 Bug | ||
|
||
<!-- A clear and concise description of what the bug is. --> | ||
|
||
## To Reproduce | ||
|
||
Steps to reproduce the behavior: | ||
|
||
1. | ||
1. | ||
1. | ||
|
||
<!-- If you have a code sample, error messages, stack traces, please provide it here as well --> | ||
|
||
## Expected behavior | ||
|
||
<!-- A clear and concise description of what you expected to happen. --> | ||
|
||
## Environment | ||
|
||
- OpenHGNN Version (e.g., 1.0): | ||
- Backend Library & Version (e.g., PyTorch 0.4.1, DGL 0.7.0): | ||
- OS (e.g., Linux): | ||
- Running command you used (e.g., python main.py -m GTN -d imdb4GTN -t node_classification -g 0 --use_best_config): | ||
- Model configuration you used (e.g., details of the model configuration you used in [config.ini](../../openhgnn/config.ini)): | ||
<!-- | ||
[HGT] | ||
seed = 0 | ||
learning_rate = 0.01 | ||
weight_decay = 0.0001 | ||
dropout = 0.4 | ||
batch_size = 5120 | ||
patience =40 | ||
hidden_dim = 64 | ||
out_dim = 16 | ||
num_layers = 2 | ||
num_heads = 2 | ||
num_workers = 64 | ||
max_epoch = 200 | ||
mini_batch_flag = False | ||
norm = True | ||
--> | ||
- Python version: | ||
- CUDA/cuDNN version (if applicable): | ||
- GPU models and configuration (e.g. V100): | ||
- Any other relevant information: | ||
|
||
## Additional context | ||
|
||
<!-- Add any other context about the problem here. --> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
--- | ||
name: "\U0001F680Feature Request" | ||
about: Submit a proposal/request for a new OpenHGNN feature | ||
|
||
--- | ||
|
||
## 🚀 Feature | ||
<!-- A brief description of the feature proposal --> | ||
|
||
## Motivation | ||
|
||
<!-- Please outline the motivation for the proposal. Is your feature request | ||
related to a problem? e.g., I'm always frustrated when [...]. If this is | ||
related to another GitHub issue, please link here too --> | ||
|
||
## Alternatives | ||
|
||
<!-- A clear and concise description of any alternative solutions or features you've considered, if any. --> | ||
|
||
## Pitch | ||
|
||
<!-- A clear and concise description of what you want to happen. --> | ||
|
||
## Additional context | ||
|
||
<!-- Add any other context or screenshots about the feature request here. --> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
--- | ||
name: "❓Questions/Help/Support" | ||
about: Do you need support? We have resources. | ||
|
||
--- | ||
|
||
## ❓ Questions and Help | ||
|
||
<!-- If you have any questions, please feel free to ask. --> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
## Description | ||
<!-- Brief description. Refer to the related issues if existed. | ||
It'll be great if relevant reviewers can be assigned as well.--> | ||
|
||
## Checklist | ||
Please feel free to remove inapplicable items for your PR. | ||
- [ ] The PR title starts with [$CATEGORY] (such as [NN], [Model], [Doc], [Feature]]) | ||
- [ ] Changes are complete (i.e. I finished coding on this PR) | ||
- [ ] All changes have test coverage | ||
- [ ] Code is well-documented | ||
- [ ] To the best of my knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change | ||
- [ ] Related issue is referred in this PR | ||
- [ ] If the PR is for a new model/paper, I've updated the example index [here](../README.md). | ||
|
||
## Changes | ||
<!-- You could use following template | ||
- [ ] Feature1, tests, (and when applicable, API doc) | ||
- [ ] Feature2, tests, (and when applicable, API doc) | ||
--> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import dgl | ||
import torch | ||
import torch.nn as nn | ||
import dgl.function as Fn | ||
import torch.nn.functional as F | ||
|
||
from dgl.ops import edge_softmax, segment_softmax | ||
from dgl.nn import HeteroLinear, TypedLinear | ||
from dgl.nn.pytorch.conv import GraphConv | ||
from . import BaseModel, register_model | ||
from ..utils import to_hetero_feat | ||
|
||
@register_model('HGAT') | ||
class HGAT(BaseModel): | ||
@classmethod | ||
def build_model_from_args(cls, args, hg): | ||
return cls(args.num_layers, | ||
args.in_dim, | ||
args.hidden_dim, | ||
args.attn_dim, | ||
args.num_classes, | ||
hg.ntypes, | ||
args.negative_slope) | ||
|
||
def __init__(self, num_layers, in_dim, hidden_dim, attn_dim, | ||
num_classes, ntypes, negative_slope): | ||
super(HGAT, self).__init__() | ||
self.num_layers = num_layers | ||
self.activation = F.elu | ||
|
||
|
||
self.hgat_layers = nn.ModuleList() | ||
self.hgat_layers.append( | ||
TypeAttention(in_dim, | ||
attn_dim, | ||
ntypes, | ||
negative_slope)) | ||
self.hgat_layers.append( | ||
NodeAttention(in_dim, | ||
attn_dim, | ||
hidden_dim, | ||
negative_slope) | ||
) | ||
for l in range(num_layers - 1): | ||
self.hgat_layers.append( | ||
TypeAttention(hidden_dim, | ||
attn_dim, | ||
ntypes, | ||
negative_slope)) | ||
self.hgat_layers.append( | ||
NodeAttention(hidden_dim, | ||
attn_dim, | ||
hidden_dim, | ||
negative_slope) | ||
) | ||
|
||
self.hgat_layers.append( | ||
TypeAttention(hidden_dim, | ||
attn_dim, | ||
ntypes, | ||
negative_slope)) | ||
self.hgat_layers.append( | ||
NodeAttention(hidden_dim, | ||
attn_dim, | ||
num_classes, | ||
negative_slope) | ||
) | ||
|
||
|
||
def forward(self, hg, h_dict): | ||
with hg.local_scope(): | ||
hg.ndata['h'] = h_dict | ||
for l in range(self.num_layers): | ||
attention = self.hgat_layers[2 * l](hg, hg.ndata['h']) | ||
hg.edata['alpha'] = attention | ||
g = dgl.to_homogeneous(hg, ndata = 'h', edata = ['alpha']) | ||
h = self.hgat_layers[2 * l + 1](g, g.ndata['h'], g.ndata['_TYPE'], g.ndata['_TYPE'], presorted = True) | ||
h_dict = to_hetero_feat(h, g.ndata['_TYPE'], hg.ntypes) | ||
hg.ndata['h'] = h_dict | ||
|
||
return h_dict | ||
|
||
class TypeAttention(nn.Module): | ||
def __init__(self, in_dim, ntypes, slope): | ||
super(TypeAttention, self).__init__() | ||
attn_vector = {} | ||
for ntype in ntypes: | ||
attn_vector[ntype] = in_dim | ||
self.mu_l = HeteroLinear(attn_vector, in_dim) | ||
self.mu_r = HeteroLinear(attn_vector, in_dim) | ||
self.leakyrelu = nn.LeakyReLU(slope) | ||
|
||
def forward(self, hg, h_dict): | ||
h_t = {} | ||
attention = {} | ||
with hg.local_scope(): | ||
hg.ndata['h'] = h_dict | ||
for srctype, etype, dsttype in hg.canonical_etypes: | ||
rel_graph = hg[srctype, etype, dsttype] | ||
if srctype not in h_dict: | ||
continue | ||
with rel_graph.local_scope(): | ||
degs = rel_graph.out_degrees().float().clamp(min = 1) | ||
norm = torch.pow(degs, -0.5) | ||
feat_src = h_dict[srctype] | ||
shp = norm.shape + (1,) * (feat_src.dim() - 1) | ||
norm = torch.reshape(norm, shp) | ||
feat_src = feat_src * norm | ||
rel_graph.srcdata['h'] = feat_src | ||
rel_graph.update_all(Fn.copy_src('h', 'm'), Fn.sum(msg='m', out='h')) | ||
rst = rel_graph.dstdata['h'] | ||
degs = rel_graph.in_degrees().float().clamp(min=1) | ||
norm = torch.pow(degs, -0.5) | ||
shp = norm.shape + (1,) * (feat_src.dim() - 1) | ||
norm = torch.reshape(norm, shp) | ||
rst = rst * norm | ||
h_t[srctype] = rst | ||
h_l = self.mu_l(h_dict)[dsttype] | ||
h_r = self.mu_r(h_t)[srctype] | ||
edge_attention = F.elu(h_l + h_r) | ||
# edge_attention = F.elu(h_l + h_r).unsqueeze(0) | ||
rel_graph.ndata['m'] = {dsttype: edge_attention, | ||
srctype: torch.zeros((rel_graph.num_nodes(ntype = srctype),))} | ||
# print(rel_graph.ndata) | ||
reverse_graph = dgl.reverse(rel_graph) | ||
reverse_graph.apply_edges(Fn.copy_src('m', 'alpha')) | ||
|
||
hg.edata['alpha'] = {(srctype, etype, dsttype): reverse_graph.edata['alpha']} | ||
|
||
# if dsttype not in attention.keys(): | ||
# attention[dsttype] = edge_attention | ||
# else: | ||
# attention[dsttype] = torch.cat((attention[dsttype], edge_attention)) | ||
attention = edge_softmax(hg, hg.edata['alpha']) | ||
# for ntype in hg.dsttypes: | ||
# attention[ntype] = F.softmax(attention[ntype], dim = 0) | ||
|
||
return attention | ||
|
||
class NodeAttention(nn.Module): | ||
def __init__(self, in_dim, out_dim, slope): | ||
super(NodeAttention, self).__init__() | ||
self.in_dim = in_dim | ||
self.out_dim = out_dim | ||
self.Mu_l = nn.Linear(in_dim, in_dim) | ||
self.Mu_r = nn.Linear(in_dim, in_dim) | ||
self.leakyrelu = nn.LeakyReLU(slope) | ||
|
||
def forward(self, g, x, ntype, etype, presorted = False): | ||
with g.local_scope(): | ||
src = g.edges()[0] | ||
dst = g.edges()[1] | ||
h_l = self.Mu_l(x)[src] | ||
h_r = self.Mu_r(x)[dst] | ||
edge_attention = self.leakyrelu((h_l + h_r) * g.edata['alpha']) | ||
edge_attention = edge_softmax(g, edge_attention) | ||
g.edata['alpha'] = edge_attention | ||
g.srcdata['x'] = x | ||
g.update_all(Fn.u_mul_e('x', 'alpha', 'm'), | ||
Fn.sum('m', 'x')) | ||
h = g.ndata['x'] | ||
return h |
Oops, something went wrong.