diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index af4cc0b83ef6c..434223b3ba43e 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1808,6 +1808,8 @@ USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm) USE_TRT_CONVERTER(preln_skip_layernorm) USE_TRT_CONVERTER(roll) USE_TRT_CONVERTER(strided_slice) +USE_TRT_CONVERTER(squeeze2) +USE_TRT_CONVERTER(unsqueeze2) #endif namespace paddle_infer { diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 0ea12084b6bea..5ad25b235588b 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -56,6 +56,8 @@ nv_library( strided_slice_op.cc preln_skip_layernorm.cc roll_op.cc + squeeze2_op.cc + unsqueeze2_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) diff --git a/paddle/fluid/inference/tensorrt/convert/squeeze2_op.cc b/paddle/fluid/inference/tensorrt/convert/squeeze2_op.cc new file mode 100644 index 0000000000000..9b494b0331880 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/squeeze2_op.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class Squeeze2OpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(4) << "convert a fluid squeeze2 op to tensorrt shuffle layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + auto input_dims = input->getDimensions(); + auto output_name = op_desc.Output("Out")[0]; + + // Get Attrs + std::vector axes = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("axes")); + PADDLE_ENFORCE_GT( + axes.size(), + 0, + platform::errors::InvalidArgument( + "Attr(axes).size should be > 0 in squeeze2 op in TensorRT," + "but received axes.size() = %d.", + axes.size())); + + std::vector should_squeeze(input_dims.nbDims, false); + for (size_t i = 0; i < axes.size(); i++) { + if (engine_->with_dynamic_shape()) { + axes[i] += (axes[i] < 0) ? input_dims.nbDims : 0; + } else { + axes[i] += (axes[i] < 0) ? input_dims.nbDims : -1; + } + should_squeeze[axes[i]] = true; + } + + nvinfer1::Dims trt_out_dims; + trt_out_dims.nbDims = 0; + std::vector gather_indices; + for (size_t i = 0; i < should_squeeze.size(); i++) { + if (should_squeeze[i]) continue; + gather_indices.push_back(i); + // for static shape + trt_out_dims.d[trt_out_dims.nbDims] = input_dims.d[i]; + trt_out_dims.nbDims++; + } + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + if (engine_->with_dynamic_shape()) { + auto* shape_tensor = Shape(input); + auto* real_shape_tensor = Gather(shape_tensor, gather_indices); + layer->setInput(1, *real_shape_tensor); + } else { + layer->setReshapeDimensions(trt_out_dims); + } + RreplenishLayerAndOutput(layer, "squeeze2", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(squeeze2, Squeeze2OpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/unsqueeze2_op.cc b/paddle/fluid/inference/tensorrt/convert/unsqueeze2_op.cc new file mode 100644 index 0000000000000..eb25971534f79 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/unsqueeze2_op.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class Unsqueeze2OpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(4) << "convert a fluid unsqueeze2 op to tensorrt shuffle layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + auto input_dims = input->getDimensions(); + auto output_name = op_desc.Output("Out")[0]; + + // Get Attrs + std::vector axes = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("axes")); + PADDLE_ENFORCE_GT( + axes.size(), + 0, + platform::errors::InvalidArgument( + "Attr(axes).size should be > 0 in unsqueeze2 op in TensorRT," + "but received axes.size() = %d.", + axes.size())); + + std::vector should_unsqueeze(input_dims.nbDims + axes.size(), false); + int cur_out_rank = input_dims.nbDims; + for (size_t i = 0; i < axes.size(); i++) { + cur_out_rank++; + if (engine_->with_dynamic_shape()) { + axes[i] += (axes[i] < 0) ? cur_out_rank : 0; + } else { + axes[i] += (axes[i] < 0) ? cur_out_rank : -1; + } + // axes[i] is relative to cur_out_rank + // we make [axes[i], cur_out_rank - 2] shift right + // and make (axes[i]) to true! + for (int j = cur_out_rank - 1; j > axes[i]; j--) { + should_unsqueeze[j] = should_unsqueeze[j - 1]; + } + if (axes[i] >= cur_out_rank) + should_unsqueeze[cur_out_rank - 1] = true; + else + should_unsqueeze[axes[i]] = true; + } + + nvinfer1::Dims trt_out_dims; + trt_out_dims.nbDims = should_unsqueeze.size(); + std::vector gather_indices; + int in_rank_i = 0; + for (size_t i = 0; i < should_unsqueeze.size(); i++) { + if (should_unsqueeze[i]) { + trt_out_dims.d[i] = 1; + gather_indices.push_back(input_dims.nbDims); + continue; + } + trt_out_dims.d[i] = input_dims.d[in_rank_i]; + gather_indices.push_back(in_rank_i); + in_rank_i++; + } + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + if (engine_->with_dynamic_shape()) { + auto* shape_tensor = Shape(input); + std::vector all_one(axes.size(), 1); + auto* all_one_tensor = Add1DConstantLayer(all_one); + std::vector concat_inputs = {shape_tensor, + all_one_tensor}; + auto* real_shape_tensor = Gather(Concat(concat_inputs), gather_indices); + layer->setInput(1, *real_shape_tensor); + } else { + layer->setReshapeDimensions(trt_out_dims); + } + RreplenishLayerAndOutput(layer, "unsqueeze2", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(unsqueeze2, Unsqueeze2OpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index d545d6f0e67e2..6f413dd0042dc 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -114,6 +114,8 @@ struct SimpleOpTypeSetTeller : public Teller { "bilinear_interp_v2", "cast", "pool3d", + "squeeze2", + "unsqueeze2", "deformable_conv", "relu6", "hard_sigmoid", @@ -179,6 +181,8 @@ struct SimpleOpTypeSetTeller : public Teller { "nearest_interp_v2", "cast", "pool3d", + "squeeze2", + "unsqueeze2", "deformable_conv", "relu6", "hard_sigmoid", @@ -891,6 +895,44 @@ bool OpTeller::Tell(const framework::ir::Node* node, } } + if (op_type == "squeeze2") { + std::vector axes; + if (desc.HasAttr("axes")) { + axes = BOOST_GET_CONST(std::vector, desc.GetAttr("axes")); + } + if (axes.size() == 0) { + VLOG(3) << "The necessary attributes of the squeeze2 operator axes is " + "missing."; + return false; + } + if (!with_dynamic_shape) { + if (std::find(axes.begin(), axes.end(), 0) != axes.end()) { + VLOG(3) << "Invalid squeeze axes. Axes having batch axis is not " + "supported in static shape"; + return false; + } + } + } + + if (op_type == "unsqueeze2") { + std::vector axes; + if (desc.HasAttr("axes")) { + axes = BOOST_GET_CONST(std::vector, desc.GetAttr("axes")); + } + if (axes.size() == 0) { + VLOG(3) << "The necessary attributes of the squeeze2 operator axes is " + "missing."; + return false; + } + if (!with_dynamic_shape) { + if (std::find(axes.begin(), axes.end(), 0) != axes.end()) { + VLOG(3) << "Invalid squeeze axes. Axes having batch axis is not " + "supported in static shape"; + return false; + } + } + } + if (op_type == "batch_norm") { const std::vector bn_inputs = { "X", "Bias", "Mean", "Scale", "Variance"}; diff --git a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu index 9bfe98d759d8e..dcefa1b05a8a4 100644 --- a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu @@ -61,16 +61,26 @@ void PoolPlugin::serialize(void *buffer) const TRT_NOEXCEPT { } PoolPlugin *PoolPlugin::clone() const TRT_NOEXCEPT { - return new PoolPlugin(ceil_mode_, pool_type_, adaptive_, exclusive_, ksize_, - strides_, paddings_, input_shape_, real_paddings_); + return new PoolPlugin(ceil_mode_, + pool_type_, + adaptive_, + exclusive_, + ksize_, + strides_, + paddings_, + input_shape_, + real_paddings_); } -int PoolPlugin::enqueue(int batchSize, const void *const *inputs, +int PoolPlugin::enqueue(int batchSize, + const void *const *inputs, #if IS_TRT_VERSION_LT(8000) - void **outputs, void *workspace, + void **outputs, + void *workspace, cudaStream_t stream) TRT_NOEXCEPT { #else - void *const *outputs, void *workspace, + void *const *outputs, + void *workspace, cudaStream_t stream) TRT_NOEXCEPT { #endif auto const &input_dims = this->getInputDims(0); @@ -87,14 +97,31 @@ int PoolPlugin::enqueue(int batchSize, const void *const *inputs, phi::funcs::MaxPool pool_process; phi::funcs::Pool2dDirectCUDAFunctor, float> pool2d_forward; - pool2d_forward(idata, input_shape, output_shape, ksize_, strides_, - paddings_, true, false, odatas[0], stream, pool_process); + pool2d_forward(idata, + input_shape, + output_shape, + ksize_, + strides_, + paddings_, + true, + false, + odatas[0], + stream, + pool_process); } else if (pool_type_ == PoolType::avg) { phi::funcs::AvgPool pool_process; phi::funcs::Pool2dDirectCUDAFunctor, float> pool2d_forward; - pool2d_forward(idata, input_shape, output_shape, ksize_, strides_, - paddings_, exclusive_, adaptive_, odatas[0], stream, + pool2d_forward(idata, + input_shape, + output_shape, + ksize_, + strides_, + paddings_, + exclusive_, + adaptive_, + odatas[0], + stream, pool_process); } @@ -137,21 +164,25 @@ void PoolPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT { } nvinfer1::IPluginV2DynamicExt *PoolPluginDynamic::clone() const TRT_NOEXCEPT { - return new PoolPluginDynamic(ceil_mode_, pool_type_, adaptive_, exclusive_, - ksize_, strides_, paddings_, is_global_); + return new PoolPluginDynamic(ceil_mode_, + pool_type_, + adaptive_, + exclusive_, + ksize_, + strides_, + paddings_, + is_global_); } nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions( - int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, + int output_index, + const nvinfer1::DimsExprs *inputs, + int nb_inputs, nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { - PADDLE_ENFORCE_EQ(nb_inputs, 1, + PADDLE_ENFORCE_EQ(nb_inputs, + 1, platform::errors::InvalidArgument( "The Split plugin should be only one input.")); - - PADDLE_ENFORCE_EQ( - inputs[0].d[1]->isConstant(), true, - platform::errors::InvalidArgument("The channel dimension should be " - "static, but we found it's dynamic.")); nvinfer1::DimsExprs output(inputs[0]); if (is_global_ && !adaptive_) { output.d[2] = expr_builder.constant(1); @@ -184,16 +215,16 @@ nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions( nvinfer1::DimensionOperation::kSUM, *expr_builder.operation( nvinfer1::DimensionOperation::kFLOOR_DIV, - *expr_builder.operation(nvinfer1::DimensionOperation::kSUM, - *inputs[0].d[2], *v0_tmp), + *expr_builder.operation( + nvinfer1::DimensionOperation::kSUM, *inputs[0].d[2], *v0_tmp), *stri_0), *one_value); output.d[3] = expr_builder.operation( nvinfer1::DimensionOperation::kSUM, *expr_builder.operation( nvinfer1::DimensionOperation::kFLOOR_DIV, - *expr_builder.operation(nvinfer1::DimensionOperation::kSUM, - *inputs[0].d[3], *v1_tmp), + *expr_builder.operation( + nvinfer1::DimensionOperation::kSUM, *inputs[0].d[3], *v1_tmp), *stri_1), *one_value); @@ -202,8 +233,8 @@ nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions( nvinfer1::DimensionOperation::kSUM, *expr_builder.operation( nvinfer1::DimensionOperation::kFLOOR_DIV, - *expr_builder.operation(nvinfer1::DimensionOperation::kSUM, - *inputs[0].d[2], *ceil_tmp), + *expr_builder.operation( + nvinfer1::DimensionOperation::kSUM, *inputs[0].d[2], *ceil_tmp), *stri_0), *one_value); output.d[3] = expr_builder.operation( @@ -211,7 +242,8 @@ nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions( *expr_builder.operation( nvinfer1::DimensionOperation::kFLOOR_DIV, *expr_builder.operation(nvinfer1::DimensionOperation::kSUM, - *inputs[0].d[3], *ceil1_tmp), + *inputs[0].d[3], + *ceil1_tmp), *stri_1), *one_value); } @@ -220,17 +252,22 @@ nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions( } bool PoolPluginDynamic::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs, + int pos, + const nvinfer1::PluginTensorDesc *in_out, + int nb_inputs, int nb_outputs) TRT_NOEXCEPT { PADDLE_ENFORCE_NOT_NULL( - in_out, platform::errors::InvalidArgument( - "The input of swish plugin shoule not be nullptr.")); + in_out, + platform::errors::InvalidArgument( + "The input of swish plugin shoule not be nullptr.")); PADDLE_ENFORCE_LT( - pos, nb_inputs + nb_outputs, + pos, + nb_inputs + nb_outputs, platform::errors::InvalidArgument("The pos(%d) should be less than the " "num(%d) of the input and the output.", - pos, nb_inputs + nb_outputs)); + pos, + nb_inputs + nb_outputs)); (in_out && pos < (nb_inputs + nb_outputs)); return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) && @@ -238,13 +275,17 @@ bool PoolPluginDynamic::supportsFormatCombination( } nvinfer1::DataType PoolPluginDynamic::getOutputDataType( - int index, const nvinfer1::DataType *input_types, + int index, + const nvinfer1::DataType *input_types, int nb_inputs) const TRT_NOEXCEPT { - PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( - "The Pool Plugin only has one input, so the " - "index value should be 0, but get %d.", - index)); - PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT), true, + PADDLE_ENFORCE_EQ(index, + 0, + platform::errors::InvalidArgument( + "The Pool Plugin only has one input, so the " + "index value should be 0, but get %d.", + index)); + PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT), + true, platform::errors::InvalidArgument( "The input type should be half or float")); return input_types[0]; @@ -252,7 +293,8 @@ nvinfer1::DataType PoolPluginDynamic::getOutputDataType( int PoolPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *output_desc, - const void *const *inputs, void *const *outputs, + const void *const *inputs, + void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT { auto input_dims = input_desc[0].dims; @@ -279,8 +321,8 @@ int PoolPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, output_shape[2] = 1; output_shape[3] = 1; } else { - auto data_dim = CalcOutputSize({h, w}, ceil_mode_, adaptive_, ksize_, - strides_, paddings_); + auto data_dim = CalcOutputSize( + {h, w}, ceil_mode_, adaptive_, ksize_, strides_, paddings_); output_shape[2] = data_dim[0]; output_shape[3] = data_dim[1]; } @@ -293,14 +335,32 @@ int PoolPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, phi::funcs::MaxPool pool_process; phi::funcs::Pool2dDirectCUDAFunctor, float> pool2d_forward; - pool2d_forward(input, input_shape, output_shape, ksize, strides_, paddings, - true, false, output, stream, pool_process); + pool2d_forward(input, + input_shape, + output_shape, + ksize, + strides_, + paddings, + true, + false, + output, + stream, + pool_process); } else if (pool_type_ == "avg") { phi::funcs::AvgPool pool_process; phi::funcs::Pool2dDirectCUDAFunctor, float> pool2d_forward; - pool2d_forward(input, input_shape, output_shape, ksize, strides_, paddings, - exclusive_, adaptive_, output, stream, pool_process); + pool2d_forward(input, + input_shape, + output_shape, + ksize, + strides_, + paddings, + exclusive_, + adaptive_, + output, + stream, + pool_process); } return cudaGetLastError() != cudaSuccess; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_squeeze2.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_squeeze2.py new file mode 100644 index 0000000000000..10af67b51cc1b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_squeeze2.py @@ -0,0 +1,138 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import unittest +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertSqueeze2Test(TrtLayerAutoScanTest): + + def is_program_valid(self, program_config: ProgramConfig) -> bool: + inputs = program_config.inputs + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + if len(inputs['in_data'].shape) <= max(attrs[0]['axes']): + return False + return True + + def sample_program_configs(self): + for dims in [2, 3, 4]: + for batch in [3, 4]: + for axes in [[2], [2, 3], [-1]]: + self.batch = batch + self.dims = dims + self.axes = axes + dics = [{"axes": axes}] + ops_config = [{ + "op_type": "squeeze2", + "op_inputs": { + "X": ["in_data"] + }, + "op_outputs": { + "Out": ["out_data"], + "XShape": ["XShape_data"] + }, + "op_attrs": dics[0] + }] + # new_axes is the update of axes + new_axes = list(axes) + for i in range(len(new_axes)): + if (new_axes[i] < 0): + new_axes[i] += dims + if (max(new_axes) >= dims): + continue + # generate input data + self.input_shape = [1] * dims + for i in range(dims): + self.input_shape[i] = np.random.randint(1, 20) + + def generate_input1(attrs: List[Dict[str, Any]], batch): + self.input_shape[0] = batch + for i in new_axes: + self.input_shape[i] = 1 + return np.random.random(self.input_shape).astype( + np.float32) + + ops = self.generate_op_config(ops_config) + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "in_data": + TensorConfig( + data_gen=partial(generate_input1, dics, batch)) + }, + outputs=["out_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + + def generate_dynamic_shape(attrs): + max_shape = list(self.input_shape) + min_shape = list(self.input_shape) + opt_shape = list(self.input_shape) + for i in range(len(self.input_shape)): + max_shape[i] = max_shape[i] + 1 + self.dynamic_shape.min_input_shape = {"in_data": min_shape} + self.dynamic_shape.max_input_shape = {"in_data": max_shape} + self.dynamic_shape.opt_input_shape = {"in_data": opt_shape} + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + self.trt_param.max_batch_size = 9 + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 + + def add_skip_trt_case(self): + pass + + def test(self): + self.add_skip_trt_case() + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unsqueeze2.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unsqueeze2.py new file mode 100644 index 0000000000000..d8f573ca9a92f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unsqueeze2.py @@ -0,0 +1,124 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import unittest +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertUnsqueeze2Test(TrtLayerAutoScanTest): + + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + for dims in [2, 3, 4]: + for batch in [3, 4]: + for axes in [[-2, 3], [1], [2], [2, 3]]: + self.batch = batch + self.dims = dims + self.axes = axes + dics = [{"axes": axes}] + ops_config = [{ + "op_type": "unsqueeze2", + "op_inputs": { + "X": ["in_data"] + }, + "op_outputs": { + "Out": ["out_data"], + "XShape": ["XShape_data"] + }, + "op_attrs": dics[0] + }] + + # generate input data + self.input_shape = [1] * dims + for i in range(dims): + self.input_shape[i] = np.random.randint(1, 20) + + def generate_input1(attrs: List[Dict[str, Any]], batch): + self.input_shape[0] = batch + return np.random.random(self.input_shape).astype( + np.float32) + + ops = self.generate_op_config(ops_config) + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "in_data": + TensorConfig( + data_gen=partial(generate_input1, dics, batch)) + }, + outputs=["out_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + + def generate_dynamic_shape(attrs): + max_shape = list(self.input_shape) + min_shape = list(self.input_shape) + opt_shape = list(self.input_shape) + for i in range(len(self.input_shape)): + max_shape[i] = max_shape[i] + 1 + self.dynamic_shape.min_input_shape = {"in_data": min_shape} + self.dynamic_shape.max_input_shape = {"in_data": max_shape} + self.dynamic_shape.opt_input_shape = {"in_data": opt_shape} + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + self.trt_param.max_batch_size = 9 + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 + + def add_skip_trt_case(self): + pass + + def test(self): + self.add_skip_trt_case() + self.run_test() + + +if __name__ == "__main__": + unittest.main()