Skip to content

Commit

Permalink
Merge branch 'master' into exclude_seed_edges
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Jul 31, 2024
2 parents 34db402 + 65f85b5 commit 578b497
Show file tree
Hide file tree
Showing 46 changed files with 2,152 additions and 835 deletions.
217 changes: 138 additions & 79 deletions examples/graphbolt/pyg/labor/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,34 @@
import time

from copy import deepcopy
from functools import partial

import dgl.graphbolt as gb
import torch

# Needed until https://github.com/pytorch/pytorch/issues/121197 is resolved to
# use the `--torch-compile` cmdline option reliably.
# For torch.compile until https://github.com/pytorch/pytorch/issues/121197 is
# resolved.
import torch._inductor.codecache

torch._dynamo.config.cache_size_limit = 32

import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
from load_dataset import load_dataset
from sage_conv import SAGEConv
from torch.torch_version import TorchVersion
from sage_conv import SAGEConv as CustomSAGEConv
from torch_geometric.nn import SAGEConv
from tqdm import tqdm


def accuracy(out, labels):
assert out.ndim == 2
assert out.size(0) == labels.size(0)
assert labels.ndim == 1 or (labels.ndim == 2 and labels.size(1) == 1)
labels = labels.flatten()
predictions = torch.argmax(out, 1)
return (labels == predictions).sum(dtype=torch.float64) / labels.size(0)


def convert_to_pyg(h, subgraph):
#####################################################################
# (HIGHLIGHT) Convert given features to be consumed by a PyG layer.
Expand All @@ -38,25 +50,39 @@ def convert_to_pyg(h, subgraph):


class GraphSAGE(torch.nn.Module):
def __init__(self, in_size, hidden_size, out_size, n_layers, dropout):
def __init__(
self, in_size, hidden_size, out_size, n_layers, dropout, variant
):
super().__init__()
assert variant in ["original", "custom"]
self.layers = torch.nn.ModuleList()
sizes = [in_size] + [hidden_size] * n_layers
for i in range(n_layers):
self.layers.append(SAGEConv(sizes[i], sizes[i + 1]))
self.linear = nn.Linear(hidden_size, out_size)
if variant == "custom":
sizes = [in_size] + [hidden_size] * n_layers
for i in range(n_layers):
self.layers.append(CustomSAGEConv(sizes[i], sizes[i + 1]))
self.linear = nn.Linear(hidden_size, out_size)
self.activation = nn.GELU()
else:
sizes = [in_size] + [hidden_size] * (n_layers - 1) + [out_size]
for i in range(n_layers):
self.layers.append(SAGEConv(sizes[i], sizes[i + 1]))
self.activation = nn.ReLU()
self.dropout = nn.Dropout(dropout)
self.hidden_size = hidden_size
self.out_size = out_size
self.variant = variant

def forward(self, subgraphs, x):
h = x
for layer, subgraph in zip(self.layers, subgraphs):
for i, (layer, subgraph) in enumerate(zip(self.layers, subgraphs)):
h, edge_index, size = convert_to_pyg(h, subgraph)
h = layer(h, edge_index, size=size)
h = F.gelu(h)
h = self.dropout(h)
return self.linear(h)
if self.variant == "custom":
h = self.activation(h)
h = self.dropout(h)
elif i != len(subgraphs) - 1:
h = self.activation(h)
return self.linear(h) if self.variant == "custom" else h

