Skip to content

Commit

Permalink
[Dataset] Add peptides functional dataset in LRGB (#6363)
Browse files Browse the repository at this point in the history
  • Loading branch information
paoxiaode committed Sep 26, 2023
1 parent 88e9422 commit 614401f
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 5 deletions.
2 changes: 2 additions & 0 deletions docs/source/api/python/dgl.data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ Datasets for node classification/regression tasks
TolokersDataset
QuestionsDataset
MovieLensDataset
PeptidesStructuralDataset
PeptidesFunctionalDataset

Edge Prediction Datasets
---------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions python/dgl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@
TolokersDataset,
)

# RDKit is required for Peptides-Structural dataset.
# RDKit is required for Peptides-Structural, Peptides-Functional dataset.
# Exception handling was added to prevent crashes for users who are using other
# datasets.
try:
from .lrgb import PeptidesStructuralDataset
from .lrgb import PeptidesFunctionalDataset, PeptidesStructuralDataset
except ImportError:
pass
from .pattern import PATTERNDataset
Expand Down
208 changes: 208 additions & 0 deletions python/dgl/data/lrgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,211 @@ def __getitem__(self, idx):
return self.graphs[idx], self.labels[idx]
else:
return self._transform(self.graphs[idx]), self.labels[idx]


class PeptidesFunctionalDataset(DGLDataset):
r"""Peptides functional dataset for the graph classification task.
DGL dataset of 15,535 peptides represented as their molecular graph
(SMILES) with 10-way multi-task binary classification of their
functional classes.
The 10 classes represent the following functional classes (in order):
['antifungal', 'cell_cell_communication', 'anticancer',
'drug_delivery_vehicle', 'antimicrobial', 'antiviral',
'antihypertensive', 'antibacterial', 'antiparasitic', 'toxic']
Reference `<https://arxiv.org/abs/2206.08164.pdf>`_
Statistics:
- Train examples: 10,873
- Valid examples: 2,331
- Test examples: 2,331
- Average number of nodes: 150.94
- Average number of edges: 307.30
- Number of atom types: 9
- Number of bond types: 3
Parameters
----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: "~/.dgl/".
force_reload : bool
Whether to reload the dataset.
Default: False.
verbose : bool
Whether to print out progress information.
Default: False.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
smiles2graph (callable):
A callable function that converts a SMILES string into a graph object.
* The default smiles2graph requires rdkit to be installed *
Examples
---------
>>> from dgl.data import PeptidesFunctionalDataset
>>> dataset = PeptidesFunctionalDataset()
>>> len(dataset)
15535
>>> dataset.num_classes
10
>>> graph, label = dataset[0]
>>> graph
Graph(num_nodes=119, num_edges=244,
ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
>>> split_dict = dataset.get_idx_split()
>>> trainset = dataset[split_dict["train"]]
>>> graph, label = trainset[0]
>>> graph
Graph(num_nodes=338, num_edges=682,
ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
"""

def __init__(
self,
raw_dir=None,
force_reload=None,
verbose=None,
transform=None,
smiles2graph=smiles2graph,
):
self.smiles2graph = smiles2graph
# MD5 hash of the dataset file.
self.md5sum_data = "701eb743e899f4d793f0e13c8fa5a1b4"
self.url_stratified_split = """
https://www.dropbox.com/s/j4zcnx2eipuo0xz/splits_random_stratified_peptide.pickle?dl=1
"""
self.md5sum_stratified_split = "5a0114bdadc80b94fc7ae974f13ef061"

super(PeptidesFunctionalDataset, self).__init__(
name="Peptides-func",
raw_dir=raw_dir,
url="""
https://www.dropbox.com/s/ol2v01usvaxbsr8/peptide_multi_class_dataset.csv.gz?dl=1
""",
force_reload=force_reload,
verbose=verbose,
transform=transform,
)

@property
def raw_data_path(self):
return os.path.join(self.raw_path, "peptide_multi_class_dataset.csv.gz")

@property
def split_data_path(self):
return os.path.join(
self.raw_path, "splits_random_stratified_peptide.pickle"
)

