diff --git a/examples/pytorch/ogb/ogbn-products/mlp/mlp.py b/examples/pytorch/ogb/ogbn-products/mlp/mlp.py index 244ed78cd5f8..4c1fa8228ded 100755 --- a/examples/pytorch/ogb/ogbn-products/mlp/mlp.py +++ b/examples/pytorch/ogb/ogbn-products/mlp/mlp.py @@ -99,23 +99,24 @@ def train( preds = torch.zeros(labels.shape[0], n_classes) - for _input_nodes, output_nodes, subgraphs in dataloader: - subgraphs = [b.to(device) for b in subgraphs] - new_train_idx = list(range(len(output_nodes))) + with dataloader.enable_cpu_affinity(): + for _input_nodes, output_nodes, subgraphs in dataloader: + subgraphs = [b.to(device) for b in subgraphs] + new_train_idx = list(range(len(output_nodes))) - pred = model(subgraphs[0].srcdata["feat"]) - preds[output_nodes] = pred.cpu().detach() + pred = model(subgraphs[0].srcdata["feat"]) + preds[output_nodes] = pred.cpu().detach() - loss = criterion( - pred[new_train_idx], labels[output_nodes][new_train_idx] - ) - optimizer.zero_grad() - loss.backward() - optimizer.step() + loss = criterion( + pred[new_train_idx], labels[output_nodes][new_train_idx] + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() - count = len(new_train_idx) - loss_sum += loss.item() * count - total += count + count = len(new_train_idx) + loss_sum += loss.item() * count + total += count preds = preds.to(train_idx.device) return ( @@ -143,11 +144,12 @@ def evaluate( eval_times = 1 # Due to the limitation of memory capacity, we calculate the average of logits 'eval_times' times. for _ in range(eval_times): - for _input_nodes, output_nodes, subgraphs in dataloader: - subgraphs = [b.to(device) for b in subgraphs] + with dataloader.enable_cpu_affinity(): + for _input_nodes, output_nodes, subgraphs in dataloader: + subgraphs = [b.to(device) for b in subgraphs] - pred = model(subgraphs[0].srcdata["feat"]) - preds[output_nodes] = pred + pred = model(subgraphs[0].srcdata["feat"]) + preds[output_nodes] = pred preds /= eval_times