From d41491b831e4ae29d5da78e8be943e9d13746ee3 Mon Sep 17 00:00:00 2001 From: wenbin Date: Mon, 6 Sep 2021 11:57:18 +0000 Subject: [PATCH 1/5] conv3d --- .../fluid/inference/api/analysis_predictor.cc | 2 + .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../inference/tensorrt/convert/conv3d_op.cc | 174 ++++++++++++++++++ paddle/fluid/inference/tensorrt/engine.h | 14 +- paddle/fluid/inference/tensorrt/op_teller.cc | 143 +++++++++----- .../unittests/ir/inference/CMakeLists.txt | 1 + .../ir/inference/test_trt_conv3d_op.py | 152 +++++++++++++++ .../inference/test_trt_conv3d_transpose_op.py | 131 +++++++++++++ 8 files changed, 568 insertions(+), 50 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/conv3d_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_transpose_op.py diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index b31b5f906b9b9..25f4cfea0ba24 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1257,6 +1257,8 @@ USE_TRT_CONVERTER(reduce_sum); USE_TRT_CONVERTER(gather_nd); USE_TRT_CONVERTER(reduce_mean); USE_TRT_CONVERTER(tile); +USE_TRT_CONVERTER(conv3d); +USE_TRT_CONVERTER(conv3d_transpose); #endif namespace paddle_infer { diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 63d9114e1acda..c79915629b70d 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -16,6 +16,7 @@ nv_library(tensorrt_converter reduce_op.cc gather_nd_op.cc tile_op.cc + conv3d_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc new file mode 100644 index 0000000000000..e3dc268ff9085 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc @@ -0,0 +1,174 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +template +void ConvertConv3d(TensorRTEngine* engine, const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode, + RegistFunc fadd_layer, SetDilationFunc fset_dilation, + const std::string& name) { + VLOG(3) << "convert a fluid " << name << " op to tensorrt layer without bias"; + + framework::OpDesc op_desc(op, nullptr); + + auto* X = engine->GetITensor(op_desc.Input("Input").front()); + std::string filter_var_name = op_desc.Input("Filter").front(); + auto* Y_v = scope.FindVar(filter_var_name); + PADDLE_ENFORCE_NOT_NULL( + Y_v, platform::errors::NotFound( + "Can not find %s presistale var in scope.", filter_var_name)); + auto* Y_t = Y_v->GetMutable(); + float* weight_data = nullptr; + bool enable_int8 = op_desc.HasAttr("enable_int8"); + + if (enable_int8) { +#if IS_TRT_VERSION_GE(5000) + float in_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127; + auto weight_scale = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("weight_scale")); + weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, + true, weight_scale); + engine->SetTensorDynamicRange(X, in_scale); +#endif + } else { + weight_data = + engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, false); + } + + PADDLE_ENFORCE_EQ(Y_t->dims().size(), 5UL, + platform::errors::InvalidArgument( + "The conv3d filter's dims size should be 5, but got %d", + Y_t->dims().size())); + + const int n_output = Y_t->dims()[0]; + const int n_input = Y_t->dims()[1]; + const int filter_d = Y_t->dims()[2]; + const int filter_h = Y_t->dims()[3]; + const int filter_w = Y_t->dims()[4]; + const int groups = BOOST_GET_CONST(int, op_desc.GetAttr("groups")); + const std::vector dilations = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("dilations")); + const std::vector strides = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("strides")); + const std::vector paddings = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("paddings")); + std::string padding_algorithm = "EXPLICIT"; + if (op_desc.HasAttr("padding_algorithm")) + padding_algorithm = + BOOST_GET_CONST(std::string, op_desc.GetAttr("padding_algorithm")); + + nvinfer1::Dims3 nv_ksize(filter_d, filter_h, filter_w); + nvinfer1::Dims3 nv_dilations(dilations[0], dilations[1], dilations[2]); + nvinfer1::Dims3 nv_strides(strides[0], strides[1], strides[2]); + nvinfer1::Dims3 nv_paddings(paddings[0], paddings[1], paddings[2]); + + TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(Y_t->numel())}; + float* bias_data = nullptr; + size_t bias_size = 0; + + TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, + static_cast(bias_data), bias_size}; + // In conv3d_transpose output channels = filter_dims[1] * groups + auto* layer = (op_desc.Type() == "conv3d_transpose") + ? fadd_layer(const_cast(X), + n_input * groups, nv_ksize, weight, bias) + : fadd_layer(const_cast(X), n_output, + nv_ksize, weight, bias); + + PADDLE_ENFORCE_NOT_NULL( + layer, platform::errors::Fatal("TensorRT create conv3d/conv3d_transpose" + " layer failed.")); + layer->setStrideNd(nv_strides); + layer->setPaddingNd(nv_paddings); + layer->setNbGroups(groups); + if (padding_algorithm == "SAME") { + layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); + } + // set dilations + fset_dilation(layer, nv_dilations); + + auto output_name = op_desc.Output("Output").front(); + layer->setName((name + " (Output: " + output_name + ")").c_str()); + layer->getOutput(0)->setName(output_name.c_str()); + engine->SetITensor(output_name, layer->getOutput(0)); + + if (test_mode) { + engine->DeclareOutput(output_name); + } +} + +class Conv3dOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + ConvertConv3d( + engine_, op, scope, test_mode, + [&](nvinfer1::ITensor* inputs, int n_output, /* Conv output maps */ + nvinfer1::Dims& ksize, TensorRTEngine::Weight& weight, + TensorRTEngine::Weight& bias) -> nvinfer1::IConvolutionLayer* { + auto* layer = + TRT_ENGINE_ADD_LAYER(engine_, ConvolutionNd, *inputs, n_output, + ksize, weight.get(), bias.get()); + return layer; + }, + [](nvinfer1::IConvolutionLayer* layer, nvinfer1::Dims& dilations) { + layer->setDilationNd(dilations); + }, + "conv3d"); + } +}; + +class Deconv3dOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + ConvertConv3d( + engine_, op, scope, test_mode, + [&](nvinfer1::ITensor* inputs, int n_output, /* Deconv input maps */ + nvinfer1::Dims& ksize, TensorRTEngine::Weight& weight, + TensorRTEngine::Weight& bias) -> nvinfer1::IDeconvolutionLayer* { + auto* layer = + TRT_ENGINE_ADD_LAYER(engine_, DeconvolutionNd, *inputs, n_output, + ksize, weight.get(), bias.get()); + return layer; + }, + [](nvinfer1::IDeconvolutionLayer* layer, nvinfer1::Dims& dilations) {}, + "conv3d_transpose"); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(conv3d, Conv3dOpConverter); +REGISTER_TRT_OP_CONVERTER(conv3d_transpose, Deconv3dOpConverter); diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 3604a47a7eb90..721af98ce9ba6 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -76,11 +76,7 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape, std::string input, "TensorRT's tensor input requires at least 1 " "dimensions, but input %s has %d dims.", input, shape.size())); - PADDLE_ENFORCE_LE(shape.size(), 4UL, - platform::errors::InvalidArgument( - "TensorRT's tensor input requires at most 4 " - "dimensions, but input %s has %d dims.", - input, shape.size())); + auto ShapeStr = [](const std::vector& shape) { std::ostringstream os; os << "["; @@ -103,6 +99,14 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape, std::string input, input, ShapeStr(shape))); } return nvinfer1::Dims3(shape[1], shape[2], shape[3]); + } else if (shape.size() == 5UL) { + if (shape[2] == -1 || shape[3] == -1 || shape[4] == -1) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The input [%s] shape of trt subgraph is %s, please enable " + "trt dynamic_shape mode by SetTRTDynamicShapeInfo.", + input, ShapeStr(shape))); + } + return nvinfer1::Dims4(shape[1], shape[2], shape[3], shape[4]); } else if (shape.size() == 3UL) { if (shape[1] == -1 || shape[2] == -1) { PADDLE_THROW(platform::errors::InvalidArgument( diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 72f20790f3524..47f290d411776 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -90,51 +90,51 @@ struct SimpleOpTypeSetTeller : public Teller { "elementwise_mul", "conv2d_transpose", "hard_swish"}; - std::unordered_set teller_set{ - "mul", - "matmul", - "conv2d", - "conv2d_fusion", - "pool2d", - "relu", - "softmax", - "sigmoid", - "hard_swish", - "depthwise_conv2d", - "batch_norm", - "concat", - "tanh", - "pad", - "elementwise_add", - "elementwise_mul", - "dropout", - "prelu", - "conv2d_transpose", - "depthwise_conv2d_transpose", - "leaky_relu", - "fc", - "shuffle_channel", - "swish", - "split", - "instance_norm", - "gelu", - "layer_norm", - "scale", - "stack", - "transpose2", - "transpose", - "flatten2", - "flatten", - "gather", - "gather_nd", - "yolo_box", - "roi_align", - "affine_channel", - "nearest_interp", - "anchor_generator", - "reduce_sum", - "reduce_mean", - }; + std::unordered_set teller_set{"mul", + "matmul", + "conv2d", + "conv2d_fusion", + "pool2d", + "relu", + "softmax", + "sigmoid", + "hard_swish", + "depthwise_conv2d", + "batch_norm", + "concat", + "tanh", + "pad", + "elementwise_add", + "elementwise_mul", + "dropout", + "prelu", + "conv2d_transpose", + "depthwise_conv2d_transpose", + "leaky_relu", + "fc", + "shuffle_channel", + "swish", + "split", + "instance_norm", + "gelu", + "layer_norm", + "scale", + "stack", + "transpose2", + "transpose", + "flatten2", + "flatten", + "gather", + "gather_nd", + "yolo_box", + "roi_align", + "affine_channel", + "nearest_interp", + "anchor_generator", + "reduce_sum", + "reduce_mean", + "conv3d", + "conv3d_transpose"}; }; bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, @@ -758,6 +758,59 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } #endif + if (op_type == "conv3d" || op_type == "conv3d_transpose") { + std::vector paddings = + BOOST_GET_CONST(std::vector, desc.GetAttr("paddings")); + + if (desc.HasAttr("padding_algorithm")) { + std::string padding_algorithm = + BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm")); + + // trt error is arised if conv3d_transpose and SAME + if (op_type == "conv3d_transpose" && padding_algorithm == "SAME" && + !with_dynamic_shape) { + return false; + } + } + + // conv3d and conv3d_transpose need padding check + if (paddings.size() > 3) return false; + + if (desc.Input("Input").size() != 1) { + VLOG(3) << "TRT Conv3d expect 1 input, but got " + << desc.Input("Input").size() << " input."; + return false; + } + + if (desc.Input("Filter").size() != 1) { + VLOG(3) << "TRT Conv3d expect 1 filter, but got " + << desc.Input("Filter").size() << " filter."; + return false; + } + + if (op_type == "conv3d_transpose") { + if (!desc.HasAttr("dilations")) { + return false; + } else { + const std::vector dilations = + BOOST_GET_CONST(std::vector, desc.GetAttr("dilations")); + if (dilations[0] != 1 || dilations[1] != 1 || dilations[2] != 1) { + VLOG(3) << "In conv3d_transpose, Dilations must be (1, 1, 1) for " + "tensorRT, but given (" + << dilations[0] << ", " << dilations[1] << ", " + << dilations[2] << ")"; + return false; + } + } + } + + if (desc.Output("Output").size() != 1) { + VLOG(3) << "TRT Conv3d expect 1 output, but got " + << desc.Output("Output").size() << " output."; + return false; + } + } + if ((*teller)(op_type, desc, use_no_calib_int8)) return true; } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index f73327f8248d8..2c24e6862c222 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -55,4 +55,5 @@ set_tests_properties(test_trt_tile_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_fc_fuse_quant_dequant_pass PROPERTIES TIMEOUT 100) set_tests_properties(test_trt_conv_quant_dequant_pass PROPERTIES TIMEOUT 100) set_tests_properties(test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 100) +set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60) endif() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py new file mode 100644 index 0000000000000..43faaa67e9bf0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py @@ -0,0 +1,152 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import PassVersionChecker +from paddle.fluid.core import AnalysisConfig + + +class TensorRTSubgraphPassConv3dTest(InferencePassTest): + def setUp(self): + self.init_params() + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 3, 12, 32, 32], dtype="float32") + conv_out = fluid.layers.conv3d( + input=data, + num_filters=self.conv_num_filters, + filter_size=self.conv_filter_size, + groups=self.conv_groups, + padding=self.conv_padding, + bias_attr=False, + use_cudnn=self.use_cudnn, + stride=self.stride, + act=None) + self.feeds = { + "data": np.random.random([1, 3, 12, 32, 32]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassConv3dTest.TensorRTParam( + 1 << 30, 32, 1, self.precision, self.use_static, False) + self.fetch_list = [conv_out] + + def init_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 3 + self.conv_padding = [1, 1, 1] + self.use_cudnn = True + self.use_static = False + self.precision = AnalysisConfig.Precision.Float32 + self.stride = 1 + + def set_params(self): + pass + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TensorRTSubgraphPassConv3dValidPaddingTest( + TensorRTSubgraphPassConv3dTest): + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 3 + self.conv_padding = 'VALID' + + +class TensorRTSubgraphPassConv3dSamePaddingTest(TensorRTSubgraphPassConv3dTest): + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 3 + self.conv_padding = 'SAME' + + +class TensorRTSubgraphPassConv3dPaddingTest(TensorRTSubgraphPassConv3dTest): + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 3 + self.conv_padding = [2, 3, 3] + + +class TensorRTSubgraphPassConv3dStrideTest(TensorRTSubgraphPassConv3dTest): + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 3 + self.conv_padding = 'SAME' + self.stride = [1, 2, 2] + + +class DynamicShapeTensorRTSubgraphPassConv3dTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 6, -1, -1, -1], dtype="float32") + conv_out = fluid.layers.conv3d( + input=data, + num_filters=self.conv_num_filters, + filter_size=self.conv_filter_size, + groups=self.conv_groups, + padding=self.conv_padding, + bias_attr=False, + use_cudnn=self.use_cudnn, + stride=self.stride, + act=None) + self.feeds = { + "data": np.random.random([1, 6, 64, 64, 8]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = DynamicShapeTensorRTSubgraphPassConv3dTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.dynamic_shape_params = DynamicShapeTensorRTSubgraphPassConv3dTest.DynamicShapeParam( + { + "data": [1, 6, 8, 8, 8], + }, {"data": [32, 6, 64, 64, 8], }, {"data": [16, 6, 16, 16, 8], + }, False) + self.fetch_list = [conv_out] + + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 6 + self.conv_padding = 'SAME' + self.use_cudnn = True + self.stride = [2, 2, 2] + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_transpose_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_transpose_op.py new file mode 100644 index 0000000000000..d2304e09c7177 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_transpose_op.py @@ -0,0 +1,131 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import PassVersionChecker +from paddle.fluid.core import AnalysisConfig + + +class TensorRTSubgraphPassConv3dTransposeTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 4, 4, 64, 64], dtype="float32") + conv_out = fluid.layers.conv3d_transpose( + input=data, + num_filters=self.conv_num_filters, + filter_size=self.conv_filter_size, + groups=self.conv_groups, + padding=self.conv_padding, + bias_attr=False, + use_cudnn=self.use_cudnn, + stride=1, + act=None) + self.feeds = { + "data": np.random.random([1, 4, 4, 64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassConv3dTransposeTest.TensorRTParam( + 1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [conv_out] + + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 1 + self.conv_padding = [1, 1, 1] + self.use_cudnn = True + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TensorRTSubgraphPassConv3dTransposeSamePaddingTest( + TensorRTSubgraphPassConv3dTransposeTest): + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 1 + self.conv_padding = 'VALID' + self.use_cudnn = True + + +class TensorRTSubgraphPassConv3dTransposeMultigroupTest( + TensorRTSubgraphPassConv3dTransposeTest): + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 2 + self.conv_padding = 'VALID' + self.use_cudnn = True + + +class DynamicShapeTensorRTSubgraphPassConv3dTransposeTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 6, -1, -1, -1], dtype="float32") + conv_out = fluid.layers.conv3d_transpose( + input=data, + num_filters=self.conv_num_filters, + filter_size=self.conv_filter_size, + groups=self.conv_groups, + padding=self.conv_padding, + bias_attr=False, + use_cudnn=self.use_cudnn, + stride=self.stride, + act=None) + self.feeds = { + "data": np.random.random([1, 6, 64, 64, 8]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = DynamicShapeTensorRTSubgraphPassConv3dTransposeTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.dynamic_shape_params = DynamicShapeTensorRTSubgraphPassConv3dTransposeTest.DynamicShapeParam( + { + "data": [1, 6, 8, 8, 8], + }, {"data": [32, 6, 64, 64, 8], }, {"data": [16, 6, 16, 16, 8], + }, False) + self.fetch_list = [conv_out] + + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 6 + self.conv_padding = 'SAME' + self.use_cudnn = True + self.stride = [2, 2, 2] + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +if __name__ == "__main__": + unittest.main() From 4e11471b81f3c8d73262851f8f281963afac2c0a Mon Sep 17 00:00:00 2001 From: wenbin Date: Mon, 6 Sep 2021 12:11:37 +0000 Subject: [PATCH 2/5] remove const_cast --- paddle/fluid/inference/tensorrt/convert/conv3d_op.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc index e3dc268ff9085..f714fb8942b8f 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc @@ -99,10 +99,8 @@ void ConvertConv3d(TensorRTEngine* engine, const framework::proto::OpDesc& op, static_cast(bias_data), bias_size}; // In conv3d_transpose output channels = filter_dims[1] * groups auto* layer = (op_desc.Type() == "conv3d_transpose") - ? fadd_layer(const_cast(X), - n_input * groups, nv_ksize, weight, bias) - : fadd_layer(const_cast(X), n_output, - nv_ksize, weight, bias); + ? fadd_layer(X, n_input * groups, nv_ksize, weight, bias) + : fadd_layer(X, n_output, nv_ksize, weight, bias); PADDLE_ENFORCE_NOT_NULL( layer, platform::errors::Fatal("TensorRT create conv3d/conv3d_transpose" From 3efa3a49bbd4939d2db36fd418ab2600d7c75e08 Mon Sep 17 00:00:00 2001 From: wenbin Date: Tue, 7 Sep 2021 02:36:23 +0000 Subject: [PATCH 3/5] modify ut --- .../fluid/tests/unittests/ir/inference/CMakeLists.txt | 1 + .../tests/unittests/ir/inference/test_trt_conv3d_op.py | 8 ++++---- .../ir/inference/test_trt_conv3d_transpose_op.py | 8 ++++---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 2c24e6862c222..16d26a93f9a21 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -56,4 +56,5 @@ set_tests_properties(test_trt_fc_fuse_quant_dequant_pass PROPERTIES TIMEOUT 100) set_tests_properties(test_trt_conv_quant_dequant_pass PROPERTIES TIMEOUT 100) set_tests_properties(test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 100) set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60) +set_tests_properties(test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60) endif() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py index 43faaa67e9bf0..13aca246a65dd 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py @@ -29,7 +29,7 @@ def setUp(self): self.set_params() with fluid.program_guard(self.main_program, self.startup_program): data = fluid.data( - name="data", shape=[-1, 3, 12, 32, 32], dtype="float32") + name="data", shape=[-1, 3, 6, 32, 32], dtype="float32") conv_out = fluid.layers.conv3d( input=data, num_filters=self.conv_num_filters, @@ -41,7 +41,7 @@ def setUp(self): stride=self.stride, act=None) self.feeds = { - "data": np.random.random([1, 3, 12, 32, 32]).astype("float32"), + "data": np.random.random([1, 3, 6, 32, 32]).astype("float32"), } self.enable_trt = True self.trt_parameters = TensorRTSubgraphPassConv3dTest.TensorRTParam( @@ -120,7 +120,7 @@ def setUp(self): stride=self.stride, act=None) self.feeds = { - "data": np.random.random([1, 6, 64, 64, 8]).astype("float32"), + "data": np.random.random([1, 6, 32, 32, 8]).astype("float32"), } self.enable_trt = True self.trt_parameters = DynamicShapeTensorRTSubgraphPassConv3dTest.TensorRTParam( @@ -128,7 +128,7 @@ def setUp(self): self.dynamic_shape_params = DynamicShapeTensorRTSubgraphPassConv3dTest.DynamicShapeParam( { "data": [1, 6, 8, 8, 8], - }, {"data": [32, 6, 64, 64, 8], }, {"data": [16, 6, 16, 16, 8], + }, {"data": [32, 6, 32, 32, 8], }, {"data": [16, 6, 16, 16, 8], }, False) self.fetch_list = [conv_out] diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_transpose_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_transpose_op.py index d2304e09c7177..7e0ffbad9f498 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_transpose_op.py @@ -28,7 +28,7 @@ def setUp(self): self.set_params() with fluid.program_guard(self.main_program, self.startup_program): data = fluid.data( - name="data", shape=[-1, 4, 4, 64, 64], dtype="float32") + name="data", shape=[-1, 4, 4, 32, 32], dtype="float32") conv_out = fluid.layers.conv3d_transpose( input=data, num_filters=self.conv_num_filters, @@ -40,7 +40,7 @@ def setUp(self): stride=1, act=None) self.feeds = { - "data": np.random.random([1, 4, 4, 64, 64]).astype("float32"), + "data": np.random.random([1, 4, 4, 32, 32]).astype("float32"), } self.enable_trt = True self.trt_parameters = TensorRTSubgraphPassConv3dTransposeTest.TensorRTParam( @@ -99,7 +99,7 @@ def setUp(self): stride=self.stride, act=None) self.feeds = { - "data": np.random.random([1, 6, 64, 64, 8]).astype("float32"), + "data": np.random.random([1, 6, 32, 32, 8]).astype("float32"), } self.enable_trt = True self.trt_parameters = DynamicShapeTensorRTSubgraphPassConv3dTransposeTest.TensorRTParam( @@ -107,7 +107,7 @@ def setUp(self): self.dynamic_shape_params = DynamicShapeTensorRTSubgraphPassConv3dTransposeTest.DynamicShapeParam( { "data": [1, 6, 8, 8, 8], - }, {"data": [32, 6, 64, 64, 8], }, {"data": [16, 6, 16, 16, 8], + }, {"data": [32, 6, 32, 32, 8], }, {"data": [16, 6, 16, 16, 8], }, False) self.fetch_list = [conv_out] From 98c6cc5142a293c8aeb8c87bb314d9de1e7f10a6 Mon Sep 17 00:00:00 2001 From: wenbin Date: Tue, 7 Sep 2021 08:11:17 +0000 Subject: [PATCH 4/5] disable dynamic shape for trt6.0 --- paddle/fluid/inference/tensorrt/op_teller.cc | 12 +++++++++--- .../unittests/ir/inference/test_trt_conv3d_op.py | 10 ++++++++-- .../ir/inference/test_trt_conv3d_transpose_op.py | 10 ++++++++-- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 47f290d411776..9bdec858771b0 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -759,9 +759,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, #endif if (op_type == "conv3d" || op_type == "conv3d_transpose") { - std::vector paddings = - BOOST_GET_CONST(std::vector, desc.GetAttr("paddings")); - if (desc.HasAttr("padding_algorithm")) { std::string padding_algorithm = BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm")); @@ -773,6 +770,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } } +#if !IS_TRT_VERSION_GE(7000) + // looks like some issues with trt6.0 + if (with_dynamic_shape) { + return false; + } +#endif + std::vector paddings = + BOOST_GET_CONST(std::vector, desc.GetAttr("paddings")); + // conv3d and conv3d_transpose need padding check if (paddings.size() > 3) return false; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py index 13aca246a65dd..8bca7af3f0d23 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py @@ -128,8 +128,14 @@ def setUp(self): self.dynamic_shape_params = DynamicShapeTensorRTSubgraphPassConv3dTest.DynamicShapeParam( { "data": [1, 6, 8, 8, 8], - }, {"data": [32, 6, 32, 32, 8], }, {"data": [16, 6, 16, 16, 8], - }, False) + "conv3d_0.tmp_0": [1, 6, 8, 8, 4], + }, { + "data": [32, 6, 32, 32, 8], + "conv3d_0.tmp_0": [32, 6, 32, 32, 8], + }, { + "data": [16, 6, 16, 16, 8], + "conv3d_0.tmp_0": [16, 6, 16, 16, 8], + }, False) self.fetch_list = [conv_out] def set_params(self): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_transpose_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_transpose_op.py index 7e0ffbad9f498..dfec7ef9b4d7d 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_transpose_op.py @@ -107,8 +107,14 @@ def setUp(self): self.dynamic_shape_params = DynamicShapeTensorRTSubgraphPassConv3dTransposeTest.DynamicShapeParam( { "data": [1, 6, 8, 8, 8], - }, {"data": [32, 6, 32, 32, 8], }, {"data": [16, 6, 16, 16, 8], - }, False) + "conv3d_transpose_0.tmp_0": [1, 6, 8, 8, 1], + }, { + "data": [32, 6, 32, 32, 8], + "conv3d_transpose_0.tmp_0": [32, 6, 64, 64, 16], + }, { + "data": [16, 6, 16, 16, 8], + "conv3d_transpose_0.tmp_0": [16, 6, 16, 16, 8], + }, False) self.fetch_list = [conv_out] def set_params(self): From f0624fc41ec16a4d328b593da24c832ff8cc424c Mon Sep 17 00:00:00 2001 From: wenbin Date: Wed, 8 Sep 2021 06:43:43 +0000 Subject: [PATCH 5/5] remove trt5 --- paddle/fluid/inference/tensorrt/convert/conv3d_op.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc index f714fb8942b8f..dae92264d2c3e 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc @@ -48,7 +48,6 @@ void ConvertConv3d(TensorRTEngine* engine, const framework::proto::OpDesc& op, bool enable_int8 = op_desc.HasAttr("enable_int8"); if (enable_int8) { -#if IS_TRT_VERSION_GE(5000) float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127; auto weight_scale = @@ -56,7 +55,6 @@ void ConvertConv3d(TensorRTEngine* engine, const framework::proto::OpDesc& op, weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, true, weight_scale); engine->SetTensorDynamicRange(X, in_scale); -#endif } else { weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, false);