From ee54cb4472127c84b221cc38b37220ee3aa7b9a1 Mon Sep 17 00:00:00 2001 From: wanglipeng Date: Fri, 29 Dec 2023 10:44:47 +0800 Subject: [PATCH 1/4] trie by custname, aigc-model-42 --- paddle/fluid/framework/trie.h | 6 ++-- paddle/fluid/framework/trie_manager.cc | 43 ++++++++++++++++++++++++-- paddle/fluid/framework/trie_manager.h | 7 +++-- paddle/fluid/pybind/box_helper_py.cc | 5 ++- 4 files changed, 52 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/trie.h b/paddle/fluid/framework/trie.h index e0d72d55ae10a..b62b3a820ef43 100644 --- a/paddle/fluid/framework/trie.h +++ b/paddle/fluid/framework/trie.h @@ -64,7 +64,7 @@ struct File { struct Node { uint32_t id = 0; - uint16_t label = 0; + uint32_t label = 0; std::vector child; uint8_t aleaf = 0; }; @@ -74,7 +74,7 @@ struct Node { virtual ~Trie() {} int load(const std::string& dir, const uint32_t thr_num=20u); - uint16_t label(uint32_t id) { + uint32_t label(uint32_t id) { return label_.at(id); } @@ -157,7 +157,7 @@ struct Node { void load_file(uint32_t thr_id, File& file); void stat_file(uint32_t thr_id, File& file); - std::vector label_; + std::vector label_; std::vector aleaf_; std::vector child_mem_; std::vector mem_off_; diff --git a/paddle/fluid/framework/trie_manager.cc b/paddle/fluid/framework/trie_manager.cc index a7c3e4bf175e3..0f6ac35e95ac0 100644 --- a/paddle/fluid/framework/trie_manager.cc +++ b/paddle/fluid/framework/trie_manager.cc @@ -19,6 +19,45 @@ namespace paddle { namespace framework { std::shared_ptr TrieManager::_s_instance = nullptr; +void TrieManager::reset(const std::vector& labels) { + VLOG(3) << "trie reset..."; + std::unique_lock lock(mtx_); + + size_t root = 0; + size_t chs = trie_.child_size(root); + std::unordered_map l2n; + for (size_t i = 0; i < chs; ++i) { + uint32_t cid = trie_.child_at(root, i); + uint32_t lab = trie_.label(cid); + l2n.insert({lab, cid}); + } + + parent_idx_.mutable_data({int(labels.size())}, phi::GPUPinnedPlace()); + int64_t* parent_idx = parent_idx_.data(); + + select_ids_.mutable_data({int(labels.size())}, phi::GPUPinnedPlace()); + int64_t* select_ids = select_ids_.data(); + + label2node_.resize(labels.size()); + for (size_t i = 0; i < labels.size(); ++i) { + auto it = l2n.find(labels[i]); + uint32_t label = endid_; + uint32_t nodeid = end_nodeid_; + + if (it != l2n.end()) { + label = labels[i]; + nodeid = it->second; + } + + parent_idx[i] = i; + select_ids[i] = label; + label2node_[i].insert({label, nodeid}); + } + + phase_ = Phase::run; + cv_.notify_one(); +} + void TrieManager::reset() { VLOG(3) << "trie reset..."; std::unique_lock lock(mtx_); @@ -84,8 +123,8 @@ void TrieManager::run() { int64_t* parent_idx = parent_idx_.data(); int64_t* select_ids = select_ids_.data(); - std::vector> label2node(numel); - std::vector> outs(numel); + std::vector> label2node(numel); + std::vector> outs(numel); parallel_run_range(numel, thr_num, [this, parent_idx, select_ids, &outs, &label2node] ( uint32_t thr_id, uint32_t start, uint32_t end) { for (size_t i = start; i < end; ++i) { diff --git a/paddle/fluid/framework/trie_manager.h b/paddle/fluid/framework/trie_manager.h index cc727f3efc0ad..605ba641dbf91 100644 --- a/paddle/fluid/framework/trie_manager.h +++ b/paddle/fluid/framework/trie_manager.h @@ -69,7 +69,7 @@ enum class Phase { }; public: - TrieManager(uint16_t endid) : endid_(endid), + TrieManager(uint32_t endid) : endid_(endid), place_(platform::GetCurrentDeviceId()) { thread_ = std::thread(&TrieManager::run, this); } @@ -94,7 +94,7 @@ enum class Phase { return _s_instance; } - static std::shared_ptr SetInstance(uint16_t endid) { + static std::shared_ptr SetInstance(uint32_t endid) { static std::mutex mutex; std::lock_guard lock(mutex); if (nullptr == _s_instance) { @@ -111,6 +111,7 @@ enum class Phase { return trie_.load(dir, thr_num); } void reset(); + void reset(const std::vector& labels); void search_start(const Tensor* d_parent, const Tensor* d_select); void search_wait(); @@ -124,7 +125,7 @@ enum class Phase { // cpu Tensor parent_idx_; Tensor select_ids_; - std::vector> label2node_; + std::vector> label2node_; // cpu Tensor next_out_; diff --git a/paddle/fluid/pybind/box_helper_py.cc b/paddle/fluid/pybind/box_helper_py.cc index a68ffe903ac4f..e36edc421750c 100644 --- a/paddle/fluid/pybind/box_helper_py.cc +++ b/paddle/fluid/pybind/box_helper_py.cc @@ -142,7 +142,10 @@ void BindTrieManager(py::module* m) { py::arg("thr_num")=20u, py::call_guard()) .def("reset", - &framework::TrieManager::reset, + py::overload_cast<>(&framework::TrieManager::reset), + py::call_guard()) + .def("reset", + py::overload_cast&>(&framework::TrieManager::reset), py::call_guard()); } // end TrieManager From a39f3752c1e02e4664f08d9671dab590498a515d Mon Sep 17 00:00:00 2001 From: laipaang Date: Mon, 8 Jan 2024 16:24:16 +0800 Subject: [PATCH 2/4] one stage topk, aigc-model-74 --- .../fused/fused_multi_transformer_int8_op.cu | 2 +- paddle/phi/api/yaml/ops.yaml | 2 +- paddle/phi/infermeta/multiary.cc | 1 + paddle/phi/infermeta/multiary.h | 1 + .../phi/kernels/fusion/beam_search_softmax.h | 1 + .../kernels/fusion/gpu/beam_search_softmax.cu | 237 +++++++++---- .../kernels/fusion/gpu/beam_search_topk.cu.h | 316 ++++++++++++++++++ python/paddle/tensor/search.py | 5 +- 8 files changed, 497 insertions(+), 68 deletions(-) create mode 100644 paddle/phi/kernels/fusion/gpu/beam_search_topk.cu.h diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu index 83bd95eaa8f1c..17bce3961a910 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu @@ -245,7 +245,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); auto *qktv_out_data = dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); - fmha_out.Resize({{token_num, num_head, dim_head}}); + fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); auto *fmha_out_data = dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 9ad89b04bfe82..6c07069d39bc9 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -210,7 +210,7 @@ backward : flip_grad - op : beam_search_softmax - args : (Tensor logits, Tensor cum_scores, Tensor sequence_lengths, Tensor stop_flags, Tensor end_ids, Tensor step_ids, Tensor last_cache_ids, Tensor last_beam_offsets, int beam_size, int max_seq_len, int max_dec_len, bool fuse_softmax, bool early_stop, float length_penalty=0.0) + args : (Tensor logits, Tensor cum_scores, Tensor sequence_lengths, Tensor stop_flags, Tensor end_ids, Tensor step_ids, Tensor last_cache_ids, Tensor last_beam_offsets, int beam_size, int max_seq_len, int max_dec_len, bool fuse_softmax, bool early_stop, float length_penalty=0.0, bool one_stage_topk=false) output : Tensor(ids_this_time), Tensor(out_cum_scores), Tensor(cache_ids), Tensor(beam_offsets), Tensor(parent_idx), Tensor(stop_flags_out), Tensor(seq_lens_out), Tensor(step_ids_out) infer_meta : func : BeamSearchSoftmaxInferMeta diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 567573bb1e151..c3e1cd07b37d4 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -688,6 +688,7 @@ void BeamSearchSoftmaxInferMeta(const MetaTensor& logits, bool fuse_softmax, bool early_stop, float length_penalty, + bool one_stage_topk, MetaTensor* ids_this_time, MetaTensor* out_cum_scores, MetaTensor* cache_ids, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index cf8ac463b6775..5bfcfe889fa1c 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -204,6 +204,7 @@ void BeamSearchSoftmaxInferMeta(const MetaTensor& logits, bool fuse_softmax, bool early_stop, float length_penalty, + bool one_stage_topk, MetaTensor* ids_this_time, MetaTensor* out_cum_scores, MetaTensor* cache_ids, diff --git a/paddle/phi/kernels/fusion/beam_search_softmax.h b/paddle/phi/kernels/fusion/beam_search_softmax.h index 6970286f0370f..60cd528a12bd6 100644 --- a/paddle/phi/kernels/fusion/beam_search_softmax.h +++ b/paddle/phi/kernels/fusion/beam_search_softmax.h @@ -35,6 +35,7 @@ void BeamSearchSoftmaxKernel(const Context &dev_ctx, bool fuse_softmax, bool early_stop, float length_penalty, + bool one_stage_topk, DenseTensor *ids_this_time, DenseTensor *out_cum_scores, DenseTensor *cache_ids, diff --git a/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu b/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu index 4a0a63bfd4407..bc114e566e69d 100644 --- a/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu +++ b/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu @@ -20,6 +20,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/fusion/beam_search_softmax.h" +#include "paddle/phi/kernels/fusion/gpu/beam_search_topk.cu.h" namespace phi { namespace fusion { @@ -27,36 +28,37 @@ namespace fusion { #define FLT_MAX 1e38 // #define DEBUG_BEAM_SEARCH_SOFTMAX -#define CASE_K(K) \ - case K: \ - invokeTopKSoftMaxLauncher(dev_ctx, \ - log_probs, \ - stop_flags, \ - sequence_lengths, \ - cum_log_probs, \ - step_ids, \ - last_cache_ids, \ - last_beam_offsets, \ - end_ids, \ - out_cum_log_probs, \ - stop_flags_out, \ - seq_lens_out, \ - step_ids_out, \ - ids, \ - tmp_ids, \ - tmp_vals, \ - parent_idx, \ - cache_ids, \ - beam_offsets, \ - batch_size, \ - beam_size, \ - vocab_size, \ - max_seq_len, \ - max_dec_len, \ - fuse_softmax, \ - early_stop, \ - length_penalty, \ - stream); \ +#define CASE_K(K) \ + case K: \ + invokeTopKSoftMaxLauncher(dev_ctx, \ + log_probs, \ + stop_flags, \ + sequence_lengths, \ + cum_log_probs, \ + step_ids, \ + last_cache_ids, \ + last_beam_offsets, \ + end_ids, \ + out_cum_log_probs, \ + stop_flags_out, \ + seq_lens_out, \ + step_ids_out, \ + ids, \ + tmp_ids, \ + tmp_vals, \ + parent_idx, \ + cache_ids, \ + beam_offsets, \ + batch_size, \ + beam_size, \ + vocab_size, \ + max_seq_len, \ + max_dec_len, \ + fuse_softmax, \ + early_stop, \ + length_penalty, \ + one_stage_topk, \ + stream); \ break struct __align__(8) DySoftMaxStruct { @@ -165,7 +167,7 @@ __global__ void batch_topk(const int *topk_tmp_id_buf, int *step_ids_out) { int thread_id = threadIdx.x; int block_id = blockIdx.x; // bs - const int beam_size = K / 2; + const int beam_size = K; TopK partial; if (thread_id == 0) { for (int i = 0; i < beam_size; ++i) { @@ -220,7 +222,7 @@ __global__ void batch_topk(const int *topk_tmp_id_buf, int *step_ids_out) { int thread_id = threadIdx.x; int block_id = blockIdx.x; // bs - const int beam_size = K / 2; + const int beam_size = K; TopK partial; if (thread_id == 0) { for (int i = 0; i < beam_size; ++i) { @@ -665,35 +667,19 @@ void invokeUpdateCacheIds(const int *last_cache_ids, } template -void invokeTopKSoftMaxLauncher(const Context &dev_ctx, - const T *log_probs, - const bool *stop_flags, - const int *sequence_lengths, - const float *cum_log_probs, - const int *step_ids, - const int *last_cache_ids, - const int *last_beam_offsets, - const int *end_ids, - float *out_cum_log_probs, - bool *stop_flags_out, - int *seq_lens_out, - int *step_ids_out, - int *ids, - int *tmp_ids, - T *tmp_vals, - int *parent_idx, - int *cache_ids, - int *beam_offsets, - const int batch_size, - const int beam_size, - const int vocab_size, - const int max_seq_len, - const int max_dec_len, - const bool fuse_softmax, - const bool early_stop, - const float length_penalty, - cudaStream_t stream) { - // K = 2 * beam_size +void invokeTwoStageTopK(const Context &dev_ctx, + const T *log_probs, + const bool *stop_flags, + const float *cum_log_probs, + const int *step_ids, + const int *end_ids, + int *tmp_ids, + T *tmp_vals, + const int batch_size, + const int beam_size, + const int vocab_size, + const bool fuse_softmax, + const float length_penalty) { const int block_size = 128; int voc_parts = vocab_size / 1024; voc_parts = std::min(128, voc_parts); @@ -715,7 +701,7 @@ void invokeTopKSoftMaxLauncher(const Context &dev_ctx, cudaSharedmemCarveoutMaxL1); // (bs, bm, voc_parts, 2 * K + 2) beam_search_softmax_topk_stage1 - <<>>( + <<>>( log_probs, stop_flags, end_ids, tmp_buffer, vocab_size, fuse_softmax); } else { cudaFuncSetAttribute(beam_search_softmax_topk_stage1, @@ -723,7 +709,7 @@ void invokeTopKSoftMaxLauncher(const Context &dev_ctx, cudaSharedmemCarveoutMaxL1); // (bs, bm, voc_parts, 2 * K) beam_search_softmax_topk_stage1 - <<>>( + <<>>( log_probs, stop_flags, end_ids, tmp_buffer, vocab_size, fuse_softmax); } // (bs, bm, K) @@ -738,7 +724,120 @@ void invokeTopKSoftMaxLauncher(const Context &dev_ctx, fuse_softmax, length_penalty, step_ids, - stream); + dev_ctx.stream()); +} + +template +void invokeOneStageTopK(const Context &dev_ctx, + const T *log_probs, + const bool *stop_flags, + const float *cum_log_probs, + const int *step_ids, + const int *end_ids, + int *tmp_ids, + T *tmp_vals, + const int batch_size, + const int beam_size, + const int vocab_size, + const float length_penalty) { + int input_height = batch_size * beam_size; + int input_width = vocab_size; + int k = beam_size; + + if (k > input_width) { + PADDLE_THROW(paddle::platform::errors::Unavailable( + "Calculation error occurred in TopK Operator's CUDA Kernel.")); + } + + const int kMaxHeight = 2048; + int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; + + switch (GetDesiredBlockDim(input_width)) { + FIXED_BLOCK_DIM( + BeamSearchTopK + <<>>(tmp_vals, + k, + tmp_ids, + log_probs, + input_width, + input_width, + k, + gridx, + input_height, + cum_log_probs, + stop_flags, + step_ids, + end_ids, + length_penalty)); + default: + PADDLE_THROW(paddle::platform::errors::Unavailable( + "Calculation error occurred in TopK Operator's CUDA Kernel.")); + } +} + +template +void invokeTopKSoftMaxLauncher(const Context &dev_ctx, + const T *log_probs, + const bool *stop_flags, + const int *sequence_lengths, + const float *cum_log_probs, + const int *step_ids, + const int *last_cache_ids, + const int *last_beam_offsets, + const int *end_ids, + float *out_cum_log_probs, + bool *stop_flags_out, + int *seq_lens_out, + int *step_ids_out, + int *ids, + int *tmp_ids, + T *tmp_vals, + int *parent_idx, + int *cache_ids, + int *beam_offsets, + const int batch_size, + const int beam_size, + const int vocab_size, + const int max_seq_len, + const int max_dec_len, + const bool fuse_softmax, + const bool early_stop, + const float length_penalty, + const bool one_stage_topk, + cudaStream_t stream) { + if (one_stage_topk) { + if (fuse_softmax) { + PADDLE_THROW(paddle::platform::errors::Unavailable( + "one stage topk not support fuse softmax.")); + } + invokeOneStageTopK(dev_ctx, + log_probs, + stop_flags, + cum_log_probs, + step_ids, + end_ids, + tmp_ids, + tmp_vals, + batch_size, + beam_size, + vocab_size, + length_penalty); + } else { + invokeTwoStageTopK(dev_ctx, + log_probs, + stop_flags, + cum_log_probs, + step_ids, + end_ids, + tmp_ids, + tmp_vals, + batch_size, + beam_size, + vocab_size, + fuse_softmax, + length_penalty); + } + // (bs, bm) if (early_stop) { batch_topk<<>>( @@ -822,6 +921,7 @@ void invokeTopkSoftMax(const Context &dev_ctx, const bool fuse_softmax, const bool early_stop, const float length_penalty, + const bool one_stage_topk, cudaStream_t stream) { switch (beam_size) { CASE_K(1); @@ -865,6 +965,7 @@ void BeamSearchSoftmaxKernel(const Context &dev_ctx, bool fuse_softmax, bool early_stop, float length_penalty, + bool one_stage_topk, DenseTensor *ids_this_time, DenseTensor *out_cum_scores, DenseTensor *cache_ids, @@ -897,7 +998,12 @@ void BeamSearchSoftmaxKernel(const Context &dev_ctx, phi::Copy( dev_ctx, step_ids, dev_ctx.GetPlace(), false, step_ids_out); - const int tmp_size = batch_size * beam_size * beam_size * 2; + if (max_seq_len == 0) { // dynamic max_seq_len + const DDim& dims = last_beam_offsets.dims(); // bs, beam_size, max_seq_len + max_dec_len + max_seq_len = dims[2] - max_dec_len; + } + + const int tmp_size = batch_size * beam_size * beam_size; DenseTensor tmp_topk_id, tmp_topk_val; tmp_topk_id.Resize(phi::make_ddim({tmp_size})); dev_ctx.template Alloc(&tmp_topk_id); @@ -931,6 +1037,7 @@ void BeamSearchSoftmaxKernel(const Context &dev_ctx, fuse_softmax, early_stop, length_penalty, + one_stage_topk, dev_ctx.stream()); } diff --git a/paddle/phi/kernels/fusion/gpu/beam_search_topk.cu.h b/paddle/phi/kernels/fusion/gpu/beam_search_topk.cu.h new file mode 100644 index 0000000000000..63c8c1441ef4a --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/beam_search_topk.cu.h @@ -0,0 +1,316 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace phi { +namespace fusion { + +#define WRAP_SIZE 32 +#define FLT_MAX 1e38 + +inline static int GetDesiredBlockDim(int dim) { + if (dim > 128) { + return 256; + } else if (dim > 64) { + return 128; + } else if (dim > 32) { + return 64; + } else { + return 32; + } +} + +#define FIXED_BLOCK_DIM_BASE(dim, ...) \ + case (dim): { \ + constexpr auto kBlockDim = (dim); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM(...) \ + FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) + +template +struct Pair { + __device__ __forceinline__ Pair() {} + __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {} + + __device__ __forceinline__ void set(T value, int id) { + v = value; + this->id = id; + } + + __device__ __forceinline__ void operator=(const Pair& in) { + v = in.v; + id = in.id; + } + + __device__ __forceinline__ bool operator<(const T value) const { + return (v < value); + } + + __device__ __forceinline__ bool operator>(const T value) const { + return (v > value); + } + __device__ __forceinline__ bool operator<(const Pair& in) const { + return (v < in.v) || ((v == in.v) && (id > in.id)); + } + + __device__ __forceinline__ bool operator>(const Pair& in) const { + return (v > in.v) || ((v == in.v) && (id < in.id)); + } + + T v; + int id; +}; + +template +__device__ __forceinline__ void AddTo(Pair topk[], + const Pair& p, + int beam_size) { + for (int k = beam_size - 2; k >= 0; k--) { + if (topk[k] < p) { + topk[k + 1] = topk[k]; + } else { + topk[k + 1] = p; + return; + } + } + topk[0] = p; +} + +template +__device__ __forceinline__ void GetTopK(Pair topk[], + const T* src, + int idx, + int dim, + int beam_size, + const bool stop_flag, + const int end_id) { + Pair tmp; + while (idx < dim) { + if (stop_flag) { + tmp.set(idx == end_id ? 0 : -static_cast(FLT_MAX), idx); + } else { + tmp.set(src[idx], idx); + } + if (topk[beam_size - 1] < tmp) { + AddTo(topk, tmp, beam_size); + } + idx += BlockSize; + } +} + +template +__device__ __forceinline__ void GetTopK(Pair topk[], + const T* src, + int idx, + int dim, + const Pair& max, + int beam_size, + const bool stop_flag, + const int end_id) { + Pair tmp; + while (idx < dim) { + if (stop_flag) { + tmp.set(idx == end_id ? 0 : -static_cast(FLT_MAX), idx); + } else { + tmp.set(src[idx], idx); + } + if (topk[beam_size - 1] < tmp && tmp < max) { + AddTo(topk, tmp, beam_size); + } + + idx += BlockSize; + } +} + +template +__device__ __forceinline__ void ThreadGetTopK(Pair topk[], + int* beam, + int beam_size, + const T* src, + bool* firstStep, + bool* is_empty, + Pair* max, + int dim, + const int tid, + const bool stop_flag, + const int end_id) { + if (*beam > 0) { + int length = (*beam) < beam_size ? *beam : beam_size; + if (*firstStep) { + *firstStep = false; + GetTopK(topk, src, tid, dim, length, stop_flag, end_id); + } else { + for (int k = 0; k < MaxLength; k++) { + if (k < MaxLength - (*beam)) { + topk[k] = topk[k + *beam]; + } else { + topk[k].set(-static_cast(FLT_MAX), -1); + } + } + if (!(*is_empty)) { + GetTopK( + topk + MaxLength - *beam, src, tid, dim, *max, length, stop_flag, end_id); + } + } + + *max = topk[MaxLength - 1]; + if (max->id == -1) { + *is_empty = true; + } + *beam = 0; + } +} + +template +__device__ __forceinline__ void BlockReduce(Pair* sh_topk, + int* maxid, + Pair topk[], + T** topVal, + int** topIds, + int* beam, + int* k, + const int tid, + const int warp, + const T cum_log_prob, + const T length_penalty) { + while (true) { + __syncthreads(); + if (tid < BlockSize / 2) { + if (sh_topk[tid] < sh_topk[tid + BlockSize / 2]) { + maxid[tid] = tid + BlockSize / 2; + } else { + maxid[tid] = tid; + } + } + __syncthreads(); + for (int stride = BlockSize / 4; stride > 0; stride = stride / 2) { + if (tid < stride) { + if (sh_topk[maxid[tid]] < sh_topk[maxid[tid + stride]]) { + maxid[tid] = maxid[tid + stride]; + } + } + __syncthreads(); + } + __syncthreads(); + + if (tid == 0) { + **topVal = sh_topk[maxid[0]].v / length_penalty + cum_log_prob; + **topIds = sh_topk[maxid[0]].id; + (*topVal)++; + (*topIds)++; + } + if (tid == maxid[0]) (*beam)++; + if (--(*k) == 0) break; + __syncthreads(); + + if (tid == maxid[0]) { + if (*beam < MaxLength) { + sh_topk[tid] = topk[*beam]; + } + } + + if (*beam == MaxLength) { + break; + } + } +} + +template +__global__ void BeamSearchTopK(T* output, + int output_stride, + int* indices, + const T* src, + int lds, + int dim, + int k, + int grid_dim, + int num, + const T *cum_log_probs, // batch * beam + const bool *stop_flags, // batch * beam + const int *step_ids, + const int *end_ids, + const T length_penalty) { + __shared__ Pair sh_topk[BlockSize]; + const int tid = threadIdx.x; + const int warp = threadIdx.x / WRAP_SIZE; + + const int bid = blockIdx.x; + for (int i = bid; i < num; i += grid_dim) { + int top_num = k; + __shared__ int maxid[BlockSize / 2]; + T* out = output + i * output_stride; + int* inds = indices + i * k; + Pair topk[MaxLength]; + int beam = MaxLength; + Pair max; + bool is_empty = false; + bool firststep = true; + + T cum_log_prob = cum_log_probs[i]; + T current_penalty = 1; + const bool stop_flag = stop_flags[i]; + const int end_id = end_ids[0]; + + if (length_penalty == 0.0) { + // do nothing + } else { + // new_prob = (prob + cum_log_prob * previous_penalty) / current_penalty; + T previous_penalty = static_cast(powf(step_ids[i], length_penalty)); + current_penalty = static_cast(powf(step_ids[i] + 1, length_penalty)); + cum_log_prob = cum_log_prob * previous_penalty / current_penalty; + } + + #pragma unroll 1 + for (int j = 0; j < MaxLength; j++) { + topk[j].set(-static_cast(FLT_MAX), -1); + } + + while (top_num) { + ThreadGetTopK(topk, + &beam, + k, + src + i * lds, + &firststep, + &is_empty, + &max, + dim, + tid, + stop_flag, + end_id); + + sh_topk[tid] = topk[0]; + BlockReduce(sh_topk, + maxid, + topk, + &out, + &inds, + &beam, + &top_num, + tid, + warp, + cum_log_prob, + current_penalty); + } + } +} + +#undef FLT_MAX +} // namespace fusion +} // namespace phi diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index eb1975fdcc246..4650f5cf933ba 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -1123,6 +1123,7 @@ def beam_search_softmax( fuse_softmax, early_stop, length_penalty=0.0, + one_stage_topk=False, ): if in_dygraph_mode(): return _C_ops.beam_search_softmax( @@ -1139,7 +1140,8 @@ def beam_search_softmax( max_dec_len, fuse_softmax, early_stop, - length_penalty + length_penalty, + one_stage_topk ) inputs = { @@ -1159,6 +1161,7 @@ def beam_search_softmax( attrs['fuse_softmax'] = fuse_softmax attrs['early_stop'] = early_stop attrs['length_penalty'] = length_penalty + attrs['one_stage_topk'] = one_stage_topk helper = LayerHelper('beam_search_softmax', **locals()) ids_this_time = helper.create_variable_for_type_inference(dtype="int32") From 97ed59fe0418eacd818c572e656fb43a52511b43 Mon Sep 17 00:00:00 2001 From: laipaang Date: Thu, 11 Jan 2024 14:17:13 +0800 Subject: [PATCH 3/4] copy from cpu paddle tensor & infer run release GIL, aigc-model-79 --- paddle/fluid/pybind/inference_api.cc | 49 ++++++++++++++++++++---- python/paddle/fluid/inference/wrapper.py | 2 + 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 32b5d2cb2a6af..b2ab2c6cbdc65 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -228,6 +228,35 @@ void PaddleInferTensorCreate( tensor.CopyFromCpu(static_cast(data.data())); } +void CopyFromCpuPaddleTensor(paddle_infer::Tensor &tensor, + paddle::experimental::Tensor &&paddle_tensor) { + std::vector shape; + for (int i = 0; i < paddle_tensor.dims().size(); ++i) { + shape.push_back(paddle_tensor.dims()[i]); + } + tensor.Reshape(std::move(shape)); + + switch (paddle_tensor.dtype()) { + case paddle::experimental::DataType::FLOAT16: + tensor.CopyFromCpu(static_cast( + paddle_tensor.data())); + break; + case paddle::experimental::DataType::FLOAT32: + tensor.CopyFromCpu(static_cast(paddle_tensor.data())); + break; + case paddle::experimental::DataType::INT32: + tensor.CopyFromCpu(static_cast(paddle_tensor.data())); + break; + case paddle::experimental::DataType::INT64: + tensor.CopyFromCpu(static_cast(paddle_tensor.data())); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported data type. Now copy_from_cpu only supports FLOAT16, FLOAT32, " + "INT32, and INT64.")); + } +} + paddle_infer::PlaceType ToPaddleInferPlace( phi::AllocationType allocation_type) { if (allocation_type == phi::AllocationType::CPU) { @@ -585,7 +614,8 @@ void BindPaddlePredictor(py::module *m) { std::vector outputs; self.Run(inputs, &outputs); return outputs; - }) + }, + py::call_guard()) .def("get_input_tensor", &PaddlePredictor::GetInputTensor) .def("get_output_tensor", &PaddlePredictor::GetOutputTensor) .def("get_input_names", &PaddlePredictor::GetInputNames) @@ -634,7 +664,8 @@ void BindNativePredictor(py::module *m) { std::vector outputs; self.Run(inputs, &outputs); return outputs; - }) + }, + py::call_guard()) .def("get_input_tensor", &NativePaddlePredictor::GetInputTensor) .def("get_output_tensor", &NativePaddlePredictor::GetOutputTensor) .def("zero_copy_run", &NativePaddlePredictor::ZeroCopyRun) @@ -926,7 +957,8 @@ void BindAnalysisPredictor(py::module *m) { std::vector outputs; self.Run(inputs, &outputs); return outputs; - }) + }, + py::call_guard()) .def("get_input_tensor", &AnalysisPredictor::GetInputTensor) .def("get_output_tensor", &AnalysisPredictor::GetOutputTensor) .def("get_input_names", &AnalysisPredictor::GetInputNames) @@ -972,11 +1004,9 @@ void BindPaddleInferPredictor(py::module *m) { .def("get_output_handle", &paddle_infer::Predictor::GetOutputHandle) .def("run", [](paddle_infer::Predictor &self) { -#ifdef PADDLE_WITH_ASCEND_CL - pybind11::gil_scoped_release release; -#endif self.Run(); - }) + }, + py::call_guard()) .def("clone", [](paddle_infer::Predictor &self) { return self.Clone(nullptr); }) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -1024,6 +1054,11 @@ void BindPaddleInferTensor(py::module *m) { .def("copy_from_cpu_bind", &PaddleInferTensorCreate) .def("copy_from_cpu_bind", &PaddleInferStringTensorCreate) + .def("_copy_from_cpu_bind", + [](paddle_infer::Tensor &self, const py::handle &input) { + PyObject *obj = input.ptr(); + CopyFromCpuPaddleTensor(self, std::move(CastPyArg2Tensor(obj, 0))); + }) .def("share_external_data_bind", &PaddleInferShareExternalData) .def("_share_external_data_paddle_tensor_bind", [](paddle_infer::Tensor &self, const py::handle &input) { diff --git a/python/paddle/fluid/inference/wrapper.py b/python/paddle/fluid/inference/wrapper.py index 9afd08f01c946..0b91f4ff2933b 100644 --- a/python/paddle/fluid/inference/wrapper.py +++ b/python/paddle/fluid/inference/wrapper.py @@ -38,6 +38,8 @@ def tensor_copy_from_cpu(self, data): if isinstance(data, np.ndarray) or (isinstance(data, list) and len(data) > 0 and isinstance(data[0], str)): self.copy_from_cpu_bind(data) + elif isinstance(data, paddle.Tensor): + self._copy_from_cpu_bind(data) else: raise TypeError( "In copy_from_cpu, we only support numpy ndarray and list[str] data type." From 687821408b022890999a0ff26cec3ba7ad90f54d Mon Sep 17 00:00:00 2001 From: laipaang Date: Tue, 16 Jan 2024 10:29:22 +0800 Subject: [PATCH 4/4] reduced batch topk, aigc-model-85 --- .../kernels/fusion/gpu/beam_search_softmax.cu | 129 +++++++++++---- .../kernels/fusion/gpu/beam_search_topk.cu.h | 153 ++++++++++++++++++ 2 files changed, 251 insertions(+), 31 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu b/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu index bc114e566e69d..4407f795eec5d 100644 --- a/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu +++ b/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu @@ -775,6 +775,88 @@ void invokeOneStageTopK(const Context &dev_ctx, } } +template +void invokeBatchTopK(int *tmp_ids, + T *tmp_vals, + const float *cum_log_probs, + const int *step_ids, + const bool *stop_flags, // bs * beam_size + const int *sequence_lengths, + const int *end_ids, + int *ids, + T *out_cum_log_probs, + int *parent_idx, + bool *stop_flags_out, + int *seq_lens_out, + int *step_ids_out, + const int batch_size, + const bool early_stop, + cudaStream_t stream) { + if (early_stop) { + if (K > 10) { + reduced_batch_topk<<>>( + tmp_ids, + tmp_vals, + cum_log_probs, + step_ids, + stop_flags, + sequence_lengths, + end_ids, + ids, + out_cum_log_probs, + parent_idx, + stop_flags_out, + seq_lens_out, + step_ids_out); + } else { + batch_topk<<>>( + tmp_ids, + tmp_vals, + cum_log_probs, + step_ids, + stop_flags, + sequence_lengths, + end_ids, + ids, + out_cum_log_probs, + parent_idx, + stop_flags_out, + seq_lens_out, + step_ids_out); + } + } else { + if (K > 10) { + reduced_batch_topk<<>>( + tmp_ids, + tmp_vals, + step_ids, + stop_flags, + sequence_lengths, + end_ids, + ids, + out_cum_log_probs, + parent_idx, + stop_flags_out, + seq_lens_out, + step_ids_out); + } else { + batch_topk<<>>( + tmp_ids, + tmp_vals, + step_ids, + stop_flags, + sequence_lengths, + end_ids, + ids, + out_cum_log_probs, + parent_idx, + stop_flags_out, + seq_lens_out, + step_ids_out); + } + } +} + template void invokeTopKSoftMaxLauncher(const Context &dev_ctx, const T *log_probs, @@ -838,37 +920,22 @@ void invokeTopKSoftMaxLauncher(const Context &dev_ctx, length_penalty); } - // (bs, bm) - if (early_stop) { - batch_topk<<>>( - tmp_ids, - tmp_vals, - cum_log_probs, - step_ids, - stop_flags, - sequence_lengths, - end_ids, - ids, - out_cum_log_probs, - parent_idx, - stop_flags_out, - seq_lens_out, - step_ids_out); - } else { - batch_topk<<>>( - tmp_ids, - tmp_vals, - step_ids, - stop_flags, - sequence_lengths, - end_ids, - ids, - out_cum_log_probs, - parent_idx, - stop_flags_out, - seq_lens_out, - step_ids_out); - } + invokeBatchTopK(tmp_ids, + tmp_vals, + cum_log_probs, + step_ids, + stop_flags, + sequence_lengths, + end_ids, + ids, + out_cum_log_probs, + parent_idx, + stop_flags_out, + seq_lens_out, + step_ids_out, + batch_size, + early_stop, + stream); invokeUpdateBeamOffset(last_beam_offsets, parent_idx, sequence_lengths, diff --git a/paddle/phi/kernels/fusion/gpu/beam_search_topk.cu.h b/paddle/phi/kernels/fusion/gpu/beam_search_topk.cu.h index 63c8c1441ef4a..dcabede3fc1dc 100644 --- a/paddle/phi/kernels/fusion/gpu/beam_search_topk.cu.h +++ b/paddle/phi/kernels/fusion/gpu/beam_search_topk.cu.h @@ -311,6 +311,159 @@ __global__ void BeamSearchTopK(T* output, } } +template +struct BatchTopK { + int id = -1; + int tid = -1; // token id + int pid = -1; // parent beam id + T val = -FLT_MAX; // score value + + __device__ __forceinline__ void insert(int id, T val, int tid, int pid) { + if (val > this->val) { + this->id = id; + this->tid = tid; + this->pid = pid; + this->val = val; + } + } + + __device__ __forceinline__ void init() { + id = -1; + tid = -1; + pid = -1; + val = -FLT_MAX; + } +}; + +template +__device__ __forceinline__ BatchTopK reduce_topk_beam_op(const BatchTopK& a, + const BatchTopK& b) { + if (a.val > b.val || (a.val == b.val && a.tid < b.tid)) { + return a; + } else { + return b; + } +} + +template +__global__ void reduced_batch_topk(int *topk_tmp_id_buf, // batch_size * K * K + T *topk_tmp_val_buf, + const int *step_ids, + const bool *stop_flags, + const int *seq_lens, + const int *end_ids, + int *id_buf, // batch_size * K + T *val_buf, + int *parent_idx, + bool *stop_flags_out, + int *seq_lens_out, + int *step_ids_out) { + const int tid = threadIdx.x; // 0 - BlockSize + const int bid = blockIdx.x; // 0 - BatchSize + const int tmp_buf_idx = bid * K * K; + const int tmp_buf_len = step_ids[0] == 0 ? K : K * K; + const int buf_idx = bid * K; + + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + BatchTopK partial; + + for (int k = 0; k < K; ++k) { + partial.init(); + for (int i = tid; i < tmp_buf_len; i += THREADBLOCK_SIZE) { + const int idx = tmp_buf_idx + i; + partial.insert(idx, topk_tmp_val_buf[idx], topk_tmp_id_buf[idx], i / K); + } + + BatchTopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_beam_op); + + if (tid == 0) { + const int idx = buf_idx + k; + id_buf[idx] = total.tid; + val_buf[idx] = total.val; + parent_idx[idx] = total.pid; + stop_flags_out[idx] = stop_flags[buf_idx + total.pid]; + seq_lens_out[idx] = seq_lens[buf_idx + total.pid]; + step_ids_out[idx] = step_ids[buf_idx + total.pid]; + + topk_tmp_val_buf[total.id] = -FLT_MAX; + } + __syncthreads(); + } +} + +// early stop +template +__global__ void reduced_batch_topk(int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const float* cum_log_probs, + const int* step_ids, + const bool* stop_flags, // bs * beam_size + const int* seq_lens, + const int* end_ids, + int* id_buf, + T* val_buf, + int* parent_idx, + bool* stop_flags_out, + int* seq_lens_out, + int* step_ids_out) { + const int tid = threadIdx.x; // 0 - BlockSize + const int bid = blockIdx.x; // 0 - BatchSize + const int tmp_buf_idx = bid * K * K; + const int buf_idx = bid * K; + + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + BatchTopK partial; + + for (int k = 0; k < K; ++k) { + const int idx = buf_idx + k; + if (stop_flags[idx]) { + if (tid == 0) { + id_buf[idx] = end_ids[0]; + val_buf[idx] = cum_log_probs[idx]; + parent_idx[idx] = k; + stop_flags_out[idx] = stop_flags[idx]; + seq_lens_out[idx] = seq_lens[idx]; + step_ids_out[idx] = step_ids[idx]; + } + + continue; + } + + partial.init(); + + if (step_ids[0] == 0) { + for (int i = tid; i < K; i += THREADBLOCK_SIZE) { + const int idx = tmp_buf_idx + i; + partial.insert(idx, topk_tmp_val_buf[idx], topk_tmp_id_buf[idx], i / K); + } + } else { + for (int i = tid; i < K * K; i += THREADBLOCK_SIZE) { + // if stop, this branch end, no longer update. + if (!stop_flags[buf_idx + i / K]) { + const int idx = tmp_buf_idx + i; + partial.insert(idx, topk_tmp_val_buf[idx], topk_tmp_id_buf[idx], i / K); + } + } + } + + BatchTopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_beam_op); + + if (tid == 0) { + id_buf[idx] = total.tid; + val_buf[idx] = total.val; + parent_idx[idx] = total.pid; + stop_flags_out[idx] = stop_flags[buf_idx + total.pid]; + seq_lens_out[idx] = seq_lens[buf_idx + total.pid]; + step_ids_out[idx] = step_ids[buf_idx + total.pid]; + + topk_tmp_val_buf[total.id] = -FLT_MAX; + } + __syncthreads(); + } +} + #undef FLT_MAX } // namespace fusion } // namespace phi