Skip to content

Commit

Permalink
Merge branch 'master' into async_gpu_sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-dlasalle committed May 16, 2023
2 parents 1bbf3ea + 29df6ec commit dcdefbc
Show file tree
Hide file tree
Showing 44 changed files with 1,775 additions and 142 deletions.
41 changes: 41 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dgl_option(USE_EPOLL "Build with epoll for socket communicator" ON)
dgl_option(TP_BUILD_LIBUV "Build libuv together with tensorpipe (only impacts Linux)" ON)
dgl_option(BUILD_TORCH "Build the PyTorch plugin" OFF)
dgl_option(BUILD_SPARSE "Build DGL sparse library" ON)
dgl_option(BUILD_GRAPHBOLT "Build Graphbolt library" OFF)
dgl_option(TORCH_PYTHON_INTERPS "Python interpreter used to build tensoradapter and DGL sparse library" python3)

# Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG
Expand Down Expand Up @@ -103,6 +104,11 @@ if(USE_OPENMP)
message(STATUS "Build with OpenMP.")
endif(USE_OPENMP)

if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)")
message(STATUS "Disabling LIBXSMM on ${CMAKE_SYSTEM_PROCESSOR}.")
set(USE_LIBXSMM OFF)
endif()

if(USE_LIBXSMM)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_LIBXSMM -DDGL_CPU_LLC_SIZE=40000000")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_LIBXSMM -DDGL_CPU_LLC_SIZE=40000000")
Expand Down Expand Up @@ -386,3 +392,38 @@ if(BUILD_SPARSE)
endif(MSVC)
add_dependencies(dgl_sparse dgl)
endif(BUILD_SPARSE)

if(BUILD_GRAPHBOLT)
message(STATUS "Configuring graphbolt library")
file(TO_NATIVE_PATH ${CMAKE_CURRENT_BINARY_DIR} BINDIR)
file(TO_NATIVE_PATH ${CMAKE_COMMAND} CMAKE_CMD)
if(MSVC)
file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/graphbolt/build.bat BUILD_SCRIPT)
add_custom_target(
graphbolt
ALL
${CMAKE_COMMAND} -E env
CMAKE_COMMAND=${CMAKE_CMD}
BINDIR=${BINDIR}
CFLAGS=${CMAKE_C_FLAGS}
CXXFLAGS=${CMAKE_CXX_FLAGS}
LDFLAGS=${CMAKE_SHARED_LINKER_FLAGS}
cmd /e:on /c ${BUILD_SCRIPT} ${TORCH_PYTHON_INTERPS}
DEPENDS ${BUILD_SCRIPT}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/graphbolt)
else(MSVC)
file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/graphbolt/build.sh BUILD_SCRIPT)
add_custom_target(
graphbolt
ALL
${CMAKE_COMMAND} -E env
CMAKE_COMMAND=${CMAKE_CMD}
BINDIR=${CMAKE_CURRENT_BINARY_DIR}
CFLAGS=${CMAKE_C_FLAGS}
CXXFLAGS=${CMAKE_CXX_FLAGS}
LDFLAGS=${CMAKE_SHARED_LINKER_FLAGS}
bash ${BUILD_SCRIPT} ${TORCH_PYTHON_INTERPS}
DEPENDS ${BUILD_SCRIPT}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/graphbolt)
endif(MSVC)
endif(BUILD_GRAPHBOLT)
8 changes: 7 additions & 1 deletion benchmarks/benchmarks/api/bench_format_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


