Skip to content

Commit

Permalink
Improving HARD_GAT example.
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Jul 24, 2023
1 parent 766a73b commit 07e0355
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions examples/pytorch/hardgat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
num_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = g.num_edges()
print(
"""----Data statistics------'
Expand Down Expand Up @@ -115,7 +115,7 @@ def main(args):
)

# initialize graph

Check warning on line 117 in examples/pytorch/hardgat/train.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
dur = []
mean = 0
for epoch in range(args.epochs):
model.train()
if epoch >= 3:
Expand All @@ -129,29 +129,28 @@ def main(args):
optimizer.step()

if epoch >= 3:
dur.append(time.time() - t0)

train_acc = accuracy(logits[train_mask], labels[train_mask])

if args.fastmode:
val_acc = accuracy(logits[val_mask], labels[val_mask])
else:
val_acc = evaluate(model, features, labels, val_mask)
if args.early_stop:
if stopper.step(val_acc, model):
break

print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.item(),
train_acc,
val_acc,
n_edges / np.mean(dur) / 1000,
mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)
train_acc = accuracy(logits[train_mask], labels[train_mask])

if args.fastmode:
val_acc = accuracy(logits[val_mask], labels[val_mask])
else:
val_acc = evaluate(model, features, labels, val_mask)
if args.early_stop:
if stopper.step(val_acc, model):
break

print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch,
mean,
loss.item(),
train_acc,
val_acc,
n_edges / mean / 1000,
)
)
)

print()
if args.early_stop:
Expand Down

0 comments on commit 07e0355

Please sign in to comment.