Skip to content

Commit

Permalink
[Feature] : Add ScatterND TensorRT Plugin (#786)
Browse files Browse the repository at this point in the history
* add scatter plugin

* fix bugs of scatternd

* add trt scatternd plugin

* format code with clang-format

* add test for scatternd

* skip test_tensorrt in CI

* remove unused variable

Co-authored-by: maningsheng <maningsheng@sensetime.com>
  • Loading branch information
grimoire and RunningLeon committed Jan 20, 2021
1 parent 8e3a801 commit 2ab544f
Show file tree
Hide file tree
Showing 7 changed files with 509 additions and 10 deletions.
2 changes: 2 additions & 0 deletions mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include "trt_plugin.hpp"

#include "trt_roi_align.hpp"
#include "trt_scatternd.hpp"

REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator);

extern "C" {
bool initLibMMCVInferPlugins() { return true; }
Expand Down
206 changes: 206 additions & 0 deletions mmcv/ops/csrc/tensorrt/plugins/trt_scatternd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
#include "trt_scatternd.hpp"

#include <assert.h>
#include <stdio.h>

#include <chrono>

#include "trt_serialize.hpp"

extern void TRTONNXScatterNDKernelLauncher_float(
const float *data, const int *indices, const float *update, const int *dims,
int nbDims, const int *indices_dims, int indice_nbDims, float *output,
cudaStream_t stream);

extern void TRTONNXScatterNDKernelLauncher_int32(
const int *data, const int *indices, const int *update, const int *dims,
int nbDims, const int *indices_dims, int indice_nbDims, int *output,
cudaStream_t stream);

namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"ScatterND"};
} // namespace

nvinfer1::PluginFieldCollection ONNXScatterNDDynamicCreator::mFC{};
std::vector<nvinfer1::PluginField>
ONNXScatterNDDynamicCreator::mPluginAttributes;

ONNXScatterNDDynamic::ONNXScatterNDDynamic(const std::string &name)
: mLayerName(name) {}

ONNXScatterNDDynamic::ONNXScatterNDDynamic(const std::string name,
const void *data, size_t length)
: mLayerName(name) {}

nvinfer1::IPluginV2DynamicExt *ONNXScatterNDDynamic::clone() const {
ONNXScatterNDDynamic *plugin = new ONNXScatterNDDynamic(mLayerName);
plugin->setPluginNamespace(getPluginNamespace());

return plugin;
}

nvinfer1::DimsExprs ONNXScatterNDDynamic::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) {
return inputs[0];
}

bool ONNXScatterNDDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
int nbOutputs) {
if (pos < nbInputs) {
switch (pos) {
case 0:
// data
return (inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR) ||
(inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR);
case 1:
// indices
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
case 2:
// updates
return inOut[pos].type == inOut[0].type &&
inOut[pos].format == inOut[0].format;
default:
return true;
}
} else {
switch (pos - nbInputs) {
case 0:
// output
return inOut[pos].type == inOut[0].type &&
inOut[pos].format == inOut[0].format;
default:
return true;
}
}
return true;
}

void ONNXScatterNDDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}

size_t ONNXScatterNDDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
return 0;
}

int ONNXScatterNDDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs,
void *const *outputs, void *workSpace,
cudaStream_t stream) {
const int *dims = &(inputDesc[0].dims.d[0]);
const int *indices_dims = &(inputDesc[1].dims.d[0]);
int nbDims = inputDesc[0].dims.nbDims;
int indice_nbDims = inputDesc[1].dims.nbDims;

const void *data = inputs[0];
const void *indices = inputs[1];
const void *update = inputs[2];
void *output = outputs[0];

auto data_type = inputDesc[0].type;

switch (data_type) {
case nvinfer1::DataType::kFLOAT:
TRTONNXScatterNDKernelLauncher_float(
(float *)data, (int *)indices, (float *)update, dims, nbDims,
indices_dims, indice_nbDims, (float *)output, stream);
break;

case nvinfer1::DataType::kINT32:
TRTONNXScatterNDKernelLauncher_int32(
(int *)data, (int *)indices, (int *)update, dims, nbDims,
indices_dims, indice_nbDims, (int *)output, stream);
break;
default:
break;
}

return 0;
}

nvinfer1::DataType ONNXScatterNDDynamic::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
return inputTypes[0];
}

// IPluginV2 Methods
const char *ONNXScatterNDDynamic::getPluginType() const { return PLUGIN_NAME; }

const char *ONNXScatterNDDynamic::getPluginVersion() const {
return PLUGIN_VERSION;
}

int ONNXScatterNDDynamic::getNbOutputs() const { return 1; }

int ONNXScatterNDDynamic::initialize() { return 0; }

void ONNXScatterNDDynamic::terminate() {}

size_t ONNXScatterNDDynamic::getSerializationSize() const { return 0; }

void ONNXScatterNDDynamic::serialize(void *buffer) const {}

void ONNXScatterNDDynamic::destroy() {
// This gets called when the network containing plugin is destroyed
delete this;
}

void ONNXScatterNDDynamic::setPluginNamespace(const char *libNamespace) {
mNamespace = libNamespace;
}

const char *ONNXScatterNDDynamic::getPluginNamespace() const {
return mNamespace.c_str();
}

////////////////////// creator /////////////////////////////