@property
def graph_path(self):
return os.path.join(self.save_path, "Peptides-func.bin")

@property
def num_atom_types(self):
return 9

@property
def num_bond_types(self):
return 3

@property
def num_classes(self):
return 10

def _md5sum(self, path):
hash_md5 = hashlib.md5()
with open(path, "rb") as f:
buffer = f.read()
hash_md5.update(buffer)
return hash_md5.hexdigest()

def download(self):
path = download(self.url, path=self.raw_data_path)
# Save to disk the MD5 hash of the downloaded file.
hash = self._md5sum(path)
if hash != self.md5sum_data:
raise ValueError("Unexpected MD5 hash of the downloaded file")
open(os.path.join(self.raw_path, hash), "w").close()
# Download train/val/test splits.
path_split = download(
self.url_stratified_split, path=self.split_data_path
)
hash_split = self._md5sum(path_split)
if hash_split != self.md5sum_stratified_split:
raise ValueError("Unexpected MD5 hash of the split file")

def process(self):
data_df = pd.read_csv(self.raw_data_path)
smiles_list = data_df["smiles"]
if self.verbose:
print("Converting SMILES strings into graphs...")
self.graphs = []
self.labels = []
for i in tqdm(range(len(smiles_list))):
smiles = smiles_list[i]
graph = self.smiles2graph(smiles)

assert len(graph["edge_feat"]) == graph["edge_index"].shape[1]
assert len(graph["node_feat"]) == graph["num_nodes"]
DGLgraph = dgl_graph(
(graph["edge_index"][0], graph["edge_index"][1]),
num_nodes=graph["num_nodes"],
)
DGLgraph.edata["feat"] = F.zerocopy_from_numpy(
graph["edge_feat"]
).to(F.int64)
DGLgraph.ndata["feat"] = F.zerocopy_from_numpy(
graph["node_feat"]
).to(F.int64)
self.graphs.append(DGLgraph)
self.labels.append(eval(data_df["labels"].iloc[i]))
self.labels = F.tensor(self.labels, dtype=F.float32)

def load(self):
self.graphs, label_dict = load_graphs(self.graph_path)
self.labels = label_dict["labels"]

def save(self):
save_graphs(
self.graph_path, self.graphs, labels={"labels": self.labels}
)

def has_cache(self):
return os.path.exists(self.graph_path)

def get_idx_split(self):
"""Get dataset splits.
Returns:
Dict with 'train', 'val', 'test', splits indices.
"""
with open(self.split_data_path, "rb") as f:
split_dict = pickle.load(f)
for key in split_dict.keys():
split_dict[key] = F.zerocopy_from_numpy(split_dict[key])
return split_dict

def __len__(self):
return len(self.graphs)

def __getitem__(self, idx):
"""Get datapoint with index"""
if F.is_tensor(idx) and idx.dim() == 1:
return Subset(self, idx.cpu())

if self._transform is None:
return self.graphs[idx], self.labels[idx]
else:
return self._transform(self.graphs[idx]), self.labels[idx]
22 changes: 19 additions & 3 deletions tests/integration/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,29 @@ def test_fakenews():
def test_peptides_structural():
transform = dgl.AddSelfLoop(allow_duplicate=True)
dataset1 = data.PeptidesStructuralDataset()
g1, label = dataset1[0]
g1 = dataset1[0][0]
dataset2 = data.PeptidesStructuralDataset(transform=transform)
g2 = dataset2[0][0]

assert g2.num_edges() - g1.num_edges() == g1.num_nodes()


@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="only supports pytorch"
)
def test_peptides_functional():
transform = dgl.AddSelfLoop(allow_duplicate=True)
dataset1 = data.PeptidesFunctionalDataset()
g1, label = dataset1[0]
dataset2 = data.PeptidesFunctionalDataset(transform=transform)
g2, _ = dataset2[0]

assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
# return a scalar tensor
assert not label.shape
assert dataset1.num_classes == label.shape[0]


@unittest.skipIf(
Expand Down

0 comments on commit 614401f

Please sign in to comment.