Skip to content

Commit

Permalink
Expose sorted to python interface
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-dlasalle committed Feb 19, 2021
1 parent d824573 commit 9ab524b
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 12 deletions.
34 changes: 29 additions & 5 deletions python/dgl/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def graph(data,
num_nodes=None,
idtype=None,
device=None,
row_sorted=False,
col_sorted=False,
check_sorted=True,
**deprecated_kwargs):
"""Create a graph and return.
Expand Down Expand Up @@ -72,7 +75,11 @@ def graph(data,
the :attr:`data` argument. If :attr:`data` is not a tuple of node-tensors, the
returned graph is on CPU. If the specified :attr:`device` differs from that of the
provided tensors, it casts the given tensors to the specified device first.
row_sorted : bool, optional
Whether or not the rows of the COO are in ascending order.
col_sorted : bool, optional
Whether or not the columns of the COO are in ascending order within
each row. This only has an effect when ``row_sorted`` is True.
Returns
-------
DGLGraph
Expand Down Expand Up @@ -158,7 +165,9 @@ def graph(data,
' but got {} and {}.'.format(num_nodes, max(urange, vrange) - 1))
urange, vrange = num_nodes, num_nodes

g = create_from_edges(u, v, '_N', '_E', '_N', urange, vrange)
g = create_from_edges(u, v, '_N', '_E', '_N', urange, vrange,
row_sorted=row_sorted, col_sorted=col_sorted,
check_sorted=check_sorted)

return g.to(device)

Expand Down Expand Up @@ -926,7 +935,7 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals
eids.append(F.arange(0, num_edges, G.idtype, G.device))

retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), num_nodes=total_num_nodes,
idtype=G.idtype, device=G.device)
idtype=G.idtype, device=G.device, check_sorted=False)

# copy features
if ndata is None:
Expand Down Expand Up @@ -1590,7 +1599,10 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
def create_from_edges(u, v,
utype, etype, vtype,
urange, vrange,
validate=True):
validate=True,
row_sorted=False,
col_sorted=False,
check_sorted=False):
"""Internal function to create a graph from incident nodes with types.
utype could be equal to vtype
Expand All @@ -1615,6 +1627,17 @@ def create_from_edges(u, v,
maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional
If True, checks if node IDs are within range.
row_sorted : bool, optional
Whether or not the rows of the COO are in ascending order.
col_sorted : bool, optional
Whether or not the columns of the COO are in ascending order within
each row. This only has an effect when ``row_sorted`` is True.
check_sorted : bool, optional
If this is ``True`` and ``row_sorted`` is ``False``, the edge list will
be scanned to see if it is in ascending order, and the resulting graph
will be marked accordingly.
Returns
-------
Expand All @@ -1636,7 +1659,8 @@ def create_from_edges(u, v,
num_ntypes = 2

hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, urange, vrange, u, v, ['coo', 'csr', 'csc'])
num_ntypes, urange, vrange, u, v, ['coo', 'csr', 'csc'],
row_sorted, col_sorted, check_sorted)
if utype == vtype:
return DGLHeteroGraph(hgidx, [utype], [etype])
else:
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def merge_graphs(res_list, num_nodes):
src_tensor = res_list[0].global_src
dst_tensor = res_list[0].global_dst
eid_tensor = res_list[0].global_eids
g = graph((src_tensor, dst_tensor), num_nodes=num_nodes)
g = graph((src_tensor, dst_tensor), num_nodes=num_nodes, check_sorted=False)
g.edata[EID] = eid_tensor
return g

Expand Down
14 changes: 12 additions & 2 deletions python/dgl/heterograph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,8 @@ def induced_edges(self):
#################################################################

def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col,
formats):
formats, row_sorted=False, col_sorted=False,
check_sorted=False):
"""Create a unitgraph graph index from COO format
Parameters
Expand All @@ -987,6 +988,15 @@ def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col,
Col index.
formats : list of str.
Restrict the storage formats allowed for the unit graph.
row_sorted : bool, optional
Whether or not the rows of the COO are in ascending order.
col_sorted : bool, optional
Whether or not the columns of the COO are in ascending order within
each row. This only has an effect when ``row_sorted`` is True.
check_sorted : bool, optional
If this is ``True`` and ``row_sorted`` is ``False``, the edge list will
be scanned to see if it is in ascending order, and the resulting graph
will be marked accordingly.
Returns
-------
Expand All @@ -997,7 +1007,7 @@ def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col,
return _CAPI_DGLHeteroCreateUnitGraphFromCOO(
int(num_ntypes), int(num_src), int(num_dst),
F.to_dgl_nd(row), F.to_dgl_nd(col),
formats)
formats, row_sorted, col_sorted, check_sorted)

def create_unitgraph_from_csr(num_ntypes, num_src, num_dst, indptr, indices, edge_ids,
formats):
Expand Down
12 changes: 8 additions & 4 deletions src/graph/heterograph_capi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,21 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
IdArray row = args[3];
IdArray col = args[4];
List<Value> formats = args[5];
bool row_sorted = args[6];
bool col_sorted = args[7];
bool check_sorted = args[8];
std::vector<SparseFormat> formats_vec;
for (Value val : formats) {
std::string fmt = val->data;
formats_vec.push_back(ParseSparseFormat(fmt));
}
auto code = SparseFormatsToCode(formats_vec);

// setup sorted flags
bool row_sorted, col_sorted;
std::tie(row_sorted, col_sorted) = COOIsSorted(
aten::COOMatrix(num_src, num_dst, row, col));
if (!row_sorted && check_sorted) {
// setup sorted flags
std::tie(row_sorted, col_sorted) = COOIsSorted(
aten::COOMatrix(num_src, num_dst, row, col));
}

auto hgptr = CreateFromCOO(nvtypes, num_src, num_dst, row, col,
row_sorted, col_sorted, code);
Expand Down

0 comments on commit 9ab524b

Please sign in to comment.