Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/engine support load model #10580

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/tensorrt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ find_path(TENSORRT_INCLUDE_DIR NvInfer.h
NO_DEFAULT_PATH
)

find_library(TENSORRT_LIBRARY NAMES libnvinfer.so libnvinfer.a
find_library(TENSORRT_LIBRARY NAMES libnvinfer.so libnvinfer.a libnvparsers.so libnvparsers.a
PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/lib
$ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/lib
NO_DEFAULT_PATH
Expand Down
18 changes: 9 additions & 9 deletions contrib/inference/paddle_inference_api.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/inference/tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto)
nv_library(tensorrt_engine SRCS engine.cc helper.h DEPS framework_proto)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine dynamic_loader nvparsers nvinfer)

add_subdirectory(convert)
99 changes: 93 additions & 6 deletions paddle/fluid/inference/tensorrt/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,31 @@ namespace inference {
namespace tensorrt {

void TensorRTEngine::Build(const DescType& paddle_model) {
PADDLE_ENFORCE(false, "not implemented");
PADDLE_THROW("not implemented");
}

void TensorRTEngine::BuildFromONNX(const std::string& model_path) {
PADDLE_ENFORCE(IsFileExists(model_path));
infer_builder_.reset(createInferBuilder(&logger_));
nvinfer1::IHostMemory* model;
OnnxToGIEModel("", model_path, max_batch_, model, &logger_);
infer_runtime_.reset(createInferRuntime(&logger_));
VLOG(4) << "build engine";
infer_engine_.reset(infer_runtime_->deserializeCudaEngine(
model->data(), model->size(), nullptr));
if (model) model->destroy();
VLOG(4) << "create context";
PADDLE_ENFORCE(infer_engine_ != nullptr);
infer_context_.reset(infer_engine_->createExecutionContext());

PADDLE_ENFORCE_GT(infer_engine_->getNbBindings(), 0,
"model do not have any inputs or outputs");
DetectInputsAndOutputs();
FreezeNetwork();
}

void TensorRTEngine::Execute(int batch_size) {
// TODO(Superjomn) consider to make buffers not a temp variable and resuable.
std::vector<void*> buffers;
for (auto& buf : buffers_) {
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated");
Expand Down Expand Up @@ -61,11 +82,17 @@ void TensorRTEngine::FreezeNetwork() {
infer_builder_->setMaxBatchSize(max_batch_);
infer_builder_->setMaxWorkspaceSize(max_workspace_);

infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");
if (!infer_engine_) {
VLOG(4) << "build cuda engine";
infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");

infer_context_.reset(infer_engine_->createExecutionContext());
VLOG(4) << "create execution context";
infer_context_.reset(infer_engine_->createExecutionContext());
}

VLOG(4) << "allocate gpu buffers";
PADDLE_ENFORCE(!buffer_sizes_.empty());
// allocate GPU buffers.
buffers_.resize(buffer_sizes_.size());
for (auto& item : buffer_sizes_) {
Expand All @@ -75,6 +102,7 @@ void TensorRTEngine::FreezeNetwork() {
infer_engine_->getBindingDataType(slot_offset))] *
AccumDims(infer_engine_->getBindingDimensions(slot_offset));
}
PADDLE_ENFORCE_GT(item.second, 0);
auto& buf = buffer(item.first);
CHECK(buf.buffer == nullptr); // buffer should be allocated only once.
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second));
Expand All @@ -89,14 +117,42 @@ nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
name);

PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
PADDLE_ENFORCE(infer_network_ != nullptr, "should InitNetwork first");
auto* input = infer_network_->addInput(name.c_str(), dtype, dim);
PADDLE_ENFORCE(input, "infer network add input %s failed", name);
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] * AccumDims(dim);
TensorRTEngine::SetITensor(name, input);
return input;
}

// nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name) {
// PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
// name);

// auto* x = TensorRTEngine::GetITensor(name);
// PADDLE_ENFORCE(x != nullptr);
// // output->setName(name.c_str());
// infer_network_->
// markInput(*x);
// // output buffers' size can only be decided latter, set zero here to mark
// this
// // and will reset latter.
// buffer_sizes_[name] = 0;
// }

