Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NN] Fix GATConv for Broadcasting with Residual Connections #2867

Merged
merged 3 commits into from
Apr 27, 2021
Merged

[NN] Fix GATConv for Broadcasting with Residual Connections #2867

merged 3 commits into from
Apr 27, 2021

Conversation

mufeili
Copy link
Member

@mufeili mufeili commented Apr 25, 2021

Description

Previously the forward computation of GATConv fails if in_feats=out_feats and residual=True.

Checklist

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [$CATEGORY] (such as [NN], [Model], [Doc], [Feature]])
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented
  • To the my best knowledge, examples are either not affected by this change,
    or have been fixed to be compatible with this change
  • Related issue is referred in this PR
  • If the PR is for a new model/paper, I've updated the example index here.

@dgl-bot
Copy link
Collaborator

dgl-bot commented Apr 25, 2021

To trigger regression tests:

  • @dgl-bot run [instance-type] [which tests] [compare-with-branch];
    For example: @dgl-bot run g4dn.4xlarge all dmlc/master or @dgl-bot run c5.9xlarge kernel,api dmlc/master

@@ -304,7 +304,8 @@ def forward(self, graph, feat, get_attention=False):
rst = graph.dstdata['ft']
# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst).view(h_dst.shape[0], self._num_heads, self._out_feats)
# Use -1 rather than self._num_heads to handle broadcasting
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain why does it work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use self._num_heads, then we will get an error with view. If we use -1, it will be 1 in this case. We can add a tensor of shape (A, 1, C) to a tensor of shape (A, B, C).

@BarclayII BarclayII merged commit e18c2ab into dmlc:master Apr 27, 2021
@mufeili mufeili deleted the fix_gat branch April 27, 2021 06:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants