Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle-TRT] IPluginExt -> IPluginV2 #33680

Merged
merged 58 commits into from
Jul 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
f40cee7
add trt LT version helper
zlsh80826 Jun 18, 2021
75d20da
upgrade PluginTensorRT to IPluginV2Ext
zlsh80826 Jun 20, 2021
5d63c78
trt plugin factory is not usable in IPluginV2
zlsh80826 Jun 20, 2021
f7e94a0
upgrade add plugin api to use IPluginV2
zlsh80826 Jun 20, 2021
ac6bd52
remove IPlugin register and adapt getSerializeSize(), serialize()
zlsh80826 Jun 20, 2021
b311c3f
adapt IPluginV2Layer
zlsh80826 Jun 20, 2021
02b4120
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 22, 2021
70e8a1e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 24, 2021
21c9c35
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 24, 2021
b290aa9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 24, 2021
99e1461
downgrade to IPluginV2
zlsh80826 Jun 24, 2021
e7e8d92
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 24, 2021
686b368
implement elementwise clone
zlsh80826 Jun 24, 2021
fffb0b5
add gelu plugin creator and fix gelu serialization bug
zlsh80826 Jun 25, 2021
b55ac3a
add swish plugin creator and fix swish serialization bug
zlsh80826 Jun 25, 2021
a852793
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 25, 2021
c5b08bb
format
zlsh80826 Jun 25, 2021
326e093
fix typo
zlsh80826 Jun 25, 2021
d434252
add elementwise plugin creator and fix serialization
zlsh80826 Jun 25, 2021
3d36ece
add base creator class
zlsh80826 Jun 25, 2021
2020e58
add gelu plugin creator
zlsh80826 Jun 25, 2021
84e3675
add hard swish creator and fix serialization
zlsh80826 Jun 25, 2021
20c91fa
add instance norm creator and fix serialization
zlsh80826 Jun 25, 2021
8fcd8fd
add layer norm creator and fix serialization
zlsh80826 Jun 25, 2021
571ca99
add pool creator and fix serialization
zlsh80826 Jun 25, 2021
f60ec9c
add prelu creator and fix serialization
zlsh80826 Jun 25, 2021
4c1ab09
add slice creator and fix serialization
zlsh80826 Jun 25, 2021
c7053fa
add swish creator and fix serialization
zlsh80826 Jun 25, 2021
b0df6b1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 25, 2021
3b2e493
add instance norm op unittest
zlsh80826 Jun 25, 2021
3c0a212
remove redundent api
zlsh80826 Jun 25, 2021
3d59b63
fix wrong graph size to enable trt
zlsh80826 Jun 25, 2021
045e906
instance norm function move to cc
zlsh80826 Jun 26, 2021
3ad1eb6
add trt elementwise ut to trigger coverage
zlsh80826 Jun 26, 2021
92849b9
remove opt cahce to hit serialization coverage
zlsh80826 Jun 26, 2021
3723bbb
remove opt cahce to hit serialization coverage
zlsh80826 Jun 26, 2021
5a6b329
remove unused code
zlsh80826 Jun 26, 2021
b57de25
remove unused inputs_
zlsh80826 Jun 26, 2021
4a03a9c
add dbg info
zlsh80826 Jun 26, 2021
1d5c960
remove dbg info
zlsh80826 Jun 26, 2021
d29a365
add instance norm serialization
zlsh80826 Jun 26, 2021
6ac45e2
roll back
zlsh80826 Jun 26, 2021
ce98f21
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 28, 2021
4397920
remove comment code
zlsh80826 Jun 28, 2021
faa9bf6
remove trt plugin registery
zlsh80826 Jun 28, 2021
8efb2c5
fix prelu dynamic serialization
zlsh80826 Jun 28, 2021
55fb335
add prelu ut and reduce the input size to reduce memory usage
zlsh80826 Jun 28, 2021
a340884
fix pool dynamic plugin serialization and add ut
zlsh80826 Jun 28, 2021
6097f42
refine pool ut with subtest
zlsh80826 Jun 28, 2021
eafd8c7
add env for avoiding oom
zlsh80826 Jun 28, 2021
d9a6505
reduce test input size & increase pool op ut to 45s
zlsh80826 Jun 28, 2021
80cef1e
add the contributor
zlsh80826 Jun 29, 2021
1c0f1f7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 29, 2021
8a10d21
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jul 8, 2021
33c79db
Merge branch 'trt-IPluginV2Ext' of github.com:zlsh80826/Paddle into t…
zlsh80826 Jul 8, 2021
ca6b8b3
remove copyright (will add in contributor)
zlsh80826 Jul 8, 2021
df49070
remove copyright (will add in contributor)
zlsh80826 Jul 9, 2021
95f4d3d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jul 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,10 @@ class ElementwiseTensorOpConverter : public OpConverter {
} else {
plugin::ElementWisePlugin* plugin =
new plugin::ElementWisePlugin(op_type_, dims_x, dims_y, axis);
plugin->AddInput(X);
plugin->AddInput(Y);
nvinfer1::IPluginLayer* plugin_layer = engine_->AddPlugin(
plugin->GetInputs().data(), 2,

std::vector<nvinfer1::ITensor*> inputs{X, Y};
auto* plugin_layer = engine_->AddPlugin(
inputs.data(), inputs.size(),
reinterpret_cast<plugin::PluginTensorRT*>(plugin));

layer = plugin_layer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class InstanceNormOpConverter : public OpConverter {
plugin::InstanceNormPlugin* plugin =
new plugin::InstanceNormPlugin(eps, scale_v, bias_v);
plugin->getPluginType();
nvinfer1::IPluginLayer* layer = engine_->AddPlugin(&input, 1, plugin);
auto* layer = engine_->AddPlugin(&input, 1, plugin);

auto output_name = op_desc.Output("Y")[0];
RreplenishLayerAndOutput(layer, "instance_norm", {output_name}, test_mode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ class ShuffleChannelOpConverter : public OpConverter {
reshape_layer->setReshapeDimensions(reshape_dim2);

auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(reshape_layer, "concat", {output_name}, test_mode);
RreplenishLayerAndOutput(reshape_layer, "shuffle_channel", {output_name},
test_mode);
}
};

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/inference/tensorrt/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,11 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,

int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; }

nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
nvinfer1::IPluginV2Layer *TensorRTEngine::AddPlugin(
nvinfer1::ITensor *const *inputs, int num_inputs,
plugin::PluginTensorRT *plugin) {
owned_plugin_.emplace_back(plugin);
return network()->addPluginExt(inputs, num_inputs, *plugin);
return network()->addPluginV2(inputs, num_inputs, *plugin);
}

nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2Ext(
Expand Down
20 changes: 4 additions & 16 deletions paddle/fluid/inference/tensorrt/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ limitations under the License. */
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "paddle/fluid/inference/utils/singleton.h"

Expand Down Expand Up @@ -276,19 +275,8 @@ class TensorRTEngine {
}
}

if (with_dynamic_shape_) {
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size()));
} else {
#if IS_TRT_VERSION_LT(8000)
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size(),
&inference::Singleton<plugin::PluginFactoryTensorRT>::Global()));
#else
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size()));
#endif
}
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size()));