nvinfer1::ITensor* TensorRTEngine::DeclareInput(int offset) {
// This is a trick to reuse some facility of the manual network building.
auto name = ibuffer_name(offset);
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
name);
PADDLE_ENFORCE(infer_network_ != nullptr, "should InitNetwork first");
auto* x = infer_network_->getInput(offset);
x->setName(name.c_str());
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(x->getType())] *
AccumDims(x->getDimensions());
return x;
}

void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
const std::string& name) {
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
Expand All @@ -117,13 +173,35 @@ void TensorRTEngine::DeclareOutput(const std::string& name) {

auto* output = TensorRTEngine::GetITensor(name);
PADDLE_ENFORCE(output != nullptr);
output->setName(name.c_str());
// output->setName(name.c_str());
infer_network_->markOutput(*output);
// output buffers' size can only be decided latter, set zero here to mark this
// and will reset latter.
buffer_sizes_[name] = 0;
}

void TensorRTEngine::DeclareOutput(int offset) {
// This is a trick to reuse some facility of the manual network building.
auto name = obuffer_name(offset);
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
name);
PADDLE_ENFORCE(infer_network_ != nullptr, "should InitNetwork first");
auto* x = infer_network_->getInput(offset);
x->setName(name.c_str());
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(x->getType())] *
AccumDims(x->getDimensions());
}

void TensorRTEngine::DetectInputsAndOutputs() {
for (int i = 0; i < infer_engine_->getNbBindings(); ++i) {
const auto* name = infer_engine_->getBindingName(i);
const auto& dims = infer_engine_->getBindingDimensions(i);
buffer_sizes_[name] = AccumDims(dims);
VLOG(4) << "get ONNX model input/output: " << name
<< " dims: " << buffer_sizes_[name];
}
}

void* TensorRTEngine::GetOutputInGPU(const std::string& name) {
return buffer(name).buffer;
}
Expand All @@ -149,6 +227,15 @@ Buffer& TensorRTEngine::buffer(const std::string& name) {
return buffers_[slot_offset];
}

Buffer& TensorRTEngine::ibuffer(int offset) {
auto name = ibuffer_name(offset);
PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end());
auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
return buffers_[slot_offset];
}