def inference(self, graph, features, dataloader, storage_device):
"""Conduct layer-wise inference to get all the node embeddings."""
Expand All @@ -79,9 +105,12 @@ def inference(self, graph, features, dataloader, storage_device):
data.node_features["feat"], data.sampled_subgraphs[0]
)
hidden_x = layer(h, edge_index, size=size)
hidden_x = F.gelu(hidden_x)
if is_last_layer:
hidden_x = self.linear(hidden_x)
if self.variant == "custom":
hidden_x = self.activation(hidden_x)
if is_last_layer:
hidden_x = self.linear(hidden_x)
elif not is_last_layer:
hidden_x = self.activation(hidden_x)
# By design, our output nodes are contiguous.
y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(
buffer_device
Expand Down Expand Up @@ -138,44 +167,56 @@ def create_dataloader(
)


@torch.compile
def train_step(minibatch, optimizer, model, loss_fn, multilabel, eval_fn):
node_features = minibatch.node_features["feat"]
labels = minibatch.labels
optimizer.zero_grad()
out = model(minibatch.sampled_subgraphs, node_features)
label_dtype = out.dtype if multilabel else None
loss = loss_fn(out, labels.to(label_dtype))
num_correct = eval_fn(out, labels) * labels.size(0)
loss.backward()
optimizer.step()
return loss.detach(), num_correct, labels.size(0)


def train_helper(
dataloader,
model,
optimizer,
loss_fn,
multilabel,
kwargs,
eval_fn,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
):
model.train() # Set the model to training mode
total_loss = torch.zeros(1, device=device) # Accumulator for the total loss
total_correct = 0 # Accumulator for the total number of correct predictions
# Accumulator for the total number of correct predictions
total_correct = torch.zeros(1, dtype=torch.float64, device=device)
total_samples = 0 # Accumulator for the total number of samples processed
num_batches = 0 # Counter for the number of mini-batches processed
start = time.time()
dataloader = tqdm(dataloader, "Training")
for minibatch in dataloader:
node_features = minibatch.node_features["feat"]
labels = minibatch.labels
optimizer.zero_grad()
out = model(minibatch.sampled_subgraphs, node_features)
label_dtype = out.dtype if multilabel else None
loss = loss_fn(out, labels.to(label_dtype))
total_loss += loss.detach()
total_correct += MF.f1_score(out, labels, **kwargs) * labels.size(0)
total_samples += labels.size(0)
loss.backward()
optimizer.step()
num_batches += 1
dataloader.set_postfix(
{
"num_nodes": node_features.size(0),
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
"cpu_cache_miss": cpu_cache_miss_rate_fn(),
}
for step, minibatch in enumerate(dataloader):
loss, num_correct, num_samples = train_step(
minibatch, optimizer, model, loss_fn, multilabel, eval_fn
)
total_loss += loss
total_correct += num_correct
total_samples += num_samples
num_batches += 1
if step % 25 == 0:
# log every 25 steps for performance.
dataloader.set_postfix(
{
"num_nodes": minibatch.node_ids().size(0),
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
"cpu_cache_miss": cpu_cache_miss_rate_fn(),
}
)
train_loss = total_loss / num_batches
train_acc = total_correct / total_samples
end = time.time()
Expand All @@ -187,7 +228,7 @@ def train(
valid_dataloader,
model,
multilabel,
kwargs,
eval_fn,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
Expand All @@ -206,25 +247,27 @@ def train(
optimizer,
loss_fn,
multilabel,
kwargs,
eval_fn,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
)
val_acc = evaluate(
model,
valid_dataloader,
kwargs,
eval_fn,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
)
if val_acc > best_model_acc:
best_model_acc = val_acc
best_model = deepcopy(model.state_dict())
best_model_epoch = epoch
print(
f"Epoch {epoch:02d}, Loss: {train_loss.item():.4f}, "
f"Approx. Train: {train_acc:.4f}, Approx. Val: {val_acc:.4f}, "
f"Approx. Train: {train_acc.item():.4f}, "
f"Approx. Val: {val_acc.item():.4f}, "
f"Time: {duration}s"
)
if best_model_epoch + args.early_stopping_patience < epoch:
Expand All @@ -240,7 +283,7 @@ def layerwise_infer(
itemsets,
all_nodes_set,
model,
kwargs,
eval_fn,
):
model.eval()
dataloader = create_dataloader(
Expand All @@ -257,39 +300,51 @@ def layerwise_infer(
metrics = {}
for split_name, itemset in itemsets.items():
nid, labels = itemset[:]
acc = MF.f1_score(
acc = eval_fn(
pred[nid.to(pred.device)],
labels.to(pred.device),
**kwargs,
)
metrics[split_name] = acc.item()

return metrics


@torch.compile
def evaluate_step(minibatch, model, eval_fn):
node_features = minibatch.node_features["feat"]
labels = minibatch.labels
out = model(minibatch.sampled_subgraphs, node_features)
num_correct = eval_fn(out, labels) * labels.size(0)
return num_correct, labels.size(0)


@torch.no_grad()
def evaluate(
model, dataloader, kwargs, gpu_cache_miss_rate_fn, cpu_cache_miss_rate_fn
model,
dataloader,
eval_fn,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
):
model.eval()
y_hats = []
ys = []
total_correct = torch.zeros(1, dtype=torch.float64, device=device)
total_samples = 0
val_dataloader_tqdm = tqdm(dataloader, "Evaluating")
for minibatch in val_dataloader_tqdm:
node_features = minibatch.node_features["feat"]
labels = minibatch.labels
out = model(minibatch.sampled_subgraphs, node_features)
y_hats.append(out)
ys.append(labels)
val_dataloader_tqdm.set_postfix(
{
"num_nodes": node_features.size(0),
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
"cpu_cache_miss": cpu_cache_miss_rate_fn(),
}
)
for step, minibatch in enumerate(val_dataloader_tqdm):
num_correct, num_samples = evaluate_step(minibatch, model, eval_fn)
total_correct += num_correct
total_samples += num_samples
if step % 25 == 0:
val_dataloader_tqdm.set_postfix(
{
"num_nodes": minibatch.node_ids().size(0),
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
"cpu_cache_miss": cpu_cache_miss_rate_fn(),
}
)

return MF.f1_score(torch.cat(y_hats), torch.cat(ys), **kwargs)
return total_correct / total_samples


def parse_args():
Expand Down Expand Up @@ -347,8 +402,8 @@ def parse_args():
"cuda-pinned-cuda",
"cuda-cuda-cuda",
],
help="Graph storage - feature storage - Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
help="Graph storage - feature storage - Train device: 'cpu' for CPU and"
" RAM, 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
)
parser.add_argument("--layer-dependency", action="store_true")
parser.add_argument("--batch-dependency", type=int, default=1)
Expand Down Expand Up @@ -379,11 +434,11 @@ def parse_args():
help="The sampling function when doing layerwise sampling.",
)
parser.add_argument(
"--disable-torch-compile",
action="store_true",
default=TorchVersion(torch.__version__) < TorchVersion("2.2.0a0"),
help="Disables torch.compile() on the trained GNN model because it is "
"enabled by default for torch>=2.2.0 without this option.",
"--sage-model-variant",
default="custom",
choices=["custom", "original"],
help="The custom SAGE GNN model provides higher accuracy with lower"
" runtime performance.",
)
parser.add_argument("--precision", type=str, default="high")
return parser.parse_args()
Expand Down Expand Up @@ -480,24 +535,28 @@ def main():
num_classes,
len(args.fanout),
args.dropout,
args.sage_model_variant,
).to(args.device)
assert len(args.fanout) == len(model.layers)
if not args.disable_torch_compile:
torch._dynamo.config.cache_size_limit = 32
model = torch.compile(model, fullgraph=True, dynamic=True)

kwargs = {
"num_labels" if multilabel else "num_classes": num_classes,
"task": "multilabel" if multilabel else "multiclass",
"validate_args": False,
}
eval_fn = (
partial(
# TODO @mfbalin: Find an implementation that does not synchronize.
MF.f1_score,
task="multilabel",
num_labels=num_classes,
validate_args=False,
)
if multilabel
else accuracy
)

best_model = train(
train_dataloader,
valid_dataloader,
model,
multilabel,
kwargs,
eval_fn,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
args.device,
Expand All @@ -514,7 +573,7 @@ def main():
itemsets,
all_nodes_set,
model,
kwargs,
eval_fn,
)
print("Final accuracy values:")
print(final_acc)
Expand Down
Loading

0 comments on commit 578b497

Please sign in to comment.