diff --git a/docs/source/guide/minibatch-custom-sampler.rst b/docs/source/guide/minibatch-custom-sampler.rst index 5ca5464ebfa9..80473122fbaa 100644 --- a/docs/source/guide/minibatch-custom-sampler.rst +++ b/docs/source/guide/minibatch-custom-sampler.rst @@ -79,11 +79,11 @@ can be used on heterogeneous graphs: { "user": gb.ItemSet( (torch.arange(0, 5), torch.arange(5, 10)), - names=("seed_nodes", "labels"), + names=("seeds", "labels"), ), "item": gb.ItemSet( (torch.arange(5, 10), torch.arange(10, 15)), - names=("seed_nodes", "labels"), + names=("seeds", "labels"), ), } ) diff --git a/docs/source/guide/minibatch-edge.rst b/docs/source/guide/minibatch-edge.rst index ee7cd85c676b..ae1ad9f49b90 100644 --- a/docs/source/guide/minibatch-edge.rst +++ b/docs/source/guide/minibatch-edge.rst @@ -30,9 +30,9 @@ edges(namely, node pairs) in the training set instead of the nodes. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') g = gb.SamplingGraph() - node_paris = torch.arange(0, 1000).reshape(-1, 2) + seeds = torch.arange(0, 1000).reshape(-1, 2) labels = torch.randint(0, 2, (5,)) - train_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels")) + train_set = gb.ItemSet((seeds, labels), names=("seeds", "labels")) datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True) datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers. # Or equivalently: @@ -83,9 +83,9 @@ You can use :func:`~dgl.graphbolt.exclude_seed_edges` alongside with device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') g = gb.SamplingGraph() - node_paris = torch.arange(0, 1000).reshape(-1, 2) + seeds = torch.arange(0, 1000).reshape(-1, 2) labels = torch.randint(0, 2, (5,)) - train_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels")) + train_set = gb.ItemSet((seeds, labels), names=("seeds", "labels")) datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True) datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers. exclude_seed_edges = partial(gb.exclude_seed_edges, include_reverse_edges=True) @@ -138,9 +138,9 @@ concatenating the incident node features and projecting it with a dense layer. super().__init__() self.W = nn.Linear(2 * in_features, num_classes) - def forward(self, node_pairs, x): - src_x = x[node_pairs[0]] - dst_x = x[node_pairs[1]] + def forward(self, seeds, x): + src_x = x[seeds[:, 0]] + dst_x = x[seeds[:, 1]] data = torch.cat([src_x, dst_x], 1) return self.W(data) @@ -157,9 +157,9 @@ loader, as well as the input node features as follows: in_features, hidden_features, out_features) self.predictor = ScorePredictor(num_classes, out_features) - def forward(self, blocks, x, node_pairs): + def forward(self, blocks, x, seeds): x = self.gcn(blocks, x) - return self.predictor(node_pairs, x) + return self.predictor(seeds, x) DGL ensures that that the nodes in the edge subgraph are the same as the output nodes of the last MFG in the generated list of MFGs. @@ -182,7 +182,7 @@ their incident node representations. for data in dataloader: blocks = data.blocks x = data.edge_features("feat") - y_hat = model(data.blocks, x, data.positive_node_pairs) + y_hat = model(data.blocks, x, data.compacted_seeds) loss = F.cross_entropy(data.labels, y_hat) opt.zero_grad() loss.backward() @@ -226,10 +226,10 @@ over the edge types. super().__init__() self.W = nn.Linear(2 * in_features, num_classes) - def forward(self, node_pairs, x): + def forward(self, seeds, x): scores = {} - for etype in node_pairs.keys(): - src, dst = node_pairs[etype] + for etype in seeds.keys(): + src, dst = seeds[etype].T data = torch.cat([x[etype][src], x[etype][dst]], 1) scores[etype] = self.W(data) return scores @@ -242,9 +242,9 @@ over the edge types. in_features, hidden_features, out_features, etypes) self.pred = ScorePredictor(num_classes, out_features) - def forward(self, node_pairs, blocks, x): + def forward(self, seeds, blocks, x): x = self.rgcn(blocks, x) - return self.pred(node_pairs, x) + return self.pred(seeds, x) Data loader definition is almost identical to that of homogeneous graph. The only difference is that the train_set is now an instance of @@ -256,17 +256,17 @@ only difference is that the train_set is now an instance of device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') g = gb.SamplingGraph() - node_pairs = torch.arange(0, 1000).reshape(-1, 2) + seeds = torch.arange(0, 1000).reshape(-1, 2) labels = torch.randint(0, 3, (1000,)) - node_pairs_labels = { + seeds_labels = { "user:like:item": gb.ItemSet( - (node_pairs, labels), names=("node_pairs", "labels") + (seeds, labels), names=("seeds", "labels") ), "user:follow:user": gb.ItemSet( - (node_pairs, labels), names=("node_pairs", "labels") + (seeds, labels), names=("seeds", "labels") ), } - train_set = gb.ItemSetDict(node_pairs_labels) + train_set = gb.ItemSetDict(seeds_labels) datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True) datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers. datapipe = datapipe.fetch_feature( @@ -316,7 +316,7 @@ dictionaries of node types and predictions here. for data in dataloader: blocks = data.blocks x = data.edge_features(("user:like:item", "feat")) - y_hat = model(data.blocks, x, data.positive_node_pairs) + y_hat = model(data.blocks, x, data.compacted_seeds) loss = F.cross_entropy(data.labels, y_hat) opt.zero_grad() loss.backward() diff --git a/docs/source/guide/minibatch-inference.rst b/docs/source/guide/minibatch-inference.rst index 2f2303e60cf7..54446e26828b 100644 --- a/docs/source/guide/minibatch-inference.rst +++ b/docs/source/guide/minibatch-inference.rst @@ -106,7 +106,7 @@ and combined as well. hidden_x = self.dropout(hidden_x) # By design, our output nodes are contiguous. y[ - data.seed_nodes[0] : data.seed_nodes[-1] + 1 + data.seeds[0] : data.seeds[-1] + 1 ] = hidden_x.to(device) feature = y diff --git a/docs/source/guide/minibatch-link.rst b/docs/source/guide/minibatch-link.rst index a3dbc341d742..ad1fc1d3d9f1 100644 --- a/docs/source/guide/minibatch-link.rst +++ b/docs/source/guide/minibatch-link.rst @@ -53,8 +53,8 @@ proportional to a power of degrees. self.weights = node_degrees ** 0.75 self.k = k - def _sample_with_etype(node_pairs, etype=None): - src, _ = node_pairs + def _sample_with_etype(self, seeds, etype=None): + src, _ = seeds.T src = src.repeat_interleave(self.k) dst = self.weights.multinomial(len(src), replacement=True) return src, dst @@ -95,7 +95,7 @@ Define a GraphSAGE model for minibatch training When a negative sampler is provided, the data loader will generate positive and negative node pairs for each minibatch besides the *Message Flow Graphs* (MFGs). -Use `node_pairs_with_labels` to get compact node pairs with corresponding +Use `compacted_seeds` and `labels` to get compact node pairs and corresponding labels. @@ -116,7 +116,8 @@ above. start_epoch_time = time.time() for step, data in enumerate(dataloader): # Unpack MiniBatch. - compacted_pairs, labels = data.node_pairs_with_labels + compacted_seeds = data.compacted_seeds.T + labels = data.labels node_feature = data.node_features["feat"] # Convert sampled subgraphs to DGL blocks. blocks = data.blocks @@ -124,7 +125,7 @@ above. # Get the embeddings of the input nodes. y = model(blocks, node_feature) logits = model.predictor( - y[compacted_pairs[0]] * y[compacted_pairs[1]] + y[compacted_seeds[0]] * y[compacted_seeds[1]] ).squeeze() # Compute loss. @@ -217,8 +218,8 @@ If you want to give your own negative sampling function, just inherit from the } self.k = k - def _sample_with_etype(node_pairs, etype): - src, _ = node_pairs + def _sample_with_etype(self, seeds, etype): + src, _ = seeds.T src = src.repeat_interleave(self.k) dst = self.weights[etype].multinomial(len(src), replacement=True) return src, dst @@ -241,7 +242,8 @@ loss on specific edge type. start_epoch_time = time.time() for step, data in enumerate(dataloader): # Unpack MiniBatch. - compacted_pairs, labels = data.node_pairs_with_labels + compacted_seeds = data.compacted_seeds + labels = data.labels node_features = { ntype: data.node_features[(ntype, "feat")] for ntype in data.blocks[0].srctypes @@ -251,8 +253,8 @@ loss on specific edge type. # Get the embeddings of the input nodes. y = model(blocks, node_feature) logits = model.predictor( - y[category][compacted_pairs[category][0]] - * y[category][compacted_pairs[category][1]] + y[category][compacted_pairs[category][:, 0]] + * y[category][compacted_pairs[category][:, 1]] ).squeeze() # Compute loss. diff --git a/docs/source/stochastic_training/ondisk-dataset-specification.rst b/docs/source/stochastic_training/ondisk-dataset-specification.rst index 0587b26a8806..96f72227da88 100644 --- a/docs/source/stochastic_training/ondisk-dataset-specification.rst +++ b/docs/source/stochastic_training/ondisk-dataset-specification.rst @@ -201,9 +201,8 @@ such as ``num_classes`` and all these fields will be passed to the The ``name`` field is used to specify the name of the data. It is mandatory and used to specify the data fields of ``MiniBatch`` for sampling. It can - be either ``seed_nodes``, ``labels``, ``node_pairs``, ``negative_srcs`` or - ``negative_dsts``. If any other name is used, it will be added into the - ``MiniBatch`` data fields. + be either ``seeds``, ``labels`` or ``indexes``. If any other name is used, + it will be added into the ``MiniBatch`` data fields. - ``format``: ``string`` The ``format`` field is used to specify the format of the data. It can be diff --git a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h index 758c3204510c..37d03437f70e 100644 --- a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h +++ b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h @@ -415,6 +415,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { private: template c10::intrusive_ptr SampleNeighborsImpl( + const torch::Tensor& seeds, + torch::optional>& seed_offsets, + const std::vector& fanouts, bool return_eids, + NumPickFn num_pick_fn, PickFn pick_fn) const; + + template + c10::intrusive_ptr TemporalSampleNeighborsImpl( const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn, PickFn pick_fn) const; @@ -498,13 +505,14 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * @param offset The starting edge ID for the connected neighbors of the given * node. * @param num_neighbors The number of neighbors of this node. - * - * @return The pick number of the given node. + * @param num_picked_ptr The pointer of the tensor which stores the pick + * numbers. */ -int64_t NumPick( +template +void NumPick( int64_t fanout, bool replace, const torch::optional& probs_or_mask, int64_t offset, - int64_t num_neighbors); + int64_t num_neighbors, PickedNumType* num_picked_ptr); int64_t TemporalNumPick( torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout, @@ -513,11 +521,13 @@ int64_t TemporalNumPick( const torch::optional& edge_timestamp, int64_t seed_offset, int64_t offset, int64_t num_neighbors); -int64_t NumPickByEtype( - const std::vector& fanouts, bool replace, +template +void NumPickByEtype( + bool with_seed_offsets, const std::vector& fanouts, bool replace, const torch::Tensor& type_per_edge, const torch::optional& probs_or_mask, int64_t offset, - int64_t num_neighbors); + int64_t num_neighbors, PickedNumType* num_picked_ptr, int64_t seed_index, + const std::vector& etype_id_to_num_picked_offset); int64_t TemporalNumPickByEtype( torch::Tensor seed_timestamp, torch::Tensor csc_indices, @@ -610,16 +620,24 @@ int64_t TemporalPick( * probabilities associated with each neighboring edge of a node in the original * graph. It must be a 1D floating-point tensor with the number of elements * equal to the number of edges in the graph. - * @param picked_data_ptr The destination address where the picked neighbors + * @param picked_data_ptr The pointer of the tensor where the picked neighbors * should be put. Enough memory space should be allocated in advance. + * @param seed_offset The offset(index) of the seed among the group of seeds + * which share the same node type. + * @param subgraph_indptr_ptr The pointer of the tensor which stores the indptr + * of the sampled subgraph. + * @param etype_id_to_num_picked_offset A vector storing the mappings from each + * etype_id to the offset of its pick numbers in the tensor. */ template int64_t PickByEtype( - int64_t offset, int64_t num_neighbors, const std::vector& fanouts, - bool replace, const torch::TensorOptions& options, - const torch::Tensor& type_per_edge, + bool with_seed_offsets, int64_t offset, int64_t num_neighbors, + const std::vector& fanouts, bool replace, + const torch::TensorOptions& options, const torch::Tensor& type_per_edge, const torch::optional& probs_or_mask, SamplerArgs args, - PickedType* picked_data_ptr); + PickedType* picked_data_ptr, int64_t seed_offset, + PickedType* subgraph_indptr_ptr, + const std::vector& etype_id_to_num_picked_offset); template int64_t TemporalPickByEtype( diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index 87f5cbeca3dc..8e1535df59a2 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -18,6 +18,7 @@ #include #include +#include "./expand_indptr.h" #include "./macro.h" #include "./random.h" #include "./shared_memory_helper.h" @@ -355,17 +356,23 @@ c10::intrusive_ptr FusedCSCSamplingGraph::InSubgraph( auto GetNumPickFn( const std::vector& fanouts, bool replace, const torch::optional& type_per_edge, - const torch::optional& probs_or_mask) { + const torch::optional& probs_or_mask, + bool with_seed_offsets) { // If fanouts.size() > 1, returns the total number of all edge types of the // given node. - return [&fanouts, replace, &probs_or_mask, &type_per_edge]( - int64_t seed_offset, int64_t offset, int64_t num_neighbors) { + return [&fanouts, replace, &probs_or_mask, &type_per_edge, with_seed_offsets]( + int64_t offset, int64_t num_neighbors, auto num_picked_ptr, + int64_t seed_index, + const std::vector& etype_id_to_num_picked_offset) { if (fanouts.size() > 1) { - return NumPickByEtype( - fanouts, replace, type_per_edge.value(), probs_or_mask, offset, - num_neighbors); + NumPickByEtype( + with_seed_offsets, fanouts, replace, type_per_edge.value(), + probs_or_mask, offset, num_neighbors, num_picked_ptr, seed_index, + etype_id_to_num_picked_offset); } else { - return NumPick(fanouts[0], replace, probs_or_mask, offset, num_neighbors); + NumPick( + fanouts[0], replace, probs_or_mask, offset, num_neighbors, + num_picked_ptr + seed_index); } }; } @@ -423,21 +430,25 @@ auto GetPickFn( const std::vector& fanouts, bool replace, const torch::TensorOptions& options, const torch::optional& type_per_edge, - const torch::optional& probs_or_mask, SamplerArgs args) { - return [&fanouts, replace, &options, &type_per_edge, &probs_or_mask, args]( - int64_t seed_offset, int64_t offset, int64_t num_neighbors, - auto picked_data_ptr) { + const torch::optional& probs_or_mask, bool with_seed_offsets, + SamplerArgs args) { + return [&fanouts, replace, &options, &type_per_edge, &probs_or_mask, args, + with_seed_offsets]( + int64_t offset, int64_t num_neighbors, auto picked_data_ptr, + int64_t seed_offset, auto subgraph_indptr_ptr, + const std::vector& etype_id_to_num_picked_offset) { // If fanouts.size() > 1, perform sampling for each edge type of each // node; otherwise just sample once for each node with no regard of edge // types. if (fanouts.size() > 1) { return PickByEtype( - offset, num_neighbors, fanouts, replace, options, - type_per_edge.value(), probs_or_mask, args, picked_data_ptr); + with_seed_offsets, offset, num_neighbors, fanouts, replace, options, + type_per_edge.value(), probs_or_mask, args, picked_data_ptr, + seed_offset, subgraph_indptr_ptr, etype_id_to_num_picked_offset); } else { int64_t num_sampled = Pick( offset, num_neighbors, fanouts[0], replace, options, probs_or_mask, - args, picked_data_ptr); + args, picked_data_ptr + subgraph_indptr_ptr[seed_offset]); if (type_per_edge) { std::sort(picked_data_ptr, picked_data_ptr + num_sampled); } @@ -484,6 +495,304 @@ auto GetTemporalPickFn( template c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighborsImpl( + const torch::Tensor& seeds, + torch::optional>& seed_offsets, + const std::vector& fanouts, bool return_eids, + NumPickFn num_pick_fn, PickFn pick_fn) const { + const int64_t num_seeds = seeds.size(0); + const auto indptr_options = indptr_.options(); + + // Calculate GrainSize for parallel_for. + // Set the default grain size to 64. + const int64_t grain_size = 64; + torch::Tensor picked_eids; + torch::Tensor subgraph_indptr; + torch::Tensor subgraph_indices; + torch::optional subgraph_type_per_edge = torch::nullopt; + torch::optional edge_offsets = torch::nullopt; + + bool with_seed_offsets = seed_offsets.has_value(); + bool hetero_with_seed_offsets = with_seed_offsets && fanouts.size() > 1; + + // Get the number of edge types. If it's homo or if the size of fanouts is 1 + // (hetero graph but sampled as a homo graph), set num_etypes as 1. + // In temporal sampling, this will not be used for now since the logic hasn't + // been adopted for temporal sampling. + const int64_t num_etypes = + (edge_type_to_id_.has_value() && hetero_with_seed_offsets) + ? edge_type_to_id_->size() + : 1; + std::vector etype_id_to_src_ntype_id(num_etypes); + std::vector etype_id_to_dst_ntype_id(num_etypes); + torch::optional subgraph_indptr_substract = torch::nullopt; + // The pick numbers are stored in a single tensor by the order of etype. Each + // etype corresponds to a group of seeds whose ntype are the same as the + // dst_type. `etype_id_to_num_picked_offset` indicates the beginning offset + // where each etype's corresponding seeds' pick numbers are stored in the pick + // number tensor. + std::vector etype_id_to_num_picked_offset(num_etypes + 1); + if (hetero_with_seed_offsets) { + for (auto& etype_and_id : edge_type_to_id_.value()) { + auto etype = etype_and_id.key(); + auto id = etype_and_id.value(); + auto [src_type, dst_type] = utils::parse_src_dst_ntype_from_etype(etype); + auto dst_ntype_id = node_type_to_id_->at(dst_type); + etype_id_to_src_ntype_id[id] = node_type_to_id_->at(src_type); + etype_id_to_dst_ntype_id[id] = dst_ntype_id; + etype_id_to_num_picked_offset[id + 1] = + seed_offsets->at(dst_ntype_id + 1) - seed_offsets->at(dst_ntype_id) + + 1; + } + std::partial_sum( + etype_id_to_num_picked_offset.begin(), + etype_id_to_num_picked_offset.end(), + etype_id_to_num_picked_offset.begin()); + } else { + etype_id_to_dst_ntype_id[0] = 0; + etype_id_to_num_picked_offset[1] = num_seeds + 1; + } + // `num_rows` indicates the length of `num_picked_neighbors_per_node`, which + // is used for storing pick numbers. In non-temporal hetero sampling, it + // equals to sum_{etype} #seeds with ntype=dst_type(etype). In homo sampling, + // it equals to `num_seeds`. + const int64_t num_rows = etype_id_to_num_picked_offset[num_etypes]; + torch::Tensor num_picked_neighbors_per_node = + torch::empty({num_rows}, indptr_options); + + AT_DISPATCH_INDEX_TYPES( + indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] { + using indptr_t = index_t; + AT_DISPATCH_INDEX_TYPES( + seeds.scalar_type(), "SampleNeighborsImplWrappedWithSeeds", ([&] { + using seeds_t = index_t; + const auto indptr_data = indptr_.data_ptr(); + const auto num_picked_neighbors_data_ptr = + num_picked_neighbors_per_node.data_ptr(); + num_picked_neighbors_data_ptr[0] = 0; + const auto seeds_data_ptr = seeds.data_ptr(); + + // Initialize the empty spots in `num_picked_neighbors_per_node`. + if (hetero_with_seed_offsets) { + for (auto i = 0; i < num_etypes; ++i) { + num_picked_neighbors_data_ptr + [etype_id_to_num_picked_offset[i]] = 0; + } + } + + // Step 1. Calculate pick number of each node. + torch::parallel_for( + 0, num_seeds, grain_size, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + const auto nid = seeds_data_ptr[i]; + TORCH_CHECK( + nid >= 0 && nid < NumNodes(), + "The seed nodes' IDs should fall within the range of " + "the graph's node IDs."); + const auto offset = indptr_data[nid]; + const auto num_neighbors = indptr_data[nid + 1] - offset; + + const auto seed_type_id = + (hetero_with_seed_offsets) + ? std::upper_bound( + seed_offsets->begin(), seed_offsets->end(), + i) - + seed_offsets->begin() - 1 + : 0; + // `seed_index` indicates the index of the current + // seed within the group of seeds which have the same + // node type. + const auto seed_index = + (hetero_with_seed_offsets) + ? i - seed_offsets->at(seed_type_id) + : i; + num_pick_fn( + offset, num_neighbors, + num_picked_neighbors_data_ptr + 1, seed_index, + etype_id_to_num_picked_offset); + } + }); + + if (hetero_with_seed_offsets) { + torch::Tensor num_picked_offset_tensor = + torch::zeros({num_etypes + 1}, indptr_options); + torch::Tensor substract_offset = + torch::zeros({num_etypes}, indptr_options); + const auto substract_offset_data_ptr = + substract_offset.data_ptr(); + const auto num_picked_offset_data_ptr = + num_picked_offset_tensor.data_ptr(); + for (auto i = 0; i < num_etypes; ++i) { + num_picked_offset_data_ptr[i + 1] = + etype_id_to_num_picked_offset[i + 1]; + // Collect the total pick number for each edge type. + if (i + 1 < num_etypes) + substract_offset_data_ptr[i + 1] = + num_picked_neighbors_data_ptr + [etype_id_to_num_picked_offset[i]]; + num_picked_neighbors_data_ptr + [etype_id_to_num_picked_offset[i]] = 0; + } + substract_offset = + substract_offset.cumsum(0, indptr_.scalar_type()); + subgraph_indptr_substract = ops::ExpandIndptr( + num_picked_offset_tensor, indptr_.scalar_type(), + substract_offset); + } + + // Step 2. Calculate prefix sum to get total length and offsets of + // each node. It's also the indptr of the generated subgraph. + subgraph_indptr = num_picked_neighbors_per_node.cumsum( + 0, indptr_.scalar_type()); + auto subgraph_indptr_data_ptr = + subgraph_indptr.data_ptr(); + + // When doing non-temporal hetero sampling, we generate an + // edge_offsets tensor. + if (hetero_with_seed_offsets) { + edge_offsets = torch::empty({num_etypes + 1}, indptr_options); + AT_DISPATCH_INTEGRAL_TYPES( + edge_offsets.value().scalar_type(), "CalculateEdgeOffsets", + ([&] { + auto edge_offsets_data_ptr = + edge_offsets.value().data_ptr(); + edge_offsets_data_ptr[0] = 0; + for (auto i = 0; i < num_etypes; ++i) { + edge_offsets_data_ptr[i + 1] = subgraph_indptr_data_ptr + [etype_id_to_num_picked_offset[i + 1] - 1]; + } + })); + } + + // Step 3. Allocate the tensor for picked neighbors. + const auto total_length = + subgraph_indptr.data_ptr()[num_rows - 1]; + picked_eids = torch::empty({total_length}, indptr_options); + subgraph_indices = + torch::empty({total_length}, indices_.options()); + if (!hetero_with_seed_offsets && type_per_edge_.has_value()) { + subgraph_type_per_edge = torch::empty( + {total_length}, type_per_edge_.value().options()); + } + + auto picked_eids_data_ptr = picked_eids.data_ptr(); + torch::parallel_for( + 0, num_seeds, grain_size, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + const auto nid = seeds_data_ptr[i]; + const auto offset = indptr_data[nid]; + const auto num_neighbors = indptr_data[nid + 1] - offset; + auto picked_number = 0; + const auto seed_type_id = + (hetero_with_seed_offsets) + ? std::upper_bound( + seed_offsets->begin(), seed_offsets->end(), + i) - + seed_offsets->begin() - 1 + : 0; + const auto seed_index = + (hetero_with_seed_offsets) + ? i - seed_offsets->at(seed_type_id) + : i; + + // Step 4. Pick neighbors for each node. + picked_number = pick_fn( + offset, num_neighbors, picked_eids_data_ptr, + seed_index, subgraph_indptr_data_ptr, + etype_id_to_num_picked_offset); + if (!hetero_with_seed_offsets) { + TORCH_CHECK( + num_picked_neighbors_data_ptr[i + 1] == + picked_number, + "Actual picked count doesn't match the calculated " + "pick number."); + } + + // Step 5. Calculate other attributes and return the + // subgraph. + if (picked_number > 0) { + AT_DISPATCH_INDEX_TYPES( + subgraph_indices.scalar_type(), + "IndexSelectSubgraphIndices", ([&] { + auto subgraph_indices_data_ptr = + subgraph_indices.data_ptr(); + auto indices_data_ptr = + indices_.data_ptr(); + for (auto i = 0; i < num_etypes; ++i) { + if (etype_id_to_dst_ntype_id[i] != seed_type_id) + continue; + const auto indptr_offset = + with_seed_offsets + ? etype_id_to_num_picked_offset[i] + + seed_index + : seed_index; + const auto picked_begin = + subgraph_indptr_data_ptr[indptr_offset]; + const auto picked_end = + subgraph_indptr_data_ptr[indptr_offset + 1]; + for (auto j = picked_begin; j < picked_end; + ++j) { + subgraph_indices_data_ptr[j] = + indices_data_ptr[picked_eids_data_ptr[j]]; + if (hetero_with_seed_offsets && + node_type_offset_.has_value()) { + // Substract the node type offset from + // subgraph indices. Assuming + // node_type_offset has the same dtype as + // indices. + auto node_type_offset_data = + node_type_offset_.value() + .data_ptr(); + subgraph_indices_data_ptr[j] -= + node_type_offset_data + [etype_id_to_src_ntype_id[i]]; + } + } + } + })); + + if (!hetero_with_seed_offsets && + type_per_edge_.has_value()) { + // When hetero graph is sampled as a homo graph, we + // still generate type_per_edge tensor for this + // situation. + AT_DISPATCH_INTEGRAL_TYPES( + subgraph_type_per_edge.value().scalar_type(), + "IndexSelectTypePerEdge", ([&] { + auto subgraph_type_per_edge_data_ptr = + subgraph_type_per_edge.value() + .data_ptr(); + auto type_per_edge_data_ptr = + type_per_edge_.value().data_ptr(); + const auto picked_offset = + subgraph_indptr_data_ptr[seed_index]; + for (auto j = picked_offset; + j < picked_offset + picked_number; ++j) + subgraph_type_per_edge_data_ptr[j] = + type_per_edge_data_ptr + [picked_eids_data_ptr[j]]; + })); + } + } + } + }); + })); + })); + + torch::optional subgraph_reverse_edge_ids = torch::nullopt; + if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids); + + if (subgraph_indptr_substract.has_value()) { + subgraph_indptr -= subgraph_indptr_substract.value(); + } + + return c10::make_intrusive( + subgraph_indptr, subgraph_indices, seeds, torch::nullopt, + subgraph_reverse_edge_ids, subgraph_type_per_edge, edge_offsets); +} + +template +c10::intrusive_ptr +FusedCSCSamplingGraph::TemporalSampleNeighborsImpl( const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn, PickFn pick_fn) const { const int64_t num_nodes = nodes.size(0); @@ -663,6 +972,8 @@ c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( } } + bool with_seed_offsets = seed_offsets.has_value(); + if (layer) { if (random_seed.has_value() && random_seed->numel() >= 2) { SamplerArgs args{ @@ -670,11 +981,13 @@ c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( {random_seed.value(), static_cast(seed2_contribution)}, NumNodes()}; return SampleNeighborsImpl( - seeds.value(), return_eids, - GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask), + seeds.value(), seed_offsets, fanouts, return_eids, + GetNumPickFn( + fanouts, replace, type_per_edge_, probs_or_mask, + with_seed_offsets), GetPickFn( fanouts, replace, indptr_.options(), type_per_edge_, - probs_or_mask, args)); + probs_or_mask, with_seed_offsets, args)); } else { auto args = [&] { if (random_seed.has_value() && random_seed->numel() == 1) { @@ -689,20 +1002,23 @@ c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( } }(); return SampleNeighborsImpl( - seeds.value(), return_eids, - GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask), + seeds.value(), seed_offsets, fanouts, return_eids, + GetNumPickFn( + fanouts, replace, type_per_edge_, probs_or_mask, + with_seed_offsets), GetPickFn( fanouts, replace, indptr_.options(), type_per_edge_, - probs_or_mask, args)); + probs_or_mask, with_seed_offsets, args)); } } else { SamplerArgs args; return SampleNeighborsImpl( - seeds.value(), return_eids, - GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask), + seeds.value(), seed_offsets, fanouts, return_eids, + GetNumPickFn( + fanouts, replace, type_per_edge_, probs_or_mask, with_seed_offsets), GetPickFn( fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask, - args)); + with_seed_offsets, args)); } } @@ -734,7 +1050,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt( static_cast(0), std::numeric_limits::max()); SamplerArgs args{indices_, random_seed, NumNodes()}; - return SampleNeighborsImpl( + return TemporalSampleNeighborsImpl( input_nodes, return_eids, GetTemporalNumPickFn( input_nodes_timestamp, this->indices_, fanouts, replace, @@ -745,7 +1061,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( edge_timestamp, args)); } else { SamplerArgs args; - return SampleNeighborsImpl( + return TemporalSampleNeighborsImpl( input_nodes, return_eids, GetTemporalNumPickFn( input_nodes_timestamp, this->indices_, fanouts, replace, @@ -806,12 +1122,13 @@ void FusedCSCSamplingGraph::HoldSharedMemoryObject( tensor_data_shm_ = std::move(tensor_data_shm); } -int64_t NumPick( +template +void NumPick( int64_t fanout, bool replace, const torch::optional& probs_or_mask, int64_t offset, - int64_t num_neighbors) { + int64_t num_neighbors, PickedNumType* picked_num_ptr) { int64_t num_valid_neighbors = num_neighbors; - if (probs_or_mask.has_value()) { + if (probs_or_mask.has_value() && num_neighbors > 0) { // Subtract the count of zeros in probs_or_mask. AT_DISPATCH_ALL_TYPES( probs_or_mask.value().scalar_type(), "CountZero", ([&] { @@ -821,8 +1138,11 @@ int64_t NumPick( 0); })); } - if (num_valid_neighbors == 0 || fanout == -1) return num_valid_neighbors; - return replace ? fanout : std::min(fanout, num_valid_neighbors); + if (num_valid_neighbors == 0 || fanout == -1) { + *picked_num_ptr = num_valid_neighbors; + } else { + *picked_num_ptr = replace ? fanout : std::min(fanout, num_valid_neighbors); + } } torch::Tensor TemporalMask( @@ -949,14 +1269,16 @@ int64_t TemporalNumPick( return replace ? fanout : std::min(fanout, num_valid_neighbors); } -int64_t NumPickByEtype( - const std::vector& fanouts, bool replace, +template +void NumPickByEtype( + bool with_seed_offsets, const std::vector& fanouts, bool replace, const torch::Tensor& type_per_edge, const torch::optional& probs_or_mask, int64_t offset, - int64_t num_neighbors) { + int64_t num_neighbors, PickedNumType* num_picked_ptr, int64_t seed_index, + const std::vector& etype_id_to_num_picked_offset) { int64_t etype_begin = offset; const int64_t end = offset + num_neighbors; - int64_t total_count = 0; + PickedNumType total_count = 0; AT_DISPATCH_INTEGRAL_TYPES( type_per_edge.scalar_type(), "NumPickFnByEtype", ([&] { const scalar_t* type_per_edge_data = type_per_edge.data_ptr(); @@ -970,13 +1292,32 @@ int64_t NumPickByEtype( etype); int64_t etype_end = etype_end_it - type_per_edge_data; // Do sampling for one etype. - total_count += NumPick( - fanouts[etype], replace, probs_or_mask, etype_begin, - etype_end - etype_begin); + if (with_seed_offsets) { + // The pick numbers aren't stored continuously, but separately for + // each different etype. + const auto offset = + etype_id_to_num_picked_offset[etype] + seed_index; + NumPick( + fanouts[etype], replace, probs_or_mask, etype_begin, + etype_end - etype_begin, num_picked_ptr + offset); + // Use the skipped position of each edge type in the + // num_picked_tensor to sum up the total pick number for each edge + // type. + num_picked_ptr[etype_id_to_num_picked_offset[etype] - 1] += + num_picked_ptr[offset]; + } else { + PickedNumType picked_count = 0; + NumPick( + fanouts[etype], replace, probs_or_mask, etype_begin, + etype_end - etype_begin, &picked_count); + total_count += picked_count; + } etype_begin = etype_end; } })); - return total_count; + if (!with_seed_offsets) { + num_picked_ptr[seed_index] = total_count; + } } int64_t TemporalNumPickByEtype( @@ -1288,6 +1629,7 @@ int64_t Pick( const torch::TensorOptions& options, const torch::optional& probs_or_mask, SamplerArgs args, PickedType* picked_data_ptr) { + if (fanout == 0 || num_neighbors == 0) return 0; if (probs_or_mask.has_value()) { return NonUniformPick( offset, num_neighbors, fanout, replace, options, probs_or_mask.value(), @@ -1349,14 +1691,16 @@ int64_t TemporalPick( template int64_t PickByEtype( - int64_t offset, int64_t num_neighbors, const std::vector& fanouts, - bool replace, const torch::TensorOptions& options, - const torch::Tensor& type_per_edge, + bool with_seed_offsets, int64_t offset, int64_t num_neighbors, + const std::vector& fanouts, bool replace, + const torch::TensorOptions& options, const torch::Tensor& type_per_edge, const torch::optional& probs_or_mask, SamplerArgs args, - PickedType* picked_data_ptr) { + PickedType* picked_data_ptr, int64_t seed_index, + PickedType* subgraph_indptr_ptr, + const std::vector& etype_id_to_num_picked_offset) { int64_t etype_begin = offset; int64_t etype_end = offset; - int64_t pick_offset = 0; + int64_t picked_total_count = 0; AT_DISPATCH_INTEGRAL_TYPES( type_per_edge.scalar_type(), "PickByEtype", ([&] { const scalar_t* type_per_edge_data = type_per_edge.data_ptr(); @@ -1371,17 +1715,36 @@ int64_t PickByEtype( type_per_edge_data + etype_begin, type_per_edge_data + end, etype); etype_end = etype_end_it - type_per_edge_data; - // Do sampling for one etype. + // Do sampling for one etype. The picked nodes aren't stored + // continuously, but separately for each different etype. if (fanout != 0) { - int64_t picked_count = Pick( - etype_begin, etype_end - etype_begin, fanout, replace, options, - probs_or_mask, args, picked_data_ptr + pick_offset); - pick_offset += picked_count; + auto picked_count = 0; + if (with_seed_offsets) { + const auto indptr_offset = + etype_id_to_num_picked_offset[etype] + seed_index; + picked_count = Pick( + etype_begin, etype_end - etype_begin, fanout, replace, + options, probs_or_mask, args, + picked_data_ptr + subgraph_indptr_ptr[indptr_offset]); + TORCH_CHECK( + subgraph_indptr_ptr[indptr_offset + 1] - + subgraph_indptr_ptr[indptr_offset] == + picked_count, + "Actual picked count doesn't match the calculated " + "pick number."); + } else { + picked_count = Pick( + etype_begin, etype_end - etype_begin, fanout, replace, + options, probs_or_mask, args, + picked_data_ptr + subgraph_indptr_ptr[seed_index] + + picked_total_count); + } + picked_total_count += picked_count; } etype_begin = etype_end; } })); - return pick_offset; + return picked_total_count; } template @@ -1432,7 +1795,7 @@ std::enable_if_t Pick( const torch::TensorOptions& options, const torch::optional& probs_or_mask, SamplerArgs args, PickedType* picked_data_ptr) { - if (fanout == 0) return 0; + if (fanout == 0 || num_neighbors == 0) return 0; if (probs_or_mask.has_value()) { if (fanout < 0) { return NonUniformPick( diff --git a/notebooks/graphbolt/walkthrough.ipynb b/notebooks/graphbolt/walkthrough.ipynb index 137a2ba4d3f4..2500b7871fb1 100644 --- a/notebooks/graphbolt/walkthrough.ipynb +++ b/notebooks/graphbolt/walkthrough.ipynb @@ -61,12 +61,12 @@ }, "outputs": [], "source": [ - "node_pairs = torch.tensor(\n", + "seeds = torch.tensor(\n", " [[7, 0], [6, 0], [1, 3], [3, 3], [2, 4], [8, 4], [1, 4], [2, 4], [1, 5],\n", " [9, 6], [0, 6], [8, 6], [7, 7], [7, 7], [4, 7], [6, 8], [5, 8], [9, 9],\n", " [4, 9], [4, 9], [5, 9], [9, 9], [5, 9], [9, 9], [7, 9]]\n", ")\n", - "item_set = gb.ItemSet(node_pairs, names=\"node_pairs\")\n", + "item_set = gb.ItemSet(seeds, names=\"seeds\")\n", "print(list(item_set))" ] }, @@ -262,7 +262,7 @@ "num_nodes = 10\n", "nodes = torch.arange(num_nodes)\n", "labels = torch.tensor([1, 2, 0, 2, 2, 0, 2, 2, 2, 2])\n", - "item_set = gb.ItemSet((nodes, labels), names=(\"seed_nodes\", \"labels\"))\n", + "item_set = gb.ItemSet((nodes, labels), names=(\"seeds\", \"labels\"))\n", "\n", "indptr = torch.tensor([0, 2, 2, 2, 4, 8, 9, 12, 15, 17, 25])\n", "indices = torch.tensor(\n", @@ -311,4 +311,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/python/dgl/backend/pytorch/tensor.py b/python/dgl/backend/pytorch/tensor.py index 3f7147ecfd91..a0f5af3bc6fe 100644 --- a/python/dgl/backend/pytorch/tensor.py +++ b/python/dgl/backend/pytorch/tensor.py @@ -12,8 +12,8 @@ from ...function.base import TargetCode from ...utils import version -if version.parse(th.__version__) < version.parse("1.13.0"): - raise RuntimeError("DGL requires PyTorch >= 1.13.0") +if version.parse(th.__version__) < version.parse("2.0.0"): + raise RuntimeError("DGL requires PyTorch >= 2.0.0") def data_type_dict(): diff --git a/python/dgl/graphbolt/base.py b/python/dgl/graphbolt/base.py index 43999f307ead..398ac31e5290 100644 --- a/python/dgl/graphbolt/base.py +++ b/python/dgl/graphbolt/base.py @@ -25,6 +25,7 @@ "expand_indptr", "CSCFormatBase", "seed", + "seed_type_str_to_ntypes", ] CANONICAL_ETYPE_DELIMITER = ":" @@ -185,6 +186,37 @@ def etype_str_to_tuple(c_etype): return ret +def seed_type_str_to_ntypes(seed_type, seed_size): + """Convert seeds type to node types from string to list. + + Examples + -------- + 1. node pairs + + >>> seed_type = "user:like:item" + >>> seed_size = 2 + >>> node_type = seed_type_str_to_ntypes(seed_type, seed_size) + >>> print(node_type) + ["user", "item"] + + 2. hyperlink + + >>> seed_type = "query:user:item" + >>> seed_size = 3 + >>> node_type = seed_type_str_to_ntypes(seed_type, seed_size) + >>> print(node_type) + ["query", "user", "item"] + """ + assert isinstance( + seed_type, str + ), f"Passed-in seed type should be string, but got {type(seed_type)}" + ntypes = seed_type.split(CANONICAL_ETYPE_DELIMITER) + is_hyperlink = len(ntypes) == seed_size + if not is_hyperlink: + ntypes = ntypes[::2] + return ntypes + + def apply_to(x, device): """Apply `to` function to object x only if it has `to`.""" diff --git a/python/dgl/graphbolt/impl/in_subgraph_sampler.py b/python/dgl/graphbolt/impl/in_subgraph_sampler.py index d3be7b81e59d..5ccaebd07847 100644 --- a/python/dgl/graphbolt/impl/in_subgraph_sampler.py +++ b/python/dgl/graphbolt/impl/in_subgraph_sampler.py @@ -34,7 +34,7 @@ class InSubgraphSampler(SubgraphSampler): >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14]) >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4]) >>> graph = gb.fused_csc_sampling_graph(indptr, indices) - >>> item_set = gb.ItemSet(len(indptr) - 1, names="seed_nodes") + >>> item_set = gb.ItemSet(len(indptr) - 1, names="seeds") >>> item_sampler = gb.ItemSampler(item_set, batch_size=2) >>> insubgraph_sampler = gb.InSubgraphSampler(item_sampler, graph) >>> for _, data in enumerate(insubgraph_sampler): diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 64a30c1c0329..7396dbe197db 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -407,8 +407,8 @@ class NeighborSampler(NeighborSamplerImpl): >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5]) >>> graph = gb.fused_csc_sampling_graph(indptr, indices) - >>> node_pairs = torch.LongTensor([[0, 1], [1, 2]]) - >>> item_set = gb.ItemSet(node_pairs, names="node_pairs") + >>> seeds = torch.LongTensor([[0, 1], [1, 2]]) + >>> item_set = gb.ItemSet(seeds, names="seeds") >>> datapipe = gb.ItemSampler(item_set, batch_size=1) >>> datapipe = datapipe.sample_uniform_negative(graph, 2) >>> datapipe = datapipe.sample_neighbor(graph, [5, 10, 15]) @@ -534,8 +534,8 @@ class LayerNeighborSampler(NeighborSamplerImpl): >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5]) >>> graph = gb.fused_csc_sampling_graph(indptr, indices) - >>> node_pairs = torch.LongTensor([[0, 1], [1, 2]]) - >>> item_set = gb.ItemSet(node_pairs, names="node_pairs") + >>> seeds = torch.LongTensor([[0, 1], [1, 2]]) + >>> item_set = gb.ItemSet(seeds, names="seeds") >>> item_sampler = gb.ItemSampler(item_set, batch_size=1,) >>> neg_sampler = gb.UniformNegativeSampler(item_sampler, graph, 2) >>> fanouts = [torch.LongTensor([5]), @@ -566,8 +566,12 @@ class LayerNeighborSampler(NeighborSamplerImpl): original_edge_ids=None, original_column_node_ids=tensor([0, 1, 5, 2]), )] - >>> next(iter(subgraph_sampler)).compacted_node_pairs - (tensor([0]), tensor([1])) + >>> next(iter(subgraph_sampler)).compacted_seeds + tensor([[0, 1], [0, 2], [0, 3]]) + >>> next(iter(subgraph_sampler)).labels + tensor([1., 0., 0.]) + >>> next(iter(subgraph_sampler)).indexes + tensor([0, 0, 0]) """ def __init__( diff --git a/python/dgl/graphbolt/impl/ondisk_dataset.py b/python/dgl/graphbolt/impl/ondisk_dataset.py index 0bcf9451244e..31fcf804f30c 100644 --- a/python/dgl/graphbolt/impl/ondisk_dataset.py +++ b/python/dgl/graphbolt/impl/ondisk_dataset.py @@ -42,11 +42,7 @@ __all__ = ["OnDiskDataset", "preprocess_ondisk_dataset", "BuiltinDataset"] NAMES_INDICATING_NODE_IDS = [ - "seed_nodes", - "node_pairs", "seeds", - "negative_srcs", - "negative_dsts", ] diff --git a/python/dgl/graphbolt/impl/uniform_negative_sampler.py b/python/dgl/graphbolt/impl/uniform_negative_sampler.py index 6a9a661dd426..15f3447abe87 100644 --- a/python/dgl/graphbolt/impl/uniform_negative_sampler.py +++ b/python/dgl/graphbolt/impl/uniform_negative_sampler.py @@ -36,20 +36,20 @@ class UniformNegativeSampler(NegativeSampler): >>> indptr = torch.LongTensor([0, 1, 2, 3, 4]) >>> indices = torch.LongTensor([1, 2, 3, 0]) >>> graph = gb.fused_csc_sampling_graph(indptr, indices) - >>> node_pairs = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]]) - >>> item_set = gb.ItemSet(node_pairs, names="node_pairs") + >>> seeds = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]]) + >>> item_set = gb.ItemSet(seeds, names="seeds") >>> item_sampler = gb.ItemSampler( ... item_set, batch_size=4,) >>> neg_sampler = gb.UniformNegativeSampler( ... item_sampler, graph, 2) >>> for minibatch in neg_sampler: - ... print(minibatch.negative_srcs) - ... print(minibatch.negative_dsts) - None - tensor([[2, 1], - [2, 1], - [3, 2], - [1, 3]]) + ... print(minibatch.seeds) + ... print(minibatch.labels) + ... print(minibatch.indexes) + tensor([[0, 1], [1, 2], [2, 3], [3, 0], [0, 1], [0, 3], [1, 1], [1, 2], + [2, 1], [2, 0], [3, 0], [3, 2]]) + tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]) + tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]) """ def __init__( diff --git a/python/dgl/graphbolt/item_sampler.py b/python/dgl/graphbolt/item_sampler.py index 79d05db02bc4..f92798e6b5cf 100644 --- a/python/dgl/graphbolt/item_sampler.py +++ b/python/dgl/graphbolt/item_sampler.py @@ -50,7 +50,7 @@ def minibatcher_default(batch, names): return batch if len(names) == 1: # Handle the case of single item: batch = tensor([0, 1, 2, 3]), names = - # ("seed_nodes",) as `zip(batch, names)` will iterate over the tensor + # ("seeds",) as `zip(batch, names)` will iterate over the tensor # instead of the batch. init_data = {names[0]: batch} else: @@ -313,68 +313,61 @@ class ItemSampler(IterDataPipe): >>> import torch >>> from dgl import graphbolt as gb - >>> item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes") + >>> item_set = gb.ItemSet(torch.arange(0, 10), names="seeds") >>> item_sampler = gb.ItemSampler( ... item_set, batch_size=4, shuffle=False, drop_last=False ... ) >>> next(iter(item_sampler)) - MiniBatch(seed_nodes=tensor([0, 1, 2, 3]), node_pairs=None, labels=None, - negative_srcs=None, negative_dsts=None, sampled_subgraphs=None, - input_nodes=None, node_features=None, edge_features=None, - compacted_node_pairs=None, compacted_negative_srcs=None, - compacted_negative_dsts=None) + MiniBatch(seeds=tensor([0, 1, 2, 3]), sampled_subgraphs=None, + node_features=None, labels=None, input_nodes=None, + indexes=None, edge_features=None, compacted_seeds=None, + blocks=None,) 2. Node pairs. >>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), - ... names="node_pairs") + ... names="seeds") >>> item_sampler = gb.ItemSampler( ... item_set, batch_size=4, shuffle=False, drop_last=False ... ) >>> next(iter(item_sampler)) - MiniBatch(seed_nodes=None, - node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])), - labels=None, negative_srcs=None, negative_dsts=None, - sampled_subgraphs=None, input_nodes=None, node_features=None, - edge_features=None, compacted_node_pairs=None, - compacted_negative_srcs=None, compacted_negative_dsts=None) + MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]), + sampled_subgraphs=None, node_features=None, labels=None, + input_nodes=None, indexes=None, edge_features=None, + compacted_seeds=None, blocks=None,) 3. Node pairs and labels. >>> item_set = gb.ItemSet( ... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 20)), - ... names=("node_pairs", "labels") + ... names=("seeds", "labels") ... ) >>> item_sampler = gb.ItemSampler( ... item_set, batch_size=4, shuffle=False, drop_last=False ... ) >>> next(iter(item_sampler)) - MiniBatch(seed_nodes=None, - node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])), - labels=tensor([10, 11, 12, 13]), negative_srcs=None, - negative_dsts=None, sampled_subgraphs=None, input_nodes=None, - node_features=None, edge_features=None, compacted_node_pairs=None, - compacted_negative_srcs=None, compacted_negative_dsts=None) - - 4. Node pairs and negative destinations. - - >>> node_pairs = torch.arange(0, 20).reshape(-1, 2) - >>> negative_dsts = torch.arange(10, 30).reshape(-1, 2) - >>> item_set = gb.ItemSet((node_pairs, negative_dsts), names=("node_pairs", - ... "negative_dsts")) + MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]), + sampled_subgraphs=None, node_features=None, + labels=tensor([10, 11, 12, 13]), input_nodes=None, + indexes=None, edge_features=None, compacted_seeds=None, + blocks=None,) + + 4. Node pairs, labels and indexes. + + >>> seeds = torch.arange(0, 20).reshape(-1, 2) + >>> labels = torch.tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0]) + >>> indexes = torch.tensor([0, 1, 0, 0, 0, 0, 1, 1, 1, 1]) + >>> item_set = gb.ItemSet((seeds, labels, indexes), names=("seeds", + ... "labels", "indexes")) >>> item_sampler = gb.ItemSampler( ... item_set, batch_size=4, shuffle=False, drop_last=False ... ) >>> next(iter(item_sampler)) - MiniBatch(seed_nodes=None, - node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])), - labels=None, negative_srcs=None, - negative_dsts=tensor([[10, 11], - [12, 13], - [14, 15], - [16, 17]]), sampled_subgraphs=None, input_nodes=None, - node_features=None, edge_features=None, compacted_node_pairs=None, - compacted_negative_srcs=None, compacted_negative_dsts=None) + MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]), + sampled_subgraphs=None, node_features=None, + labels=tensor([1, 1, 0, 0]), input_nodes=None, + indexes=tensor([0, 1, 0, 0]), edge_features=None, + compacted_seeds=None, blocks=None,) 5. DGLGraphs. @@ -404,85 +397,74 @@ class ItemSampler(IterDataPipe): 7. Heterogeneous node IDs. >>> ids = { - ... "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"), - ... "item": gb.ItemSet(torch.arange(0, 6), names="seed_nodes"), + ... "user": gb.ItemSet(torch.arange(0, 5), names="seeds"), + ... "item": gb.ItemSet(torch.arange(0, 6), names="seeds"), ... } >>> item_set = gb.ItemSetDict(ids) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4) >>> next(iter(item_sampler)) - MiniBatch(seed_nodes={'user': tensor([0, 1, 2, 3])}, node_pairs=None, - labels=None, negative_srcs=None, negative_dsts=None, sampled_subgraphs=None, - input_nodes=None, node_features=None, edge_features=None, - compacted_node_pairs=None, compacted_negative_srcs=None, - compacted_negative_dsts=None) + MiniBatch(seeds={'user': tensor([0, 1, 2, 3])}, sampled_subgraphs=None, + node_features=None, labels=None, input_nodes=None, indexes=None, + edge_features=None, compacted_seeds=None, blocks=None,) 8. Heterogeneous node pairs. - >>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2) - >>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2) + >>> seeds_like = torch.arange(0, 10).reshape(-1, 2) + >>> seeds_follow = torch.arange(10, 20).reshape(-1, 2) >>> item_set = gb.ItemSetDict({ ... "user:like:item": gb.ItemSet( - ... node_pairs_like, names="node_pairs"), + ... seeds_like, names="seeds"), ... "user:follow:user": gb.ItemSet( - ... node_pairs_follow, names="node_pairs"), + ... seeds_follow, names="seeds"), ... }) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4) >>> next(iter(item_sampler)) - MiniBatch(seed_nodes=None, - node_pairs={'user:like:item': - (tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))}, - labels=None, negative_srcs=None, negative_dsts=None, - sampled_subgraphs=None, input_nodes=None, node_features=None, - edge_features=None, compacted_node_pairs=None, - compacted_negative_srcs=None, compacted_negative_dsts=None) + MiniBatch(seeds={'user:like:item': + tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None, + node_features=None, labels=None, input_nodes=None, indexes=None, + edge_features=None, compacted_seeds=None, blocks=None,) 9. Heterogeneous node pairs and labels. - >>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2) - >>> labels_like = torch.arange(0, 10) - >>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2) - >>> labels_follow = torch.arange(10, 20) + >>> seeds_like = torch.arange(0, 10).reshape(-1, 2) + >>> labels_like = torch.arange(0, 5) + >>> seeds_follow = torch.arange(10, 20).reshape(-1, 2) + >>> labels_follow = torch.arange(5, 10) >>> item_set = gb.ItemSetDict({ - ... "user:like:item": gb.ItemSet((node_pairs_like, labels_like), - ... names=("node_pairs", "labels")), - ... "user:follow:user": gb.ItemSet((node_pairs_follow, labels_follow), - ... names=("node_pairs", "labels")), + ... "user:like:item": gb.ItemSet((seeds_like, labels_like), + ... names=("seeds", "labels")), + ... "user:follow:user": gb.ItemSet((seeds_follow, labels_follow), + ... names=("seeds", "labels")), ... }) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4) >>> next(iter(item_sampler)) - MiniBatch(seed_nodes=None, - node_pairs={'user:like:item': - (tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))}, - labels={'user:like:item': tensor([0, 1, 2, 3])}, - negative_srcs=None, negative_dsts=None, sampled_subgraphs=None, - input_nodes=None, node_features=None, edge_features=None, - compacted_node_pairs=None, compacted_negative_srcs=None, - compacted_negative_dsts=None) - - 10. Heterogeneous node pairs and negative destinations. - - >>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2) - >>> negative_dsts_like = torch.arange(10, 20).reshape(-1, 2) - >>> node_pairs_follow = torch.arange(20, 30).reshape(-1, 2) - >>> negative_dsts_follow = torch.arange(30, 40).reshape(-1, 2) + MiniBatch(seeds={'user:like:item': + tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None, + node_features=None, labels={'user:like:item': tensor([0, 1, 2, 3])}, + input_nodes=None, indexes=None, edge_features=None, + compacted_seeds=None, blocks=None,) + + 10. Heterogeneous node pairs, labels and indexes. + + >>> seeds_like = torch.arange(0, 10).reshape(-1, 2) + >>> labels_like = torch.tensor([1, 1, 0, 0, 0]) + >>> indexes_like = torch.tensor([0, 1, 0, 0, 1]) + >>> seeds_follow = torch.arange(20, 30).reshape(-1, 2) + >>> labels_follow = torch.tensor([1, 1, 0, 0, 0]) + >>> indexes_follow = torch.tensor([0, 1, 0, 0, 1]) >>> item_set = gb.ItemSetDict({ - ... "user:like:item": gb.ItemSet((node_pairs_like, negative_dsts_like), - ... names=("node_pairs", "negative_dsts")), - ... "user:follow:user": gb.ItemSet((node_pairs_follow, - ... negative_dsts_follow), names=("node_pairs", "negative_dsts")), + ... "user:like:item": gb.ItemSet((seeds_like, labels_like, + ... indexes_like), names=("seeds", "labels", "indexes")), + ... "user:follow:user": gb.ItemSet((seeds_follow,labels_follow, + ... indexes_follow), names=("seeds", "labels", "indexes")), ... }) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4) >>> next(iter(item_sampler)) - MiniBatch(seed_nodes=None, - node_pairs={'user:like:item': - (tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))}, - labels=None, negative_srcs=None, - negative_dsts={'user:like:item': tensor([[10, 11], - [12, 13], - [14, 15], - [16, 17]])}, sampled_subgraphs=None, input_nodes=None, - node_features=None, edge_features=None, compacted_node_pairs=None, - compacted_negative_srcs=None, compacted_negative_dsts=None) + MiniBatch(seeds={'user:like:item': + tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None, + node_features=None, labels={'user:like:item': tensor([1, 1, 0, 0])}, + input_nodes=None, indexes={'user:like:item': tensor([0, 1, 0, 0])}, + edge_features=None, compacted_seeds=None, blocks=None,) """ def __init__( diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index afe9e392e492..ce96eb60e96e 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -110,6 +110,21 @@ class requires each input itemset to be iterable. tensor([1, 1, 0, 0, 0])) >>> item_set.names ('seeds', 'labels') + + 6. Tuple of iterables with different shape: hyperlink and labels. + + >>> seeds = torch.arange(0, 10).reshape(-1, 5) + >>> labels = torch.tensor([1, 0]) + >>> item_set = gb.ItemSet( + ... (seeds, labels), names=("seeds", "lables")) + >>> list(item_set) + [(tensor([0, 1, 2, 3, 4]), tensor([1])), + (tensor([5, 6, 7, 8, 9]), tensor([0]))] + >>> item_set[:] + (tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), + tensor([1, 0])) + >>> item_set.names + ('seeds', 'labels') """ def __init__( @@ -315,6 +330,31 @@ class ItemSetDict: tensor([1, 1, 0]))} >>> item_set.names ('seeds', 'labels') + + 4. Tuple of iterables with different shape: hyperlink and labels. + + >>> first_seeds = torch.arange(0, 6).reshape(-1, 3) + >>> first_labels = torch.tensor([1, 0]) + >>> second_seeds = torch.arange(0, 2).reshape(-1, 1) + >>> second_labels = torch.tensor([1, 0]) + >>> item_set = gb.ItemSetDict({ + ... "query:user:item": gb.ItemSet( + ... (first_seeds, first_labels), + ... names=("seeds", "labels")), + ... "user": gb.ItemSet( + ... (second_seeds, second_labels), + ... names=("seeds", "labels"))}) + >>> list(item_set) + [{'query:user:item': (tensor([0, 1, 2]), tensor(1))}, + {'query:user:item': (tensor([3, 4, 5]), tensor(0))}, + {'user': (tensor([0]), tensor(1))}, + {'user': (tensor([1]), tensor(0))}] + >>> item_set[:] + {'query:user:item': (tensor([[0, 1, 2], [3, 4, 5]]), + tensor([1, 0])), + 'user': (tensor([[0], [1]]),tensor([1, 0]))} + >>> item_set.names + ('seeds', 'labels') """ def __init__(self, itemsets: Dict[str, ItemSet]) -> None: diff --git a/python/dgl/graphbolt/minibatch.py b/python/dgl/graphbolt/minibatch.py index 61883d043449..726010453333 100644 --- a/python/dgl/graphbolt/minibatch.py +++ b/python/dgl/graphbolt/minibatch.py @@ -26,11 +26,11 @@ class MiniBatch: labels: Union[torch.Tensor, Dict[str, torch.Tensor]] = None """ - Labels associated with seed nodes / node pairs in the graph. + Labels associated with seeds in the graph. - If `labels` is a tensor: It indicates the graph is homogeneous. The value - should be corresponding labels to given 'seed_nodes' or 'node_pairs'. + should be corresponding labels to given 'seeds'. - If `labels` is a dictionary: The keys should be node or edge type and the - value should be corresponding labels to given 'seed_nodes' or 'node_pairs'. + value should be corresponding labels to given 'seeds'. """ seeds: Union[ @@ -61,15 +61,14 @@ class MiniBatch: indexes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None """ - Indexes associated with seed nodes / node pairs in the graph, which - indicates to which query a seed node / node pair belongs. + Indexes associated with seeds in the graph, which + indicates to which query a seeds belongs. - If `indexes` is a tensor: It indicates the graph is homogeneous. The - value should be corresponding query to given 'seed_nodes' or - 'node_pairs'. - - If `indexes` is a dictionary: It indicates the graph is - heterogeneous. The keys should be node or edge type and the value should - be corresponding query to given 'seed_nodes' or 'node_pairs'. For each - key, indexes are consecutive integers starting from zero. + value should be corresponding query to given 'seeds'. + - If `indexes` is a dictionary: It indicates the graph is heterogeneous. + The keys should be node or edge type and the value should be + corresponding query to given 'seeds'. For each key, indexes are + consecutive integers starting from zero. """ sampled_subgraphs: List[SampledSubgraph] = None diff --git a/python/dgl/graphbolt/subgraph_sampler.py b/python/dgl/graphbolt/subgraph_sampler.py index 895658cd5c14..2221b41ed3dc 100644 --- a/python/dgl/graphbolt/subgraph_sampler.py +++ b/python/dgl/graphbolt/subgraph_sampler.py @@ -6,7 +6,7 @@ import torch from torch.utils.data import functional_datapipe -from .base import etype_str_to_tuple +from .base import seed_type_str_to_ntypes from .internal import compact_temporal_nodes, unique_and_compact from .minibatch_transformer import MiniBatchTransformer @@ -93,7 +93,8 @@ def _seeds_preprocess(minibatch): """Preprocess `seeds` in a minibatch to construct `unique_seeds`, `node_timestamp` and `compacted_seeds` for further sampling. It optionally incorporates timestamps for temporal graphs, organizing and - compacting seeds based on their types and timestamps. + compacting seeds based on their types and timestamps. In heterogeneous + graph, `seeds` with same node type will be unqiued together. Parameters ---------- @@ -121,7 +122,7 @@ def _seeds_preprocess(minibatch): nodes_timestamp = None if use_timestamp: nodes_timestamp = defaultdict(list) - for etype, typed_seeds in seeds.items(): + for seed_type, typed_seeds in seeds.items(): # When typed_seeds is a one-dimensional tensor, it represents # seed nodes, which does not need to do unique and compact. if typed_seeds.ndim == 1: @@ -131,25 +132,27 @@ def _seeds_preprocess(minibatch): else None ) return seeds, nodes_timestamp, None - assert typed_seeds.ndim == 2 and typed_seeds.shape[1] == 2, ( - "Only tensor with shape 1*N and N*2 is " + assert typed_seeds.ndim == 2, ( + "Only tensor with shape 1*N and N*M is " + f"supported now, but got {typed_seeds.shape}." ) - ntypes = etype[:].split(":")[::2] + ntypes = seed_type_str_to_ntypes( + seed_type, typed_seeds.shape[1] + ) if use_timestamp: negative_ratio = ( typed_seeds.shape[0] - // minibatch.timestamp[etype].shape[0] + // minibatch.timestamp[seed_type].shape[0] - 1 ) neg_timestamp = minibatch.timestamp[ - etype + seed_type ].repeat_interleave(negative_ratio) for i, ntype in enumerate(ntypes): nodes[ntype].append(typed_seeds[:, i]) if use_timestamp: nodes_timestamp[ntype].append( - minibatch.timestamp[etype] + minibatch.timestamp[seed_type] ) nodes_timestamp[ntype].append(neg_timestamp) # Unique and compact the collected nodes. @@ -164,11 +167,16 @@ def _seeds_preprocess(minibatch): nodes_timestamp = None compacted_seeds = {} # Map back in same order as collect. - for etype, typed_seeds in seeds.items(): - src_type, _, dst_type = etype_str_to_tuple(etype) - src = compacted[src_type].pop(0) - dst = compacted[dst_type].pop(0) - compacted_seeds[etype] = torch.cat((src, dst)).view(2, -1).T + for seed_type, typed_seeds in seeds.items(): + ntypes = seed_type_str_to_ntypes( + seed_type, typed_seeds.shape[1] + ) + compacted_seed = [] + for ntype in ntypes: + compacted_seed.append(compacted[ntype].pop(0)) + compacted_seeds[seed_type] = ( + torch.cat(compacted_seed).view(len(ntypes), -1).T + ) else: # When seeds is a one-dimensional tensor, it represents seed nodes, # which does not need to do unique and compact. @@ -193,7 +201,9 @@ def _seeds_preprocess(minibatch): seeds_timestamp = torch.cat( (minibatch.timestamp, neg_timestamp) ) - nodes_timestamp = [seeds_timestamp for _ in range(seeds.ndim)] + nodes_timestamp = [ + seeds_timestamp for _ in range(seeds.shape[1]) + ] # Unique and compact the collected nodes. if use_timestamp: ( diff --git a/script/create_dev_conda_env.sh b/script/create_dev_conda_env.sh index 2d4cd0e68761..8b18037d06ac 100644 --- a/script/create_dev_conda_env.sh +++ b/script/create_dev_conda_env.sh @@ -1,8 +1,8 @@ #!/bin/bash -readonly CUDA_VERSIONS="11.6,11.7,11.8,12.1" -readonly TORCH_VERSION="1.13.0" -readonly PYTHON_VERSION="3.8" +readonly CUDA_VERSIONS="11.7,11.8,12.1" +readonly TORCH_VERSION="2.0.0" +readonly PYTHON_VERSION="3.10" usage() { cat << EOF @@ -10,8 +10,8 @@ usage: bash $0 OPTIONS examples: bash $0 -c bash $0 -g 11.7 - bash $0 -g 11.7 -p 3.8 - bash $0 -g 11.7 -p 3.8 -t 1.13.0 + bash $0 -g 11.7 -p 3.10 + bash $0 -g 11.7 -p 3.10 -t 2.0.0 bash $0 -c -n dgl-dev-cpu Create a developement environment for DGL developers. @@ -29,7 +29,7 @@ OPTIONS: -p Create dev environment based on specified python version. -s Run silently which indicates always 'yes' for any confirmation. -t Create dev environment based on specified PyTorch version such - as '1.13.0'. + as '2.0.0'. EOF } diff --git a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py index e21bf5d9c869..e4622deef010 100644 --- a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py +++ b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py @@ -2219,10 +2219,13 @@ def test_sample_neighbors_hetero_pick_number( type_per_edge=type_per_edge, node_type_to_id=ntypes, edge_type_to_id=etypes, - ) + ).to(F.ctx()) # Generate subgraph via sample neighbors. - nodes = torch.LongTensor([0, 1]) + nodes = { + "N0": torch.LongTensor([0]).to(F.ctx()), + "N1": torch.LongTensor([1]).to(F.ctx()), + } sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors diff --git a/tests/python/pytorch/graphbolt/impl/test_in_subgraph_sampler.py b/tests/python/pytorch/graphbolt/impl/test_in_subgraph_sampler.py index 33f87bee3ef6..fe3b21ccd33f 100644 --- a/tests/python/pytorch/graphbolt/impl/test_in_subgraph_sampler.py +++ b/tests/python/pytorch/graphbolt/impl/test_in_subgraph_sampler.py @@ -81,7 +81,7 @@ def test_InSubgraphSampler_homo(): graph = gb.fused_csc_sampling_graph(indptr, indices).to(F.ctx()) seed_nodes = torch.LongTensor([0, 5, 3]) - item_set = gb.ItemSet(seed_nodes, names="seed_nodes") + item_set = gb.ItemSet(seed_nodes, names="seeds") batch_size = 1 item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to( F.ctx() @@ -162,8 +162,8 @@ def test_InSubgraphSampler_hetero(): item_set = gb.ItemSetDict( { - "N0": gb.ItemSet(torch.LongTensor([1, 0, 2]), names="seed_nodes"), - "N1": gb.ItemSet(torch.LongTensor([0, 2, 1]), names="seed_nodes"), + "N0": gb.ItemSet(torch.LongTensor([1, 0, 2]), names="seeds"), + "N1": gb.ItemSet(torch.LongTensor([0, 2, 1]), names="seeds"), } ) batch_size = 2 diff --git a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py index 3328ea66e816..52f9d86000f0 100644 --- a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py +++ b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py @@ -47,7 +47,7 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted): items = torch.arange(3) else: items = torch.tensor([2, 0, 1]) - names = "seed_nodes" + names = "seeds" itemset = gb.ItemSet(items, names=names) graph = get_hetero_graph().to(F.ctx()) if hetero: @@ -94,9 +94,7 @@ def test_labor_dependent_minibatching(layer_dependency, overlap_graph_fetch): ).to(F.ctx()) torch.random.set_rng_state(torch.manual_seed(123).get_state()) batch_dependency = 100 - itemset = gb.ItemSet( - torch.zeros(batch_dependency + 1).int(), names="seed_nodes" - ) + itemset = gb.ItemSet(torch.zeros(batch_dependency + 1).int(), names="seeds") datapipe = gb.ItemSampler(itemset, batch_size=1).copy_to(F.ctx()) fanouts = [5, 5] datapipe = datapipe.sample_layer_neighbor( diff --git a/tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py b/tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py index 1d33c934b98b..ebca92e97fbc 100644 --- a/tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py +++ b/tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py @@ -96,7 +96,7 @@ def test_OnDiskDataset_multiple_tasks(): train_set: - type: null data: - - name: seed_nodes + - name: seeds format: numpy in_memory: true path: {train_ids_path} @@ -112,7 +112,7 @@ def test_OnDiskDataset_multiple_tasks(): train_set: - type: null data: - - name: seed_nodes + - name: seeds format: numpy in_memory: true path: {train_ids_path} @@ -140,7 +140,7 @@ def test_OnDiskDataset_multiple_tasks(): for i, (id, label, _) in enumerate(train_set): assert id == train_ids[i] assert label == train_labels[i] - assert train_set.names == ("seed_nodes", "labels", None) + assert train_set.names == ("seeds", "labels", None) train_set = None dataset = None @@ -162,7 +162,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_names(): train_set: - type: null data: - - name: seed_nodes + - name: seeds format: numpy in_memory: true path: {train_ids_path} @@ -183,7 +183,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_names(): for i, (id, label, _) in enumerate(train_set): assert id == train_ids[i] assert label == train_labels[i] - assert train_set.names == ("seed_nodes", "labels", None) + assert train_set.names == ("seeds", "labels", None) train_set = None @@ -204,7 +204,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names(): train_set: - type: "author:writes:paper" data: - - name: seed_nodes + - name: seeds format: numpy in_memory: true path: {train_ids_path} @@ -228,7 +228,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names(): id, label, _ = item["author:writes:paper"] assert id == train_ids[i] assert label == train_labels[i] - assert train_set.names == ("seed_nodes", "labels", None) + assert train_set.names == ("seeds", "labels", None) train_set = None @@ -267,7 +267,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): train_set: - type: null data: - - name: seed_nodes + - name: seeds format: numpy in_memory: true path: {train_ids_path} @@ -277,7 +277,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): path: {train_labels_path} validation_set: - data: - - name: seed_nodes + - name: seeds format: numpy in_memory: true path: {validation_ids_path} @@ -288,7 +288,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): test_set: - type: null data: - - name: seed_nodes + - name: seeds format: numpy in_memory: true path: {test_ids_path} @@ -311,7 +311,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): for i, (id, label) in enumerate(train_set): assert id == train_ids[i] assert label == train_labels[i] - assert train_set.names == ("seed_nodes", "labels") + assert train_set.names == ("seeds", "labels") train_set = None # Verify validation set. @@ -321,7 +321,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): for i, (id, label) in enumerate(validation_set): assert id == validation_ids[i] assert label == validation_labels[i] - assert validation_set.names == ("seed_nodes", "labels") + assert validation_set.names == ("seeds", "labels") validation_set = None # Verify test set. @@ -331,7 +331,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): for i, (id, label) in enumerate(test_set): assert id == test_ids[i] assert label == test_labels[i] - assert test_set.names == ("seed_nodes", "labels") + assert test_set.names == ("seeds", "labels") test_set = None dataset = None @@ -355,25 +355,23 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels(): """Test TVTSet which returns ItemSet with node pairs and labels.""" with tempfile.TemporaryDirectory() as test_dir: - train_node_pairs = np.arange(2000).reshape(1000, 2) - train_node_pairs_path = os.path.join(test_dir, "train_node_pairs.npy") - np.save(train_node_pairs_path, train_node_pairs) + train_seeds = np.arange(2000).reshape(1000, 2) + train_seeds_path = os.path.join(test_dir, "train_seeds.npy") + np.save(train_seeds_path, train_seeds) train_labels = np.random.randint(0, 10, size=1000) train_labels_path = os.path.join(test_dir, "train_labels.npy") np.save(train_labels_path, train_labels) - validation_node_pairs = np.arange(2000, 4000).reshape(1000, 2) - validation_node_pairs_path = os.path.join( - test_dir, "validation_node_pairs.npy" - ) - np.save(validation_node_pairs_path, validation_node_pairs) + validation_seeds = np.arange(2000, 4000).reshape(1000, 2) + validation_seeds_path = os.path.join(test_dir, "validation_seeds.npy") + np.save(validation_seeds_path, validation_seeds) validation_labels = np.random.randint(0, 10, size=1000) validation_labels_path = os.path.join(test_dir, "validation_labels.npy") np.save(validation_labels_path, validation_labels) - test_node_pairs = np.arange(4000, 6000).reshape(1000, 2) - test_node_pairs_path = os.path.join(test_dir, "test_node_pairs.npy") - np.save(test_node_pairs_path, test_node_pairs) + test_seeds = np.arange(4000, 6000).reshape(1000, 2) + test_seeds_path = os.path.join(test_dir, "test_seeds.npy") + np.save(test_seeds_path, test_seeds) test_labels = np.random.randint(0, 10, size=1000) test_labels_path = os.path.join(test_dir, "test_labels.npy") np.save(test_labels_path, test_labels) @@ -384,20 +382,20 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels(): train_set: - type: null data: - - name: node_pairs + - name: seeds format: numpy in_memory: true - path: {train_node_pairs_path} + path: {train_seeds_path} - name: labels format: numpy in_memory: true path: {train_labels_path} validation_set: - data: - - name: node_pairs + - name: seeds format: numpy in_memory: true - path: {validation_node_pairs_path} + path: {validation_seeds_path} - name: labels format: numpy in_memory: true @@ -405,10 +403,10 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels(): test_set: - type: null data: - - name: node_pairs + - name: seeds format: numpy in_memory: true - path: {test_node_pairs_path} + path: {test_seeds_path} - name: labels format: numpy in_memory: true @@ -421,10 +419,10 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels(): assert len(train_set) == 1000 assert isinstance(train_set, gb.ItemSet) for i, (node_pair, label) in enumerate(train_set): - assert node_pair[0] == train_node_pairs[i][0] - assert node_pair[1] == train_node_pairs[i][1] + assert node_pair[0] == train_seeds[i][0] + assert node_pair[1] == train_seeds[i][1] assert label == train_labels[i] - assert train_set.names == ("node_pairs", "labels") + assert train_set.names == ("seeds", "labels") train_set = None # Verify validation set. @@ -432,10 +430,10 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels(): assert len(validation_set) == 1000 assert isinstance(validation_set, gb.ItemSet) for i, (node_pair, label) in enumerate(validation_set): - assert node_pair[0] == validation_node_pairs[i][0] - assert node_pair[1] == validation_node_pairs[i][1] + assert node_pair[0] == validation_seeds[i][0] + assert node_pair[1] == validation_seeds[i][1] assert label == validation_labels[i] - assert validation_set.names == ("node_pairs", "labels") + assert validation_set.names == ("seeds", "labels") validation_set = None # Verify test set. @@ -443,43 +441,69 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels(): assert len(test_set) == 1000 assert isinstance(test_set, gb.ItemSet) for i, (node_pair, label) in enumerate(test_set): - assert node_pair[0] == test_node_pairs[i][0] - assert node_pair[1] == test_node_pairs[i][1] + assert node_pair[0] == test_seeds[i][0] + assert node_pair[1] == test_seeds[i][1] assert label == test_labels[i] - assert test_set.names == ("node_pairs", "labels") + assert test_set.names == ("seeds", "labels") test_set = None dataset = None -def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_negs(): +def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels_indexes(): """Test TVTSet which returns ItemSet with node pairs and negative ones.""" with tempfile.TemporaryDirectory() as test_dir: - train_node_pairs = np.arange(2000).reshape(1000, 2) - train_node_pairs_path = os.path.join(test_dir, "train_node_pairs.npy") - np.save(train_node_pairs_path, train_node_pairs) - train_neg_dst = np.random.choice(1000 * 10, size=1000 * 10).reshape( - 1000, 10 - ) - train_neg_dst_path = os.path.join(test_dir, "train_neg_dst.npy") - np.save(train_neg_dst_path, train_neg_dst) - - validation_node_pairs = np.arange(2000, 4000).reshape(1000, 2) - validation_node_pairs_path = os.path.join( - test_dir, "validation_node_pairs.npy" - ) - np.save(validation_node_pairs_path, validation_node_pairs) - validation_neg_dst = train_neg_dst + 1 - validation_neg_dst_path = os.path.join( - test_dir, "validation_neg_dst.npy" - ) - np.save(validation_neg_dst_path, validation_neg_dst) - - test_node_pairs = np.arange(4000, 6000).reshape(1000, 2) - test_node_pairs_path = os.path.join(test_dir, "test_node_pairs.npy") - np.save(test_node_pairs_path, test_node_pairs) - test_neg_dst = train_neg_dst + 2 - test_neg_dst_path = os.path.join(test_dir, "test_neg_dst.npy") - np.save(test_neg_dst_path, test_neg_dst) + train_seeds = np.arange(2000).reshape(1000, 2) + train_neg_dst = np.random.choice(1000 * 10, size=1000 * 10) + train_neg_src = train_seeds[:, 0].repeat(10) + train_neg_seeds = ( + np.concatenate((train_neg_dst, train_neg_src)).reshape(2, -1).T + ) + train_seeds = np.concatenate((train_seeds, train_neg_seeds)) + train_seeds_path = os.path.join(test_dir, "train_seeds.npy") + np.save(train_seeds_path, train_seeds) + + train_labels = torch.empty(1000 * 11) + train_labels[:1000] = 1 + train_labels[1000:] = 0 + train_labels_path = os.path.join(test_dir, "train_labels.pt") + torch.save(train_labels, train_labels_path) + + train_indexes = torch.arange(0, 1000) + train_indexes = np.concatenate( + (train_indexes, train_indexes.repeat_interleave(10)) + ) + train_indexes_path = os.path.join(test_dir, "train_indexes.pt") + torch.save(train_indexes, train_indexes_path) + + validation_seeds = np.arange(2000, 4000).reshape(1000, 2) + validation_neg_seeds = train_neg_seeds + 1 + validation_seeds = np.concatenate( + (validation_seeds, validation_neg_seeds) + ) + validation_seeds_path = os.path.join(test_dir, "validation_seeds.npy") + np.save(validation_seeds_path, validation_seeds) + validation_labels = train_labels + validation_labels_path = os.path.join(test_dir, "validation_labels.pt") + torch.save(validation_labels, validation_labels_path) + + validation_indexes = train_indexes + validation_indexes_path = os.path.join( + test_dir, "validation_indexes.pt" + ) + torch.save(validation_indexes, validation_indexes_path) + + test_seeds = np.arange(4000, 6000).reshape(1000, 2) + test_neg_seeds = train_neg_seeds + 2 + test_seeds = np.concatenate((test_seeds, test_neg_seeds)) + test_seeds_path = os.path.join(test_dir, "test_seeds.npy") + np.save(test_seeds_path, test_seeds) + test_labels = train_labels + test_labels_path = os.path.join(test_dir, "test_labels.pt") + torch.save(test_labels, test_labels_path) + + test_indexes = train_indexes + test_indexes_path = os.path.join(test_dir, "test_indexes.pt") + torch.save(test_indexes, test_indexes_path) yaml_content = f""" tasks: @@ -487,69 +511,83 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_negs(): train_set: - type: null data: - - name: node_pairs + - name: seeds format: numpy in_memory: true - path: {train_node_pairs_path} - - name: negative_dsts - format: numpy + path: {train_seeds_path} + - name: labels + format: torch + in_memory: true + path: {train_labels_path} + - name: indexes + format: torch in_memory: true - path: {train_neg_dst_path} + path: {train_indexes_path} validation_set: - data: - - name: node_pairs + - name: seeds format: numpy in_memory: true - path: {validation_node_pairs_path} - - name: negative_dsts - format: numpy + path: {validation_seeds_path} + - name: labels + format: torch in_memory: true - path: {validation_neg_dst_path} + path: {validation_labels_path} + - name: indexes + format: torch + in_memory: true + path: {validation_indexes_path} test_set: - type: null data: - - name: node_pairs + - name: seeds format: numpy in_memory: true - path: {test_node_pairs_path} - - name: negative_dsts - format: numpy + path: {test_seeds_path} + - name: labels + format: torch in_memory: true - path: {test_neg_dst_path} + path: {test_labels_path} + - name: indexes + format: torch + in_memory: true + path: {test_indexes_path} """ dataset = write_yaml_and_load_dataset(yaml_content, test_dir) # Verify train set. train_set = dataset.tasks[0].train_set - assert len(train_set) == 1000 + assert len(train_set) == 1000 * 11 assert isinstance(train_set, gb.ItemSet) - for i, (node_pair, negs) in enumerate(train_set): - assert node_pair[0] == train_node_pairs[i][0] - assert node_pair[1] == train_node_pairs[i][1] - assert torch.equal(negs, torch.from_numpy(train_neg_dst[i])) - assert train_set.names == ("node_pairs", "negative_dsts") + for i, (node_pair, label, index) in enumerate(train_set): + assert node_pair[0] == train_seeds[i][0] + assert node_pair[1] == train_seeds[i][1] + assert label == train_labels[i] + assert index == train_indexes[i] + assert train_set.names == ("seeds", "labels", "indexes") train_set = None # Verify validation set. validation_set = dataset.tasks[0].validation_set - assert len(validation_set) == 1000 + assert len(validation_set) == 1000 * 11 assert isinstance(validation_set, gb.ItemSet) - for i, (node_pair, negs) in enumerate(validation_set): - assert node_pair[0] == validation_node_pairs[i][0] - assert node_pair[1] == validation_node_pairs[i][1] - assert torch.equal(negs, torch.from_numpy(validation_neg_dst[i])) - assert validation_set.names == ("node_pairs", "negative_dsts") + for i, (node_pair, label, index) in enumerate(validation_set): + assert node_pair[0] == validation_seeds[i][0] + assert node_pair[1] == validation_seeds[i][1] + assert label == validation_labels[i] + assert index == validation_indexes[i] + assert validation_set.names == ("seeds", "labels", "indexes") validation_set = None # Verify test set. test_set = dataset.tasks[0].test_set - assert len(test_set) == 1000 + assert len(test_set) == 1000 * 11 assert isinstance(test_set, gb.ItemSet) - for i, (node_pair, negs) in enumerate(test_set): - assert node_pair[0] == test_node_pairs[i][0] - assert node_pair[1] == test_node_pairs[i][1] - assert torch.equal(negs, torch.from_numpy(test_neg_dst[i])) - assert test_set.names == ("node_pairs", "negative_dsts") + for i, (node_pair, label, index) in enumerate(test_set): + assert node_pair[0] == test_seeds[i][0] + assert label == test_labels[i] + assert index == test_indexes[i] + assert test_set.names == ("seeds", "labels", "indexes") test_set = None dataset = None @@ -581,36 +619,36 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): train_set: - type: paper data: - - name: seed_nodes + - name: seeds format: numpy in_memory: true path: {train_path} - type: author data: - - name: seed_nodes + - name: seeds format: numpy path: {train_path} validation_set: - type: paper data: - - name: seed_nodes + - name: seeds format: numpy path: {validation_path} - type: author data: - - name: seed_nodes + - name: seeds format: numpy path: {validation_path} test_set: - type: paper data: - - name: seed_nodes + - name: seeds format: numpy in_memory: false path: {test_path} - type: author data: - - name: seed_nodes + - name: seeds format: numpy path: {test_path} """ @@ -628,7 +666,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): id, label = item[key] assert id == train_ids[i % 1000] assert label == train_labels[i % 1000] - assert train_set.names == ("seed_nodes",) + assert train_set.names == ("seeds",) train_set = None # Verify validation set. @@ -643,7 +681,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): id, label = item[key] assert id == validation_ids[i % 1000] assert label == validation_labels[i % 1000] - assert validation_set.names == ("seed_nodes",) + assert validation_set.names == ("seeds",) validation_set = None # Verify test set. @@ -658,7 +696,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): id, label = item[key] assert id == test_ids[i % 1000] assert label == test_labels[i % 1000] - assert test_set.names == ("seed_nodes",) + assert test_set.names == ("seeds",) test_set = None dataset = None @@ -666,25 +704,23 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels(): """Test TVTSet which returns ItemSetDict with node pairs and labels.""" with tempfile.TemporaryDirectory() as test_dir: - train_node_pairs = np.arange(2000).reshape(1000, 2) - train_node_pairs_path = os.path.join(test_dir, "train_node_pairs.npy") - np.save(train_node_pairs_path, train_node_pairs) + train_seeds = np.arange(2000).reshape(1000, 2) + train_seeds_path = os.path.join(test_dir, "train_seeds.npy") + np.save(train_seeds_path, train_seeds) train_labels = np.random.randint(0, 10, size=1000) train_labels_path = os.path.join(test_dir, "train_labels.npy") np.save(train_labels_path, train_labels) - validation_node_pairs = np.arange(2000, 4000).reshape(1000, 2) - validation_node_pairs_path = os.path.join( - test_dir, "validation_node_pairs.npy" - ) - np.save(validation_node_pairs_path, validation_node_pairs) + validation_seeds = np.arange(2000, 4000).reshape(1000, 2) + validation_seeds_path = os.path.join(test_dir, "validation_seeds.npy") + np.save(validation_seeds_path, validation_seeds) validation_labels = np.random.randint(0, 10, size=1000) validation_labels_path = os.path.join(test_dir, "validation_labels.npy") np.save(validation_labels_path, validation_labels) - test_node_pairs = np.arange(4000, 6000).reshape(1000, 2) - test_node_pairs_path = os.path.join(test_dir, "test_node_pairs.npy") - np.save(test_node_pairs_path, test_node_pairs) + test_seeds = np.arange(4000, 6000).reshape(1000, 2) + test_seeds_path = os.path.join(test_dir, "test_seeds.npy") + np.save(test_seeds_path, test_seeds) test_labels = np.random.randint(0, 10, size=1000) test_labels_path = os.path.join(test_dir, "test_labels.npy") np.save(test_labels_path, test_labels) @@ -695,56 +731,56 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels(): train_set: - type: paper:cites:paper data: - - name: node_pairs + - name: seeds format: numpy in_memory: true - path: {train_node_pairs_path} + path: {train_seeds_path} - name: labels format: numpy in_memory: true path: {train_labels_path} - type: author:writes:paper data: - - name: node_pairs + - name: seeds format: numpy - path: {train_node_pairs_path} + path: {train_seeds_path} - name: labels format: numpy path: {train_labels_path} validation_set: - type: paper:cites:paper data: - - name: node_pairs + - name: seeds format: numpy - path: {validation_node_pairs_path} + path: {validation_seeds_path} - name: labels format: numpy path: {validation_labels_path} - type: author:writes:paper data: - - name: node_pairs + - name: seeds format: numpy - path: {validation_node_pairs_path} + path: {validation_seeds_path} - name: labels format: numpy path: {validation_labels_path} test_set: - type: paper:cites:paper data: - - name: node_pairs + - name: seeds format: numpy in_memory: true - path: {test_node_pairs_path} + path: {test_seeds_path} - name: labels format: numpy in_memory: true path: {test_labels_path} - type: author:writes:paper data: - - name: node_pairs + - name: seeds format: numpy in_memory: true - path: {test_node_pairs_path} + path: {test_seeds_path} - name: labels format: numpy in_memory: true @@ -762,10 +798,10 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels(): key = list(item.keys())[0] assert key in ["paper:cites:paper", "author:writes:paper"] node_pair, label = item[key] - assert node_pair[0] == train_node_pairs[i % 1000][0] - assert node_pair[1] == train_node_pairs[i % 1000][1] + assert node_pair[0] == train_seeds[i % 1000][0] + assert node_pair[1] == train_seeds[i % 1000][1] assert label == train_labels[i % 1000] - assert train_set.names == ("node_pairs", "labels") + assert train_set.names == ("seeds", "labels") train_set = None # Verify validation set. @@ -778,10 +814,10 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels(): key = list(item.keys())[0] assert key in ["paper:cites:paper", "author:writes:paper"] node_pair, label = item[key] - assert node_pair[0] == validation_node_pairs[i % 1000][0] - assert node_pair[1] == validation_node_pairs[i % 1000][1] + assert node_pair[0] == validation_seeds[i % 1000][0] + assert node_pair[1] == validation_seeds[i % 1000][1] assert label == validation_labels[i % 1000] - assert validation_set.names == ("node_pairs", "labels") + assert validation_set.names == ("seeds", "labels") validation_set = None # Verify test set. @@ -794,10 +830,10 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels(): key = list(item.keys())[0] assert key in ["paper:cites:paper", "author:writes:paper"] node_pair, label = item[key] - assert node_pair[0] == test_node_pairs[i % 1000][0] - assert node_pair[1] == test_node_pairs[i % 1000][1] + assert node_pair[0] == test_seeds[i % 1000][0] + assert node_pair[1] == test_seeds[i % 1000][1] assert label == test_labels[i % 1000] - assert test_set.names == ("node_pairs", "labels") + assert test_set.names == ("seeds", "labels") test_set = None dataset = None @@ -1294,21 +1330,21 @@ def test_OnDiskDataset_preprocess_homogeneous_hardcode( f" train_set:\n" f" - type: null\n" f" data:\n" - f" - name: node_pairs\n" + f" - name: seeds\n" f" format: numpy\n" f" in_memory: true\n" f" path: {train_path}\n" f" validation_set:\n" f" - type: null\n" f" data:\n" - f" - name: node_pairs\n" + f" - name: seeds\n" f" format: numpy\n" f" in_memory: true\n" f" path: {valid_path}\n" f" test_set:\n" f" - type: null\n" f" data:\n" - f" - name: node_pairs\n" + f" - name: seeds\n" f" format: numpy\n" f" in_memory: true\n" f" path: {test_path}\n" @@ -2856,22 +2892,22 @@ def test_OnDiskDataset_auto_force_preprocess(capsys): def test_OnDiskTask_repr_homogeneous(): item_set = gb.ItemSet( (torch.arange(0, 5), torch.arange(5, 10)), - names=("seed_nodes", "labels"), + names=("seeds", "labels"), ) metadata = {"name": "node_classification"} task = gb.OnDiskTask(metadata, item_set, item_set, item_set) expected_str = ( "OnDiskTask(validation_set=ItemSet(\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" - " names=('seed_nodes', 'labels'),\n" + " names=('seeds', 'labels'),\n" " ),\n" " train_set=ItemSet(\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" - " names=('seed_nodes', 'labels'),\n" + " names=('seeds', 'labels'),\n" " ),\n" " test_set=ItemSet(\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" - " names=('seed_nodes', 'labels'),\n" + " names=('seeds', 'labels'),\n" " ),\n" " metadata={'name': 'node_classification'},)" ) @@ -2908,8 +2944,8 @@ def test_OnDiskDataset_not_include_eids(): def test_OnDiskTask_repr_heterogeneous(): item_set = gb.ItemSetDict( { - "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"), - "item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"), + "user": gb.ItemSet(torch.arange(0, 5), names="seeds"), + "item": gb.ItemSet(torch.arange(5, 10), names="seeds"), } ) metadata = {"name": "node_classification"} @@ -2918,32 +2954,32 @@ def test_OnDiskTask_repr_heterogeneous(): "OnDiskTask(validation_set=ItemSetDict(\n" " itemsets={'user': ItemSet(\n" " items=(tensor([0, 1, 2, 3, 4]),),\n" - " names=('seed_nodes',),\n" + " names=('seeds',),\n" " ), 'item': ItemSet(\n" " items=(tensor([5, 6, 7, 8, 9]),),\n" - " names=('seed_nodes',),\n" + " names=('seeds',),\n" " )},\n" - " names=('seed_nodes',),\n" + " names=('seeds',),\n" " ),\n" " train_set=ItemSetDict(\n" " itemsets={'user': ItemSet(\n" " items=(tensor([0, 1, 2, 3, 4]),),\n" - " names=('seed_nodes',),\n" + " names=('seeds',),\n" " ), 'item': ItemSet(\n" " items=(tensor([5, 6, 7, 8, 9]),),\n" - " names=('seed_nodes',),\n" + " names=('seeds',),\n" " )},\n" - " names=('seed_nodes',),\n" + " names=('seeds',),\n" " ),\n" " test_set=ItemSetDict(\n" " itemsets={'user': ItemSet(\n" " items=(tensor([0, 1, 2, 3, 4]),),\n" - " names=('seed_nodes',),\n" + " names=('seeds',),\n" " ), 'item': ItemSet(\n" " items=(tensor([5, 6, 7, 8, 9]),),\n" - " names=('seed_nodes',),\n" + " names=('seeds',),\n" " )},\n" - " names=('seed_nodes',),\n" + " names=('seeds',),\n" " ),\n" " metadata={'name': 'node_classification'},)" ) diff --git a/tests/python/pytorch/graphbolt/test_base.py b/tests/python/pytorch/graphbolt/test_base.py index 9ab0577f6d02..e8e34257d641 100644 --- a/tests/python/pytorch/graphbolt/test_base.py +++ b/tests/python/pytorch/graphbolt/test_base.py @@ -169,6 +169,31 @@ def test_etype_str_to_tuple(): _ = gb.etype_str_to_tuple(c_etype_str) +def test_seed_type_str_to_ntypes(): + """Convert etype from string to tuple.""" + # Test for node pairs. + seed_type_str = "user:like:item" + seed_size = 2 + node_type = gb.seed_type_str_to_ntypes(seed_type_str, seed_size) + assert node_type == ["user", "item"] + + # Test for node pairs. + seed_type_str = "user:item:user" + seed_size = 3 + node_type = gb.seed_type_str_to_ntypes(seed_type_str, seed_size) + assert node_type == ["user", "item", "user"] + + # Test for unexpected input: list. + seed_type_str = ["user", "item"] + with pytest.raises( + AssertionError, + match=re.escape( + "Passed-in seed type should be string, but got " + ), + ): + _ = gb.seed_type_str_to_ntypes(seed_type_str, 2) + + def test_isin(): elements = torch.tensor([2, 3, 5, 5, 20, 13, 11], device=F.ctx()) test_elements = torch.tensor([2, 5], device=F.ctx()) diff --git a/tests/python/pytorch/graphbolt/test_feature_fetcher.py b/tests/python/pytorch/graphbolt/test_feature_fetcher.py index 7e6bed8e3bbb..b1944f06bc44 100644 --- a/tests/python/pytorch/graphbolt/test_feature_fetcher.py +++ b/tests/python/pytorch/graphbolt/test_feature_fetcher.py @@ -25,7 +25,7 @@ def test_FeatureFetcher_invoke(): features[keys[1]] = gb.TorchBasedFeature(b) feature_store = gb.BasicFeatureStore(features) - itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") + itemset = gb.ItemSet(torch.arange(10), names="seeds") item_sampler = gb.ItemSampler(itemset, batch_size=2) num_layer = 2 fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] @@ -58,7 +58,7 @@ def test_FeatureFetcher_homo(): features[keys[1]] = gb.TorchBasedFeature(b) feature_store = gb.BasicFeatureStore(features) - itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") + itemset = gb.ItemSet(torch.arange(10), names="seeds") item_sampler = gb.ItemSampler(itemset, batch_size=2) num_layer = 2 fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] @@ -104,7 +104,7 @@ def add_node_and_edge_ids(minibatch): features[keys[1]] = gb.TorchBasedFeature(b) feature_store = gb.BasicFeatureStore(features) - itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") + itemset = gb.ItemSet(torch.arange(10), names="seeds") item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids) fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"]) @@ -152,8 +152,8 @@ def test_FeatureFetcher_hetero(): itemset = gb.ItemSetDict( { - "n1": gb.ItemSet(torch.LongTensor([0, 1]), names="seed_nodes"), - "n2": gb.ItemSet(torch.LongTensor([0, 1, 2]), names="seed_nodes"), + "n1": gb.ItemSet(torch.LongTensor([0, 1]), names="seeds"), + "n2": gb.ItemSet(torch.LongTensor([0, 1, 2]), names="seeds"), } ) item_sampler = gb.ItemSampler(itemset, batch_size=2) @@ -215,7 +215,7 @@ def add_node_and_edge_ids(minibatch): itemset = gb.ItemSetDict( { - "n1": gb.ItemSet(torch.randint(0, 20, (10,)), names="seed_nodes"), + "n1": gb.ItemSet(torch.randint(0, 20, (10,)), names="seeds"), } ) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) diff --git a/tests/python/pytorch/graphbolt/test_item_sampler.py b/tests/python/pytorch/graphbolt/test_item_sampler.py index a169d06ce850..127480d9aa15 100644 --- a/tests/python/pytorch/graphbolt/test_item_sampler.py +++ b/tests/python/pytorch/graphbolt/test_item_sampler.py @@ -1161,7 +1161,7 @@ def test_DistributedItemSampler( ): nprocs = 4 batch_size = 4 - item_set = gb.ItemSet(torch.arange(0, num_ids), names="seed_nodes") + item_set = gb.ItemSet(torch.arange(0, num_ids), names="seeds") # On Windows, if the process group initialization file already exists, # the program may hang. So we need to delete it if it exists. diff --git a/tests/python/pytorch/graphbolt/test_itemset.py b/tests/python/pytorch/graphbolt/test_itemset.py index 91dac881e61e..e41efcb2a2df 100644 --- a/tests/python/pytorch/graphbolt/test_itemset.py +++ b/tests/python/pytorch/graphbolt/test_itemset.py @@ -8,15 +8,15 @@ def test_ItemSet_names(): # ItemSet with single name. - item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes") - assert item_set.names == ("seed_nodes",) + item_set = gb.ItemSet(torch.arange(0, 5), names="seeds") + assert item_set.names == ("seeds",) # ItemSet with multiple names. item_set = gb.ItemSet( (torch.arange(0, 5), torch.arange(5, 10)), - names=("seed_nodes", "labels"), + names=("seeds", "labels"), ) - assert item_set.names == ("seed_nodes", "labels") + assert item_set.names == ("seeds", "labels") # ItemSet without name. item_set = gb.ItemSet(torch.arange(0, 5)) @@ -27,19 +27,19 @@ def test_ItemSet_names(): AssertionError, match=re.escape("Number of items (1) and names (2) must match."), ): - _ = gb.ItemSet(5, names=("seed_nodes", "labels")) + _ = gb.ItemSet(5, names=("seeds", "labels")) # ItemSet with mismatched items and names. with pytest.raises( AssertionError, match=re.escape("Number of items (1) and names (2) must match."), ): - _ = gb.ItemSet(torch.arange(0, 5), names=("seed_nodes", "labels")) + _ = gb.ItemSet(torch.arange(0, 5), names=("seeds", "labels")) @pytest.mark.parametrize("dtype", [torch.int32, torch.int64]) def test_ItemSet_scalar_dtype(dtype): - item_set = gb.ItemSet(torch.tensor(5, dtype=dtype), names="seed_nodes") + item_set = gb.ItemSet(torch.tensor(5, dtype=dtype), names="seeds") for i, item in enumerate(item_set): assert i == item assert item.dtype == dtype @@ -106,8 +106,8 @@ def __iter__(self): def test_ItemSet_seed_nodes(): # Node IDs with tensor. - item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes") - assert item_set.names == ("seed_nodes",) + item_set = gb.ItemSet(torch.arange(0, 5), names="seeds") + assert item_set.names == ("seeds",) # Iterating over ItemSet and indexing one by one. for i, item in enumerate(item_set): assert i == item.item() @@ -118,8 +118,8 @@ def test_ItemSet_seed_nodes(): assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5)) # Node IDs with single integer. - item_set = gb.ItemSet(5, names="seed_nodes") - assert item_set.names == ("seed_nodes",) + item_set = gb.ItemSet(5, names="seeds") + assert item_set.names == ("seeds",) # Iterating over ItemSet and indexing one by one. for i, item in enumerate(item_set): assert i == item.item() @@ -145,8 +145,8 @@ def test_ItemSet_seed_nodes_labels(): # Node IDs and labels. seed_nodes = torch.arange(0, 5) labels = torch.randint(0, 3, (5,)) - item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels")) - assert item_set.names == ("seed_nodes", "labels") + item_set = gb.ItemSet((seed_nodes, labels), names=("seeds", "labels")) + assert item_set.names == ("seeds", "labels") # Iterating over ItemSet and indexing one by one. for i, (seed_node, label) in enumerate(item_set): assert seed_node == seed_nodes[i] @@ -164,8 +164,8 @@ def test_ItemSet_seed_nodes_labels(): def test_ItemSet_node_pairs(): # Node pairs. node_pairs = torch.arange(0, 10).reshape(-1, 2) - item_set = gb.ItemSet(node_pairs, names="node_pairs") - assert item_set.names == ("node_pairs",) + item_set = gb.ItemSet(node_pairs, names="seeds") + assert item_set.names == ("seeds",) # Iterating over ItemSet and indexing one by one. for i, (src, dst) in enumerate(item_set): assert node_pairs[i][0] == src @@ -182,8 +182,8 @@ def test_ItemSet_node_pairs_labels(): # Node pairs and labels node_pairs = torch.arange(0, 10).reshape(-1, 2) labels = torch.randint(0, 3, (5,)) - item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels")) - assert item_set.names == ("node_pairs", "labels") + item_set = gb.ItemSet((node_pairs, labels), names=("seeds", "labels")) + assert item_set.names == ("seeds", "labels") # Iterating over ItemSet and indexing one by one. for i, (node_pair, label) in enumerate(item_set): assert torch.equal(node_pairs[i], node_pair) @@ -198,26 +198,31 @@ def test_ItemSet_node_pairs_labels(): assert torch.equal(item_set[torch.arange(0, 5)][1], labels) -def test_ItemSet_node_pairs_neg_dsts(): +def test_ItemSet_node_pairs_labels_indexes(): # Node pairs and negative destinations. node_pairs = torch.arange(0, 10).reshape(-1, 2) - neg_dsts = torch.arange(10, 25).reshape(-1, 3) + labels = torch.tensor([1, 1, 0, 0, 0]) + indexes = torch.tensor([0, 1, 0, 0, 1]) item_set = gb.ItemSet( - (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts") + (node_pairs, labels, indexes), names=("seeds", "labels", "indexes") ) - assert item_set.names == ("node_pairs", "negative_dsts") + assert item_set.names == ("seeds", "labels", "indexes") # Iterating over ItemSet and indexing one by one. - for i, (node_pair, neg_dst) in enumerate(item_set): + for i, (node_pair, label, index) in enumerate(item_set): assert torch.equal(node_pairs[i], node_pair) - assert torch.equal(neg_dsts[i], neg_dst) + assert torch.equal(labels[i], label) + assert torch.equal(indexes[i], index) assert torch.equal(node_pairs[i], item_set[i][0]) - assert torch.equal(neg_dsts[i], item_set[i][1]) + assert torch.equal(labels[i], item_set[i][1]) + assert torch.equal(indexes[i], item_set[i][2]) # Indexing with a slice. assert torch.equal(item_set[:][0], node_pairs) - assert torch.equal(item_set[:][1], neg_dsts) + assert torch.equal(item_set[:][1], labels) + assert torch.equal(item_set[:][2], indexes) # Indexing with an Iterable. assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs) - assert torch.equal(item_set[torch.arange(0, 5)][1], neg_dsts) + assert torch.equal(item_set[torch.arange(0, 5)][1], labels) + assert torch.equal(item_set[torch.arange(0, 5)][2], indexes) def test_ItemSet_graphs(): @@ -237,26 +242,26 @@ def test_ItemSetDict_names(): # ItemSetDict with single name. item_set = gb.ItemSetDict( { - "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"), - "item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"), + "user": gb.ItemSet(torch.arange(0, 5), names="seeds"), + "item": gb.ItemSet(torch.arange(5, 10), names="seeds"), } ) - assert item_set.names == ("seed_nodes",) + assert item_set.names == ("seeds",) # ItemSetDict with multiple names. item_set = gb.ItemSetDict( { "user": gb.ItemSet( (torch.arange(0, 5), torch.arange(5, 10)), - names=("seed_nodes", "labels"), + names=("seeds", "labels"), ), "item": gb.ItemSet( (torch.arange(5, 10), torch.arange(10, 15)), - names=("seed_nodes", "labels"), + names=("seeds", "labels"), ), } ) - assert item_set.names == ("seed_nodes", "labels") + assert item_set.names == ("seeds", "labels") # ItemSetDict with no name. item_set = gb.ItemSetDict( @@ -276,11 +281,9 @@ def test_ItemSetDict_names(): { "user": gb.ItemSet( (torch.arange(0, 5), torch.arange(5, 10)), - names=("seed_nodes", "labels"), - ), - "item": gb.ItemSet( - (torch.arange(5, 10),), names=("seed_nodes",) + names=("seeds", "labels"), ), + "item": gb.ItemSet((torch.arange(5, 10),), names=("seeds",)), } ) @@ -354,14 +357,14 @@ def test_ItemSetDict_iteration_seed_nodes(): user_ids = torch.arange(0, 5) item_ids = torch.arange(5, 10) ids = { - "user": gb.ItemSet(user_ids, names="seed_nodes"), - "item": gb.ItemSet(item_ids, names="seed_nodes"), + "user": gb.ItemSet(user_ids, names="seeds"), + "item": gb.ItemSet(item_ids, names="seeds"), } chained_ids = [] for key, value in ids.items(): chained_ids += [(key, v) for v in value] item_set = gb.ItemSetDict(ids) - assert item_set.names == ("seed_nodes",) + assert item_set.names == ("seeds",) # Iterating over ItemSetDict and indexing one by one. for i, item in enumerate(item_set): assert len(item) == 1 @@ -413,18 +416,14 @@ def test_ItemSetDict_iteration_seed_nodes_labels(): item_ids = torch.arange(5, 10) item_labels = torch.randint(0, 3, (5,)) ids_labels = { - "user": gb.ItemSet( - (user_ids, user_labels), names=("seed_nodes", "labels") - ), - "item": gb.ItemSet( - (item_ids, item_labels), names=("seed_nodes", "labels") - ), + "user": gb.ItemSet((user_ids, user_labels), names=("seeds", "labels")), + "item": gb.ItemSet((item_ids, item_labels), names=("seeds", "labels")), } chained_ids = [] for key, value in ids_labels.items(): chained_ids += [(key, v) for v in value] item_set = gb.ItemSetDict(ids_labels) - assert item_set.names == ("seed_nodes", "labels") + assert item_set.names == ("seeds", "labels") # Iterating over ItemSetDict and indexing one by one. for i, item in enumerate(item_set): assert len(item) == 1 @@ -443,14 +442,14 @@ def test_ItemSetDict_iteration_node_pairs(): # Node pairs. node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs_dict = { - "user:like:item": gb.ItemSet(node_pairs, names="node_pairs"), - "user:follow:user": gb.ItemSet(node_pairs, names="node_pairs"), + "user:like:item": gb.ItemSet(node_pairs, names="seeds"), + "user:follow:user": gb.ItemSet(node_pairs, names="seeds"), } expected_data = [] for key, value in node_pairs_dict.items(): expected_data += [(key, v) for v in value] item_set = gb.ItemSetDict(node_pairs_dict) - assert item_set.names == ("node_pairs",) + assert item_set.names == ("seeds",) # Iterating over ItemSetDict and indexing one by one. for i, item in enumerate(item_set): assert len(item) == 1 @@ -471,17 +470,17 @@ def test_ItemSetDict_iteration_node_pairs_labels(): labels = torch.randint(0, 3, (5,)) node_pairs_labels = { "user:like:item": gb.ItemSet( - (node_pairs, labels), names=("node_pairs", "labels") + (node_pairs, labels), names=("seeds", "labels") ), "user:follow:user": gb.ItemSet( - (node_pairs, labels), names=("node_pairs", "labels") + (node_pairs, labels), names=("seeds", "labels") ), } expected_data = [] for key, value in node_pairs_labels.items(): expected_data += [(key, v) for v in value] item_set = gb.ItemSetDict(node_pairs_labels) - assert item_set.names == ("node_pairs", "labels") + assert item_set.names == ("seeds", "labels") # Iterating over ItemSetDict and indexing one by one. for i, item in enumerate(item_set): assert len(item) == 1 @@ -501,23 +500,24 @@ def test_ItemSetDict_iteration_node_pairs_labels(): assert torch.equal(item_set[:]["user:follow:user"][1], labels) -def test_ItemSetDict_iteration_node_pairs_neg_dsts(): +def test_ItemSetDict_iteration_node_pairs_labels_indexes(): # Node pairs and negative destinations. node_pairs = torch.arange(0, 10).reshape(-1, 2) - neg_dsts = torch.arange(10, 25).reshape(-1, 3) + labels = torch.tensor([1, 1, 0, 0, 0]) + indexes = torch.tensor([0, 1, 0, 0, 1]) node_pairs_neg_dsts = { "user:like:item": gb.ItemSet( - (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts") + (node_pairs, labels, indexes), names=("seeds", "labels", "indexes") ), "user:follow:user": gb.ItemSet( - (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts") + (node_pairs, labels, indexes), names=("seeds", "labels", "indexes") ), } expected_data = [] for key, value in node_pairs_neg_dsts.items(): expected_data += [(key, v) for v in value] item_set = gb.ItemSetDict(node_pairs_neg_dsts) - assert item_set.names == ("node_pairs", "negative_dsts") + assert item_set.names == ("seeds", "labels", "indexes") # Iterating over ItemSetDict and indexing one by one. for i, item in enumerate(item_set): assert len(item) == 1 @@ -526,24 +526,28 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts(): assert key in item assert torch.equal(item[key][0], value[0]) assert torch.equal(item[key][1], value[1]) + assert torch.equal(item[key][2], value[2]) assert item_set[i].keys() == item.keys() key = list(item.keys())[0] assert torch.equal(item_set[i][key][0], item[key][0]) assert torch.equal(item_set[i][key][1], item[key][1]) + assert torch.equal(item_set[i][key][2], item[key][2]) # Indexing with a slice. assert torch.equal(item_set[:]["user:like:item"][0], node_pairs) - assert torch.equal(item_set[:]["user:like:item"][1], neg_dsts) + assert torch.equal(item_set[:]["user:like:item"][1], labels) + assert torch.equal(item_set[:]["user:like:item"][2], indexes) assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs) - assert torch.equal(item_set[:]["user:follow:user"][1], neg_dsts) + assert torch.equal(item_set[:]["user:follow:user"][1], labels) + assert torch.equal(item_set[:]["user:follow:user"][2], indexes) def test_ItemSet_repr(): # ItemSet with single name. - item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes") + item_set = gb.ItemSet(torch.arange(0, 5), names="seeds") expected_str = ( "ItemSet(\n" " items=(tensor([0, 1, 2, 3, 4]),),\n" - " names=('seed_nodes',),\n" + " names=('seeds',),\n" ")" ) @@ -552,12 +556,12 @@ def test_ItemSet_repr(): # ItemSet with multiple names. item_set = gb.ItemSet( (torch.arange(0, 5), torch.arange(5, 10)), - names=("seed_nodes", "labels"), + names=("seeds", "labels"), ) expected_str = ( "ItemSet(\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" - " names=('seed_nodes', 'labels'),\n" + " names=('seeds', 'labels'),\n" ")" ) assert str(item_set) == expected_str, item_set @@ -567,20 +571,20 @@ def test_ItemSetDict_repr(): # ItemSetDict with single name. item_set = gb.ItemSetDict( { - "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"), - "item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"), + "user": gb.ItemSet(torch.arange(0, 5), names="seeds"), + "item": gb.ItemSet(torch.arange(5, 10), names="seeds"), } ) expected_str = ( "ItemSetDict(\n" " itemsets={'user': ItemSet(\n" " items=(tensor([0, 1, 2, 3, 4]),),\n" - " names=('seed_nodes',),\n" + " names=('seeds',),\n" " ), 'item': ItemSet(\n" " items=(tensor([5, 6, 7, 8, 9]),),\n" - " names=('seed_nodes',),\n" + " names=('seeds',),\n" " )},\n" - " names=('seed_nodes',),\n" + " names=('seeds',),\n" ")" ) assert str(item_set) == expected_str, item_set @@ -590,11 +594,11 @@ def test_ItemSetDict_repr(): { "user": gb.ItemSet( (torch.arange(0, 5), torch.arange(5, 10)), - names=("seed_nodes", "labels"), + names=("seeds", "labels"), ), "item": gb.ItemSet( (torch.arange(5, 10), torch.arange(10, 15)), - names=("seed_nodes", "labels"), + names=("seeds", "labels"), ), } ) @@ -602,12 +606,12 @@ def test_ItemSetDict_repr(): "ItemSetDict(\n" " itemsets={'user': ItemSet(\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" - " names=('seed_nodes', 'labels'),\n" + " names=('seeds', 'labels'),\n" " ), 'item': ItemSet(\n" " items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])),\n" - " names=('seed_nodes', 'labels'),\n" + " names=('seeds', 'labels'),\n" " )},\n" - " names=('seed_nodes', 'labels'),\n" + " names=('seeds', 'labels'),\n" ")" ) assert str(item_set) == expected_str, item_set diff --git a/tests/python/pytorch/graphbolt/test_minibatch.py b/tests/python/pytorch/graphbolt/test_minibatch.py index 5d7fe61a2bd3..91522722430a 100644 --- a/tests/python/pytorch/graphbolt/test_minibatch.py +++ b/tests/python/pytorch/graphbolt/test_minibatch.py @@ -563,11 +563,10 @@ def test_dgl_link_predication_homo(): check_dgl_blocks_homo(minibatch, dgl_blocks) -@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"]) -def test_dgl_link_predication_hetero(mode): +def test_dgl_link_predication_hetero(): # Arrange minibatch = create_hetero_minibatch() - minibatch.compacted_node_pairs = { + minibatch.compacted_seeds = { relation: (torch.tensor([[1, 1, 2, 0, 1, 2], [1, 0, 1, 1, 0, 0]]).T,), reverse_relation: ( torch.tensor([[0, 1, 1, 2, 0, 2], [1, 0, 1, 1, 0, 0]]).T, diff --git a/tests/python/pytorch/graphbolt/test_subgraph_sampler.py b/tests/python/pytorch/graphbolt/test_subgraph_sampler.py index 6d5034859057..49d8713beb10 100644 --- a/tests/python/pytorch/graphbolt/test_subgraph_sampler.py +++ b/tests/python/pytorch/graphbolt/test_subgraph_sampler.py @@ -265,6 +265,38 @@ def test_SubgraphSampler_Link_With_Negative(sampler_type): _check_sampler_len(datapipe, 5) +@pytest.mark.parametrize( + "sampler_type", + [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], +) +def test_SubgraphSampler_HyperLink(sampler_type): + _check_sampler_type(sampler_type) + graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to( + F.ctx() + ) + items = torch.arange(20).reshape(-1, 5) + names = "seeds" + if sampler_type == SamplerType.Temporal: + graph.node_attributes = {"timestamp": torch.arange(20).to(F.ctx())} + graph.edge_attributes = { + "timestamp": torch.arange(len(graph.indices)).to(F.ctx()) + } + items = (items, torch.arange(4)) + names = (names, "timestamp") + itemset = gb.ItemSet(items, names=names) + datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) + num_layer = 2 + fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] + sampler = _get_sampler(sampler_type) + datapipe = sampler(datapipe, graph, fanouts) + _check_sampler_len(datapipe, 2) + for data in datapipe: + assert torch.equal( + data.compacted_seeds, + torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]).to(F.ctx()), + ) + + @pytest.mark.parametrize( "sampler_type", [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], @@ -487,6 +519,57 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type): _check_sampler_len(datapipe, 5) +@pytest.mark.parametrize( + "sampler_type", + [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], +) +def test_SubgraphSampler_HyperLink_Hetero(sampler_type): + _check_sampler_type(sampler_type) + graph = get_hetero_graph().to(F.ctx()) + items = torch.LongTensor([[2, 0, 1, 1, 2], [0, 1, 1, 0, 0]]) + names = "seeds" + if sampler_type == SamplerType.Temporal: + graph.node_attributes = { + "timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx()) + } + graph.edge_attributes = { + "timestamp": torch.arange(graph.indices.numel()).to(F.ctx()) + } + items = (items, torch.randint(0, 10, (2,))) + names = (names, "timestamp") + itemset = gb.ItemSetDict( + { + "n2:n1:n2:n1:n2": gb.ItemSet( + items, + names=names, + ), + } + ) + + datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) + num_layer = 2 + fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] + sampler = _get_sampler(sampler_type) + datapipe = sampler(datapipe, graph, fanouts) + _check_sampler_len(datapipe, 1) + for data in datapipe: + for compacted_seeds in data.compacted_seeds.values(): + if sampler_type == SamplerType.Temporal: + assert torch.equal( + compacted_seeds, + torch.tensor([[0, 0, 2, 2, 4], [1, 1, 3, 3, 5]]).to( + F.ctx() + ), + ) + else: + assert torch.equal( + compacted_seeds, + torch.tensor([[0, 0, 2, 1, 0], [1, 1, 2, 0, 1]]).to( + F.ctx() + ), + ) + + @pytest.mark.parametrize( "sampler_type", [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], @@ -1353,3 +1436,374 @@ def test_SubgraphSampler_unique_csc_format_Hetero_Link(labor): sampled_subgraph.sampled_csc[etype].indptr, csc_formats[step][etype].indptr.to(F.ctx()), ) + + +@pytest.mark.parametrize( + "sampler_type", + [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], +) +def test_SubgraphSampler_without_deduplication_Homo_HyperLink(sampler_type): + _check_sampler_type(sampler_type) + graph = dgl.graph( + ([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4]) + ) + graph = gb.from_dglgraph(graph, True).to(F.ctx()) + items = torch.LongTensor([[0, 1, 4], [3, 5, 6]]) + names = "seeds" + if sampler_type == SamplerType.Temporal: + graph.node_attributes = { + "timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx()) + } + graph.edge_attributes = { + "timestamp": torch.zeros(graph.indices.numel()).to(F.ctx()) + } + items = (items, torch.randint(1, 10, (2,))) + names = (names, "timestamp") + + itemset = gb.ItemSet(items, names=names) + item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) + num_layer = 2 + fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] + + sampler = _get_sampler(sampler_type) + if sampler_type == SamplerType.Temporal: + datapipe = sampler(item_sampler, graph, fanouts) + else: + datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False) + + length = [23, 11] + compacted_indices = [ + (torch.arange(0, 12) + 11).to(F.ctx()), + (torch.arange(0, 5) + 6).to(F.ctx()), + ] + indptr = [ + torch.tensor([0, 1, 2, 4, 5, 5, 5, 5, 6, 8, 10, 12]).to(F.ctx()), + torch.tensor([0, 1, 2, 4, 5, 5, 5]).to(F.ctx()), + ] + seeds = [ + torch.tensor([0, 0, 1, 2, 2, 3, 4, 4, 5, 5, 6]).to(F.ctx()), + torch.tensor([0, 1, 3, 4, 5, 6]).to(F.ctx()), + ] + for data in datapipe: + for step, sampled_subgraph in enumerate(data.sampled_subgraphs): + assert len(sampled_subgraph.original_row_node_ids) == length[step] + assert torch.equal( + sampled_subgraph.sampled_csc.indices, compacted_indices[step] + ) + assert torch.equal( + sampled_subgraph.sampled_csc.indptr, indptr[step] + ) + assert torch.equal( + torch.sort(sampled_subgraph.original_column_node_ids)[0], + seeds[step], + ) + + +@pytest.mark.parametrize( + "sampler_type", + [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], +) +def test_SubgraphSampler_without_deduplication_Hetero_HyperLink(sampler_type): + _check_sampler_type(sampler_type) + graph = get_hetero_graph().to(F.ctx()) + items = torch.arange(3).view(1, 3) + names = "seeds" + if sampler_type == SamplerType.Temporal: + graph.node_attributes = { + "timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx()) + } + graph.edge_attributes = { + "timestamp": torch.zeros(graph.indices.numel()).to(F.ctx()) + } + items = (items, torch.randint(1, 10, (1,))) + names = (names, "timestamp") + itemset = gb.ItemSetDict({"n2:n1:n2": gb.ItemSet(items, names=names)}) + item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) + num_layer = 2 + fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] + sampler = _get_sampler(sampler_type) + if sampler_type == SamplerType.Temporal: + datapipe = sampler(item_sampler, graph, fanouts) + else: + datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False) + csc_formats = [ + { + "n1:e1:n2": gb.CSCFormatBase( + indptr=torch.tensor([0, 2, 4, 6, 8]), + indices=torch.tensor([5, 6, 7, 8, 9, 10, 11, 12]), + ), + "n2:e2:n1": gb.CSCFormatBase( + indptr=torch.tensor([0, 2, 4, 6, 8, 10]), + indices=torch.tensor([4, 5, 6, 7, 8, 9, 10, 11, 12, 13]), + ), + }, + { + "n1:e1:n2": gb.CSCFormatBase( + indptr=torch.tensor([0, 2, 4]), + indices=torch.tensor([1, 2, 3, 4]), + ), + "n2:e2:n1": gb.CSCFormatBase( + indptr=torch.tensor([0, 2]), + indices=torch.tensor([2, 3], dtype=torch.int64), + ), + }, + ] + original_column_node_ids = [ + { + "n1": torch.tensor([1, 0, 1, 0, 1]), + "n2": torch.tensor([0, 2, 0, 1]), + }, + { + "n1": torch.tensor([1]), + "n2": torch.tensor([0, 2]), + }, + ] + original_row_node_ids = [ + { + "n1": torch.tensor([1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0]), + "n2": torch.tensor([0, 2, 0, 1, 0, 1, 0, 2, 0, 1, 0, 2, 0, 1]), + }, + { + "n1": torch.tensor([1, 0, 1, 0, 1]), + "n2": torch.tensor([0, 2, 0, 1]), + }, + ] + + for data in datapipe: + for step, sampled_subgraph in enumerate(data.sampled_subgraphs): + for ntype in ["n1", "n2"]: + assert torch.equal( + sampled_subgraph.original_row_node_ids[ntype], + original_row_node_ids[step][ntype].to(F.ctx()), + ) + assert torch.equal( + sampled_subgraph.original_column_node_ids[ntype], + original_column_node_ids[step][ntype].to(F.ctx()), + ) + for etype in ["n1:e1:n2", "n2:e2:n1"]: + assert torch.equal( + sampled_subgraph.sampled_csc[etype].indices, + csc_formats[step][etype].indices.to(F.ctx()), + ) + assert torch.equal( + sampled_subgraph.sampled_csc[etype].indptr, + csc_formats[step][etype].indptr.to(F.ctx()), + ) + + +@unittest.skipIf( + F._default_context_str == "gpu", + reason="Fails due to different result on the GPU.", +) +@pytest.mark.parametrize("labor", [False, True]) +def test_SubgraphSampler_unique_csc_format_Homo_HyperLink_cpu(labor): + torch.manual_seed(1205) + graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4])) + graph = gb.from_dglgraph(graph, True).to(F.ctx()) + seed_nodes = torch.LongTensor([[0, 3, 3], [4, 4, 4]]) + + itemset = gb.ItemSet(seed_nodes, names="seeds") + item_sampler = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx()) + num_layer = 2 + fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] + + Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler + datapipe = Sampler( + item_sampler, + graph, + fanouts, + deduplicate=True, + ) + + original_row_node_ids = [ + torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()), + torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()), + ] + compacted_indices = [ + torch.tensor([3, 4, 4, 2, 5, 6]).to(F.ctx()), + torch.tensor([3, 4, 4, 2]).to(F.ctx()), + ] + indptr = [ + torch.tensor([0, 1, 2, 4, 4, 6]).to(F.ctx()), + torch.tensor([0, 1, 2, 4]).to(F.ctx()), + ] + seeds = [ + torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()), + torch.tensor([0, 3, 4]).to(F.ctx()), + ] + for data in datapipe: + for step, sampled_subgraph in enumerate(data.sampled_subgraphs): + assert torch.equal( + sampled_subgraph.original_row_node_ids, + original_row_node_ids[step], + ) + assert torch.equal( + sampled_subgraph.sampled_csc.indices, compacted_indices[step] + ) + assert torch.equal( + sampled_subgraph.sampled_csc.indptr, indptr[step] + ) + assert torch.equal( + sampled_subgraph.original_column_node_ids, seeds[step] + ) + + +@unittest.skipIf( + F._default_context_str == "cpu", + reason="Fails due to different result on the CPU.", +) +@pytest.mark.parametrize("labor", [False, True]) +def test_SubgraphSampler_unique_csc_format_Homo_HyperLink_gpu(labor): + torch.manual_seed(1205) + graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4])) + graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx()) + seed_nodes = torch.LongTensor([[0, 3, 4], [4, 4, 3]]) + + itemset = gb.ItemSet(seed_nodes, names="seeds") + item_sampler = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx()) + num_layer = 2 + fanouts = [torch.LongTensor([-1]) for _ in range(num_layer)] + + Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler + datapipe = Sampler( + item_sampler, + graph, + fanouts, + deduplicate=True, + ) + + if torch.cuda.get_device_capability()[0] < 7: + original_row_node_ids = [ + torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()), + torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()), + ] + compacted_indices = [ + torch.tensor([4, 3, 2, 5, 5]).to(F.ctx()), + torch.tensor([4, 3, 2]).to(F.ctx()), + ] + indptr = [ + torch.tensor([0, 1, 2, 3, 5, 5]).to(F.ctx()), + torch.tensor([0, 1, 2, 3]).to(F.ctx()), + ] + seeds = [ + torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()), + torch.tensor([0, 3, 4]).to(F.ctx()), + ] + else: + original_row_node_ids = [ + torch.tensor([0, 3, 4, 5, 2, 7]).to(F.ctx()), + torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()), + ] + compacted_indices = [ + torch.tensor([3, 4, 2, 5, 5]).to(F.ctx()), + torch.tensor([3, 4, 2]).to(F.ctx()), + ] + indptr = [ + torch.tensor([0, 1, 2, 3, 3, 5]).to(F.ctx()), + torch.tensor([0, 1, 2, 3]).to(F.ctx()), + ] + seeds = [ + torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()), + torch.tensor([0, 3, 4]).to(F.ctx()), + ] + + for data in datapipe: + for step, sampled_subgraph in enumerate(data.sampled_subgraphs): + assert torch.equal( + sampled_subgraph.original_row_node_ids, + original_row_node_ids[step], + ) + assert torch.equal( + sampled_subgraph.sampled_csc.indices, compacted_indices[step] + ) + assert torch.equal( + sampled_subgraph.sampled_csc.indptr, indptr[step] + ) + assert torch.equal( + sampled_subgraph.original_column_node_ids, seeds[step] + ) + + +@pytest.mark.parametrize("labor", [False, True]) +def test_SubgraphSampler_unique_csc_format_Hetero_HyperLink(labor): + graph = get_hetero_graph().to(F.ctx()) + itemset = gb.ItemSetDict( + {"n1:n2:n1": gb.ItemSet(torch.tensor([[0, 1, 0]]), names="seeds")} + ) + item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) + num_layer = 2 + fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] + Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler + datapipe = Sampler( + item_sampler, + graph, + fanouts, + deduplicate=True, + ) + csc_formats = [ + { + "n1:e1:n2": gb.CSCFormatBase( + indptr=torch.tensor([0, 2, 4, 6]), + indices=torch.tensor([1, 0, 0, 1, 0, 1]), + ), + "n2:e2:n1": gb.CSCFormatBase( + indptr=torch.tensor([0, 2, 4]), + indices=torch.tensor([1, 2, 1, 0]), + ), + }, + { + "n1:e1:n2": gb.CSCFormatBase( + indptr=torch.tensor([0, 2]), + indices=torch.tensor([1, 0]), + ), + "n2:e2:n1": gb.CSCFormatBase( + indptr=torch.tensor([0, 2]), + indices=torch.tensor([1, 2], dtype=torch.int64), + ), + }, + ] + original_column_node_ids = [ + { + "n1": torch.tensor([0, 1]), + "n2": torch.tensor([0, 1, 2]), + }, + { + "n1": torch.tensor([0]), + "n2": torch.tensor([1]), + }, + ] + original_row_node_ids = [ + { + "n1": torch.tensor([0, 1]), + "n2": torch.tensor([0, 1, 2]), + }, + { + "n1": torch.tensor([0, 1]), + "n2": torch.tensor([0, 1, 2]), + }, + ] + + for data in datapipe: + for step, sampled_subgraph in enumerate(data.sampled_subgraphs): + for ntype in ["n1", "n2"]: + assert torch.equal( + torch.sort(sampled_subgraph.original_row_node_ids[ntype])[ + 0 + ], + original_row_node_ids[step][ntype].to(F.ctx()), + ) + assert torch.equal( + torch.sort( + sampled_subgraph.original_column_node_ids[ntype] + )[0], + original_column_node_ids[step][ntype].to(F.ctx()), + ) + for etype in ["n1:e1:n2", "n2:e2:n1"]: + assert torch.equal( + sampled_subgraph.sampled_csc[etype].indices, + csc_formats[step][etype].indices.to(F.ctx()), + ) + assert torch.equal( + sampled_subgraph.sampled_csc[etype].indptr, + csc_formats[step][etype].indptr.to(F.ctx()), + ) diff --git a/tests/scripts/build_dgl.bat b/tests/scripts/build_dgl.bat index b445a717e93c..c8996fbd4292 100644 --- a/tests/scripts/build_dgl.bat +++ b/tests/scripts/build_dgl.bat @@ -1,6 +1,10 @@ @ECHO OFF SETLOCAL EnableDelayedExpansion +ECHO "Current user: %USERNAME%" + +python --version + CALL "C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC\Auxiliary\Build\vcvars64.bat" CALL mkvirtualenv --system-site-packages %BUILD_TAG% DEL /S /Q build