Skip to content

Commit

Permalink
add caching support.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 21, 2024
1 parent c739063 commit 26841f3
Showing 1 changed file with 170 additions and 57 deletions.
227 changes: 170 additions & 57 deletions examples/graphbolt/pyg/hetero/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,28 @@ def evaluate_step(minibatch, model):


@torch.no_grad()
def evaluate(model, dataloader, device):
def evaluate(
model,
dataloader,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
):
model.eval()
total_correct = torch.zeros(1, dtype=torch.float64, device=device)
total_samples = 0
for minibatch in tqdm(dataloader, desc="Evaluating"):
dataloader = tqdm(dataloader, desc="Evaluating")
for step, minibatch in enumerate(dataloader):
num_correct, num_samples = evaluate_step(minibatch, model)
total_correct += num_correct
total_samples += num_samples
if step % 15 == 0:
dataloader.set_postfix(
{
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
"cpu_cache_miss": cpu_cache_miss_rate_fn(),
}
)

return total_correct / total_samples

Expand All @@ -266,34 +280,70 @@ def train_step(minibatch, optimizer, model, loss_fn):
return loss.detach(), num_correct, labels.size(0)


def train_helper(dataloader, model, optimizer, loss_fn, device):
def train_helper(
dataloader,
model,
optimizer,
loss_fn,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
):
model.train()
total_loss = torch.zeros(1, device=device)
total_correct = torch.zeros(1, dtype=torch.float64, device=device)
total_samples = 0
start = time.time()
for minibatch in tqdm(dataloader, "Training"):
dataloader = tqdm(dataloader, "Training")
for step, minibatch in enumerate(dataloader):
loss, num_correct, num_samples = train_step(
minibatch, optimizer, model, loss_fn
)
total_loss += loss * num_samples
total_correct += num_correct
total_samples += num_samples
if step % 15 == 0:
# log every 15 steps for performance.
dataloader.set_postfix(
{
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
"cpu_cache_miss": cpu_cache_miss_rate_fn(),
}
)
loss = total_loss / total_samples
acc = total_correct / total_samples
end = time.time()
return loss, acc, end - start


def train(train_dataloader, valid_dataloader, model, device):
def train(
train_dataloader,
valid_dataloader,
model,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
):
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(args.epochs):
train_loss, train_acc, duration = train_helper(
train_dataloader, model, optimizer, loss_fn, device
train_dataloader,
model,
optimizer,
loss_fn,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
)
val_acc = evaluate(
model,
valid_dataloader,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
)
val_acc = evaluate(model, valid_dataloader, device)
print(
f"Epoch: {epoch:02d}, Loss: {train_loss.item():.4f}, "
f"Approx. Train: {train_acc.item():.4f}, "
Expand All @@ -302,6 +352,73 @@ def train(train_dataloader, valid_dataloader, model, device):
)


def parse_args():
parser = argparse.ArgumentParser(description="GraphBolt PyG R-SAGE")
parser.add_argument(
"--epochs", type=int, default=10, help="Number of training epochs."
)
parser.add_argument(
"--lr",
type=float,
default=0.003,
help="Learning rate for optimization.",
)
parser.add_argument(
"--batch-size", type=int, default=1024, help="Batch size for training."
)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument(
"--dataset",
type=str,
default="ogb-lsc-mag240m",
choices=["ogb-lsc-mag240m"],
help="Dataset name. Possible values: ogb-lsc-mag240m",
)
parser.add_argument(
"--fanout",
type=str,
default="10,25",
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: ",
)
parser.add_argument(
"--mode",
default="pinned-pinned-cuda",
choices=[
"cpu-cpu-cpu",
"cpu-cpu-cuda",
"cpu-pinned-cuda",
"pinned-pinned-cuda",
"cuda-pinned-cuda",
"cuda-cuda-cuda",
],
help="Graph storage - feature storage - Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
)
parser.add_argument(
"--cpu-feature-cache-policy",
type=str,
default=None,
choices=["s3-fifo", "sieve", "lru", "clock"],
help="The cache policy for the CPU feature cache.",
)
parser.add_argument(
"--cpu-cache-size",
type=float,
default=0,
help="The capacity of the CPU cache in GiB.",
)
parser.add_argument(
"--gpu-cache-size",
type=float,
default=0,
help="The capacity of the GPU cache in GiB.",
)

parser.add_argument("--precision", type=str, default="high")
return parser.parse_args()


def main():
torch.set_float32_matmul_precision(args.precision)
if not torch.cuda.is_available():
Expand Down Expand Up @@ -335,6 +452,37 @@ def main():
num_classes = dataset.tasks[0].metadata["num_classes"]
num_etypes = len(graph.num_edges)

feats_on_disk = {
k: features[k]
for k in features.keys()
if k[2] == "feat" and isinstance(features[k], gb.DiskBasedFeature)
}

if args.cpu_cache_size > 0 and len(feats_on_disk) > 0:
cached_features = gb.cpu_cached_feature(
feats_on_disk,
int(args.cpu_cache_size * (2**30)),
args.cpu_feature_cache_policy,
args.feature_device == "pinned",
)
for k, feature in cached_features.items():
features[k] = feature
cpu_cache_miss_rate_fn = lambda: feature.miss_rate
else:
cpu_cache_miss_rate_fn = lambda: 1

if args.gpu_cache_size > 0 and args.feature_device != "cuda":
feats = {k: features[k] for k in features.keys() if k[2] == "feat"}
cached_features = gb.gpu_cached_feature(
feats,
int(args.gpu_cache_size * (2**30)),
)
for k, feature in cached_features.items():
features[k] = feature
gpu_cache_miss_rate_fn = lambda: feature.miss_rate
else:
gpu_cache_miss_rate_fn = lambda: 1

train_dataloader, valid_dataloader, test_dataloader = (
create_dataloader(
graph=graph,
Expand Down Expand Up @@ -365,59 +513,24 @@ def main():
f"{sum(p.numel() for p in model.parameters())}"
)

train(train_dataloader, valid_dataloader, model, args.device)
train(
train_dataloader,
valid_dataloader,
model,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
args.device,
)

print("Testing...")
test_acc = evaluate(model, test_dataloader, args.device)
print(f"Test accuracy {test_acc.item():.4f}")


def parse_args():
parser = argparse.ArgumentParser(description="GraphBolt PyG R-SAGE")
parser.add_argument(
"--epochs", type=int, default=10, help="Number of training epochs."
)
parser.add_argument(
"--lr",
type=float,
default=0.003,
help="Learning rate for optimization.",
)
parser.add_argument(
"--batch-size", type=int, default=1024, help="Batch size for training."
)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument(
"--dataset",
type=str,
default="ogb-lsc-mag240m",
choices=["ogb-lsc-mag240m"],
help="Dataset name. Possible values: ogb-lsc-mag240m",
)
parser.add_argument(
"--fanout",
type=str,
default="10,25",
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: ",
)
parser.add_argument(
"--mode",
default="pinned-pinned-cuda",
choices=[
"cpu-cpu-cpu",
"cpu-cpu-cuda",
"cpu-pinned-cuda",
"pinned-pinned-cuda",
"cuda-pinned-cuda",
"cuda-cuda-cuda",
],
help="Graph storage - feature storage - Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
test_acc = evaluate(
model,
test_dataloader,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
args.device,
)

parser.add_argument("--precision", type=str, default="high")
return parser.parse_args()
print(f"Test accuracy {test_acc.item():.4f}")


if __name__ == "__main__":
Expand Down

0 comments on commit 26841f3

Please sign in to comment.