From 5eaa6c91388f8f7e5e64a875d7dc1e761b2e9a13 Mon Sep 17 00:00:00 2001 From: zhangshaopeng Date: Fri, 2 Sep 2022 10:46:46 +0800 Subject: [PATCH 1/5] [Feature] Support RoipointPool3d with cambricon MLU backend --- docs/en/understand_mmcv/ops.md | 2 +- docs/zh_cn/understand_mmcv/ops.md | 2 +- ...oint_pool3d_large_boxes_num_mlu_kernel.mlu | 537 +++++++++++++++++ .../common/mlu/roipoint_pool3d_mlu_kernel.mlu | 545 ++++++++++++++++++ .../csrc/common/mlu/roipoint_pool3d_utils.hpp | 18 + .../csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp | 175 ++++++ tests/test_ops/test_roipoint_pool3d.py | 76 ++- 7 files changed, 1326 insertions(+), 29 deletions(-) create mode 100644 mmcv/ops/csrc/common/mlu/roipoint_pool3d_large_boxes_num_mlu_kernel.mlu create mode 100644 mmcv/ops/csrc/common/mlu/roipoint_pool3d_mlu_kernel.mlu create mode 100644 mmcv/ops/csrc/common/mlu/roipoint_pool3d_utils.hpp create mode 100644 mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 9b9b7a2841..d35a05a88c 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -40,7 +40,7 @@ We implement common ops used in detection, segmentation, etc. | PointsInPolygons | | √ | | | | PSAMask | √ | √ | √ | | | RotatedFeatureAlign | √ | √ | | | -| RoIPointPool3d | | √ | | | +| RoIPointPool3d | | √ | √ | | | RoIPool | | √ | √ | | | RoIAlignRotated | √ | √ | √ | | | RiRoIAlignRotated | | √ | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 0b39b2afe0..4f2c4d3847 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -40,7 +40,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | PointsInPolygons | | √ | | | | PSAMask | √ | √ | √ | | | RotatedFeatureAlign | √ | √ | | | -| RoIPointPool3d | | √ | | | +| RoIPointPool3d | | √ | √ | | | RoIPool | | √ | √ | | | RoIAlignRotated | √ | √ | √ | | | RiRoIAlignRotated | | √ | | | diff --git a/mmcv/ops/csrc/common/mlu/roipoint_pool3d_large_boxes_num_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/roipoint_pool3d_large_boxes_num_mlu_kernel.mlu new file mode 100644 index 0000000000..ec179396f9 --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/roipoint_pool3d_large_boxes_num_mlu_kernel.mlu @@ -0,0 +1,537 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * OR IMPLIED, INCLUDING BUvoid NOKType LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENvoid SHALL THE AUTHORS OR COPYRIGHKType HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORvoid OR OTHERWISE, ARISING FROM, OUKType OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "common_mlu_helper.hpp" +#include "roipoint_pool3d_utils.hpp" + +/************************************************************************* + * + * NRAM partition: + * | boxes3d | ping points + pong points | aux_a ~ aux_f | + * | 7 * sizeof(T) | 6 * deal_num * sizeof(T) | 6 * deal_num * sizeof(T) | + * + *************************************************************************/ +#define TWELVE_SPLIT 12 + +__nram__ char nram_buffer[MAX_NRAM_SIZE]; + +template +__mlu_func__ void checkPointsInBox3d(const T *boxes3d, + const size_t deal_num, + T *x, + T *y, + T *z, + T *auxiliary_a, + T *auxiliary_b, + T *auxiliary_c, + T *auxiliary_d, + T *auxiliary_e, + T *auxiliary_f, + T *pts_assign) { + // param box3d: (cx, cy, cz, dx, dy, dz, rz) in LiDAR coordinate + T cx = boxes3d[0]; + T cy = boxes3d[1]; + T cz = boxes3d[2]; + T dx = boxes3d[3]; + T dy = boxes3d[4]; + T dz = boxes3d[5]; + T rz = boxes3d[6]; + // shift to the center since cz in box3d is the bottom center + cz += 0.5 * dz; + + T cosa = (T)std::cos(-rz); + T sina = (T)std::sin(-rz); + + // x - cx + __bang_sub_scalar((T *)auxiliary_a, (T *)x, (T)cx, deal_num); + // y - cy + __bang_sub_scalar((T *)auxiliary_b, (T *)y, (T)cy, deal_num); + // z - cz + __bang_sub_scalar((T *)auxiliary_c, (T *)z, (T)cz, deal_num); + // |z - cz| + __bang_active_abs((T *)auxiliary_c, (T *)auxiliary_c, deal_num); + // |z - cz| > dz / 2.0 +#if __BANG_ARCH__ >= 322 + __bang_gt_scalar((T *)auxiliary_c, (T *)auxiliary_c, (T)(0.5 * dz), deal_num); +#else + __bang_write_value((T *)auxiliary_d, deal_num, (T)(0.5 * dz)); + __bang_lt((T *)auxiliary_c, (T *)auxiliary_d, (T *)auxiliary_c, deal_num); +#endif + // !(|z - cz| > dz / 2.0) + __bang_not((T *)auxiliary_c, (T *)auxiliary_c, deal_num); + // (x - cx) * cos(-rz) + __bang_mul_scalar((T *)auxiliary_d, (T *)auxiliary_a, (T)cosa, deal_num); + // (y - cy) * sin(-rz) + __bang_mul_scalar((T *)auxiliary_e, (T *)auxiliary_b, (T)sina, deal_num); + // local_x = (x - cx) * cos(-rz) + (y - cy) * -sin(-rz) + __bang_sub((T *)auxiliary_d, (T *)auxiliary_d, (T *)auxiliary_e, deal_num); + // |local_x| + __bang_active_abs((T *)auxiliary_d, (T *)auxiliary_d, deal_num); + // |local_x| < dx / 2.0 +#if __BANG_ARCH__ >= 322 + __bang_lt_scalar(auxiliary_d, auxiliary_d, (T)(0.5 * dx), deal_num); +#else + __bang_write_value((T *)auxiliary_e, deal_num, (T)(0.5 * dx)); + __bang_gt((T *)auxiliary_d, (T *)auxiliary_e, (T *)auxiliary_d, deal_num); +#endif + // (x - cx) * sin(-rz) + __bang_mul_scalar((T *)auxiliary_e, (T *)auxiliary_a, (T)sina, deal_num); + // (y - cy) * cos(-rz) + __bang_mul_scalar((T *)auxiliary_f, (T *)auxiliary_b, (T)cosa, deal_num); + // local_y = (x - cx) * sin(-rz) + (y - cy) * cos(-rz) + __bang_add((T *)auxiliary_e, (T *)auxiliary_e, (T *)auxiliary_f, deal_num); + // |local_y| + __bang_active_abs((T *)auxiliary_e, (T *)auxiliary_e, deal_num); + // |local_y| < dy / 2.0 +#if __BANG_ARCH__ >= 322 + __bang_lt_scalar(auxiliary_e, auxiliary_e, (T)(0.5 * dy), deal_num); +#else + __bang_write_value((T *)auxiliary_f, deal_num, (T)(0.5 * dy)); + __bang_gt((T *)auxiliary_e, (T *)auxiliary_f, (T *)auxiliary_e, deal_num); +#endif + // pts_assign = |x - cx| < dx / 2.0 && |y - cy| < dy / 2.0 && |z - cz| <= dz / 2.0 + __bang_mul((T *)pts_assign, (T *)auxiliary_c, (T *)auxiliary_d, deal_num); + __bang_mul((T *)pts_assign, (T *)pts_assign, (T *)auxiliary_e, deal_num); +} + +template +__mlu_func__ void computeStoreRoipointPool3d(char *boxes3d, + int *cnt, + char *points_x, + char *points_y, + char *points_z, + const char *point_features, + char *auxiliary_a, + char *auxiliary_b, + char *auxiliary_c, + char *auxiliary_d, + char *auxiliary_e, + char *auxiliary_f, + const int box_idx, + const int pts_num, + const int feature_in_len, + const int sampled_pts_num, + const size_t span_num_deal, + char *pooled_features_gdram, + char *pooled_empty_flag_gdram) { + char *pts_assign = auxiliary_a; + if (*cnt >= sampled_pts_num) { + return; + } + checkPointsInBox3d((T *)boxes3d, span_num_deal, (T *)points_x, (T *)points_y, (T *)points_z, + (T *)auxiliary_a, (T *)auxiliary_b, (T *)auxiliary_c, (T *)auxiliary_d, + (T *)auxiliary_e, (T *)auxiliary_f, (T *)pts_assign); + + // __bang_select returns selected elements vector and the number of selected elements + __bang_select((T *)auxiliary_b, (T *)points_x, (T *)pts_assign, span_num_deal); + uint32_t select_num = *((uint32_t *)auxiliary_b); + + if (select_num == 0) { + return; + } + int sampled_pts_num_rem = sampled_pts_num - *cnt; + int segnum = MIN(select_num, sampled_pts_num_rem) - 1; + + // copy x to pooled_features_gdram + // The result of __bang_select is composed of three parts: + // The first 4-byte is the number of selected element, whose data type is unsigned int. + // The next 124-byte is zero. The rest bytes are the selected elements. + int select_num_size = 128; + __memcpy( + pooled_features_gdram + (box_idx * sampled_pts_num + *cnt) * (3 + feature_in_len) * sizeof(T), + (T *)((int8_t *)auxiliary_b + select_num_size), sizeof(T), NRAM2GDRAM, + (3 + feature_in_len) * sizeof(T), sizeof(T), segnum); + + // copy y to pooled_features_gdram + __bang_collect((T *)auxiliary_d, (T *)points_y, (T *)pts_assign, span_num_deal); + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + *cnt) * (3 + feature_in_len) * sizeof(T) + + 1 * sizeof(T), + (T *)auxiliary_d, sizeof(T), NRAM2GDRAM, (3 + feature_in_len) * sizeof(T), sizeof(T), + segnum); + + // copy z to pooled_features_gdram + __bang_collect((T *)auxiliary_e, (T *)points_z, (T *)pts_assign, span_num_deal); + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + *cnt) * (3 + feature_in_len) * sizeof(T) + + 2 * sizeof(T), + (T *)auxiliary_e, sizeof(T), NRAM2GDRAM, (3 + feature_in_len) * sizeof(T), sizeof(T), + segnum); + + // copy features to pooled_features_gdram + for (int c_idx = 0; c_idx < feature_in_len; c_idx++) { + __memcpy(auxiliary_d, point_features + c_idx * pts_num * sizeof(T), span_num_deal * sizeof(T), + GDRAM2NRAM); + __bang_collect((T *)auxiliary_e, (T *)auxiliary_d, (T *)pts_assign, span_num_deal); + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + *cnt) * (3 + feature_in_len) * sizeof(T) + + (3 + c_idx) * sizeof(T), + auxiliary_e, sizeof(T), NRAM2GDRAM, (3 + feature_in_len) * sizeof(T), sizeof(T), + segnum); + } + + *cnt += select_num; +} + +template +__mlu_func__ void computeStoreLastBlockRoipointPool3d(char *boxes3d, + int *cnt, + char *points_x, + char *points_y, + char *points_z, + const char *point_features, + char *auxiliary_a, + char *auxiliary_b, + char *auxiliary_c, + char *auxiliary_d, + char *auxiliary_e, + char *auxiliary_f, + const int box_idx, + const int pts_num, + const int feature_in_len, + const int sampled_pts_num, + const size_t span_num_deal, + const size_t auxiliary_num_deal, + char *pooled_features_gdram, + char *pooled_empty_flag_gdram) { + char *pts_assign = auxiliary_a; + if (*cnt >= sampled_pts_num) { + // pooled_empty_flag_gdram set 0 + *((int *)auxiliary_a) = 0; + __memcpy(pooled_empty_flag_gdram + box_idx * sizeof(int), auxiliary_a, sizeof(int), NRAM2GDRAM); + return; + } + checkPointsInBox3d((T *)boxes3d, span_num_deal, (T *)points_x, (T *)points_y, (T *)points_z, + (T *)auxiliary_a, (T *)auxiliary_b, (T *)auxiliary_c, (T *)auxiliary_d, + (T *)auxiliary_e, (T *)auxiliary_f, (T *)pts_assign); + + // __bang_select returns selected elements vector and the number of selected elements + __bang_select((T *)auxiliary_b, (T *)points_x, (T *)pts_assign, span_num_deal); + uint32_t select_num = *((uint32_t *)auxiliary_b); + + if (*cnt + select_num == 0) { + // pooled_empty_flag_gdram set 1 + *((int *)auxiliary_a) = 1; + __memcpy(pooled_empty_flag_gdram + box_idx * sizeof(int), auxiliary_a, sizeof(int), NRAM2GDRAM); + + // pooled_features_gdram set 0 + int repeat = (sampled_pts_num * (3 + feature_in_len)) / (auxiliary_num_deal * 6); + int rem = (sampled_pts_num * (3 + feature_in_len)) % (auxiliary_num_deal * 6); + // use auxiliary_a to auxiliary_f + __bang_write_zero((T *)auxiliary_a, PAD_UP(auxiliary_num_deal * 6, NFU_ALIGN_SIZE)); + if (repeat > 0) { + __memcpy(pooled_features_gdram + box_idx * sampled_pts_num * (3 + feature_in_len) * sizeof(T), + auxiliary_a, auxiliary_num_deal * 6 * sizeof(T), NRAM2GDRAM, + auxiliary_num_deal * 6 * sizeof(T), 0, repeat - 1); + } + if (rem > 0) { + __memcpy(pooled_features_gdram + + box_idx * sampled_pts_num * (3 + feature_in_len) * sizeof(T) + + repeat * auxiliary_num_deal * 6 * sizeof(T), + auxiliary_a, rem * sizeof(T), NRAM2GDRAM); + } + return; + } + + if (select_num > 0) { + int sampled_pts_num_rem = sampled_pts_num - *cnt; + int segnum = MIN(select_num, sampled_pts_num_rem) - 1; + + // copy x to pooled_features_gdram + // The result of __bang_select is composed of three parts: + // The first 4-byte is the number of selected element, whose data type is unsigned int. + // The next 124-byte is zero. The rest bytes are the selected elements. + int select_num_size = 128; + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + *cnt) * (3 + feature_in_len) * sizeof(T), + (T *)((int8_t *)auxiliary_b + select_num_size), sizeof(T), NRAM2GDRAM, + (3 + feature_in_len) * sizeof(T), sizeof(T), segnum); + + // copy y to pooled_features_gdram + __bang_collect((T *)auxiliary_d, (T *)points_y, (T *)pts_assign, span_num_deal); + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + *cnt) * (3 + feature_in_len) * sizeof(T) + + 1 * sizeof(T), + (T *)auxiliary_d, sizeof(T), NRAM2GDRAM, (3 + feature_in_len) * sizeof(T), sizeof(T), + segnum); + + // copy z to pooled_features_gdram + __bang_collect((T *)auxiliary_e, (T *)points_z, (T *)pts_assign, span_num_deal); + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + *cnt) * (3 + feature_in_len) * sizeof(T) + + 2 * sizeof(T), + (T *)auxiliary_e, sizeof(T), NRAM2GDRAM, (3 + feature_in_len) * sizeof(T), sizeof(T), + segnum); + + // copy features to pooled_features_gdram + for (int c_idx = 0; c_idx < feature_in_len; c_idx++) { + __memcpy(auxiliary_d, point_features + c_idx * pts_num * sizeof(T), span_num_deal * sizeof(T), + GDRAM2NRAM); + __bang_collect((T *)auxiliary_e, (T *)auxiliary_d, (T *)pts_assign, span_num_deal); + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + *cnt) * (3 + feature_in_len) * sizeof(T) + + (3 + c_idx) * sizeof(T), + auxiliary_e, sizeof(T), NRAM2GDRAM, (3 + feature_in_len) * sizeof(T), sizeof(T), + segnum); + } + } + + // pooled_empty_flag_gdram set 0 + *((int *)auxiliary_a) = 0; + __memcpy(pooled_empty_flag_gdram + box_idx * sizeof(int), auxiliary_a, sizeof(int), NRAM2GDRAM); + + *cnt += select_num; + if (*cnt < sampled_pts_num) { + // duplicate same points for sampling + int repeat = sampled_pts_num / (*cnt) - 1; + int rem = sampled_pts_num % (*cnt); + if (repeat > 0) { + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + *cnt) * (3 + feature_in_len) * sizeof(T), + pooled_features_gdram + box_idx * sampled_pts_num * (3 + feature_in_len) * sizeof(T), + (*cnt) * (3 + feature_in_len) * sizeof(T), GDRAM2GDRAM, + (*cnt) * (3 + feature_in_len) * sizeof(T), 0, repeat - 1); + } + if (rem > 0) { + __memcpy( + pooled_features_gdram + + (box_idx * sampled_pts_num + (repeat + 1) * (*cnt)) * (3 + feature_in_len) * + sizeof(T), + pooled_features_gdram + box_idx * sampled_pts_num * (3 + feature_in_len) * sizeof(T), + rem * (3 + feature_in_len) * sizeof(T), GDRAM2GDRAM); + } + } +} + +template +__mlu_global__ void MLUUnion1KernelRoiPointPool3dLargeBoxesNumForward( + const int batch_size, + const int pts_num, + const int boxes_num, + const int feature_in_len, + const int sampled_pts_num, + const char *points_xyz_gdram, + const char *point_features_gdram, + const char *boxes3d_gdram, + char *pooled_features_gdram, + char *pooled_empty_flag_gdram) { + if (coreId == 0x80) { + return; + } + size_t boxes_per_core = (batch_size * boxes_num) / taskDim; + size_t boxes_rem = (batch_size * boxes_num) % taskDim; + // calc batch_start, batch_end, first_batch_box_start, last batch_box_end for each core + int32_t batch_start = taskId < (boxes_rem + 1) ? + (taskId * (boxes_per_core + 1)) / boxes_num : + (taskId * boxes_per_core + boxes_rem) / boxes_num; + int32_t batch_end = taskId < boxes_rem ? + ((taskId + 1) * (boxes_per_core + 1) - 1) / boxes_num : + ((taskId + 1) * boxes_per_core + boxes_rem - 1) / boxes_num; + size_t first_batch_box_start = taskId < (boxes_rem + 1) ? + (taskId * (boxes_per_core + 1)) - batch_start * boxes_num : + taskId * boxes_per_core + boxes_rem - batch_start * boxes_num; + size_t last_batch_box_end = taskId < boxes_rem ? + (taskId + 1) * (boxes_per_core + 1) - batch_end * boxes_num : + ((taskId + 1) * boxes_per_core + boxes_rem) - batch_end * boxes_num; + + // points_xyz : [3, B, N] + const char *points_x_gdram = points_xyz_gdram; + const char *points_y_gdram = points_xyz_gdram + (1 * batch_size * pts_num) * sizeof(T); + const char *points_z_gdram = points_xyz_gdram + (2 * batch_size * pts_num) * sizeof(T); + + size_t boxes3d_size = PAD_UP(7, NFU_ALIGN_SIZE) * sizeof(T); + size_t span_num_deal = PAD_DOWN(MAX_NRAM_SIZE / TWELVE_SPLIT / sizeof(T), NFU_ALIGN_SIZE); + size_t align_num = NFU_ALIGN_SIZE; + int32_t repeat = pts_num / span_num_deal; + size_t rem = pts_num % span_num_deal; + size_t align_rem = CEIL_ALIGN(rem, align_num); + char *boxes3d = nram_buffer; + char *ping_points_x = nram_buffer + boxes3d_size; + char *ping_points_y = ping_points_x + span_num_deal * sizeof(T); + char *ping_points_z = ping_points_y + span_num_deal * sizeof(T); + size_t ping_pong_gap = 3 * span_num_deal * sizeof(T); + char *auxiliary_a = ping_points_x + 2 * ping_pong_gap; + char *auxiliary_b = auxiliary_a + span_num_deal * sizeof(T); + char *auxiliary_c = auxiliary_b + span_num_deal * sizeof(T); + char *auxiliary_d = auxiliary_c + span_num_deal * sizeof(T); + char *auxiliary_e = auxiliary_d + span_num_deal * sizeof(T); + char *auxiliary_f = auxiliary_e + span_num_deal * sizeof(T); + size_t span_load_input1_size = span_num_deal * sizeof(T); + size_t span_load_input2_size = span_num_deal * sizeof(T); + size_t span_load_input3_size = span_num_deal * sizeof(T); + size_t span_load_input4_size = span_num_deal * sizeof(T); + int cnt = 0; + + for (int bs_idx = batch_start; bs_idx <= batch_end; bs_idx++) { + const char *points_x_start = points_x_gdram + bs_idx * pts_num * sizeof(T); + const char *points_y_start = points_y_gdram + bs_idx * pts_num * sizeof(T); + const char *points_z_start = points_z_gdram + bs_idx * pts_num * sizeof(T); + const char *point_features_start = + point_features_gdram + bs_idx * feature_in_len * pts_num * sizeof(T); + char *pooled_features_start = + pooled_features_gdram + + (bs_idx * boxes_num * sampled_pts_num * (3 + feature_in_len)) * sizeof(T); + char *pooled_empty_flag_start = pooled_empty_flag_gdram + bs_idx * boxes_num * sizeof(int); + size_t box_start = bs_idx == batch_start ? first_batch_box_start : 0; + size_t box_end = bs_idx == batch_end ? last_batch_box_end : boxes_num; + + for (int box_idx = box_start; box_idx < box_end; box_idx++) { + __memcpy_async(boxes3d, + boxes3d_gdram + bs_idx * boxes_num * 7 * sizeof(T) + box_idx * 7 * sizeof(T), + 7 * sizeof(T), GDRAM2NRAM); + cnt = 0; + if (repeat > 0) { + __memcpy_async(ping_points_x, points_x_start, span_load_input1_size, GDRAM2NRAM); + __memcpy_async(ping_points_y, points_y_start, span_load_input2_size, GDRAM2NRAM); + __memcpy_async(ping_points_z, points_z_start, span_load_input3_size, GDRAM2NRAM); + __asm__ volatile("sync;"); + } + + for (int i = 0; i < repeat - 1; i++) { + __memcpy_async(ping_points_x + ((i + 1) % 2) * ping_pong_gap, + points_x_start + (i + 1) * span_load_input1_size, span_load_input1_size, + GDRAM2NRAM); + __memcpy_async(ping_points_y + ((i + 1) % 2) * ping_pong_gap, + points_y_start + (i + 1) * span_load_input2_size, span_load_input2_size, + GDRAM2NRAM); + __memcpy_async(ping_points_z + ((i + 1) % 2) * ping_pong_gap, + points_z_start + (i + 1) * span_load_input3_size, span_load_input3_size, + GDRAM2NRAM); + computeStoreRoipointPool3d( + boxes3d, &cnt, ping_points_x + (i % 2) * ping_pong_gap, + ping_points_y + (i % 2) * ping_pong_gap, ping_points_z + (i % 2) * ping_pong_gap, + point_features_start + i * span_load_input4_size, auxiliary_a, auxiliary_b, auxiliary_c, + auxiliary_d, auxiliary_e, auxiliary_f, box_idx, pts_num, feature_in_len, + sampled_pts_num, span_num_deal, pooled_features_start, pooled_empty_flag_start); + __asm__ volatile("sync;"); + } + + if (rem > 0) { + if (sizeof(T) == sizeof(float)) { + __bang_write_value((T *)(ping_points_x + (repeat % 2) * ping_pong_gap + + PAD_DOWN(rem, NFU_ALIGN_SIZE) * sizeof(T)), + NFU_ALIGN_SIZE, (T)NAN); + __bang_write_value((T *)(ping_points_y + (repeat % 2) * ping_pong_gap + + PAD_DOWN(rem, NFU_ALIGN_SIZE) * sizeof(T)), + NFU_ALIGN_SIZE, (T)NAN); + __bang_write_value((T *)(ping_points_z + (repeat % 2) * ping_pong_gap + + PAD_DOWN(rem, NFU_ALIGN_SIZE) * sizeof(T)), + NFU_ALIGN_SIZE, (T)NAN); + } else { + __bang_write_value((T *)(ping_points_x + (repeat % 2) * ping_pong_gap + + PAD_DOWN(rem, NFU_ALIGN_SIZE) * sizeof(T)), + NFU_ALIGN_SIZE, (T)NAN); + __bang_write_value((T *)(ping_points_y + (repeat % 2) * ping_pong_gap + + PAD_DOWN(rem, NFU_ALIGN_SIZE) * sizeof(T)), + NFU_ALIGN_SIZE, (T)NAN); + __bang_write_value((T *)(ping_points_z + (repeat % 2) * ping_pong_gap + + PAD_DOWN(rem, NFU_ALIGN_SIZE) * sizeof(T)), + NFU_ALIGN_SIZE, (T)NAN); + } + __memcpy_async(ping_points_x + (repeat % 2) * ping_pong_gap, + points_x_start + repeat * span_load_input1_size, rem * sizeof(T), + GDRAM2NRAM); + __memcpy_async(ping_points_y + (repeat % 2) * ping_pong_gap, + points_y_start + repeat * span_load_input2_size, rem * sizeof(T), + GDRAM2NRAM); + __memcpy_async(ping_points_z + (repeat % 2) * ping_pong_gap, + points_z_start + repeat * span_load_input3_size, rem * sizeof(T), + GDRAM2NRAM); + } + + if (repeat > 0 && rem > 0) { + computeStoreRoipointPool3d( + boxes3d, &cnt, ping_points_x + ((repeat - 1) % 2) * ping_pong_gap, + ping_points_y + ((repeat - 1) % 2) * ping_pong_gap, + ping_points_z + ((repeat - 1) % 2) * ping_pong_gap, + point_features_start + (repeat - 1) * span_load_input4_size, auxiliary_a, auxiliary_b, + auxiliary_c, auxiliary_d, auxiliary_e, auxiliary_f, box_idx, pts_num, feature_in_len, + sampled_pts_num, span_num_deal, pooled_features_start, pooled_empty_flag_start); + } else if (repeat > 0 && rem == 0) { + computeStoreLastBlockRoipointPool3d( + boxes3d, &cnt, ping_points_x + ((repeat - 1) % 2) * ping_pong_gap, + ping_points_y + ((repeat - 1) % 2) * ping_pong_gap, + ping_points_z + ((repeat - 1) % 2) * ping_pong_gap, + point_features_start + (repeat - 1) * span_load_input4_size, auxiliary_a, auxiliary_b, + auxiliary_c, auxiliary_d, auxiliary_e, auxiliary_f, box_idx, pts_num, feature_in_len, + sampled_pts_num, span_num_deal, span_num_deal, pooled_features_start, + pooled_empty_flag_start); + } + + if (rem > 0) { + __asm__ volatile("sync;"); + computeStoreLastBlockRoipointPool3d( + boxes3d, &cnt, ping_points_x + (repeat % 2) * ping_pong_gap, + ping_points_y + (repeat % 2) * ping_pong_gap, + ping_points_z + (repeat % 2) * ping_pong_gap, + point_features_start + repeat * span_load_input4_size, auxiliary_a, auxiliary_b, + auxiliary_c, auxiliary_d, auxiliary_e, auxiliary_f, box_idx, pts_num, feature_in_len, + sampled_pts_num, align_rem, span_num_deal, pooled_features_start, + pooled_empty_flag_start); + } + } + } +} + +template __mlu_global__ void MLUUnion1KernelRoiPointPool3dLargeBoxesNumForward( + const int batch_size, + const int pts_num, + const int boxes_num, + const int feature_in_len, + const int sampled_pts_num, + const char *points_xyz_gdram, + const char *point_features_gdram, + const char *boxes3d_gdram, + char *pooled_features_gdram, + char *pooled_empty_flag_gdram); + +template __mlu_global__ void MLUUnion1KernelRoiPointPool3dLargeBoxesNumForward( + const int batch_size, + const int pts_num, + const int boxes_num, + const int feature_in_len, + const int sampled_pts_num, + const char *points_xyz_gdram, + const char *point_features_gdram, + const char *boxes3d_gdram, + char *pooled_features_gdram, + char *pooled_empty_flag_gdram); + +void KernelRoiPointPool3dLargeBoxesNumForward(cnrtDim3_t k_dim, + cnrtFunctionType_t k_type, + cnrtQueue_t queue, + const cnrtDataType_t d_type, + const int batch_size, + const int pts_num, + const int boxes_num, + const int feature_in_len, + const int sampled_pts_num, + const void *points_xyz, + const void *boxes3d, + const void *point_features, + void *pooled_features, + int *pooled_empty_flag) { + switch (d_type) { + default: { break; } + case CNRT_FLOAT32: { + MLUUnion1KernelRoiPointPool3dLargeBoxesNumForward<<>>( + batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, + (char *)points_xyz, (char *)point_features, (char *)boxes3d, + (char *)pooled_features, (char *)pooled_empty_flag); + }; break; + case CNRT_FLOAT16: { + MLUUnion1KernelRoiPointPool3dLargeBoxesNumForward<<>>( + batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, + (char *)points_xyz, (char *)point_features, (char *)boxes3d, + (char *)pooled_features, (char *)pooled_empty_flag); + }; break; + } +} diff --git a/mmcv/ops/csrc/common/mlu/roipoint_pool3d_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/roipoint_pool3d_mlu_kernel.mlu new file mode 100644 index 0000000000..9cfecec13a --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/roipoint_pool3d_mlu_kernel.mlu @@ -0,0 +1,545 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * OR IMPLIED, INCLUDING BUvoid NOKType LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENvoid SHALL THE AUTHORS OR COPYRIGHKType HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORvoid OR OTHERWISE, ARISING FROM, OUKType OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "common_mlu_helper.hpp" +#include "roipoint_pool3d_utils.hpp" + +/************************************************************************************** + * + * NRAM partition: + * | boxes3d | cnt | + * | boxes_num * 7 * sizeof(T) | boxes_num * sizeof(int) | + * + * | ping points | pong points | aux_a ~ aux_f | + * | 3 * deal_num * sizeof(T) | 3 * deal_num * sizeof(T) | 6 * deal_num * sizeof(T) | + * + ***************************************************************************************/ +#define TWELVE_SPLIT 12 + +__nram__ char nram_buffer[MAX_NRAM_SIZE]; + +template +__mlu_func__ void checkPointsInBox3d(const T *boxes3d, + const size_t deal_num, + T *x, + T *y, + T *z, + T *auxiliary_a, + T *auxiliary_b, + T *auxiliary_c, + T *auxiliary_d, + T *auxiliary_e, + T *auxiliary_f, + T *pts_assign) { + // param box3d: (cx, cy, cz, dx, dy, dz, rz) in LiDAR coordinate + T cx = boxes3d[0]; + T cy = boxes3d[1]; + T cz = boxes3d[2]; + T dx = boxes3d[3]; + T dy = boxes3d[4]; + T dz = boxes3d[5]; + T rz = boxes3d[6]; + // shift to the center since cz in box3d is the bottom center + cz += 0.5 * dz; + + T cosa = (T)std::cos(-rz); + T sina = (T)std::sin(-rz); + + // x - cx + __bang_sub_scalar((T *)auxiliary_a, (T *)x, (T)cx, deal_num); + // y - cy + __bang_sub_scalar((T *)auxiliary_b, (T *)y, (T)cy, deal_num); + // z - cz + __bang_sub_scalar((T *)auxiliary_c, (T *)z, (T)cz, deal_num); + // |z - cz| + __bang_active_abs((T *)auxiliary_c, (T *)auxiliary_c, deal_num); + // |z - cz| > dz / 2.0 +#if __BANG_ARCH__ >= 322 + __bang_gt_scalar((T *)auxiliary_c, (T *)auxiliary_c, (T)(0.5 * dz), deal_num); +#else + __bang_write_value((T *)auxiliary_d, deal_num, (T)(0.5 * dz)); + __bang_lt((T *)auxiliary_c, (T *)auxiliary_d, (T *)auxiliary_c, deal_num); +#endif + // !(|z - cz| > dz / 2.0) + __bang_not((T *)auxiliary_c, (T *)auxiliary_c, deal_num); + // (x - cx) * cos(-rz) + __bang_mul_scalar((T *)auxiliary_d, (T *)auxiliary_a, (T)cosa, deal_num); + // (y - cy) * sin(-rz) + __bang_mul_scalar((T *)auxiliary_e, (T *)auxiliary_b, (T)sina, deal_num); + // local_x = (x - cx) * cos(-rz) + (y - cy) * -sin(-rz) + __bang_sub((T *)auxiliary_d, (T *)auxiliary_d, (T *)auxiliary_e, deal_num); + // |local_x| + __bang_active_abs((T *)auxiliary_d, (T *)auxiliary_d, deal_num); + // |local_x| < dx / 2.0 +#if __BANG_ARCH__ >= 322 + __bang_lt_scalar(auxiliary_d, auxiliary_d, (T)(0.5 * dx), deal_num); +#else + __bang_write_value((T *)auxiliary_e, deal_num, (T)(0.5 * dx)); + __bang_gt((T *)auxiliary_d, (T *)auxiliary_e, (T *)auxiliary_d, deal_num); +#endif + // (x - cx) * sin(-rz) + __bang_mul_scalar((T *)auxiliary_e, (T *)auxiliary_a, (T)sina, deal_num); + // (y - cy) * cos(-rz) + __bang_mul_scalar((T *)auxiliary_f, (T *)auxiliary_b, (T)cosa, deal_num); + // local_y = (x - cx) * sin(-rz) + (y - cy) * cos(-rz) + __bang_add((T *)auxiliary_e, (T *)auxiliary_e, (T *)auxiliary_f, deal_num); + // |local_y| + __bang_active_abs((T *)auxiliary_e, (T *)auxiliary_e, deal_num); + // |local_y| < dy / 2.0 +#if __BANG_ARCH__ >= 322 + __bang_lt_scalar(auxiliary_e, auxiliary_e, (T)(0.5 * dy), deal_num); +#else + __bang_write_value((T *)auxiliary_f, deal_num, (T)(0.5 * dy)); + __bang_gt((T *)auxiliary_e, (T *)auxiliary_f, (T *)auxiliary_e, deal_num); +#endif + // pts_assign = |x - cx| < dx / 2.0 && |y - cy| < dy / 2.0 && |z - cz| <= dz / 2.0 + __bang_mul((T *)pts_assign, (T *)auxiliary_c, (T *)auxiliary_d, deal_num); + __bang_mul((T *)pts_assign, (T *)pts_assign, (T *)auxiliary_e, deal_num); +} + +template +__mlu_func__ void computeStoreRoipointPool3d(char *boxes3d, + int *cnt, + char *points_x, + char *points_y, + char *points_z, + const char *point_features, + char *auxiliary_a, + char *auxiliary_b, + char *auxiliary_c, + char *auxiliary_d, + char *auxiliary_e, + char *auxiliary_f, + const int box_idx, + const int pts_num, + const int feature_in_len, + const int sampled_pts_num, + const size_t span_num_deal, + char *pooled_features_gdram, + char *pooled_empty_flag_gdram) { + char *pts_assign = auxiliary_a; + if (cnt[box_idx] >= sampled_pts_num) { + return; + } + checkPointsInBox3d((T *)(boxes3d + box_idx * 7 * sizeof(T)), span_num_deal, (T *)points_x, + (T *)points_y, (T *)points_z, (T *)auxiliary_a, (T *)auxiliary_b, + (T *)auxiliary_c, (T *)auxiliary_d, (T *)auxiliary_e, (T *)auxiliary_f, + (T *)pts_assign); + + // __bang_select returns selected elements vector and the number of selected elements + __bang_select((T *)auxiliary_b, (T *)points_x, (T *)pts_assign, span_num_deal); + uint32_t select_num = *((uint32_t *)auxiliary_b); + + if (select_num == 0) { + return; + } + int sampled_pts_num_rem = sampled_pts_num - cnt[box_idx]; + int segnum = MIN(select_num, sampled_pts_num_rem) - 1; + + // copy x to pooled_features_gdram + // The result of __bang_select is composed of three parts: + // The first 4-byte is the number of selected element, whose data type is unsigned int. + // The next 124-byte is zero. The rest bytes are the selected elements. + int select_num_size = 128; + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + cnt[box_idx]) * (3 + feature_in_len) * sizeof(T), + (T *)((int8_t *)auxiliary_b + select_num_size), sizeof(T), NRAM2GDRAM, + (3 + feature_in_len) * sizeof(T), sizeof(T), segnum); + + // copy y to pooled_features_gdram + __bang_collect((T *)auxiliary_d, (T *)points_y, (T *)pts_assign, span_num_deal); + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + cnt[box_idx]) * (3 + feature_in_len) * sizeof(T) + + 1 * sizeof(T), + (T *)auxiliary_d, sizeof(T), NRAM2GDRAM, (3 + feature_in_len) * sizeof(T), sizeof(T), + segnum); + + // copy z to pooled_features_gdram + __bang_collect((T *)auxiliary_e, (T *)points_z, (T *)pts_assign, span_num_deal); + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + cnt[box_idx]) * (3 + feature_in_len) * sizeof(T) + + 2 * sizeof(T), + (T *)auxiliary_e, sizeof(T), NRAM2GDRAM, (3 + feature_in_len) * sizeof(T), sizeof(T), + segnum); + + // copy features to pooled_features_gdram + for (int c_idx = 0; c_idx < feature_in_len; c_idx++) { + __memcpy(auxiliary_d, point_features + c_idx * pts_num * sizeof(T), span_num_deal * sizeof(T), + GDRAM2NRAM); + __bang_collect((T *)auxiliary_e, (T *)auxiliary_d, (T *)pts_assign, span_num_deal); + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + cnt[box_idx]) * (3 + feature_in_len) * sizeof(T) + + (3 + c_idx) * sizeof(T), + auxiliary_e, sizeof(T), NRAM2GDRAM, (3 + feature_in_len) * sizeof(T), sizeof(T), + segnum); + } + + cnt[box_idx] += select_num; +} + +template +__mlu_func__ void computeStoreLastBlockRoipointPool3d(char *boxes3d, + int *cnt, + char *points_x, + char *points_y, + char *points_z, + const char *point_features, + char *auxiliary_a, + char *auxiliary_b, + char *auxiliary_c, + char *auxiliary_d, + char *auxiliary_e, + char *auxiliary_f, + const int box_idx, + const int pts_num, + const int feature_in_len, + const int sampled_pts_num, + const size_t span_num_deal, + const size_t auxiliary_num_deal, + char *pooled_features_gdram, + char *pooled_empty_flag_gdram) { + char *pts_assign = auxiliary_a; + if (cnt[box_idx] >= sampled_pts_num) { + // pooled_empty_flag_gdram set 0 + *((int *)auxiliary_a) = 0; + __memcpy(pooled_empty_flag_gdram + box_idx * sizeof(int), auxiliary_a, sizeof(int), NRAM2GDRAM); + return; + } + checkPointsInBox3d((T *)(boxes3d + box_idx * 7 * sizeof(T)), span_num_deal, (T *)points_x, + (T *)points_y, (T *)points_z, (T *)auxiliary_a, (T *)auxiliary_b, + (T *)auxiliary_c, (T *)auxiliary_d, (T *)auxiliary_e, (T *)auxiliary_f, + (T *)pts_assign); + + // __bang_select returns selected elements vector and the number of selected elements + __bang_select((T *)auxiliary_b, (T *)points_x, (T *)pts_assign, span_num_deal); + uint32_t select_num = *((uint32_t *)auxiliary_b); + + if (cnt[box_idx] + select_num == 0) { + // pooled_empty_flag_gdram set 1 + *((int *)auxiliary_a) = 1; + __memcpy(pooled_empty_flag_gdram + box_idx * sizeof(int), auxiliary_a, sizeof(int), NRAM2GDRAM); + + // pooled_features_gdram set 0 + int repeat = (sampled_pts_num * (3 + feature_in_len)) / (auxiliary_num_deal * 6); + int rem = (sampled_pts_num * (3 + feature_in_len)) % (auxiliary_num_deal * 6); + // use auxiliary_a to auxiliary_f + __bang_write_zero((T *)auxiliary_a, PAD_UP(auxiliary_num_deal * 6, NFU_ALIGN_SIZE)); + if (repeat > 0) { + __memcpy(pooled_features_gdram + box_idx * sampled_pts_num * (3 + feature_in_len) * sizeof(T), + auxiliary_a, auxiliary_num_deal * 6 * sizeof(T), NRAM2GDRAM, + auxiliary_num_deal * 6 * sizeof(T), 0, repeat - 1); + } + if (rem > 0) { + __memcpy(pooled_features_gdram + + box_idx * sampled_pts_num * (3 + feature_in_len) * sizeof(T) + + repeat * auxiliary_num_deal * 6 * sizeof(T), + auxiliary_a, rem * sizeof(T), NRAM2GDRAM); + } + return; + } + + if (select_num > 0) { + int sampled_pts_num_rem = sampled_pts_num - cnt[box_idx]; + int segnum = MIN(select_num, sampled_pts_num_rem) - 1; + + // copy x to pooled_features_gdram + // The result of __bang_select is composed of three parts: + // The first 4-byte is the number of selected element, whose data type is unsigned int. + // The next 124-byte is zero. The rest bytes are the selected elements. + int select_num_size = 128; + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + cnt[box_idx]) * (3 + feature_in_len) * sizeof(T), + (T *)((int8_t *)auxiliary_b + select_num_size), sizeof(T), NRAM2GDRAM, + (3 + feature_in_len) * sizeof(T), sizeof(T), segnum); + + // copy y to pooled_features_gdram + __bang_collect((T *)auxiliary_d, (T *)points_y, (T *)pts_assign, span_num_deal); + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + cnt[box_idx]) * (3 + feature_in_len) * sizeof(T) + + 1 * sizeof(T), + (T *)auxiliary_d, sizeof(T), NRAM2GDRAM, (3 + feature_in_len) * sizeof(T), sizeof(T), + segnum); + + // copy z to pooled_features_gdram + __bang_collect((T *)auxiliary_e, (T *)points_z, (T *)pts_assign, span_num_deal); + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + cnt[box_idx]) * (3 + feature_in_len) * sizeof(T) + + 2 * sizeof(T), + (T *)auxiliary_e, sizeof(T), NRAM2GDRAM, (3 + feature_in_len) * sizeof(T), sizeof(T), + segnum); + + // copy features to pooled_features_gdram + for (int c_idx = 0; c_idx < feature_in_len; c_idx++) { + __memcpy(auxiliary_d, point_features + c_idx * pts_num * sizeof(T), span_num_deal * sizeof(T), + GDRAM2NRAM); + __bang_collect((T *)auxiliary_e, (T *)auxiliary_d, (T *)pts_assign, span_num_deal); + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + cnt[box_idx]) * (3 + feature_in_len) * sizeof(T) + + (3 + c_idx) * sizeof(T), + auxiliary_e, sizeof(T), NRAM2GDRAM, (3 + feature_in_len) * sizeof(T), sizeof(T), + segnum); + } + } + + // pooled_empty_flag_gdram set 0 + *((int *)auxiliary_a) = 0; + __memcpy(pooled_empty_flag_gdram + box_idx * sizeof(int), auxiliary_a, sizeof(int), NRAM2GDRAM); + + cnt[box_idx] += select_num; + if (cnt[box_idx] < sampled_pts_num) { + // duplicate same points for sampling + int repeat = sampled_pts_num / cnt[box_idx] - 1; + int rem = sampled_pts_num % cnt[box_idx]; + if (repeat > 0) { + __memcpy(pooled_features_gdram + + (box_idx * sampled_pts_num + cnt[box_idx]) * (3 + feature_in_len) * sizeof(T), + pooled_features_gdram + box_idx * sampled_pts_num * (3 + feature_in_len) * sizeof(T), + cnt[box_idx] * (3 + feature_in_len) * sizeof(T), GDRAM2GDRAM, + cnt[box_idx] * (3 + feature_in_len) * sizeof(T), 0, repeat - 1); + } + if (rem > 0) { + __memcpy(pooled_features_gdram + (box_idx * sampled_pts_num + (repeat + 1) * cnt[box_idx]) * + (3 + feature_in_len) * sizeof(T), + pooled_features_gdram + box_idx * sampled_pts_num * (3 + feature_in_len) * sizeof(T), + rem * (3 + feature_in_len) * sizeof(T), GDRAM2GDRAM); + } + } +} + +template +__mlu_global__ void MLUUnion1KernelRoiPointPool3dForward( + const int batch_size, + const int pts_num, + const int boxes_num, + const int feature_in_len, + const int sampled_pts_num, + const char *points_xyz_gdram, + const char *point_features_gdram, + const char *boxes3d_gdram, + char *pooled_features_gdram, + char *pooled_empty_flag_gdram) { + if (coreId == 0x80) { + return; + } + size_t boxes_per_core = (batch_size * boxes_num) / taskDim; + size_t boxes_rem = (batch_size * boxes_num) % taskDim; + // calc batch_start, batch_end, first_batch_box_start, last batch_box_end for each core + int32_t batch_start = taskId < (boxes_rem + 1) ? + (taskId * (boxes_per_core + 1)) / boxes_num : + (taskId * boxes_per_core + boxes_rem) / boxes_num; + int32_t batch_end = taskId < boxes_rem ? + ((taskId + 1) * (boxes_per_core + 1) - 1) / boxes_num : + ((taskId + 1) * boxes_per_core + boxes_rem - 1) / boxes_num; + size_t first_batch_box_start = taskId < (boxes_rem + 1) ? + (taskId * (boxes_per_core + 1)) - batch_start * boxes_num : + taskId * boxes_per_core + boxes_rem - batch_start * boxes_num; + size_t last_batch_box_end = taskId < boxes_rem ? + (taskId + 1) * (boxes_per_core + 1) - batch_end * boxes_num : + ((taskId + 1) * boxes_per_core + boxes_rem) - batch_end * boxes_num; + + // points_xyz : [3, B, N] + const char *points_x_gdram = points_xyz_gdram; + const char *points_y_gdram = points_xyz_gdram + (1 * batch_size * pts_num) * sizeof(T); + const char *points_z_gdram = points_xyz_gdram + (2 * batch_size * pts_num) * sizeof(T); + + size_t boxes3d_size = PAD_UP(boxes_num * 7, NFU_ALIGN_SIZE) * sizeof(T); + size_t cnt_size = PAD_UP(boxes_num, NFU_ALIGN_SIZE) * sizeof(int); + size_t span_num_deal = PAD_DOWN( + (MAX_NRAM_SIZE - boxes3d_size - cnt_size) / TWELVE_SPLIT / sizeof(T), NFU_ALIGN_SIZE); + size_t align_num = NFU_ALIGN_SIZE; + int32_t repeat = pts_num / span_num_deal; + size_t rem = pts_num % span_num_deal; + size_t align_rem = CEIL_ALIGN(rem, align_num); + char *boxes3d = nram_buffer; + char *cnt = nram_buffer + boxes3d_size; + char *ping_points_x = cnt + cnt_size; + char *ping_points_y = ping_points_x + span_num_deal * sizeof(T); + char *ping_points_z = ping_points_y + span_num_deal * sizeof(T); + size_t ping_pong_gap = 3 * span_num_deal * sizeof(T); + char *auxiliary_a = ping_points_x + 2 * ping_pong_gap; + char *auxiliary_b = auxiliary_a + span_num_deal * sizeof(T); + char *auxiliary_c = auxiliary_b + span_num_deal * sizeof(T); + char *auxiliary_d = auxiliary_c + span_num_deal * sizeof(T); + char *auxiliary_e = auxiliary_d + span_num_deal * sizeof(T); + char *auxiliary_f = auxiliary_e + span_num_deal * sizeof(T); + size_t span_load_input1_size = span_num_deal * sizeof(T); + size_t span_load_input2_size = span_num_deal * sizeof(T); + size_t span_load_input3_size = span_num_deal * sizeof(T); + size_t span_load_input4_size = span_num_deal * sizeof(T); + + for (int bs_idx = batch_start; bs_idx <= batch_end; bs_idx++) { + __memcpy_async(boxes3d, boxes3d_gdram + bs_idx * boxes_num * 7 * sizeof(T), + boxes_num * 7 * sizeof(T), GDRAM2NRAM); + __bang_write_zero((int *)cnt, PAD_UP(boxes_num, NFU_ALIGN_SIZE)); + + const char *points_x_start = points_x_gdram + bs_idx * pts_num * sizeof(T); + const char *points_y_start = points_y_gdram + bs_idx * pts_num * sizeof(T); + const char *points_z_start = points_z_gdram + bs_idx * pts_num * sizeof(T); + const char *point_features_start = + point_features_gdram + bs_idx * feature_in_len * pts_num * sizeof(T); + char *pooled_features_start = + pooled_features_gdram + + (bs_idx * boxes_num * sampled_pts_num * (3 + feature_in_len)) * sizeof(T); + char *pooled_empty_flag_start = pooled_empty_flag_gdram + bs_idx * boxes_num * sizeof(int); + size_t box_start = bs_idx == batch_start ? first_batch_box_start : 0; + size_t box_end = bs_idx == batch_end ? last_batch_box_end : boxes_num; + + if (repeat > 0) { + __memcpy_async(ping_points_x, points_x_start, span_load_input1_size, GDRAM2NRAM); + __memcpy_async(ping_points_y, points_y_start, span_load_input2_size, GDRAM2NRAM); + __memcpy_async(ping_points_z, points_z_start, span_load_input3_size, GDRAM2NRAM); + __asm__ volatile("sync;"); + } + + for (int i = 0; i < repeat - 1; i++) { + __memcpy_async(ping_points_x + ((i + 1) % 2) * ping_pong_gap, + points_x_start + (i + 1) * span_load_input1_size, span_load_input1_size, + GDRAM2NRAM); + __memcpy_async(ping_points_y + ((i + 1) % 2) * ping_pong_gap, + points_y_start + (i + 1) * span_load_input2_size, span_load_input2_size, + GDRAM2NRAM); + __memcpy_async(ping_points_z + ((i + 1) % 2) * ping_pong_gap, + points_z_start + (i + 1) * span_load_input3_size, span_load_input3_size, + GDRAM2NRAM); + for (int box_idx = box_start; box_idx < box_end; box_idx++) { + computeStoreRoipointPool3d( + boxes3d, (int *)cnt, ping_points_x + (i % 2) * ping_pong_gap, + ping_points_y + (i % 2) * ping_pong_gap, ping_points_z + (i % 2) * ping_pong_gap, + point_features_start + i * span_load_input4_size, auxiliary_a, auxiliary_b, auxiliary_c, + auxiliary_d, auxiliary_e, auxiliary_f, box_idx, pts_num, feature_in_len, + sampled_pts_num, span_num_deal, pooled_features_start, pooled_empty_flag_start); + } + __asm__ volatile("sync;"); + } + + if (rem > 0) { + if (sizeof(T) == sizeof(float)) { + __bang_write_value((T *)(ping_points_x + (repeat % 2) * ping_pong_gap + + PAD_DOWN(rem, NFU_ALIGN_SIZE) * sizeof(T)), + NFU_ALIGN_SIZE, (T)NAN); + __bang_write_value((T *)(ping_points_y + (repeat % 2) * ping_pong_gap + + PAD_DOWN(rem, NFU_ALIGN_SIZE) * sizeof(T)), + NFU_ALIGN_SIZE, (T)NAN); + __bang_write_value((T *)(ping_points_z + (repeat % 2) * ping_pong_gap + + PAD_DOWN(rem, NFU_ALIGN_SIZE) * sizeof(T)), + NFU_ALIGN_SIZE, (T)NAN); + } else { + __bang_write_value((T *)(ping_points_x + (repeat % 2) * ping_pong_gap + + PAD_DOWN(rem, NFU_ALIGN_SIZE) * sizeof(T)), + NFU_ALIGN_SIZE, (T)NAN); + __bang_write_value((T *)(ping_points_y + (repeat % 2) * ping_pong_gap + + PAD_DOWN(rem, NFU_ALIGN_SIZE) * sizeof(T)), + NFU_ALIGN_SIZE, (T)NAN); + __bang_write_value((T *)(ping_points_z + (repeat % 2) * ping_pong_gap + + PAD_DOWN(rem, NFU_ALIGN_SIZE) * sizeof(T)), + NFU_ALIGN_SIZE, (T)NAN); + } + __memcpy_async(ping_points_x + (repeat % 2) * ping_pong_gap, + points_x_start + repeat * span_load_input1_size, rem * sizeof(T), GDRAM2NRAM); + __memcpy_async(ping_points_y + (repeat % 2) * ping_pong_gap, + points_y_start + repeat * span_load_input2_size, rem * sizeof(T), GDRAM2NRAM); + __memcpy_async(ping_points_z + (repeat % 2) * ping_pong_gap, + points_z_start + repeat * span_load_input3_size, rem * sizeof(T), GDRAM2NRAM); + } + + if (repeat > 0 && rem > 0) { + for (int box_idx = box_start; box_idx < box_end; box_idx++) { + computeStoreRoipointPool3d( + boxes3d, (int *)cnt, ping_points_x + ((repeat - 1) % 2) * ping_pong_gap, + ping_points_y + ((repeat - 1) % 2) * ping_pong_gap, + ping_points_z + ((repeat - 1) % 2) * ping_pong_gap, + point_features_start + (repeat - 1) * span_load_input4_size, auxiliary_a, auxiliary_b, + auxiliary_c, auxiliary_d, auxiliary_e, auxiliary_f, box_idx, pts_num, feature_in_len, + sampled_pts_num, span_num_deal, pooled_features_start, pooled_empty_flag_start); + } + } else if (repeat > 0 && rem == 0) { + for (int box_idx = box_start; box_idx < box_end; box_idx++) { + computeStoreLastBlockRoipointPool3d( + boxes3d, (int *)cnt, ping_points_x + ((repeat - 1) % 2) * ping_pong_gap, + ping_points_y + ((repeat - 1) % 2) * ping_pong_gap, + ping_points_z + ((repeat - 1) % 2) * ping_pong_gap, + point_features_start + (repeat - 1) * span_load_input4_size, auxiliary_a, auxiliary_b, + auxiliary_c, auxiliary_d, auxiliary_e, auxiliary_f, box_idx, pts_num, feature_in_len, + sampled_pts_num, span_num_deal, span_num_deal, pooled_features_start, + pooled_empty_flag_start); + } + } + + if (rem > 0) { + __asm__ volatile("sync;"); + for (int box_idx = box_start; box_idx < box_end; box_idx++) { + computeStoreLastBlockRoipointPool3d( + boxes3d, (int *)cnt, ping_points_x + (repeat % 2) * ping_pong_gap, + ping_points_y + (repeat % 2) * ping_pong_gap, + ping_points_z + (repeat % 2) * ping_pong_gap, + point_features_start + repeat * span_load_input4_size, auxiliary_a, auxiliary_b, + auxiliary_c, auxiliary_d, auxiliary_e, auxiliary_f, box_idx, pts_num, feature_in_len, + sampled_pts_num, align_rem, span_num_deal, pooled_features_start, + pooled_empty_flag_start); + } + } + } +} + +template __mlu_global__ void MLUUnion1KernelRoiPointPool3dForward( + const int batch_size, + const int pts_num, + const int boxes_num, + const int feature_in_len, + const int sampled_pts_num, + const char *points_xyz_gdram, + const char *point_features_gdram, + const char *boxes3d_gdram, + char *pooled_features_gdram, + char *pooled_empty_flag_gdram); + +template __mlu_global__ void MLUUnion1KernelRoiPointPool3dForward( + const int batch_size, + const int pts_num, + const int boxes_num, + const int feature_in_len, + const int sampled_pts_num, + const char *points_xyz_gdram, + const char *point_features_gdram, + const char *boxes3d_gdram, + char *pooled_features_gdram, + char *pooled_empty_flag_gdram); + +void KernelRoiPointPool3dForward(cnrtDim3_t k_dim, + cnrtFunctionType_t k_type, + cnrtQueue_t queue, + const cnrtDataType_t d_type, + const int batch_size, + const int pts_num, + const int boxes_num, + const int feature_in_len, + const int sampled_pts_num, + const void *points_xyz, + const void *boxes3d, + const void *point_features, + void *pooled_features, + int *pooled_empty_flag) { + switch (d_type) { + default: { break; } + case CNRT_FLOAT32: { + MLUUnion1KernelRoiPointPool3dForward<<>>( + batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, + (char *)points_xyz, (char *)point_features, (char *)boxes3d, + (char *)pooled_features, (char *)pooled_empty_flag); + }; break; + case CNRT_FLOAT16: { + MLUUnion1KernelRoiPointPool3dForward<<>>( + batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, + (char *)points_xyz, (char *)point_features, (char *)boxes3d, + (char *)pooled_features, (char *)pooled_empty_flag); + }; break; + } +} diff --git a/mmcv/ops/csrc/common/mlu/roipoint_pool3d_utils.hpp b/mmcv/ops/csrc/common/mlu/roipoint_pool3d_utils.hpp new file mode 100644 index 0000000000..db618756c0 --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/roipoint_pool3d_utils.hpp @@ -0,0 +1,18 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef ROIPOINT_POOL3D_UTILS_HPP_ +#define ROIPOINT_POOL3D_UTILS_HPP_ + +#define MIN(a, b) (((a) < (b)) ? (a) : (b)) +#define MAX(a, b) (((a) > (b)) ? (a) : (b)) + +#endif // ROIPOINT_POOL3D_UTILS_HPP_ diff --git a/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp new file mode 100644 index 0000000000..28b3f98be0 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp @@ -0,0 +1,175 @@ +/************************************************************************* + * Copyright (C) 2022 by Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "pytorch_device_registry.hpp" +#include "pytorch_mlu_helper.hpp" +#include "roipoint_pool3d_utils.hpp" + +void KernelRoiPointPool3dForward(cnrtDim3_t k_dim, + cnrtFunctionType_t k_type, + cnrtQueue_t queue, + const cnrtDataType_t d_type, + const int batch_size, + const int pts_num, + const int boxes_num, + const int feature_in_len, + const int sampled_pts_num, + const void *xyz, + const void *boxes3d, + const void *pts_feature, + void *pooled_features, + int *pooled_empty_flag); + +void KernelRoiPointPool3dLargeBoxesNumForward(cnrtDim3_t k_dim, + cnrtFunctionType_t k_type, + cnrtQueue_t queue, + const cnrtDataType_t d_type, + const int batch_size, + const int pts_num, + const int boxes_num, + const int feature_in_len, + const int sampled_pts_num, + const void *xyz, + const void *boxes3d, + const void *pts_feature, + void *pooled_features, + int *pooled_empty_flag); + +// policy function +static void policyFuncForward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) { + // start U1 task, occupy all available clusters + k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + k_dim->y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + k_dim->z = 1; + *k_type = CNRT_FUNC_TYPE_UNION1; +} + +void RoIPointPool3dForwardMLUKernelLauncher(int batch_size, + int pts_num, + int boxes_num, + int feature_in_len, + int sampled_pts_num, + const Tensor xyz, + const Tensor boxes3d, + const Tensor pts_feature, + Tensor pooled_features, + Tensor pooled_empty_flag) { + // check datatype + TORCH_CHECK(((xyz.scalar_type() == pooled_features.scalar_type()) && + (boxes3d.scalar_type() == pooled_features.scalar_type()) && + (pts_feature.scalar_type() == pooled_features.scalar_type())), + "data types of xyz, boxes3d, pts_feature and pooled_features should be the same, ", + "but now xyz type is ", xyz.scalar_type(), ", boxes3d type is ", + boxes3d.scalar_type(), ", pts_feature type is ", pts_feature.scalar_type(), + ", pooled_features type is ", pooled_features.scalar_type(), "."); + TORCH_CHECK((xyz.scalar_type() == at::kFloat || xyz.scalar_type() == at::kHalf), + "xyz type should be Float or Half, got ", xyz.scalar_type(), "."); + TORCH_CHECK((pooled_empty_flag.scalar_type() == at::kInt), + "pooled_empty_flag type should be Int, got ", pooled_empty_flag.scalar_type(), "."); + + // check shape + TORCH_CHECK(boxes3d.dim() == 3, "boxes3d should be a 3d tensor, got ", boxes3d.dim(), "D."); + TORCH_CHECK(pts_feature.dim() == 3, "pts_feature should be a 3d tensor, got ", + pts_feature.dim(), "D."); + + TORCH_CHECK(boxes3d.size(2) == 7, "the 3rd dimensions of boxes3d should be 7, got ", + boxes3d.size(2), "."); + TORCH_CHECK((boxes3d.size(0) == batch_size), + "the 1st dimensions of boxes3d should be batch_size, ", + "but now the 1st dimension of boxes3d is ", boxes3d.size(0), + ", and batch_size is ", batch_size, "."); + TORCH_CHECK((pts_feature.size(0) == batch_size), + "the 1st dimensions of pts_feature should be batch_size, ", + "but now the 1st dimension of pts_feature is ", pts_feature.size(0), + ", and batch_size is ", batch_size, "."); + TORCH_CHECK((pts_feature.size(1) == pts_num), + "the 2nd dimensions of pts_feature should be pts_num, ", + "but now the 2nd dimension of pts_feature is ", pts_feature.size(1), + ", and pts_num is ", pts_num, "."); + + // check zero element + if (xyz.numel() == 0 || pts_feature.numel() == 0 || boxes3d.numel() == 0 || + pooled_features.numel() == 0 || pooled_empty_flag.numel() == 0) { + return; + } + + // large tensor check + const size_t max_input_size = 2147483648; + TORCH_CHECK(xyz.numel() < max_input_size, "xyz element num should be less than 2^31, got ", + xyz.numel(), "."); + TORCH_CHECK(boxes3d.numel() < max_input_size, "boxes3d element num should be less than 2^31, got ", + boxes3d.numel(), "."); + TORCH_CHECK(pts_feature.numel() < max_input_size, "pts_feature element num should be less than 2^31, got ", + pts_feature.numel(), "."); + + // calculate task dimension + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + policyFuncForward(&k_dim, &k_type); + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + + // get ptr of tensors + // transpose points [B, N ,3] -> [3, B, N] + auto xyz_ = xyz.permute({2, 0, 1}).contiguous(); + auto xyz_impl = torch_mlu::getMluTensorImpl(xyz_); + auto xyz_ptr = xyz_impl->cnnlMalloc(); + // transpose point_features [B, N, C] -> [B, C, N] + auto pts_feature_ = pts_feature.permute({0, 2, 1}).contiguous(); + auto pts_feature_impl = torch_mlu::getMluTensorImpl(pts_feature_); + auto pts_feature_ptr = pts_feature_impl->cnnlMalloc(); + auto boxes3d_impl = torch_mlu::getMluTensorImpl(boxes3d); + auto boxes3d_ptr = boxes3d_impl->cnnlMalloc(); + auto pooled_features_impl = torch_mlu::getMluTensorImpl(pooled_features); + auto pooled_features_ptr = pooled_features_impl->cnnlMalloc(); + auto pooled_empty_flag_impl = torch_mlu::getMluTensorImpl(pooled_empty_flag); + auto pooled_empty_flag_ptr = pooled_empty_flag_impl->cnnlMalloc(); + + // get compute dtype of input + cnrtDataType_t data_type = torch_mlu::toCnrtDtype(xyz_.dtype()); + + // launch kernel + if (boxes_num <= 10240) { + CNLOG(INFO) << "Launch Kernel MLUKernelRoiPointPool3dForward<<<" << k_dim.x + << ", " << k_dim.y << ", " << k_dim.z << ">>>"; + KernelRoiPointPool3dForward(k_dim, k_type, queue, data_type, batch_size, pts_num, boxes_num, + feature_in_len, sampled_pts_num, xyz_ptr, boxes3d_ptr, + pts_feature_ptr, pooled_features_ptr, (int *)pooled_empty_flag_ptr); + } else { + CNLOG(INFO) << "Launch Kernel MLUKernelRoiPointPool3dLargeBoxesNumForward<<<" + << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; + KernelRoiPointPool3dLargeBoxesNumForward(k_dim, k_type, queue, data_type, batch_size, pts_num, + boxes_num, feature_in_len, sampled_pts_num, xyz_ptr, + boxes3d_ptr, pts_feature_ptr, pooled_features_ptr, + (int *)pooled_empty_flag_ptr); + } +} + +void roipoint_pool3d_forward_mlu(int batch_size, int pts_num, int boxes_num, + int feature_in_len, int sampled_pts_num, + const Tensor xyz, const Tensor boxes3d, + const Tensor pts_feature, + Tensor pooled_features, + Tensor pooled_empty_flag) { + RoIPointPool3dForwardMLUKernelLauncher(batch_size, pts_num, boxes_num, feature_in_len, + sampled_pts_num, xyz, boxes3d, pts_feature, + pooled_features, pooled_empty_flag); +} + +void roipoint_pool3d_forward_impl(int batch_size, int pts_num, int boxes_num, + int feature_in_len, int sampled_pts_num, + const Tensor xyz, const Tensor boxes3d, + const Tensor pts_feature, + Tensor pooled_features, + Tensor pooled_empty_flag); + +REGISTER_DEVICE_IMPL(roipoint_pool3d_forward_impl, MLU, roipoint_pool3d_forward_mlu); diff --git a/tests/test_ops/test_roipoint_pool3d.py b/tests/test_ops/test_roipoint_pool3d.py index 6619a36148..36f018d781 100644 --- a/tests/test_ops/test_roipoint_pool3d.py +++ b/tests/test_ops/test_roipoint_pool3d.py @@ -3,34 +3,56 @@ import torch from mmcv.ops import RoIPointPool3d +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -def test_roipoint(): - feats = torch.tensor( - [[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6], - [0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], - [4.7, 3.5, -12.2], [3.8, 7.6, -2], [-10.6, -12.9, -20], [-16, -18, 9], - [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], [-2, -3, -4]], - dtype=torch.float32).unsqueeze(0).cuda() - points = feats.clone() - rois = torch.tensor([[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3], - [-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]], - dtype=torch.float32).cuda() +class TestRoiPointPool3d: - roipoint_pool3d = RoIPointPool3d(num_sampled_points=4) - roi_feat, empty_flag = roipoint_pool3d(feats, points, rois) - expected_roi_feat = torch.tensor([[[[1, 2, 3.3, 1, 2, 3.3], - [1.2, 2.5, 3, 1.2, 2.5, 3], - [0.8, 2.1, 3.5, 0.8, 2.1, 3.5], - [1.6, 2.6, 3.6, 1.6, 2.6, 3.6]], - [[-9.2, 21, 18.2, -9.2, 21, 18.2], - [-9.2, 21, 18.2, -9.2, 21, 18.2], - [-9.2, 21, 18.2, -9.2, 21, 18.2], - [-9.2, 21, 18.2, -9.2, 21, - 18.2]]]]).cuda() - expected_empty_flag = torch.tensor([[0, 0]]).int().cuda() + def _test_roipoint_pool3d_allclose(self, device, dtype): + points = torch.tensor([[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], + [1.6, 2.6, 3.6], [0.8, 1.2, 3.9], + [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], + [4.7, 3.5, -12.2], [3.8, 7.6, -2], + [-10.6, -12.9, -20], [-16, -18, 9], + [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], + [-2, -3, + -4]]).unsqueeze(0).to(device).type(dtype) + feats = points.clone() + rois = torch.tensor([[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3], + [-10.0, 23.0, 16.0, 10, 20, 20, + 0.5]]]).to(device).type(dtype) - assert torch.allclose(roi_feat, expected_roi_feat) - assert torch.allclose(empty_flag, expected_empty_flag) + roipoint_pool3d = RoIPointPool3d(num_sampled_points=4) + pooled_features, pooled_empty_flag = roipoint_pool3d( + points, feats, rois) + expected_pooled_features = torch.tensor( + [[[[1, 2, 3.3, 1, 2, 3.3], [1.2, 2.5, 3, 1.2, 2.5, 3], + [0.8, 2.1, 3.5, 0.8, 2.1, 3.5], [1.6, 2.6, 3.6, 1.6, 2.6, 3.6]], + [[-9.2, 21, 18.2, -9.2, 21, 18.2], + [-9.2, 21, 18.2, -9.2, 21, 18.2], + [-9.2, 21, 18.2, -9.2, 21, 18.2], + [-9.2, 21, 18.2, -9.2, 21, 18.2]]]]).to(device).type(dtype) + expected_pooled_empty_flag = torch.tensor([[0, 0]]).int().to(device) + + assert torch.allclose(pooled_features, expected_pooled_features) + assert torch.allclose(pooled_empty_flag, expected_pooled_empty_flag) + + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) + ]) + @pytest.mark.parametrize('dtype', [ + torch.float, torch.half, + pytest.param( + torch.double, + marks=pytest.mark.skipif( + IS_MLU_AVAILABLE, reason='MLU does not support for double')) + ]) + def test_roipoint_pool3d_allclose(self, device, dtype): + self._test_roipoint_pool3d_allclose(device, dtype) From 0edce250e53d0dfc3f07ff656f2248803b2d4972 Mon Sep 17 00:00:00 2001 From: zhangshaopeng Date: Fri, 2 Sep 2022 14:45:30 +0800 Subject: [PATCH 2/5] [Feature] Support RoipointPool3d with cambricon MLU backend --- tests/test_ops/test_roipoint_pool3d.py | 88 ++++++++++++-------------- 1 file changed, 40 insertions(+), 48 deletions(-) diff --git a/tests/test_ops/test_roipoint_pool3d.py b/tests/test_ops/test_roipoint_pool3d.py index 36f018d781..e762cab3dd 100644 --- a/tests/test_ops/test_roipoint_pool3d.py +++ b/tests/test_ops/test_roipoint_pool3d.py @@ -6,53 +6,45 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE -class TestRoiPointPool3d: +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) +]) +@pytest.mark.parametrize('dtype', [ + torch.float, torch.half, + pytest.param( + torch.double, + marks=pytest.mark.skipif( + IS_MLU_AVAILABLE, reason='MLU does not support for double')) +]) +def test_roipoint(device, dtype): + points = torch.tensor( + [[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6], + [0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], + [4.7, 3.5, -12.2], [3.8, 7.6, -2], [-10.6, -12.9, -20], [-16, -18, 9], + [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], [-2, -3, -4]], + dtype=dtype).unsqueeze(0).to(device) + feats = points.clone() + rois = torch.tensor([[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3], + [-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]], + dtype=dtype).to(device) - def _test_roipoint_pool3d_allclose(self, device, dtype): - points = torch.tensor([[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], - [1.6, 2.6, 3.6], [0.8, 1.2, 3.9], - [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], - [4.7, 3.5, -12.2], [3.8, 7.6, -2], - [-10.6, -12.9, -20], [-16, -18, 9], - [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], - [-2, -3, - -4]]).unsqueeze(0).to(device).type(dtype) - feats = points.clone() - rois = torch.tensor([[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3], - [-10.0, 23.0, 16.0, 10, 20, 20, - 0.5]]]).to(device).type(dtype) + roipoint_pool3d = RoIPointPool3d(num_sampled_points=4) + pooled_features, pooled_empty_flag = roipoint_pool3d(points, feats, rois) + expected_pooled_features = torch.tensor( + [[[[1, 2, 3.3, 1, 2, 3.3], [1.2, 2.5, 3, 1.2, 2.5, 3], + [0.8, 2.1, 3.5, 0.8, 2.1, 3.5], [1.6, 2.6, 3.6, 1.6, 2.6, 3.6]], + [[-9.2, 21, 18.2, -9.2, 21, 18.2], [-9.2, 21, 18.2, -9.2, 21, 18.2], + [-9.2, 21, 18.2, -9.2, 21, 18.2], [-9.2, 21, 18.2, -9.2, 21, 18.2]]] + ], + dtype=dtype).to(device) + expected_pooled_empty_flag = torch.tensor([[0, 0]]).int().to(device) - roipoint_pool3d = RoIPointPool3d(num_sampled_points=4) - pooled_features, pooled_empty_flag = roipoint_pool3d( - points, feats, rois) - expected_pooled_features = torch.tensor( - [[[[1, 2, 3.3, 1, 2, 3.3], [1.2, 2.5, 3, 1.2, 2.5, 3], - [0.8, 2.1, 3.5, 0.8, 2.1, 3.5], [1.6, 2.6, 3.6, 1.6, 2.6, 3.6]], - [[-9.2, 21, 18.2, -9.2, 21, 18.2], - [-9.2, 21, 18.2, -9.2, 21, 18.2], - [-9.2, 21, 18.2, -9.2, 21, 18.2], - [-9.2, 21, 18.2, -9.2, 21, 18.2]]]]).to(device).type(dtype) - expected_pooled_empty_flag = torch.tensor([[0, 0]]).int().to(device) - - assert torch.allclose(pooled_features, expected_pooled_features) - assert torch.allclose(pooled_empty_flag, expected_pooled_empty_flag) - - @pytest.mark.parametrize('device', [ - pytest.param( - 'cuda', - marks=pytest.mark.skipif( - not IS_CUDA_AVAILABLE, reason='requires CUDA support')), - pytest.param( - 'mlu', - marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) - ]) - @pytest.mark.parametrize('dtype', [ - torch.float, torch.half, - pytest.param( - torch.double, - marks=pytest.mark.skipif( - IS_MLU_AVAILABLE, reason='MLU does not support for double')) - ]) - def test_roipoint_pool3d_allclose(self, device, dtype): - self._test_roipoint_pool3d_allclose(device, dtype) + assert torch.allclose(pooled_features, expected_pooled_features) + assert torch.allclose(pooled_empty_flag, expected_pooled_empty_flag) From fad78561c745b22e20f1dc82d12d472fe4e54180 Mon Sep 17 00:00:00 2001 From: zhangshaopeng Date: Fri, 2 Sep 2022 14:51:54 +0800 Subject: [PATCH 3/5] [Feature] Support RoipointPool3d with cambricon MLU backend --- tests/test_ops/test_roipoint_pool3d.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_ops/test_roipoint_pool3d.py b/tests/test_ops/test_roipoint_pool3d.py index e762cab3dd..391a0bf3a4 100644 --- a/tests/test_ops/test_roipoint_pool3d.py +++ b/tests/test_ops/test_roipoint_pool3d.py @@ -36,15 +36,15 @@ def test_roipoint(device, dtype): dtype=dtype).to(device) roipoint_pool3d = RoIPointPool3d(num_sampled_points=4) - pooled_features, pooled_empty_flag = roipoint_pool3d(points, feats, rois) - expected_pooled_features = torch.tensor( + roi_feat, empty_flag = roipoint_pool3d(points, feats, rois) + expected_roi_feat = torch.tensor( [[[[1, 2, 3.3, 1, 2, 3.3], [1.2, 2.5, 3, 1.2, 2.5, 3], [0.8, 2.1, 3.5, 0.8, 2.1, 3.5], [1.6, 2.6, 3.6, 1.6, 2.6, 3.6]], [[-9.2, 21, 18.2, -9.2, 21, 18.2], [-9.2, 21, 18.2, -9.2, 21, 18.2], [-9.2, 21, 18.2, -9.2, 21, 18.2], [-9.2, 21, 18.2, -9.2, 21, 18.2]]] ], dtype=dtype).to(device) - expected_pooled_empty_flag = torch.tensor([[0, 0]]).int().to(device) + expected_empty_flag = torch.tensor([[0, 0]]).int().to(device) - assert torch.allclose(pooled_features, expected_pooled_features) - assert torch.allclose(pooled_empty_flag, expected_pooled_empty_flag) + assert torch.allclose(roi_feat, expected_roi_feat) + assert torch.allclose(empty_flag, expected_empty_flag) From 5bcdae1613f15c25d17b63e852cdae03e26c69f1 Mon Sep 17 00:00:00 2001 From: zhangshaopeng Date: Mon, 5 Sep 2022 10:37:55 +0800 Subject: [PATCH 4/5] [Feature] Support RoipointPool3d with cambricon MLU backend --- ...point_pool3d_large_boxes_num_mlu_kernel.mlu | 5 ++--- .../common/mlu/roipoint_pool3d_mlu_kernel.mlu | 5 ++--- .../csrc/common/mlu/roipoint_pool3d_utils.hpp | 18 ------------------ .../csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp | 1 - 4 files changed, 4 insertions(+), 25 deletions(-) delete mode 100644 mmcv/ops/csrc/common/mlu/roipoint_pool3d_utils.hpp diff --git a/mmcv/ops/csrc/common/mlu/roipoint_pool3d_large_boxes_num_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/roipoint_pool3d_large_boxes_num_mlu_kernel.mlu index ec179396f9..58a15d8765 100644 --- a/mmcv/ops/csrc/common/mlu/roipoint_pool3d_large_boxes_num_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/roipoint_pool3d_large_boxes_num_mlu_kernel.mlu @@ -10,7 +10,6 @@ *************************************************************************/ #include "common_mlu_helper.hpp" -#include "roipoint_pool3d_utils.hpp" /************************************************************************* * @@ -138,7 +137,7 @@ __mlu_func__ void computeStoreRoipointPool3d(char *boxes3d, return; } int sampled_pts_num_rem = sampled_pts_num - *cnt; - int segnum = MIN(select_num, sampled_pts_num_rem) - 1; + int segnum = min((int)select_num, sampled_pts_num_rem) - 1; // copy x to pooled_features_gdram // The result of __bang_select is composed of three parts: @@ -243,7 +242,7 @@ __mlu_func__ void computeStoreLastBlockRoipointPool3d(char *boxes3d, if (select_num > 0) { int sampled_pts_num_rem = sampled_pts_num - *cnt; - int segnum = MIN(select_num, sampled_pts_num_rem) - 1; + int segnum = min((int)select_num, sampled_pts_num_rem) - 1; // copy x to pooled_features_gdram // The result of __bang_select is composed of three parts: diff --git a/mmcv/ops/csrc/common/mlu/roipoint_pool3d_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/roipoint_pool3d_mlu_kernel.mlu index 9cfecec13a..f16d84047d 100644 --- a/mmcv/ops/csrc/common/mlu/roipoint_pool3d_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/roipoint_pool3d_mlu_kernel.mlu @@ -10,7 +10,6 @@ *************************************************************************/ #include "common_mlu_helper.hpp" -#include "roipoint_pool3d_utils.hpp" /************************************************************************************** * @@ -142,7 +141,7 @@ __mlu_func__ void computeStoreRoipointPool3d(char *boxes3d, return; } int sampled_pts_num_rem = sampled_pts_num - cnt[box_idx]; - int segnum = MIN(select_num, sampled_pts_num_rem) - 1; + int segnum = min((int)select_num, sampled_pts_num_rem) - 1; // copy x to pooled_features_gdram // The result of __bang_select is composed of three parts: @@ -248,7 +247,7 @@ __mlu_func__ void computeStoreLastBlockRoipointPool3d(char *boxes3d, if (select_num > 0) { int sampled_pts_num_rem = sampled_pts_num - cnt[box_idx]; - int segnum = MIN(select_num, sampled_pts_num_rem) - 1; + int segnum = min((int)select_num, sampled_pts_num_rem) - 1; // copy x to pooled_features_gdram // The result of __bang_select is composed of three parts: diff --git a/mmcv/ops/csrc/common/mlu/roipoint_pool3d_utils.hpp b/mmcv/ops/csrc/common/mlu/roipoint_pool3d_utils.hpp deleted file mode 100644 index db618756c0..0000000000 --- a/mmcv/ops/csrc/common/mlu/roipoint_pool3d_utils.hpp +++ /dev/null @@ -1,18 +0,0 @@ -/************************************************************************* - * Copyright (C) 2022 Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ -#ifndef ROIPOINT_POOL3D_UTILS_HPP_ -#define ROIPOINT_POOL3D_UTILS_HPP_ - -#define MIN(a, b) (((a) < (b)) ? (a) : (b)) -#define MAX(a, b) (((a) > (b)) ? (a) : (b)) - -#endif // ROIPOINT_POOL3D_UTILS_HPP_ diff --git a/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp index 28b3f98be0..7b49ead24a 100644 --- a/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp @@ -11,7 +11,6 @@ *************************************************************************/ #include "pytorch_device_registry.hpp" #include "pytorch_mlu_helper.hpp" -#include "roipoint_pool3d_utils.hpp" void KernelRoiPointPool3dForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, From c1f616b19e5add59314481a053d244d43e9c9463 Mon Sep 17 00:00:00 2001 From: zhangshaopeng Date: Thu, 8 Sep 2022 14:20:28 +0800 Subject: [PATCH 5/5] [Feature] Support RoipointPool3d with cambricon MLU backend --- .../csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp | 130 ++++++++---------- 1 file changed, 61 insertions(+), 69 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp index 7b49ead24a..49dfe0ecad 100644 --- a/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp @@ -12,35 +12,20 @@ #include "pytorch_device_registry.hpp" #include "pytorch_mlu_helper.hpp" -void KernelRoiPointPool3dForward(cnrtDim3_t k_dim, - cnrtFunctionType_t k_type, - cnrtQueue_t queue, - const cnrtDataType_t d_type, - const int batch_size, - const int pts_num, - const int boxes_num, - const int feature_in_len, - const int sampled_pts_num, - const void *xyz, - const void *boxes3d, - const void *pts_feature, - void *pooled_features, - int *pooled_empty_flag); - -void KernelRoiPointPool3dLargeBoxesNumForward(cnrtDim3_t k_dim, - cnrtFunctionType_t k_type, - cnrtQueue_t queue, - const cnrtDataType_t d_type, - const int batch_size, - const int pts_num, - const int boxes_num, - const int feature_in_len, - const int sampled_pts_num, - const void *xyz, - const void *boxes3d, - const void *pts_feature, - void *pooled_features, - int *pooled_empty_flag); +void KernelRoiPointPool3dForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, const cnrtDataType_t d_type, + const int batch_size, const int pts_num, + const int boxes_num, const int feature_in_len, + const int sampled_pts_num, const void *xyz, + const void *boxes3d, const void *pts_feature, + void *pooled_features, int *pooled_empty_flag); + +void KernelRoiPointPool3dLargeBoxesNumForward( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const int batch_size, const int pts_num, + const int boxes_num, const int feature_in_len, const int sampled_pts_num, + const void *xyz, const void *boxes3d, const void *pts_feature, + void *pooled_features, int *pooled_empty_flag); // policy function static void policyFuncForward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) { @@ -51,35 +36,36 @@ static void policyFuncForward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) { *k_type = CNRT_FUNC_TYPE_UNION1; } -void RoIPointPool3dForwardMLUKernelLauncher(int batch_size, - int pts_num, - int boxes_num, - int feature_in_len, - int sampled_pts_num, - const Tensor xyz, - const Tensor boxes3d, - const Tensor pts_feature, - Tensor pooled_features, - Tensor pooled_empty_flag) { +void RoIPointPool3dForwardMLUKernelLauncher( + int batch_size, int pts_num, int boxes_num, int feature_in_len, + int sampled_pts_num, const Tensor xyz, const Tensor boxes3d, + const Tensor pts_feature, Tensor pooled_features, + Tensor pooled_empty_flag) { // check datatype TORCH_CHECK(((xyz.scalar_type() == pooled_features.scalar_type()) && (boxes3d.scalar_type() == pooled_features.scalar_type()) && (pts_feature.scalar_type() == pooled_features.scalar_type())), - "data types of xyz, boxes3d, pts_feature and pooled_features should be the same, ", + "data types of xyz, boxes3d, pts_feature and pooled_features " + "should be the same, ", "but now xyz type is ", xyz.scalar_type(), ", boxes3d type is ", - boxes3d.scalar_type(), ", pts_feature type is ", pts_feature.scalar_type(), - ", pooled_features type is ", pooled_features.scalar_type(), "."); - TORCH_CHECK((xyz.scalar_type() == at::kFloat || xyz.scalar_type() == at::kHalf), - "xyz type should be Float or Half, got ", xyz.scalar_type(), "."); + boxes3d.scalar_type(), ", pts_feature type is ", + pts_feature.scalar_type(), ", pooled_features type is ", + pooled_features.scalar_type(), "."); + TORCH_CHECK( + (xyz.scalar_type() == at::kFloat || xyz.scalar_type() == at::kHalf), + "xyz type should be Float or Half, got ", xyz.scalar_type(), "."); TORCH_CHECK((pooled_empty_flag.scalar_type() == at::kInt), - "pooled_empty_flag type should be Int, got ", pooled_empty_flag.scalar_type(), "."); + "pooled_empty_flag type should be Int, got ", + pooled_empty_flag.scalar_type(), "."); // check shape - TORCH_CHECK(boxes3d.dim() == 3, "boxes3d should be a 3d tensor, got ", boxes3d.dim(), "D."); + TORCH_CHECK(boxes3d.dim() == 3, "boxes3d should be a 3d tensor, got ", + boxes3d.dim(), "D."); TORCH_CHECK(pts_feature.dim() == 3, "pts_feature should be a 3d tensor, got ", pts_feature.dim(), "D."); - TORCH_CHECK(boxes3d.size(2) == 7, "the 3rd dimensions of boxes3d should be 7, got ", + TORCH_CHECK(boxes3d.size(2) == 7, + "the 3rd dimensions of boxes3d should be 7, got ", boxes3d.size(2), "."); TORCH_CHECK((boxes3d.size(0) == batch_size), "the 1st dimensions of boxes3d should be batch_size, ", @@ -87,26 +73,29 @@ void RoIPointPool3dForwardMLUKernelLauncher(int batch_size, ", and batch_size is ", batch_size, "."); TORCH_CHECK((pts_feature.size(0) == batch_size), "the 1st dimensions of pts_feature should be batch_size, ", - "but now the 1st dimension of pts_feature is ", pts_feature.size(0), - ", and batch_size is ", batch_size, "."); + "but now the 1st dimension of pts_feature is ", + pts_feature.size(0), ", and batch_size is ", batch_size, "."); TORCH_CHECK((pts_feature.size(1) == pts_num), "the 2nd dimensions of pts_feature should be pts_num, ", - "but now the 2nd dimension of pts_feature is ", pts_feature.size(1), - ", and pts_num is ", pts_num, "."); + "but now the 2nd dimension of pts_feature is ", + pts_feature.size(1), ", and pts_num is ", pts_num, "."); // check zero element if (xyz.numel() == 0 || pts_feature.numel() == 0 || boxes3d.numel() == 0 || pooled_features.numel() == 0 || pooled_empty_flag.numel() == 0) { - return; + return; } // large tensor check const size_t max_input_size = 2147483648; - TORCH_CHECK(xyz.numel() < max_input_size, "xyz element num should be less than 2^31, got ", - xyz.numel(), "."); - TORCH_CHECK(boxes3d.numel() < max_input_size, "boxes3d element num should be less than 2^31, got ", + TORCH_CHECK(xyz.numel() < max_input_size, + "xyz element num should be less than 2^31, got ", xyz.numel(), + "."); + TORCH_CHECK(boxes3d.numel() < max_input_size, + "boxes3d element num should be less than 2^31, got ", boxes3d.numel(), "."); - TORCH_CHECK(pts_feature.numel() < max_input_size, "pts_feature element num should be less than 2^31, got ", + TORCH_CHECK(pts_feature.numel() < max_input_size, + "pts_feature element num should be less than 2^31, got ", pts_feature.numel(), "."); // calculate task dimension @@ -140,16 +129,18 @@ void RoIPointPool3dForwardMLUKernelLauncher(int batch_size, if (boxes_num <= 10240) { CNLOG(INFO) << "Launch Kernel MLUKernelRoiPointPool3dForward<<<" << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - KernelRoiPointPool3dForward(k_dim, k_type, queue, data_type, batch_size, pts_num, boxes_num, - feature_in_len, sampled_pts_num, xyz_ptr, boxes3d_ptr, - pts_feature_ptr, pooled_features_ptr, (int *)pooled_empty_flag_ptr); + KernelRoiPointPool3dForward( + k_dim, k_type, queue, data_type, batch_size, pts_num, boxes_num, + feature_in_len, sampled_pts_num, xyz_ptr, boxes3d_ptr, pts_feature_ptr, + pooled_features_ptr, (int *)pooled_empty_flag_ptr); } else { - CNLOG(INFO) << "Launch Kernel MLUKernelRoiPointPool3dLargeBoxesNumForward<<<" - << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - KernelRoiPointPool3dLargeBoxesNumForward(k_dim, k_type, queue, data_type, batch_size, pts_num, - boxes_num, feature_in_len, sampled_pts_num, xyz_ptr, - boxes3d_ptr, pts_feature_ptr, pooled_features_ptr, - (int *)pooled_empty_flag_ptr); + CNLOG(INFO) + << "Launch Kernel MLUKernelRoiPointPool3dLargeBoxesNumForward<<<" + << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; + KernelRoiPointPool3dLargeBoxesNumForward( + k_dim, k_type, queue, data_type, batch_size, pts_num, boxes_num, + feature_in_len, sampled_pts_num, xyz_ptr, boxes3d_ptr, pts_feature_ptr, + pooled_features_ptr, (int *)pooled_empty_flag_ptr); } } @@ -159,9 +150,9 @@ void roipoint_pool3d_forward_mlu(int batch_size, int pts_num, int boxes_num, const Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag) { - RoIPointPool3dForwardMLUKernelLauncher(batch_size, pts_num, boxes_num, feature_in_len, - sampled_pts_num, xyz, boxes3d, pts_feature, - pooled_features, pooled_empty_flag); + RoIPointPool3dForwardMLUKernelLauncher( + batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz, + boxes3d, pts_feature, pooled_features, pooled_empty_flag); } void roipoint_pool3d_forward_impl(int batch_size, int pts_num, int boxes_num, @@ -171,4 +162,5 @@ void roipoint_pool3d_forward_impl(int batch_size, int pts_num, int boxes_num, Tensor pooled_features, Tensor pooled_empty_flag); -REGISTER_DEVICE_IMPL(roipoint_pool3d_forward_impl, MLU, roipoint_pool3d_forward_mlu); +REGISTER_DEVICE_IMPL(roipoint_pool3d_forward_impl, MLU, + roipoint_pool3d_forward_mlu);