Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dataset] Chameleon #5477

Merged
merged 10 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/python/dgl.data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Datasets for node classification/regression tasks
YelpDataset
PATTERNDataset
CLUSTERDataset
ChameleonDataset

Edge Prediction Datasets
---------------------------------------
Expand Down
1 change: 1 addition & 0 deletions python/dgl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from .utils import *
from .cluster import CLUSTERDataset
from .pattern import PATTERNDataset
from .wiki_network import ChameleonDataset
from .wikics import WikiCSDataset
from .yelp import YelpDataset
from .zinc import ZINCDataset
Expand Down
18 changes: 15 additions & 3 deletions python/dgl/data/dgl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import abc
import hashlib
import os
import sys
import traceback

from ..utils import retry_method_with_fix
Expand Down Expand Up @@ -221,6 +220,15 @@ def _get_hash(self):
hash_func.update(str(self._hash_key).encode("utf-8"))
return hash_func.hexdigest()[:8]

def _get_hash_url_suffix(self):
"""Get the suffix based on the hash value of the url."""
if self._url is None:
return ""
else:
hash_func = hashlib.sha1()
hash_func.update(str(self._url).encode("utf-8"))
return "_" + hash_func.hexdigest()[:8]
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved

@property
def url(self):
r"""Get url to download the raw dataset."""
Expand All @@ -241,7 +249,9 @@ def raw_path(self):
r"""Directory contains the input data files.
By default raw_path = os.path.join(self.raw_dir, self.name)
"""
return os.path.join(self.raw_dir, self.name)
return os.path.join(
self.raw_dir, self.name + self._get_hash_url_suffix()
)

@property
def save_dir(self):
Expand All @@ -251,7 +261,9 @@ def save_dir(self):
@property
def save_path(self):
r"""Path to save the processed dataset."""
return os.path.join(self._save_dir, self.name)
return os.path.join(
self.save_dir, self.name + self._get_hash_url_suffix()
)

@property
def verbose(self):
Expand Down
11 changes: 2 additions & 9 deletions python/dgl/data/qm7b.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
"""QM7b dataset for graph property prediction (regression)."""
import os

import numpy as np
from scipy import io

from .. import backend as F
from ..convert import graph as dgl_graph

from .dgl_dataset import DGLDataset
from .utils import (
check_sha1,
deprecate_property,
download,
load_graphs,
save_graphs,
)
from .utils import check_sha1, download, load_graphs, save_graphs


class QM7bDataset(DGLDataset):
Expand Down Expand Up @@ -93,7 +86,7 @@ def __init__(
)

def process(self):
mat_path = self.raw_path + ".mat"
mat_path = os.path.join(self.raw_dir, self.name + ".mat")
self.graphs, self.label = self._load_graph(mat_path)

def _load_graph(self, filename):
Expand Down
182 changes: 182 additions & 0 deletions python/dgl/data/wiki_network.py
jermainewang marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""
Wikipedia page-page networks on the chameleon topic.
"""
import os

import numpy as np

from ..convert import graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url


class WikiNetworkDataset(DGLBuiltinDataset):
r"""Wikipedia page-page networks from `Multi-scale Attributed Node
Embedding <https://arxiv.org/abs/1909.13021>`__

Parameters
----------
name : str
Name of the dataset.
raw_dir : str
Raw file directory to store the processed data.
force_reload : bool
Whether to always generate the data from scratch rather than load a
cached version.
verbose : bool
Whether to print progress information.
transform : callable
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""

def __init__(self, name, raw_dir, force_reload, verbose, transform):
url = _get_dgl_url(f"dataset/{name}.zip")
super(WikiNetworkDataset, self).__init__(
name=name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)

def process(self):
"""Load and process the data."""
try:
import torch
except ImportError:
raise ModuleNotFoundError(
"This dataset requires PyTorch to be the backend."
)

# Process node features and labels.
with open(f"{self.raw_path}/out1_node_feature_label.txt", "r") as f:
data = f.read().split("\n")[1:-1]
features = [
[float(v) for v in r.split("\t")[1].split(",")] for r in data
]
features = torch.tensor(features, dtype=torch.float)
labels = [int(r.split("\t")[2]) for r in data]
self._num_classes = max(labels) + 1
labels = torch.tensor(labels, dtype=torch.long)

