Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dataset] Add ZINC Dataset #5428

Merged
merged 20 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/python/dgl.data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ Datasets for graph classification/regression tasks
GINDataset
FakeNewsDataset
BA2MotifDataset
ZINCDataset

Dataset adapters
-------------------
Expand Down
1 change: 1 addition & 0 deletions python/dgl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .pattern import PATTERNDataset
from .wikics import WikiCSDataset
from .yelp import YelpDataset
from .zinc import ZINCDataset


def register_data_args(parser):
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/data/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class CLUSTERDataset(DGLBuiltinDataset):
Number of classes for each node.
Examples
-------
--------
>>> from dgl.data import CLUSTERDataset
>>>
>>> trainset = CLUSTERDataset(mode='train')
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/data/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class PATTERNDataset(DGLBuiltinDataset):
Number of classes for each node.
Examples
-------
--------
>>> from dgl.data import PATTERNDataset
>>> data = PATTERNDataset(mode='train')
>>> data.num_classes
Expand Down
139 changes: 139 additions & 0 deletions python/dgl/data/zinc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import os

from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, load_graphs


class ZINCDataset(DGLBuiltinDataset):
r"""ZINC dataset for the graph regression task.
A subset (12K) of ZINC molecular graphs (250K) dataset is used to
regress a molecular property known as the constrained solubility.
For each molecular graph, the node features are the types of heavy
atoms, between which the edge features are the types of bonds.
Each graph contains 9-37 nodes and 16-84 edges.
Reference `<https://arxiv.org/pdf/2003.00982.pdf>`_
Statistics:
Train examples: 10,000
Valid examples: 1,000
Test examples: 1,000
Average number of nodes: 23.16
Average number of edges: 39.83
Number of atom types: 28
Number of bond types: 4
Parameters
----------
mode : str, optional
Should be chosen from ["train", "valid", "test"]
Default: "train".
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: "~/.dgl/".
force_reload : bool
Whether to reload the dataset.
Default: False.
verbose : bool
Whether to print out progress information.
Default: False.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
num_atom_types : int
Number of atom types.
num_bond_types : int
Number of bond types.
Examples
---------
>>> from dgl.data import ZINCDataset
>>> training_set = ZINCDataset(mode="train")
>>> training_set.num_atom_types
28
>>> len(training_set)
10000
>>> graph, label = training_set[0]
>>> graph
Graph(num_nodes=29, num_edges=64,
ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)})
"""

def __init__(
self,
mode="train",
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
self._url = _get_dgl_url("dataset/ZINC12k.zip")
self.mode = mode

super(ZINCDataset, self).__init__(
name="zinc",
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)

def process(self):
self.load()

def has_cache(self):
graph_path = os.path.join(
self.save_path, "ZincDGL_{}.bin".format(self.mode)
)
return os.path.exists(graph_path)

def load(self):
graph_path = os.path.join(
self.save_path, "ZincDGL_{}.bin".format(self.mode)
)
self._graphs, self._labels = load_graphs(graph_path)

@property
def num_atom_types(self):
return 28

@property
def num_bond_types(self):
return 4

def __len__(self):
return len(self._graphs)

def __getitem__(self, idx):
r"""Get one example by index.
Parameters
----------
idx : int
The sample index.
Returns
-------
dgl.DGLGraph
Each graph contains:
- ``ndata['feat']``: Types of heavy atoms as node features
- ``edata['feat']``: Types of bonds as edge features
Tensor
Constrained solubility as graph label
"""
labels = self._labels["g_label"]
if self._transform is None:
return self._graphs[idx], labels[idx]
else:
return self._transform(self._graphs[idx]), labels[idx]
57 changes: 57 additions & 0 deletions tests/python/common/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,63 @@ def test_cluster():
assert ds.num_classes == 6


@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="only supports pytorch"
)
def test_zinc():
mode_n_graphs = {
"train": 10000,
"valid": 1000,
"test": 1000,
}
transform = dgl.AddSelfLoop(allow_duplicate=True)
for mode, n_graphs in mode_n_graphs.items():
dataset1 = data.ZINCDataset(mode=mode)
g1, label = dataset1[0]
dataset2 = data.ZINCDataset(mode=mode, transform=transform)
g2, _ = dataset2[0]

assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
# return a scalar tensor
assert not label.shape


@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
def test_extract_archive():
# gzip
with tempfile.TemporaryDirectory() as src_dir:
gz_file = "gz_archive"
gz_path = os.path.join(src_dir, gz_file + ".gz")
content = b"test extract archive gzip"
with gzip.open(gz_path, "wb") as f:
f.write(content)
with tempfile.TemporaryDirectory() as dst_dir:
data.utils.extract_archive(gz_path, dst_dir, overwrite=True)
assert os.path.exists(os.path.join(dst_dir, gz_file))

# tar
with tempfile.TemporaryDirectory() as src_dir:
tar_file = "tar_archive"
tar_path = os.path.join(src_dir, tar_file + ".tar")
# default encode to utf8
content = "test extract archive tar\n".encode()
info = tarfile.TarInfo(name="tar_archive")
info.size = len(content)
with tarfile.open(tar_path, "w") as f:
f.addfile(info, io.BytesIO(content))
with tempfile.TemporaryDirectory() as dst_dir:
data.utils.extract_archive(tar_path, dst_dir, overwrite=True)
assert os.path.exists(os.path.join(dst_dir, tar_file))


def _test_construct_graphs_node_ids():
from dgl.data.csv_dataset_base import (
DGLGraphConstructor,
Expand Down