ONNXScatterNDDynamicCreator::ONNXScatterNDDynamicCreator() {
mPluginAttributes.clear();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

const char *ONNXScatterNDDynamicCreator::getPluginName() const {
return PLUGIN_NAME;
}

const char *ONNXScatterNDDynamicCreator::getPluginVersion() const {
return PLUGIN_VERSION;
}

const nvinfer1::PluginFieldCollection *
ONNXScatterNDDynamicCreator::getFieldNames() {
return &mFC;
}

nvinfer1::IPluginV2 *ONNXScatterNDDynamicCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) {
ONNXScatterNDDynamic *plugin = new ONNXScatterNDDynamic(name);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

nvinfer1::IPluginV2 *ONNXScatterNDDynamicCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) {
auto plugin = new ONNXScatterNDDynamic(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

void ONNXScatterNDDynamicCreator::setPluginNamespace(const char *libNamespace) {
mNamespace = libNamespace;
}

const char *ONNXScatterNDDynamicCreator::getPluginNamespace() const {
return mNamespace.c_str();
}
92 changes: 92 additions & 0 deletions mmcv/ops/csrc/tensorrt/plugins/trt_scatternd_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include <stdio.h>

#include <vector>

#include "common_cuda_helper.hpp"
#include "trt_cuda_helper.cuh"
#include "trt_plugin_helper.hpp"

static int const threadsPerBlock = sizeof(unsigned long long int) * 8;

using mmcv::TensorDesc;

template <typename T>
__global__ void onnx_scatternd_kernel(const int n, const int* indices,
const T* update, T* output,
TensorDesc tensor_desc,
TensorDesc indice_desc) {
const int indice_cols = indice_desc.shape[indice_desc.dim - 1];
const int copy_stride = tensor_desc.stride[indice_cols - 1];
const int* stride = &(tensor_desc.stride[0]);
CUDA_1D_KERNEL_LOOP(index, n) {
int output_offset = 0;
const int* indices_current = indices + index * indice_cols;
for (int i = 0; i < indice_cols; ++i) {
output_offset += stride[i] * indices_current[i];
}
memcpy(output + output_offset, update + index * copy_stride,
copy_stride * sizeof(T));
}
}

template <typename T>
void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices,
const T* update, const int* dims,
int nbDims, const int* indices_dims,
int indice_nbDims, T* output,
cudaStream_t stream) {
// fill tensordesc and initial
TensorDesc tensor_desc;
memset((void*)&tensor_desc, 0, sizeof(TensorDesc));
tensor_desc.dim = nbDims;
tensor_desc.shape[nbDims - 1] = dims[nbDims - 1];
tensor_desc.stride[nbDims - 1] = 1;
for (int i = nbDims - 2; i >= 0; --i) {
tensor_desc.shape[i] = dims[i];
tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1];
}
const int data_size = tensor_desc.stride[0] * tensor_desc.shape[0];

TensorDesc indice_desc;
memset((void*)&indice_desc, 0, sizeof(TensorDesc));
indice_desc.dim = indice_nbDims;
indice_desc.shape[indice_nbDims - 1] = indices_dims[indice_nbDims - 1];
indice_desc.stride[indice_nbDims - 1] = 1;
for (int i = indice_nbDims - 2; i >= 0; --i) {
indice_desc.shape[i] = indices_dims[i];
indice_desc.stride[i] = indices_dims[i + 1] * indice_desc.stride[i + 1];
}

// output = np.copy(data)
cudaMemcpyAsync(output, data, data_size * sizeof(T),
cudaMemcpyDeviceToDevice);

int num_update_indice = 1;
for (int i = 0; i < indice_nbDims - 1; ++i) {
num_update_indice *= indice_desc.shape[i];
}
// scatter
const int col_block = DIVUP(num_update_indice, threadsPerBlock);
onnx_scatternd_kernel<<<col_block, threadsPerBlock, 0, stream>>>(
num_update_indice, indices, update, output, tensor_desc, indice_desc);
}

void TRTONNXScatterNDKernelLauncher_float(const float* data, const int* indices,
const float* update, const int* dims,
int nbDims, const int* indices_dims,
int indice_nbDims, float* output,
cudaStream_t stream) {
TRTONNXScatterNDKernelLauncher<float>(data, indices, update, dims, nbDims,
indices_dims, indice_nbDims, output,
stream);
}

void TRTONNXScatterNDKernelLauncher_int32(const int* data, const int* indices,
const int* update, const int* dims,
int nbDims, const int* indices_dims,
int indice_nbDims, int* output,
cudaStream_t stream) {
TRTONNXScatterNDKernelLauncher<int>(data, indices, update, dims, nbDims,
indices_dims, indice_nbDims, output,
stream);
}
16 changes: 16 additions & 0 deletions mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef TRT_CUDA_HELPER_HPP
#define TRT_CUDA_HELPER_HPP

#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))

#define cudaCheckError() \
{ \
cudaError_t e = cudaGetLastError(); \
if (e != cudaSuccess) { \
printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \
cudaGetErrorString(e)); \
exit(0); \
} \
}

#endif // TRT_CUDA_HELPER_HPP
8 changes: 8 additions & 0 deletions mmcv/ops/csrc/tensorrt/trt_plugin_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@

namespace mmcv {

const int MAXTENSORDIMS = 10;

struct TensorDesc {
int shape[MAXTENSORDIMS];
int stride[MAXTENSORDIMS];
int dim;
};

inline unsigned int getElementSize(nvinfer1::DataType t) {
switch (t) {
case nvinfer1::DataType::kINT32:
Expand Down
Loading

0 comments on commit 2ab544f

Please sign in to comment.