Skip to content

Commit

Permalink
feature/mul converter (#10841)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn committed May 29, 2018
1 parent 8f7b020 commit f5fc9c3
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 21 deletions.
11 changes: 9 additions & 2 deletions paddle/fluid/inference/analysis/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ namespace paddle {
namespace inference {
namespace analysis {

template <typename Vec>
int AccuDims(Vec &&vec, int size) {
int res = 1;
for (int i = 0; i < size; i++) {
res *= std::forward<Vec>(vec)[i];
}
return res;
}

#define SET_TYPE(type__) dic_[typeid(type__).hash_code()] = #type__;
/*
* Map typeid to representation.
Expand Down Expand Up @@ -101,7 +110,5 @@ class OrderedRegistry {
} // namespace paddle

#define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \
\
type__(const type__ &) = delete; \
\
void operator=(const type__ &) = delete;
5 changes: 4 additions & 1 deletion paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
# Add TRT tests
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine)
# This test is not stable
# See https://paddleci.ngrok.io/viewLog.html?tab=buildLog&buildTypeId=Paddle_PrCi2&buildId=36834&_focus=8828
#nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc io_converter.cc
# DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine
# SERIAL)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
16 changes: 15 additions & 1 deletion paddle/fluid/inference/tensorrt/convert/mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,25 @@ namespace paddle {
namespace inference {
namespace tensorrt {

/*
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
*/
class MulOpConverter : public OpConverter {
public:
MulOpConverter() {}
void operator()(const framework::proto::OpDesc& op) override {
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
VLOG(4) << "convert a fluid mul op to tensorrt fc layer without bias";

framework::OpDesc op_desc(op, nullptr, nullptr);
// Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
// Both the input1 and input2 do not need transpose.
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), false,
*const_cast<nvinfer1::ITensor*>(input2), false);

engine_->DeclareOutput(layer, 0, op_desc.Output("Out")[0]);
}
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,5 @@ TEST(OpConverter, ConvertRelu) {
} // namespace tensorrt
} // namespace inference
} // namespace paddle

USE_OP(activation);
47 changes: 47 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/test_mul_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"

namespace paddle {
namespace inference {
namespace tensorrt {

TEST(MulOpConverter, main) {
TRTConvertValidation validator(10, 1000);
validator.DeclInputVar("mul-X", nvinfer1::Dims2(10, 6));
validator.DeclInputVar("mul-Y", nvinfer1::Dims2(6, 10));
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(10, 10));

// Prepare Op description
framework::OpDesc desc;
desc.SetType("mul");
desc.SetInput("X", {"mul-X"});
desc.SetInput("Y", {"mul-Y"});
desc.SetOutput("Out", {"mul-Out"});

LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";

validator.Execute(10);
}

} // namespace tensorrt
} // namespace inference
} // namespace paddle

USE_OP(mul);
2 changes: 0 additions & 2 deletions paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ namespace tensorrt {
TEST(OpConverter, ConvertBlock) {
framework::ProgramDesc prog;
auto* block = prog.MutableBlock(0);
auto* mul_op = block->AppendOp();
mul_op->SetType("mul");
auto* conv2d_op = block->AppendOp();
conv2d_op->SetType("conv2d");

Expand Down
156 changes: 156 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/ut_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

/*
* This file implements a UT framework to make the validation of transforming
* Fluid Op to TRT Layer.
*/

#pragma once

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"

namespace paddle {
namespace inference {
namespace tensorrt {

/*
* Get a random float value between [low, high]
*/
float random(float low, float high) {
static std::random_device rd;
static std::mt19937 mt(rd());
std::uniform_real_distribution<double> dist(1.0, 10.0);
return dist(mt);
}

void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
const platform::DeviceContext& ctx) {
auto dims = tensor->dims();
size_t num_elements = analysis::AccuDims(dims, dims.size());
PADDLE_ENFORCE_GT(num_elements, 0);
auto* data = tensor->mutable_data<float>(place);
for (size_t i = 0; i < num_elements; i++) {
*(data + i) = random(0., 1.);
}
}

/*
* Help to validate the correctness between Fluid Op and the corresponding TRT
* layer.
*/
class TRTConvertValidation {
public:
TRTConvertValidation() = delete;

TRTConvertValidation(int batch_size, int workspace_size = 1 << 10) {
// create engine.
engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_));
engine_->InitNetwork();

PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
}

// Declare a Variable as input with random initialization.
void DeclInputVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims);
// Declare TRT inputs.
engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, dims);
}

void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims);
}

void DeclVar(const std::string& name, const nvinfer1::Dims& dims) {
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);

// Init Fluid tensor.
std::vector<int> dim_vec(dims.nbDims);
for (int i = 0; i < dims.nbDims; i++) {
dim_vec[i] = dims.d[i];
}
auto* x = scope_.Var(name);
auto* x_tensor = x->GetMutable<framework::LoDTensor>();
x_tensor->Resize(framework::make_ddim(dim_vec));
RandomizeTensor(x_tensor, place, ctx);
}

