diff --git a/paddle/fluid/operators/graph_khop_sampler_imp.h b/paddle/fluid/operators/graph_khop_sampler_imp.h new file mode 100644 index 0000000000000..93ebbe46e9315 --- /dev/null +++ b/paddle/fluid/operators/graph_khop_sampler_imp.h @@ -0,0 +1,135 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace operators { + +template +inline __device__ size_t Hash(IdType id, int64_t size) { + return id % size; +} + +template +inline __device__ bool AttemptInsert(size_t pos, IdType id, int64_t index, + IdType* keys, int64_t* key_index) { + if (sizeof(IdType) == 4) { + const IdType key = + atomicCAS(reinterpret_cast(&keys[pos]), + static_cast(-1), static_cast(id)); + if (key == -1 || key == id) { + atomicMin( + reinterpret_cast(&key_index[pos]), // NOLINT + static_cast(index)); // NOLINT + return true; + } else { + return false; + } + } else if (sizeof(IdType) == 8) { + const IdType key = atomicCAS( + reinterpret_cast(&keys[pos]), // NOLINT + static_cast(-1), // NOLINT + static_cast(id)); // NOLINT + if (key == -1 || key == id) { + atomicMin( + reinterpret_cast(&key_index[pos]), // NOLINT + static_cast(index)); // NOLINT + return true; + } else { + return false; + } + } +} + +template +inline __device__ void Insert(IdType id, int64_t index, int64_t size, + IdType* keys, int64_t* key_index) { + size_t pos = Hash(id, size); + size_t delta = 1; + while (!AttemptInsert(pos, id, index, keys, key_index)) { + pos = Hash(pos + delta, size); + delta += 1; + } +} + +template +inline __device__ int64_t Search(IdType id, const IdType* keys, int64_t size) { + int64_t pos = Hash(id, size); + + int64_t delta = 1; + while (keys[pos] != id) { + pos = Hash(pos + delta, size); + delta += 1; + } + + return pos; +} + +template +__global__ void BuildHashTable(const IdType* items, int64_t num_items, + int64_t size, IdType* keys, int64_t* key_index) { + CUDA_KERNEL_LOOP_TYPE(index, num_items, int64_t) { + Insert(items[index], index, size, keys, key_index); + } +} + +template +__global__ void GetItemIndexCount(const IdType* items, int* item_count, + int64_t num_items, int64_t size, + const IdType* keys, int64_t* key_index) { + CUDA_KERNEL_LOOP_TYPE(i, num_items, int64_t) { + int64_t pos = Search(items[i], keys, size); + if (key_index[pos] == i) { + item_count[i] = 1; + } + } +} + +template +__global__ void FillUniqueItems(const IdType* items, int64_t num_items, + int64_t size, IdType* unique_items, + const int* item_count, const IdType* keys, + IdType* values, int64_t* key_index) { + CUDA_KERNEL_LOOP_TYPE(i, num_items, int64_t) { + int64_t pos = Search(items[i], keys, size); + if (key_index[pos] == i) { + values[pos] = item_count[i]; + unique_items[item_count[i]] = items[i]; + } + } +} + +template +__global__ void ReindexSrcOutput(IdType* src_output, int64_t num_items, + int64_t size, const IdType* keys, + const IdType* values) { + CUDA_KERNEL_LOOP_TYPE(i, num_items, int64_t) { + int64_t pos = Search(src_output[i], keys, size); + src_output[i] = values[pos]; + } +} + +template +__global__ void ReindexInputNodes(const IdType* nodes, int64_t num_items, + IdType* reindex_nodes, int64_t size, + const IdType* keys, const IdType* values) { + CUDA_KERNEL_LOOP_TYPE(i, num_items, int64_t) { + int64_t pos = Search(nodes[i], keys, size); + reindex_nodes[i] = values[pos]; + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/graph_khop_sampler_op.cc b/paddle/fluid/operators/graph_khop_sampler_op.cc new file mode 100644 index 0000000000000..c83ee25840605 --- /dev/null +++ b/paddle/fluid/operators/graph_khop_sampler_op.cc @@ -0,0 +1,133 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/graph_khop_sampler_op.h" + +namespace paddle { +namespace operators { + +void InputShapeCheck(const framework::DDim& dims, std::string tensor_name) { + if (dims.size() == 2) { + PADDLE_ENFORCE_EQ(dims[1], 1, platform::errors::InvalidArgument( + "The last dim of %s should be 1 when it " + "is 2D, but we get %d", + tensor_name, dims[1])); + } else { + PADDLE_ENFORCE_EQ( + dims.size(), 1, + platform::errors::InvalidArgument( + "The %s should be 1D, when it is not 2D, but we get %d", + tensor_name, dims.size())); + } +} + +class GraphKhopSamplerOP : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Row"), "Input", "Row", "GraphKhopSampler"); + OP_INOUT_CHECK(ctx->HasInput("Col_Ptr"), "Input", "Col_Ptr", + "GraphKhopSampler"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GraphKhopSampler"); + OP_INOUT_CHECK(ctx->HasOutput("Out_Src"), "Output", "Out_Src", + "GraphKhopSampler"); + OP_INOUT_CHECK(ctx->HasOutput("Out_Dst"), "Output", "Out_Dst", + "GraphKhopSampler"); + OP_INOUT_CHECK(ctx->HasOutput("Sample_Index"), "Output", "Sample_Index", + "GraphKhopSampler"); + OP_INOUT_CHECK(ctx->HasOutput("Reindex_X"), "Output", "Reindex_X", + "GraphKhopSampler"); + + // Restrict all the inputs as 1-dim tensor, or 2-dim tensor with the second + // dim as 1. + InputShapeCheck(ctx->GetInputDim("Row"), "Row"); + InputShapeCheck(ctx->GetInputDim("Col_Ptr"), "Col_Ptr"); + InputShapeCheck(ctx->GetInputDim("X"), "X"); + + const std::vector& sample_sizes = + ctx->Attrs().Get>("sample_sizes"); + PADDLE_ENFORCE_EQ( + !sample_sizes.empty(), true, + platform::errors::InvalidArgument( + "The parameter 'sample_sizes' in GraphSampleOp must be set. " + "But received 'sample_sizes' is empty.")); + const bool& return_eids = ctx->Attrs().Get("return_eids"); + if (return_eids) { + OP_INOUT_CHECK(ctx->HasInput("Eids"), "Input", "Eids", + "GraphKhopSampler"); + InputShapeCheck(ctx->GetInputDim("Eids"), "Eids"); + OP_INOUT_CHECK(ctx->HasOutput("Out_Eids"), "Output", "Out_Eids", + "GraphKhopSampler"); + ctx->SetOutputDim("Out_Eids", {-1}); + } + + ctx->SetOutputDim("Out_Src", {-1, 1}); + ctx->SetOutputDim("Out_Dst", {-1, 1}); + ctx->SetOutputDim("Sample_Index", {-1}); + + auto dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Reindex_X", dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Row"), + ctx.device_context()); + } +}; + +class GraphKhopSamplerOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Row", "The src index tensor of graph edges after sorted by dst."); + AddInput("Eids", "The eids of the input graph edges.").AsDispensable(); + AddInput("Col_Ptr", + "The cumulative sum of the number of src neighbors of dst index, " + "starts from 0, end with number of edges"); + AddInput("X", "The input center nodes index tensor."); + AddOutput("Out_Src", + "The output src edges tensor after sampling and reindex."); + AddOutput("Out_Dst", + "The output dst edges tensor after sampling and reindex."); + AddOutput("Sample_Index", + "The original index of the center nodes and sampling nodes"); + AddOutput("Reindex_X", "The reindex node id of the input nodes."); + AddOutput("Out_Eids", "The eids of the sample edges.").AsIntermediate(); + AddAttr>( + "sample_sizes", "The sample sizes of graph sample neighbors method.") + .SetDefault({}); + AddAttr("return_eids", + "Whether to return the eids of the sample edges.") + .SetDefault(false); + AddComment(R"DOC( +Graph Learning Sampling Neighbors operator, for graphsage sampling method. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR(graph_khop_sampler, ops::GraphKhopSamplerOP, + ops::GraphKhopSamplerOpMaker); +REGISTER_OP_CPU_KERNEL(graph_khop_sampler, + ops::GraphKhopSamplerOpKernel, + ops::GraphKhopSamplerOpKernel); diff --git a/paddle/fluid/operators/graph_khop_sampler_op.cu b/paddle/fluid/operators/graph_khop_sampler_op.cu new file mode 100644 index 0000000000000..777ec64f6e008 --- /dev/null +++ b/paddle/fluid/operators/graph_khop_sampler_op.cu @@ -0,0 +1,566 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +This file is inspired by + + https://github.com/quiver-team/torch-quiver/blob/main/srcs/cpp/src/quiver/cuda/quiver_sample.cu + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef PADDLE_WITH_HIP +#include +#include +#else +#include +#include +#endif + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/graph_khop_sampler_imp.h" +#include "paddle/fluid/operators/graph_khop_sampler_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/platform/place.h" + +constexpr int WARP_SIZE = 32; + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +struct MaxFunctor { + T cap; + HOSTDEVICE explicit inline MaxFunctor(T cap) { this->cap = cap; } + HOSTDEVICE inline T operator()(T x) const { + if (x > cap) { + return cap; + } + return x; + } +}; + +template +struct DegreeFunctor { + const T* dst_count; + HOSTDEVICE explicit inline DegreeFunctor(const T* x) { this->dst_count = x; } + HOSTDEVICE inline T operator()(T i) const { + return dst_count[i + 1] - dst_count[i]; + } +}; + +template +__global__ void GraphSampleNeighborsCUDAKernel( + const uint64_t rand_seed, int k, const int64_t num_rows, const T* in_rows, + const T* src, const T* dst_count, const T* src_eids, T* outputs, + T* outputs_eids, T* output_ptr, T* output_idxs, bool return_eids) { + assert(blockDim.x == WARP_SIZE); + assert(blockDim.y == BLOCK_WARPS); + + int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y; + const int64_t last_row = + min(static_cast(blockIdx.x + 1) * TILE_SIZE, num_rows); +#ifdef PADDLE_WITH_HIP + hiprandState rng; + hiprand_init(rand_seed * gridDim.x + blockIdx.x, + threadIdx.y * WARP_SIZE + threadIdx.x, 0, &rng); +#else + curandState rng; + curand_init(rand_seed * gridDim.x + blockIdx.x, + threadIdx.y * WARP_SIZE + threadIdx.x, 0, &rng); +#endif + + while (out_row < last_row) { + const int64_t row = in_rows[out_row]; + const int64_t in_row_start = dst_count[row]; + const int64_t deg = dst_count[row + 1] - in_row_start; + const int64_t out_row_start = output_ptr[out_row]; + + if (deg <= k) { + for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) { + const T in_idx = in_row_start + idx; + outputs[out_row_start + idx] = src[in_idx]; + if (return_eids) { + outputs_eids[out_row_start + idx] = src_eids[in_idx]; + } + } + } else { + for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) { + output_idxs[out_row_start + idx] = idx; + } +#ifdef PADDLE_WITH_CUDA + __syncwarp(); +#endif + + for (int idx = k + threadIdx.x; idx < deg; idx += WARP_SIZE) { +#ifdef PADDLE_WITH_HIP + const int num = hiprand(&rng) % (idx + 1); +#else + const int num = curand(&rng) % (idx + 1); +#endif + if (num < k) { + paddle::platform::CudaAtomicMax(output_idxs + out_row_start + num, + idx); + } + } +#ifdef PADDLE_WITH_CUDA + __syncwarp(); +#endif + + for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) { + const T perm_idx = output_idxs[out_row_start + idx] + in_row_start; + outputs[out_row_start + idx] = src[perm_idx]; + if (return_eids) { + outputs_eids[out_row_start + idx] = src_eids[perm_idx]; + } + } + } + + out_row += BLOCK_WARPS; + } +} + +template +__global__ void GetDstEdgeCUDAKernel(const int64_t num_rows, const T* in_rows, + const T* dst_sample_counts, + const T* dst_ptr, T* outputs) { + assert(blockDim.x == WARP_SIZE); + assert(blockDim.y == BLOCK_WARPS); + + int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y; + const int64_t last_row = + min(static_cast(blockIdx.x + 1) * TILE_SIZE, num_rows); + + while (out_row < last_row) { + const int64_t row = in_rows[out_row]; + const int64_t dst_sample_size = dst_sample_counts[out_row]; + const int64_t out_row_start = dst_ptr[out_row]; + for (int idx = threadIdx.x; idx < dst_sample_size; idx += WARP_SIZE) { + outputs[out_row_start + idx] = row; + } +#ifdef PADDLE_WITH_CUDA + __syncwarp(); +#endif + + out_row += BLOCK_WARPS; + } +} + +template +void SampleNeighbors(const framework::ExecutionContext& ctx, const T* src, + const T* dst_count, const T* src_eids, + thrust::device_vector* inputs, + thrust::device_vector* outputs, + thrust::device_vector* output_counts, + thrust::device_vector* outputs_eids, int k, + bool is_first_layer, bool is_last_layer, + bool return_eids) { + const size_t bs = inputs->size(); + output_counts->resize(bs); + + // 1. Get input nodes' degree. + thrust::transform(inputs->begin(), inputs->end(), output_counts->begin(), + DegreeFunctor(dst_count)); + + // 2. Apply sample size k to get final sample size. + if (k >= 0) { + thrust::transform(output_counts->begin(), output_counts->end(), + output_counts->begin(), MaxFunctor(k)); + } + + // 3. Get the number of total sample neighbors and some necessary datas. + T total_sample_num = + thrust::reduce(output_counts->begin(), output_counts->end()); + if (is_first_layer) { + PADDLE_ENFORCE_GT( + total_sample_num, 0, + platform::errors::InvalidArgument( + "The input nodes `X` should have at least one neighbor, " + "but none of the input nodes have neighbors.")); + } + outputs->resize(total_sample_num); + if (return_eids) { + outputs_eids->resize(total_sample_num); + } + + thrust::device_vector output_ptr; + thrust::device_vector output_idxs; + output_ptr.resize(bs); + output_idxs.resize(total_sample_num); + thrust::exclusive_scan(output_counts->begin(), output_counts->end(), + output_ptr.begin(), 0); + + // 4. Run graph sample kernel. + constexpr int BLOCK_WARPS = 128 / WARP_SIZE; + constexpr int TILE_SIZE = BLOCK_WARPS * 16; + const dim3 block(WARP_SIZE, BLOCK_WARPS); + const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE); + GraphSampleNeighborsCUDAKernel<<< + grid, block, 0, + reinterpret_cast(ctx.device_context()) + .stream()>>>( + 0, k, bs, thrust::raw_pointer_cast(inputs->data()), src, dst_count, + src_eids, thrust::raw_pointer_cast(outputs->data()), + thrust::raw_pointer_cast(outputs_eids->data()), + thrust::raw_pointer_cast(output_ptr.data()), + thrust::raw_pointer_cast(output_idxs.data()), return_eids); + + // 5. Get inputs = outputs - inputs: + if (!is_last_layer) { + thrust::sort(inputs->begin(), inputs->end()); + thrust::device_vector outputs_sort(outputs->size()); + thrust::copy(outputs->begin(), outputs->end(), outputs_sort.begin()); + thrust::sort(outputs_sort.begin(), outputs_sort.end()); + auto outputs_sort_end = + thrust::unique(outputs_sort.begin(), outputs_sort.end()); + outputs_sort.resize( + thrust::distance(outputs_sort.begin(), outputs_sort_end)); + thrust::device_vector unique_outputs(outputs_sort.size()); + auto unique_outputs_end = thrust::set_difference( + outputs_sort.begin(), outputs_sort.end(), inputs->begin(), + inputs->end(), unique_outputs.begin()); + inputs->resize( + thrust::distance(unique_outputs.begin(), unique_outputs_end)); + thrust::copy(unique_outputs.begin(), unique_outputs_end, inputs->begin()); + } +} + +template +void FillHashTable(const framework::ExecutionContext& ctx, const T* input, + int64_t num_input, int64_t len_hashtable, + thrust::device_vector* unique_items, + thrust::device_vector* keys, + thrust::device_vector* values, + thrust::device_vector* key_index) { +#ifdef PADDLE_WITH_HIP + int block = 256; +#else + int block = 1024; +#endif + const auto& dev_ctx = ctx.cuda_device_context(); + int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize().x; + int grid_tmp = (num_input + block - 1) / block; + int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; + // 1. Insert data into keys and values. + BuildHashTable< + T><<( + ctx.device_context()) + .stream()>>>( + input, num_input, len_hashtable, thrust::raw_pointer_cast(keys->data()), + thrust::raw_pointer_cast(key_index->data())); + + // 2. Get item index count. + thrust::device_vector item_count(num_input + 1, 0); + GetItemIndexCount< + T><<( + ctx.device_context()) + .stream()>>>( + input, thrust::raw_pointer_cast(item_count.data()), num_input, + len_hashtable, thrust::raw_pointer_cast(keys->data()), + thrust::raw_pointer_cast(key_index->data())); + + thrust::exclusive_scan(item_count.begin(), item_count.end(), + item_count.begin()); + size_t total_unique_items = item_count[num_input]; + unique_items->resize(total_unique_items); + + // 3. Get unique items. + FillUniqueItems< + T><<( + ctx.device_context()) + .stream()>>>( + input, num_input, len_hashtable, + thrust::raw_pointer_cast(unique_items->data()), + thrust::raw_pointer_cast(item_count.data()), + thrust::raw_pointer_cast(keys->data()), + thrust::raw_pointer_cast(values->data()), + thrust::raw_pointer_cast(key_index->data())); +} + +template +void ReindexFunc(const framework::ExecutionContext& ctx, + thrust::device_vector* inputs, + thrust::device_vector* outputs, + thrust::device_vector* subset, + thrust::device_vector* orig_nodes, + thrust::device_vector* reindex_nodes, int bs) { + subset->resize(inputs->size() + outputs->size()); + thrust::copy(inputs->begin(), inputs->end(), subset->begin()); + thrust::copy(outputs->begin(), outputs->end(), + subset->begin() + inputs->size()); + thrust::device_vector unique_items; + unique_items.clear(); + + // Fill hash table. + int64_t num = subset->size(); + int64_t log_num = 1 << static_cast(1 + std::log2(num >> 1)); + int64_t size = log_num << 1; + thrust::device_vector keys(size, -1); + thrust::device_vector values(size, -1); + thrust::device_vector key_index(size, -1); + FillHashTable(ctx, thrust::raw_pointer_cast(subset->data()), + subset->size(), size, &unique_items, &keys, &values, + &key_index); + + subset->resize(unique_items.size()); + thrust::copy(unique_items.begin(), unique_items.end(), subset->begin()); + +// Fill outputs with reindex result. +#ifdef PADDLE_WITH_HIP + int block = 256; +#else + int block = 1024; +#endif + const auto& dev_ctx = ctx.cuda_device_context(); + int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize().x; + int64_t grid_tmp = (outputs->size() + block - 1) / block; + int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; + ReindexSrcOutput< + T><<( + ctx.device_context()) + .stream()>>>( + thrust::raw_pointer_cast(outputs->data()), outputs->size(), size, + thrust::raw_pointer_cast(keys.data()), + thrust::raw_pointer_cast(values.data())); + + int grid_ = (bs + block - 1) / block; + ReindexInputNodes<<( + ctx.device_context()) + .stream()>>>( + thrust::raw_pointer_cast(orig_nodes->data()), bs, + thrust::raw_pointer_cast(reindex_nodes->data()), size, + thrust::raw_pointer_cast(keys.data()), + thrust::raw_pointer_cast(values.data())); +} + +template +class GraphKhopSamplerOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // 1. Get sample neighbors operators' inputs. + auto* src = ctx.Input("Row"); + auto* dst_count = ctx.Input("Col_Ptr"); + auto* vertices = ctx.Input("X"); + std::vector sample_sizes = ctx.Attr>("sample_sizes"); + bool return_eids = ctx.Attr("return_eids"); + + const T* src_data = src->data(); + const T* dst_count_data = dst_count->data(); + const T* p_vertices = vertices->data(); + const int bs = vertices->dims()[0]; + + // 2. Get unique input nodes(X). + thrust::device_vector inputs(bs); + thrust::copy(p_vertices, p_vertices + bs, inputs.begin()); + auto unique_inputs_end = thrust::unique(inputs.begin(), inputs.end()); + inputs.resize(thrust::distance(inputs.begin(), unique_inputs_end)); + + // 3. Sample neighbors. We should distinguish w/o "Src_Eids". + thrust::device_vector outputs; + thrust::device_vector output_counts; + thrust::device_vector outputs_eids; + std::vector> dst_vec; + dst_vec.emplace_back(inputs); + std::vector> outputs_vec; + std::vector> output_counts_vec; + std::vector> outputs_eids_vec; + + const size_t num_layers = sample_sizes.size(); + bool is_last_layer = false, is_first_layer = true; + + if (return_eids) { + auto* src_eids = ctx.Input("Eids"); + const T* src_eids_data = src_eids->data(); + for (int i = 0; i < num_layers; i++) { + if (i == num_layers - 1) { + is_last_layer = true; + } + if (inputs.size() == 0) { + break; + } + if (i > 0) { + is_first_layer = false; + dst_vec.emplace_back(inputs); + } + SampleNeighbors(ctx, src_data, dst_count_data, src_eids_data, + &inputs, &outputs, &output_counts, &outputs_eids, + sample_sizes[i], is_first_layer, is_last_layer, + return_eids); + outputs_vec.emplace_back(outputs); + output_counts_vec.emplace_back(output_counts); + outputs_eids_vec.emplace_back(outputs_eids); + } + } else { + for (int i = 0; i < num_layers; i++) { + if (i == num_layers - 1) { + is_last_layer = true; + } + if (inputs.size() == 0) { + break; + } + if (i > 0) { + is_first_layer = false; + dst_vec.emplace_back(inputs); + } + SampleNeighbors(ctx, src_data, dst_count_data, nullptr, &inputs, + &outputs, &output_counts, &outputs_eids, + sample_sizes[i], is_first_layer, is_last_layer, + return_eids); + outputs_vec.emplace_back(outputs); + output_counts_vec.emplace_back(output_counts); + outputs_eids_vec.emplace_back(outputs_eids); + } + } + + // 4. Concat intermediate sample results + // Including src_merge, unique_dst_merge and dst_sample_counts_merge. + thrust::device_vector unique_dst_merge; // unique dst + thrust::device_vector src_merge; // src + thrust::device_vector dst_sample_counts_merge; // dst degree + int64_t unique_dst_size = 0, src_size = 0; + for (int i = 0; i < num_layers; i++) { + unique_dst_size += dst_vec[i].size(); + src_size += outputs_vec[i].size(); + } + unique_dst_merge.resize(unique_dst_size); + src_merge.resize(src_size); + dst_sample_counts_merge.resize(unique_dst_size); + auto unique_dst_merge_ptr = unique_dst_merge.begin(); + auto src_merge_ptr = src_merge.begin(); + auto dst_sample_counts_merge_ptr = dst_sample_counts_merge.begin(); + for (int i = 0; i < num_layers; i++) { + if (i == 0) { + unique_dst_merge_ptr = thrust::copy( + dst_vec[i].begin(), dst_vec[i].end(), unique_dst_merge.begin()); + src_merge_ptr = thrust::copy(outputs_vec[i].begin(), + outputs_vec[i].end(), src_merge.begin()); + dst_sample_counts_merge_ptr = thrust::copy( + output_counts_vec[i].begin(), output_counts_vec[i].end(), + dst_sample_counts_merge.begin()); + } else { + unique_dst_merge_ptr = thrust::copy( + dst_vec[i].begin(), dst_vec[i].end(), unique_dst_merge_ptr); + src_merge_ptr = thrust::copy(outputs_vec[i].begin(), + outputs_vec[i].end(), src_merge_ptr); + dst_sample_counts_merge_ptr = thrust::copy(output_counts_vec[i].begin(), + output_counts_vec[i].end(), + dst_sample_counts_merge_ptr); + } + } + + // 5. Return eids results. + if (return_eids) { + thrust::device_vector eids_merge; + eids_merge.resize(src_size); + auto eids_merge_ptr = eids_merge.begin(); + for (int i = 0; i < num_layers; i++) { + if (i == 0) { + eids_merge_ptr = + thrust::copy(outputs_eids_vec[i].begin(), + outputs_eids_vec[i].end(), eids_merge.begin()); + } else { + eids_merge_ptr = + thrust::copy(outputs_eids_vec[i].begin(), + outputs_eids_vec[i].end(), eids_merge_ptr); + } + } + auto* out_eids = ctx.Output("Out_Eids"); + out_eids->Resize({static_cast(eids_merge.size())}); + T* p_out_eids = out_eids->mutable_data(ctx.GetPlace()); + thrust::copy(eids_merge.begin(), eids_merge.end(), p_out_eids); + } + + int64_t num_sample_edges = thrust::reduce(dst_sample_counts_merge.begin(), + dst_sample_counts_merge.end()); + + PADDLE_ENFORCE_EQ( + src_merge.size(), num_sample_edges, + platform::errors::PreconditionNotMet( + "Number of sample edges dismatch, the sample kernel has error.")); + + // 6. Get hashtable according to unique_dst_merge and src_merge. + // We can get unique items(subset) and reindex src nodes of sample edges. + // We also get Reindex_X for input nodes here. + thrust::device_vector orig_nodes(bs); + thrust::copy(p_vertices, p_vertices + bs, orig_nodes.begin()); + thrust::device_vector reindex_nodes(bs); + thrust::device_vector subset; + ReindexFunc(ctx, &unique_dst_merge, &src_merge, &subset, &orig_nodes, + &reindex_nodes, bs); + auto* reindex_x = ctx.Output("Reindex_X"); + T* p_reindex_x = reindex_x->mutable_data(ctx.GetPlace()); + thrust::copy(reindex_nodes.begin(), reindex_nodes.end(), p_reindex_x); + + auto* sample_index = ctx.Output("Sample_Index"); + sample_index->Resize({static_cast(subset.size())}); + T* p_sample_index = sample_index->mutable_data(ctx.GetPlace()); + thrust::copy(subset.begin(), subset.end(), p_sample_index); // Done! + + // 7. Reindex dst nodes of sample edges. + thrust::device_vector dst_merge(src_size); + thrust::device_vector unique_dst_merge_reindex(unique_dst_size); + thrust::sequence(unique_dst_merge_reindex.begin(), + unique_dst_merge_reindex.end()); + thrust::device_vector dst_ptr(unique_dst_size); + thrust::exclusive_scan(dst_sample_counts_merge.begin(), + dst_sample_counts_merge.end(), dst_ptr.begin()); + constexpr int BLOCK_WARPS = 128 / WARP_SIZE; + constexpr int TILE_SIZE = BLOCK_WARPS * 16; + const dim3 block(WARP_SIZE, BLOCK_WARPS); + const dim3 grid((unique_dst_size + TILE_SIZE - 1) / TILE_SIZE); + + GetDstEdgeCUDAKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>( + unique_dst_size, + thrust::raw_pointer_cast(unique_dst_merge_reindex.data()), + thrust::raw_pointer_cast(dst_sample_counts_merge.data()), + thrust::raw_pointer_cast(dst_ptr.data()), + thrust::raw_pointer_cast(dst_merge.data())); + + // 8. Give operator's outputs. + auto* out_src = ctx.Output("Out_Src"); + auto* out_dst = ctx.Output("Out_Dst"); + out_src->Resize({static_cast(src_merge.size()), 1}); + out_dst->Resize({static_cast(src_merge.size()), 1}); + T* p_out_src = out_src->mutable_data(ctx.GetPlace()); + T* p_out_dst = out_dst->mutable_data(ctx.GetPlace()); + const size_t& memset_bytes = src_merge.size() * sizeof(T); + thrust::copy(src_merge.begin(), src_merge.end(), p_out_src); + thrust::copy(dst_merge.begin(), dst_merge.end(), p_out_dst); + } +}; + +} // namespace operators +} // namespace paddle + +using CUDA = paddle::platform::CUDADeviceContext; +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL(graph_khop_sampler, + ops::GraphKhopSamplerOpCUDAKernel, + ops::GraphKhopSamplerOpCUDAKernel); diff --git a/paddle/fluid/operators/graph_khop_sampler_op.h b/paddle/fluid/operators/graph_khop_sampler_op.h new file mode 100644 index 0000000000000..d7121cb549370 --- /dev/null +++ b/paddle/fluid/operators/graph_khop_sampler_op.h @@ -0,0 +1,368 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +void SampleUniqueNeighbors(bidiiter begin, bidiiter end, int num_samples) { + int left_num = std::distance(begin, end); + std::random_device rd; + std::mt19937 rng{rd()}; + std::uniform_int_distribution dice_distribution( + 0, std::numeric_limits::max()); + for (int i = 0; i < num_samples; i++) { + bidiiter r = begin; + int random_step = dice_distribution(rng) % left_num; + std::advance(r, random_step); + std::swap(*begin, *r); + ++begin; + --left_num; + } +} + +template +void SampleUniqueNeighborsWithEids(bidiiter src_begin, bidiiter src_end, + bidiiter eid_begin, bidiiter eid_end, + int num_samples) { + int left_num = std::distance(src_begin, src_end); + std::random_device rd; + std::mt19937 rng{rd()}; + std::uniform_int_distribution dice_distribution( + 0, std::numeric_limits::max()); + for (int i = 0; i < num_samples; i++) { + bidiiter r1 = src_begin, r2 = eid_begin; + int random_step = dice_distribution(rng) % left_num; + std::advance(r1, random_step); + std::advance(r2, random_step); + std::swap(*src_begin, *r1); + std::swap(*eid_begin, *r2); + ++src_begin; + ++eid_begin; + --left_num; + } +} + +template +void SampleNeighbors(const T* src, const T* dst_count, const T* src_eids, + std::vector* inputs, std::vector* outputs, + std::vector* output_counts, + std::vector* outputs_eids, int k, bool is_first_layer, + bool is_last_layer, bool return_eids) { + const size_t bs = inputs->size(); + // Allocate the memory of outputs + // Collect the neighbors size + std::vector> out_src_vec; + std::vector> out_eids_vec; + // `sample_cumsum_sizes` record the start position and end position after the + // sample. + std::vector sample_cumsum_sizes(bs + 1); + size_t total_neighbors = 0; + // `total_neighbors` the size of output after the sample + sample_cumsum_sizes[0] = total_neighbors; + for (size_t i = 0; i < bs; i++) { + T node = inputs->data()[i]; + T begin = dst_count[node]; + T end = dst_count[node + 1]; + int cap = end - begin; + int sample_size = cap > k ? k : cap; + total_neighbors += sample_size; + sample_cumsum_sizes[i + 1] = total_neighbors; + std::vector out_src; + out_src.resize(cap); + out_src_vec.emplace_back(out_src); + if (return_eids) { + std::vector out_eids; + out_eids.resize(cap); + out_eids_vec.emplace_back(out_eids); + } + } + if (is_first_layer) { + PADDLE_ENFORCE_GT(total_neighbors, 0, + platform::errors::InvalidArgument( + "The input nodes `X` should have at " + "least one neighbors, but none of the " + "input nodes have neighbors.")); + } + output_counts->resize(bs); + outputs->resize(total_neighbors); + if (return_eids) { + outputs_eids->resize(total_neighbors); + } + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + // Sample the neighbour parallelism + for (size_t i = 0; i < bs; i++) { + T node = inputs->data()[i]; + T begin = dst_count[node]; + T end = dst_count[node + 1]; + int cap = end - begin; + if (k < cap) { + std::copy(src + begin, src + end, out_src_vec[i].begin()); + if (return_eids) { + std::copy(src_eids + begin, src_eids + end, out_eids_vec[i].begin()); + SampleUniqueNeighborsWithEids( + out_src_vec[i].begin(), out_src_vec[i].end(), + out_eids_vec[i].begin(), out_eids_vec[i].end(), k); + } else { + SampleUniqueNeighbors(out_src_vec[i].begin(), out_src_vec[i].end(), k); + } + *(output_counts->data() + i) = k; + } else { + std::copy(src + begin, src + end, out_src_vec[i].begin()); + if (return_eids) { + std::copy(src_eids + begin, src_eids + end, out_eids_vec[i].begin()); + } + *(output_counts->data() + i) = cap; + } + } + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + // Copy the results parallelism + for (size_t i = 0; i < bs; i++) { + int sample_size = sample_cumsum_sizes[i + 1] - sample_cumsum_sizes[i]; + std::copy(out_src_vec[i].begin(), out_src_vec[i].begin() + sample_size, + outputs->data() + sample_cumsum_sizes[i]); + if (return_eids) { + std::copy(out_eids_vec[i].begin(), out_eids_vec[i].begin() + sample_size, + outputs_eids->data() + sample_cumsum_sizes[i]); + } + } + + if (!is_last_layer) { + std::sort(inputs->begin(), inputs->end()); + std::vector outputs_sort(outputs->size()); + std::copy(outputs->begin(), outputs->end(), outputs_sort.begin()); + std::sort(outputs_sort.begin(), outputs_sort.end()); + auto outputs_sort_end = + std::unique(outputs_sort.begin(), outputs_sort.end()); + outputs_sort.resize(std::distance(outputs_sort.begin(), outputs_sort_end)); + std::vector unique_outputs(outputs_sort.size()); + + auto unique_outputs_end = std::set_difference( + outputs_sort.begin(), outputs_sort.end(), inputs->begin(), + inputs->end(), unique_outputs.begin()); + + inputs->resize(std::distance(unique_outputs.begin(), unique_outputs_end)); + std::copy(unique_outputs.begin(), unique_outputs_end, inputs->begin()); + } +} + +template +class GraphKhopSamplerOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // 1. Get sample neighbors operators' inputs. + auto* src = ctx.Input("Row"); + auto* dst_count = ctx.Input("Col_Ptr"); + auto* vertices = ctx.Input("X"); + std::vector sample_sizes = ctx.Attr>("sample_sizes"); + bool return_eids = ctx.Attr("return_eids"); + + const T* src_data = src->data(); + const T* dst_count_data = dst_count->data(); + const T* p_vertices = vertices->data(); + const size_t bs = vertices->dims()[0]; + + // 2. Get unique input nodes(X). + std::vector inputs(bs); + std::copy(p_vertices, p_vertices + bs, inputs.begin()); + auto unique_inputs_end = std::unique(inputs.begin(), inputs.end()); + inputs.resize(std::distance(inputs.begin(), unique_inputs_end)); + + // 3. Sample neighbors. We should distinguish w/o "Eids". + std::vector outputs; + std::vector output_counts; + std::vector outputs_eids; + std::vector> dst_vec; + dst_vec.emplace_back(inputs); + std::vector> outputs_vec; + std::vector> output_counts_vec; + std::vector> outputs_eids_vec; + + const size_t num_layers = sample_sizes.size(); + bool is_last_layer = false, is_first_layer = true; + + if (return_eids) { + auto* src_eids = ctx.Input("Eids"); + const T* src_eids_data = src_eids->data(); + for (size_t i = 0; i < num_layers; i++) { + if (i == num_layers - 1) { + is_last_layer = true; + } + if (inputs.size() == 0) { + break; + } + if (i > 0) { + dst_vec.emplace_back(inputs); + is_first_layer = false; + } + SampleNeighbors(src_data, dst_count_data, src_eids_data, &inputs, + &outputs, &output_counts, &outputs_eids, + sample_sizes[i], is_first_layer, is_last_layer, + return_eids); + outputs_vec.emplace_back(outputs); + output_counts_vec.emplace_back(output_counts); + outputs_eids_vec.emplace_back(outputs_eids); + } + } else { + for (size_t i = 0; i < num_layers; i++) { + if (i == num_layers - 1) { + is_last_layer = true; + } + if (inputs.size() == 0) { + break; + } + if (i > 0) { + is_first_layer = false; + dst_vec.emplace_back(inputs); + } + SampleNeighbors(src_data, dst_count_data, nullptr, &inputs, &outputs, + &output_counts, &outputs_eids, sample_sizes[i], + is_first_layer, is_last_layer, return_eids); + outputs_vec.emplace_back(outputs); + output_counts_vec.emplace_back(output_counts); + outputs_eids_vec.emplace_back(outputs_eids); + } + } + + // 4. Concat intermediate sample results. + int64_t unique_dst_size = 0, src_size = 0; + for (size_t i = 0; i < num_layers; i++) { + unique_dst_size += dst_vec[i].size(); + src_size += outputs_vec[i].size(); + } + + std::vector unique_dst_merge(unique_dst_size); + std::vector src_merge(src_size); + std::vector dst_sample_counts_merge(unique_dst_size); + auto unique_dst_merge_ptr = unique_dst_merge.begin(); + auto src_merge_ptr = src_merge.begin(); + auto dst_sample_counts_merge_ptr = dst_sample_counts_merge.begin(); + // TODO(daisiming): We may try to use std::move in the future. + for (size_t i = 0; i < num_layers; i++) { + if (i == 0) { + unique_dst_merge_ptr = std::copy(dst_vec[i].begin(), dst_vec[i].end(), + unique_dst_merge.begin()); + src_merge_ptr = std::copy(outputs_vec[i].begin(), outputs_vec[i].end(), + src_merge.begin()); + dst_sample_counts_merge_ptr = + std::copy(output_counts_vec[i].begin(), output_counts_vec[i].end(), + dst_sample_counts_merge.begin()); + } else { + unique_dst_merge_ptr = std::copy(dst_vec[i].begin(), dst_vec[i].end(), + unique_dst_merge_ptr); + src_merge_ptr = std::copy(outputs_vec[i].begin(), outputs_vec[i].end(), + src_merge_ptr); + dst_sample_counts_merge_ptr = + std::copy(output_counts_vec[i].begin(), output_counts_vec[i].end(), + dst_sample_counts_merge_ptr); + } + } + + // 5. Return eids results. + if (return_eids) { + std::vector eids_merge(src_size); + auto eids_merge_ptr = eids_merge.begin(); + for (size_t i = 0; i < num_layers; i++) { + if (i == 0) { + eids_merge_ptr = + std::copy(outputs_eids_vec[i].begin(), outputs_eids_vec[i].end(), + eids_merge.begin()); + } else { + eids_merge_ptr = std::copy(outputs_eids_vec[i].begin(), + outputs_eids_vec[i].end(), eids_merge_ptr); + } + } + auto* out_eids = ctx.Output("Out_Eids"); + out_eids->Resize({static_cast(eids_merge.size())}); + T* p_out_eids = out_eids->mutable_data(ctx.GetPlace()); + std::copy(eids_merge.begin(), eids_merge.end(), p_out_eids); + } + + int64_t num_sample_edges = std::accumulate( + dst_sample_counts_merge.begin(), dst_sample_counts_merge.end(), 0); + PADDLE_ENFORCE_EQ( + src_merge.size(), num_sample_edges, + platform::errors::PreconditionNotMet( + "Number of sample edges dismatch, the sample kernel has error.")); + + // 6. Reindex edges. + std::unordered_map node_map; + std::vector unique_nodes; + size_t reindex_id = 0; + for (size_t i = 0; i < unique_dst_merge.size(); i++) { + T node = unique_dst_merge[i]; + unique_nodes.emplace_back(node); + node_map[node] = reindex_id++; + } + for (size_t i = 0; i < src_merge.size(); i++) { + T node = src_merge[i]; + if (node_map.find(node) == node_map.end()) { + unique_nodes.emplace_back(node); + node_map[node] = reindex_id++; + } + src_merge[i] = node_map[node]; + } + std::vector dst_merge(src_merge.size()); + size_t cnt = 0; + for (size_t i = 0; i < unique_dst_merge.size(); i++) { + for (T j = 0; j < dst_sample_counts_merge[i]; j++) { + T node = unique_dst_merge[i]; + dst_merge[cnt++] = node_map[node]; + } + } + + // 7. Get Reindex_X for input nodes. + auto* reindex_x = ctx.Output("Reindex_X"); + T* p_reindex_x = reindex_x->mutable_data(ctx.GetPlace()); + for (size_t i = 0; i < bs; i++) { + p_reindex_x[i] = node_map[p_vertices[i]]; + } + + // 8. Get operator's outputs. + auto* sample_index = ctx.Output("Sample_Index"); + auto* out_src = ctx.Output("Out_Src"); + auto* out_dst = ctx.Output("Out_Dst"); + sample_index->Resize({static_cast(unique_nodes.size())}); + out_src->Resize({static_cast(src_merge.size()), 1}); + out_dst->Resize({static_cast(src_merge.size()), 1}); + T* p_sample_index = sample_index->mutable_data(ctx.GetPlace()); + T* p_out_src = out_src->mutable_data(ctx.GetPlace()); + T* p_out_dst = out_dst->mutable_data(ctx.GetPlace()); + std::copy(unique_nodes.begin(), unique_nodes.end(), p_sample_index); + std::copy(src_merge.begin(), src_merge.end(), p_out_src); + std::copy(dst_merge.begin(), dst_merge.end(), p_out_dst); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index a3f0a0c87fd80..4c641a0aedd86 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -2565,6 +2565,72 @@ void BindImperative(py::module *m_ptr) { return imperative::PyLayerApply(place, cls, args, kwargs); }); +#if defined(PADDLE_WITH_CUDA) + m.def("to_uva_tensor", + [](const py::object &obj, int device_id) { + const auto &tracer = imperative::GetCurrentTracer(); + auto new_tensor = std::shared_ptr( + new imperative::VarBase(tracer->GenerateUniqueName())); + auto array = obj.cast(); + if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, device_id); + } else if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, device_id); + } else if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, device_id); + } else if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, device_id); + } else if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, device_id); + } else if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, device_id); + } else if (py::isinstance>( + array)) { + SetUVATensorFromPyArray( + new_tensor, array, device_id); + } else if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, device_id); + } else { + // obj may be any type, obj.cast() may be failed, + // then the array.dtype will be string of unknown meaning. + PADDLE_THROW(platform::errors::InvalidArgument( + "Input object type error or incompatible array data type. " + "tensor.set() supports array with bool, float16, float32, " + "float64, int8, int16, int32, int64," + "please check your input or input array data type.")); + } + return new_tensor; + }, + py::arg("obj"), py::arg("device_id") = 0, + py::return_value_policy::reference, R"DOC( + Returns tensor with the UVA(unified virtual addressing) created from numpy array. + + Args: + obj(numpy.ndarray): The input numpy array, supporting bool, float16, float32, + float64, int8, int16, int32, int64 dtype currently. + + device_id(int, optional): The destination GPU device id. + Default: 0, means current device. + + Returns: + + new_tensor(paddle.Tensor): Return the UVA Tensor with the sample dtype and + shape with the input numpy array. + + Examples: + .. code-block:: python + + # required: gpu + import numpy as np + import paddle + + data = np.random.randint(10, size=(3, 4)) + tensor = paddle.fluid.core.to_uva_tensor(data) + print(tensor) +)DOC"); + +#endif + #if defined(PADDLE_WITH_CUDA) m.def( "async_write", diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index d916efe605a29..f63c3111bdb3f 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -83,6 +83,7 @@ std::map> op_ins_map = { {"sparse_attention", {"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}}, {"sgd", {"Param", "LearningRate", "Grad", "MasterParam"}}, + {"graph_khop_sampler", {"Row", "Eids", "Col_Ptr", "X"}}, }; // NOTE(zhiqiu): Like op_ins_map. diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 5fe361b148c41..dac84abfb9857 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -448,6 +448,39 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj, } } +template +void SetUVATensorFromPyArray( + const std::shared_ptr &self, + const py::array_t &array, int device_id) { +#if defined(PADDLE_WITH_CUDA) + auto *self_tensor = self->MutableVar()->GetMutable(); + std::vector dims; + dims.reserve(array.ndim()); + int64_t numel = 1; + for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) { + dims.emplace_back(static_cast(array.shape()[i])); + numel *= static_cast(array.shape()[i]); + } + self_tensor->Resize(framework::make_ddim(dims)); + + auto data_type = framework::ToDataType(std::type_index(typeid(T))); + const auto &need_allocate_size = numel * framework::SizeOfType(data_type); + T *data_ptr; + cudaHostAlloc(reinterpret_cast(&data_ptr), need_allocate_size, + cudaHostAllocWriteCombined | cudaHostAllocMapped); + std::memcpy(data_ptr, array.data(), array.nbytes()); + + void *cuda_device_pointer = nullptr; + cudaHostGetDevicePointer(reinterpret_cast(&cuda_device_pointer), + reinterpret_cast(data_ptr), 0); + std::shared_ptr holder = + std::make_shared( + cuda_device_pointer, need_allocate_size, + platform::CUDAPlace(device_id)); + self_tensor->ResetHolderWithType(holder, data_type); +#endif +} + template void _sliceCompute(const framework::Tensor *in, framework::Tensor *out, const platform::CPUDeviceContext &ctx, diff --git a/python/paddle/fluid/tests/unittests/test_graph_khop_sampler.py b/python/paddle/fluid/tests/unittests/test_graph_khop_sampler.py new file mode 100644 index 0000000000000..b8071222ac772 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_graph_khop_sampler.py @@ -0,0 +1,207 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid + + +class TestGraphKhopSampler(unittest.TestCase): + def setUp(self): + num_nodes = 20 + edges = np.random.randint(num_nodes, size=(100, 2)) + edges = np.unique(edges, axis=0) + edges_id = np.arange(0, len(edges)) + sorted_edges = edges[np.argsort(edges[:, 1])] + sorted_eid = edges_id[np.argsort(edges[:, 1])] + + # Calculate dst index cumsum counts. + dst_count = np.zeros(num_nodes) + dst_src_dict = {} + for dst in range(0, num_nodes): + true_index = sorted_edges[:, 1] == dst + dst_count[dst] = np.sum(true_index) + dst_src_dict[dst] = sorted_edges[:, 0][true_index] + dst_count = dst_count.astype("int64") + colptr = np.cumsum(dst_count) + colptr = np.insert(colptr, 0, 0) + + self.row = sorted_edges[:, 0].astype("int64") + self.colptr = colptr.astype("int64") + self.sorted_eid = sorted_eid.astype("int64") + self.nodes = np.unique(np.random.randint( + num_nodes, size=5)).astype("int64") + self.sample_sizes = [5, 5] + self.dst_src_dict = dst_src_dict + + def test_sample_result(self): + paddle.disable_static() + row = paddle.to_tensor(self.row) + colptr = paddle.to_tensor(self.colptr) + nodes = paddle.to_tensor(self.nodes) + + edge_src, edge_dst, sample_index, reindex_nodes = \ + paddle.incubate.graph_khop_sampler(row, colptr, + nodes, self.sample_sizes, + return_eids=False) + # Reindex edge_src and edge_dst to original index. + edge_src = edge_src.reshape([-1]) + edge_dst = edge_dst.reshape([-1]) + sample_index = sample_index.reshape([-1]) + + for i in range(len(edge_src)): + edge_src[i] = sample_index[edge_src[i]] + edge_dst[i] = sample_index[edge_dst[i]] + + for n in self.nodes: + edge_src_n = edge_src[edge_dst == n] + if edge_src_n.shape[0] == 0: + continue + # Ensure no repetitive sample neighbors. + self.assertTrue( + edge_src_n.shape[0] == paddle.unique(edge_src_n).shape[0]) + # Ensure the correct sample size. + self.assertTrue(edge_src_n.shape[0] == self.sample_sizes[0] or + edge_src_n.shape[0] == len(self.dst_src_dict[n])) + in_neighbors = np.isin(edge_src_n.numpy(), self.dst_src_dict[n]) + # Ensure the correct sample neighbors. + self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0]) + + def test_uva_sample_result(self): + paddle.disable_static() + if paddle.fluid.core.is_compiled_with_cuda(): + row = paddle.fluid.core.to_uva_tensor( + self.row.astype(self.row.dtype)) + sorted_eid = paddle.fluid.core.to_uva_tensor( + self.sorted_eid.astype(self.sorted_eid.dtype)) + colptr = paddle.to_tensor(self.colptr) + nodes = paddle.to_tensor(self.nodes) + + edge_src, edge_dst, sample_index, reindex_nodes, edge_eids = \ + paddle.incubate.graph_khop_sampler(row, colptr, + nodes, self.sample_sizes, + sorted_eids=sorted_eid, + return_eids=True) + edge_src = edge_src.reshape([-1]) + edge_dst = edge_dst.reshape([-1]) + sample_index = sample_index.reshape([-1]) + + for i in range(len(edge_src)): + edge_src[i] = sample_index[edge_src[i]] + edge_dst[i] = sample_index[edge_dst[i]] + + for n in self.nodes: + edge_src_n = edge_src[edge_dst == n] + if edge_src_n.shape[0] == 0: + continue + self.assertTrue( + edge_src_n.shape[0] == paddle.unique(edge_src_n).shape[0]) + self.assertTrue( + edge_src_n.shape[0] == self.sample_sizes[0] or + edge_src_n.shape[0] == len(self.dst_src_dict[n])) + in_neighbors = np.isin(edge_src_n.numpy(), self.dst_src_dict[n]) + self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0]) + + def test_sample_result_static_with_eids(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + row = paddle.static.data( + name="row", shape=self.row.shape, dtype=self.row.dtype) + sorted_eids = paddle.static.data( + name="eids", + shape=self.sorted_eid.shape, + dtype=self.sorted_eid.dtype) + colptr = paddle.static.data( + name="colptr", shape=self.colptr.shape, dtype=self.colptr.dtype) + nodes = paddle.static.data( + name="nodes", shape=self.nodes.shape, dtype=self.nodes.dtype) + + edge_src, edge_dst, sample_index, reindex_nodes, edge_eids = \ + paddle.incubate.graph_khop_sampler(row, colptr, + nodes, self.sample_sizes, + sorted_eids, True) + exe = paddle.static.Executor(paddle.CPUPlace()) + ret = exe.run(feed={ + 'row': self.row, + 'eids': self.sorted_eid, + 'colptr': self.colptr, + 'nodes': self.nodes + }, + fetch_list=[edge_src, edge_dst, sample_index]) + + edge_src, edge_dst, sample_index = ret + edge_src = edge_src.reshape([-1]) + edge_dst = edge_dst.reshape([-1]) + sample_index = sample_index.reshape([-1]) + + for i in range(len(edge_src)): + edge_src[i] = sample_index[edge_src[i]] + edge_dst[i] = sample_index[edge_dst[i]] + + for n in self.nodes: + edge_src_n = edge_src[edge_dst == n] + if edge_src_n.shape[0] == 0: + continue + self.assertTrue( + edge_src_n.shape[0] == np.unique(edge_src_n).shape[0]) + self.assertTrue( + edge_src_n.shape[0] == self.sample_sizes[0] or + edge_src_n.shape[0] == len(self.dst_src_dict[n])) + in_neighbors = np.isin(edge_src_n, self.dst_src_dict[n]) + self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0]) + + def test_sample_result_static_without_eids(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + row = paddle.static.data( + name="row", shape=self.row.shape, dtype=self.row.dtype) + colptr = paddle.static.data( + name="colptr", shape=self.colptr.shape, dtype=self.colptr.dtype) + nodes = paddle.static.data( + name="nodes", shape=self.nodes.shape, dtype=self.nodes.dtype) + edge_src, edge_dst, sample_index, reindex_nodes = \ + paddle.incubate.graph_khop_sampler(row, colptr, + nodes, self.sample_sizes) + exe = paddle.static.Executor(paddle.CPUPlace()) + ret = exe.run(feed={ + 'row': self.row, + 'colptr': self.colptr, + 'nodes': self.nodes + }, + fetch_list=[edge_src, edge_dst, sample_index]) + edge_src, edge_dst, sample_index = ret + edge_src = edge_src.reshape([-1]) + edge_dst = edge_dst.reshape([-1]) + sample_index = sample_index.reshape([-1]) + + for i in range(len(edge_src)): + edge_src[i] = sample_index[edge_src[i]] + edge_dst[i] = sample_index[edge_dst[i]] + + for n in self.nodes: + edge_src_n = edge_src[edge_dst == n] + if edge_src_n.shape[0] == 0: + continue + self.assertTrue( + edge_src_n.shape[0] == np.unique(edge_src_n).shape[0]) + self.assertTrue( + edge_src_n.shape[0] == self.sample_sizes[0] or + edge_src_n.shape[0] == len(self.dst_src_dict[n])) + in_neighbors = np.isin(edge_src_n, self.dst_src_dict[n]) + self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_uva.py b/python/paddle/fluid/tests/unittests/test_tensor_uva.py index 98895202c042e..c60d4d98d7154 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_uva.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_uva.py @@ -15,7 +15,6 @@ import paddle import unittest import numpy as np -from paddle.fluid.core import LoDTensor as Tensor class TestTensorCopyFrom(unittest.TestCase): @@ -28,5 +27,19 @@ def test_main(self): self.assertTrue(tensor.place.is_gpu_place()) +class TestUVATensorFromNumpy(unittest.TestCase): + def test_uva_tensor_creation(self): + if paddle.fluid.core.is_compiled_with_cuda(): + dtype_list = [ + "int32", "int64", "float32", "float64", "float16", "int8", + "int16", "bool" + ] + for dtype in dtype_list: + data = np.random.randint(10, size=[4, 5]).astype(dtype) + tensor = paddle.fluid.core.to_uva_tensor(data, 0) + self.assertTrue(tensor.place.is_gpu_place()) + self.assertTrue(np.allclose(tensor.numpy(), data)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index 7c7206d6e89c4..d637def405473 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -19,6 +19,7 @@ from .operators import softmax_mask_fuse_upper_triangle # noqa: F401 from .operators import softmax_mask_fuse # noqa: F401 from .operators import graph_send_recv +from .operators import graph_khop_sampler from .tensor import segment_sum from .tensor import segment_mean from .tensor import segment_max @@ -33,6 +34,7 @@ 'softmax_mask_fuse_upper_triangle', 'softmax_mask_fuse', 'graph_send_recv', + 'graph_khop_sampler', 'segment_sum', 'segment_mean', 'segment_max', diff --git a/python/paddle/incubate/operators/__init__.py b/python/paddle/incubate/operators/__init__.py index ecf73fb393cc1..073c3afcbcbfc 100644 --- a/python/paddle/incubate/operators/__init__.py +++ b/python/paddle/incubate/operators/__init__.py @@ -16,3 +16,4 @@ from .softmax_mask_fuse import softmax_mask_fuse # noqa: F401 from .resnet_unit import ResNetUnit #noqa: F401 from .graph_send_recv import graph_send_recv #noqa: F401 +from .graph_khop_sampler import graph_khop_sampler #noqa: F401 diff --git a/python/paddle/incubate/operators/graph_khop_sampler.py b/python/paddle/incubate/operators/graph_khop_sampler.py new file mode 100644 index 0000000000000..71f403dadcf1e --- /dev/null +++ b/python/paddle/incubate/operators/graph_khop_sampler.py @@ -0,0 +1,152 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle.fluid import core +from paddle import _C_ops + + +def graph_khop_sampler(row, + colptr, + input_nodes, + sample_sizes, + sorted_eids=None, + return_eids=False, + name=None): + """ + Graph Khop Sampler API. + + This API is mainly used in Graph Learning domain, and the main purpose is to + provide high performance graph khop sampling method with subgraph reindex step. + For example, we get the CSC(Compressed Sparse Column) format of the input graph + edges as `row` and `colptr`, so as to covert graph data into a suitable format + for sampling. And the `input_nodes` means the nodes we need to sample neighbors, + and `sample_sizes` means the number of neighbors and number of layers we want + to sample. + + **Note**: + Currently the API will reindex the output edges after finishing sampling. We + will add a choice or a new API for whether to reindex the edges in the near future. + + Args: + row (Tensor): One of the components of the CSC format of the input graph, and + the shape should be [num_edges, 1] or [num_edges]. The available + data type is int32, int64. + colptr (Tensor): One of the components of the CSC format of the input graph, + and the shape should be [num_nodes + 1, 1] or [num_nodes]. + The data type should be the same with `row`. + input_nodes (Tensor): The input nodes we need to sample neighbors for, and the + data type should be the same with `row`. + sample_sizes (list|tuple): The number of neighbors and number of layers we want + to sample. The data type should be int, and the shape + should only have one dimension. + sorted_eids (Tensor): The sorted edge ids, should not be None when `return_eids` + is True. The shape should be [num_edges, 1], and the data + type should be the same with `row`. + return_eids (bool): Whether to return the id of the sample edges. Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + edge_src (Tensor): The src index of the output edges, also means the first column of + the edges. The shape is [num_sample_edges, 1] currently. + edge_dst (Tensor): The dst index of the output edges, also means the second column + of the edges. The shape is [num_sample_edges, 1] currently. + sample_index (Tensor): The original id of the input nodes and sampled neighbor nodes. + reindex_nodes (Tensor): The reindex id of the input nodes. + edge_eids (Tensor): Return the id of the sample edges if `return_eids` is True. + + Examples: + + .. code-block:: python + + import paddle + + row = [3, 7, 0, 9, 1, 4, 2, 9, 3, 9, 1, 9, 7] + colptr = [0, 2, 4, 5, 6, 7, 9, 11, 11, 13, 13] + nodes = [0, 8, 1, 2] + sample_sizes = [2, 2] + row = paddle.to_tensor(row, dtype="int64") + colptr = paddle.to_tensor(colptr, dtype="int64") + nodes = paddle.to_tensor(nodes, dtype="int64") + + edge_src, edge_dst, sample_index, reindex_nodes = \ + paddle.incubate.graph_khop_sampler(row, colptr, nodes, sample_sizes, False) + + """ + + if in_dygraph_mode(): + if return_eids: + if sorted_eids is None: + raise ValueError(f"`sorted_eid` should not be None " + f"if return_eids is True.") + edge_src, edge_dst, sample_index, reindex_nodes, edge_eids = \ + _C_ops.graph_khop_sampler(row, sorted_eids, + colptr, input_nodes, + "sample_sizes", sample_sizes, + "return_eids", True) + return edge_src, edge_dst, sample_index, reindex_nodes, edge_eids + else: + edge_src, edge_dst, sample_index, reindex_nodes, _ = \ + _C_ops.graph_khop_sampler(row, None, + colptr, input_nodes, + "sample_sizes", sample_sizes, + "return_eids", False) + return edge_src, edge_dst, sample_index, reindex_nodes + + check_variable_and_dtype(row, "Row", ("int32", "int64"), + "graph_khop_sampler") + + if return_eids: + if sorted_eids is None: + raise ValueError(f"`sorted_eid` should not be None " + f"if return_eids is True.") + check_variable_and_dtype(sorted_eids, "Eids", ("int32", "int64"), + "graph_khop_sampler") + + check_variable_and_dtype(colptr, "Col_Ptr", ("int32", "int64"), + "graph_khop_sampler") + check_variable_and_dtype(input_nodes, "X", ("int32", "int64"), + "graph_khop_sampler") + + helper = LayerHelper("graph_khop_sampler", **locals()) + edge_src = helper.create_variable_for_type_inference(dtype=row.dtype) + edge_dst = helper.create_variable_for_type_inference(dtype=row.dtype) + sample_index = helper.create_variable_for_type_inference(dtype=row.dtype) + reindex_nodes = helper.create_variable_for_type_inference(dtype=row.dtype) + edge_eids = helper.create_variable_for_type_inference(dtype=row.dtype) + helper.append_op( + type="graph_khop_sampler", + inputs={ + "Row": row, + "Eids": sorted_eids, + "Col_Ptr": colptr, + "X": input_nodes + }, + outputs={ + "Out_Src": edge_src, + "Out_Dst": edge_dst, + "Sample_Index": sample_index, + "Reindex_X": reindex_nodes, + "Out_Eids": edge_eids + }, + attrs={"sample_sizes": sample_sizes, + "return_eids": return_eids}) + if return_eids: + return edge_src, edge_dst, sample_index, reindex_nodes, edge_eids + else: + return edge_src, edge_dst, sample_index, reindex_nodes