Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Synthetic dataset for graph classificaiton #364

Merged
merged 5 commits into from
Jan 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/source/api/python/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,10 @@ Stanford sentiment treebank dataset
For more information about the dataset, see `Sentiment Analysis <https://nlp.stanford.edu/sentiment/index.html>`__.

.. autoclass:: SST
:members: __getitem__, __len__
:members: __getitem__, __len__

Mini graph classification dataset
`````````````````````````````````

.. autoclass:: MiniGC
:members: __getitem__, __len__, num_classes
1 change: 1 addition & 0 deletions python/dgl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from . import citation_graph as citegrh
from .citation_graph import CoraBinary
from .minigc import *
from .tree import *
from .utils import *
from .sbm import SBMMixture
Expand Down
134 changes: 134 additions & 0 deletions python/dgl/data/minigc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""A mini synthetic dataset for graph classification benchmark."""

from collections.abc import Sequence
import math
import networkx as nx
import numpy as np

from ..graph import DGLGraph

__all__ = ['MiniGCDataset']

class MiniGCDataset(object):
"""The dataset class.

The datset contains 8 different types of graphs.
- class 0 : cycle graph
- class 1 : star graph
- class 2 : wheel graph
- class 3 : lollipop graph
- class 4 : hypercube graph
- class 5 : grid graph
- class 6 : clique graph
- class 7 : circular ladder graph
"""
def __init__(self, num_graphs, min_num_v, max_num_v):
"""
Parameters
----------
num_graphs: int
Number of graphs in this dataset.
min_num_v: int
Minimum number of nodes for graphs
max_num_v: int
Maximum number of nodes for graphs
"""
super(MiniGCDataset, self).__init__()
self.num_graphs = num_graphs
self.min_num_v = min_num_v
self.max_num_v = max_num_v
self.graphs = []
self.labels = []
self._generate()

def __len__(self):
return len(self.graphs)

def __getitem__(self, idx):
return self.graphs[idx], self.labels[idx]

@property
def num_classes(self):
"""Number of classes."""
return 8

def _generate(self):
self._gen_cycle(self.num_graphs // 8)
self._gen_star(self.num_graphs // 8)
self._gen_wheel(self.num_graphs // 8)
self._gen_lollipop(self.num_graphs // 8)
self._gen_hypercube(self.num_graphs // 8)
self._gen_grid(self.num_graphs // 8)
self._gen_clique(self.num_graphs // 8)
self._gen_circular_ladder(self.num_graphs - len(self.graphs))
# preprocess
for i in range(self.num_graphs):
self.graphs[i] = DGLGraph(self.graphs[i])
# add self edges
nodes = self.graphs[i].nodes()
self.graphs[i].add_edges(nodes, nodes)

def _gen_cycle(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.cycle_graph(num_v)
self.graphs.append(g)
self.labels.append(0)

def _gen_star(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
# nx.star_graph(N) gives a star graph with N+1 nodes
g = nx.star_graph(num_v - 1)
self.graphs.append(g)
self.labels.append(1)

def _gen_wheel(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.wheel_graph(num_v)
self.graphs.append(g)
self.labels.append(2)

def _gen_lollipop(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
path_len = np.random.randint(2, num_v // 2)
g = nx.lollipop_graph(m=num_v - path_len, n=path_len)
self.graphs.append(g)
self.labels.append(3)

def _gen_hypercube(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.hypercube_graph(int(math.log(num_v, 2)))
g = nx.convert_node_labels_to_integers(g)
self.graphs.append(g)
self.labels.append(4)

def _gen_grid(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
assert num_v >= 4, 'We require a grid graph to contain at least two ' \
'rows and two columns, thus 4 nodes, got {:d} ' \
'nodes'.format(num_v)
n_rows = np.random.randint(2, num_v // 2)
n_cols = num_v // n_rows
g = nx.grid_graph([n_rows, n_cols])
g = nx.convert_node_labels_to_integers(g)
self.graphs.append(g)
self.labels.append(5)

def _gen_clique(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.complete_graph(num_v)
self.graphs.append(g)
self.labels.append(6)

def _gen_circular_ladder(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.circular_ladder_graph(num_v)
self.graphs.append(g)
self.labels.append(7)
9 changes: 9 additions & 0 deletions tests/compute/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import dgl.data as data

def test_minigc():
ds = data.MiniGCDataset(16, 10, 20)
g, l = list(zip(*ds))
print(g, l)

if __name__ == '__main__':
test_minigc()