@utils.benchmark("time", timeout=600)
@utils.parametrize_cpu("graph_name", ["cora", "livejournal", "friendster"])
@utils.parametrize_cpu(
"graph_name", ["cora", "pubmed", "ogbn-arxiv", "livejournal", "friendster"]
)
@utils.parametrize_gpu("graph_name", ["cora", "livejournal"])
@utils.parametrize(
"format",
Expand All @@ -27,6 +29,10 @@ def track_time(graph_name, format):
device = utils.get_bench_device()
graph = utils.get_graph(graph_name, from_format)
graph = graph.to(device)
if format == ("coo", "csr") and graph_name == "friendster":
# Mark graph as sorted to check performance for COO matrix marked as
# sorted. Note that friendster dataset is already sorted.
graph = dgl.graph(graph.edges(), row_sorted=True)
graph = graph.formats([from_format])
# dry run
graph.formats([to_format])
Expand Down
2 changes: 1 addition & 1 deletion docs/source/_templates/classtemplate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@

.. autoclass:: {{ name }}
:show-inheritance:
:members: __getitem__, __len__, collate_fn, forward, reset_parameters, rel_emb, rel_project, explain_node, explain_graph
:members: __getitem__, __len__, collate_fn, forward, reset_parameters, rel_emb, rel_project, explain_node, explain_graph, train_step
1 change: 1 addition & 0 deletions docs/source/api/python/nn-pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ Utility Modules
~dgl.nn.pytorch.explain.HeteroGNNExplainer
~dgl.nn.pytorch.explain.SubgraphX
~dgl.nn.pytorch.explain.HeteroSubgraphX
~dgl.nn.pytorch.explain.PGExplainer
~dgl.nn.pytorch.utils.LabelPropagation
~dgl.nn.pytorch.graph_transformer.DegreeEncoder
~dgl.nn.pytorch.utils.LaplacianPosEnc
Expand Down
24 changes: 24 additions & 0 deletions examples/core/gat/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Graph Attention Networks (GAT)
============

- Paper link: [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903)
- Author's code repo (tensorflow implementation):
[https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT).
- Popular pytorch implementation:
[https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT).

How to run
-------

Run with the following for multiclass node classification (available datasets: "cora", "citeseer", "pubmed")
```bash
python3 train.py --dataset cora
```

> **_NOTE:_** Users may occasionally run into low accuracy issue (e.g., test accuracy < 0.8) due to overfitting. This can be resolved by adding Early Stopping or reducing maximum number of training epochs.
Summary
-------
* cora: ~0.821
* citeseer: ~0.710
* pubmed: ~0.780
140 changes: 140 additions & 0 deletions examples/core/gat/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import argparse
import time

import dgl.nn as dglnn

import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import AddSelfLoop
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset


class GAT(nn.Module):
def __init__(self, in_size, hid_size, out_size, heads):
super().__init__()
self.gat_layers = nn.ModuleList()
# two-layer GAT
self.gat_layers.append(
dglnn.GATConv(
in_size,
hid_size,
heads[0],
feat_drop=0.6,
attn_drop=0.6,
activation=F.elu,
)
)
self.gat_layers.append(
dglnn.GATConv(
hid_size * heads[0],
out_size,
heads[1],
feat_drop=0.6,
attn_drop=0.6,
activation=None,
)
)

def forward(self, g, inputs):
h = inputs
for i, layer in enumerate(self.gat_layers):
h = layer(g, h)
if i == len(self.gat_layers) - 1: # last layer
h = h.mean(1)
else: # other layer(s)
h = h.flatten(1)
return h


def evaluate(g, features, labels, mask, model):
model.eval()
with torch.no_grad():
logits = model(g, features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)


def train(g, features, labels, masks, model, num_epochs):
# Define train/val samples, loss function and optimizer
train_mask = masks[0]
val_mask = masks[1]
loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)

for epoch in range(num_epochs):
t0 = time.time()
model.train()
logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = evaluate(g, features, labels, val_mask, model)
t1 = time.time()
print(
"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} | Time {:.4f}".format(
epoch, loss.item(), acc, t1 - t0
)
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="cora",
help="Dataset name ('cora', 'citeseer', 'pubmed').",
)
parser.add_argument(
"--num_epochs",
type=int,
default=200,
help="Number of epochs for train.",
)
parser.add_argument(
"--num_gpus",
type=int,
default=0,
help="Number of GPUs used for train and evaluation.",
)
args = parser.parse_args()
print(f"Training with DGL built-in GATConv module.")

# Load and preprocess dataset
transform = (
AddSelfLoop()
) # by default, it will first remove self-loops to prevent duplication
if args.dataset == "cora":
data = CoraGraphDataset(transform=transform)
elif args.dataset == "citeseer":
data = CiteseerGraphDataset(transform=transform)
elif args.dataset == "pubmed":
data = PubmedGraphDataset(transform=transform)
else:
raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0]
if args.num_gpus > 0 and torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
g = g.int().to(device)
features = g.ndata["feat"]
labels = g.ndata["label"]
masks = g.ndata["train_mask"], g.ndata["val_mask"], g.ndata["test_mask"]

# Create GAT model
in_size = features.shape[1]
out_size = data.num_classes
model = GAT(in_size, 8, out_size, heads=[8, 1]).to(device)

print("Training...")
train(g, features, labels, masks, model, args.num_epochs)

print("Testing...")
acc = evaluate(g, features, labels, masks[2], model)
print("Test accuracy {:.4f}".format(acc))
23 changes: 23 additions & 0 deletions examples/core/gated_gcn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
Gated Graph ConvNet (GatedGCN)
==============================

* paper link: [https://arxiv.org/abs/2003.00982.pdf](https://arxiv.org/abs/2003.00982.pdf)

## Dataset

Task: Graph Property Prediction

| Dataset | #Graphs | #Node Feats | #Edge Feats | Metric |
| :---------: | :-----: | :---------: | :---------: | :-----: |
| ogbg-molhiv | 41,127 | 9 | 3 | ROC-AUC |

How to run
----------

```bash
python train.py
```

## Summary

* ogbg-molhiv: ~0.781
Loading

0 comments on commit dcdefbc

Please sign in to comment.