-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GraphBolt][CUDA] sampled_edge_ids
to fetch indices later.
#7664
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,8 @@ class SampledSubgraphImpl(SampledSubgraph): | |
] = None | ||
original_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None | ||
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None | ||
# Used to fetch sampled_csc.indices if it is missing. | ||
_sampled_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: _edge_ids_in_fused_csc_sampling_graph The name has to be clear since it is indeed non-intuitive. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
def __post_init__(self): | ||
if isinstance(self.sampled_csc, dict): | ||
|
@@ -53,22 +55,34 @@ def __post_init__(self): | |
isinstance(etype, str) | ||
and len(etype_str_to_tuple(etype)) == 3 | ||
), "Edge type should be a string in format of str:str:str." | ||
assert ( | ||
pair.indptr is not None and pair.indices is not None | ||
), "Node pair should be have indptr and indice." | ||
assert isinstance(pair.indptr, torch.Tensor) and isinstance( | ||
pair.indices, torch.Tensor | ||
), "Nodes in pairs should be of type torch.Tensor." | ||
assert pair.indptr is not None and isinstance( | ||
pair.indptr, torch.Tensor | ||
), "Node pair should be have indptr of type torch.Tensor." | ||
# For CUDA, indices may be None because it will be fetched later. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am wondering how much performance we can gain by asynchronously fetch indices, esp, we fetch original_edge_ids in the sync way. Please make balance on performance gain and code structure complexity. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 25% There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 25% is a lot so we needed to do it. |
||
if not pair.indptr.is_cuda or pair.indices is not None: | ||
assert isinstance( | ||
pair.indices, torch.Tensor | ||
), "Node pair should be have indices of type torch.Tensor." | ||
else: | ||
assert isinstance( | ||
self._sampled_edge_ids.get(etype, None), torch.Tensor | ||
), "When indices is missing, sampled edge ids needs to be provided." | ||
else: | ||
assert ( | ||
self.sampled_csc.indptr is not None | ||
and self.sampled_csc.indices is not None | ||
), "Node pair should be have indptr and indice." | ||
assert isinstance( | ||
assert self.sampled_csc.indptr is not None and isinstance( | ||
self.sampled_csc.indptr, torch.Tensor | ||
) and isinstance( | ||
self.sampled_csc.indices, torch.Tensor | ||
), "Nodes in pairs should be of type torch.Tensor." | ||
), "Node pair should be have torch.Tensor indptr." | ||
# For CUDA, indices may be None because it will be fetched later. | ||
if ( | ||
not self.sampled_csc.indptr.is_cuda | ||
or self.sampled_csc.indices is not None | ||
): | ||
assert isinstance( | ||
self.sampled_csc.indices, torch.Tensor | ||
), "Node pair should have a torch.Tensor indices." | ||
else: | ||
assert isinstance( | ||
self._sampled_edge_ids, torch.Tensor | ||
), "When indices is missing, sampled edge ids needs to be provided." | ||
|
||
def __repr__(self) -> str: | ||
return _sampled_subgraph_str(self, "SampledSubgraphImpl") | ||
|
@@ -81,6 +95,8 @@ def _sampled_subgraph_str(sampled_subgraph: SampledSubgraph, classname) -> str: | |
attributes.reverse() | ||
|
||
for name in attributes: | ||
if name in "_sampled_edge_ids": | ||
continue | ||
val = getattr(sampled_subgraph, name) | ||
|
||
def _add_indent(_str, indent): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also rename the original_edge_ids in C_sampled_subgraph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#7703 I updated the doc string to clarify. It is indeed hard to rename.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sampled_edge_ids -> edge_ids_in_fused_csc_sampling_graph.