Skip to content

Commit

Permalink
change partition
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Sep 12, 2024
1 parent b3c1be5 commit 3834358
Showing 1 changed file with 24 additions and 31 deletions.
55 changes: 24 additions & 31 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,27 +93,17 @@ def process_partitions(g, formats=None, sort_etypes=False):
1. format data types.
2. sort csc/csr by tag.
"""
ndata = (
g.node_attributes
if isinstance(g, gb.FusedCSCSamplingGraph)
else g.ndata
)
edata = (
g.edge_attributes
if isinstance(g, gb.FusedCSCSamplingGraph)
else g.edata
)
for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in g.ndata:
ndata[k] = F.astype(ndata[k], dtype)
g.ndata[k] = F.astype(g.ndata[k], dtype)
if k in g.edata:
edata[k] = F.astype(edata[k], dtype)
g.edata[k] = F.astype(g.edata[k], dtype)

if (sort_etypes) and (formats is not None):
if "csr" in formats:
g = sort_csr_by_tag(g, tag=edata[ETYPE], tag_type="edge")
g = sort_csr_by_tag(g, tag=g.edata[ETYPE], tag_type="edge")
if "csc" in formats:
g = sort_csc_by_tag(g, tag=edata[ETYPE], tag_type="edge")
g = sort_csc_by_tag(g, tag=g.edata[ETYPE], tag_type="edge")
return g


Expand Down Expand Up @@ -1506,6 +1496,8 @@ def get_homogeneous(g, balance_ntypes):
**kwargs,
)
)
part_metadata["node_map_dtype"] = "int64"
part_metadata["edge_map_dtype"] = "int64"
else:
for part_id, part in parts.items():
part_dir = os.path.join(out_path, "part" + str(part_id))
Expand Down Expand Up @@ -1698,12 +1690,12 @@ def gb_convert_single_dgl_partition(
ntypes,
etypes,
gpb,
graph_formats,
store_eids,
store_inner_node,
store_inner_edge,
part_meta,
graph,
graph_formats=None,
store_eids=False,
store_inner_node=False,
store_inner_edge=False,
):
"""Converts a single DGL partition to GraphBolt.
Expand All @@ -1715,6 +1707,10 @@ def gb_convert_single_dgl_partition(
The edge types
gpb : GraphPartitionBook
The global partition information.
part_meta : dict
Contain the meta data of the partition.
graph : DGLGraph
The graph to be converted to graphbolt graph.
graph_formats : str or list[str], optional
Save partitions in specified formats. It could be any combination of
`coo`, `csc`. As `csc` format is mandatory for `FusedCSCSamplingGraph`,
Expand All @@ -1728,10 +1724,6 @@ def gb_convert_single_dgl_partition(
Whether to store inner node mask in the new graph. Default: False.
store_inner_edge : bool, optional
Whether to store inner edge mask in the new graph. Default: False.
part_meta : dict
Contain the meta data of the partition.
graph : DGLGraph
The graph to be converted to graphbolt graph.
"""
debug_mode = "DGL_DIST_DEBUG" in os.environ
if debug_mode:
Expand Down Expand Up @@ -1812,16 +1804,17 @@ def gb_convert_single_dgl_partition(
return csc_graph


def convert_partition_to_graphbolt_multi_process(
def _convert_partition_to_graphbolt(
part_config,
part_id,
graph_formats,
store_eids,
store_inner_node,
store_inner_edge,
graph_formats=None,
store_eids=False,
store_inner_node=False,
store_inner_edge=False,
):
"""
Convert signle partition to graphbolt, which is used for multiple process.
The pipeline converting signle partition to graphbolt.
Parameters
----------
part_config : str
Expand Down Expand Up @@ -1867,7 +1860,7 @@ def convert_partition_to_graphbolt_multi_process(
return rel_path


def _convert_partition_to_graphbolt(
def _convert_partition_to_graphbolt_wrapper(
graph_formats,
part_config,
store_eids,
Expand All @@ -1887,7 +1880,7 @@ def _convert_partition_to_graphbolt(

# Iterate over partitions.
convert_with_format = partial(
convert_partition_to_graphbolt_multi_process,
_convert_partition_to_graphbolt,
part_config=part_config,
graph_formats=graph_formats,
store_eids=store_eids,
Expand Down Expand Up @@ -1979,7 +1972,7 @@ def dgl_partition_to_graphbolt(
)
part_meta = _load_part_config(part_config)
num_parts = part_meta["num_parts"]
part_meta = _convert_partition_to_graphbolt(
part_meta = _convert_partition_to_graphbolt_wrapper(
graph_formats=graph_formats,
part_config=part_config,
store_eids=store_eids,
Expand Down

0 comments on commit 3834358

Please sign in to comment.