Skip to content

Commit

Permalink
Improving the MLP example.
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Nov 21, 2023
1 parent 34da58d commit 7a350af
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions examples/pytorch/ogb/ogbn-products/mlp/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,24 @@ def train(

preds = torch.zeros(labels.shape[0], n_classes)

for _input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
new_train_idx = list(range(len(output_nodes)))
with dataloader.enable_cpu_affinity():
for _input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
new_train_idx = list(range(len(output_nodes)))

pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred.cpu().detach()
pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred.cpu().detach()

loss = criterion(
pred[new_train_idx], labels[output_nodes][new_train_idx]
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss = criterion(
pred[new_train_idx], labels[output_nodes][new_train_idx]
)
optimizer.zero_grad()
loss.backward()
optimizer.step()

count = len(new_train_idx)
loss_sum += loss.item() * count
total += count
count = len(new_train_idx)
loss_sum += loss.item() * count
total += count

preds = preds.to(train_idx.device)
return (
Expand Down Expand Up @@ -143,11 +144,12 @@ def evaluate(
eval_times = 1 # Due to the limitation of memory capacity, we calculate the average of logits 'eval_times' times.

for _ in range(eval_times):
for _input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
with dataloader.enable_cpu_affinity():
for _input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]

pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred
pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred

preds /= eval_times

Expand Down

0 comments on commit 7a350af

Please sign in to comment.