diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 76bf5948a2b98..7bb092d0e3c1c 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1192,6 +1192,7 @@ USE_TRT_CONVERTER(scale); USE_TRT_CONVERTER(stack); USE_TRT_CONVERTER(clip); USE_TRT_CONVERTER(gather); +USE_TRT_CONVERTER(roi_align); USE_TRT_CONVERTER(affine_channel); USE_TRT_CONVERTER(multiclass_nms); USE_TRT_CONVERTER(nearest_interp); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 6af76bd11cd59..bc7b7355ea192 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -6,6 +6,7 @@ nv_library(tensorrt_converter shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc gather_op.cc + roi_align_op.cc affine_channel_op.cc multiclass_nms_op.cc nearest_interp_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/roi_align_op.cc b/paddle/fluid/inference/tensorrt/convert/roi_align_op.cc new file mode 100644 index 0000000000000..1329608aecd20 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/roi_align_op.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * Roi Align Op + */ +class RoiAlignOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a fluid roi align op to tensorrt plugin"; + + framework::OpDesc op_desc(op, nullptr); + std::string input_name = op_desc.Input("X").front(); + std::string rois_name = op_desc.Input("ROIs").front(); + std::string output_name = op_desc.Output("Out").front(); + + const auto pooled_height = + BOOST_GET_CONST(int, op_desc.GetAttr("pooled_height")); + const auto pooled_width = + BOOST_GET_CONST(int, op_desc.GetAttr("pooled_width")); + const auto spatial_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("spatial_scale")); + const auto sampling_ratio = + BOOST_GET_CONST(int, op_desc.GetAttr("sampling_ratio")); + + const auto input_tensor = engine_->GetITensor(input_name); + const auto rois_tensor = engine_->GetITensor(rois_name); + + const nvinfer1::DataType data_type_ = engine_->WithFp16() + ? nvinfer1::DataType::kHALF + : nvinfer1::DataType::kFLOAT; + + std::vector inputs{input_tensor, rois_tensor}; + nvinfer1::ILayer* layer = nullptr; + + PADDLE_ENFORCE_EQ( + engine_->with_dynamic_shape(), true, + platform::errors::InvalidArgument( + "TRT roi align plugin only accept the dynamic shape, because that " + "the roi_align will change the batch size.")); + + auto* roi_align_plugin = new plugin::RoiAlignPluginDynamic( + data_type_, pooled_height, pooled_width, spatial_scale, sampling_ratio); + auto roi_align_layer = engine_->network()->addPluginV2( + inputs.data(), inputs.size(), *roi_align_plugin); + layer = roi_align_layer; + + std::vector output_names{output_name}; + RreplenishLayerAndOutput(layer, "roi_align", output_names, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(roi_align, RoiAlignOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index eb429405d18ae..7c1b2e8001edb 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -111,6 +111,7 @@ struct SimpleOpTypeSetTeller : public Teller { "flatten2", "flatten", "gather", + "roi_align", "affine_channel", "multiclass_nms", "nearest_interp", @@ -263,6 +264,29 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, BOOST_GET_CONST(std::string, desc.GetAttr("interp_method")); if (interp_method != "nearest") return false; } + + if (op_type == "roi_align") { + if (!with_dynamic_shape) return false; + + std::vector attrs{"pooled_height", "pooled_width", + "spatial_scale", "sampling_ratio"}; + for (auto const attr : attrs) { + if (!desc.HasAttr(attr)) return false; + } + + const auto pooled_height = + BOOST_GET_CONST(int, desc.GetAttr("pooled_height")); + if (pooled_height <= 0) return false; + + const auto pooled_width = + BOOST_GET_CONST(int, desc.GetAttr("pooled_width")); + if (pooled_width <= 0) return false; + + const auto spatial_scale = + BOOST_GET_CONST(float, desc.GetAttr("spatial_scale")); + if (spatial_scale <= 0.f) return false; + } + if ((*teller)(op_type, desc, use_no_calib_int8)) return true; } return false; diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 7ee16a598d2d0..4107f9ef67433 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -5,6 +5,7 @@ nv_library(tensorrt_plugin instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu + roi_align_op_plugin.cu DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.cu new file mode 100644 index 0000000000000..42c0df41a1b5e --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.cu @@ -0,0 +1,380 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +template +__inline__ __device__ T BilinearInterpolate(const T* input_data, + const int height, const int width, + T y, T x) { + if (y < -1.f || y > height || x < -1.f || x > width) return 0; + y = y <= 0.f ? 0.f : y; + x = x <= 0.f ? 0.f : x; + int y_low = static_cast(y); + int x_low = static_cast(x); + int y_high; + int x_high; + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = static_cast(y_low); + } else { + y_high = y_low + 1; + } + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = static_cast(x_low); + } else { + x_high = x_low + 1; + } + T ly = y - y_low, lx = x - x_low; + T hy = 1.f - ly, hx = 1.f - lx; + T v1 = input_data[y_low * width + x_low]; + T v2 = input_data[y_low * width + x_high]; + T v3 = input_data[y_high * width + x_low]; + T v4 = input_data[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__global__ void GPUROIAlignOpt(const int nthreads, + const T* __restrict__ input_data, + const T* __restrict__ input_rois, + const float spatial_scale, const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int sampling_ratio, const int num_rois, + OutT* __restrict__ output_data) { + const int batch = blockIdx.x; + const int channel = blockIdx.y; + const T* offset_input_data = + input_data + (batch * channels + channel) * height * width; + extern __shared__ T s_input_data[]; + if (USE_SMEM) { + for (int idx = threadIdx.x; idx < height * width; idx += blockDim.x) { + s_input_data[idx] = offset_input_data[idx]; + } + __syncthreads(); + } + for (int idx = threadIdx.x; idx < num_rois * pooled_height * pooled_width; + idx += blockDim.x) { + const int pw = idx % pooled_width; + const int ph = (idx / pooled_width) % pooled_height; + const int roi_idx = (idx / pooled_width / pooled_height) % num_rois; + const int n = batch * num_rois + roi_idx; + const float4 rois_offset = reinterpret_cast(input_rois)[n]; + const T roi_xmin = rois_offset.x * spatial_scale; + const T roi_ymin = rois_offset.y * spatial_scale; + const T roi_xmax = rois_offset.z * spatial_scale; + const T roi_ymax = rois_offset.w * spatial_scale; + const T roi_width = max(roi_xmax - roi_xmin, static_cast(1.f)); + const T roi_height = max(roi_ymax - roi_ymin, static_cast(1.f)); + const T bin_size_h = roi_height / static_cast(pooled_height); + const T bin_size_w = roi_width / static_cast(pooled_width); + const int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + const int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + T output_val = 0.f; + for (int iy = 0; iy < roi_bin_grid_h; ++iy) { + const T y = roi_ymin + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ++ix) { + const T x = roi_xmin + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + if (USE_SMEM) { + T val = BilinearInterpolate(s_input_data, height, width, y, x); + output_val += val; + } else { + T val = + BilinearInterpolate(offset_input_data, height, width, y, x); + output_val += val; + } + } + } + output_val /= count; + const int out_offset = + batch * num_rois * channels * pooled_height * pooled_width + + roi_idx * channels * pooled_height * pooled_width + + channel * pooled_height * pooled_width + ph * pooled_width + pw; + output_data[out_offset] = static_cast(output_val); + } +} + +#if IS_TRT_VERSION_GE(6000) +RoiAlignPluginDynamic::RoiAlignPluginDynamic(const nvinfer1::DataType data_type, + const int pooled_height, + const int pooled_width, + float spatial_scale, + int sampling_ratio) + : data_type_(data_type), + pooled_height_(pooled_height), + pooled_width_(pooled_width), + spatial_scale_(spatial_scale), + sampling_ratio_(sampling_ratio) { + bool data_type_is_valid = data_type_ == nvinfer1::DataType::kFLOAT || + data_type_ == nvinfer1::DataType::kHALF; + PADDLE_ENFORCE_EQ(data_type_is_valid, true, + platform::errors::InvalidArgument( + "TRT RoiAlign plugin only accepts kFLOAT(%d) or " + "kHALF(%d) data type, but the received data type = %d", + static_cast(nvinfer1::DataType::kFLOAT), + static_cast(nvinfer1::DataType::kHALF), + static_cast(data_type_))); + + PADDLE_ENFORCE_GT(pooled_height_, 0, + platform::errors::InvalidArgument( + "TRT RoiAlign plugin only accepts pooled_height " + "greater than %d, but the received pooled_height = %d", + 0, pooled_height_)); + + PADDLE_ENFORCE_GT(pooled_width_, 0, + platform::errors::InvalidArgument( + "TRT RoiAlign plugin only accepts pooled_width greater " + "than %d, but the received pooled_width = %d", + 0, pooled_height_)); + + PADDLE_ENFORCE_GT(spatial_scale_, 0.f, + platform::errors::InvalidArgument( + "TRT RoiAlign plugin only accepts spatial_scale " + "greater than %f, but the received spatial_scale = %f", + 0, spatial_scale_)); + + int smem_per_block = -1; + int device = -1; + cudaGetDevice(&device); + + PADDLE_ENFORCE_GE( + device, 0, + platform::errors::InvalidArgument( + "The cuda device ID should be greater than %d, but device ID is %d", + 0, device)); + + cudaDeviceGetAttribute(&smem_per_block, cudaDevAttrMaxSharedMemoryPerBlock, + device); + smem_per_block_ = smem_per_block; +} + +RoiAlignPluginDynamic::RoiAlignPluginDynamic(void const* data, size_t length) { + DeserializeValue(&data, &length, &data_type_); + DeserializeValue(&data, &length, &pooled_height_); + DeserializeValue(&data, &length, &pooled_width_); + DeserializeValue(&data, &length, &spatial_scale_); + DeserializeValue(&data, &length, &sampling_ratio_); + int smem_per_block = -1; + int device = -1; + cudaGetDevice(&device); + PADDLE_ENFORCE_GE( + device, 0, + platform::errors::InvalidArgument( + "The cuda device ID should be greater than %d, but device ID is %d", + 0, device)); + cudaDeviceGetAttribute(&smem_per_block, cudaDevAttrMaxSharedMemoryPerBlock, + device); + smem_per_block_ = smem_per_block; +} + +nvinfer1::IPluginV2DynamicExt* RoiAlignPluginDynamic::clone() const { + auto* plugin = + new RoiAlignPluginDynamic(data_type_, pooled_height_, pooled_width_, + spatial_scale_, sampling_ratio_); + plugin->setPluginNamespace(namespace_.c_str()); + return plugin; +} + +nvinfer1::DimsExprs RoiAlignPluginDynamic::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) { + nvinfer1::DimsExprs ret{}; + ret.nbDims = 4; + ret.d[0] = inputs[1].d[0]; // roi + ret.d[1] = inputs[0].d[1]; // X + ret.d[2] = exprBuilder.constant(pooled_height_); + ret.d[3] = exprBuilder.constant(pooled_width_); + return ret; +} + +bool RoiAlignPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, + int nbOutputs) { + if (inOut[pos].format != nvinfer1::TensorFormat::kLINEAR) { + return false; + } + if (pos < 2) { // input + return inOut[pos].type == nvinfer1::DataType::kFLOAT; + } + return inOut[pos].type == data_type_; +} + +void RoiAlignPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {} + +size_t RoiAlignPluginDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc* inputs, int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const { + return 0; +} + +template +int RoiAlignPluginDynamic::enqueue_impl( + const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, + void* const* outputs, void* workspace, cudaStream_t stream) { + auto in_dims = inputDesc[0].dims; + auto rois_dims = inputDesc[1].dims; + auto out_dims = outputDesc[0].dims; + + int rois_num = rois_dims.d[0]; + if (rois_num == 0) return cudaGetLastError() != cudaSuccess; + + int batch = in_dims.d[0]; + int channels = in_dims.d[1]; + int height = in_dims.d[2]; + int width = in_dims.d[3]; + + int output_size = + out_dims.d[0] * out_dims.d[1] * out_dims.d[2] * out_dims.d[3]; + + const dim3 blocks(batch, channels); + const int threads = 512; + + if (smem_per_block_ < width * height * sizeof(T)) { + GPUROIAlignOpt<<>>( + output_size, static_cast(inputs[0]), + static_cast(inputs[1]), spatial_scale_, channels, height, + width, pooled_height_, pooled_width_, sampling_ratio_, rois_num / batch, + static_cast(outputs[0])); + } else { + GPUROIAlignOpt< + T, OutT, true><<>>( + output_size, static_cast(inputs[0]), + static_cast(inputs[1]), spatial_scale_, channels, height, + width, pooled_height_, pooled_width_, sampling_ratio_, rois_num / batch, + static_cast(outputs[0])); + } + + return cudaGetLastError() != cudaSuccess; +} + +int RoiAlignPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, void* workspace, + cudaStream_t stream) { + PADDLE_ENFORCE_EQ(outputDesc[0].type, data_type_, + platform::errors::InvalidArgument( + "TRT RoiAlignPluginDynamic expects outputDesc[0].type " + "equal to data_type_")); + + if (data_type_ == nvinfer1::DataType::kHALF) { + return enqueue_impl(inputDesc, outputDesc, inputs, outputs, + workspace, stream); + } + return enqueue_impl(inputDesc, outputDesc, inputs, outputs, + workspace, stream); +} + +nvinfer1::DataType RoiAlignPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType* inputTypes, int nbInputs) const { + return data_type_; +} + +const char* RoiAlignPluginDynamic::getPluginType() const { + return "roi_align_plugin_dynamic"; +} + +int RoiAlignPluginDynamic::getNbOutputs() const { return 1; } + +int RoiAlignPluginDynamic::initialize() { return 0; } + +void RoiAlignPluginDynamic::terminate() {} + +size_t RoiAlignPluginDynamic::getSerializationSize() const { + size_t serialize_size = 0; + serialize_size += SerializedSize(data_type_); + serialize_size += SerializedSize(pooled_height_); + serialize_size += SerializedSize(pooled_width_); + serialize_size += SerializedSize(spatial_scale_); + serialize_size += SerializedSize(sampling_ratio_); + return serialize_size; +} + +void RoiAlignPluginDynamic::serialize(void* buffer) const { + SerializeValue(&buffer, data_type_); + SerializeValue(&buffer, pooled_height_); + SerializeValue(&buffer, pooled_width_); + SerializeValue(&buffer, spatial_scale_); + SerializeValue(&buffer, sampling_ratio_); +} + +void RoiAlignPluginDynamic::destroy() {} + +RoiAlignPluginDynamicCreator::RoiAlignPluginDynamicCreator() {} + +void RoiAlignPluginDynamicCreator::setPluginNamespace( + const char* lib_namespace) { + namespace_ = std::string(lib_namespace); +} + +const char* RoiAlignPluginDynamicCreator::getPluginNamespace() const { + return namespace_.c_str(); +} + +const char* RoiAlignPluginDynamicCreator::getPluginName() const { + return "roi_align_plugin_dynamic"; +} + +const char* RoiAlignPluginDynamicCreator::getPluginVersion() const { + return "1"; +} + +const nvinfer1::PluginFieldCollection* +RoiAlignPluginDynamicCreator::getFieldNames() { + return &field_collection_; +} + +nvinfer1::IPluginV2Ext* RoiAlignPluginDynamicCreator::createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) { + const nvinfer1::PluginField* fields = fc->fields; +} + +nvinfer1::IPluginV2Ext* RoiAlignPluginDynamicCreator::deserializePlugin( + const char* name, const void* serial_data, size_t serial_length) { + auto plugin = new RoiAlignPluginDynamic(serial_data, serial_length); + plugin->setPluginNamespace(namespace_.c_str()); + return plugin; +} +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.h new file mode 100644 index 0000000000000..bba7d0d5a9966 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.h @@ -0,0 +1,112 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#if IS_TRT_VERSION_GE(6000) +class RoiAlignPluginDynamic : public DynamicPluginTensorRT { + public: + explicit RoiAlignPluginDynamic(const nvinfer1::DataType data_type, + const int pooled_height, + const int pooled_width, float spatial_scale, + int sampling_ratio); + RoiAlignPluginDynamic(void const* data, size_t length); + ~RoiAlignPluginDynamic() = default; + nvinfer1::IPluginV2DynamicExt* clone() const override; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, int nbOutputs) override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) override; + + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const override; + + const char* getPluginType() const override; + int getNbOutputs() const override; + int initialize() override; + void terminate() override; + size_t getSerializationSize() const override; + void serialize(void* buffer) const override; + void destroy() override; + + private: + template + int enqueue_impl(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, + void* workspace, cudaStream_t stream); + + nvinfer1::DataType data_type_; + int pooled_height_; + int pooled_width_; + float spatial_scale_; + int sampling_ratio_; + int smem_per_block_; + std::string namespace_; +}; + +class RoiAlignPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + RoiAlignPluginDynamicCreator(); + ~RoiAlignPluginDynamicCreator() override = default; + + void setPluginNamespace(const char* lib_namespace) override; + const char* getPluginNamespace() const override; + const char* getPluginName() const override; + const char* getPluginVersion() const override; + const nvinfer1::PluginFieldCollection* getFieldNames() override; + + nvinfer1::IPluginV2Ext* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) override; + nvinfer1::IPluginV2Ext* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) override; + + private: + std::string namespace_; + nvinfer1::PluginFieldCollection field_collection_; +}; +REGISTER_TRT_PLUGIN_V2(RoiAlignPluginDynamicCreator); +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_roi_align_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_roi_align_op.py new file mode 100644 index 0000000000000..fa276dd342bc6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_roi_align_op.py @@ -0,0 +1,119 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import PassVersionChecker +from paddle.fluid.core import AnalysisConfig + + +class TRTRoiAlignTest(InferencePassTest): + def setUp(self): + self.bs = 2 + self.num_rois = 4 + self.channel = 16 + self.height = 32 + self.width = 32 + self.precision = AnalysisConfig.Precision.Float32 + self.serialize = False + self.enable_trt = True + + def build(self): + self.trt_parameters = TRTRoiAlignTest.TensorRTParam( + 1 << 30, self.bs * self.num_rois, 1, self.precision, self.serialize, + False) + with fluid.program_guard(self.main_program, self.startup_program): + data_shape = [-1, self.channel, self.height, self.width] + data = fluid.data(name='data', shape=data_shape, dtype='float32') + rois = fluid.data( + name='rois', shape=[-1, 4], dtype='float32', lod_level=1) + roi_align_out = fluid.layers.roi_align(data, rois) + out = fluid.layers.batch_norm(roi_align_out, is_test=True) + + rois_lod = fluid.create_lod_tensor( + np.random.random([self.bs * self.num_rois, 4]).astype('float32'), + [[self.num_rois, self.num_rois]], fluid.CPUPlace()) + + data_shape[0] = self.bs + self.feeds = { + 'data': np.random.random(data_shape).astype('float32'), + 'rois': rois_lod, + } + self.fetch_list = [out] + + def check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + atol = 1e-5 + if self.trt_parameters.precision == AnalysisConfig.Precision.Half: + atol = 1e-3 + self.check_output_with_option(use_gpu, atol, flatten=True) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + def set_dynamic(self): + min_shape_spec = dict() + max_shape_spec = dict() + opt_shape_spec = dict() + min_shape_spec['data'] = [ + self.bs, self.channel, self.height // 2, self.width // 2 + ] + min_shape_spec['rois'] = [1, 4] + max_shape_spec[ + 'data'] = [self.bs, self.channel, self.height * 2, self.width * 2] + max_shape_spec['rois'] = [self.bs * self.num_rois, 4] + opt_shape_spec[ + 'data'] = [self.bs, self.channel, self.height, self.width] + opt_shape_spec['rois'] = [self.bs * self.num_rois, 4] + + self.dynamic_shape_params = InferencePassTest.DynamicShapeParam( + min_shape_spec, max_shape_spec, opt_shape_spec, False) + + def run_test(self): + self.build() + self.check_output() + + def test_base(self): + self.run_test() + + def test_fp16(self): + self.precision = AnalysisConfig.Precision.Half + self.run_test() + + def test_serialize(self): + self.serialize = True + self.run_test() + + def test_dynamic(self): + self.set_dynamic() + self.run_test() + + def test_dynamic_fp16(self): + self.set_dynamic() + self.precision = AnalysisConfig.Precision.Half + self.run_test() + + def test_dynamic_serialize(self): + self.set_dynamic() + self.serialize = True + self.run_test() + + +if __name__ == "__main__": + unittest.main()