Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

heterograph.set_batch_num_edges could run automatically if batch_num_nodes is set #7498

Open
anlutfi opened this issue Jul 2, 2024 · 0 comments
Labels
feature request Feature request

Comments

@anlutfi
Copy link

anlutfi commented Jul 2, 2024

Since, in a graph batch(GB), there are no edges between individual graphs, once batch_num_nodes is set for GB, automatically calculate batch_num_edges by calling set_bath_num_edges with no arguments, or something like set_bath_num_edges(auto=True)

Motivation

I needed to make subgraphs of a GB, and I have to maintain batch info consistency. I figured out that, after calculating batch_num_nodes for the subgraph, the corresponding edges are all the edges that have source and destination in the same group in batch_num_nodes.

Example:
graph g is a new subgraph of a GB that has batch_num_nodes = [100, 100, 100].
To get the node ids for each individual graph in g, we perform cumulative sum (CS), such as CS = [100, 200, 300]. Nodes with indices < 100 are in the first graph, 100 <= indices < 200 are in the second, and 200 <= indices < 300 are in the third.

with these indices in hand, and the certainty that are no edges between nodes of different graphs in a batch, one can simply look at the source and dest of edges to determine to which batch they belong. So batch_num_edges comes for free.

I believe this feature is a good QOL improvement as it removes one source of user error when calculating batch_num_edges by hand.

Code that I'm using

bnn = sg.batch_num_nodes()

e_tail = torch.cumsum(bnn, dim=0) - 1
e_head = torch.cat([torch.tensor([0]).to(e_tail.device), e_tail[:-1] + 1])

source, dest = sg.edges()
source = source.unsqueeze(1).tile((1, len(e_tail)))
dest = dest.unsqueeze(1).tile((1, len(e_tail)))
mask = (source >= e_head) & (source <= e_tail) & (dest >= e_head) & (dest <= e_tail)
bne = torch.count_nonzero(mask, dim=0)
sg.set_batch_num_edges(bne)
@rudongyu rudongyu added the feature request Feature request label Jul 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request Feature request
Projects
Status: 🏠 Backlog
Development

No branches or pull requests

2 participants