-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] : Add ScatterND TensorRT Plugin (#786)
* add scatter plugin * fix bugs of scatternd * add trt scatternd plugin * format code with clang-format * add test for scatternd * skip test_tensorrt in CI * remove unused variable Co-authored-by: maningsheng <maningsheng@sensetime.com>
- Loading branch information
1 parent
8e3a801
commit 2ab544f
Showing
7 changed files
with
509 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
#include "trt_scatternd.hpp" | ||
|
||
#include <assert.h> | ||
#include <stdio.h> | ||
|
||
#include <chrono> | ||
|
||
#include "trt_serialize.hpp" | ||
|
||
extern void TRTONNXScatterNDKernelLauncher_float( | ||
const float *data, const int *indices, const float *update, const int *dims, | ||
int nbDims, const int *indices_dims, int indice_nbDims, float *output, | ||
cudaStream_t stream); | ||
|
||
extern void TRTONNXScatterNDKernelLauncher_int32( | ||
const int *data, const int *indices, const int *update, const int *dims, | ||
int nbDims, const int *indices_dims, int indice_nbDims, int *output, | ||
cudaStream_t stream); | ||
|
||
namespace { | ||
static const char *PLUGIN_VERSION{"1"}; | ||
static const char *PLUGIN_NAME{"ScatterND"}; | ||
} // namespace | ||
|
||
nvinfer1::PluginFieldCollection ONNXScatterNDDynamicCreator::mFC{}; | ||
std::vector<nvinfer1::PluginField> | ||
ONNXScatterNDDynamicCreator::mPluginAttributes; | ||
|
||
ONNXScatterNDDynamic::ONNXScatterNDDynamic(const std::string &name) | ||
: mLayerName(name) {} | ||
|
||
ONNXScatterNDDynamic::ONNXScatterNDDynamic(const std::string name, | ||
const void *data, size_t length) | ||
: mLayerName(name) {} | ||
|
||
nvinfer1::IPluginV2DynamicExt *ONNXScatterNDDynamic::clone() const { | ||
ONNXScatterNDDynamic *plugin = new ONNXScatterNDDynamic(mLayerName); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
|
||
return plugin; | ||
} | ||
|
||
nvinfer1::DimsExprs ONNXScatterNDDynamic::getOutputDimensions( | ||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, | ||
nvinfer1::IExprBuilder &exprBuilder) { | ||
return inputs[0]; | ||
} | ||
|
||
bool ONNXScatterNDDynamic::supportsFormatCombination( | ||
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs, | ||
int nbOutputs) { | ||
if (pos < nbInputs) { | ||
switch (pos) { | ||
case 0: | ||
// data | ||
return (inOut[pos].type == nvinfer1::DataType::kFLOAT && | ||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR) || | ||
(inOut[pos].type == nvinfer1::DataType::kINT32 && | ||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR); | ||
case 1: | ||
// indices | ||
return inOut[pos].type == nvinfer1::DataType::kINT32 && | ||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; | ||
case 2: | ||
// updates | ||
return inOut[pos].type == inOut[0].type && | ||
inOut[pos].format == inOut[0].format; | ||
default: | ||
return true; | ||
} | ||
} else { | ||
switch (pos - nbInputs) { | ||
case 0: | ||
// output | ||
return inOut[pos].type == inOut[0].type && | ||
inOut[pos].format == inOut[0].format; | ||
default: | ||
return true; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
void ONNXScatterNDDynamic::configurePlugin( | ||
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, | ||
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {} | ||
|
||
size_t ONNXScatterNDDynamic::getWorkspaceSize( | ||
const nvinfer1::PluginTensorDesc *inputs, int nbInputs, | ||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const { | ||
return 0; | ||
} | ||
|
||
int ONNXScatterNDDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, | ||
const nvinfer1::PluginTensorDesc *outputDesc, | ||
const void *const *inputs, | ||
void *const *outputs, void *workSpace, | ||
cudaStream_t stream) { | ||
const int *dims = &(inputDesc[0].dims.d[0]); | ||
const int *indices_dims = &(inputDesc[1].dims.d[0]); | ||
int nbDims = inputDesc[0].dims.nbDims; | ||
int indice_nbDims = inputDesc[1].dims.nbDims; | ||
|
||
const void *data = inputs[0]; | ||
const void *indices = inputs[1]; | ||
const void *update = inputs[2]; | ||
void *output = outputs[0]; | ||
|
||
auto data_type = inputDesc[0].type; | ||
|
||
switch (data_type) { | ||
case nvinfer1::DataType::kFLOAT: | ||
TRTONNXScatterNDKernelLauncher_float( | ||
(float *)data, (int *)indices, (float *)update, dims, nbDims, | ||
indices_dims, indice_nbDims, (float *)output, stream); | ||
break; | ||
|
||
case nvinfer1::DataType::kINT32: | ||
TRTONNXScatterNDKernelLauncher_int32( | ||
(int *)data, (int *)indices, (int *)update, dims, nbDims, | ||
indices_dims, indice_nbDims, (int *)output, stream); | ||
break; | ||
default: | ||
break; | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
nvinfer1::DataType ONNXScatterNDDynamic::getOutputDataType( | ||
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const { | ||
return inputTypes[0]; | ||
} | ||
|
||
// IPluginV2 Methods | ||
const char *ONNXScatterNDDynamic::getPluginType() const { return PLUGIN_NAME; } | ||
|
||
const char *ONNXScatterNDDynamic::getPluginVersion() const { | ||
return PLUGIN_VERSION; | ||
} | ||
|
||
int ONNXScatterNDDynamic::getNbOutputs() const { return 1; } | ||
|
||
int ONNXScatterNDDynamic::initialize() { return 0; } | ||
|
||
void ONNXScatterNDDynamic::terminate() {} | ||
|
||
size_t ONNXScatterNDDynamic::getSerializationSize() const { return 0; } | ||
|
||
void ONNXScatterNDDynamic::serialize(void *buffer) const {} | ||
|
||
void ONNXScatterNDDynamic::destroy() { | ||
// This gets called when the network containing plugin is destroyed | ||
delete this; | ||
} | ||
|
||
void ONNXScatterNDDynamic::setPluginNamespace(const char *libNamespace) { | ||
mNamespace = libNamespace; | ||
} | ||
|
||
const char *ONNXScatterNDDynamic::getPluginNamespace() const { | ||
return mNamespace.c_str(); | ||
} | ||
|
||
////////////////////// creator ///////////////////////////// | ||
|
||
ONNXScatterNDDynamicCreator::ONNXScatterNDDynamicCreator() { | ||
mPluginAttributes.clear(); | ||
mFC.nbFields = mPluginAttributes.size(); | ||
mFC.fields = mPluginAttributes.data(); | ||
} | ||
|
||
const char *ONNXScatterNDDynamicCreator::getPluginName() const { | ||
return PLUGIN_NAME; | ||
} | ||
|
||
const char *ONNXScatterNDDynamicCreator::getPluginVersion() const { | ||
return PLUGIN_VERSION; | ||
} | ||
|
||
const nvinfer1::PluginFieldCollection * | ||
ONNXScatterNDDynamicCreator::getFieldNames() { | ||
return &mFC; | ||
} | ||
|
||
nvinfer1::IPluginV2 *ONNXScatterNDDynamicCreator::createPlugin( | ||
const char *name, const nvinfer1::PluginFieldCollection *fc) { | ||
ONNXScatterNDDynamic *plugin = new ONNXScatterNDDynamic(name); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
return plugin; | ||
} | ||
|
||
nvinfer1::IPluginV2 *ONNXScatterNDDynamicCreator::deserializePlugin( | ||
const char *name, const void *serialData, size_t serialLength) { | ||
auto plugin = new ONNXScatterNDDynamic(name, serialData, serialLength); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
return plugin; | ||
} | ||
|
||
void ONNXScatterNDDynamicCreator::setPluginNamespace(const char *libNamespace) { | ||
mNamespace = libNamespace; | ||
} | ||
|
||
const char *ONNXScatterNDDynamicCreator::getPluginNamespace() const { | ||
return mNamespace.c_str(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
#include <stdio.h> | ||
|
||
#include <vector> | ||
|
||
#include "common_cuda_helper.hpp" | ||
#include "trt_cuda_helper.cuh" | ||
#include "trt_plugin_helper.hpp" | ||
|
||
static int const threadsPerBlock = sizeof(unsigned long long int) * 8; | ||
|
||
using mmcv::TensorDesc; | ||
|
||
template <typename T> | ||
__global__ void onnx_scatternd_kernel(const int n, const int* indices, | ||
const T* update, T* output, | ||
TensorDesc tensor_desc, | ||
TensorDesc indice_desc) { | ||
const int indice_cols = indice_desc.shape[indice_desc.dim - 1]; | ||
const int copy_stride = tensor_desc.stride[indice_cols - 1]; | ||
const int* stride = &(tensor_desc.stride[0]); | ||
CUDA_1D_KERNEL_LOOP(index, n) { | ||
int output_offset = 0; | ||
const int* indices_current = indices + index * indice_cols; | ||
for (int i = 0; i < indice_cols; ++i) { | ||
output_offset += stride[i] * indices_current[i]; | ||
} | ||
memcpy(output + output_offset, update + index * copy_stride, | ||
copy_stride * sizeof(T)); | ||
} | ||
} | ||
|
||
template <typename T> | ||
void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices, | ||
const T* update, const int* dims, | ||
int nbDims, const int* indices_dims, | ||
int indice_nbDims, T* output, | ||
cudaStream_t stream) { | ||
// fill tensordesc and initial | ||
TensorDesc tensor_desc; | ||
memset((void*)&tensor_desc, 0, sizeof(TensorDesc)); | ||
tensor_desc.dim = nbDims; | ||
tensor_desc.shape[nbDims - 1] = dims[nbDims - 1]; | ||
tensor_desc.stride[nbDims - 1] = 1; | ||
for (int i = nbDims - 2; i >= 0; --i) { | ||
tensor_desc.shape[i] = dims[i]; | ||
tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1]; | ||
} | ||
const int data_size = tensor_desc.stride[0] * tensor_desc.shape[0]; | ||
|
||
TensorDesc indice_desc; | ||
memset((void*)&indice_desc, 0, sizeof(TensorDesc)); | ||
indice_desc.dim = indice_nbDims; | ||
indice_desc.shape[indice_nbDims - 1] = indices_dims[indice_nbDims - 1]; | ||
indice_desc.stride[indice_nbDims - 1] = 1; | ||
for (int i = indice_nbDims - 2; i >= 0; --i) { | ||
indice_desc.shape[i] = indices_dims[i]; | ||
indice_desc.stride[i] = indices_dims[i + 1] * indice_desc.stride[i + 1]; | ||
} | ||
|
||
// output = np.copy(data) | ||
cudaMemcpyAsync(output, data, data_size * sizeof(T), | ||
cudaMemcpyDeviceToDevice); | ||
|
||
int num_update_indice = 1; | ||
for (int i = 0; i < indice_nbDims - 1; ++i) { | ||
num_update_indice *= indice_desc.shape[i]; | ||
} | ||
// scatter | ||
const int col_block = DIVUP(num_update_indice, threadsPerBlock); | ||
onnx_scatternd_kernel<<<col_block, threadsPerBlock, 0, stream>>>( | ||
num_update_indice, indices, update, output, tensor_desc, indice_desc); | ||
} | ||
|
||
void TRTONNXScatterNDKernelLauncher_float(const float* data, const int* indices, | ||
const float* update, const int* dims, | ||
int nbDims, const int* indices_dims, | ||
int indice_nbDims, float* output, | ||
cudaStream_t stream) { | ||
TRTONNXScatterNDKernelLauncher<float>(data, indices, update, dims, nbDims, | ||
indices_dims, indice_nbDims, output, | ||
stream); | ||
} | ||
|
||
void TRTONNXScatterNDKernelLauncher_int32(const int* data, const int* indices, | ||
const int* update, const int* dims, | ||
int nbDims, const int* indices_dims, | ||
int indice_nbDims, int* output, | ||
cudaStream_t stream) { | ||
TRTONNXScatterNDKernelLauncher<int>(data, indices, update, dims, nbDims, | ||
indices_dims, indice_nbDims, output, | ||
stream); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#ifndef TRT_CUDA_HELPER_HPP | ||
#define TRT_CUDA_HELPER_HPP | ||
|
||
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) | ||
|
||
#define cudaCheckError() \ | ||
{ \ | ||
cudaError_t e = cudaGetLastError(); \ | ||
if (e != cudaSuccess) { \ | ||
printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \ | ||
cudaGetErrorString(e)); \ | ||
exit(0); \ | ||
} \ | ||
} | ||
|
||
#endif // TRT_CUDA_HELPER_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.