diff --git a/src/array/cpu/labor_pick.h b/src/array/cpu/labor_pick.h index 23f99308ccae..27f2a5acabed 100644 --- a/src/array/cpu/labor_pick.h +++ b/src/array/cpu/labor_pick.h @@ -31,7 +31,6 @@ #include #include #include -#include #include #include #include @@ -124,7 +123,10 @@ auto compute_importance_sampling_probabilities( var_1 = 0; if (weighted) { for (auto j = indptr[rid]; j < indptr[rid + 1]; j++) - var_1 += A[j] * A[j] / std::min(ONE, c * ps[j - indptr[rid]]); + // The check for zero is necessary for numerical stability + var_1 += A[j] > 0 + ? A[j] * A[j] / std::min(ONE, c * ps[j - indptr[rid]]) + : 0; } else { for (auto j = indptr[rid]; j < indptr[rid + 1]; j++) var_1 += ONE / std::min(ONE, c * ps[j - indptr[rid]]); @@ -208,8 +210,9 @@ std::pair CSRLaborPick( IdArray picked_row = NDArray::Empty({hop_size}, vidtype, ctx); IdArray picked_col = NDArray::Empty({hop_size}, vidtype, ctx); IdArray picked_idx = NDArray::Empty({hop_size}, vidtype, ctx); - FloatArray picked_imp = - NDArray::Empty({importance_sampling ? hop_size : 0}, dtype, ctx); + FloatArray picked_imp = importance_sampling + ? NDArray::Empty({hop_size}, dtype, ctx) + : NullArray(); IdxType* picked_rdata = picked_row.Ptr(); IdxType* picked_cdata = picked_col.Ptr(); IdxType* picked_idata = picked_idx.Ptr(); diff --git a/src/array/cuda/labor_sampling.cu b/src/array/cuda/labor_sampling.cu index 48aeb76c114e..43ac0e62ef4d 100644 --- a/src/array/cuda/labor_sampling.cu +++ b/src/array/cuda/labor_sampling.cu @@ -66,11 +66,12 @@ struct TransformOp { const IdType* subindptr; const IdType* indices; const IdType* data_arr; + bool is_pinned; __host__ __device__ auto operator()(IdType idx) { const auto in_row = idx_coo[idx]; const auto row = rows[in_row]; - const auto in_idx = indptr[row] + idx - subindptr[in_row]; - const auto u = indices[in_idx]; + const auto in_idx = indptr[in_row] + idx - subindptr[in_row]; + const auto u = indices[is_pinned ? idx : in_idx]; const auto data = data_arr ? data_arr[in_idx] : in_idx; return thrust::make_tuple(row, u, data); } @@ -90,13 +91,14 @@ struct TransformOpImp { const IdType* subindptr; const IdType* indices; const IdType* data_arr; + bool is_pinned; __host__ __device__ auto operator()(IdType idx) { const auto ps = probs[idx]; const auto in_row = idx_coo[idx]; const auto c = cs[in_row]; const auto row = rows[in_row]; - const auto in_idx = indptr[row] + idx - subindptr[in_row]; - const auto u = indices[in_idx]; + const auto in_idx = indptr[in_row] + idx - subindptr[in_row]; + const auto u = indices[is_pinned ? idx : in_idx]; const auto w = A[in_idx]; const auto w2 = B[in_idx]; const auto data = data_arr ? data_arr[in_idx] : in_idx; @@ -123,17 +125,16 @@ struct StencilOpFused { const ps_t probs; const A_t A; const IdType* subindptr; - const IdType* rows; const IdType* indptr; const IdType* indices; const IdType* nids; + bool is_pinned; __device__ auto operator()(IdType idx) { const auto in_row = idx_coo[idx]; const auto ps = probs[idx]; IdType rofs = idx - subindptr[in_row]; - const IdType row = rows[in_row]; - const auto in_idx = indptr[row] + rofs; - const auto u = indices[in_idx]; + const auto in_idx = indptr[in_row] + rofs; + const auto u = indices[is_pinned ? idx : in_idx]; const auto t = nids ? nids[u] : u; // t in the paper curandStatePhilox4_32_10_t rng; // rolled random number r_t is a function of the random_seed and t @@ -162,7 +163,10 @@ struct TransformOpMinWith1 { template struct IndptrFunc { const IdType* indptr; - __host__ __device__ auto operator()(IdType row) { return indptr[row]; } + const IdType* in_deg; + __host__ __device__ auto operator()(IdType row) { + return indptr[row] + (in_deg ? in_deg[row] : 0); + } }; template @@ -186,24 +190,26 @@ struct DegreeFunc { const IdType num_picks; const IdType* rows; const IdType* indptr; - const FloatType* ds; IdType* in_deg; + IdType* inrow_indptr; FloatType* cs; __host__ __device__ auto operator()(IdType tIdx) { const auto out_row = rows[tIdx]; - const auto d = indptr[out_row + 1] - indptr[out_row]; + const auto indptr_val = indptr[out_row]; + const auto d = indptr[out_row + 1] - indptr_val; in_deg[tIdx] = d; - cs[tIdx] = num_picks / (ds ? ds[tIdx] : (FloatType)d); + inrow_indptr[tIdx] = indptr_val; + cs[tIdx] = num_picks / (FloatType)d; } }; template __global__ void _CSRRowWiseOneHopExtractorKernel( - const uint64_t rand_seed, const IdType hop_size, const IdType* const rows, - const IdType* const indptr, const IdType* const subindptr, - const IdType* const indices, const IdType* const idx_coo, - const IdType* const nids, const FloatType* const A, FloatType* const rands, - IdType* const hop, FloatType* const A_l) { + const uint64_t rand_seed, const IdType hop_size, const IdType* const indptr, + const IdType* const subindptr, const IdType* const indices, + const IdType* const idx_coo, const IdType* const nids, + const FloatType* const A, FloatType* const rands, IdType* const hop, + FloatType* const A_l) { IdType tx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; const int stride_x = gridDim.x * blockDim.x; @@ -212,10 +218,10 @@ __global__ void _CSRRowWiseOneHopExtractorKernel( while (tx < hop_size) { IdType rpos = idx_coo[tx]; IdType rofs = tx - subindptr[rpos]; - const IdType row = rows[rpos]; - const auto in_idx = indptr[row] + rofs; - const auto u = indices[in_idx]; - hop[tx] = u; + const auto in_idx = indptr[rpos] + rofs; + const auto not_pinned = indices != hop; + const auto u = indices[not_pinned ? in_idx : tx]; + if (not_pinned) hop[tx] = u; const auto v = nids ? nids[u] : u; // 123123 is just a number with no significance. curand_init(123123, rand_seed, v, &rng); @@ -226,10 +232,53 @@ __global__ void _CSRRowWiseOneHopExtractorKernel( } } +constexpr int CACHE_LINE_SIZE = 128; + +template +struct AlignmentFunc { + static_assert(CACHE_LINE_SIZE % sizeof(IdType) == 0); + const IdType* in_deg; + const int64_t* perm; + IdType num_rows; + __host__ __device__ auto operator()(IdType row) { + constexpr int num_elements = CACHE_LINE_SIZE / sizeof(IdType); + return in_deg[perm ? perm[row % num_rows] : row] + num_elements - 1; + } +}; + +template +__global__ void _CSRRowWiseOneHopExtractorAlignedKernel( + const IdType hop_size, const IdType num_rows, const IdType* const indptr, + const IdType* const subindptr, const IdType* const subindptr_aligned, + const IdType* const indices, IdType* const hop, const int64_t* const perm) { + IdType tx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const int stride_x = gridDim.x * blockDim.x; + + while (tx < hop_size) { + const IdType rpos_ = + dgl::cuda::_UpperBound(subindptr_aligned, num_rows, tx) - 1; + const IdType rpos = perm ? perm[rpos_] : rpos_; + const auto out_row = subindptr[rpos]; + const auto d = subindptr[rpos + 1] - out_row; + const int offset = + ((uint64_t)(indices + indptr[rpos] - subindptr_aligned[rpos_]) % + CACHE_LINE_SIZE) / + sizeof(IdType); + const IdType rofs = tx - subindptr_aligned[rpos_] - offset; + if (rofs >= 0 && rofs < d) { + const auto in_idx = indptr[rpos] + rofs; + assert((uint64_t)(indices + in_idx - tx) % CACHE_LINE_SIZE == 0); + const auto u = indices[in_idx]; + hop[out_row + rofs] = u; + } + tx += stride_x; + } +} + template __global__ void _CSRRowWiseLayerSampleDegreeKernel( - const IdType num_picks, const IdType num_rows, const IdType* const rows, - FloatType* const cs, const FloatType* const ds, const FloatType* const d2s, + const IdType num_picks, const IdType num_rows, FloatType* const cs, + const FloatType* const ds, const FloatType* const d2s, const IdType* const indptr, const FloatType* const probs, const FloatType* const A, const IdType* const subindptr) { typedef cub::BlockReduce BlockReduce; @@ -247,21 +296,19 @@ __global__ void _CSRRowWiseLayerSampleDegreeKernel( constexpr FloatType ONE = 1; while (out_row < last_row) { - const auto row = rows[out_row]; - - const auto in_row_start = indptr[row]; + const auto in_row_start = indptr[out_row]; const auto out_row_start = subindptr[out_row]; - const IdType degree = indptr[row + 1] - in_row_start; + const IdType degree = subindptr[out_row + 1] - out_row_start; if (degree > 0) { // stands for k in in arXiv:2210.13339, i.e. fanout const auto k = min(num_picks, degree); // slightly better than NS - const FloatType d_ = ds ? ds[row] : degree; + const FloatType d_ = ds ? ds[out_row] : degree; // stands for right handside of Equation (22) in arXiv:2210.13339 FloatType var_target = - d_ * d_ / k + (ds ? d2s[row] - d_ * d_ / degree : 0); + d_ * d_ / k + (ds ? d2s[out_row] - d_ * d_ / degree : 0); auto c = cs[out_row]; const int num_valid = min(degree, (IdType)CTA_SIZE); @@ -273,7 +320,7 @@ __global__ void _CSRRowWiseLayerSampleDegreeKernel( for (int idx = threadIdx.x; idx < degree; idx += CTA_SIZE) { const auto w = A[in_row_start + idx]; const auto ps = probs ? probs[out_row_start + idx] : w; - var_1 += w * w / min(ONE, c * ps); + var_1 += w > 0 ? w * w / min(ONE, c * ps) : 0; } } else { for (int idx = threadIdx.x; idx < degree; idx += CTA_SIZE) { @@ -298,12 +345,20 @@ __global__ void _CSRRowWiseLayerSampleDegreeKernel( } // namespace +template +int log_size(const IdType size) { + if (size <= 0) return 0; + for (int i = 0; i < static_cast(sizeof(IdType)) * 8; i++) + if (((size - 1) >> i) == 0) return i; + return sizeof(IdType) * 8; +} + template void compute_importance_sampling_probabilities( CSRMatrix mat, const IdType hop_size, cudaStream_t stream, - const uint64_t random_seed, const IdType num_rows, const IdType* rows, - const IdType* indptr, const IdType* subindptr, const IdType* indices, - IdArray idx_coo_arr, const IdType* nids, + const uint64_t random_seed, const IdType num_rows, const IdType* indptr, + const IdType* subindptr, const IdType* indices, IdArray idx_coo_arr, + const IdType* nids, FloatArray cs_arr, // holds the computed cs values, has size num_rows const bool weighted, const FloatType* A, const FloatType* ds, const FloatType* d2s, const IdType num_picks, DGLContext ctx, @@ -322,19 +377,15 @@ void compute_importance_sampling_probabilities( : NullArray(); auto A_l = A_l_arr.Ptr(); - const uint64_t max_log_num_vertices = [&]() -> int { - for (int i = 0; i < static_cast(sizeof(IdType)) * 8; i++) - if (mat.num_cols <= ((IdType)1) << i) return i; - return sizeof(IdType) * 8; - }(); + const int max_log_num_vertices = log_size(mat.num_cols); { // extracts the onehop neighborhood cols to a contiguous range into hop_1 const dim3 block(BLOCK_SIZE); const dim3 grid((hop_size + BLOCK_SIZE - 1) / BLOCK_SIZE); CUDA_KERNEL_CALL( (_CSRRowWiseOneHopExtractorKernel), grid, block, 0, - stream, random_seed, hop_size, rows, indptr, subindptr, indices, - idx_coo, nids, weighted ? A : nullptr, rands, hop_1, A_l); + stream, random_seed, hop_size, indptr, subindptr, indices, idx_coo, + nids, weighted ? A : nullptr, rands, hop_1, A_l); } int64_t hop_uniq_size = 0; IdArray hop_new_arr = NewIdArray(hop_size, ctx, sizeof(IdType) * 8); @@ -445,7 +496,7 @@ void compute_importance_sampling_probabilities( CUDA_KERNEL_CALL( (_CSRRowWiseLayerSampleDegreeKernel< IdType, FloatType, BLOCK_CTAS, TILE_SIZE>), - grid, block, 0, stream, (IdType)num_picks, num_rows, rows, cs, + grid, block, 0, stream, (IdType)num_picks, num_rows, cs, weighted ? ds : nullptr, weighted ? d2s : nullptr, indptr, probs_found, A, subindptr); } @@ -484,10 +535,12 @@ std::pair CSRLaborSampling( IdType* const nids = IsNullArray(NIDs) ? nullptr : NIDs.Ptr(); FloatType* const A = prob_arr.Ptr(); - IdType* const indptr = mat.indptr.Ptr(); - IdType* const indices = mat.indices.Ptr(); + IdType* const indptr_ = mat.indptr.Ptr(); + IdType* const indices_ = mat.indices.Ptr(); IdType* const data = CSRHasData(mat) ? mat.data.Ptr() : nullptr; + // Read indptr only once in case it is pinned and access is slow. + auto indptr = allocator.alloc_unique(num_rows); // compute in-degrees auto in_deg = allocator.alloc_unique(num_rows + 1); // cs stands for c_s in arXiv:2210.13339 @@ -504,11 +557,17 @@ std::pair CSRLaborSampling( : NullArray(); auto d2s = d2s_arr.Ptr(); + thrust::counting_iterator iota(0); + thrust::for_each( + exec_policy, iota, iota + num_rows, + DegreeFunc{ + (IdType)num_picks, rows, indptr_, in_deg.get(), indptr.get(), cs}); + if (weighted) { - auto b_offsets = - thrust::make_transform_iterator(rows, IndptrFunc{indptr}); - auto e_offsets = - thrust::make_transform_iterator(rows, IndptrFunc{indptr + 1}); + auto b_offsets = thrust::make_transform_iterator( + iota, IndptrFunc{indptr.get(), nullptr}); + auto e_offsets = thrust::make_transform_iterator( + iota, IndptrFunc{indptr.get(), in_deg.get()}); auto A_A2 = thrust::make_transform_iterator(A, SquareFunc{}); auto ds_d2s = thrust::make_zip_iterator(ds, d2s); @@ -524,13 +583,6 @@ std::pair CSRLaborSampling( stream)); } - thrust::counting_iterator iota(0); - thrust::for_each( - exec_policy, iota, iota + num_rows, - DegreeFunc{ - (IdType)num_picks, rows, indptr, weighted ? ds : nullptr, - in_deg.get(), cs}); - // fill subindptr IdArray subindptr_arr = NewIdArray(num_rows + 1, ctx, sizeof(IdType) * 8); auto subindptr = subindptr_arr.Ptr(); @@ -560,6 +612,38 @@ std::pair CSRLaborSampling( auto idx_coo = idx_coo_arr.Ptr(); auto hop_1 = hop_arr.Ptr(); + const bool is_pinned = mat.indices.IsPinned(); + if (is_pinned) { + const auto res = Sort(rows_arr, log_size(mat.num_rows)); + const int64_t* perm = static_cast(res.second->data); + + IdType hop_size; // Shadows the original one as this is temporary + auto subindptr_aligned = allocator.alloc_unique(num_rows + 1); + { + auto modified_in_deg = thrust::make_transform_iterator( + iota, AlignmentFunc{in_deg.get(), perm, num_rows}); + size_t prefix_temp_size = 0; + CUDA_CALL(cub::DeviceScan::ExclusiveSum( + nullptr, prefix_temp_size, modified_in_deg, subindptr_aligned.get(), + num_rows + 1, stream)); + auto temp = allocator.alloc_unique(prefix_temp_size); + CUDA_CALL(cub::DeviceScan::ExclusiveSum( + temp.get(), prefix_temp_size, modified_in_deg, + subindptr_aligned.get(), num_rows + 1, stream)); + + device->CopyDataFromTo( + subindptr_aligned.get(), num_rows * sizeof(hop_size), &hop_size, 0, + sizeof(hop_size), ctx, DGLContext{kDGLCPU, 0}, mat.indptr->dtype); + } + const dim3 block(BLOCK_SIZE); + const dim3 grid((hop_size + BLOCK_SIZE - 1) / BLOCK_SIZE); + CUDA_KERNEL_CALL( + (_CSRRowWiseOneHopExtractorAlignedKernel), grid, block, 0, + stream, hop_size, num_rows, indptr.get(), subindptr, + subindptr_aligned.get(), indices_, hop_1, perm); + } + const auto indices = is_pinned ? hop_1 : indices_; + auto rands = allocator.alloc_unique(importance_sampling ? hop_size : 1); auto probs_found = @@ -575,8 +659,8 @@ std::pair CSRLaborSampling( CUDA_KERNEL_CALL( (_CSRRowWiseLayerSampleDegreeKernel< IdType, FloatType, BLOCK_CTAS, TILE_SIZE>), - grid, block, 0, stream, (IdType)num_picks, num_rows, rows, cs, ds, d2s, - indptr, nullptr, A, subindptr); + grid, block, 0, stream, (IdType)num_picks, num_rows, cs, ds, d2s, + indptr.get(), nullptr, A, subindptr); } const uint64_t random_seed = @@ -587,7 +671,7 @@ std::pair CSRLaborSampling( if (importance_sampling) compute_importance_sampling_probabilities< IdType, FloatType, decltype(exec_policy)>( - mat, hop_size, stream, random_seed, num_rows, rows, indptr, subindptr, + mat, hop_size, stream, random_seed, num_rows, indptr.get(), subindptr, indices, idx_coo_arr, nids, cs_arr, weighted, A, ds, d2s, (IdType)num_picks, ctx, allocator, exec_policy, importance_sampling, hop_1, rands.get(), probs_found.get()); @@ -621,8 +705,8 @@ std::pair CSRLaborSampling( output, TransformOpImp< IdType, FloatType, FloatType*, FloatType*, decltype(one)>{ - probs_found.get(), A, one, idx_coo, rows, cs, indptr, subindptr, - indices, data}); + probs_found.get(), A, one, idx_coo, rows, cs, indptr.get(), + subindptr, indices, data, is_pinned}); auto stencil = thrust::make_zip_iterator(idx_coo, probs_found.get(), rands.get()); num_edges = @@ -635,8 +719,8 @@ std::pair CSRLaborSampling( output, TransformOpImp< IdType, FloatType, FloatType*, decltype(one), decltype(one)>{ - probs_found.get(), one, one, idx_coo, rows, cs, indptr, - subindptr, indices, data}); + probs_found.get(), one, one, idx_coo, rows, cs, indptr.get(), + subindptr, indices, data, is_pinned}); auto stencil = thrust::make_zip_iterator(idx_coo, probs_found.get(), rands.get()); num_edges = @@ -654,12 +738,12 @@ std::pair CSRLaborSampling( output, TransformOpImp< IdType, FloatType, decltype(one), FloatType*, FloatType*>{ - one, A, A, idx_coo, rows, cs, indptr, subindptr, indices, - data}); + one, A, A, idx_coo, rows, cs, indptr.get(), subindptr, indices, + data, is_pinned}); const auto pred = StencilOpFused{ - random_seed, idx_coo, cs, one, A, - subindptr, rows, indptr, indices, nids}; + random_seed, idx_coo, cs, one, A, + subindptr, indptr.get(), indices, nids, is_pinned}; num_edges = thrust::copy_if( exec_policy, iota, iota + hop_size, iota, transformed_output, pred) - @@ -669,11 +753,12 @@ std::pair CSRLaborSampling( picked_row_data, picked_col_data, picked_idx_data); auto transformed_output = thrust::make_transform_output_iterator( output, TransformOp{ - idx_coo, rows, indptr, subindptr, indices, data}); + idx_coo, rows, indptr.get(), subindptr, indices, data, + is_pinned}); const auto pred = StencilOpFused{ - random_seed, idx_coo, cs, one, one, - subindptr, rows, indptr, indices, nids}; + random_seed, idx_coo, cs, one, one, + subindptr, indptr.get(), indices, nids, is_pinned}; num_edges = thrust::copy_if( exec_policy, iota, iota + hop_size, iota, transformed_output, pred) - diff --git a/src/graph/sampling/neighbor/neighbor.cc b/src/graph/sampling/neighbor/neighbor.cc index 5b31c14c8a3c..ec586e2ba91e 100644 --- a/src/graph/sampling/neighbor/neighbor.cc +++ b/src/graph/sampling/neighbor/neighbor.cc @@ -115,47 +115,44 @@ std::pair> SampleLabors( hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype), hg->DataType(), ctx); induced_edges[etype] = aten::NullArray(hg->DataType(), ctx); - } else if (fanouts[etype] == -1) { - const auto& earr = (dir == EdgeDir::kOut) - ? hg->OutEdges(etype, nodes_ntype) - : hg->InEdges(etype, nodes_ntype); - subrels[etype] = UnitGraph::CreateFromCOO( - hg->GetRelationGraph(etype)->NumVertexTypes(), - hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype), earr.src, - earr.dst); - induced_edges[etype] = earr.id; + subimportances[etype] = NullArray(); } else { // sample from one relation graph auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE; auto avail_fmt = hg->SelectFormat(etype, req_fmt); COOMatrix sampled_coo; FloatArray importances; + const int64_t fanout = + fanouts[etype] >= 0 + ? fanouts[etype] + : std::max( + hg->NumVertices(dst_vtype), hg->NumVertices(src_vtype)); switch (avail_fmt) { case SparseFormat::kCOO: if (dir == EdgeDir::kIn) { auto fs = aten::COOLaborSampling( aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype, - fanouts[etype], prob[etype], importance_sampling, random_seed, + fanout, prob[etype], importance_sampling, random_seed, NIDs_ntype); sampled_coo = aten::COOTranspose(fs.first); importances = fs.second; } else { std::tie(sampled_coo, importances) = aten::COOLaborSampling( - hg->GetCOOMatrix(etype), nodes_ntype, fanouts[etype], - prob[etype], importance_sampling, random_seed, NIDs_ntype); + hg->GetCOOMatrix(etype), nodes_ntype, fanout, prob[etype], + importance_sampling, random_seed, NIDs_ntype); } break; case SparseFormat::kCSR: CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix."; std::tie(sampled_coo, importances) = aten::CSRLaborSampling( - hg->GetCSRMatrix(etype), nodes_ntype, fanouts[etype], prob[etype], + hg->GetCSRMatrix(etype), nodes_ntype, fanout, prob[etype], importance_sampling, random_seed, NIDs_ntype); break; case SparseFormat::kCSC: CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix."; std::tie(sampled_coo, importances) = aten::CSRLaborSampling( - hg->GetCSCMatrix(etype), nodes_ntype, fanouts[etype], prob[etype], + hg->GetCSCMatrix(etype), nodes_ntype, fanout, prob[etype], importance_sampling, random_seed, NIDs_ntype); sampled_coo = aten::COOTranspose(sampled_coo); break; diff --git a/tests/python/common/sampling/test_sampling.py b/tests/python/common/sampling/test_sampling.py index 4e2195a502fe..f62ba858afc6 100644 --- a/tests/python/common/sampling/test_sampling.py +++ b/tests/python/common/sampling/test_sampling.py @@ -681,6 +681,120 @@ def _test3(p, replace): assert subg["flips"].num_edges() == 4 +def _test_sample_labors(hypersparse, prob): + g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False) + + # test with seed nodes [0, 1] + def _test1(p): + subg = dgl.sampling.sample_labors(g, [0, 1], -1, prob=p)[0] + assert subg.num_nodes() == g.num_nodes() + u, v = subg.edges() + u_ans, v_ans, e_ans = g.in_edges([0, 1], form="all") + if p is not None: + emask = F.gather_row(g.edata[p], e_ans) + if p == "prob": + emask = emask != 0 + u_ans = F.boolean_mask(u_ans, emask) + v_ans = F.boolean_mask(v_ans, emask) + uv = set(zip(F.asnumpy(u), F.asnumpy(v))) + uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans))) + assert uv == uv_ans + + for i in range(10): + subg = dgl.sampling.sample_labors(g, [0, 1], 2, prob=p)[0] + assert subg.num_nodes() == g.num_nodes() + assert subg.num_edges() >= 0 + u, v = subg.edges() + assert set(F.asnumpy(F.unique(v))).issubset({0, 1}) + assert F.array_equal( + F.astype(g.has_edges_between(u, v), F.int64), + F.ones((subg.num_edges(),), dtype=F.int64), + ) + assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID]) + edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) + # check no duplication + assert len(edge_set) == subg.num_edges() + if p is not None: + assert not (3, 0) in edge_set + assert not (3, 1) in edge_set + + _test1(prob) + + # test with seed nodes [0, 2] + def _test2(p): + subg = dgl.sampling.sample_labors(g, [0, 2], -1, prob=p)[0] + assert subg.num_nodes() == g.num_nodes() + u, v = subg.edges() + u_ans, v_ans, e_ans = g.in_edges([0, 2], form="all") + if p is not None: + emask = F.gather_row(g.edata[p], e_ans) + if p == "prob": + emask = emask != 0 + u_ans = F.boolean_mask(u_ans, emask) + v_ans = F.boolean_mask(v_ans, emask) + uv = set(zip(F.asnumpy(u), F.asnumpy(v))) + uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans))) + assert uv == uv_ans + + for i in range(10): + subg = dgl.sampling.sample_labors(g, [0, 2], 2, prob=p)[0] + assert subg.num_nodes() == g.num_nodes() + assert subg.num_edges() >= 0 + u, v = subg.edges() + assert set(F.asnumpy(F.unique(v))).issubset({0, 2}) + assert F.array_equal( + F.astype(g.has_edges_between(u, v), F.int64), + F.ones((subg.num_edges(),), dtype=F.int64), + ) + assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID]) + edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) + # check no duplication + assert len(edge_set) == subg.num_edges() + if p is not None: + assert not (3, 0) in edge_set + + _test2(prob) + + # test with heterogenous seed nodes + def _test3(p): + subg = dgl.sampling.sample_labors( + hg, {"user": [0, 1], "game": 0}, -1, prob=p + )[0] + assert len(subg.ntypes) == 3 + assert len(subg.etypes) == 4 + assert subg["follow"].num_edges() == 6 if p is None else 4 + assert subg["play"].num_edges() == 1 + assert subg["liked-by"].num_edges() == 4 + assert subg["flips"].num_edges() == 0 + + for i in range(10): + subg = dgl.sampling.sample_labors( + hg, {"user": [0, 1], "game": 0}, 2, prob=p + )[0] + assert len(subg.ntypes) == 3 + assert len(subg.etypes) == 4 + assert subg["follow"].num_edges() >= 0 + assert subg["play"].num_edges() >= 0 + assert subg["liked-by"].num_edges() >= 0 + assert subg["flips"].num_edges() >= 0 + + _test3(prob) + + # test different fanouts for different relations + for i in range(10): + subg = dgl.sampling.sample_labors( + hg, + {"user": [0, 1], "game": 0, "coin": 0}, + {"follow": 1, "play": 2, "liked-by": 0, "flips": g.num_nodes()}, + )[0] + assert len(subg.ntypes) == 3 + assert len(subg.etypes) == 4 + assert subg["follow"].num_edges() >= 0 + assert subg["play"].num_edges() >= 0 + assert subg["liked-by"].num_edges() == 0 + assert subg["flips"].num_edges() == 4 + + def _test_sample_neighbors_outedge(hypersparse): g, hg = _gen_neighbor_sampling_test_graph(hypersparse, True) @@ -967,11 +1081,19 @@ def test_sample_neighbors_noprob(): # _test_sample_neighbors(True) +def test_sample_labors_noprob(): + _test_sample_labors(False, None) + + def test_sample_neighbors_prob(): _test_sample_neighbors(False, "prob") # _test_sample_neighbors(True) +def test_sample_labors_prob(): + _test_sample_labors(False, "prob") + + def test_sample_neighbors_outedge(): _test_sample_neighbors_outedge(False) # _test_sample_neighbors_outedge(True) @@ -1575,7 +1697,9 @@ def test_global_uniform_negative_sampling(dtype): from itertools import product test_sample_neighbors_noprob() + test_sample_labors_noprob() test_sample_neighbors_prob() + test_sample_labors_prob() test_sample_neighbors_mask() for args in product(["coo", "csr", "csc"], ["in", "out"], [False, True]): test_sample_neighbors_etype_homogeneous(*args) diff --git a/tests/python/pytorch/dataloading/test_dataloader.py b/tests/python/pytorch/dataloading/test_dataloader.py index 414953b3a766..cae8ba49ea65 100644 --- a/tests/python/pytorch/dataloading/test_dataloader.py +++ b/tests/python/pytorch/dataloading/test_dataloader.py @@ -320,7 +320,9 @@ def _ddp_runner(proc_id, nprocs, g, data, args): @parametrize_idtype -@pytest.mark.parametrize("sampler_name", ["full", "neighbor", "neighbor2"]) +@pytest.mark.parametrize( + "sampler_name", ["full", "neighbor", "neighbor2", "labor"] +) @pytest.mark.parametrize( "mode", ["cpu", "uva_cuda_indices", "uva_cpu_indices", "pure_gpu"] ) @@ -353,6 +355,7 @@ def test_node_dataloader(idtype, sampler_name, mode, use_ddp): "full": dgl.dataloading.MultiLayerFullNeighborSampler(2), "neighbor": dgl.dataloading.MultiLayerNeighborSampler([3, 3]), "neighbor2": dgl.dataloading.MultiLayerNeighborSampler([3, 3]), + "labor": dgl.dataloading.LaborSampler([3, 3]), }[sampler_name] for num_workers in [0, 1, 2] if mode == "cpu" else [0]: dataloader = dgl.dataloading.DataLoader( @@ -405,6 +408,7 @@ def test_node_dataloader(idtype, sampler_name, mode, use_ddp): [{etype: 3 for etype in g2.etypes}] * 2 ), "neighbor2": dgl.dataloading.MultiLayerNeighborSampler([3, 3]), + "labor": dgl.dataloading.LaborSampler([3, 3]), }[sampler_name] for num_workers in [0, 1, 2] if mode == "cpu" else [0]: dataloader = dgl.dataloading.DataLoader(