diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index d008355e0ed5b..5e787394bce25 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2089,6 +2089,7 @@ USE_TRT_CONVERTER(top_k) USE_TRT_CONVERTER(top_k_v2) USE_TRT_CONVERTER(squeeze2) USE_TRT_CONVERTER(unsqueeze2) +USE_TRT_CONVERTER(fused_token_prune) #if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) USE_TRT_CONVERTER(sparse_fc) USE_TRT_CONVERTER(sparse_multihead_matmul) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 90089fcbfd806..ca91df902a9a1 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -68,7 +68,8 @@ list( c_allreduce_op.cc top_k_op.cc squeeze2_op.cc - unsqueeze2_op.cc) + unsqueeze2_op.cc + fused_token_prune_op.cc) if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc) diff --git a/paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc b/paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc new file mode 100644 index 0000000000000..bab04ac16aac9 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class FusedTokenPruneOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + nvinfer1::ILayer* layer = nullptr; + + auto* Attn = engine_->GetITensor(op_desc.Input("Attn").front()); + auto* X = engine_->GetITensor(op_desc.Input("X").front()); + auto* Mask = engine_->GetITensor(op_desc.Input("Mask").front()); + auto* NewMask = engine_->GetITensor(op_desc.Input("NewMask").front()); + bool keep_first_token = + op_desc.HasAttr("keep_first_token") + ? BOOST_GET_CONST(bool, op_desc.GetAttr("keep_first_token")) + : true; + bool keep_order = op_desc.HasAttr("keep_order") + ? BOOST_GET_CONST(bool, op_desc.GetAttr("keep_order")) + : false; + + std::vector itensors = {Attn, X, Mask, NewMask}; + + auto output_name = op_desc.Output("SlimmedX")[0]; + auto out_inds_name = op_desc.Output("CLSInds")[0]; + if (engine_->with_dynamic_shape()) { +#if IS_TRT_VERSION_GE(6000) + bool with_fp16 = + engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + + if (engine_->precision() == AnalysisConfig::Precision::kInt8) { + with_fp16 = true; + } + plugin::FusedTokenPrunePluginDynamic* plugin = + new plugin::FusedTokenPrunePluginDynamic( + with_fp16, keep_first_token, keep_order); + layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin); +#else + PADDLE_THROW(platform::errors::Fatal( + "You are running the TRT Dynamic Shape mode, need to confirm that " + "your TRT version is no less than 6.0")); +#endif + } else { + PADDLE_THROW(platform::errors::Fatal( + "You are running the Ernie(Bert) model in static shape mode, which " + "is not supported for the time being.\n" + "You can use the config.SetTRTDynamicShapeInfo(...) interface to set " + "the shape information to run the dynamic shape mode.")); + } + RreplenishLayerAndOutput( + layer, "fused_token_prune", {output_name, out_inds_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(fused_token_prune, FusedTokenPruneOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 190f6c731a3b4..d7f66d4cdbc2c 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -275,7 +275,8 @@ struct SimpleOpTypeSetTeller : public Teller { "recover_padding", "remove_padding", "squeeze2", - "unsqueeze2"}; + "unsqueeze2", + "fused_token_prune"}; }; bool OpTeller::Tell(const framework::ir::Node* node, diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index cd65316fb4a63..90344fc0adae8 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -29,7 +29,8 @@ list( remove_padding_plugin.cu recover_padding_plugin.cu c_allreduce_op_plugin.cu - preln_residual_bias_plugin.cu) + preln_residual_bias_plugin.cu + fused_token_prune_op_plugin.cu) if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) list(APPEND TRT_FILES spmm_plugin.cu) @@ -44,3 +45,10 @@ nv_test( test_split_plugin SRCS test_split_plugin.cc DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_plugin) + +if(NOT WIN32) + nv_test( + test_fused_token_prune_plugin + SRCS test_fused_token_prune_plugin.cc + DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_plugin) +endif() diff --git a/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu new file mode 100644 index 0000000000000..627ef44e6fd75 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu @@ -0,0 +1,527 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "cub/cub.cuh" + +#include "paddle/phi/kernels/funcs/math_function.h" + +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/platform/device_context.h" + +#include "paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h" +#include "paddle/fluid/operators/fused_token_prune_op.cu.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#if IS_TRT_VERSION_GE(6000) + +template +__global__ void ElementwiseMask(const T* a, + const T* b, + T* res, + int num_elements) { + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= num_elements) return; + const T zero = 0; + res[tid] = b[tid] >= zero ? a[tid] : zero; +} + +template +__global__ void FillZero(T* data, int len) { + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= len) return; + const T zero = 0; + data[tid] = zero; +} + +__global__ void FillIndex(int32_t* indices, int num_raws, int num_cols) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= num_raws * num_cols) return; + + int col = tid % num_cols; + int raw = tid / num_cols; + + indices[tid] = col; +} + +template +__global__ void MaximumFirst(T* mat, int num_raws, int num_cols, T max_value) { + auto raw = blockIdx.x * blockDim.x + threadIdx.x; + if (raw >= num_raws) return; + mat[raw * num_cols] = max_value; +} + +__global__ void FillOffsets(int* offsets, int num_raws, int num_cols) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid > num_raws) return; + + offsets[tid] = tid * num_cols; +} + +template +__global__ void Slice( + const T* src, T* dst, int num_raws, int src_num_cols, int dst_num_cols) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= num_raws * dst_num_cols) return; + int raw = tid / dst_num_cols; + int col = tid % dst_num_cols; + dst[tid] = src[raw * src_num_cols + col]; +} + +template +__global__ void ReduceSum2( + const T* src, T* dst, int bsz, int nb_head, int max_seq_len) { + int tid = threadIdx.x; + int bid = blockIdx.x; + int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len); + int batch = bid / (nb_head * num_blocks_per_head); + int col = bid % max_seq_len; + int head = (bid / num_blocks_per_head) % nb_head; + + extern __shared__ T res_float[]; + res_float[tid] = + src[batch * (nb_head * max_seq_len * max_seq_len) + + head * (max_seq_len * max_seq_len) + col + tid * max_seq_len]; + __syncthreads(); + + for (int offset = blockDim.x >> 1; offset > 0; offset >>= 1) { + if (tid < offset) { + res_float[tid] += res_float[tid + offset]; + } + __syncthreads(); + if (offset % 2 == 1 && tid == offset - 2) { + res_float[tid] += res_float[tid + 1]; + } + } + + if (tid == 0) { + auto* dst_addr = dst + batch * max_seq_len + col; + atomicAdd(dst_addr, res_float[0]); + } +} + +template <> +__global__ void ReduceSum2( + const half* src, half* dst, int bsz, int nb_head, int max_seq_len) { + int tid = threadIdx.x; + int bid = blockIdx.x; + int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len); + int batch = bid / (nb_head * num_blocks_per_head); + int col = bid % max_seq_len; + int head = (bid / num_blocks_per_head) % nb_head; + + extern __shared__ half res_half[]; + res_half[tid] = + src[batch * (nb_head * max_seq_len * max_seq_len) + + head * (max_seq_len * max_seq_len) + col + tid * max_seq_len]; + __syncthreads(); + + for (int offset = blockDim.x >> 1; offset > 0; offset >>= 1) { + if (tid < offset) { + res_half[tid] += res_half[tid + offset]; + } + __syncthreads(); + if (offset % 2 == 1 && tid == offset - 2) { + res_half[tid] += res_half[tid + 1]; + } + __syncthreads(); + } + + if (tid == 0) { + platform::fastAtomicAdd( + reinterpret_cast(dst), + static_cast(batch * max_seq_len + col), + static_cast(bsz * max_seq_len), + static_cast(res_half[0])); + } +} + +template +__global__ void TakeAlongAxis(const T* src, + T* dst, + int32_t* indices, + int num_raws, + int src_num_cols, + int dst_num_cols, + int num_elements) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= num_raws * dst_num_cols) return; + + int raw = tid / dst_num_cols; + int col = tid % dst_num_cols; + for (int i = 0; i < num_elements; ++i) { + dst[tid * num_elements + i] = + *(src + (raw * src_num_cols + indices[tid]) * num_elements + i); + } +} + +nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT { + auto x_dims = inputs[1], new_mask_dims = inputs[3]; + if (output_index == 0) { + nvinfer1::DimsExprs ret = x_dims; + ret.d[1] = new_mask_dims.d[2]; + return ret; + } else { + nvinfer1::DimsExprs ret; + ret.nbDims = 2; + ret.d[0] = new_mask_dims.d[0]; + ret.d[1] = new_mask_dims.d[2]; + return ret; + } +} + +bool FusedTokenPrunePluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc* in_out, + int nb_inputs, + int nb_outputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_NOT_NULL( + in_out, + platform::errors::InvalidArgument( + "The input of swish plugin shoule not be nullptr.")); + + PADDLE_ENFORCE_LT( + pos, + nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, + nb_inputs + nb_outputs)); + + const nvinfer1::PluginTensorDesc& in = in_out[pos]; + if (pos == 0) { + if (with_fp16_) { +#ifdef TRT_PLUGIN_FP16_AVALIABLE + return (in.type == nvinfer1::DataType::kFLOAT || + in.type == nvinfer1::DataType::kHALF) && + (in.format == nvinfer1::TensorFormat::kLINEAR); +#else + return (in.type == nvinfer1::DataType::kFLOAT) && + (in.format == nvinfer1::TensorFormat::kLINEAR); +#endif + } else { + return (in.type == nvinfer1::DataType::kFLOAT) && + (in.format == nvinfer1::TensorFormat::kLINEAR); + } + } else if (pos <= 4) { + const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1]; + return in.type == prev.type && in.format == prev.format; + } else { + const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1]; + return in.type == nvinfer1::DataType::kINT32 && in.format == prev.format; + } +} + +nvinfer1::DataType FusedTokenPrunePluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType* input_types, + int nb_inputs) const TRT_NOEXCEPT { + if (index == 0) { + return input_types[1]; + } else if (index == 1) { + return nvinfer1::DataType::kINT32; + } +} + +size_t FusedTokenPrunePluginDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc* inputs, + int nb_inputs, + const nvinfer1::PluginTensorDesc* outputs, + int nb_outputs) const TRT_NOEXCEPT { + auto attn_dims = inputs[0].dims; + auto x_dims = inputs[1].dims; + auto new_mask_dims = inputs[3].dims; + auto bsz = attn_dims.d[0], nb_head = attn_dims.d[1], + max_seq_len = attn_dims.d[2]; + + int slimmed_x_len = new_mask_dims.d[2]; + int total = bsz * nb_head * max_seq_len * max_seq_len; + size_t size = total * sizeof(float); + size += bsz * max_seq_len * sizeof(float); + size += bsz * max_seq_len * sizeof(int32_t); + size += bsz * max_seq_len * sizeof(float); + size += bsz * max_seq_len * sizeof(int32_t); + size += (bsz + 1) * sizeof(int); + size += bsz * slimmed_x_len * sizeof(int32_t); + return size; +} + +template +int FusedTokenPrunePluginDynamic::enqueueImpl( + const nvinfer1::PluginTensorDesc* input_desc, + const nvinfer1::PluginTensorDesc* output_desc, + const void* const* inputs, + void* const* outputs, + void* workspace_ptr, + cudaStream_t stream, + int device_id, + T max_value) { + // Dims + auto attn_dims = input_desc[0].dims; + auto x_dims = input_desc[1].dims; + auto new_mask_dims = input_desc[3].dims; + + auto bsz = attn_dims.d[0], nb_head = attn_dims.d[1], + max_seq_len = attn_dims.d[2]; + auto c = x_dims.d[2]; + auto slimmed_x_len = new_mask_dims.d[2]; + + // Inputs + const T* attn_data = static_cast(inputs[0]); + const T* x_data = static_cast(inputs[1]); + const T* mask_data = static_cast(inputs[2]); + + // Outputs + T* output_data = static_cast(outputs[0]); + int32_t* output_indices_data = static_cast(outputs[1]); + + int total = bsz * nb_head * max_seq_len * max_seq_len; + int block = operators::ComputeBlockSize(max_seq_len); + int grid = operators::CeilDivide(total, block); + + // Workspace for intermediate variable + char* workspace = static_cast(workspace_ptr); + T* attn_tmp_data = reinterpret_cast(workspace); + size_t offset = total * sizeof(T); + T* attn_accu_data = reinterpret_cast(workspace + offset); + offset += bsz * max_seq_len * sizeof(T); + int32_t* attn_accu_indices_data = + reinterpret_cast(workspace + offset); + offset += bsz * max_seq_len * sizeof(int32_t); + T* sort_attn_accu_data = reinterpret_cast(workspace + offset); + offset += bsz * max_seq_len * sizeof(T); + int32_t* sort_attn_accu_indices_data = + reinterpret_cast(workspace + offset); + offset += bsz * max_seq_len * sizeof(int32_t); + int* offsets_data = reinterpret_cast(workspace + offset); + offset += (bsz + 1) * sizeof(int); + int32_t* slimmed_sort_attn_accu_indices_data = + reinterpret_cast(workspace + offset); + + // 1. Filter attn by mask + ElementwiseMask + <<>>(attn_data, mask_data, attn_tmp_data, total); + + total = bsz * max_seq_len; + block = operators::ComputeBlockSize(max_seq_len); + grid = operators::CeilDivide(total, block); + FillZero<<>>(attn_accu_data, total); + + // 2. Reduce sum + total = bsz * nb_head * max_seq_len * max_seq_len; + int block_tmp = max_seq_len; + while (block_tmp > 1024) + block_tmp /= 2; // if max seq len > 1024, it must be 2^n + block = + block_tmp; // make sure max_seq_len is an integral multiple of block_size + grid = operators::CeilDivide(total, block); + ReduceSum2<<>>( + attn_tmp_data, attn_accu_data, bsz, nb_head, max_seq_len); + + // 3. Prepare token indices + total = bsz * max_seq_len; + block = operators::ComputeBlockSize(max_seq_len); + grid = operators::CeilDivide(total, block); + + FillIndex<<>>( + attn_accu_indices_data, bsz, max_seq_len); + + // 4. Sort token indices by attn + if (keep_first_token_) { + MaximumFirst + <<>>(attn_accu_data, bsz, max_seq_len, max_value); + } + size_t temp_storage_bytes = -1; + int num_items = bsz * max_seq_len; + int num_segments = bsz; + FillOffsets<<>>(offsets_data, bsz, max_seq_len); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, + temp_storage_bytes, + attn_accu_data, + sort_attn_accu_data, + attn_accu_indices_data, + sort_attn_accu_indices_data, + num_items, + num_segments, + offsets_data, + offsets_data + 1, + 0, + sizeof(T) * 8, + stream)); + int64_t temp_size = temp_storage_bytes; + framework::Tensor temp_storage; + auto* temp_storage_data = temp_storage.mutable_data( + {temp_size}, platform::CUDAPlace(device_id)); + + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortPairsDescending( + temp_storage_data, + temp_storage_bytes, + attn_accu_data, + sort_attn_accu_data, + attn_accu_indices_data, + sort_attn_accu_indices_data, + num_items, + num_segments, + offsets_data, + offsets_data + 1, + 0, + sizeof(T) * 8, + stream)); + // 5. Slice + total = bsz * slimmed_x_len; + block = operators::ComputeBlockSize(slimmed_x_len); + grid = operators::CeilDivide(total, block); + + Slice + <<>>(sort_attn_accu_indices_data, + slimmed_sort_attn_accu_indices_data, + bsz, + max_seq_len, + slimmed_x_len); + + if (keep_order_) { + // 6. reorder + num_items = bsz * slimmed_x_len; + FillOffsets<<>>(offsets_data, bsz, slimmed_x_len); + temp_storage_bytes = -1; + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, + temp_storage_bytes, + slimmed_sort_attn_accu_indices_data, + output_indices_data, + num_items, + num_segments, + offsets_data, + offsets_data + 1, + 0, + sizeof(int32_t) * 8, + stream)); + + temp_size = temp_storage_bytes; + temp_storage.Resize({temp_size}); + temp_storage_data = + temp_storage.mutable_data(platform::CUDAPlace(device_id)); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortKeys( + temp_storage_data, + temp_storage_bytes, + slimmed_sort_attn_accu_indices_data, + output_indices_data, + num_items, + num_segments, + offsets_data, + offsets_data + 1, + 0, + sizeof(int32_t) * 8, + stream)); + + TakeAlongAxis<<>>(x_data, + output_data, + output_indices_data, + bsz, + max_seq_len, + slimmed_x_len, + c); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(output_indices_data, + slimmed_sort_attn_accu_indices_data, + bsz * slimmed_x_len * sizeof(int32_t), + cudaMemcpyDeviceToDevice)); + TakeAlongAxis + <<>>(x_data, + output_data, + slimmed_sort_attn_accu_indices_data, + bsz, + max_seq_len, + slimmed_x_len, + c); + } + + return cudaGetLastError() != cudaSuccess; +} + +int FusedTokenPrunePluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc* input_desc, + const nvinfer1::PluginTensorDesc* output_desc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT { + auto input_type = input_desc[0].type; + auto attn_dims = input_desc[0].dims; + auto bsz = attn_dims.d[0], nb_head = attn_dims.d[1], + max_seq_len = attn_dims.d[2]; + int device_id; + cudaGetDevice(&device_id); + + if (input_type == nvinfer1::DataType::kFLOAT) { + VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp32"; + + float max = std::numeric_limits::max(); + + return enqueueImpl(input_desc, + output_desc, + inputs, + outputs, + workspace, + stream, + device_id, + max); + + } else if (input_type == nvinfer1::DataType::kHALF) { +#ifdef TRT_PLUGIN_FP16_AVALIABLE + VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp16"; + + half max = 65504.0; + + return enqueueImpl(input_desc, + output_desc, + inputs, + outputs, + workspace, + stream, + device_id, + max); + +#else + PADDLE_THROW(platform::errors::Fatal( + "The Ernie(Bert) TensorRT Plugin should be " + "complied with CUDA version >= 10.0 when running with fp16. " + "Please recomplie it or try to use fp32 by set " + "config.SetTRTDynamicShapeInfo(min_input_shape, " + "max_input_shape, opt_input_shape, true")); +#endif + } else { + PADDLE_THROW( + platform::errors::Fatal("The FusedTokenPrune TRT Plugin's input type " + "should be float or half.")); + } +} + +#endif +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h new file mode 100644 index 0000000000000..fcd91522ca39c --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h @@ -0,0 +1,159 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#if IS_TRT_VERSION_GE(6000) + +class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { + public: + explicit FusedTokenPrunePluginDynamic(bool with_fp16, + bool keep_first_token, + bool keep_order) + : keep_first_token_(keep_first_token), keep_order_(keep_order) { + with_fp16_ = with_fp16; + } + FusedTokenPrunePluginDynamic(void const* serial_data, size_t serial_length) { + DeserializeValue(&serial_data, &serial_length, &with_fp16_); + DeserializeValue(&serial_data, &serial_length, &keep_first_token_); + DeserializeValue(&serial_data, &serial_length, &keep_order_); + } + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + return new FusedTokenPrunePluginDynamic( + with_fp16_, keep_first_token_, keep_order_); + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "fused_token_prune_plugin_dynamic"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 2; } + int initialize() TRT_NOEXCEPT override { return 0; } + + size_t getSerializationSize() const TRT_NOEXCEPT override { + return SerializedSize(with_fp16_) + SerializedSize(keep_first_token_) + + SerializedSize(keep_order_); + } + void serialize(void* buffer) const TRT_NOEXCEPT override { + SerializeValue(&buffer, with_fp16_); + SerializeValue(&buffer, keep_first_token_); + SerializeValue(&buffer, keep_order_); + } + + nvinfer1::DimsExprs getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) // NOLINT + TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* in_out, + int nb_inputs, + int nb_outputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nb_inputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nb_outputs) TRT_NOEXCEPT override {} + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nb_inputs, + const nvinfer1::PluginTensorDesc* outputs, + int nb_outputs) const TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* input_desc, + const nvinfer1::PluginTensorDesc* output_desc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* input_types, + int nb_inputs) const + TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + + private: + template + int enqueueImpl(const nvinfer1::PluginTensorDesc* input_desc, + const nvinfer1::PluginTensorDesc* output_desc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream, + int device_id, + T max_value); + bool keep_first_token_; + bool keep_order_; +}; + +class FusedTokenPrunePluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + FusedTokenPrunePluginDynamicCreator() {} + const char* getPluginName() const TRT_NOEXCEPT override { + return "fused_token_prune_plugin_dynamic"; + } + + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) + TRT_NOEXCEPT override { + auto plugin = new FusedTokenPrunePluginDynamic(serial_data, serial_length); + return plugin; + } + + void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const TRT_NOEXCEPT override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_; + std::vector plugin_attributes_; +}; +REGISTER_TRT_PLUGIN_V2(FusedTokenPrunePluginDynamicCreator); + +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/test_fused_token_prune_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/test_fused_token_prune_plugin.cc new file mode 100644 index 0000000000000..131ce46d89a66 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/test_fused_token_prune_plugin.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +TEST(fused_token_prune_op_plugin, test_plugin) { + FusedTokenPrunePluginDynamic plugin( + true, /*keep_first_token*/ false, /*keep_order*/ true); + plugin.configurePlugin(nullptr, 4, nullptr, 2); + plugin.initialize(); + plugin.getPluginType(); + plugin.getNbOutputs(); + auto clone_plugin = plugin.clone(); + clone_plugin->destroy(); + size_t buf_size = plugin.getSerializationSize(); + std::vector buf(buf_size); + plugin.serialize(buf.data()); +} + +TEST(fused_token_prune_op_plugin, test_plugin_creater) { + FusedTokenPrunePluginDynamicCreator creator; + creator.getFieldNames(); + creator.createPlugin("test", nullptr); + creator.setPluginNamespace("test"); +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc index eae1e2baf9ad1..8d95bbea5b89f 100644 --- a/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc @@ -22,6 +22,7 @@ limitations under the License. */ #if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) #include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h" #endif +#include "paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/phi/common/float16.h" @@ -195,6 +196,197 @@ TEST_F(TensorRTDynamicEngineTest, test_spmm) { return; } +class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test { + protected: + void SetUp() override { + ctx_ = new platform::CUDADeviceContext(platform::CUDAPlace(0)); + ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(platform::CUDAPlace(0), ctx_->stream()) + .get()); + ctx_->SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + ctx_->SetZeroAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetZeroAllocator(platform::CUDAPlace(0)) + .get()); + ctx_->SetPinnedAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CUDAPinnedPlace()) + .get()); + ctx_->PartialInitWithAllocator(); + + std::map> min_input_shape = { + {"attn", {4, 1, 4, 4}}, + {"x", {4, 4, 1}}, + {"mask", {4, 1, 4, 4}}, + {"new_mask", {4, 1, 2, 2}}}; + std::map> max_input_shape = { + {"attn", {4, 1, 4, 4}}, + {"x", {4, 4, 1}}, + {"mask", {4, 1, 4, 4}}, + {"new_mask", {4, 1, 2, 2}}}; + std::map> optim_input_shape = { + {"attn", {4, 1, 4, 4}}, + {"x", {4, 4, 1}}, + {"mask", {4, 1, 4, 4}}, + {"new_mask", {4, 1, 2, 2}}}; + + engine_ = new TensorRTEngine(16, + 1 << 10, + AnalysisConfig::Precision::kHalf, + nullptr, + 0, + min_input_shape, + max_input_shape, + optim_input_shape, + false, + phi::DataType::FLOAT32, + NaiveLogger::Global()); + engine_->InitNetwork(); + } + + void TearDown() override { + if (engine_) { + delete engine_; + engine_ = nullptr; + } + } + + void PrepareInputOutput(const std::vector> inputs, + std::vector> output_shapes) { + LOG(INFO) << "PrepareInputOutput"; + int num_inputs = inputs.size(); + int num_outputs = output_shapes.size(); + inputs_.resize(num_inputs); + outputs_.resize(num_outputs); + for (int i = 0; i < num_inputs; ++i) { + paddle::framework::TensorFromVector(inputs[i], *ctx_, &inputs_[i]); + } + for (int i = 0; i < num_outputs; ++i) { + outputs_[i].Resize(phi::make_ddim(output_shapes[i])); + } + } + + void GetOutput(std::vector &slimmed_x, // NOLINT + std::vector &cls_inds) { // NOLINT + paddle::framework::TensorToVector(outputs_[0], *ctx_, &slimmed_x); + paddle::framework::TensorToVector(outputs_[1], *ctx_, &cls_inds); + } + + protected: + std::vector inputs_; + std::vector outputs_; + TensorRTEngine *engine_; + platform::CUDADeviceContext *ctx_; +}; + +TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) { +#if IS_TRT_VERSION_GE(8000) + auto *attn = engine_->DeclareInput( + "attn", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4}); + auto *x = engine_->DeclareInput( + "x", nvinfer1::DataType::kHALF, nvinfer1::Dims3{-1, 4, 1}); + auto *mask = engine_->DeclareInput( + "mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4}); + auto *new_mask = engine_->DeclareInput( + "new_mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 2, 2}); + plugin::FusedTokenPrunePluginDynamic *plugin = + new plugin::FusedTokenPrunePluginDynamic( + true, /*keep_first_token*/ false, /*keep_order*/ true); + std::vector itensors = {attn, x, mask, new_mask}; + auto *layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin); + PADDLE_ENFORCE_NOT_NULL(layer, + platform::errors::InvalidArgument( + "TRT fused_token_prune layer building failed.")); + std::vector output_tensor_names{"out_slimmed_x", "out_cls_inds"}; + for (size_t i = 0; i < 2; i++) { + layer->getOutput(i)->setName(output_tensor_names[i].c_str()); + engine_->DeclareOutput(layer, i, output_tensor_names[i]); + } + engine_->FreezeNetwork(); + + ASSERT_EQ(engine_->engine()->getNbBindings(), 6); + LOG(INFO) << "create input"; + std::vector attn_v(64); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 4; ++k) { + attn_v[i * 16 + j * 4 + k] = k; + } + } + } + std::vector x_v(16); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + x_v[i * 4 + j] = 1; + } + } + std::vector mask_v(64); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 4; ++k) { + mask_v[i * 16 + j * 4 + k] = 1; + } + } + } + std::vector new_mask_v(16); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 2; ++j) { + for (int k = 0; k < 2; ++k) { + new_mask_v[i * 4 + j * 2 + k] = 1; + } + } + } + + LOG(INFO) << "create output"; + std::vector out_slimmed_x_shape{4, 2, 1}; + std::vector out_cls_ins_shape{4, 2}; + + PrepareInputOutput({attn_v, x_v, mask_v, new_mask_v}, + {out_slimmed_x_shape, out_cls_ins_shape}); + + auto *attn_gpu_data = inputs_[0].mutable_data(ctx_->GetPlace()); + auto *x_gpu_data = inputs_[1].mutable_data(ctx_->GetPlace()); + auto *mask_gpu_data = inputs_[2].mutable_data(ctx_->GetPlace()); + auto *new_mask_gpu_data = inputs_[3].mutable_data(ctx_->GetPlace()); + + auto *slimmed_x_gpu_data = outputs_[0].mutable_data(ctx_->GetPlace()); + auto *cls_inds_gpu_data = outputs_[1].mutable_data(ctx_->GetPlace()); + + LOG(INFO) << "create buffers"; + + std::vector buffers(6); + buffers[0] = reinterpret_cast(attn_gpu_data); + buffers[1] = reinterpret_cast(x_gpu_data); + buffers[2] = reinterpret_cast(mask_gpu_data); + buffers[3] = reinterpret_cast(new_mask_gpu_data); + buffers[4] = reinterpret_cast(slimmed_x_gpu_data); + buffers[5] = reinterpret_cast(cls_inds_gpu_data); + + LOG(INFO) << "Execute"; + + engine_->Execute(4, &buffers, ctx_->stream()); + + std::vector slimmed_x_v; + std::vector cls_inds_v; + + LOG(INFO) << "GetOutput"; + GetOutput(slimmed_x_v, cls_inds_v); + + ASSERT_EQ(cls_inds_v[0], 2); + ASSERT_EQ(cls_inds_v[1], 3); + ASSERT_EQ(cls_inds_v[2], 2); + ASSERT_EQ(cls_inds_v[3], 3); + ASSERT_EQ(cls_inds_v[4], 2); + ASSERT_EQ(cls_inds_v[5], 3); + ASSERT_EQ(cls_inds_v[6], 2); + ASSERT_EQ(cls_inds_v[7], 3); + LOG(INFO) << "finish"; +#endif +} + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/operators/fused_token_prune_op.cc b/paddle/fluid/operators/fused_token_prune_op.cc new file mode 100644 index 0000000000000..50ca45967b7bd --- /dev/null +++ b/paddle/fluid/operators/fused_token_prune_op.cc @@ -0,0 +1,187 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class FusedTokenPruneOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Attn", + "(Tensor)" + "The input of fused_token_prune op, whose shape should be [bsz, " + "num_head, max_seq_len, max_seq_len] and dtype should be " + "float32/float64," + "Attn is attention scores of input sequences which will be used " + "to sort another input tensor: X's indices so that " + "some elements of X with lower attention score will not be " + "considered after this op."); + + AddInput("X", + "(Tensor)" + "The input of fused_token_prune op, whose shape should be [bsz, " + "max_seq_len, c] and dtype should be float32/float64."); + + AddInput( + "Mask", + "(Tensor)" + "The input of fused_token_prune op, whose shape should be [bsz, " + "num_head, " + "max_seq_len, max_seq_len] and dtype should be float32/float64." + "Mask is corresponding to Attn's elemnts one by one. Elements of Attn " + "will be set to zero if their corresponding mask is smaller than 0." + "This process happens before sorting X by attn."); + + AddInput("NewMask", + "(Tensor)" + "The input of fused_token_prune op, whose shape should be [bsz, " + "num_head, slimmed_seq_len, slimmed_seq_len]." + "NewMask is just used to get slimmed_seq_len, so the value of " + "this input is not important in this op."); + + AddOutput("SlimmedX", + "(Tensor)" + "The output of fused_token_prune op, whose shape should be [bsz, " + "slimmed_seq_len, C]." + "The tokens of X will be sorted by Attn firstly and then the " + "last (max_seq_len - slimmed_seq_len)" + "tokens will be deleted. SlimmedX is the remainning part of X. " + ""); + + AddOutput( + "CLSInds", + "(Tensor)" + "The output of fused_token_prune op, whose shape should be [bsz, " + "slimmed_seq_len] and dtype is int64. CLSInds contains token indices " + " of each batch after sorting and pruning. "); + + AddAttr("keep_first_token", + "If keep_first_token is True, the element located in " + "CLSInds[:, 1] must be 0.") + .SetDefault(true); + + AddAttr("keep_order", + "If keep_order is True, the relative order of SlimmedX and " + "CLSInds remains unchanged") + .SetDefault(false); + + AddComment(R"DOC( + fused_token_prune op is used to fuse multiple ops to perform token pruning. + In this op: + 1. Elements of Attn will be set to zero if their corresponding mask is smaller than 0. + 2. The second dimension of X will be sorted by Attn. + 3. The last (max_seq_len - slimmed_seq_len) lines of X will be pruned. + 4. The remainning part of sorted X will output. + )DOC"); + } +}; + +class FusedTokenPruneOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Attn"), "Input", "Attn", "FusedTokenPrune"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedTokenPrune"); + OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "FusedTokenPrune"); + OP_INOUT_CHECK( + ctx->HasInput("NewMask"), "Input", "NewMask", "FusedTokenPrune"); + OP_INOUT_CHECK( + ctx->HasOutput("SlimmedX"), "Output", "SlimmedX", "FusedTokenPrune"); + OP_INOUT_CHECK( + ctx->HasOutput("CLSInds"), "Output", "CLSInds", "FusedTokenPrune"); + + auto mask_dim = ctx->GetInputDim("Mask"); + auto attn_dim = ctx->GetInputDim("Attn"); + auto x_dim = ctx->GetInputDim("X"); + auto new_mask_dim = ctx->GetInputDim("NewMask"); + + // check input dims number + PADDLE_ENFORCE_EQ(mask_dim.size(), + 4, + platform::errors::InvalidArgument( + "The input mask must be 4-dimention")); + PADDLE_ENFORCE_EQ(attn_dim.size(), + 4, + platform::errors::InvalidArgument( + "The input attn must be 4-dimention")); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 3, + platform::errors::InvalidArgument("The input x must be 4-dimention")); + PADDLE_ENFORCE_EQ(new_mask_dim.size(), + 4, + platform::errors::InvalidArgument( + "The input attn must be 4-dimention")); + + // check input dims relations + PADDLE_ENFORCE_EQ(mask_dim[0], + attn_dim[0], + platform::errors::InvalidArgument( + "The first dim of mask and attn should be the same" + "which is batch size")); + PADDLE_ENFORCE_EQ(mask_dim[1], + attn_dim[1], + platform::errors::InvalidArgument( + "The second dim of mask and attn should be the same" + "which is nb_head")); + PADDLE_ENFORCE_EQ(mask_dim[0], + x_dim[0], + platform::errors::InvalidArgument( + "The first dim of mask and x should be the same" + "which is batch size")); + PADDLE_ENFORCE_EQ( + mask_dim[2], + mask_dim[3], + platform::errors::InvalidArgument( + "The third dim and the fourth dim of mask should be the same" + "which is max seq len")); + PADDLE_ENFORCE_EQ( + attn_dim[2], + attn_dim[3], + platform::errors::InvalidArgument( + "The third dim and the fourth dim of mask should be the same" + "which is max seq len")); + PADDLE_ENFORCE_EQ(attn_dim[2], + mask_dim[2], + platform::errors::InvalidArgument( + "The third dim of mask and attn should be the same" + "which is max seq len")); + PADDLE_ENFORCE_EQ(attn_dim[2], + x_dim[1], + platform::errors::InvalidArgument( + "The third dim of mask and the second dim of attn" + "should be the same which is max seq len")); + + auto bsz = mask_dim[0]; + auto c = x_dim[2]; + auto slim_seq_len = new_mask_dim[2]; + + ctx->SetOutputDim("SlimmedX", {bsz, slim_seq_len, c}); + ctx->SetOutputDim("CLSInds", {bsz, slim_seq_len}); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR( + fused_token_prune, + ops::FusedTokenPruneOp, + ops::FusedTokenPruneOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/fused_token_prune_op.cu b/paddle/fluid/operators/fused_token_prune_op.cu new file mode 100644 index 0000000000000..90044f30d8a6e --- /dev/null +++ b/paddle/fluid/operators/fused_token_prune_op.cu @@ -0,0 +1,287 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#ifdef __NVCC__ +#include +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include "paddle/phi/backends/gpu/gpu_launch_config.h" + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/fused_token_prune_op.cu.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +template +struct AttnMaskFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { + return b >= 0 ? a : 0; + } +}; + +__global__ void FillIndex(int64_t* indices, int num_raws, int num_cols) { + int num_threads = num_raws * num_cols; + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + + for (; tid < num_threads; tid += stride) { + int col = tid % num_cols; + indices[tid] = (int64_t)col; + } +} + +template +__global__ void TakeAlongAxis(const T* src, + T* dst, + int64_t* indices, + int num_raws, + int src_num_cols, + int dst_num_cols, + int num_elements) { + int num_threads = num_raws * dst_num_cols; + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + + for (; tid < num_threads; tid += stride) { + int raw = tid / dst_num_cols; + int col = tid % dst_num_cols; + for (int i = 0; i < num_elements; ++i) { + dst[tid * num_elements + i] = + *(src + (raw * src_num_cols + indices[tid]) * num_elements + i); + } + } +} + +template +__global__ void MaximumFirst(T* mat, int num_raws, int num_cols, T max_value) { + int num_threads = num_raws; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + for (; tid < num_threads; tid += stride) { + mat[tid * num_cols] = max_value; + } +} + +template +class FusedTokenPruneOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto& dev_ctx = context.cuda_device_context(); + // Inouts + const Tensor* attn = context.Input("Attn"); + const Tensor* x = context.Input("X"); + const Tensor* mask = context.Input("Mask"); + const Tensor* new_mask = context.Input("NewMask"); + + // Input dims + auto attn_dims = attn->dims(); + auto x_dims = x->dims(); + auto new_mask_dims = new_mask->dims(); + + auto bsz = attn_dims[0]; + auto num_heads = attn_dims[1]; + auto max_seq_len = attn_dims[2]; + auto c = x_dims[2]; + int slimmed_x_len = new_mask_dims[2]; + + // Attrs + const bool keep_first_token = context.Attr("keep_first_token"); + const bool keep_order = context.Attr("keep_order"); + + // Outputs + Tensor* out_slimmed_x = context.Output("SlimmedX"); + Tensor* slimmed_indices = context.Output("CLSInds"); + auto* out_slimmed_x_data = + out_slimmed_x->mutable_data(context.GetPlace()); + auto* slimmed_indices_data = + slimmed_indices->mutable_data(context.GetPlace()); + + // Intermediate variable + Tensor attn_tmp; + auto* attn_tmp_data = + attn_tmp.mutable_data(attn_dims, context.GetPlace()); + Tensor attn_accu; + auto* attn_accu_data = + attn_accu.mutable_data({bsz, max_seq_len}, context.GetPlace()); + Tensor attn_accu_indices; + auto* attn_accu_indices_data = attn_accu_indices.mutable_data( + {bsz, max_seq_len}, context.GetPlace()); + Tensor sort_attn_accu; + auto* sort_attn_accu_data = + sort_attn_accu.mutable_data({bsz, max_seq_len}, context.GetPlace()); + Tensor sort_attn_accu_indices; + auto* sort_attn_accu_indices_data = + sort_attn_accu_indices.mutable_data({bsz, max_seq_len}, + context.GetPlace()); + Tensor temp_storage; + + // 1. Filter attn by mask + std::vector ins; + std::vector outs; + ins.emplace_back(attn); + ins.emplace_back(mask); + outs.emplace_back(&attn_tmp); + LaunchElementwiseCudaKernel( + dev_ctx, ins, &outs, -1, AttnMaskFunctor()); + + // 2. Reduce sum + const std::vector reduce_dims{1, 2}; + phi::Reduce(dev_ctx, + attn_tmp, + false, + reduce_dims, + false, + attn_accu.dtype(), + &attn_accu); + // 3. Prepare token indices + phi::backends::gpu::GpuLaunchConfig config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, bsz * max_seq_len); + FillIndex<<>>(attn_accu_indices_data, bsz, max_seq_len); + + // 4. Sort token indices by attn + if (keep_first_token) { + T max = std::numeric_limits::max(); + config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, bsz); + MaximumFirst + <<>>(attn_accu_data, bsz, max_seq_len, max); + } + size_t temp_storage_bytes = -1; + int num_items = bsz * max_seq_len; + int num_segments = bsz; + + cub::CountingInputIterator counting_iter(0); + cub::TransformInputIterator> + segment_offsets_t(counting_iter, SegmentOffsetIter(max_seq_len)); + // Determine temporary device storage requirements + PADDLE_ENFORCE_GPU_SUCCESS( + cub::DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, + temp_storage_bytes, + attn_accu_data, + sort_attn_accu_data, + attn_accu_indices_data, + sort_attn_accu_indices_data, + num_items, + num_segments, + segment_offsets_t, + segment_offsets_t + 1, + 0, + sizeof(T) * 8, + dev_ctx.stream())); + // Allocate temporary storage + int64_t temp_size = temp_storage_bytes; + auto* temp_storage_data = + temp_storage.mutable_data({temp_size}, context.GetPlace()); + // Run sorting operation + PADDLE_ENFORCE_GPU_SUCCESS( + cub::DeviceSegmentedRadixSort::SortPairsDescending( + temp_storage_data, + temp_storage_bytes, + attn_accu_data, + sort_attn_accu_data, + attn_accu_indices_data, + sort_attn_accu_indices_data, + num_items, + num_segments, + segment_offsets_t, + segment_offsets_t + 1, + 0, + sizeof(T) * 8, + dev_ctx.stream())); + // 5. Slice + auto slimmed_indices_tmp = + phi::funcs::Slice(dev_ctx, + sort_attn_accu_indices, + {1} /*axes*/, + {0} /*starts*/, + {slimmed_x_len} /*ends*/); + if (keep_order) { + // 6. reorder + num_items = bsz * slimmed_x_len; + temp_storage_bytes = -1; + cub::TransformInputIterator> + segment_offsets_t2(counting_iter, SegmentOffsetIter(slimmed_x_len)); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, + temp_storage_bytes, + static_cast(slimmed_indices_tmp.data()), + static_cast(slimmed_indices->data()), + num_items, + num_segments, + segment_offsets_t2, + segment_offsets_t2 + 1, + 0, + sizeof(int64_t) * 8, + dev_ctx.stream())); + temp_size = temp_storage_bytes; + temp_storage.Resize({temp_size}); + temp_storage_data = + temp_storage.mutable_data(context.GetPlace()); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortKeys( + temp_storage_data, + temp_storage_bytes, + static_cast(slimmed_indices_tmp.data()), + static_cast(slimmed_indices->data()), + num_items, + num_segments, + segment_offsets_t2, + segment_offsets_t2 + 1, + 0, + sizeof(int64_t) * 8, + dev_ctx.stream())); + } else { + framework::TensorCopy( + slimmed_indices_tmp, context.GetPlace(), slimmed_indices); + } + // 7. Get slimmed X by indices + config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, bsz * slimmed_x_len); + TakeAlongAxis<<>>(x->data(), + out_slimmed_x_data, + slimmed_indices->data(), + bsz, + max_seq_len, + slimmed_x_len, + c); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(fused_token_prune, + ops::FusedTokenPruneOpCUDAKernel, + ops::FusedTokenPruneOpCUDAKernel); diff --git a/paddle/fluid/operators/fused_token_prune_op.cu.h b/paddle/fluid/operators/fused_token_prune_op.cu.h new file mode 100644 index 0000000000000..e1e73a5e3d9e2 --- /dev/null +++ b/paddle/fluid/operators/fused_token_prune_op.cu.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/phi/kernels/funcs/slice.h" +#include "paddle/phi/kernels/gpu/reduce.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" + +namespace paddle { +namespace operators { + +HOSTDEVICE inline int CeilDivide(int n, int m) { return (n + m - 1) / m; } + +inline int ComputeBlockSize(int col) { + if (col > 512) + return 1024; + else if (col > 256 && col <= 512) + return 512; + else if (col > 128 && col <= 256) + return 256; + else if (col > 64 && col <= 128) + return 128; + else + return 64; +} + +// Iter for move to next row +struct SegmentOffsetIter { + EIGEN_DEVICE_FUNC + explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const { + return idx * num_cols_; + } + + int num_cols_; +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 6f79a248cf38b..7a67bf95d15a8 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -28,6 +28,13 @@ if(NOT WITH_DISTRIBUTE) list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_c_allreduce") endif() +if(WIN32) + list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES + "test_trt_convert_fused_token_prune") + list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_fused_token_prune") + list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_fused_token_prune") +endif() + # Only for cpu(mkl + openblas) set(TEST_INFERENCE_CPU_UT "test_mul_lstm_fuse_pass" "test_mul_gru_fuse_pass") diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fused_token_prune.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fused_token_prune.py new file mode 100644 index 0000000000000..85c56506de5cf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fused_token_prune.py @@ -0,0 +1,129 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + + +class TrtConvertFusedTokenPruneTest(TrtLayerAutoScanTest): + + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 + + def generate_attn_or_mask(attrs: List[Dict[str, Any]]): + return np.ones([4, 12, 64, 64]).astype(np.float32) + + def generate_x(attrs: List[Dict[str, Any]]): + return np.random.random([4, 64, 76]).astype(np.float32) + + def generate_new_mask(attrs: List[Dict[str, Any]]): + return np.random.random([4, 12, 32, 32]).astype(np.float32) + + for keep_first_token in [True, False]: + for keep_order in [True, False]: + dics = [{ + "keep_first_token": keep_first_token, + "keep_order": keep_order + }] + ops_config = [{ + "op_type": "fused_token_prune", + "op_inputs": { + "Attn": ["attn"], + "X": ["x"], + "Mask": ["mask"], + "NewMask": ["new_mask"] + }, + "op_outputs": { + "SlimmedX": ["slimmed_x"], + "CLSInds": ["cls_inds"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "attn": + TensorConfig( + data_gen=partial(generate_attn_or_mask, dics)), + "x": + TensorConfig(data_gen=partial(generate_x, dics)), + "mask": + TensorConfig( + data_gen=partial(generate_attn_or_mask, dics)), + "new_mask": + TensorConfig(data_gen=partial(generate_new_mask, dics)) + }, + outputs=["slimmed_x", "cls_inds"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = { + "attn": [4, 12, 64, 64], + "x": [4, 64, 76], + "mask": [4, 12, 64, 64], + "new_mask": [4, 12, 32, 32] + } + self.dynamic_shape.max_input_shape = { + "attn": [4, 12, 64, 64], + "x": [4, 64, 76], + "mask": [4, 12, 64, 64], + "new_mask": [4, 12, 32, 32] + } + self.dynamic_shape.opt_input_shape = { + "attn": [4, 12, 64, 64], + "x": [4, 64, 76], + "mask": [4, 12, 64, 64], + "new_mask": [4, 12, 32, 32] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 6 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), (1e-5, 1e-5, 1e-5, 1e-5) + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), (1e-5, 1e-5, 1e-5, 1e-5) + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_token_prune_op.py b/python/paddle/fluid/tests/unittests/test_fused_token_prune_op.py new file mode 100644 index 0000000000000..9425283f078c0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_token_prune_op.py @@ -0,0 +1,112 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +from op_test import OpTest +from paddle.framework import core + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFusedTokenPruneOp(OpTest): + + def setDtype(self): + self.dtype = np.float32 + + def setInouts(self): + attn = [[1, 2], [3, 4]] + attn = np.array(attn, dtype=self.dtype) + attn = np.expand_dims(attn, axis=0) + self.attn = np.expand_dims( + attn, axis=0) # [1,1,2,2] bsz = 1, nd_head=1, max_seq_len=2 + mask = [[1, 1], [-1, -1]] + mask = np.array(mask, dtype=self.dtype) + mask = np.expand_dims(mask, axis=0) + self.mask = np.expand_dims(mask, axis=0) # same as attn + x = [[1, 2, 3], [4, 5, 6]] + x = np.array(x, dtype=self.dtype) + self.x = np.expand_dims(x, + axis=0) # [1, 2, 3] bsz = 1, max_seq_len=2, c=3 + new_mask = [[1]] + new_mask = np.array(new_mask, dtype=self.dtype) + new_mask = np.expand_dims(new_mask, axis=0) + self.new_mask = np.expand_dims(new_mask, axis=0) #[1, 1, 1, 1] + + out_slimmedx_py = [[[1, 2, 3]]] + self.out_slimmedx_py = np.array(out_slimmedx_py, dtype=self.dtype) + + out_cls_inds_py = [[0]] + self.out_cls_inds_py = np.array(out_cls_inds_py, dtype='int64') + + def setUp(self): + self.op_type = 'fused_token_prune' + self.setDtype() + self.setInouts() + self.inputs = { + 'Attn': self.attn, + 'Mask': self.mask, + 'X': self.x, + 'NewMask': self.new_mask + } + + self.outputs = { + 'SlimmedX': self.out_slimmedx_py, + 'CLSInds': self.out_cls_inds_py + } + + def test_check_output(self): + self.check_output_with_place(core.CUDAPlace(0)) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFusedTokenPruneOpFloat64(TestFusedTokenPruneOp): + + def setDtype(self): + self.dtype = np.float64 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFusedTokenPruneOp2(TestFusedTokenPruneOp): + + def setInouts(self): + attn = [[[[1, 2, 3, 4], [4, 3, 2, 1], [5, 9, 5, 4], [9, 6, 5, 4]], + [[8, 5, 2, 0], [1, 0, 2, 3], [2, 2, 3, 2], [7, 4, 1, 8]]]] + self.attn = np.array( + attn, + dtype=self.dtype) # [1,2,4,4] bsz = 1, nd_head=2, max_seq_len=4 + mask = [[[[-1, -1, -1, 1], [-1, -1, 1, 1], [-1, -1, 1, 1], + [-1, -1, 1, 1]], + [[-1, -1, 1, 1], [-1, -1, 1, 1], [-1, -1, 1, 1], + [-1, -1, 1, 1]]]] + self.mask = np.array(mask, dtype=self.dtype) # same as attn + x = [[[1.1, 1.1, 1.1], [2.2, 2.2, 2.2], [3.3, 3.3, 3.3], + [4.4, 4.4, 4.4]]] + self.x = np.array( + x, dtype=self.dtype) # [1, 4, 3] bsz = 1, max_seq_len=4, c=3 + self.new_mask = np.random.rand(1, 2, 2, + 2).astype(self.dtype) #[1, 2, 2, 2] + + out_slimmedx_py = [[[1.1, 1.1, 1.1], [4.4, 4.4, 4.4]]] #[1, 2, 3] + self.out_slimmedx_py = np.array(out_slimmedx_py, dtype=self.dtype) + + out_cls_inds_py = [[0, 3]] + self.out_cls_inds_py = np.array(out_cls_inds_py, dtype='int64') + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 95c5ecf713112..7e92b6b9b7afc 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -233,6 +233,7 @@ 'test_fused_elemwise_activation_op', 'test_fused_emb_seq_pool_op', 'test_fused_embedding_fc_lstm_op', + 'test_fused_token_prune_op', 'test_fusion_gru_op', 'test_fusion_lstm_op', 'test_fusion_repeated_fc_relu_op',