-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Example] Temporal Graph Neural Network #2636
Conversation
examples/pytorch/tgn/data.py
Outdated
new_df.i += 1 | ||
new_df.idx += 1 | ||
|
||
return new_df |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the code is adapted from original code, please add a commet: https://github.com/dmlc/dgl/blob/master/examples/pytorch/pointcloud/pointnet/pointnet2.py#L10
examples/pytorch/tgn/data.py
Outdated
train_div = int(0.7*num_edges) | ||
valid_div = int(0.85*num_edges) | ||
|
||
self.nn_test_g = dgl.edge_subgraph(self.g,range(valid_div,num_edges)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nn -> new_node
examples/pytorch/tgn/data.py
Outdated
edge_ids = range(self.batch_cnt*self.batch_size,min((self.batch_cnt+1)*self.batch_size,graph.num_edges())) | ||
subgraph = dgl.edge_subgraph(graph,edge_ids) | ||
working_d = DictNode(parent=p_dict,NIDdict=subgraph.ndata[dgl.NID]) | ||
subgraph.ndata[dgl.NID] = torch.from_numpy(working_d.GetRootID(range(subgraph.num_nodes()))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't use dgl.NID. Currently the behavior of dgl.NID is not clearly defined and this is a dgl system setting. After we get clear the behavior of dgl.NID in system level, we may revisit here.
examples/pytorch/tgn/data.py
Outdated
map_index = self.NIDdict[index] | ||
return self.parent.GetRootID(map_index) | ||
|
||
class TemporalDataLoader: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to inherit from pytorch dataloader.
examples/pytorch/tgn/tgn.py
Outdated
|
||
# Fake link embedding and compute score | ||
new_neg,neg_subg = self.node_sampler(neg,ts,mode) | ||
n_ts = ts.repeat(neg_subg.num_nodes()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the correct timestamp.
examples/pytorch/tgn/tgn.py
Outdated
pred_pos = self.linkpredictor(embedding[pos_src_id],embedding[pos_dst_id]) | ||
pred_neg = self.linkpredictor(embedding[pos_src_id],embedding[neg_id]) | ||
subg = dgl.remove_self_loop(subg) | ||
subg = dgl.add_self_loop(subg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why self_loop?
Add a simple temporal graph sampling method. This method can support k-hop sampling while others can not. This method also achieve good performance and speed.
Description
Checklist
Please feel free to remove inapplicable items for your PR.
or have been fixed to be compatible with this change
Changes