Skip to content

Commit

Permalink
fix the bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Mar 24, 2024
1 parent 7bd258f commit bf587d4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/dgl/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def create_block(
data,
idtype,
bipartite=True,
infer_node_count=False,
infer_node_count=need_infer,
)
node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays)
if need_infer:
Expand Down
12 changes: 4 additions & 8 deletions python/dgl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,34 +199,30 @@ def graphdata2tensors(
num_src, num_dst = (
infer_num_nodes(data, bipartite=bipartite)
if infer_node_count
else None,
None,
else (None, None)
)
elif isinstance(data, list):
src, dst = elist2tensor(data, idtype)
data = SparseAdjTuple("coo", (src, dst))
num_src, num_dst = (
infer_num_nodes(data, bipartite=bipartite)
if infer_node_count
else None,
None,
else (None, None)
)
elif isinstance(data, sp.sparse.spmatrix):
# We can get scipy matrix's number of rows and columns easily.
num_src, num_dst = (
infer_num_nodes(data, bipartite=bipartite)
if infer_node_count
else None,
None,
else (None, None)
)
data = scipy2tensor(data, idtype)
elif isinstance(data, nx.Graph):
# We can get networkx graph's number of sources and destinations easily.
num_src, num_dst = (
infer_num_nodes(data, bipartite=bipartite)
if infer_node_count
else None,
None,
else (None, None)
)
edge_id_attr_name = kwargs.get("edge_id_attr_name", None)
if bipartite:
Expand Down

0 comments on commit bf587d4

Please sign in to comment.