void SetOp(const framework::proto::OpDesc& desc) {
op_ = framework::OpRegistry::CreateOp(desc);

OpConverter op_converter;
op_converter.ConvertOp(desc, engine_.get());

engine_->FreezeNetwork();

// Declare outputs.
op_desc_.reset(new framework::OpDesc(desc, nullptr, nullptr));

// Set Inputs.
for (const auto& input : op_desc_->InputArgumentNames()) {
auto* var = scope_.FindVar(input);
PADDLE_ENFORCE(var);
auto tensor = var->GetMutable<framework::LoDTensor>();
engine_->SetInputFromCPU(
input, static_cast<void*>(tensor->data<float>()),
sizeof(float) *
analysis::AccuDims(tensor->dims(), tensor->dims().size()));
}
}

void Execute(int batch_size) {
// Execute Fluid Op
// Execute TRT
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
engine_->Execute(batch_size);

op_->Run(scope_, place);

ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
for (const auto& output : op_desc_->OutputArgumentNames()) {
std::vector<float> fluid_out;
std::vector<float> trt_out(200);
engine_->GetOutputInCPU(output, &trt_out[0], 200 * sizeof(float));

auto* var = scope_.FindVar(output);
auto tensor = var->GetMutable<framework::LoDTensor>();
framework::TensorToVector(*tensor, ctx, &fluid_out);
// Compare two output
ASSERT_FALSE(fluid_out.empty());
for (size_t i = 0; i < fluid_out.size(); i++) {
EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 0.001);
}
}
}

framework::Scope& scope() { return scope_; }

private:
std::unique_ptr<TensorRTEngine> engine_;
cudaStream_t stream_;
framework::Scope scope_;
std::unique_ptr<framework::OperatorBase> op_;
std::unique_ptr<framework::OpDesc> op_desc_;
};

} // namespace tensorrt
} // namespace inference
} // namespace paddle
15 changes: 9 additions & 6 deletions paddle/fluid/inference/tensorrt/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include <cuda.h>
#include <glog/logging.h>
#include <string>
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"

Expand Down Expand Up @@ -71,9 +72,10 @@ void TensorRTEngine::FreezeNetwork() {
for (auto& item : buffer_sizes_) {
if (item.second == 0) {
auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
auto dims = infer_engine_->getBindingDimensions(slot_offset);
item.second = kDataTypeSize[static_cast<int>(
infer_engine_->getBindingDataType(slot_offset))] *
AccumDims(infer_engine_->getBindingDimensions(slot_offset));
analysis::AccuDims(dims.d, dims.nbDims);
}
auto& buf = buffer(item.first);
CHECK(buf.buffer == nullptr); // buffer should be allocated only once.
Expand All @@ -85,14 +87,15 @@ void TensorRTEngine::FreezeNetwork() {

nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
nvinfer1::DataType dtype,
const nvinfer1::Dims& dim) {
const nvinfer1::Dims& dims) {
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
name);

PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
auto* input = infer_network_->addInput(name.c_str(), dtype, dim);
auto* input = infer_network_->addInput(name.c_str(), dtype, dims);
PADDLE_ENFORCE(input, "infer network add input %s failed", name);
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] * AccumDims(dim);
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
analysis::AccuDims(dims.d, dims.nbDims);
TensorRTEngine::SetITensor(name, input);
return input;
}
Expand Down Expand Up @@ -162,13 +165,13 @@ void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
void TensorRTEngine::SetITensor(const std::string& name,
nvinfer1::ITensor* tensor) {
PADDLE_ENFORCE(tensor != nullptr);
PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate itensor name %s",
PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate ITensor name %s",
name);
itensor_map_[name] = tensor;
}

nvinfer1::ITensor* TensorRTEngine::GetITensor(const std::string& name) {
PADDLE_ENFORCE(itensor_map_.count(name), "no itensor %s", name);
PADDLE_ENFORCE(itensor_map_.count(name), "no ITensor %s", name);
return itensor_map_[name];
}

Expand Down
9 changes: 0 additions & 9 deletions paddle/fluid/inference/tensorrt/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,6 @@ namespace tensorrt {

namespace dy = paddle::platform::dynload;

static size_t AccumDims(nvinfer1::Dims dims) {
size_t num = dims.nbDims == 0 ? 0 : 1;
for (int i = 0; i < dims.nbDims; i++) {
PADDLE_ENFORCE_GT(dims.d[i], 0);
num *= dims.d[i];
}
return num;
}

// TensorRT data type to size
const int kDataTypeSize[] = {
4, // kFLOAT
Expand Down

0 comments on commit f5fc9c3

Please sign in to comment.