Skip to content

Commit

Permalink
Merge branch 'my-first-pr' of github.com:pyynb/dgl-pyy-dev into my-fi…
Browse files Browse the repository at this point in the history
…rst-pr
  • Loading branch information
Ubuntu committed Apr 3, 2024
2 parents c00356b + 41f83f2 commit 904a12f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
12 changes: 7 additions & 5 deletions notebooks/stochastic_training/link_prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
},
"outputs": [],
"source": [
"dataset = gb.BuiltinDataset(\"cora\").load()"
"dataset = gb.BuiltinDataset(\"cora-seeds\").load()"
]
},
{
Expand Down Expand Up @@ -255,15 +255,16 @@
" total_loss = 0\n",
" for step, data in tqdm(enumerate(create_train_dataloader())):\n",
" # Get node pairs with labels for loss calculation.\n",
" compacted_pairs, labels = data.node_pairs_with_labels\n",
" compacted_seeds = data.compacted_seeds.T\n",
" labels = data.labels\n",
" node_feature = data.node_features[\"feat\"]\n",
" # Convert sampled subgraphs to DGL blocks.\n",
" blocks = data.blocks\n",
"\n",
" # Get the embeddings of the input nodes.\n",
" y = model(blocks, node_feature)\n",
" logits = model.predictor(\n",
" y[compacted_pairs[0]] * y[compacted_pairs[1]]\n",
" y[compacted_seeds[0]] * y[compacted_seeds[1]]\n",
" ).squeeze()\n",
"\n",
" # Compute loss.\n",
Expand Down Expand Up @@ -308,15 +309,16 @@
"labels = []\n",
"for step, data in tqdm(enumerate(eval_dataloader)):\n",
" # Get node pairs with labels for loss calculation.\n",
" compacted_pairs, label = data.node_pairs_with_labels\n",
" compacted_seeds = data.compacted_seeds.T\n",
" label = data.labels\n",
"\n",
" # The features of sampled nodes.\n",
" x = data.node_features[\"feat\"]\n",
"\n",
" # Forward.\n",
" y = model(data.blocks, x)\n",
" logit = (\n",
" model.predictor(y[compacted_pairs[0]] * y[compacted_pairs[1]])\n",
" model.predictor(y[compacted_seeds[0]] * y[compacted_seeds[1]])\n",
" .squeeze()\n",
" .detach()\n",
" )\n",
Expand Down
6 changes: 3 additions & 3 deletions notebooks/stochastic_training/node_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
},
"outputs": [],
"source": [
"dataset = gb.BuiltinDataset(\"ogbn-arxiv\").load()"
"dataset = gb.BuiltinDataset(\"ogbn-arxiv-seeds\").load()"
]
},
{
Expand Down Expand Up @@ -143,7 +143,7 @@
"source": [
"def create_dataloader(itemset, shuffle):\n",
" datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=shuffle)\n",
" datapipe = datapipe.copy_to(device, extra_attrs=[\"seed_nodes\"])\n",
" datapipe = datapipe.copy_to(device, extra_attrs=[\"seeds\"])\n",
" datapipe = datapipe.sample_neighbor(graph, [4, 4])\n",
" datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n",
" return gb.DataLoader(datapipe)"
Expand Down Expand Up @@ -375,4 +375,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

0 comments on commit 904a12f

Please sign in to comment.