diff --git a/docs/source/api/python/dgl.data.rst b/docs/source/api/python/dgl.data.rst index f4a7116b9dd5..9671301c8fb9 100644 --- a/docs/source/api/python/dgl.data.rst +++ b/docs/source/api/python/dgl.data.rst @@ -93,6 +93,7 @@ Datasets for graph classification/regression tasks GINDataset FakeNewsDataset BA2MotifDataset + ZINCDataset Dataset adapters ------------------- diff --git a/python/dgl/data/__init__.py b/python/dgl/data/__init__.py index d8b1b116d56d..c42e553c678f 100644 --- a/python/dgl/data/__init__.py +++ b/python/dgl/data/__init__.py @@ -56,6 +56,7 @@ from .pattern import PATTERNDataset from .wikics import WikiCSDataset from .yelp import YelpDataset +from .zinc import ZINCDataset def register_data_args(parser): diff --git a/python/dgl/data/cluster.py b/python/dgl/data/cluster.py index df75ddd0d6e9..264eb352584b 100644 --- a/python/dgl/data/cluster.py +++ b/python/dgl/data/cluster.py @@ -49,7 +49,7 @@ class CLUSTERDataset(DGLBuiltinDataset): Number of classes for each node. Examples - —------- + -------- >>> from dgl.data import CLUSTERDataset >>> >>> trainset = CLUSTERDataset(mode='train') diff --git a/python/dgl/data/pattern.py b/python/dgl/data/pattern.py index 240738501555..59894cd09217 100644 --- a/python/dgl/data/pattern.py +++ b/python/dgl/data/pattern.py @@ -50,7 +50,7 @@ class PATTERNDataset(DGLBuiltinDataset): Number of classes for each node. Examples - —------- + -------- >>> from dgl.data import PATTERNDataset >>> data = PATTERNDataset(mode='train') >>> data.num_classes diff --git a/python/dgl/data/zinc.py b/python/dgl/data/zinc.py new file mode 100644 index 000000000000..715e2a482417 --- /dev/null +++ b/python/dgl/data/zinc.py @@ -0,0 +1,139 @@ +import os + +from .dgl_dataset import DGLBuiltinDataset +from .utils import _get_dgl_url, load_graphs + + +class ZINCDataset(DGLBuiltinDataset): + r"""ZINC dataset for the graph regression task. + + A subset (12K) of ZINC molecular graphs (250K) dataset is used to + regress a molecular property known as the constrained solubility. + For each molecular graph, the node features are the types of heavy + atoms, between which the edge features are the types of bonds. + Each graph contains 9-37 nodes and 16-84 edges. + + Reference ``_ + + Statistics: + + Train examples: 10,000 + Valid examples: 1,000 + Test examples: 1,000 + Average number of nodes: 23.16 + Average number of edges: 39.83 + Number of atom types: 28 + Number of bond types: 4 + + Parameters + ---------- + mode : str, optional + Should be chosen from ["train", "valid", "test"] + Default: "train". + raw_dir : str + Raw file directory to download/contains the input data directory. + Default: "~/.dgl/". + force_reload : bool + Whether to reload the dataset. + Default: False. + verbose : bool + Whether to print out progress information. + Default: False. + transform : callable, optional + A transform that takes in a :class:`~dgl.DGLGraph` object and returns + a transformed version. The :class:`~dgl.DGLGraph` object will be + transformed before every access. + + Attributes + ---------- + num_atom_types : int + Number of atom types. + num_bond_types : int + Number of bond types. + + Examples + --------- + >>> from dgl.data import ZINCDataset + + >>> training_set = ZINCDataset(mode="train") + >>> training_set.num_atom_types + 28 + >>> len(training_set) + 10000 + >>> graph, label = training_set[0] + >>> graph + Graph(num_nodes=29, num_edges=64, + ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)} + edata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}) + """ + + def __init__( + self, + mode="train", + raw_dir=None, + force_reload=False, + verbose=False, + transform=None, + ): + self._url = _get_dgl_url("dataset/ZINC12k.zip") + self.mode = mode + + super(ZINCDataset, self).__init__( + name="zinc", + url=self._url, + raw_dir=raw_dir, + force_reload=force_reload, + verbose=verbose, + transform=transform, + ) + + def process(self): + self.load() + + def has_cache(self): + graph_path = os.path.join( + self.save_path, "ZincDGL_{}.bin".format(self.mode) + ) + return os.path.exists(graph_path) + + def load(self): + graph_path = os.path.join( + self.save_path, "ZincDGL_{}.bin".format(self.mode) + ) + self._graphs, self._labels = load_graphs(graph_path) + + @property + def num_atom_types(self): + return 28 + + @property + def num_bond_types(self): + return 4 + + def __len__(self): + return len(self._graphs) + + def __getitem__(self, idx): + r"""Get one example by index. + + Parameters + ---------- + idx : int + The sample index. + + Returns + ------- + dgl.DGLGraph + Each graph contains: + + - ``ndata['feat']``: Types of heavy atoms as node features + - ``edata['feat']``: Types of bonds as edge features + + Tensor + Constrained solubility as graph label + """ + labels = self._labels["g_label"] + if self._transform is None: + return self._graphs[idx], labels[idx] + else: + return self._transform(self._graphs[idx]), labels[idx] diff --git a/tests/python/common/data/test_data.py b/tests/python/common/data/test_data.py index 3125aa546d13..2e487b96a44c 100644 --- a/tests/python/common/data/test_data.py +++ b/tests/python/common/data/test_data.py @@ -408,6 +408,63 @@ def test_cluster(): assert ds.num_classes == 6 +@unittest.skipIf( + F._default_context_str == "gpu", + reason="Datasets don't need to be tested on GPU.", +) +@unittest.skipIf( + dgl.backend.backend_name != "pytorch", reason="only supports pytorch" +) +def test_zinc(): + mode_n_graphs = { + "train": 10000, + "valid": 1000, + "test": 1000, + } + transform = dgl.AddSelfLoop(allow_duplicate=True) + for mode, n_graphs in mode_n_graphs.items(): + dataset1 = data.ZINCDataset(mode=mode) + g1, label = dataset1[0] + dataset2 = data.ZINCDataset(mode=mode, transform=transform) + g2, _ = dataset2[0] + + assert g2.num_edges() - g1.num_edges() == g1.num_nodes() + # return a scalar tensor + assert not label.shape + + +@unittest.skipIf( + F._default_context_str == "gpu", + reason="Datasets don't need to be tested on GPU.", +) +@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet") +def test_extract_archive(): + # gzip + with tempfile.TemporaryDirectory() as src_dir: + gz_file = "gz_archive" + gz_path = os.path.join(src_dir, gz_file + ".gz") + content = b"test extract archive gzip" + with gzip.open(gz_path, "wb") as f: + f.write(content) + with tempfile.TemporaryDirectory() as dst_dir: + data.utils.extract_archive(gz_path, dst_dir, overwrite=True) + assert os.path.exists(os.path.join(dst_dir, gz_file)) + + # tar + with tempfile.TemporaryDirectory() as src_dir: + tar_file = "tar_archive" + tar_path = os.path.join(src_dir, tar_file + ".tar") + # default encode to utf8 + content = "test extract archive tar\n".encode() + info = tarfile.TarInfo(name="tar_archive") + info.size = len(content) + with tarfile.open(tar_path, "w") as f: + f.addfile(info, io.BytesIO(content)) + with tempfile.TemporaryDirectory() as dst_dir: + data.utils.extract_archive(tar_path, dst_dir, overwrite=True) + assert os.path.exists(os.path.join(dst_dir, tar_file)) + + def _test_construct_graphs_node_ids(): from dgl.data.csv_dataset_base import ( DGLGraphConstructor,