Skip to content

Commit

Permalink
make the code simpler.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Mar 24, 2024
1 parent bf587d4 commit 47e46c2
Showing 1 changed file with 9 additions and 20 deletions.
29 changes: 9 additions & 20 deletions python/dgl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,39 +191,28 @@ def graphdata2tensors(
data.format, tuple(F.tensor(a) for a in data.arrays)
)

num_src, num_dst = None, None
if isinstance(data, SparseAdjTuple):
if idtype is not None:
data = SparseAdjTuple(
data.format, tuple(F.astype(a, idtype) for a in data.arrays)
)
num_src, num_dst = (
infer_num_nodes(data, bipartite=bipartite)
if infer_node_count
else (None, None)
)
if infer_node_count:
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
elif isinstance(data, list):
src, dst = elist2tensor(data, idtype)
data = SparseAdjTuple("coo", (src, dst))
num_src, num_dst = (
infer_num_nodes(data, bipartite=bipartite)
if infer_node_count
else (None, None)
)
if infer_node_count:
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
elif isinstance(data, sp.sparse.spmatrix):
# We can get scipy matrix's number of rows and columns easily.
num_src, num_dst = (
infer_num_nodes(data, bipartite=bipartite)
if infer_node_count
else (None, None)
)
if infer_node_count:
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
data = scipy2tensor(data, idtype)
elif isinstance(data, nx.Graph):
# We can get networkx graph's number of sources and destinations easily.
num_src, num_dst = (
infer_num_nodes(data, bipartite=bipartite)
if infer_node_count
else (None, None)
)
if infer_node_count:
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
edge_id_attr_name = kwargs.get("edge_id_attr_name", None)
if bipartite:
top_map = kwargs.get("top_map")
Expand Down

0 comments on commit 47e46c2

Please sign in to comment.