void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
size_t size) {
auto& buf = buffer(name);
Expand Down
113 changes: 97 additions & 16 deletions paddle/fluid/inference/tensorrt/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,53 @@ namespace tensorrt {
/*
* TensorRT Engine.
*
* There are two alternative ways to use it, one is to build from a paddle
* protobuf model, another way is to manully construct the network.
* There are two alternative ways to use it:
* 1. manually build the network by add layers, we call it the manual way,
* 2. Load from an ONNX model, we call it the ONNX way.
*
* The manual way:
*
* // Init
* TensorRTEngine engine(...);
* engine.InitNetwork();
*
* // Add layers one by one
* TRT_ENGINE_ADD_LAYER
*
* engine.DeclareInput("x", ...)
* engine.DeclareOutput("y", ...)
* engine.FreezeNetwork(); // end network building
*
* // Ready to predict for any times.
*
* // Set input data.
* cudaMemCpy(buffer(in), ...)
*
* engine.Execute();
*
* // Get output data.
* cudaMemCpy(..., buffer(out), ...)
*
* The ONNX way:
*
* TensorRTEngine engine(...);
* // Load model from ONNX.
* engine.BuildFromONNX(...);
*
* // Ready to predict for any times.
*
* // Set input data.
* cudaMemCpy(buffer(in), ...)
*
* engine.Execute(batch_size);
*
* // Get output data.
* for (int i = 0; i < num_outputs; i++) cudaMemCpy(..., buffer(i), ...)
*/
class TensorRTEngine : public EngineBase {
public:
// Weight is model parameter.
class Weight {
public:
Weight(nvinfer1::DataType dtype, void* value, int num_elem) {
w_.type = dtype;
w_.values = value;
w_.count = num_elem;
}
const nvinfer1::Weights& get() { return w_; }

private:
nvinfer1::Weights w_;
};
class Weight;

TensorRTEngine(int max_batch, int max_workspace, cudaStream_t* stream,
nvinfer1::ILogger& logger = NaiveLogger::Global())
Expand All @@ -60,6 +89,9 @@ class TensorRTEngine : public EngineBase {
// TODO(Superjomn) implement it later when graph segmentation is supported.
void Build(const DescType& paddle_model) override;

// Build the TensorRT engine with an ONNX model.
void BuildFromONNX(const std::string& model_path);

void Execute(int batch_size) override;

// Initialize the inference network, so that TensorRT layers can add to this
Expand All @@ -72,23 +104,46 @@ class TensorRTEngine : public EngineBase {
// environment.
void FreezeNetwork();

// Add an input and set its name, data type and dimention.
// Add an input and set its name, data type and dimention. This should be used
// in network manual building.
nvinfer1::ITensor* DeclareInput(const std::string& name,
nvinfer1::DataType dtype,
const nvinfer1::Dims& dim);

// Collect the input ITensor's information after the network is already built.
// It can be used in loading ONNX or other existing network.
nvinfer1::ITensor* DeclareInput(int offset);

nvinfer1::ITensor* DeclareInput(const std::string& name);

// Set the offset-th output from a layer as the network's output, and set its
// name.
void DeclareOutput(const nvinfer1::ILayer* layer, int offset,
const std::string& name);
// Set the itensor_map_[name] as the network's output, and set its name.
void DeclareOutput(const std::string& name);
// Collect the output ITensor's information after the network is already
// built. It can be used in loading ONNX model or other existing network.
void DeclareOutput(int offset);

// Detect inputs and outputs from an existing TensorRT Network load from ONNX
// or some other formats, prepare buffer for them.
void DetectInputsAndOutputs();

// GPU memory address for an ITensor with specific name. One can operate on
// these memory directly for acceleration, for example, output the converted
// data directly to the buffer to save data copy overhead.
// data directly to the buffer to save data copy overhead. This method can
// only be used in manual network building where the inputs and outputs are
// manually declared with an unique name.
// NOTE this should be used after calling `FreezeNetwork`.
Buffer& buffer(const std::string& name) override;

// The ibuffer, obuffer returns the offset-th input/output of the network.
// There are used in loading directly from an existing model because the input
// and output doesn't have unique names, and can only be identified by offset.
Buffer& ibuffer(int offset);
Buffer& obuffer(int offsert);

cudaStream_t* stream() { return stream_; }

// Fill an input from CPU memory with name and size.
Expand All @@ -110,6 +165,29 @@ class TensorRTEngine : public EngineBase {
nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); }
nvinfer1::INetworkDefinition* network() { return infer_network_.get(); }

class Weight {
public:
Weight(nvinfer1::DataType dtype, void* value, int num_elem) {
w_.type = dtype;
w_.values = value;
w_.count = num_elem;
}
const nvinfer1::Weights& get() { return w_; }

private:
nvinfer1::Weights w_;
};

protected:
// Get an input buffer's string id.
std::string ibuffer_name(int offset) const {
return "in-" + std::to_string(offset);
}
// Get an output buffer's string id.
std::string obuffer_name(int offset) const {
return "out-" + std::to_string(offset);
}

private:
// the max batch size
int max_batch_;
Expand All @@ -131,6 +209,9 @@ class TensorRTEngine : public EngineBase {
};
template <typename T>
using infer_ptr = std::unique_ptr<T, Destroyer<T>>;
// The following members is declared for different Builds, for each kind of
// Build method, not all these members are used.
infer_ptr<nvinfer1::IRuntime> infer_runtime_;
infer_ptr<nvinfer1::IBuilder> infer_builder_;
infer_ptr<nvinfer1::INetworkDefinition> infer_network_;
infer_ptr<nvinfer1::ICudaEngine> infer_engine_;
Expand Down
Loading