# Process graph structure.
with open(f"{self.raw_path}/out1_graph_edges.txt", "r") as f:
data = f.read().split("\n")[1:-1]
data = [[int(v) for v in r.split("\t")] for r in data]
dst, src = torch.tensor(data, dtype=torch.long).t().contiguous()

self._g = graph((src, dst), num_nodes=features.size(0))
self._g.ndata["feat"] = features
self._g.ndata["label"] = labels

# Process 10 train/val/test node splits.
train_masks, val_masks, test_masks = [], [], []
for i in range(10):
filepath = f"{self.raw_path}/{self.name}_split_0.6_0.2_{i}.npz"
f = np.load(filepath)
train_masks += [torch.from_numpy(f["train_mask"])]
val_masks += [torch.from_numpy(f["val_mask"])]
test_masks += [torch.from_numpy(f["test_mask"])]
self._g.ndata["train_mask"] = torch.stack(train_masks, dim=1).bool()
self._g.ndata["val_mask"] = torch.stack(val_masks, dim=1).bool()
self._g.ndata["test_mask"] = torch.stack(test_masks, dim=1).bool()

def has_cache(self):
return os.path.exists(self.raw_path)

def load(self):
self.process()

def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph."
if self._transform is None:
return self._g
else:
return self._transform(self._g)

def __len__(self):
return 1

@property
def num_classes(self):
return self._num_classes


class ChameleonDataset(WikiNetworkDataset):
"""Wikipedia page-page network on chameleons from `Multi-scale Attributed
Node Embedding <https://arxiv.org/abs/1909.13021>`__ and later processed by
`Geom-GCN: Geometric Graph Convolutional Networks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why Geom-GCN is related? Is this dataset designated for Geom-GCN?

Copy link
Member Author

@mufeili mufeili Mar 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Geom-GCN introduced this variant of the dataset, including turning the task from node regression into node classification, modifying node features, and introducing these dataset splits.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe reword it. Currently working sounds like this dataset is just used by Geom-GCN.
e.g. later processed by --> introduced by.

Wikipedia page-page network on chameleons from `Multi-scale Attributed
    Node Embedding <https://arxiv.org/abs/1909.13021>`__, introduced by
    `Geom-GCN: Geometric Graph Convolutional Networks

Copy link
Member Author

@mufeili mufeili Mar 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did "later processed by" -> "later modified by"

<https://arxiv.org/abs/2002.05287>`

Nodes represent articles from the English Wikipedia, edges reflect mutual
links between them. Node features indicate the presence of particular nouns
in the articles. The nodes were classified into 5 classes in terms of their
average monthly traffic.

Statistics:

- Nodes: 2277
- Edges: 36101
- Number of Classes: 5
- 10 splits with 60/20/20 train/val/test ratio

- Train: 1092
- Val: 729
- Test: 456

Parameters
----------
raw_dir : str, optional
Raw file directory to store the processed data. Default: ~/.dgl/
force_reload : bool, optional
Whether to always generate the data from scratch rather than load a
cached version. Default: False
verbose : bool, optional
Whether to print progress information. Default: True
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access. Default: None

Attributes
----------
num_classes : int
Number of node classes

Notes
-----
The graph does not come with edges for both directions.

Examples
--------

>>> from dgl.data import ChameleonDataset
>>> dataset = ChameleonDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes

>>> # get node features
>>> feat = g.ndata["feat"]

>>> # get data split
>>> train_mask = g.ndata["train_mask"]
>>> val_mask = g.ndata["val_mask"]
>>> test_mask = g.ndata["test_mask"]

>>> # get labels
>>> label = g.ndata['label']
"""

def __init__(
self, raw_dir=None, force_reload=False, verbose=True, transform=None
):
super(ChameleonDataset, self).__init__(
name="chameleon",
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
23 changes: 23 additions & 0 deletions tests/python/common/data/test_wiki_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import unittest

import backend as F

import dgl


@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="only supports pytorch"
)
def test_chameleon():
transform = dgl.AddSelfLoop(allow_duplicate=True)

# chameleon
g = dgl.data.ChameleonDataset(force_reload=True)[0]
assert g.num_nodes() == 2277
assert g.num_edges() == 36101
g2 = dgl.data.ChameleonDataset(force_reload=True, transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()