diff --git a/examples/graphbolt/disk_based_feature/node_classification_offline.py b/examples/graphbolt/disk_based_feature/node_classification_offline.py new file mode 100644 index 000000000000..26338f8a576b --- /dev/null +++ b/examples/graphbolt/disk_based_feature/node_classification_offline.py @@ -0,0 +1,644 @@ +""" +This example references examples/graphbolt/pyg/labor/node_classification.py +""" + +import argparse + +import os +import time + +from copy import deepcopy + +import dgl.graphbolt as gb +import dgl.nn as dglnn +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + + +def accuracy(out, labels): + assert out.ndim == 2 + assert out.size(0) == labels.size(0) + assert labels.ndim == 1 or (labels.ndim == 2 and labels.size(1) == 1) + labels = labels.flatten() + predictions = torch.argmax(out, 1) + return (labels == predictions).sum(dtype=torch.float64) / labels.size(0) + + +class SAGE(nn.Module): + def __init__(self, in_size, hidden_size, out_size, num_layers, dropout): + super().__init__() + self.layers = nn.ModuleList() + # Three-layer GraphSAGE-mean. + self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean")) + for _ in range(num_layers - 2): + self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean")) + self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean")) + self.dropout = nn.Dropout(dropout) + self.hidden_size = hidden_size + self.out_size = out_size + # Set the dtype for the layers manually. + self.set_layer_dtype(torch.float32) + + def set_layer_dtype(self, _dtype): + for layer in self.layers: + for param in layer.parameters(): + param.data = param.data.to(_dtype) + + def forward(self, blocks, x): + hidden_x = x + for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)): + hidden_x = layer(block, hidden_x) + is_last_layer = layer_idx == len(self.layers) - 1 + if not is_last_layer: + hidden_x = F.relu(hidden_x) + hidden_x = self.dropout(hidden_x) + return hidden_x + + def inference(self, graph, features, dataloader, storage_device): + """Conduct layer-wise inference to get all the node embeddings.""" + pin_memory = storage_device == "pinned" + buffer_device = torch.device("cpu" if pin_memory else storage_device) + + for layer_idx, layer in enumerate(self.layers): + is_last_layer = layer_idx == len(self.layers) - 1 + + y = torch.empty( + graph.total_num_nodes, + self.out_size if is_last_layer else self.hidden_size, + dtype=torch.float32, + device=buffer_device, + pin_memory=pin_memory, + ) + for data in tqdm(dataloader): + # len(blocks) = 1 + hidden_x = layer(data.blocks[0], data.node_features["feat"]) + if not is_last_layer: + hidden_x = F.relu(hidden_x) + hidden_x = self.dropout(hidden_x) + # By design, our output nodes are contiguous. + y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to( + buffer_device + ) + if not is_last_layer: + features.update("node", None, "feat", y) + + return y + + +def create_sampler(graph, itemset, batch_size, fanout, device, job): + # Initialize an ItemSampler to sample mini-batches from the dataset. + datapipe = gb.ItemSampler( + itemset, + batch_size=batch_size, + shuffle=(job == "train"), + drop_last=False, + ) + # Copy the data to the specified device. + if args.graph_device != "cpu": + datapipe = datapipe.copy_to(device=device) + # Sample neighbors for each node in the mini-batch. + kwargs = ( + { + # Layer dependency makes it so that the sampled neighborhoods across layers + # become correlated, reducing the total number of sampled unique nodes in a + # minibatch, thus reducing the amount of feature data requested. + "layer_dependency": args.layer_dependency, + # Batch dependency makes it so that the sampled neighborhoods across minibatches + # become correlated, reducing the total number of sampled unique nodes across + # minibatches, thus increasing temporal locality and reducing cache miss rates. + "batch_dependency": args.batch_dependency, + } + if args.sample_mode == "sample_layer_neighbor" + else {} + ) + datapipe = getattr(datapipe, args.sample_mode)( + graph, fanout if job != "infer" else [-1], **kwargs + ) + return gb.DataLoader( + datapipe, + num_workers=args.num_workers, + overlap_graph_fetch=args.overlap_graph_fetch, + ) + + +def create_trainloader(features, itemset, device, job): + assert job == "train" + # Initialize an ItemSampler to sample mini-batches from the dataset. + datapipe = gb.ItemSampler( + itemset, + batch_size=1, + shuffle=(job == "train"), + drop_last=False, + ) + datapipe = datapipe.load_minibatch(subgraph_dir=args.subgraph_dir) + # Copy the data to the specified device. + if args.feature_device != "cpu": + datapipe = datapipe.copy_to(device=device) + # Fetch node features for the sampled subgraph. + datapipe = datapipe.fetch_feature( + features, + node_feature_keys=["feat"], + overlap_fetch=args.overlap_feature_fetch, + ) + # Copy the data to the specified device. + if args.feature_device == "cpu": + datapipe = datapipe.copy_to(device=device) + # Create and return a DataLoader to handle data loading. + return gb.DataLoader( + datapipe, + num_workers=args.num_workers, + overlap_graph_fetch=args.overlap_graph_fetch, + ) + + +def create_dataloader( + graph, features, itemset, batch_size, fanout, device, job +): + + # Initialize an ItemSampler to sample mini-batches from the dataset. + datapipe = gb.ItemSampler( + itemset, + batch_size=batch_size, + shuffle=(job == "train"), + drop_last=False, + ) + # Copy the data to the specified device. + if args.graph_device != "cpu": + datapipe = datapipe.copy_to(device=device) + # Sample neighbors for each node in the mini-batch. + kwargs = ( + { + # Layer dependency makes it so that the sampled neighborhoods across layers + # become correlated, reducing the total number of sampled unique nodes in a + # minibatch, thus reducing the amount of feature data requested. + "layer_dependency": args.layer_dependency, + # Batch dependency makes it so that the sampled neighborhoods across minibatches + # become correlated, reducing the total number of sampled unique nodes across + # minibatches, thus increasing temporal locality and reducing cache miss rates. + "batch_dependency": args.batch_dependency, + } + if args.sample_mode == "sample_layer_neighbor" + else {} + ) + datapipe = getattr(datapipe, args.sample_mode)( + graph, fanout if job != "infer" else [-1], **kwargs + ) + # Copy the data to the specified device. + if args.feature_device != "cpu": + datapipe = datapipe.copy_to(device=device) + # Fetch node features for the sampled subgraph. + datapipe = datapipe.fetch_feature( + features, + node_feature_keys=["feat"], + overlap_fetch=args.overlap_feature_fetch, + ) + # Copy the data to the specified device. + if args.feature_device == "cpu": + datapipe = datapipe.copy_to(device=device) + # Create and return a DataLoader to handle data loading. + return gb.DataLoader( + datapipe, + num_workers=args.num_workers, + overlap_graph_fetch=args.overlap_graph_fetch, + ) + + +def train_step(minibatch, optimizer, model, loss_fn): + node_features = minibatch.node_features["feat"] + labels = minibatch.labels + optimizer.zero_grad() + out = model(minibatch.blocks, node_features) + loss = loss_fn(out, labels) + num_correct = accuracy(out, labels) * labels.size(0) + loss.backward() + optimizer.step() + return loss.detach(), num_correct, labels.size(0) + + +def train_helper( + dataloader, + model, + optimizer, + loss_fn, + gpu_cache_miss_rate_fn, + cpu_cache_miss_rate_fn, + device, +): + model.train() # Set the model to training mode + total_loss = torch.zeros(1, device=device) # Accumulator for the total loss + # Accumulator for the total number of correct predictions + total_correct = torch.zeros(1, dtype=torch.float64, device=device) + total_samples = 0 # Accumulator for the total number of samples processed + num_batches = 0 # Counter for the number of mini-batches processed + start = time.time() + dataloader = tqdm(dataloader, "Training") + for step, minibatch in enumerate(dataloader): + loss, num_correct, num_samples = train_step( + minibatch, optimizer, model, loss_fn + ) + total_loss += loss + total_correct += num_correct + total_samples += num_samples + num_batches += 1 + if step % 25 == 0: + # log every 25 steps for performance. + dataloader.set_postfix( + { + "num_nodes": minibatch.node_ids().size(0), + "gpu_cache_miss": gpu_cache_miss_rate_fn(), + "cpu_cache_miss": cpu_cache_miss_rate_fn(), + } + ) + train_loss = total_loss / num_batches + train_acc = total_correct / total_samples + end = time.time() + return train_loss, train_acc, end - start + + +def train( + train_dataloader, + valid_dataloader, + model, + gpu_cache_miss_rate_fn, + cpu_cache_miss_rate_fn, + device, +): + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + loss_fn = nn.CrossEntropyLoss() + + best_model = None + best_model_acc = 0 + best_model_epoch = -1 + + for epoch in range(args.epochs): + train_loss, train_acc, duration = train_helper( + train_dataloader, + model, + optimizer, + loss_fn, + gpu_cache_miss_rate_fn, + cpu_cache_miss_rate_fn, + device, + ) + val_acc = evaluate( + model, + valid_dataloader, + gpu_cache_miss_rate_fn, + cpu_cache_miss_rate_fn, + device, + ) + if val_acc > best_model_acc: + best_model_acc = val_acc + best_model = deepcopy(model.state_dict()) + best_model_epoch = epoch + print( + f"Epoch {epoch:02d}, Loss: {train_loss.item():.4f}, " + f"Approx. Train: {train_acc.item():.4f}, " + f"Approx. Val: {val_acc.item():.4f}, " + f"Time: {duration}s" + ) + if best_model_epoch + args.early_stopping_patience < epoch: + break + return best_model + + +@torch.no_grad() +def layerwise_infer( + args, + graph, + features, + itemsets, + all_nodes_set, + model, +): + model.eval() + dataloader = create_dataloader( + graph=graph, + features=features, + itemset=all_nodes_set, + batch_size=args.batch_size, + fanout=[-1], + device=args.device, + job="infer", + ) + pred = model.inference(graph, features, dataloader, args.feature_device) + + metrics = {} + for split_name, itemset in itemsets.items(): + nid, labels = itemset[:] + acc = accuracy( + pred[nid.to(pred.device)], + labels.to(pred.device), + ) + metrics[split_name] = acc.item() + + return metrics + + +def evaluate_step(minibatch, model): + node_features = minibatch.node_features["feat"] + labels = minibatch.labels + out = model(minibatch.blocks, node_features) + num_correct = accuracy(out, labels) * labels.size(0) + return num_correct, labels.size(0) + + +@torch.no_grad() +def evaluate( + model, + dataloader, + gpu_cache_miss_rate_fn, + cpu_cache_miss_rate_fn, + device, +): + model.eval() + total_correct = torch.zeros(1, dtype=torch.float64, device=device) + total_samples = 0 + val_dataloader_tqdm = tqdm(dataloader, "Evaluating") + for step, minibatch in enumerate(val_dataloader_tqdm): + num_correct, num_samples = evaluate_step(minibatch, model) + total_correct += num_correct + total_samples += num_samples + if step % 25 == 0: + val_dataloader_tqdm.set_postfix( + { + "num_nodes": minibatch.node_ids().size(0), + "gpu_cache_miss": gpu_cache_miss_rate_fn(), + "cpu_cache_miss": cpu_cache_miss_rate_fn(), + } + ) + + return total_correct / total_samples + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Which dataset are you going to use?" + ) + parser.add_argument( + "--epochs", type=int, default=9999999, help="Number of training epochs." + ) + parser.add_argument( + "--lr", + type=float, + default=0.001, + help="Learning rate for optimization.", + ) + parser.add_argument("--num-hidden", type=int, default=256) + parser.add_argument("--dropout", type=float, default=0.2) + parser.add_argument( + "--batch-size", type=int, default=1024, help="Batch size for training." + ) + parser.add_argument( + "--num-workers", + type=int, + default=0, + help="Number of workers for data loading.", + ) + parser.add_argument( + "--dataset", + type=str, + default="ogbn-products", + choices=[ + "ogbn-arxiv", + "ogbn-products", + "ogbn-papers100M", + "reddit", + "yelp", + "flickr", + ], + ) + parser.add_argument("--root", type=str, default="datasets") + parser.add_argument( + "--fanout", + type=str, + default="10,10,10", + help="Fan-out of neighbor sampling. len(fanout) determines the number of" + " GNN layers in your model. Default: 10,10,10", + ) + parser.add_argument( + "--mode", + default="pinned-pinned-cuda", + choices=[ + "cpu-cpu-cpu", + "cpu-cpu-cuda", + "cpu-pinned-cuda", + "pinned-pinned-cuda", + "cuda-pinned-cuda", + "cuda-cuda-cuda", + ], + help="Graph storage - feature storage - Train device: 'cpu' for CPU and" + " RAM, 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.", + ) + parser.add_argument("--layer-dependency", action="store_true") + parser.add_argument("--batch-dependency", type=int, default=1) + parser.add_argument( + "--cpu-feature-cache-policy", + type=str, + default=None, + choices=["s3-fifo", "sieve", "lru", "clock"], + help="The cache policy for the CPU feature cache.", + ) + parser.add_argument( + "--cpu-cache-size-in-gigabytes", + type=float, + default=0, + help="The capacity of the CPU cache, the number of features to store.", + ) + parser.add_argument( + "--gpu-cache-size-in-gigabytes", + type=float, + default=0, + help="The capacity of the GPU cache, the number of features to store.", + ) + parser.add_argument("--early-stopping-patience", type=int, default=25) + parser.add_argument( + "--sample-mode", + default="sample_neighbor", + choices=["sample_neighbor", "sample_layer_neighbor"], + help="The sampling function when doing layerwise sampling.", + ) + parser.add_argument("--precision", type=str, default="high") + parser.add_argument("--enable-inference", action="store_true") + return parser.parse_args() + + +def main(): + start = time.time() + torch.set_float32_matmul_precision(args.precision) + if not torch.cuda.is_available(): + args.mode = "cpu-cpu-cpu" + print(f"Training in {args.mode} mode.") + args.graph_device, args.feature_device, args.device = args.mode.split("-") + args.overlap_feature_fetch = args.feature_device == "pinned" + # For now, only sample_layer_neighbor is faster with this option + args.overlap_graph_fetch = ( + args.sample_mode == "sample_layer_neighbor" + and args.graph_device == "pinned" + ) + + """ + Load and preprocess on-disk dataset. + We inspect the in_memory field of the feature_data in the YAML file and modify + it to False. This will make sure the feature_data is loaded as DiskBasedFeature. + """ + print("Loading data...") + disk_based_feature_keys = None + if args.cpu_cache_size_in_gigabytes > 0: + disk_based_feature_keys = [("node", None, "feat")] + + dataset = gb.BuiltinDataset(args.dataset, root=args.root) + if disk_based_feature_keys is None: + disk_based_feature_keys = set() + for feature in dataset.yaml_data["feature_data"]: + feature_key = (feature["domain"], feature["type"], feature["name"]) + # Set the in_memory setting to False without modifying YAML file. + if feature_key in disk_based_feature_keys: + feature["in_memory"] = False + dataset = dataset.load() + + # Move the dataset to the selected storage. + graph = ( + dataset.graph.pin_memory_() + if args.graph_device == "pinned" + else dataset.graph.to(args.graph_device) + ) + features = ( + dataset.feature.pin_memory_() + if args.feature_device == "pinned" + else dataset.feature.to(args.feature_device) + ) + + train_set = dataset.tasks[0].train_set + valid_set = dataset.tasks[0].validation_set + test_set = dataset.tasks[0].test_set + all_nodes_set = dataset.all_nodes_set + args.fanout = list(map(int, args.fanout.split(","))) + num_classes = dataset.tasks[0].metadata["num_classes"] + + """ + If the CPU cache size is greater than 0, we wrap the DiskBasedFeature to be + a CPUCachedFeature. This internally manages the CPU feature cache by the + specified cache replacement policy. This will reduce the amount of data + transferred during disk read operations for this feature. + + Note: It is advised to set the CPU cache size to be at least 4 times the number + of sampled nodes in a mini-batch, otherwise the feature fetcher might get into + a deadlock, causing a hang. + """ + if args.cpu_cache_size_in_gigabytes > 0 and isinstance( + features[("node", None, "feat")], gb.DiskBasedFeature + ): + features[("node", None, "feat")] = gb.CPUCachedFeature( + features[("node", None, "feat")], + int(args.cpu_cache_size_in_gigabytes * 1000 * 1000 * 1000), + args.cpu_feature_cache_policy, + args.feature_device == "pinned", + ) + cpu_cached_feature = features[("node", None, "feat")] + cpu_cache_miss_rate_fn = lambda: cpu_cached_feature._feature.miss_rate + else: + cpu_cache_miss_rate_fn = lambda: 1 + + """ + If the GPU cache size is greater than 0, we wrap the underlying feature store + to be a GPUCachedFeature. This will reduce the amount of data transferred during + host-to-device copy operations for this feature. + """ + if args.gpu_cache_size_in_gigabytes > 0 and args.feature_device != "cuda": + features[("node", None, "feat")] = gb.GPUCachedFeature( + features[("node", None, "feat")], + int(args.gpu_cache_size_in_gigabytes * 1000 * 1000 * 1000), + ) + gpu_cached_feature = features[("node", None, "feat")] + gpu_cache_miss_rate_fn = lambda: gpu_cached_feature._feature.miss_rate + else: + gpu_cache_miss_rate_fn = lambda: 1 + + num_minibatch = (len(train_set) + args.batch_size - 1) // args.batch_size + train_sampler = create_sampler( + graph=graph, + itemset=train_set, + batch_size=args.batch_size, + fanout=args.fanout, + device=args.device, + job="train", + ) + train_dataloader = create_trainloader( + features=features, + itemset=gb.ItemSet(num_minibatch, "seeds"), + device=args.device, + job="train", + ) + valid_dataloader = create_dataloader( + graph=graph, + features=features, + itemset=valid_set, + batch_size=args.batch_size, + fanout=args.fanout, + device=args.device, + job="evaluate", + ) + + in_channels = features.size("node", None, "feat")[0] + model = SAGE( + in_channels, + args.num_hidden, + num_classes, + len(args.fanout), + args.dropout, + ).to(args.device) + assert len(args.fanout) == len(model.layers) + + perpare_time = time.time() - start + print(f"Prepare time: {perpare_time:.2f}s") + + start = time.time() + if not os.path.exists(args.subgraph_dir): + os.makedirs(args.subgraph_dir) + for it, minibatch in enumerate(tqdm(train_sampler, "Sampling")): + data = [ + minibatch.seeds.cpu(), + minibatch.input_nodes.cpu(), + minibatch.labels.cpu(), + [block.cpu() for block in minibatch.blocks], + ] + torch.save(data, f"{args.subgraph_dir}/train-{it}.pt") + print(f"Sampling time: {time.time() - start:.2f}s") + + best_model = train( + train_dataloader, + valid_dataloader, + model, + gpu_cache_miss_rate_fn, + cpu_cache_miss_rate_fn, + args.device, + ) + model.load_state_dict(best_model) + + if args.enable_inference: + # Test the model. + print("Testing...") + itemsets = {"train": train_set, "val": valid_set, "test": test_set} + final_acc = layerwise_infer( + args, + graph, + features, + itemsets, + all_nodes_set, + model, + ) + print("Final accuracy values:") + print(final_acc) + + +if __name__ == "__main__": + args = parse_args() + args.subgraph_dir = os.path.join( + args.root, + f"{args.dataset}-{args.batch_size}-{args.fanout}", + ) + print(args) + main() diff --git a/python/dgl/graphbolt/__init__.py b/python/dgl/graphbolt/__init__.py index 2c980dd532da..ccdadaaccf5b 100644 --- a/python/dgl/graphbolt/__init__.py +++ b/python/dgl/graphbolt/__init__.py @@ -89,6 +89,7 @@ def load_graphbolt(): from .negative_sampler import * from .sampled_subgraph import * from .subgraph_sampler import * +from .minibatch_provider import * from .external_utils import add_reverse_edges, exclude_seed_edges from .internal import ( compact_csc_format, diff --git a/python/dgl/graphbolt/minibatch.py b/python/dgl/graphbolt/minibatch.py index 7b34a8d5f1d3..1fef89416b21 100644 --- a/python/dgl/graphbolt/minibatch.py +++ b/python/dgl/graphbolt/minibatch.py @@ -173,10 +173,10 @@ def blocks(self) -> list: """DGL blocks extracted from `MiniBatch` containing graphical structures and ID mappings. """ - if not self.sampled_subgraphs: - return None + # if not self.sampled_subgraphs: + # return None - if self._blocks is None: + if self._blocks is None and self.sampled_subgraphs: self._blocks = self.compute_blocks() return self._blocks diff --git a/python/dgl/graphbolt/minibatch_provider.py b/python/dgl/graphbolt/minibatch_provider.py new file mode 100644 index 000000000000..3e533c646a75 --- /dev/null +++ b/python/dgl/graphbolt/minibatch_provider.py @@ -0,0 +1,30 @@ +"""Minibatch Loader""" + +import torch +from torch.utils.data import functional_datapipe + +from .minibatch_transformer import MiniBatchTransformer + +__all__ = [ + "MinibatchLoader", +] + + +@functional_datapipe("load_minibatch") +class MinibatchLoader(MiniBatchTransformer): + def __init__(self, datapipe, subgraph_dir: str): + self._subgraph_dir = subgraph_dir + datapipe = datapipe.transform(self._load_minibatch) + super().__init__(datapipe) + + def _load_minibatch(self, minibatch): + torch.cuda.synchronize() + nid = minibatch.seeds.item() + seeds, input_nodes, labels, blocks = torch.load( + f"{self._subgraph_dir}/train-{nid}.pt" + ) + minibatch.seeds = seeds + minibatch.input_nodes = input_nodes + minibatch.labels = labels + minibatch._blocks = blocks + return minibatch