diff --git a/python/dgl/convert.py b/python/dgl/convert.py index a5ea39a84321..6f16a0aef95a 100644 --- a/python/dgl/convert.py +++ b/python/dgl/convert.py @@ -37,6 +37,9 @@ def graph(data, num_nodes=None, idtype=None, device=None, + row_sorted=False, + col_sorted=False, + check_sorted=True, **deprecated_kwargs): """Create a graph and return. @@ -72,7 +75,11 @@ def graph(data, the :attr:`data` argument. If :attr:`data` is not a tuple of node-tensors, the returned graph is on CPU. If the specified :attr:`device` differs from that of the provided tensors, it casts the given tensors to the specified device first. - + row_sorted : bool, optional + Whether or not the rows of the COO are in ascending order. + col_sorted : bool, optional + Whether or not the columns of the COO are in ascending order within + each row. This only has an effect when ``row_sorted`` is True. Returns ------- DGLGraph @@ -158,7 +165,9 @@ def graph(data, ' but got {} and {}.'.format(num_nodes, max(urange, vrange) - 1)) urange, vrange = num_nodes, num_nodes - g = create_from_edges(u, v, '_N', '_E', '_N', urange, vrange) + g = create_from_edges(u, v, '_N', '_E', '_N', urange, vrange, + row_sorted=row_sorted, col_sorted=col_sorted, + check_sorted=check_sorted) return g.to(device) @@ -926,7 +935,7 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals eids.append(F.arange(0, num_edges, G.idtype, G.device)) retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), num_nodes=total_num_nodes, - idtype=G.idtype, device=G.device) + idtype=G.idtype, device=G.device, check_sorted=False) # copy features if ndata is None: @@ -1590,7 +1599,10 @@ def to_networkx(g, node_attrs=None, edge_attrs=None): def create_from_edges(u, v, utype, etype, vtype, urange, vrange, - validate=True): + validate=True, + row_sorted=False, + col_sorted=False, + check_sorted=False): """Internal function to create a graph from incident nodes with types. utype could be equal to vtype @@ -1615,6 +1627,17 @@ def create_from_edges(u, v, maximum of the destination node IDs in the edge list plus 1. (Default: None) validate : bool, optional If True, checks if node IDs are within range. + row_sorted : bool, optional + Whether or not the rows of the COO are in ascending order. + col_sorted : bool, optional + Whether or not the columns of the COO are in ascending order within + each row. This only has an effect when ``row_sorted`` is True. + check_sorted : bool, optional + If this is ``True`` and ``row_sorted`` is ``False``, the edge list will + be scanned to see if it is in ascending order, and the resulting graph + will be marked accordingly. + + Returns ------- @@ -1636,7 +1659,8 @@ def create_from_edges(u, v, num_ntypes = 2 hgidx = heterograph_index.create_unitgraph_from_coo( - num_ntypes, urange, vrange, u, v, ['coo', 'csr', 'csc']) + num_ntypes, urange, vrange, u, v, ['coo', 'csr', 'csc'], + row_sorted, col_sorted, check_sorted) if utype == vtype: return DGLHeteroGraph(hgidx, [utype], [etype]) else: diff --git a/python/dgl/distributed/graph_services.py b/python/dgl/distributed/graph_services.py index f24e701d9d7e..aa62fd0d9de9 100644 --- a/python/dgl/distributed/graph_services.py +++ b/python/dgl/distributed/graph_services.py @@ -177,7 +177,7 @@ def merge_graphs(res_list, num_nodes): src_tensor = res_list[0].global_src dst_tensor = res_list[0].global_dst eid_tensor = res_list[0].global_eids - g = graph((src_tensor, dst_tensor), num_nodes=num_nodes) + g = graph((src_tensor, dst_tensor), num_nodes=num_nodes, check_sorted=False) g.edata[EID] = eid_tensor return g diff --git a/python/dgl/heterograph_index.py b/python/dgl/heterograph_index.py index 6bad76907068..c82eabb80207 100644 --- a/python/dgl/heterograph_index.py +++ b/python/dgl/heterograph_index.py @@ -970,7 +970,8 @@ def induced_edges(self): ################################################################# def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col, - formats): + formats, row_sorted=False, col_sorted=False, + check_sorted=False): """Create a unitgraph graph index from COO format Parameters @@ -987,6 +988,15 @@ def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col, Col index. formats : list of str. Restrict the storage formats allowed for the unit graph. + row_sorted : bool, optional + Whether or not the rows of the COO are in ascending order. + col_sorted : bool, optional + Whether or not the columns of the COO are in ascending order within + each row. This only has an effect when ``row_sorted`` is True. + check_sorted : bool, optional + If this is ``True`` and ``row_sorted`` is ``False``, the edge list will + be scanned to see if it is in ascending order, and the resulting graph + will be marked accordingly. Returns ------- @@ -997,7 +1007,7 @@ def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col, return _CAPI_DGLHeteroCreateUnitGraphFromCOO( int(num_ntypes), int(num_src), int(num_dst), F.to_dgl_nd(row), F.to_dgl_nd(col), - formats) + formats, row_sorted, col_sorted, check_sorted) def create_unitgraph_from_csr(num_ntypes, num_src, num_dst, indptr, indices, edge_ids, formats): diff --git a/src/graph/heterograph_capi.cc b/src/graph/heterograph_capi.cc index f630d93b80ba..c78ef7df73db 100644 --- a/src/graph/heterograph_capi.cc +++ b/src/graph/heterograph_capi.cc @@ -30,6 +30,9 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO") IdArray row = args[3]; IdArray col = args[4]; List formats = args[5]; + bool row_sorted = args[6]; + bool col_sorted = args[7]; + bool check_sorted = args[8]; std::vector formats_vec; for (Value val : formats) { std::string fmt = val->data; @@ -37,10 +40,11 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO") } auto code = SparseFormatsToCode(formats_vec); - // setup sorted flags - bool row_sorted, col_sorted; - std::tie(row_sorted, col_sorted) = COOIsSorted( - aten::COOMatrix(num_src, num_dst, row, col)); + if (!row_sorted && check_sorted) { + // setup sorted flags + std::tie(row_sorted, col_sorted) = COOIsSorted( + aten::COOMatrix(num_src, num_dst, row, col)); + } auto hgptr = CreateFromCOO(nvtypes, num_src, num_dst, row, col, row_sorted, col_sorted, code);