diff --git a/cmake/paddle2onnx.cmake b/cmake/paddle2onnx.cmake old mode 100644 new mode 100755 index 3fc84c77fd..de52b6abca --- a/cmake/paddle2onnx.cmake +++ b/cmake/paddle2onnx.cmake @@ -43,7 +43,7 @@ else() endif(WIN32) set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/fastdeploy/third_libs/") -set(PADDLE2ONNX_VERSION "1.0.1rc") +set(PADDLE2ONNX_VERSION "1.0.1") if(WIN32) set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip") if(NOT CMAKE_CL_64) diff --git a/fastdeploy/backends/paddle/paddle_backend.cc b/fastdeploy/backends/paddle/paddle_backend.cc index 831aae0a81..b1d9282e1b 100644 --- a/fastdeploy/backends/paddle/paddle_backend.cc +++ b/fastdeploy/backends/paddle/paddle_backend.cc @@ -16,13 +16,23 @@ namespace fastdeploy { -void PaddleBackend::BuildOption(const PaddleBackendOption& option) { +void PaddleBackend::BuildOption(const PaddleBackendOption& option, + const std::string& model_file) { if (option.use_gpu) { config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id); } else { config_.DisableGpu(); if (option.enable_mkldnn) { config_.EnableMKLDNN(); + std::string contents; + if (!ReadBinaryFromFile(model_file, &contents)) { + return; + } + auto reader = + paddle2onnx::PaddleReader(contents.c_str(), contents.size()); + if (reader.is_quantize_model) { + config_.EnableMkldnnInt8(); + } config_.SetMkldnnCacheCapacity(option.mkldnn_cache_size); } } @@ -52,7 +62,7 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file, return false; } config_.SetModel(model_file, params_file); - BuildOption(option); + BuildOption(option, model_file); predictor_ = paddle_infer::CreatePredictor(config_); std::vector input_names = predictor_->GetInputNames(); std::vector output_names = predictor_->GetOutputNames(); diff --git a/fastdeploy/backends/paddle/paddle_backend.h b/fastdeploy/backends/paddle/paddle_backend.h old mode 100644 new mode 100755 index 43f5a4a174..b14a0e27c0 --- a/fastdeploy/backends/paddle/paddle_backend.h +++ b/fastdeploy/backends/paddle/paddle_backend.h @@ -20,6 +20,7 @@ #include #include "fastdeploy/backends/backend.h" +#include "paddle2onnx/converter.h" #include "paddle_inference_api.h" // NOLINT namespace fastdeploy { @@ -61,7 +62,8 @@ class PaddleBackend : public BaseBackend { public: PaddleBackend() {} virtual ~PaddleBackend() = default; - void BuildOption(const PaddleBackendOption& option); + void BuildOption(const PaddleBackendOption& option, + const std::string& model_file); bool InitFromPaddle( const std::string& model_file, const std::string& params_file, diff --git a/fastdeploy/backends/tensorrt/trt_backend.cc b/fastdeploy/backends/tensorrt/trt_backend.cc index bec41d761a..38171a596a 100644 --- a/fastdeploy/backends/tensorrt/trt_backend.cc +++ b/fastdeploy/backends/tensorrt/trt_backend.cc @@ -131,10 +131,13 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file, } char* model_content_ptr; int model_content_size = 0; + char* calibration_cache_ptr; + int calibration_cache_size = 0; if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(), &model_content_ptr, &model_content_size, 11, true, verbose, true, true, true, custom_ops.data(), - custom_ops.size())) { + custom_ops.size(), "tensorrt", + &calibration_cache_ptr, &calibration_cache_size)) { FDERROR << "Error occured while export PaddlePaddle to ONNX format." << std::endl; return false; @@ -151,6 +154,13 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file, delete[] model_content_ptr; std::string onnx_model_proto(new_model, new_model + new_model_size); delete[] new_model; + if (calibration_cache_size) { + std::string calibration_str( + calibration_cache_ptr, + calibration_cache_ptr + calibration_cache_size); + calibration_str_ = calibration_str; + delete[] calibration_cache_ptr; + } return InitFromOnnx(onnx_model_proto, option, true); } @@ -158,6 +168,12 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file, model_content_ptr + model_content_size); delete[] model_content_ptr; model_content_ptr = nullptr; + if (calibration_cache_size) { + std::string calibration_str(calibration_cache_ptr, + calibration_cache_ptr + calibration_cache_size); + calibration_str_ = calibration_str; + delete[] calibration_cache_ptr; + } return InitFromOnnx(onnx_model_proto, option, true); #else FDERROR << "Didn't compile with PaddlePaddle frontend, you can try to " @@ -409,6 +425,7 @@ bool TrtBackend::BuildTrtEngine() { "will use FP32 instead." << std::endl; } else { + FDINFO << "[TrtBackend] Use FP16 to inference." << std::endl; config->setFlag(nvinfer1::BuilderFlag::kFP16); } } @@ -459,6 +476,20 @@ bool TrtBackend::BuildTrtEngine() { } config->addOptimizationProfile(profile); + if (calibration_str_.size()) { + if (!builder_->platformHasFastInt8()) { + FDWARNING << "Detected INT8 is not supported in the current GPU, " + "will use FP32 instead." + << std::endl; + } else { + FDINFO << "[TrtBackend] Use INT8 to inference." << std::endl; + config->setFlag(nvinfer1::BuilderFlag::kINT8); + Int8EntropyCalibrator2* calibrator = + new Int8EntropyCalibrator2(calibration_str_); + config->setInt8Calibrator(calibrator); + } + } + FDUniquePtr plan{ builder_->buildSerializedNetwork(*network_, *config)}; if (!plan) { diff --git a/fastdeploy/backends/tensorrt/trt_backend.h b/fastdeploy/backends/tensorrt/trt_backend.h old mode 100644 new mode 100755 index 82b43ab46c..ad3ace6a43 --- a/fastdeploy/backends/tensorrt/trt_backend.h +++ b/fastdeploy/backends/tensorrt/trt_backend.h @@ -26,6 +26,32 @@ #include "fastdeploy/backends/backend.h" #include "fastdeploy/backends/tensorrt/utils.h" +class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 { + public: + explicit Int8EntropyCalibrator2(const std::string& calibration_cache) + : calibration_cache_(calibration_cache) {} + + int getBatchSize() const noexcept override { return 0; } + + bool getBatch(void* bindings[], const char* names[], + int nbBindings) noexcept override { + return false; + } + + const void* readCalibrationCache(size_t& length) noexcept override { + length = calibration_cache_.size(); + return length ? calibration_cache_.data() : nullptr; + } + + void writeCalibrationCache(const void* cache, + size_t length) noexcept override { + std::cout << "NOT IMPLEMENT." << std::endl; + } + + private: + const std::string calibration_cache_; +}; + namespace fastdeploy { struct TrtValueInfo { @@ -95,6 +121,8 @@ class TrtBackend : public BaseBackend { std::map inputs_buffer_; std::map outputs_buffer_; + std::string calibration_str_; + // Sometimes while the number of outputs > 1 // the output order of tensorrt may not be same // with the original onnx model diff --git a/fastdeploy/pybind/runtime.cc b/fastdeploy/pybind/runtime.cc old mode 100644 new mode 100755 index 023daaf74e..685b44222a --- a/fastdeploy/pybind/runtime.cc +++ b/fastdeploy/pybind/runtime.cc @@ -25,6 +25,7 @@ void BindRuntime(pybind11::module& m) { .def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum) .def("use_paddle_backend", &RuntimeOption::UsePaddleBackend) .def("use_ort_backend", &RuntimeOption::UseOrtBackend) + .def("set_ort_graph_opt_level", &RuntimeOption::SetOrtGraphOptLevel) .def("use_trt_backend", &RuntimeOption::UseTrtBackend) .def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend) .def("use_lite_backend", &RuntimeOption::UseLiteBackend) diff --git a/fastdeploy/runtime.cc b/fastdeploy/runtime.cc old mode 100644 new mode 100755 index 2bbe643ae5..44cd5e9cba --- a/fastdeploy/runtime.cc +++ b/fastdeploy/runtime.cc @@ -198,6 +198,13 @@ void RuntimeOption::SetCpuThreadNum(int thread_num) { cpu_thread_num = thread_num; } +void RuntimeOption::SetOrtGraphOptLevel(int level) { + std::vector supported_level{-1, 0, 1, 2}; + auto valid_level = std::find(supported_level.begin(), supported_level.end(), level) != supported_level.end(); + FDASSERT(valid_level, "The level must be -1, 0, 1, 2."); + ort_graph_opt_level = level; +} + // use paddle inference backend void RuntimeOption::UsePaddleBackend() { #ifdef ENABLE_PADDLE_BACKEND diff --git a/fastdeploy/runtime.h b/fastdeploy/runtime.h old mode 100644 new mode 100755 index a4d857ac33..fc5c1b824f --- a/fastdeploy/runtime.h +++ b/fastdeploy/runtime.h @@ -22,6 +22,7 @@ #include #include +#include #include "fastdeploy/backends/backend.h" #include "fastdeploy/utils/perf.h" @@ -104,6 +105,9 @@ struct FASTDEPLOY_DECL RuntimeOption { */ void SetCpuThreadNum(int thread_num); + /// Use ORT graph opt level + void SetOrtGraphOptLevel(int level = -1); + /// Set Paddle Inference as inference backend, support CPU/GPU void UsePaddleBackend(); diff --git a/python/fastdeploy/runtime.py b/python/fastdeploy/runtime.py old mode 100644 new mode 100755 index 17cb8fe3a7..901ef953c2 --- a/python/fastdeploy/runtime.py +++ b/python/fastdeploy/runtime.py @@ -117,6 +117,9 @@ def set_cpu_thread_num(self, thread_num=-1): """ return self._option.set_cpu_thread_num(thread_num) + def set_ort_graph_opt_level(self, level=-1): + return self._option.set_ort_graph_opt_level(level) + def use_paddle_backend(self): """Use Paddle Inference backend, support inference Paddle model on CPU/Nvidia GPU. """ diff --git a/tests/eval_example/test_quantize_diff.py b/tests/eval_example/test_quantize_diff.py new file mode 100755 index 0000000000..2c9454dd34 --- /dev/null +++ b/tests/eval_example/test_quantize_diff.py @@ -0,0 +1,96 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fastdeploy as fd +import cv2 +import os +import pickle +import numpy as np + +model_url = "https://bj.bcebos.com/fastdeploy/tests/yolov6_quant.tgz" +fd.download_and_decompress(model_url, ".") + + +def test_quant_mkldnn(): + model_path = "./yolov6_quant" + model_file = os.path.join(model_path, "model.pdmodel") + params_file = os.path.join(model_path, "model.pdiparams") + + input_file = os.path.join(model_path, "input.npy") + output_file = os.path.join(model_path, "mkldnn_output.npy") + + option = fd.RuntimeOption() + option.use_paddle_backend() + option.use_cpu() + + option.set_model_path(model_file, params_file) + runtime = fd.Runtime(option) + input_name = runtime.get_input_info(0).name + data = np.load(input_file) + outs = runtime.infer({input_name: data}) + expected = np.load(output_file) + diff = np.fabs(outs[0] - expected) + thres = 1e-05 + assert diff.max() < thres, "The diff is %f, which is bigger than %f" % ( + diff.max(), thres) + + +def test_quant_ort(): + model_path = "./yolov6_quant" + model_file = os.path.join(model_path, "model.pdmodel") + params_file = os.path.join(model_path, "model.pdiparams") + + input_file = os.path.join(model_path, "input.npy") + output_file = os.path.join(model_path, "ort_output.npy") + + option = fd.RuntimeOption() + option.use_ort_backend() + option.use_cpu() + + option.set_ort_graph_opt_level(1) + + option.set_model_path(model_file, params_file) + runtime = fd.Runtime(option) + input_name = runtime.get_input_info(0).name + data = np.load(input_file) + outs = runtime.infer({input_name: data}) + expected = np.load(output_file) + diff = np.fabs(outs[0] - expected) + thres = 1e-05 + assert diff.max() < thres, "The diff is %f, which is bigger than %f" % ( + diff.max(), thres) + + +def test_quant_trt(): + model_path = "./yolov6_quant" + model_file = os.path.join(model_path, "model.pdmodel") + params_file = os.path.join(model_path, "model.pdiparams") + + input_file = os.path.join(model_path, "input.npy") + output_file = os.path.join(model_path, "trt_output.npy") + + option = fd.RuntimeOption() + option.use_trt_backend() + option.use_gpu() + + option.set_model_path(model_file, params_file) + runtime = fd.Runtime(option) + input_name = runtime.get_input_info(0).name + data = np.load(input_file) + outs = runtime.infer({input_name: data}) + expected = np.load(output_file) + diff = np.fabs(outs[0] - expected) + thres = 1e-05 + assert diff.max() < thres, "The diff is %f, which is bigger than %f" % ( + diff.max(), thres)