Skip to content

Commit

Permalink
add test_data & refine format
Browse files Browse the repository at this point in the history
  • Loading branch information
paoxiaode committed Sep 19, 2023
1 parent e3a5cba commit a9658dc
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 17 deletions.
9 changes: 8 additions & 1 deletion python/dgl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,18 @@
RomanEmpireDataset,
TolokersDataset,
)

# RDKit is required for OGB-LSC PCQM4Mv2 and the datasets derived from it.
# Exception handling was added to prevent crashes for users who are using other
# datasets.
try:
from .lrgb import PeptidesStructuralDataset
except ImportError:
pass
from .pattern import PATTERNDataset
from .wikics import WikiCSDataset
from .yelp import YelpDataset
from .zinc import ZINCDataset
from .lrgb import PeptidesStructuralDataset


def register_data_args(parser):
Expand Down
31 changes: 15 additions & 16 deletions python/dgl/data/lrgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from ogb.utils import smiles2graph
from tqdm import tqdm

import dgl
from .. import backend as F

from ..convert import graph as dgl_graph
from .dgl_dataset import DGLDataset
from .utils import download, load_graphs, save_graphs, Subset

Expand Down Expand Up @@ -57,7 +58,7 @@ class PeptidesStructuralDataset(DGLDataset):
Whether to print out progress information.
Default: False.
smiles2graph (callable):
A callable function that converts a SMILES string into a graph object. We use the OGB featurization.
A callable function that converts a SMILES string into a graph object.
* The default smiles2graph requires rdkit to be installed *
Examples
Expand Down Expand Up @@ -92,14 +93,19 @@ def __init__(
smiles2graph=smiles2graph,
):
self.smiles2graph = smiles2graph
self.version = "9786061a34298a0684150f2e4ff13f47" # MD5 hash of the intended dataset file
self.url_stratified_split = "https://www.dropbox.com/s/9dfifzft1hqgow6/splits_random_stratified_peptide_structure.pickle?dl=1"
# MD5 hash of the dataset file.
self.version = "9786061a34298a0684150f2e4ff13f47"
self.url_stratified_split = """
https://www.dropbox.com/s/9dfifzft1hqgow6/splits_random_stratified_peptide_structure.pickle?dl=1
"""
self.md5sum_stratified_split = "5a0114bdadc80b94fc7ae974f13ef061"

super(PeptidesStructuralDataset, self).__init__(
name="Peptides-struc",
raw_dir=raw_dir,
url="https://www.dropbox.com/s/464u3303eu2u4zp/peptide_structure_dataset.csv.gz?dl=1",
url="""
https://www.dropbox.com/s/464u3303eu2u4zp/peptide_structure_dataset.csv.gz?dl=1
""",
force_reload=force_reload,
verbose=verbose,
)
Expand Down Expand Up @@ -147,13 +153,6 @@ def download(self):
assert self._md5sum(path_split) == self.md5sum_stratified_split

def process(self):
try:
import torch
except ImportError:
raise ModuleNotFoundError(
"This dataset requires PyTorch to be the backend."
)

data_df = pd.read_csv(self.raw_data_path)
smiles_list = data_df["smiles"]
target_names = [
Expand Down Expand Up @@ -184,18 +183,18 @@ def process(self):

assert len(graph["edge_feat"]) == graph["edge_index"].shape[1]
assert len(graph["node_feat"]) == graph["num_nodes"]
dgl_graph = dgl.graph(
DGLgraph = dgl_graph(
(graph["edge_index"][0], graph["edge_index"][1]),
num_nodes=graph["num_nodes"],
)
dgl_graph.edata["feat"] = F.zerocopy_from_numpy(
DGLgraph.edata["feat"] = F.zerocopy_from_numpy(
graph["edge_feat"]
).to(F.int64)
dgl_graph.ndata["feat"] = F.zerocopy_from_numpy(
DGLgraph.ndata["feat"] = F.zerocopy_from_numpy(
graph["node_feat"]
).to(F.int64)

self.graphs.append(dgl_graph)
self.graphs.append(DGLgraph)
self.labels.append(y)

self.labels = F.tensor(self.labels, dtype=F.float32)
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ def test_fakenews():
assert g2.num_edges() - g.num_edges() == g.num_nodes()


@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="only supports pytorch"
)
def test_peptides_structural():
transform = dgl.AddSelfLoop(allow_duplicate=True)
dataset1 = data.PeptidesStructuralDataset()
g1, label = dataset1[0]
dataset2 = data.PeptidesStructuralDataset(transform=transform)
g2, _ = dataset2[0]

assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
# return a scalar tensor
assert not label.shape


@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
Expand Down

0 comments on commit a9658dc

Please sign in to comment.