PADDLE_ENFORCE_NOT_NULL(
infer_engine_,
Expand All @@ -311,8 +299,8 @@ class TensorRTEngine {

int GetDeviceId() { return device_id_; }

nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
int num_inputs, plugin::PluginTensorRT*);
nvinfer1::IPluginV2Layer* AddPlugin(nvinfer1::ITensor* const* inputs,
int num_inputs, plugin::PluginTensorRT*);

nvinfer1::IPluginV2Layer* AddPluginV2Ext(nvinfer1::ITensor* const* inputs,
int num_inputs,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
prelu_op_plugin.cu gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
#include <cassert>

#include "paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"

#include "paddle/fluid/operators/detection/anchor_generator_op.h"

namespace paddle {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,12 @@ limitations under the License. */

#include <glog/logging.h>
#include "paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

ElementWisePlugin *CreateElementWisePluginDeserialize(const void *buffer,
size_t length) {
return new ElementWisePlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("elementwise_plugin", CreateElementWisePluginDeserialize);

namespace details {
template <typename T>
struct Add {
Expand Down
47 changes: 35 additions & 12 deletions paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ class ElementWisePlugin : public PluginTensorRT {
const char* elementwise_type;
DeserializeValue(&serial_data, &serial_length, &elementwise_type);
type_ = std::string(elementwise_type);
DeserializeValue(&serial_data, &serial_length, &axis_);
DeserializeValue(&serial_data, &serial_length, &dims_x_);
DeserializeValue(&serial_data, &serial_length, &dims_y_);
DeserializeValue(&serial_data, &serial_length, &axis_);
DeserializeValue(&serial_data, &serial_length, &prev_size_);
DeserializeValue(&serial_data, &serial_length, &midd_size_);
DeserializeValue(&serial_data, &serial_length, &post_size_);
}

ElementWisePlugin* clone() const override {
// return new ElementWisePlugin(dims_x_, dims_y_, axis_);
return nullptr;
return new ElementWisePlugin(type_, dims_x_, dims_y_, axis_);
}

const char* getPluginType() const override { return "elementwise_plugin"; }
Expand All @@ -65,22 +67,25 @@ class ElementWisePlugin : public PluginTensorRT {
#endif
void* workspace, cudaStream_t stream);

protected:
size_t getSerializationSize() override {
return SerializedSize(getPluginType()) + SerializedSize(axis_) +
size_t getSerializationSize() const override {
return getBaseSerializationSize() + SerializedSize(type_.c_str()) +
SerializedSize(dims_x_) + SerializedSize(dims_y_) +
getBaseSerializationSize();
SerializedSize(axis_) + SerializedSize(prev_size_) +
SerializedSize(midd_size_) + SerializedSize(post_size_);
}

void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
void serialize(void* buffer) const override {
serializeBase(buffer);
SerializeValue(&buffer, type_.c_str());
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, dims_x_);
SerializeValue(&buffer, dims_y_);
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, prev_size_);
SerializeValue(&buffer, midd_size_);
SerializeValue(&buffer, post_size_);
}

protected:
std::string type_;
nvinfer1::Dims dims_x_;
nvinfer1::Dims dims_y_;
Expand All @@ -90,6 +95,20 @@ class ElementWisePlugin : public PluginTensorRT {
int post_size_;
};

class ElementWisePluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "elementwise_plugin"; }

const char* getPluginVersion() const override { return "1"; }

nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new ElementWisePlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(ElementWisePluginCreator);

#if IS_TRT_VERSION_GE(6000)
class ElementwisePluginDynamic : public DynamicPluginTensorRT {
public:
Expand All @@ -105,7 +124,9 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT {
return new ElementwisePluginDynamic(type_, axis_);
}

const char* getPluginType() const override { return "elementwise_plugin"; }
const char* getPluginType() const override {
return "elementwise_plugin_dynamic";
}
int getNbOutputs() const override { return 1; }
int initialize() override;

Expand Down Expand Up @@ -150,7 +171,9 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT {
class ElementwisePluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
ElementwisePluginDynamicCreator() {}
const char* getPluginName() const override { return "elementwise_plugin"; }
const char* getPluginName() const override {
return "elementwise_plugin_dynamic";
}

const char* getPluginVersion() const override { return "1"; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"

namespace paddle {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

#include "NvInferRuntimeCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/platform/place.h"

namespace paddle {
Expand Down
7 changes: 0 additions & 7 deletions paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
Expand All @@ -31,12 +30,6 @@ static const float kAT = 0.5;
static const float kBT = 0.7978845608028654; // sqrt(2.0/M_PI)
static const float kCT = 0.035677408136300125; // 0.044715 * sqrt(2.0/M_PI)

GeluPlugin* CreateGeluPluginDeserialize(const void* buffer, size_t length) {
return new GeluPlugin(buffer, length);
}

REGISTER_TRT_PLUGIN("gelu_plugin", CreateGeluPluginDeserialize);

bool GeluPlugin::supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const {
if (with_fp16_) {
Expand Down
53 changes: 19 additions & 34 deletions paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,28 @@ class GeluPlugin : public PluginTensorRT {
#endif
void* workspace, cudaStream_t stream) override;

protected:
size_t getSerializationSize() override {
return getBaseSerializationSize() + SerializedSize(getPluginType());
size_t getSerializationSize() const override {
return getBaseSerializationSize();
}

// TRT will call this func to serialize the configuration of TRT
// It should not be called by users.
void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
void serialize(void* buffer) const override { serializeBase(buffer); }
};

class GeluPluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "gelu_plugin"; }

const char* getPluginVersion() const override { return "1"; }

nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new GeluPlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(GeluPluginCreator);

#if IS_TRT_VERSION_GE(6000)
class GeluPluginDynamic : public DynamicPluginTensorRT {
Expand All @@ -77,7 +87,7 @@ class GeluPluginDynamic : public DynamicPluginTensorRT {
return new GeluPluginDynamic(with_fp16_);
}

const char* getPluginType() const override { return "gelu_plugin"; }
const char* getPluginType() const override { return "gelu_plugin_dynamic"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }

Expand Down Expand Up @@ -119,44 +129,19 @@ class GeluPluginDynamic : public DynamicPluginTensorRT {
void destroy() override { delete this; }
};

class GeluPluginDynamicCreator : public nvinfer1::IPluginCreator {
class GeluPluginDynamicCreator : public TensorRTPluginCreator {
public:
GeluPluginDynamicCreator() {}
const char* getPluginName() const override { return "gelu_plugin"; }
const char* getPluginName() const override { return "gelu_plugin_dynamic"; }

const char* getPluginVersion() const override { return "1"; }

const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
}

nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
}

nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
auto plugin = new GeluPluginDynamic(serial_data, serial_length);
return plugin;
}

void setPluginNamespace(const char* lib_namespace) override {
plugin_namespace_ = lib_namespace;
}

const char* getPluginNamespace() const override {
return plugin_namespace_.c_str();
}

private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};

REGISTER_TRT_PLUGIN_V2(GeluPluginDynamicCreator);
#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,12 @@
#include <cassert>
#include <cstring>
#include "paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

HardSwishPlugin* CreateHardSwishPluginDeserialize(const void* buffer,
size_t length) {
return new HardSwishPlugin(buffer, length);
}

REGISTER_TRT_PLUGIN("hard_swish_plugin", CreateHardSwishPluginDeserialize);

nvinfer1::Dims HardSwishPlugin::getOutputDimensions(
int index, const nvinfer1::Dims* in_dims, int nb_inputs) {
assert(nb_inputs == 1);
Expand Down
Loading