diff --git a/CMakeLists.txt b/CMakeLists.txt index 0d1b36c592..18fce68019 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -462,7 +462,7 @@ endif() if(ENABLE_ENCRYPTION) add_definitions(-DENABLE_ENCRYPTION) list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_ENCRYPTION_SRCS}) - include(${PROJECT_SOURCE_DIR}/cmake/gflags.cmake) + # include(${PROJECT_SOURCE_DIR}/cmake/gflags.cmake) include(${PROJECT_SOURCE_DIR}/cmake/openssl.cmake) list(APPEND DEPEND_LIBS ${OPENSSL_LIBRARIES}) endif() diff --git a/benchmark/paddlex/CMakeLists.txt b/benchmark/paddlex/CMakeLists.txt index b1439c546e..90c2f2e748 100755 --- a/benchmark/paddlex/CMakeLists.txt +++ b/benchmark/paddlex/CMakeLists.txt @@ -21,6 +21,7 @@ add_executable(benchmark_structurev2_table ${PROJECT_SOURCE_DIR}/benchmark_struc add_executable(benchmark_structurev2_layout ${PROJECT_SOURCE_DIR}/benchmark_structurev2_layout.cc) add_executable(benchmark_ppshituv2_rec ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_rec.cc) add_executable(benchmark_ppshituv2_det ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_det.cc) +add_executable(benchmark_pp3d_cadnn ${PROJECT_SOURCE_DIR}/benchmark_pp3d_cadnn.cc) add_executable(benchmark_pp3d_centerpoint ${PROJECT_SOURCE_DIR}/benchmark_pp3d_centerpoint.cc) if(UNIX AND (NOT APPLE) AND (NOT ANDROID)) @@ -34,6 +35,7 @@ if(UNIX AND (NOT APPLE) AND (NOT ANDROID)) target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags pthread) + target_link_libraries(benchmark_pp3d_cadnn ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags pthread) else() target_link_libraries(benchmark ${FASTDEPLOY_LIBS} gflags) @@ -46,6 +48,7 @@ else() target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags) + target_link_libraries(benchmark_pp3d_cadnn ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags) endif() # only for Android ADB test diff --git a/benchmark/paddlex/benchmark_gpu.sh b/benchmark/paddlex/benchmark_gpu.sh index 87358ad2b8..92b4b33c49 100755 --- a/benchmark/paddlex/benchmark_gpu.sh +++ b/benchmark/paddlex/benchmark_gpu.sh @@ -41,9 +41,17 @@ fi # PP-ShiTuV2 ./benchmark_ppshituv2_rec --model PP-ShiTuv2-rec --image ppshituv2_wangzai.png --config_path $CONFIG_PATH +./benchmark_ppshituv2_det --model PP-ShiTuv2-det --image ppdet_det_img.jpg --config_path $CONFIG_PATH # PP-StructureV2 ./benchmark_structurev2_layout --model PP-Structurev2-layout --image structurev2_layout_val_0002.jpg --config_path $CONFIG_PATH ./benchmark_structurev2_table --model PP-Structurev2-SLANet --image structurev2_table.jpg --table_char_dict_path table_structure_dict_ch.txt --config_path $CONFIG_PATH +./benchmark --model PP-Structurev2-vi-layoutxlm --shapes "1,512:1,512,4:1,512:1,512" --trt_shapes "1,512:1,512:1,512:1,512,4:1,512,4:1,512,4:1,512:1,512:1,512:1,512:1,512:1,512" --names "x_0:x_1:x_2:x_3" --dtypes "INT64:INT64:INT64:INT64" --disable_mkldnn --custom_tensor_value 0.2 --config_path $CONFIG_PATH + +# Paddle3D +./benchmark --model PETRv1_v99 --shapes "1,6,3,320,800:1,6,4,4" --names "images:img2lidars" --dtypes "FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH +./benchmark --model PETRv2_v99 --shapes "1,12,3,320,800:1,12,4,4:1,12" --names "images:img2lidars:timestamps" --dtypes "FP32:FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH +./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH +./benchmark_pp3d_cadnn --model CADNN_OCRNet-HRNetW18 --image paddle3d_cadnn_kitti_000780.png --config_path $CONFIG_PATH set +x diff --git a/benchmark/paddlex/benchmark_gpu_trt.sh b/benchmark/paddlex/benchmark_gpu_trt.sh index da67d42ffc..4d449454ce 100755 --- a/benchmark/paddlex/benchmark_gpu_trt.sh +++ b/benchmark/paddlex/benchmark_gpu_trt.sh @@ -43,9 +43,17 @@ fi # PP-ShiTuV2 ./benchmark_ppshituv2_rec --model PP-ShiTuv2-rec --image ppshituv2_wangzai.png --config_path $CONFIG_PATH +./benchmark_ppshituv2_det --model PP-ShiTuv2-det --image ppdet_det_img.jpg --config_path $CONFIG_PATH # PP-StructureV2 ./benchmark_structurev2_layout --model PP-Structurev2-layout --image structurev2_layout_val_0002.jpg --config_path $CONFIG_PATH ./benchmark_structurev2_table --model PP-Structurev2-SLANet --image structurev2_table.jpg --table_char_dict_path table_structure_dict_ch.txt --config_path $CONFIG_PATH +./benchmark --model PP-Structurev2-vi-layoutxlm --shapes "1,512:1,512,4:1,512:1,512" --trt_shapes "1,512:1,512:1,512:1,512,4:1,512,4:1,512,4:1,512:1,512:1,512:1,512:1,512:1,512" --names "x_0:x_1:x_2:x_3" --dtypes "INT64:INT64:INT64:INT64" --disable_mkldnn --custom_tensor_value 0.2 --collect_trt_shape_by_custom_tensor_value --collect_trt_shape_by_device --config_path $CONFIG_PATH + +# Paddle3D +./benchmark --model PETRv1_v99 --shapes "1,6,3,320,800:1,6,4,4" --trt_shapes "1,6,3,320,800:1,6,3,320,800:1,6,3,320,800:1,6,4,4:1,6,4,4:1,6,4,4" --names "images:img2lidars" --dtypes "FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH +./benchmark --model PETRv2_v99 --shapes "1,12,3,320,800:1,12,4,4:1,12" --trt_shapes "1,12,3,320,800:1,12,3,320,800:1,12,3,320,800:1,12,4,4:1,12,4,4:1,12,4,4:1,12:1,12:1,12" --names "images:img2lidars:timestamps" --dtypes "FP32:FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH +./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH +./benchmark_pp3d_cadnn --model CADNN_OCRNet-HRNetW18 --image paddle3d_cadnn_kitti_000780.png --config_path $CONFIG_PATH set +x diff --git a/benchmark/paddlex/benchmark_pp3d_cadnn.cc b/benchmark/paddlex/benchmark_pp3d_cadnn.cc new file mode 100644 index 0000000000..d16787bd44 --- /dev/null +++ b/benchmark/paddlex/benchmark_pp3d_cadnn.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "flags.h" +#include "macros.h" +#include "option.h" + +namespace vision = fastdeploy::vision; +namespace benchmark = fastdeploy::benchmark; + +int main(int argc, char* argv[]) { +#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION) + // Initialization + auto option = fastdeploy::RuntimeOption(); + if (!CreateRuntimeOption(&option, argc, argv, true)) { + return -1; + } + auto im = cv::imread(FLAGS_image); + std::unordered_map config_info; + benchmark::ResultManager::LoadBenchmarkConfig(FLAGS_config_path, + &config_info); + std::string model_name, params_name, config_name; + auto model_format = fastdeploy::ModelFormat::PADDLE; + if (!UpdateModelResourceName(&model_name, ¶ms_name, &config_name, + &model_format, config_info, false)) { + return -1; + } + auto model_file = FLAGS_model + sep + model_name; + auto params_file = FLAGS_model + sep + params_name; + std::vector cam_data{7.183351e+02, 0.000000e+00, 6.003891e+02, + 4.450382e+01, 0.000000e+00, 7.183351e+02, + 1.815122e+02, -5.951107e-01, 0.000000e+00, + 0.000000e+00, 1.000000e+00, 2.616315e-03}; + std::vector lidar_data = { + 0.0048523, -0.9999298, -0.01081266, -0.00711321, + -0.00302069, 0.01079808, -0.99993706, -0.06176636, + 0.99998367, 0.00488465, -0.00296808, -0.26739058, + 0., 0., 0., 1.}; + if (config_info["backend"] == "paddle_trt") { + option.paddle_infer_option.collect_trt_shape = true; + option.paddle_infer_option.collect_trt_shape_by_device = true; + option.paddle_infer_option.trt_min_subgraph_size = 12; + option.paddle_infer_option.DisableTrtOps({"squeeze2"}); + option.trt_option.max_batch_size = 1; + } + if (config_info["backend"] == "paddle_trt" || + config_info["backend"] == "trt") { + // use custom data to perform collect shapes. + option.trt_option.SetShape("images", {1, 3, 375, 1242}, + {1, 3, 375, 1242}, {1, 3, 375, 1242}); + option.trt_option.SetShape("trans_lidar_to_cam", {1, 4, 4}, + {1, 4, 4}, {1, 4, 4}); + option.trt_option.SetShape("trans_cam_to_img", {1, 3, 4}, + {1, 3, 4}, {1, 3, 4}); + std::vector image_data; + image_data.assign(im.data, im.data + 1*3*375*1242); + option.trt_option.SetInputData("trans_lidar_to_cam", lidar_data); + option.trt_option.SetInputData("trans_cam_to_img", cam_data); + option.trt_option.SetInputData("images", image_data); + } + auto model_cadnn = vision::perception::Caddn( + model_file, params_file, "", option, model_format); + vision::PerceptionResult res; + // Run profiling + BENCHMARK_MODEL(model_cadnn, model_cadnn.Predict(im, cam_data, lidar_data, &res)) + std::cout << res.Str() << std::endl; +#endif + + return 0; +} diff --git a/benchmark/paddlex/benchmark_pp3d_centerpoint.cc b/benchmark/paddlex/benchmark_pp3d_centerpoint.cc index f7c81c1b0d..930c178730 100644 --- a/benchmark/paddlex/benchmark_pp3d_centerpoint.cc +++ b/benchmark/paddlex/benchmark_pp3d_centerpoint.cc @@ -93,7 +93,7 @@ int main(int argc, char* argv[]) { vision::PerceptionResult res; // Run profiling BENCHMARK_MODEL(model_centerpoint, model_centerpoint.Predict(point_dir, &res)) - // std::cout << res.Str() << std::endl; + std::cout << res.Str() << std::endl; #endif return 0; diff --git a/benchmark/paddlex/benchmark_x86.sh b/benchmark/paddlex/benchmark_x86.sh index bdcfd8e4fd..ca668bb60a 100755 --- a/benchmark/paddlex/benchmark_x86.sh +++ b/benchmark/paddlex/benchmark_x86.sh @@ -41,9 +41,17 @@ fi # PP-ShiTuV2 ./benchmark_ppshituv2_rec --model PP-ShiTuv2-rec --image ppshituv2_wangzai.png --config_path $CONFIG_PATH +./benchmark_ppshituv2_det --model PP-ShiTuv2-det --image ppdet_det_img.jpg --config_path $CONFIG_PATH # PP-StructureV2 ./benchmark_structurev2_layout --model PP-Structurev2-layout --image structurev2_layout_val_0002.jpg --config_path $CONFIG_PATH ./benchmark_structurev2_table --model PP-Structurev2-SLANet --image structurev2_table.jpg --table_char_dict_path table_structure_dict_ch.txt --config_path $CONFIG_PATH +./benchmark --model PP-Structurev2-vi-layoutxlm --shapes "1,512:1,512,4:1,512:1,512" --names "x_0:x_1:x_2:x_3" --dtypes "INT64:INT64:INT64:INT64" --disable_mkldnn --custom_tensor_value 0.2 --config_path $CONFIG_PATH + +# Paddle3D +./benchmark --model PETRv1_v99 --config_path $CONFIG_PATH --shapes "1,6,3,320,800:1,6,4,4" --names "images:img2lidars" --dtypes "FP32:FP32" --disable_mkldnn --warmup 5 --repeat 20 +./benchmark --model PETRv2_v99 --config_path $CONFIG_PATH --shapes "1,12,3,320,800:1,12,4,4:1,12" --names "images:img2lidars:timestamps" --dtypes "FP32:FP32:FP32" --disable_mkldnn --warmup 5 --repeat 20 +./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH +./benchmark_pp3d_cadnn --model CADNN_OCRNet-HRNetW18 --image paddle3d_cadnn_kitti_000780.png --config_path $CONFIG_PATH set +x \ No newline at end of file diff --git a/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cc b/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cc index 42d4e8dc23..e931b9559e 100644 --- a/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cc +++ b/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cc @@ -118,4 +118,5 @@ PD_BUILD_OP(centerpoint_postprocess) .SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::PostProcessInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::PostProcessInferDtype)); -#endif // WITH_GPU \ No newline at end of file +#endif // WITH_GPU + diff --git a/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cu b/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cu index 05d41b02ce..1edeedb5bb 100644 --- a/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cu +++ b/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cu @@ -220,7 +220,7 @@ std::vector postprocess_gpu( // nms // in NmsLauncher, rot = - theta - pi / 2 - const int col_blocks = DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS); + int col_blocks = DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS); auto nms_mask = paddle::empty({num_bboxes_for_nms * col_blocks}, paddle::DataType::INT64, paddle::GPUPlace()); int64_t *nms_mask_data = nms_mask.data(); @@ -291,4 +291,4 @@ std::vector postprocess_gpu( } } // namespace fastdeploy -} // namespace paddle_custom_ops \ No newline at end of file +} // namespace paddle_custom_ops diff --git a/fastdeploy/runtime/backends/paddle/ops/grid_sample_3d.cc b/fastdeploy/runtime/backends/paddle/ops/grid_sample_3d.cc new file mode 100644 index 0000000000..9ff4d47f86 --- /dev/null +++ b/fastdeploy/runtime/backends/paddle/ops/grid_sample_3d.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(WITH_GPU) + +#include "grid_sample_3d.h" + +#include + +#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x) +#include "paddle/include/experimental/ext_all.h" +#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x) +#include "paddle/include/paddle/extension.h" +#else +#include "paddle/extension.h" +#endif + +namespace fastdeploy { +namespace paddle_custom_ops { + +std::vector GridSample3DCUDAForward( + const paddle::Tensor& x, const paddle::Tensor& grid, + const std::string& mode, const std::string& padding_mode, + bool align_corners); + +std::vector GridSample3DForward(const paddle::Tensor& x, + const paddle::Tensor& grid, + const std::string& mode, + const std::string& padding_mode, + bool align_corners) { + return GridSample3DCUDAForward(x, grid, mode, padding_mode, align_corners); +} + +std::vector GridSample3DCUDABackward( + const paddle::Tensor& x, const paddle::Tensor& grid, + const paddle::Tensor& grad_out, const std::string& mode, + const std::string& padding_mode, bool align_corners); + +std::vector GridSample3DBackward( + const paddle::Tensor& x, const paddle::Tensor& grid, + const paddle::Tensor& grad_out, const std::string& mode, + const std::string& padding_mode, bool align_corners) { + return GridSample3DCUDABackward(x, grid, grad_out, mode, padding_mode, + align_corners); +} + +std::vector> GridSample3DInferShape( + std::vector x_shape, std::vector grid_shape) { + return { + {x_shape[0], x_shape[1], grid_shape[1], grid_shape[2], grid_shape[3]}}; +} + +std::vector> GridSample3DInferBackShape( + std::vector x_shape, std::vector grid_shape) { + return {x_shape}; +} + +std::vector GridSample3DInferDtype( + paddle::DataType x_dtype, paddle::DataType grid_dtype) { + return {x_dtype}; +} + +} // namespace fastdeploy +} // namespace paddle_custom_ops + +PD_BUILD_OP(grid_sample_3d) + .Inputs({"x", "grid"}) + .Attrs({"mode: std::string", "padding_mode: std::string", + "align_corners: bool"}) + .Outputs({"out"}) + .SetKernelFn(PD_KERNEL(fastdeploy::paddle_custom_ops::GridSample3DForward)) + .SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::GridSample3DInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::GridSample3DInferDtype)); + +PD_BUILD_GRAD_OP(grid_sample_3d) + .Inputs({"x", "grid", paddle::Grad("out")}) + .Attrs({"mode: std::string", "padding_mode: std::string", + "align_corners: bool"}) + .Outputs({paddle::Grad("x")}) + .SetKernelFn(PD_KERNEL(fastdeploy::paddle_custom_ops::GridSample3DBackward)) + .SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::GridSample3DInferBackShape)); + +#endif \ No newline at end of file diff --git a/fastdeploy/runtime/backends/paddle/ops/grid_sample_3d.cu b/fastdeploy/runtime/backends/paddle/ops/grid_sample_3d.cu new file mode 100644 index 0000000000..0176c908ab --- /dev/null +++ b/fastdeploy/runtime/backends/paddle/ops/grid_sample_3d.cu @@ -0,0 +1,657 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "grid_sample_3d.h" + +#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x) +#include "paddle/include/experimental/ext_all.h" +#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x) +#include "paddle/include/paddle/extension.h" +#else +#include "paddle/extension.h" +#endif + +namespace fastdeploy { +namespace paddle_custom_ops { + +#define CHECK_INPUT_GPU(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") + +static __forceinline__ __device__ bool InBounds3D(int64_t d, int64_t h, + int64_t w, int64_t D, + int64_t H, int64_t W) { + return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; +} + +#define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \ + index_type _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x; \ + for (index_type i = _i_n_d_e_x; _i_n_d_e_x < (n); \ + _i_n_d_e_x += blockDim.x * gridDim.x, i = _i_n_d_e_x) + +#define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int) + +template +static __forceinline__ __device__ T Unnormalize(T coord, int size, + bool align_corners) { + if (align_corners) { + return ((coord + 1.f) / 2) * (size - 1); + } else { + return ((coord + 1.f) * size - 1) / 2; + } +} + +template +static __forceinline__ __device__ T ClipIndexes(T in, int max_value) { + return min(static_cast(max_value), max(in, static_cast(0))); +} + +template +static __forceinline__ __device__ T ReflectIndexes(T in, int twice_low, + int twice_high) { + if (twice_low == twice_high) { + return static_cast(0); + } + T min = static_cast(twice_low) / 2; + T span = static_cast(twice_high - twice_low) / 2; + in = fabs(in - min); + T extra = fmod(in, span); + int flips = static_cast(floor(in / span)); + if (flips % 2 == 0) { + return extra + min; + } else { + return span - extra + min; + } +} + +template +static __forceinline__ __device__ T ComputePositions(T coord, int size, + PaddingMode padding_mode, + bool align_corners) { + coord = Unnormalize(coord, size, align_corners); + if (padding_mode == PaddingMode::border) { + coord = ClipIndexes(coord, size - 1); + } else if (padding_mode == PaddingMode::reflect) { + if (align_corners) { + coord = ReflectIndexes(coord, 0, 2 * (size - 1)); + } else { + coord = ReflectIndexes(coord, -1, 2 * size - 1); + } + coord = ClipIndexes(coord, size - 1); + } + return coord; +} + +template +__global__ void GridSample3DCudaKernel( + const index_t nthreads, index_t out_c, index_t out_d, index_t out_h, + index_t out_w, index_t in_d, index_t in_h, index_t in_w, const T* input, + const T* grid, T* output, const Mode interpolation_mode, + const PaddingMode padding_mode, bool align_corners) { + // printf("size: %d, %d, %d, %d, %d, %d \n", out_c, out_d, out_w, out_h, in_d, + // in_w); + index_t inp_sW = 1; + index_t inp_sH = in_w; + index_t inp_sD = in_h * in_w; + index_t inp_sC = in_d * inp_sD; + index_t inp_sN = out_c * inp_sC; + + index_t grid_sCoor = 1; + index_t grid_sW = 3; + index_t grid_sH = out_w * grid_sW; + index_t grid_sD = out_h * grid_sH; + index_t grid_sN = out_d * grid_sD; + + index_t out_sW = 1; + index_t out_sH = out_w; + index_t out_sD = out_h * out_w; + index_t out_sC = out_d * out_sD; + index_t out_sN = out_c * out_sC; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) { + const index_t w = index % out_w; + const index_t h = (index / out_w) % out_h; + const index_t d = (index / (out_h * out_w)) % out_d; + const index_t n = index / (out_d * out_h * out_w); + const index_t grid_offset = + n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + // get the corresponding input x, y, z co-ordinates from grid + T ix = grid[grid_offset]; + T iy = grid[grid_offset + grid_sCoor]; + T iz = grid[grid_offset + 2 * grid_sCoor]; + ix = ComputePositions(ix, in_w, padding_mode, align_corners); + iy = ComputePositions(iy, in_h, padding_mode, align_corners); + iz = ComputePositions(iz, in_d, padding_mode, align_corners); + // printf("ix: %f, iy: %f, iz: %f \n", ix, iy, iz); + if (interpolation_mode == Mode::bilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + index_t ix_tnw = static_cast(std::floor(ix)); + index_t iy_tnw = static_cast(std::floor(iy)); + index_t iz_tnw = static_cast(std::floor(iz)); + + index_t ix_tne = ix_tnw + 1; + index_t iy_tne = iy_tnw; + index_t iz_tne = iz_tnw; + + index_t ix_tsw = ix_tnw; + index_t iy_tsw = iy_tnw + 1; + index_t iz_tsw = iz_tnw; + + index_t ix_tse = ix_tnw + 1; + index_t iy_tse = iy_tnw + 1; + index_t iz_tse = iz_tnw; + + index_t ix_bnw = ix_tnw; + index_t iy_bnw = iy_tnw; + index_t iz_bnw = iz_tnw + 1; + + index_t ix_bne = ix_tnw + 1; + index_t iy_bne = iy_tnw; + index_t iz_bne = iz_tnw + 1; + + index_t ix_bsw = ix_tnw; + index_t iy_bsw = iy_tnw + 1; + index_t iz_bsw = iz_tnw + 1; + + index_t ix_bse = ix_tnw + 1; + index_t iy_bse = iy_tnw + 1; + index_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + T tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + T tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + T tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + T tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + T bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + T bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + T bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + T bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + auto inp_ptr_NC = input + n * inp_sN; + auto out_ptr_NCDHW = + output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (index_t c = 0; c < out_c; + ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { + *out_ptr_NCDHW = static_cast(0); + if (InBounds3D(iz_tnw, iy_tnw, ix_tnw, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * + tnw; + } + if (InBounds3D(iz_tne, iy_tne, ix_tne, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * + tne; + } + if (InBounds3D(iz_tsw, iy_tsw, ix_tsw, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * + tsw; + } + if (InBounds3D(iz_tse, iy_tse, ix_tse, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * + tse; + } + if (InBounds3D(iz_bnw, iy_bnw, ix_bnw, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * + bnw; + } + if (InBounds3D(iz_bne, iy_bne, ix_bne, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * + bne; + } + if (InBounds3D(iz_bsw, iy_bsw, ix_bsw, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * + bsw; + } + if (InBounds3D(iz_bse, iy_bse, ix_bse, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * + bse; + } + } + } else if (interpolation_mode == Mode::nearest) { + index_t ix_nearest = static_cast(std::round(ix)); + index_t iy_nearest = static_cast(std::round(iy)); + index_t iz_nearest = static_cast(std::round(iz)); + + // assign nearest neighor pixel value to output pixel + auto inp_ptr_NC = input + n * inp_sN; + auto out_ptr_NCDHW = + output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (index_t c = 0; c < out_c; + ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { + if (InBounds3D(iz_nearest, iy_nearest, ix_nearest, in_d, in_h, in_w)) { + *out_ptr_NCDHW = + inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + + ix_nearest * inp_sW]; + } else { + *out_ptr_NCDHW = static_cast(0); + } + } + } + } +} + +std::vector GridSample3DCUDAForward( + const paddle::Tensor& x, const paddle::Tensor& grid, + const std::string& mode, const std::string& padding_mode, + bool align_corners) { + CHECK_INPUT_GPU(x); + CHECK_INPUT_GPU(grid); + PaddingMode enum_padding_mode; + Mode enum_mode; + if (padding_mode == "border") { + enum_padding_mode = PaddingMode::border; + } else if (padding_mode == "reflection") { + enum_padding_mode = PaddingMode::reflect; + } else { + enum_padding_mode = PaddingMode::zeros; + } + + if (mode == "nearest") { + enum_mode = Mode::nearest; + } else { + enum_mode = Mode::bilinear; + } + const int n = grid.shape()[0]; + const int out_d = grid.shape()[1]; + const int out_h = grid.shape()[2]; + const int out_w = grid.shape()[3]; + const int c = x.shape()[1]; + const int in_d = x.shape()[2]; + const int in_h = x.shape()[3]; + const int in_w = x.shape()[4]; + + auto output = paddle::full({n, c, out_d, out_h, out_w}, 0, + paddle::DataType::FLOAT32, paddle::GPUPlace()); + const int count = static_cast(n * out_d * out_h * out_w); + + int max_threads_per_block = 512; + int block_num = (count - 1) / max_threads_per_block + 1; + // printf("size: %d, %d, %d, %d, %d, %d \n", n, c, out_d, out_h, count, + // block_num); + GridSample3DCudaKernel + <<>>( + count, c, out_d, out_h, out_w, in_d, in_h, in_w, x.data(), + grid.data(), output.data(), enum_mode, + enum_padding_mode, align_corners); + + cudaError_t error_check; + error_check = cudaGetLastError(); + if (error_check != cudaSuccess) { + printf("%s\n", cudaGetErrorString(error_check)); + } + // printf("size: %d, %d, %d, %d, %d, %d \n", n, c, out_d, out_h, count, + // block_num); + return {output}; +} + +template +static __forceinline__ __device__ T UnnormalizeWithMask(T coord, int size, + bool align_corners, + T* grad_in) { + if (align_corners) { + *grad_in = static_cast(size - 1) / 2; + return ((coord + 1.f) / 2) * (size - 1); + } else { + *grad_in = static_cast(size) / 2; + return ((coord + 1.f) * size - 1) / 2; + } +} + +template +static __forceinline__ __device__ T ClipIndexesWithMask(T in, int clip_limit, + T* grad_in) { + if (in <= static_cast(0)) { + *grad_in = static_cast(0); + return static_cast(0); + } else { + T max = static_cast(clip_limit - 1); + if (in >= max) { + *grad_in = static_cast(0); + return max; + } else { + *grad_in = static_cast(1); + return in; + } + } +} + +template +static __forceinline__ __device__ T ReflectIndexesWithMask(T in, int twice_low, + int twice_high, + T* grad_in) { + if (twice_low == twice_high) { + *grad_in = static_cast(0); + return static_cast(0); + } + int grad_in_mult_; + T min = static_cast(twice_low) / 2; + T span = static_cast(twice_high - twice_low) / 2; + in = in - min; + if (in < static_cast(0)) { + grad_in_mult_ = -1; + in = -in; + } else { + grad_in_mult_ = 1; + } + T extra = fmod(in, span); + int flips = static_cast(floor(in / span)); + if (flips % 2 == 0) { + *grad_in = static_cast(grad_in_mult_); + return extra + min; + } else { + *grad_in = static_cast(-grad_in_mult_); + return span - extra + min; + } +} + +template +static __forceinline__ __device__ T +ComputePositionsWithMask(T coord, int size, PaddingMode padding_mode, + bool align_corners, T* grad_in) { + T grad_clip, grad_refl; + coord = UnnormalizeWithMask(coord, size, align_corners, grad_in); + if (padding_mode == PaddingMode::border) { + coord = ClipIndexesWithMask(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_clip; + } else if (padding_mode == PaddingMode::reflect) { + if (align_corners) { + coord = ReflectIndexesWithMask(coord, 0, 2 * (size - 1), &grad_refl); + } else { + coord = ReflectIndexesWithMask(coord, -1, 2 * size - 1, &grad_refl); + } + coord = ClipIndexesWithMask(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_refl * grad_clip; + } + + return coord; +} + +template +static __forceinline__ __device__ void AtomicAdd3D( + T* data, int64_t d, int64_t h, int64_t w, int64_t sD, int64_t sH, + int64_t sW, int64_t D, int64_t H, int64_t W, T delta) { + if (InBounds3D(d, h, w, D, H, W)) { + atomicAdd(data + d * sD + h * sH + w * sW, delta); + } +} + +template +__global__ void GridSample3DCudaBackwardKernel( + const index_t nthreads, const T* grad_output, const T* input, const T* grid, + index_t out_c, index_t out_d, index_t out_h, index_t out_w, index_t in_d, + index_t in_h, index_t in_w, T* grad_input, T* grad_grid, const Mode mode, + const PaddingMode padding_mode, bool align_corners) { + index_t inp_sW = 1; + index_t inp_sH = in_w; + index_t inp_sD = in_h * in_w; + index_t inp_sC = in_d * inp_sD; + index_t inp_sN = out_c * inp_sC; + + index_t grid_sCoor = 1; + index_t grid_sW = 3; + index_t grid_sH = out_w * grid_sW; + index_t grid_sD = out_h * grid_sH; + index_t grid_sN = out_d * grid_sD; + + index_t gOut_sW = 1; + index_t gOut_sH = out_w; + index_t gOut_sD = out_h * out_w; + index_t gOut_sC = out_d * gOut_sD; + index_t gOut_sN = out_c * gOut_sC; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) { + const index_t w = index % out_w; + const index_t h = (index / out_w) % out_h; + const index_t d = (index / (out_h * out_w)) % out_d; + const index_t n = index / (out_d * out_h * out_w); + const auto grid_offset = + n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y, z co-ordinates from grid + T ix = grid[grid_offset]; + T iy = grid[grid_offset + grid_sCoor]; + T iz = grid[grid_offset + 2 * grid_sCoor]; + + // multipliers for gradients on ix, iy, and iz + T gix_mult, giy_mult, giz_mult; + ix = ComputePositionsWithMask(ix, in_w, padding_mode, align_corners, + &gix_mult); + iy = ComputePositionsWithMask(iy, in_h, padding_mode, align_corners, + &giy_mult); + iz = ComputePositionsWithMask(iz, in_d, padding_mode, align_corners, + &giz_mult); + + if (mode == Mode::bilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + index_t ix_tnw = static_cast(std::floor(ix)); + index_t iy_tnw = static_cast(std::floor(iy)); + index_t iz_tnw = static_cast(std::floor(iz)); + + index_t ix_tne = ix_tnw + 1; + index_t iy_tne = iy_tnw; + index_t iz_tne = iz_tnw; + + index_t ix_tsw = ix_tnw; + index_t iy_tsw = iy_tnw + 1; + index_t iz_tsw = iz_tnw; + + index_t ix_tse = ix_tnw + 1; + index_t iy_tse = iy_tnw + 1; + index_t iz_tse = iz_tnw; + + index_t ix_bnw = ix_tnw; + index_t iy_bnw = iy_tnw; + index_t iz_bnw = iz_tnw + 1; + + index_t ix_bne = ix_tnw + 1; + index_t iy_bne = iy_tnw; + index_t iz_bne = iz_tnw + 1; + + index_t ix_bsw = ix_tnw; + index_t iy_bsw = iy_tnw + 1; + index_t iz_bsw = iz_tnw + 1; + + index_t ix_bse = ix_tnw + 1; + index_t iy_bse = iy_tnw + 1; + index_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + T tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + T tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + T tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + T tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + T bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + T bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + T bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + T bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + T gix = static_cast(0), giy = static_cast(0), + giz = static_cast(0); + index_t gOut_offset = + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + index_t inp_offset_NC = n * inp_sN; + T* gInp_ptr_NC = grad_input + n * inp_sN; + for (index_t c = 0; c < out_c; ++c, gOut_offset += gOut_sC, + gInp_ptr_NC += inp_sC, inp_offset_NC += inp_sC) { + T gOut = grad_output[gOut_offset]; + + AtomicAdd3D(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, inp_sD, inp_sH, inp_sW, + in_d, in_h, in_w, tnw * gOut); + AtomicAdd3D(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, inp_sD, inp_sH, inp_sW, + in_d, in_h, in_w, tne * gOut); + AtomicAdd3D(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, inp_sD, inp_sH, inp_sW, + in_d, in_h, in_w, tsw * gOut); + AtomicAdd3D(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, inp_sD, inp_sH, inp_sW, + in_d, in_h, in_w, tse * gOut); + AtomicAdd3D(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, inp_sD, inp_sH, inp_sW, + in_d, in_h, in_w, bnw * gOut); + AtomicAdd3D(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, inp_sD, inp_sH, inp_sW, + in_d, in_h, in_w, bne * gOut); + AtomicAdd3D(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, inp_sD, inp_sH, inp_sW, + in_d, in_h, in_w, bsw * gOut); + AtomicAdd3D(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, inp_sD, inp_sH, inp_sW, + in_d, in_h, in_w, bse * gOut); + + // calculate grad_grid + if (InBounds3D(iz_tnw, iy_tnw, ix_tnw, in_d, in_h, in_w)) { + T tnw_val = input[inp_offset_NC + iz_tnw * inp_sD + iy_tnw * inp_sH + + ix_tnw * inp_sW]; + gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut; + giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut; + giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut; + } + if (InBounds3D(iz_tne, iy_tne, ix_tne, in_d, in_h, in_w)) { + T tne_val = input[inp_offset_NC + iz_tne * inp_sD + iy_tne * inp_sH + + ix_tne * inp_sW]; + gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut; + giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut; + giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut; + } + if (InBounds3D(iz_tsw, iy_tsw, ix_tsw, in_d, in_h, in_w)) { + T tsw_val = input[inp_offset_NC + iz_tsw * inp_sD + iy_tsw * inp_sH + + ix_tsw * inp_sW]; + gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut; + giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut; + giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut; + } + if (InBounds3D(iz_tse, iy_tse, ix_tse, in_d, in_h, in_w)) { + T tse_val = input[inp_offset_NC + iz_tse * inp_sD + iy_tse * inp_sH + + ix_tse * inp_sW]; + gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut; + giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut; + giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut; + } + if (InBounds3D(iz_bnw, iy_bnw, ix_bnw, in_d, in_h, in_w)) { + T bnw_val = input[inp_offset_NC + iz_bnw * inp_sD + iy_bnw * inp_sH + + ix_bnw * inp_sW]; + gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut; + giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut; + giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut; + } + if (InBounds3D(iz_bne, iy_bne, ix_bne, in_d, in_h, in_w)) { + T bne_val = input[inp_offset_NC + iz_bne * inp_sD + iy_bne * inp_sH + + ix_bne * inp_sW]; + gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut; + giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut; + giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut; + } + if (InBounds3D(iz_bsw, iy_bsw, ix_bsw, in_d, in_h, in_w)) { + T bsw_val = input[inp_offset_NC + iz_bsw * inp_sD + iy_bsw * inp_sH + + ix_bsw * inp_sW]; + gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut; + giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut; + giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut; + } + if (InBounds3D(iz_bse, iy_bse, ix_bse, in_d, in_h, in_w)) { + T bse_val = input[inp_offset_NC + iz_bse * inp_sD + iy_bse * inp_sH + + ix_bse * inp_sW]; + gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut; + giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut; + giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut; + } + } + if (grad_grid != nullptr) { + T* gGrid_ptr_NDHW = grad_grid + index * grid_sW; + gGrid_ptr_NDHW[0] = gix_mult * gix; + gGrid_ptr_NDHW[1] = giy_mult * giy; + gGrid_ptr_NDHW[2] = giz_mult * giz; + } + } else if (mode == Mode::nearest) { + auto ix_nearest = static_cast(std::round(ix)); + auto iy_nearest = static_cast(std::round(iy)); + auto iz_nearest = static_cast(std::round(iz)); + + // assign nearest neighor pixel value to output pixel + index_t gOut_offset = + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + T* gInp_ptr_NC = grad_input + n * inp_sN; + for (index_t c = 0; c < out_c; + ++c, gOut_offset += gOut_sC, gInp_ptr_NC += inp_sC) { + AtomicAdd3D(gInp_ptr_NC, iz_nearest, iy_nearest, ix_nearest, inp_sD, + inp_sH, inp_sW, in_d, in_h, in_w, grad_output[gOut_offset]); + } + if (grad_grid != nullptr) { + T* gGrid_ptr_NDHW = grad_grid + index * grid_sW; + gGrid_ptr_NDHW[0] = static_cast(0); + gGrid_ptr_NDHW[1] = static_cast(0); + gGrid_ptr_NDHW[2] = static_cast(0); + } + } + } +} + +std::vector GridSample3DCUDABackward( + const paddle::Tensor& x, const paddle::Tensor& grid, + const paddle::Tensor& grad_out, const std::string& mode, + const std::string& padding_mode, bool align_corners) { + PaddingMode enum_padding_mode; + Mode enum_mode; + if (padding_mode == "border") { + enum_padding_mode = PaddingMode::border; + } else if (padding_mode == "reflection") { + enum_padding_mode = PaddingMode::reflect; + } else { + enum_padding_mode = PaddingMode::zeros; + } + + if (mode == "nearest") { + enum_mode = Mode::nearest; + } else { + enum_mode = Mode::bilinear; + } + + const int out_d = grid.shape()[1]; + const int out_h = grid.shape()[2]; + const int out_w = grid.shape()[3]; + const int n = x.shape()[0]; + const int c = x.shape()[1]; + const int in_d = x.shape()[2]; + const int in_h = x.shape()[3]; + const int in_w = x.shape()[4]; + + auto grid_grad_output = + paddle::empty({n, out_d, out_h, out_w, 3}, paddle::DataType::FLOAT32, + paddle::GPUPlace()); + auto x_grad_output = + paddle::full({n, c, in_d, in_h, in_w}, 0, paddle::DataType::FLOAT32, + paddle::GPUPlace()); + + const int count = static_cast(n * out_d * out_h * out_w); + + int max_threads_per_block = 512; + int block_num = (count - 1) / max_threads_per_block + 1; + + GridSample3DCudaBackwardKernel + <<>>( + count, grad_out.data(), x.data(), grid.data(), c, + out_d, out_h, out_w, in_d, in_h, in_w, x_grad_output.data(), + grid_grad_output.data(), enum_mode, enum_padding_mode, + align_corners); + + return {x_grad_output}; +} + +} // namespace fastdeploy +} // namespace paddle_custom_ops \ No newline at end of file diff --git a/fastdeploy/runtime/backends/paddle/ops/grid_sample_3d.h b/fastdeploy/runtime/backends/paddle/ops/grid_sample_3d.h new file mode 100644 index 0000000000..580a890867 --- /dev/null +++ b/fastdeploy/runtime/backends/paddle/ops/grid_sample_3d.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include +#include +#include + +namespace fastdeploy { +namespace paddle_custom_ops { + +#define HOST_DEVICE __host__ __device__ +#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__ + +enum class Mode { bilinear, nearest }; + +enum class PaddingMode { zeros, border, reflect }; + +} // namespace fastdeploy +} // namespace paddle_custom_ops diff --git a/fastdeploy/runtime/backends/paddle/ops/iou3d_cpu.cc b/fastdeploy/runtime/backends/paddle/ops/iou3d_cpu.cc new file mode 100644 index 0000000000..3b6ba00bac --- /dev/null +++ b/fastdeploy/runtime/backends/paddle/ops/iou3d_cpu.cc @@ -0,0 +1,272 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* +3D Rotated IoU Calculation (CPU) +Written by Shaoshuai Shi +All Rights Reserved 2020. +*/ + +#include "iou3d_cpu.h" +#include +#include +#include + +namespace fastdeploy { +namespace paddle_custom_ops { + +static inline float min(float a, float b) { return a > b ? b : a; } + +static inline float max(float a, float b) { return a > b ? a : b; } + +#if defined(_WIN32) +#if defined(EPS) +#undef EPS +#endif +#define EPS 1e-8 +#else +static const float EPS = 1e-8; +#endif + +struct Point { + float x, y; + Point() {} + Point(double _x, double _y) { x = _x, y = _y; } + + void set(float _x, float _y) { + x = _x; + y = _y; + } + + Point operator+(const Point &b) const { + return Point(x + b.x, y + b.y); + } + + Point operator-(const Point &b) const { + return Point(x - b.x, y - b.y); + } +}; + +static inline float cross(const Point &a, const Point &b) { + return a.x * b.y - a.y * b.x; +} + +static inline float cross(const Point &p1, const Point &p2, const Point &p0) { + return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y); +} + +static inline int check_rect_cross(const Point &p1, const Point &p2, const Point &q1, + const Point &q2) { + int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) && + min(q1.x, q2.x) <= max(p1.x, p2.x) && + min(p1.y, p2.y) <= max(q1.y, q2.y) && + min(q1.y, q2.y) <= max(p1.y, p2.y); + return ret; +} + +static inline int check_in_box2d(const float *box, const Point &p) { + // params: (7) [x, y, z, dx, dy, dz, heading] + const float MARGIN = 1e-2; + + float center_x = box[0], center_y = box[1]; + float angle_cos = cos(-box[6]), + angle_sin = + sin(-box[6]); // rotate the point in the opposite direction of box + float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin); + float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos; + + return (fabs(rot_x) < box[3] / 2 + MARGIN && + fabs(rot_y) < box[4] / 2 + MARGIN); +} + +static inline int intersection(const Point &p1, const Point &p0, const Point &q1, + const Point &q0, Point &ans) { + // fast exclusion + if (check_rect_cross(p0, p1, q0, q1) == 0) return 0; + + // check cross standing + float s1 = cross(q0, p1, p0); + float s2 = cross(p1, q1, p0); + float s3 = cross(p0, q1, q0); + float s4 = cross(q1, p1, q0); + + if (!(s1 * s2 > 0 && s3 * s4 > 0)) return 0; + + // calculate intersection of two lines + float s5 = cross(q1, p1, p0); + if (fabs(s5 - s1) > EPS) { + ans.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1); + ans.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1); + + } else { + float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y; + float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y; + float D = a0 * b1 - a1 * b0; + + ans.x = (b0 * c1 - b1 * c0) / D; + ans.y = (a1 * c0 - a0 * c1) / D; + } + + return 1; +} + +static inline void rotate_around_center(const Point ¢er, const float angle_cos, + const float angle_sin, Point &p) { + float new_x = + (p.x - center.x) * angle_cos + (p.y - center.y) * (-angle_sin) + center.x; + float new_y = + (p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y; + p.set(new_x, new_y); +} + +static inline int point_cmp(const Point &a, const Point &b, const Point ¢er) { + return atan2(a.y - center.y, a.x - center.x) > + atan2(b.y - center.y, b.x - center.x); +} + +static inline float box_overlap(const float *box_a, const float *box_b) { + // params: box_a (7) [x, y, z, dx, dy, dz, heading] + // params: box_b (7) [x, y, z, dx, dy, dz, heading] + + // float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 = + // box_a[3], a_angle = box_a[4]; + // float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 = + // box_b[3], b_angle = box_b[4]; + float a_angle = box_a[6], b_angle = box_b[6]; + float a_dx_half = box_a[3] / 2, b_dx_half = box_b[3] / 2, + a_dy_half = box_a[4] / 2, b_dy_half = box_b[4] / 2; + float a_x1 = box_a[0] - a_dx_half, a_y1 = box_a[1] - a_dy_half; + float a_x2 = box_a[0] + a_dx_half, a_y2 = box_a[1] + a_dy_half; + float b_x1 = box_b[0] - b_dx_half, b_y1 = box_b[1] - b_dy_half; + float b_x2 = box_b[0] + b_dx_half, b_y2 = box_b[1] + b_dy_half; + + Point center_a(box_a[0], box_a[1]); + Point center_b(box_b[0], box_b[1]); + + Point box_a_corners[5]; + box_a_corners[0].set(a_x1, a_y1); + box_a_corners[1].set(a_x2, a_y1); + box_a_corners[2].set(a_x2, a_y2); + box_a_corners[3].set(a_x1, a_y2); + + Point box_b_corners[5]; + box_b_corners[0].set(b_x1, b_y1); + box_b_corners[1].set(b_x2, b_y1); + box_b_corners[2].set(b_x2, b_y2); + box_b_corners[3].set(b_x1, b_y2); + + // get oriented corners + float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle); + float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle); + + for (int k = 0; k < 4; k++) { + rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]); + rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]); + } + + box_a_corners[4] = box_a_corners[0]; + box_b_corners[4] = box_b_corners[0]; + + // get intersection of lines + Point cross_points[16]; + Point poly_center; + int cnt = 0, flag = 0; + + poly_center.set(0, 0); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + flag = intersection(box_a_corners[i + 1], box_a_corners[i], + box_b_corners[j + 1], box_b_corners[j], + cross_points[cnt]); + if (flag) { + poly_center = poly_center + cross_points[cnt]; + cnt++; + } + } + } + + // check corners + for (int k = 0; k < 4; k++) { + if (check_in_box2d(box_a, box_b_corners[k])) { + poly_center = poly_center + box_b_corners[k]; + cross_points[cnt] = box_b_corners[k]; + cnt++; + } + if (check_in_box2d(box_b, box_a_corners[k])) { + poly_center = poly_center + box_a_corners[k]; + cross_points[cnt] = box_a_corners[k]; + cnt++; + } + } + + poly_center.x /= cnt; + poly_center.y /= cnt; + + // sort the points of polygon + Point temp; + for (int j = 0; j < cnt - 1; j++) { + for (int i = 0; i < cnt - j - 1; i++) { + if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) { + temp = cross_points[i]; + cross_points[i] = cross_points[i + 1]; + cross_points[i + 1] = temp; + } + } + } + + // get the overlap areas + float area = 0; + for (int k = 0; k < cnt - 1; k++) { + area += cross(cross_points[k] - cross_points[0], + cross_points[k + 1] - cross_points[0]); + } + + return fabs(area) / 2.0; +} + +static inline float iou_bev(const float *box_a, const float *box_b) { + // params: box_a (7) [x, y, z, dx, dy, dz, heading] + // params: box_b (7) [x, y, z, dx, dy, dz, heading] + float sa = box_a[3] * box_a[4]; + float sb = box_b[3] * box_b[4]; + float s_overlap = box_overlap(box_a, box_b); + return s_overlap / fmaxf(sa + sb - s_overlap, EPS); +} + +int boxes_iou_bev_cpu(paddle::Tensor boxes_a_tensor, + paddle::Tensor boxes_b_tensor, + paddle::Tensor ans_iou_tensor) { + // params boxes_a_tensor: (N, 7) [x, y, z, dx, dy, dz, heading] + // params boxes_b_tensor: (M, 7) [x, y, z, dx, dy, dz, heading] + // params ans_iou_tensor: (N, M) + + // CHECK_CONTIGUOUS(boxes_a_tensor); + // CHECK_CONTIGUOUS(boxes_b_tensor); + + int num_boxes_a = boxes_a_tensor.shape()[0]; + int num_boxes_b = boxes_b_tensor.shape()[0]; + const float *boxes_a = boxes_a_tensor.data(); + const float *boxes_b = boxes_b_tensor.data(); + float *ans_iou = ans_iou_tensor.data(); + + for (int i = 0; i < num_boxes_a; i++) { + for (int j = 0; j < num_boxes_b; j++) { + ans_iou[i * num_boxes_b + j] = iou_bev(boxes_a + i * 7, boxes_b + j * 7); + } + } + return 1; +} + +} // namespace fastdeploy +} // namespace paddle_custom_ops diff --git a/fastdeploy/runtime/backends/paddle/ops/iou3d_cpu.h b/fastdeploy/runtime/backends/paddle/ops/iou3d_cpu.h new file mode 100644 index 0000000000..c56fee7af5 --- /dev/null +++ b/fastdeploy/runtime/backends/paddle/ops/iou3d_cpu.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x) +#include "paddle/include/experimental/ext_all.h" +#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x) +#include "paddle/include/paddle/extension.h" +#else +#include "paddle/extension.h" +#endif + +#include "fastdeploy/utils/utils.h" + +namespace fastdeploy { +namespace paddle_custom_ops { + +FASTDEPLOY_DECL int boxes_iou_bev_cpu( + paddle::Tensor boxes_a_tensor, paddle::Tensor boxes_b_tensor, + paddle::Tensor ans_iou_tensor); + +} // namespace fastdeploy +} // namespace paddle_custom_ops \ No newline at end of file diff --git a/fastdeploy/runtime/backends/paddle/ops/iou3d_nms.cc b/fastdeploy/runtime/backends/paddle/ops/iou3d_nms.cc new file mode 100644 index 0000000000..957a05686e --- /dev/null +++ b/fastdeploy/runtime/backends/paddle/ops/iou3d_nms.cc @@ -0,0 +1,237 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* +3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others) +Written by Shaoshuai Shi +All Rights Reserved 2019-2020. +*/ + +#if defined(WITH_GPU) + +#include +#include + +#include "iou3d_nms.h" + +namespace fastdeploy { +namespace paddle_custom_ops { + +#define CHECK_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") +// #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) +static inline int DIVUP(const int m, const int n) +{ return ((m) / (n) + ((m) % (n) > 0)); } + +#define CHECK_ERROR(ans) \ + { gpuAssert((ans), __FILE__, __LINE__); } +inline void gpuAssert(cudaError_t code, const char *file, int line, + bool abort = true) { + if (code != cudaSuccess) { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, + line); + if (abort) exit(code); + } +} + +#define D(x) \ + PD_THROW('\n', x, \ + "\n--------------------------------- where is the error ? " \ + "---------------------------------------\n"); + +static const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; + +void boxesoverlapLauncher(const int num_a, const float *boxes_a, + const int num_b, const float *boxes_b, + float *ans_overlap); +void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b, + const float *boxes_b, float *ans_iou); +void nmsLauncher(const float *boxes, unsigned long long *mask, int boxes_num, + float nms_overlap_thresh); +void nmsNormalLauncher(const float *boxes, unsigned long long *mask, + int boxes_num, float nms_overlap_thresh); + +int boxes_overlap_bev_gpu(paddle::Tensor boxes_a, paddle::Tensor boxes_b, + paddle::Tensor ans_overlap) { + // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading] + // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading] + // params ans_overlap: (N, M) + + CHECK_INPUT(boxes_a); + CHECK_INPUT(boxes_b); + CHECK_INPUT(ans_overlap); + + int num_a = boxes_a.shape()[0]; + int num_b = boxes_b.shape()[0]; + + const float *boxes_a_data = boxes_a.data(); + const float *boxes_b_data = boxes_b.data(); + float *ans_overlap_data = ans_overlap.data(); + + boxesoverlapLauncher(num_a, boxes_a_data, num_b, boxes_b_data, + ans_overlap_data); + + return 1; +} + +int boxes_iou_bev_gpu(paddle::Tensor boxes_a, paddle::Tensor boxes_b, + paddle::Tensor ans_iou) { + // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading] + // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading] + // params ans_overlap: (N, M) + CHECK_INPUT(boxes_a); + CHECK_INPUT(boxes_b); + CHECK_INPUT(ans_iou); + + int num_a = boxes_a.shape()[0]; + int num_b = boxes_b.shape()[0]; + + const float *boxes_a_data = boxes_a.data(); + const float *boxes_b_data = boxes_b.data(); + float *ans_iou_data = ans_iou.data(); + + boxesioubevLauncher(num_a, boxes_a_data, num_b, boxes_b_data, ans_iou_data); + + return 1; +} + +std::vector nms_gpu(const paddle::Tensor &boxes, + float nms_overlap_thresh) { + // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading] + // params keep: (N) + CHECK_INPUT(boxes); + // CHECK_CONTIGUOUS(keep); + auto keep = paddle::empty({boxes.shape()[0]}, paddle::DataType::INT32, + paddle::CPUPlace()); + auto num_to_keep_tensor = + paddle::empty({1}, paddle::DataType::INT32, paddle::CPUPlace()); + int *num_to_keep_data = num_to_keep_tensor.data(); + + int boxes_num = boxes.shape()[0]; + const float *boxes_data = boxes.data(); + int *keep_data = keep.data(); + + int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + + unsigned long long *mask_data = NULL; + CHECK_ERROR(cudaMalloc((void **)&mask_data, + boxes_num * col_blocks * sizeof(unsigned long long))); + nmsLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh); + + // unsigned long long mask_cpu[boxes_num * col_blocks]; + // unsigned long long *mask_cpu = new unsigned long long [boxes_num * + // col_blocks]; + std::vector mask_cpu(boxes_num * col_blocks); + + // printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks); + CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, + boxes_num * col_blocks * sizeof(unsigned long long), + cudaMemcpyDeviceToHost)); + + cudaFree(mask_data); + + // WARN(qiuyanjun): codes below will throw a compile error on windows with + // msvc. Thus, we choosed to use std::vectored to store the result instead. + // unsigned long long remv_cpu[col_blocks]; + // memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long)); + std::vector remv_cpu(col_blocks, 0); + + int num_to_keep = 0; + + for (int i = 0; i < boxes_num; i++) { + int nblock = i / THREADS_PER_BLOCK_NMS; + int inblock = i % THREADS_PER_BLOCK_NMS; + + if (!(remv_cpu[nblock] & (1ULL << inblock))) { + keep_data[num_to_keep++] = i; + unsigned long long *p = &mask_cpu[0] + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv_cpu[j] |= p[j]; + } + } + } + + num_to_keep_data[0] = num_to_keep; + + if (cudaSuccess != cudaGetLastError()) printf("Error!\n"); + + return {keep, num_to_keep_tensor}; +} + +int nms_normal_gpu(paddle::Tensor boxes, paddle::Tensor keep, + float nms_overlap_thresh) { + // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading] + // params keep: (N) + + CHECK_INPUT(boxes); + // CHECK_CONTIGUOUS(keep); + + int boxes_num = boxes.shape()[0]; + const float *boxes_data = boxes.data(); + // WARN(qiuyanjun): long type for Tensor::data() API is not exported by paddle, + // it will raise some link error on windows with msvc. Please check: + // https://github.com/PaddlePaddle/Paddle/blob/release/2.5/paddle/phi/api/lib/tensor.cc +#if defined(_WIN32) + int *keep_data = keep.data(); +#else + long *keep_data = keep.data(); +#endif + + int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + + unsigned long long *mask_data = NULL; + CHECK_ERROR(cudaMalloc((void **)&mask_data, + boxes_num * col_blocks * sizeof(unsigned long long))); + nmsNormalLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh); + + // unsigned long long mask_cpu[boxes_num * col_blocks]; + // unsigned long long *mask_cpu = new unsigned long long [boxes_num * + // col_blocks]; + std::vector mask_cpu(boxes_num * col_blocks); + + // printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks); + CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, + boxes_num * col_blocks * sizeof(unsigned long long), + cudaMemcpyDeviceToHost)); + + cudaFree(mask_data); + + // WARN(qiuyanjun): codes below will throw a compile error on windows with + // msvc. Thus, we choosed to use std::vectored to store the result instead. + // unsigned long long remv_cpu[col_blocks]; + // memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long)); + std::vector remv_cpu(col_blocks, 0); + + int num_to_keep = 0; + + for (int i = 0; i < boxes_num; i++) { + int nblock = i / THREADS_PER_BLOCK_NMS; + int inblock = i % THREADS_PER_BLOCK_NMS; + + if (!(remv_cpu[nblock] & (1ULL << inblock))) { + keep_data[num_to_keep++] = i; + unsigned long long *p = &mask_cpu[0] + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv_cpu[j] |= p[j]; + } + } + } + if (cudaSuccess != cudaGetLastError()) printf("Error!\n"); + + return num_to_keep; +} + +} // namespace fastdeploy +} // namespace paddle_custom_ops + +#endif \ No newline at end of file diff --git a/fastdeploy/runtime/backends/paddle/ops/iou3d_nms.h b/fastdeploy/runtime/backends/paddle/ops/iou3d_nms.h new file mode 100644 index 0000000000..c5614edf00 --- /dev/null +++ b/fastdeploy/runtime/backends/paddle/ops/iou3d_nms.h @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x) +#include "paddle/include/experimental/ext_all.h" +#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x) +#include "paddle/include/paddle/extension.h" +#else +#include "paddle/extension.h" +#endif + +#include "fastdeploy/utils/utils.h" + +#if defined(WITH_GPU) +namespace fastdeploy { +namespace paddle_custom_ops { + +FASTDEPLOY_DECL int boxes_overlap_bev_gpu( + paddle::Tensor boxes_a, paddle::Tensor boxes_b, + paddle::Tensor ans_overlap); +FASTDEPLOY_DECL int boxes_iou_bev_gpu(paddle::Tensor boxes_a, + paddle::Tensor boxes_b, + paddle::Tensor ans_iou); +FASTDEPLOY_DECL std::vector nms_gpu( + const paddle::Tensor& boxes, float nms_overlap_thresh); +FASTDEPLOY_DECL int nms_normal_gpu( + paddle::Tensor boxes, paddle::Tensor keep, float nms_overlap_thresh); + +} // namespace fastdeploy +} // namespace paddle_custom_ops + +#endif + diff --git a/fastdeploy/runtime/backends/paddle/ops/iou3d_nms_api.cc b/fastdeploy/runtime/backends/paddle/ops/iou3d_nms_api.cc new file mode 100644 index 0000000000..87b0374eae --- /dev/null +++ b/fastdeploy/runtime/backends/paddle/ops/iou3d_nms_api.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x) +#include "paddle/include/experimental/ext_all.h" +#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x) +#include "paddle/include/paddle/extension.h" +#else +#include "paddle/extension.h" +#endif + +#include + +#include "iou3d_cpu.h" +#include "iou3d_nms.h" + +namespace fastdeploy { +namespace paddle_custom_ops { + +std::vector> NMSInferShape( + std::vector boxes_shape) { + int64_t keep_num = 1; + return {{boxes_shape[0]}, {keep_num}}; +} + +std::vector NMSInferDtype(paddle::DataType boxes_dtype) { + return {paddle::DataType::INT64, paddle::DataType::INT64}; +} + +} // namespace fastdeploy +} // namespace paddle_custom_ops + +#if defined(WITH_GPU) + +PD_BUILD_OP(nms_gpu) + .Inputs({"boxes"}) + .Outputs({"keep", "num_to_keep"}) + .Attrs({"nms_overlap_thresh: float"}) + .SetKernelFn(PD_KERNEL(fastdeploy::paddle_custom_ops::nms_gpu)) + .SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::NMSInferDtype)) + .SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::NMSInferShape)); + +#endif diff --git a/fastdeploy/runtime/backends/paddle/ops/iou3d_nms_kernel.cu b/fastdeploy/runtime/backends/paddle/ops/iou3d_nms_kernel.cu index 3c34ebfe4c..dc56a50cd5 100644 --- a/fastdeploy/runtime/backends/paddle/ops/iou3d_nms_kernel.cu +++ b/fastdeploy/runtime/backends/paddle/ops/iou3d_nms_kernel.cu @@ -1,23 +1,8 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - /* 3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others) Written by Shaoshuai Shi All Rights Reserved 2019-2020. */ - #include namespace fastdeploy { @@ -78,20 +63,36 @@ __device__ int check_rect_cross(const Point &p1, const Point &p2, __device__ inline int check_in_box2d(const float *box, const Point &p) { // params: (7) [x, y, z, dx, dy, dz, heading] const float MARGIN = 1e-2; + // Align with the setting of mmdet3d + // const float MARGIN = 1e-5; float center_x = box[0], center_y = box[1]; - // rotate the point in the opposite direction of box - float angle_cos = cos(-box[6]), angle_sin = sin(-box[6]); + float angle_cos = cos(-box[6]), + angle_sin = + sin(-box[6]); // rotate the point in the opposite direction of box float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin); float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos; return (fabs(rot_x) < box[3] / 2 + MARGIN && fabs(rot_y) < box[4] / 2 + MARGIN); + // Align with the implement of mmdet3d + // float rot_x = + // (p.x - center_x) * angle_cos + (p.y - center_y) * angle_sin + center_x; + // float rot_y = + // -(p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + + // center_y; + // float x1 = center_x - box[3] / 2; + // float x2 = center_x + box[3] / 2; + // float y1 = center_y - box[4] / 2; + // float y2 = center_y + box[4] / 2; + // return (rot_x > x1 - MARGIN && rot_x < x2 + MARGIN && rot_y > y1 - MARGIN + // && + // rot_y < y2 + MARGIN); } __device__ inline int intersection(const Point &p1, const Point &p0, const Point &q1, const Point &q0, - Point *ans) { + Point &ans) { // fast exclusion if (check_rect_cross(p0, p1, q0, q1) == 0) return 0; @@ -106,16 +107,16 @@ __device__ inline int intersection(const Point &p1, const Point &p0, // calculate intersection of two lines float s5 = cross(q1, p1, p0); if (fabs(s5 - s1) > EPS) { - ans->x = (s5 * q0.x - s1 * q1.x) / (s5 - s1); - ans->y = (s5 * q0.y - s1 * q1.y) / (s5 - s1); + ans.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1); + ans.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1); } else { float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y; float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y; float D = a0 * b1 - a1 * b0; - ans->x = (b0 * c1 - b1 * c0) / D; - ans->y = (a1 * c0 - a0 * c1) / D; + ans.x = (b0 * c1 - b1 * c0) / D; + ans.y = (a1 * c0 - a0 * c1) / D; } return 1; @@ -123,12 +124,18 @@ __device__ inline int intersection(const Point &p1, const Point &p0, __device__ inline void rotate_around_center(const Point ¢er, const float angle_cos, - const float angle_sin, Point *p) { - float new_x = (p->x - center.x) * angle_cos + - (p->y - center.y) * (-angle_sin) + center.x; + const float angle_sin, Point &p) { + // float new_x = (p.x - center.x) * angle_cos + (p.y - center.y) * + // (-angle_sin) + center.x; + // float new_y = (p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + + // center.y; + // p.set(new_x, new_y); + // Aligh with the implement of mmdet3d + float new_x = + (p.x - center.x) * angle_cos + (p.y - center.y) * angle_sin + center.x; float new_y = - (p->x - center.x) * angle_sin + (p->y - center.y) * angle_cos + center.y; - p->set(new_x, new_y); + -(p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y; + p.set(new_x, new_y); } __device__ inline int point_cmp(const Point &a, const Point &b, @@ -152,6 +159,14 @@ __device__ inline float box_overlap(const float *box_a, const float *box_b) { Point center_a(box_a[0], box_a[1]); Point center_b(box_b[0], box_b[1]); +#ifdef DEBUG + printf( + "a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)\n", + a_x1, a_y1, a_x2, a_y2, a_angle, b_x1, b_y1, b_x2, b_y2, b_angle); + printf("center a: (%.3f, %.3f), b: (%.3f, %.3f)\n", center_a.x, center_a.y, + center_b.x, center_b.y); +#endif + Point box_a_corners[5]; box_a_corners[0].set(a_x1, a_y1); box_a_corners[1].set(a_x2, a_y1); @@ -169,8 +184,17 @@ __device__ inline float box_overlap(const float *box_a, const float *box_b) { float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle); for (int k = 0; k < 4; k++) { - rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners + k); - rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners + k); +#ifdef DEBUG + printf("before corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, + box_a_corners[k].x, box_a_corners[k].y, box_b_corners[k].x, + box_b_corners[k].y); +#endif + rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]); + rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]); +#ifdef DEBUG + printf("corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, box_a_corners[k].x, + box_a_corners[k].y, box_b_corners[k].x, box_b_corners[k].y); +#endif } box_a_corners[4] = box_a_corners[0]; @@ -186,10 +210,19 @@ __device__ inline float box_overlap(const float *box_a, const float *box_b) { for (int j = 0; j < 4; j++) { flag = intersection(box_a_corners[i + 1], box_a_corners[i], box_b_corners[j + 1], box_b_corners[j], - cross_points + cnt); + cross_points[cnt]); if (flag) { poly_center = poly_center + cross_points[cnt]; cnt++; +#ifdef DEBUG + printf( + "Cross points (%.3f, %.3f): a(%.3f, %.3f)->(%.3f, %.3f), b(%.3f, " + "%.3f)->(%.3f, %.3f) \n", + cross_points[cnt - 1].x, cross_points[cnt - 1].y, + box_a_corners[i].x, box_a_corners[i].y, box_a_corners[i + 1].x, + box_a_corners[i + 1].y, box_b_corners[i].x, box_b_corners[i].y, + box_b_corners[i + 1].x, box_b_corners[i + 1].y); +#endif } } } @@ -200,11 +233,19 @@ __device__ inline float box_overlap(const float *box_a, const float *box_b) { poly_center = poly_center + box_b_corners[k]; cross_points[cnt] = box_b_corners[k]; cnt++; +#ifdef DEBUG + printf("b corners in a: corner_b(%.3f, %.3f)", cross_points[cnt - 1].x, + cross_points[cnt - 1].y); +#endif } if (check_in_box2d(box_b, box_a_corners[k])) { poly_center = poly_center + box_a_corners[k]; cross_points[cnt] = box_a_corners[k]; cnt++; +#ifdef DEBUG + printf("a corners in b: corner_a(%.3f, %.3f)", cross_points[cnt - 1].x, + cross_points[cnt - 1].y); +#endif } } @@ -223,6 +264,14 @@ __device__ inline float box_overlap(const float *box_a, const float *box_b) { } } +#ifdef DEBUG + printf("cnt=%d\n", cnt); + for (int i = 0; i < cnt; i++) { + printf("All cross point %d: (%.3f, %.3f)\n", i, cross_points[i].x, + cross_points[i].y); + } +#endif + // get the overlap areas float area = 0; for (int k = 0; k < cnt - 1; k++) { @@ -242,11 +291,221 @@ __device__ inline float iou_bev(const float *box_a, const float *box_b) { return s_overlap / fmaxf(sa + sb - s_overlap, EPS); } -__global__ void nms_kernel(const int num_bboxes, const int num_bboxes_for_nms, - const float nms_overlap_thresh, - const int decode_bboxes_dims, const float *bboxes, - const int *index, const int64_t *sorted_index, - int64_t *mask) { +__global__ void boxes_overlap_kernel(const int num_a, const float *boxes_a, + const int num_b, const float *boxes_b, + float *ans_overlap) { + // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading] + // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading] + const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y; + const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x; + + if (a_idx >= num_a || b_idx >= num_b) { + return; + } + const float *cur_box_a = boxes_a + a_idx * 7; + const float *cur_box_b = boxes_b + b_idx * 7; + float s_overlap = box_overlap(cur_box_a, cur_box_b); + ans_overlap[a_idx * num_b + b_idx] = s_overlap; +} + +__global__ void boxes_iou_bev_kernel(const int num_a, const float *boxes_a, + const int num_b, const float *boxes_b, + float *ans_iou) { + // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading] + // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading] + const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y; + const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x; + + if (a_idx >= num_a || b_idx >= num_b) { + return; + } + + const float *cur_box_a = boxes_a + a_idx * 7; + const float *cur_box_b = boxes_b + b_idx * 7; + float cur_iou_bev = iou_bev(cur_box_a, cur_box_b); + ans_iou[a_idx * num_b + b_idx] = cur_iou_bev; +} + +__global__ void nms_kernel(const int boxes_num, const float nms_overlap_thresh, + const float *boxes, unsigned long long *mask) { + // params: boxes (N, 7) [x, y, z, dx, dy, dz, heading] + // params: mask (N, N/THREADS_PER_BLOCK_NMS) + + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + + __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7]; + + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 7 + 0] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 0]; + block_boxes[threadIdx.x * 7 + 1] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 1]; + block_boxes[threadIdx.x * 7 + 2] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 2]; + block_boxes[threadIdx.x * 7 + 3] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 3]; + block_boxes[threadIdx.x * 7 + 4] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 4]; + block_boxes[threadIdx.x * 7 + 5] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 5]; + block_boxes[threadIdx.x * 7 + 6] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 6]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x; + const float *cur_box = boxes + cur_box_idx * 7; + + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (iou_bev(cur_box, block_boxes + i * 7) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +__device__ inline float iou_normal(float const *const a, float const *const b) { + // params: a: [x, y, z, dx, dy, dz, heading] + // params: b: [x, y, z, dx, dy, dz, heading] + + float left = fmaxf(a[0] - a[3] / 2, b[0] - b[3] / 2), + right = fminf(a[0] + a[3] / 2, b[0] + b[3] / 2); + float top = fmaxf(a[1] - a[4] / 2, b[1] - b[4] / 2), + bottom = fminf(a[1] + a[4] / 2, b[1] + b[4] / 2); + float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f); + float interS = width * height; + float Sa = a[3] * a[4]; + float Sb = b[3] * b[4]; + return interS / fmaxf(Sa + Sb - interS, EPS); +} + +__global__ void nms_normal_kernel(const int boxes_num, + const float nms_overlap_thresh, + const float *boxes, + unsigned long long *mask) { + // params: boxes (N, 7) [x, y, z, dx, dy, dz, heading] + // params: mask (N, N/THREADS_PER_BLOCK_NMS) + + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + + __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7]; + + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 7 + 0] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 0]; + block_boxes[threadIdx.x * 7 + 1] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 1]; + block_boxes[threadIdx.x * 7 + 2] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 2]; + block_boxes[threadIdx.x * 7 + 3] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 3]; + block_boxes[threadIdx.x * 7 + 4] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 4]; + block_boxes[threadIdx.x * 7 + 5] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 5]; + block_boxes[threadIdx.x * 7 + 6] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 6]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x; + const float *cur_box = boxes + cur_box_idx * 7; + + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (iou_normal(cur_box, block_boxes + i * 7) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +void boxesoverlapLauncher(const int num_a, const float *boxes_a, + const int num_b, const float *boxes_b, + float *ans_overlap) { + dim3 blocks( + DIVUP(num_b, THREADS_PER_BLOCK), + DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK); + + boxes_overlap_kernel<<>>(num_a, boxes_a, num_b, boxes_b, + ans_overlap); +#ifdef DEBUG + cudaDeviceSynchronize(); // for using printf in kernel function +#endif +} + +void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b, + const float *boxes_b, float *ans_iou) { + dim3 blocks( + DIVUP(num_b, THREADS_PER_BLOCK), + DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK); + + boxes_iou_bev_kernel<<>>(num_a, boxes_a, num_b, boxes_b, + ans_iou); +#ifdef DEBUG + cudaDeviceSynchronize(); // for using printf in kernel function +#endif +} + +void nmsLauncher(const float *boxes, unsigned long long *mask, int boxes_num, + float nms_overlap_thresh) { + dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS), + DIVUP(boxes_num, THREADS_PER_BLOCK_NMS)); + dim3 threads(THREADS_PER_BLOCK_NMS); + nms_kernel<<>>(boxes_num, nms_overlap_thresh, boxes, mask); +} + +void nmsNormalLauncher(const float *boxes, unsigned long long *mask, + int boxes_num, float nms_overlap_thresh) { + dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS), + DIVUP(boxes_num, THREADS_PER_BLOCK_NMS)); + dim3 threads(THREADS_PER_BLOCK_NMS); + nms_normal_kernel<<>>(boxes_num, nms_overlap_thresh, boxes, + mask); +} + +__global__ void nms_kernel_centerpoint(const int num_bboxes, + const int num_bboxes_for_nms, + const float nms_overlap_thresh, + const int decode_bboxes_dims, + const float *bboxes, const int *index, + const int64_t *sorted_index, + int64_t *mask) { // params: boxes (N, 7) [x, y, z, dx, dy, dz, heading] // params: mask (N, N/THREADS_PER_BLOCK_NMS) @@ -304,7 +563,7 @@ __global__ void nms_kernel(const int num_bboxes, const int num_bboxes_for_nms, t |= 1ULL << i; } } - const int col_blocks = DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS); + int col_blocks = DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS); mask[cur_box_idx * col_blocks + col_start] = t; } } @@ -317,10 +576,10 @@ void NmsLauncher(const cudaStream_t &stream, const float *bboxes, dim3 blocks(DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS), DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS)); dim3 threads(THREADS_PER_BLOCK_NMS); - nms_kernel<<>>( + nms_kernel_centerpoint<<>>( num_bboxes, num_bboxes_for_nms, nms_overlap_thresh, decode_bboxes_dims, bboxes, index, sorted_index, mask); } } // namespace fastdeploy -} // namespace paddle_custom_ops \ No newline at end of file +} // namespace paddle_custom_ops diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index 67f2ebab3b..3da49fa5f5 100755 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -37,6 +37,7 @@ #include "fastdeploy/vision/perception/paddle3d/smoke/smoke.h" #include "fastdeploy/vision/perception/paddle3d/petr/petr.h" #include "fastdeploy/vision/perception/paddle3d/centerpoint/centerpoint.h" +#include "fastdeploy/vision/perception/paddle3d/caddn/caddn.h" #include "fastdeploy/vision/detection/ppdet/model.h" #include "fastdeploy/vision/facealign/contrib/face_landmark_1000.h" #include "fastdeploy/vision/facealign/contrib/pfld.h" diff --git a/fastdeploy/vision/perception/paddle3d/caddn/caddn.cc b/fastdeploy/vision/perception/paddle3d/caddn/caddn.cc new file mode 100644 index 0000000000..76ed81b649 --- /dev/null +++ b/fastdeploy/vision/perception/paddle3d/caddn/caddn.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/perception/paddle3d/caddn/caddn.h" + +namespace fastdeploy { +namespace vision { +namespace perception { + +Caddn::Caddn(const std::string& model_file, const std::string& params_file, + const std::string& config_file, const RuntimeOption& custom_option, + const ModelFormat& model_format) + : preprocessor_(config_file) { + valid_gpu_backends = {Backend::PDINFER}; + + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool Caddn::Initialize() { + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +bool Caddn::Predict(const cv::Mat& im, std::vector& input_cam_data, + std::vector& input_lidar_data, + PerceptionResult* result) { + std::vector results; + if (!BatchPredict({im}, input_cam_data, input_lidar_data, &results)) { + return false; + } + if (results.size()) { + *result = std::move(results[0]); + } + return true; +} + +bool Caddn::BatchPredict(const std::vector& images, + std::vector& input_cam_data, + std::vector& input_lidar_data, + std::vector* results) { + std::vector fd_images = WrapMat(images); + + if (!preprocessor_.Run(&fd_images, input_cam_data, input_lidar_data, + &reused_input_tensors_)) { + FDERROR << "Failed to preprocess the input image." << std::endl; + return false; + } + + reused_input_tensors_[0].name = "images"; + reused_input_tensors_[1].name = "trans_cam_to_img"; + reused_input_tensors_[2].name = "trans_lidar_to_cam"; + + if (!Infer(reused_input_tensors_, &reused_output_tensors_)) { + FDERROR << "Failed to inference by runtime." << std::endl; + return false; + } + + if (!postprocessor_.Run(reused_output_tensors_, results)) { + FDERROR << "Failed to postprocess the inference results by runtime." + << std::endl; + return false; + } + return true; +} + +} // namespace perception +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/perception/paddle3d/caddn/caddn.h b/fastdeploy/vision/perception/paddle3d/caddn/caddn.h new file mode 100755 index 0000000000..b622969709 --- /dev/null +++ b/fastdeploy/vision/perception/paddle3d/caddn/caddn.h @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. //NOLINT +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/perception/paddle3d/caddn/preprocessor.h" +#include "fastdeploy/vision/perception/paddle3d/caddn/postprocessor.h" + +namespace fastdeploy { +namespace vision { +namespace perception { +/*! @brief Caddn model object used when to load a Caddn model exported by Caddn. + */ +class FASTDEPLOY_DECL Caddn : public FastDeployModel { + public: + /** \brief Set path of model file and the configuration of runtime. + * + * \param[in] model_file Path of model file, e.g Caddn/model.pdiparams + * \param[in] params_file Path of parameter file, e.g Caddn/model.pdiparams, if the model format is ONNX, this parameter will be ignored + * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends" + * \param[in] model_format Model format of the loaded model, default is Paddle format + */ + Caddn(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE); + + std::string ModelName() const { return "Paddle3D/Caddn"; } + + /** \brief Predict the perception result for an input image + * + * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format + * \param[in] result The output perception result will be writen to this structure + * \return true if the prediction successed, otherwise false + */ + virtual bool Predict(const cv::Mat& im, + std::vector& input_cam_data, + std::vector& input_lidar_data, + PerceptionResult* results); + + /** \brief Predict the perception results for a batch of input images + * + * \param[in] imgs, The input image list, each element comes from cv::imread() + * \param[in] results The output perception result list + * \return true if the prediction successed, otherwise false + */ + virtual bool BatchPredict(const std::vector& images, + std::vector& input_cam_data, + std::vector& input_lidar_data, + std::vector* results); + + /// Get preprocessor reference of Caddn + virtual CaddnPreprocessor& GetPreprocessor() { + return preprocessor_; + } + + /// Get postprocessor reference of Caddn + virtual CaddnPostprocessor& GetPostprocessor() { + return postprocessor_; + } + + protected: + bool Initialize(); + CaddnPreprocessor preprocessor_; + CaddnPostprocessor postprocessor_; + bool initialized_ = false; +}; + +} // namespace perception +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/perception/paddle3d/caddn/caddn_pybind.cc b/fastdeploy/vision/perception/paddle3d/caddn/caddn_pybind.cc new file mode 100644 index 0000000000..622c4f43ee --- /dev/null +++ b/fastdeploy/vision/perception/paddle3d/caddn/caddn_pybind.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindCaddn(pybind11::module& m) { + pybind11::class_(m, "CaddnPreprocessor") + .def(pybind11::init()) + .def("run", + [](vision::perception::CaddnPreprocessor& self, + std::vector& im_list, + std::vector& cam_data, std::vector& lidar_data) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + } + std::vector outputs; + if (!self.Run(&images, cam_data, lidar_data, &outputs)) { + throw std::runtime_error( + "Failed to preprocess the input data in CaddnPreprocessor."); + } + for (size_t i = 0; i < outputs.size(); ++i) { + outputs[i].StopSharing(); + } + return outputs; + }); + + pybind11::class_(m, + "CaddnPostprocessor") + .def(pybind11::init<>()) + .def("run", + [](vision::perception::CaddnPostprocessor& self, + std::vector& inputs) { + std::vector results; + if (!self.Run(inputs, &results)) { + throw std::runtime_error( + "Failed to postprocess the runtime result in " + "CaddnPostprocessor."); + } + return results; + }) + .def("run", [](vision::perception::CaddnPostprocessor& self, + std::vector& input_array) { + std::vector results; + std::vector inputs; + PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); + if (!self.Run(inputs, &results)) { + throw std::runtime_error( + "Failed to postprocess the runtime result in " + "CaddnPostprocessor."); + } + return results; + }); + + pybind11::class_(m, "Caddn") + .def(pybind11::init()) + .def("predict", + [](vision::perception::Caddn& self, pybind11::array& data, + std::vector& cam_data, std::vector& lidar_data) { + auto mat = PyArrayToCvMat(data); + vision::PerceptionResult res; + self.Predict(mat, cam_data, lidar_data, &res); + return res; + }) + .def("batch_predict", + [](vision::perception::Caddn& self, + std::vector& data, std::vector& cam_data, + std::vector& lidar_data) { + std::vector images; + for (size_t i = 0; i < data.size(); ++i) { + images.push_back(PyArrayToCvMat(data[i])); + } + std::vector results; + self.BatchPredict(images, cam_data, lidar_data, &results); + return results; + }) + .def_property_readonly("preprocessor", + &vision::perception::Caddn::GetPreprocessor) + .def_property_readonly("postprocessor", + &vision::perception::Caddn::GetPostprocessor); +} +} // namespace fastdeploy diff --git a/fastdeploy/vision/perception/paddle3d/caddn/postprocessor.cc b/fastdeploy/vision/perception/paddle3d/caddn/postprocessor.cc new file mode 100644 index 0000000000..84a9c2b965 --- /dev/null +++ b/fastdeploy/vision/perception/paddle3d/caddn/postprocessor.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/perception/paddle3d/caddn/postprocessor.h" + +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { +namespace vision { +namespace perception { + +CaddnPostprocessor::CaddnPostprocessor() {} + +bool CaddnPostprocessor::Run(const std::vector& tensors, + std::vector* results) { + results->resize(1); + (*results)[0].Clear(); + (*results)[0].Reserve(tensors[0].shape[0]); + if (tensors[0].dtype != FDDataType::FP32) { + FDERROR << "Only support post process with float32 data." << std::endl; + return false; + } + const float* data_0 = reinterpret_cast(tensors[0].Data()); + auto result = &(*results)[0]; + for (int i = 0; i < tensors[0].shape[0] * tensors[0].shape[1]; i += 7) { + // item 1 ~ 3 : box3d bottom center x, y, z + // item 4 ~ 6 : box3d w, h, l + // item 7 : box3d yaw angle + std::vector vec(data_0 + i, data_0 + i + 7); + result->boxes.emplace_back( + std::array{0, 0, 0, 0, vec[3], vec[4], vec[5]}); + result->center.emplace_back(std::array{vec[0], vec[1], vec[2]}); + result->yaw_angle.push_back(vec[6]); + } + const float* data_1 = reinterpret_cast(tensors[2].Data()); + for (int i = 0; i < tensors[2].shape[0]; i += 1) { + std::vector vec(data_1 + i, data_1 + i + 1); + result->scores.push_back(vec[0]); + } + const float* data_2 = reinterpret_cast(tensors[1].Data()); + for (int i = 0; i < tensors[1].shape[0]; i++) { + std::vector vec(data_2 + i, data_2 + i + 1); + result->label_ids.push_back(vec[0]); + } + + result->valid.push_back(true); // 0 scores + result->valid.push_back(true); // 1 label_ids + result->valid.push_back(true); // 2 boxes + result->valid.push_back(true); // 3 center + result->valid.push_back(false); // 4 observation_angle + result->valid.push_back(true); // 5 yaw_angle + result->valid.push_back(false); // 6 velocity + + return true; +} + +} // namespace perception +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/perception/paddle3d/caddn/postprocessor.h b/fastdeploy/vision/perception/paddle3d/caddn/postprocessor.h new file mode 100644 index 0000000000..c7b57c6c2f --- /dev/null +++ b/fastdeploy/vision/perception/paddle3d/caddn/postprocessor.h @@ -0,0 +1,48 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace perception { +/*! @brief Postprocessor object for Caddn serials model. + */ +class FASTDEPLOY_DECL CaddnPostprocessor { + public: + /** \brief Create a postprocessor instance for Caddn serials model + */ + CaddnPostprocessor(); + + /** \brief Process the result of runtime and fill to PerceptionResult structure + * + * \param[in] tensors The inference result from runtime + * \param[in] result The output result of detection + * \param[in] ims_info The shape info list, record input_shape and output_shape + * \return true if the postprocess successed, otherwise false + */ + bool Run(const std::vector& tensors, + std::vector* results); + + + protected: + float conf_threshold_; +}; + +} // namespace perception +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/perception/paddle3d/caddn/preprocessor.cc b/fastdeploy/vision/perception/paddle3d/caddn/preprocessor.cc new file mode 100644 index 0000000000..ca86b89c26 --- /dev/null +++ b/fastdeploy/vision/perception/paddle3d/caddn/preprocessor.cc @@ -0,0 +1,112 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/perception/paddle3d/caddn/preprocessor.h" + +#include "fastdeploy/function/concat.h" +#include "yaml-cpp/yaml.h" + +namespace fastdeploy { +namespace vision { +namespace perception { + +CaddnPreprocessor::CaddnPreprocessor(const std::string& config_file) { + config_file_ = config_file; + FDASSERT(BuildPreprocessPipeline(), + "Failed to create Paddle3DDetPreprocessor."); + initialized_ = true; +} + +bool CaddnPreprocessor::BuildPreprocessPipeline() { + processors_.clear(); + + // preprocess + processors_.push_back(std::make_shared()); + + std::vector alpha = {1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0}; + std::vector beta = {0.0, 0.0, 0.0}; + processors_.push_back(std::make_shared(alpha, beta)); + + processors_.push_back(std::make_shared("float")); + processors_.push_back(std::make_shared()); + + // Fusion will improve performance + FuseTransforms(&processors_); + + return true; +} + +bool CaddnPreprocessor::Apply(FDMatBatch* image_batch, + std::vector& input_cam_data, + std::vector& input_lidar_data, + std::vector* outputs) { + if (image_batch->mats->empty()) { + FDERROR << "The size of input images should be greater than 0." + << std::endl; + return false; + } + if (!initialized_) { + FDERROR << "The preprocessor is not initialized." << std::endl; + return false; + } + // There are 3 outputs, image, cam_data, lidar_data + outputs->resize(3); + int batch = static_cast(image_batch->mats->size()); + + // Allocate memory for cam_data + (*outputs)[1].Resize({batch, 3, 4}, FDDataType::FP32); + + // Allocate memory for lidar_data + (*outputs)[2].Resize({batch, 4, 4}, FDDataType::FP32); + + auto* cam_data_ptr = reinterpret_cast((*outputs)[1].MutableData()); + auto* lidar_data_ptr = reinterpret_cast((*outputs)[2].MutableData()); + + for (size_t i = 0; i < image_batch->mats->size(); ++i) { + FDMat* mat = &(image_batch->mats->at(i)); + for (size_t j = 0; j < processors_.size(); ++j) { + if (!(*(processors_[j].get()))(mat)) { + FDERROR << "Failed to processs image:" << i << " in " + << processors_[j]->Name() << "." << std::endl; + return false; + } + } + + memcpy(cam_data_ptr + i * 12, input_cam_data.data(), 12 * sizeof(float)); + memcpy(lidar_data_ptr + i * 16, input_lidar_data.data(), + 16 * sizeof(float)); + } + + FDTensor* tensor = image_batch->Tensor(); + (*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(), + tensor->Data(), tensor->device, + tensor->device_id); + + return true; +} + +bool CaddnPreprocessor::Run(std::vector* images, + std::vector& input_cam_data, + std::vector& input_lidar_data, + std::vector* outputs) { + FDMatBatch image_batch(images); + PreApply(&image_batch); + bool ret = Apply(&image_batch, input_cam_data, input_lidar_data, outputs); + PostApply(); + return ret; +} + +} // namespace perception +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/perception/paddle3d/caddn/preprocessor.h b/fastdeploy/vision/perception/paddle3d/caddn/preprocessor.h new file mode 100755 index 0000000000..26fa1eaa4e --- /dev/null +++ b/fastdeploy/vision/perception/paddle3d/caddn/preprocessor.h @@ -0,0 +1,69 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/manager.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace perception { +/*! @brief Preprocessor object for Caddn serials model. + */ +class FASTDEPLOY_DECL CaddnPreprocessor : public ProcessorManager { + public: + CaddnPreprocessor() = default; + /** \brief Create a preprocessor instance for Caddn model + * + * \param[in] config_file Path of configuration file for deployment, e.g Caddn/infer_cfg.yml + */ + explicit CaddnPreprocessor(const std::string& config_file); + + bool Run(std::vector* images, + std::vector& input_cam_data, + std::vector& input_lidar_data, + std::vector* outputs); + + /** \brief Process the input image and prepare input tensors for runtime + * + * \param[in] images The input image data list, all the elements are returned by cv::imread() + * \param[in] outputs The output tensors which will feed in runtime + * \param[in] ims_info The shape info list, record input_shape and output_shape + * \return true if the preprocess successed, otherwise false + */ + bool Apply(FDMatBatch* image_batch, std::vector* outputs) { + FDERROR << "CaddnPreprocessor should input cam and lidar datas" << std::endl; + return 0; + }; + bool Apply(FDMatBatch* image_batch, + std::vector& input_cam_data, + std::vector& input_lidar_data, + std::vector* outputs); + + protected: + bool BuildPreprocessPipeline(); + std::vector> processors_; + + bool disable_permute_ = false; + + bool initialized_ = false; + + std::string config_file_; +}; + +} // namespace perception +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/perception/paddle3d/centerpoint/preprocessor.cc b/fastdeploy/vision/perception/paddle3d/centerpoint/preprocessor.cc index 4697d8ed3a..f4f4266f09 100644 --- a/fastdeploy/vision/perception/paddle3d/centerpoint/preprocessor.cc +++ b/fastdeploy/vision/perception/paddle3d/centerpoint/preprocessor.cc @@ -24,8 +24,7 @@ CenterpointPreprocessor::CenterpointPreprocessor( bool CenterpointPreprocessor::ReadPoint(const std::string &file_path, const int64_t num_point_dim, - std::vector &data, - int64_t *num_points) { + std::vector &data, int64_t *num_points) { std::ifstream file_in(file_path, std::ios::in | std::ios::binary); if (num_point_dim < 4) { FDERROR << "Point dimension must not be less than 4, but received " diff --git a/fastdeploy/vision/perception/perception_pybind.cc b/fastdeploy/vision/perception/perception_pybind.cc index 998776b162..bc325c363e 100755 --- a/fastdeploy/vision/perception/perception_pybind.cc +++ b/fastdeploy/vision/perception/perception_pybind.cc @@ -19,6 +19,7 @@ namespace fastdeploy { void BindSmoke(pybind11::module& m); void BindPetr(pybind11::module& m); void BindCenterpoint(pybind11::module& m); +void BindCaddn(pybind11::module& m); void BindPerception(pybind11::module& m) { auto perception_module = @@ -26,5 +27,6 @@ void BindPerception(pybind11::module& m) { BindSmoke(perception_module); BindPetr(perception_module); BindCenterpoint(perception_module); + BindCaddn(perception_module); } } // namespace fastdeploy diff --git a/python/fastdeploy/vision/perception/__init__.py b/python/fastdeploy/vision/perception/__init__.py index b97d4f6ec2..fc2f2a21d5 100755 --- a/python/fastdeploy/vision/perception/__init__.py +++ b/python/fastdeploy/vision/perception/__init__.py @@ -16,3 +16,4 @@ from .paddle3d.smoke import * from .paddle3d.petr import * from .paddle3d.centerpoint import * +from .paddle3d.caddn import * diff --git a/python/fastdeploy/vision/perception/paddle3d/caddn.py b/python/fastdeploy/vision/perception/paddle3d/caddn.py new file mode 100644 index 0000000000..312d82cdb1 --- /dev/null +++ b/python/fastdeploy/vision/perception/paddle3d/caddn.py @@ -0,0 +1,108 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import logging +from .... import FastDeployModel, ModelFormat +from .... import c_lib_wrap as C + + +class CaddnPreprocessor: + def __init__(self, config_file): + """Create a preprocessor for Caddn + """ + self._preprocessor = C.vision.perception.CaddnPreprocessor(config_file) + + def run(self, input_ims, cam_data, lidar_data): + """Preprocess input images for Caddn + + :param: input_ims: (list of numpy.ndarray)The input image + :return: list of FDTensor + """ + return self._preprocessor.run(input_ims, cam_data, lidar_data) + + +class CaddnPostprocessor: + def __init__(self): + """Create a postprocessor for Caddn + """ + self._postprocessor = C.vision.perception.CaddnPostprocessor() + + def run(self, runtime_results): + """Postprocess the runtime results for Caddn + + :param: runtime_results: (list of FDTensor)The output FDTensor results from runtime + :return: list of PerceptionResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) + """ + return self._postprocessor.run(runtime_results) + + +class Caddn(FastDeployModel): + def __init__(self, + model_file, + params_file, + config_file, + runtime_option=None, + model_format=ModelFormat.PADDLE): + """Load a Caddn model exported by Caddn. + + :param model_file: (str)Path of model file, e.g ./Caddn.pdmodel + :param params_file: (str)Path of parameters file, e.g ./Caddn.pdiparams + :param config_file: (str)Path of config file, e.g ./infer_cfg.yaml + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + super(Caddn, self).__init__(runtime_option) + + self._model = C.vision.perception.Caddn( + model_file, params_file, config_file, self._runtime_option, + model_format) + assert self.initialized, "Caddn initialize failed." + + def predict(self, input_image, cam_data, lidar_data): + """Detect an input image + + :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :param: cam_data: (list)The input camera data + :param: lidar_data: (list)The input lidar data + :return: PerceptionResult + """ + return self._model.predict(input_image, cam_data, lidar_data) + + def batch_predict(self, images, cam_data, lidar_data): + """Classify a batch of input image + + :param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format + :param: cam_data: (list)The input camera data + :param: lidar_data: (list)The input lidar data + :return list of PerceptionResult + """ + + return self._model.batch_predict(images, cam_data, lidar_data) + + @property + def preprocessor(self): + """Get CaddnPreprocessor object of the loaded model + + :return CaddnPreprocessor + """ + return self._model.preprocessor + + @property + def postprocessor(self): + """Get CaddnPostprocessor object of the loaded model + + :return CaddnPostprocessor + """ + return self._model.postprocessor