diff --git a/mmcv/ops/csrc/common/mlu/roi_align_rotated_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/roi_align_rotated_mlu_kernel.mlu deleted file mode 100644 index 9356776c58..0000000000 --- a/mmcv/ops/csrc/common/mlu/roi_align_rotated_mlu_kernel.mlu +++ /dev/null @@ -1,490 +0,0 @@ -/************************************************************************* - * Copyright (C) 2022 Cambricon. - * - * OR IMPLIED, INCLUDING BUvoid NOKType LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENvoid SHALL THE AUTHORS OR COPYRIGHKType HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORvoid OR OTHERWISE, ARISING FROM, OUKType OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ -#include "common_mlu_helper.hpp" -#include "roi_align_rotated_utils.hpp" - -#define ROI_OFFSET 6 -#define SAMPLING_NUM 4 - -__nram__ char nram_buffer[MAX_NRAM_SIZE]; - -template -__mlu_func__ void swap(T &a, T &b) { - T tmp = a; - a = b; - b = tmp; -} - -template -__mlu_func__ void bilinearInterpolate(const int input_height, - const int input_width, T x, T y, T *w1, - T *w2, T *w3, T *w4, int *x_low, - int *x_high, int *y_low, int *y_high, - bool *empty) { - // deal with case that the point is out of feature map boundary - if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) { - *empty = true; - return; - } - - if (y <= 0) y = (T)0; - if (x <= 0) x = (T)0; - - *y_low = int(y); - *x_low = int(x); - - if (*y_low >= input_height - 1) { - *y_high = *y_low = input_height - 1; - y = (T)(*y_low); - } else { - *y_high = *y_low + 1; - } - - if (*x_low >= input_width - 1) { - *x_high = *x_low = input_width - 1; - x = T(*x_low); - } else { - *x_high = *x_low + 1; - } - T ly = y - *y_low; - T lx = x - *x_low; - T hy = 1.0 - ly; - T hx = 1.0 - lx; - *w1 = hy * hx; - *w2 = hy * lx; - *w3 = ly * hx; - *w4 = ly * lx; - return; -} - -template -__mlu_func__ void getRoiBinInfo(const T *rois_dram, const int bin_i, - const RoiAlignRotatedParams ¶ms, - int *batch_idx, int *roi_n, int *pw, int *ph, - T *roi_center_x, T *roi_center_y, T *roi_width, - T *roi_height, T *theta) { - T offset = params.aligned ? (T)0.5 : (T)0.0; - *pw = bin_i % params.pooled_width; - *ph = (bin_i / params.pooled_width) % params.pooled_height; - *roi_n = bin_i / params.pooled_width / params.pooled_height; - const T *roi_info = rois_dram + (*roi_n) * ROI_OFFSET; - *batch_idx = (int)roi_info[0]; - *roi_center_x = roi_info[1] * (T)params.spatial_scale - offset; - *roi_center_y = roi_info[2] * (T)params.spatial_scale - offset; - *roi_width = roi_info[3] * (T)params.spatial_scale; - *roi_height = roi_info[4] * (T)params.spatial_scale; - *theta = roi_info[5]; - if (params.clockwise) { - *theta = -(*theta); - } - if (!params.aligned) { - *roi_width = *roi_width > (T)1.0 ? *roi_width : (T)1.0; - *roi_height = *roi_height > (T)1.0 ? *roi_height : (T)1.0; - } -} - -template -__mlu_func__ void roiAlignRotatedForward(const T *input_dram, - const T *rois_dram, const int batch, - const int height, const int width, - const int channel, const int rois_num, - const RoiAlignRotatedParams ¶ms, - T *output_dram) { - int align_base_128 = NFU_ALIGN_SIZE / sizeof(T); - int channel_max_cap = MAX_NRAM_SIZE / sizeof(T) / (2 * SAMPLING_NUM + 1); - channel_max_cap = channel_max_cap / align_base_128 * align_base_128; - int channel_align = channel < channel_max_cap ? channel : channel_max_cap; - channel_align = CEIL_ALIGN(channel_align, align_base_128); - - T *nram_out = (T *)nram_buffer; - T *nram_ping = nram_out + channel_align; - T *nram_pong = nram_ping + channel_align * SAMPLING_NUM; - - int bin_first = taskId; - int bin_end = rois_num * params.pooled_height * params.pooled_width; - - for (int bin_i = bin_first; bin_i < bin_end; bin_i += taskDim) { - T roi_center_x, roi_center_y, roi_width, roi_height, theta; - int batch_idx, roi_n, pw, ph; - getRoiBinInfo(rois_dram, bin_i, params, &batch_idx, &roi_n, &pw, &ph, - &roi_center_x, &roi_center_y, &roi_width, &roi_height, - &theta); - T bin_size_h = roi_height / params.pooled_height; - T bin_size_w = roi_width / params.pooled_width; - - int roi_bin_grid_h = - (params.sample_ratio > 0) - ? params.sample_ratio - : __float2int_up((float)roi_height / params.pooled_height); - int roi_bin_grid_w = - (params.sample_ratio > 0) - ? params.sample_ratio - : __float2int_up((float)roi_width / params.pooled_width); - T roi_start_y = -roi_height / 2; - T roi_start_x = -roi_width / 2; - const int bin_dim = roi_bin_grid_h * roi_bin_grid_w > 1 - ? roi_bin_grid_h * roi_bin_grid_w - : 1; - T cos_theta = std::cos(theta); - T sin_theta = std::sin(theta); - T zero_sign = 1.0f / bin_dim; - - bool is_first_sample = true; - int src_offset = 0; - int dst_offset = 0; - int c_rem, c_slice, c_slice_align, pongc_slice, pongc_slice_align; - for (int c_offset = 0; c_offset < channel; c_offset += channel_align) { - __bang_write_value(nram_out, channel_align, (T)0); - c_rem = channel - c_offset; - c_slice = channel_align > c_rem ? c_rem : channel_align; - c_slice_align = CEIL_ALIGN(c_slice, align_base_128); - is_first_sample = true; - for (int iy = 0; iy < roi_bin_grid_h; ++iy) { - const T yy = roi_start_y + ph * bin_size_h + - T(iy + 0.5) * bin_size_h / roi_bin_grid_h; - for (int ix = 0; ix < roi_bin_grid_w; ++ix) { - const T xx = roi_start_x + pw * bin_size_w + - T(ix + 0.5) * bin_size_w / roi_bin_grid_w; - int sample_i = iy * roi_bin_grid_w + ix; - - T y = yy * cos_theta - xx * sin_theta + roi_center_y; - T x = yy * sin_theta + xx * cos_theta + roi_center_x; - T w1, w2, w3, w4; - bool empty = false; - int x_low, x_high, y_low, y_high; - bilinearInterpolate(height, width, x, y, &w1, &w2, &w3, &w4, &x_low, - &x_high, &y_low, &y_high, &empty); - /******************************************************* - | ping | pong | - |------|-----|-----|-----|-----|-----|-----|-----|-----| - |output| p1 | p2 | p3 | p4 | p1 | p2 | p3 | p4 | - |------|-----|-----|-----|-----|-----|-----|-----|-----| - ********************************************************/ - if (is_first_sample && !empty) { - // load input data from dram to nram - __bang_write_value(nram_ping, SAMPLING_NUM * c_slice_align, (T)0); - src_offset = - (batch_idx * height * width + y_low * width + x_low) * channel + - c_offset; - dst_offset = 0; - __memcpy(nram_ping + dst_offset, input_dram + src_offset, - c_slice * sizeof(T), GDRAM2NRAM); - src_offset = (batch_idx * height * width + y_low * width + x_high) * - channel + - c_offset; - dst_offset = c_slice_align; - __memcpy(nram_ping + dst_offset, input_dram + src_offset, - c_slice * sizeof(T), GDRAM2NRAM); - src_offset = (batch_idx * height * width + y_high * width + x_low) * - channel + - c_offset; - dst_offset = c_slice_align * 2; - __memcpy(nram_ping + dst_offset, input_dram + src_offset, - c_slice * sizeof(T), GDRAM2NRAM); - src_offset = - (batch_idx * height * width + y_high * width + x_high) * - channel + - c_offset; - dst_offset = c_slice_align * 3; - __memcpy(nram_ping + dst_offset, input_dram + src_offset, - c_slice * sizeof(T), GDRAM2NRAM); - } - // load next input data to nram - if (sample_i + 1 < bin_dim) { - int p_iy = (sample_i + 1) / roi_bin_grid_w; - int p_ix = (sample_i + 1) % roi_bin_grid_w; - const T p_yy = roi_start_y + ph * bin_size_h + - T(p_iy + 0.5) * bin_size_h / roi_bin_grid_h; - const T p_xx = roi_start_x + pw * bin_size_w + - T(p_ix + 0.5) * bin_size_w / roi_bin_grid_w; - T p_y = p_yy * cos_theta - p_xx * sin_theta + roi_center_y; - T p_x = p_yy * sin_theta + p_xx * cos_theta + roi_center_x; - T p_w1, p_w2, p_w3, p_w4; - bool p_empty = false; - int p_x_low, p_x_high, p_y_low, p_y_high; - bilinearInterpolate(height, width, p_x, p_y, &p_w1, &p_w2, &p_w3, - &p_w4, &p_x_low, &p_x_high, &p_y_low, &p_y_high, - &p_empty); - pongc_slice = c_slice; - pongc_slice_align = c_slice_align; - if (!p_empty) { - __bang_write_value(nram_pong, SAMPLING_NUM * pongc_slice_align, - (T)0); - src_offset = - (batch_idx * height * width + p_y_low * width + p_x_low) * - channel + - c_offset; - dst_offset = 0; - __memcpy(nram_pong + dst_offset, input_dram + src_offset, - c_slice * sizeof(T), GDRAM2NRAM); - src_offset = - (batch_idx * height * width + p_y_low * width + p_x_high) * - channel + - c_offset; - dst_offset = pongc_slice_align; - __memcpy(nram_pong + dst_offset, input_dram + src_offset, - c_slice * sizeof(T), GDRAM2NRAM); - src_offset = - (batch_idx * height * width + p_y_high * width + p_x_low) * - channel + - c_offset; - dst_offset = pongc_slice_align * 2; - __memcpy(nram_pong + dst_offset, input_dram + src_offset, - c_slice * sizeof(T), GDRAM2NRAM); - src_offset = - (batch_idx * height * width + p_y_high * width + p_x_high) * - channel + - c_offset; - dst_offset = pongc_slice_align * 3; - __memcpy(nram_pong + dst_offset, input_dram + src_offset, - c_slice * sizeof(T), GDRAM2NRAM); - } - } - T *tmp_sum = nram_ping + 3 * c_slice_align; - if (empty) { - __bang_write_value(tmp_sum, c_slice_align, T(0)); - } else { - __bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align); - __bang_mul_scalar(nram_ping + c_slice_align, - nram_ping + c_slice_align, w2, c_slice_align); - __bang_mul_scalar(nram_ping + 2 * c_slice_align, - nram_ping + 2 * c_slice_align, w3, c_slice_align); - __bang_mul_scalar(nram_ping + 3 * c_slice_align, - nram_ping + 3 * c_slice_align, w4, c_slice_align); - __bang_sumpool(tmp_sum, nram_ping, c_slice_align, 1, SAMPLING_NUM, - 1, SAMPLING_NUM, 1, 1); - } - __bang_add(nram_out, nram_out, tmp_sum, c_slice_align); - swap(nram_ping, nram_pong); - __asm__ volatile("sync;"); - is_first_sample = false; - } - } - __bang_mul_scalar(nram_out, nram_out, zero_sign, c_slice_align); - // store the result to dram - int output_offset = - ((roi_n * params.pooled_height + ph) * params.pooled_width + pw) * - channel + - c_offset; - __memcpy(output_dram + output_offset, nram_out, c_slice * sizeof(T), - NRAM2GDRAM); - } - } -} - -template -__mlu_func__ void roiAlignRotatedBackward(const T *top_grad_dram, - const T *rois_dram, const int batch, - const int height, const int width, - const int channel, const int rois_num, - const RoiAlignRotatedParams ¶ms, - T *bottom_grad_dram) { - int align_base_128 = NFU_ALIGN_SIZE / sizeof(T); - int channel_align = CEIL_ALIGN(channel, align_base_128); - - unsigned int max_element = MAX_NRAM_SIZE / sizeof(T); - int c_limit = max_element >> 2; - c_limit = c_limit > channel_align ? channel_align : c_limit; - - T *nram_ping = (T *)nram_buffer; - T *nram_pong = nram_ping + 2 * c_limit; - T *nram_output = nullptr; - - int bin_first = taskId; - int bin_end = rois_num * params.pooled_height * params.pooled_width; - bool is_first_bin = true; - T roi_center_x, roi_center_y, roi_width, roi_height, theta; - int batch_idx, roi_n, pw, ph; - T pong_roi_center_x, pong_roi_center_y, pong_roi_width, pong_roi_height, - pong_theta; - int pong_batch_idx, pong_roi_n, pong_pw, pong_ph; - for (int bin_i = bin_first; bin_i < bin_end; bin_i += taskDim) { - getRoiBinInfo(rois_dram, bin_i, params, &batch_idx, &roi_n, &pw, &ph, - &roi_center_x, &roi_center_y, &roi_width, &roi_height, - &theta); - T bin_size_h = roi_height / params.pooled_height; - T bin_size_w = roi_width / params.pooled_width; - - int roi_bin_grid_h = - (params.sample_ratio > 0) - ? params.sample_ratio - : __float2int_up((float)roi_height / params.pooled_height); - int roi_bin_grid_w = - (params.sample_ratio > 0) - ? params.sample_ratio - : __float2int_up((float)roi_width / params.pooled_width); - T roi_start_y = -roi_height / 2; - T roi_start_x = -roi_width / 2; - const int bin_dim = roi_bin_grid_h * roi_bin_grid_w > 1 - ? roi_bin_grid_h * roi_bin_grid_w - : 1; - T cos_theta = std::cos(theta); - T sin_theta = std::sin(theta); - T zero_sign = 1.0f / bin_dim; - int c_rem, c_slice, pongc_slice, c_offset; - c_rem = channel; - c_offset = 0; - /**************************************** - | ping | pong | - |---------|---------|---------|---------| - | input | output | input | output | - |---------|---------|---------|---------| - *****************************************/ - if (is_first_bin) { - // load the first top_grad to nram - c_slice = c_limit < c_rem ? c_limit : c_rem; - int top_grad_offset = - ((roi_n * params.pooled_height + ph) * params.pooled_width + pw) * - channel; - __memcpy(nram_ping, top_grad_dram + top_grad_offset, c_slice * sizeof(T), - GDRAM2NRAM); - } - nram_output = nram_ping + c_limit; - while (c_rem > 0) { - c_slice = c_slice < c_rem ? c_slice : c_rem; - // load the next top_grad to nram - if (c_rem - c_slice > 0) { - // load the rest channels to nram - pongc_slice = (c_rem - c_slice > c_slice) ? c_slice : c_rem - c_slice; - int top_grad_offset = - ((roi_n * params.pooled_height + ph) * params.pooled_width + pw) * - channel + - c_offset + c_slice; - __memcpy_async(nram_pong, top_grad_dram + top_grad_offset, - pongc_slice * sizeof(T), GDRAM2NRAM); - } else if (bin_i + taskDim < bin_end) { - // load next bin's data to nram - getRoiBinInfo(rois_dram, bin_i + taskDim, params, &pong_batch_idx, - &pong_roi_n, &pong_pw, &pong_ph, &pong_roi_center_x, - &pong_roi_center_y, &pong_roi_width, &pong_roi_height, - &pong_theta); - pongc_slice = c_limit < channel ? c_limit : channel; - int top_grad_offset = ((pong_roi_n * params.pooled_height + pong_ph) * - params.pooled_width + - pong_pw) * - channel; - __memcpy_async(nram_pong, top_grad_dram + top_grad_offset, - c_slice * sizeof(T), GDRAM2NRAM); - } - // comput the output in a single bin - - for (int iy = 0; iy < roi_bin_grid_h; ++iy) { - const T yy = roi_start_y + ph * bin_size_h + - T(iy + 0.5) * bin_size_h / roi_bin_grid_h; - for (int ix = 0; ix < roi_bin_grid_w; ++ix) { - const T xx = roi_start_x + pw * bin_size_w + - T(ix + 0.5) * bin_size_w / roi_bin_grid_w; - T y = yy * cos_theta - xx * sin_theta + roi_center_y; - T x = yy * sin_theta + xx * cos_theta + roi_center_x; - T w1, w2, w3, w4; - bool empty = false; - int x_low, x_high, y_low, y_high; - bilinearInterpolate(height, width, x, y, &w1, &w2, &w3, &w4, &x_low, - &x_high, &y_low, &y_high, &empty); - if (empty) { - continue; - } else { - __bang_mul_scalar(nram_output, nram_ping, w1 * zero_sign, c_limit); - __bang_atomic_add( - (T *)nram_output, - bottom_grad_dram + batch_idx * height * width * channel + - y_low * width * channel + x_low * channel + c_offset, - (T *)nram_output, c_slice); - __bang_mul_scalar(nram_output, nram_ping, w2 * zero_sign, c_limit); - __bang_atomic_add( - (T *)nram_output, - bottom_grad_dram + batch_idx * height * width * channel + - y_low * width * channel + x_high * channel + c_offset, - (T *)nram_output, c_slice); - __bang_mul_scalar(nram_output, nram_ping, w3 * zero_sign, c_limit); - __bang_atomic_add( - (T *)nram_output, - bottom_grad_dram + batch_idx * height * width * channel + - y_high * width * channel + x_low * channel + c_offset, - (T *)nram_output, c_slice); - __bang_mul_scalar(nram_output, nram_ping, w4 * zero_sign, c_limit); - __bang_atomic_add( - (T *)nram_output, - bottom_grad_dram + batch_idx * height * width * channel + - y_high * width * channel + x_high * channel + c_offset, - (T *)nram_output, c_slice); - } - } - } - swap(nram_ping, nram_pong); - c_rem -= c_slice; - c_offset += c_slice; - __asm__ volatile("sync;"); - } - is_first_bin = false; - } -} - -__mlu_global__ void MLUUnion1KernelRoiAlignRotatedForward( - const void *features, const void *rois, void *output, const int batch, - const int height, const int width, const int channel, const int rois_num, - const RoiAlignRotatedParams rroiAlignParams, - const cnrtDataType_t data_type) { - if (0x80 == coreId) { - return; - } - - if (data_type == CNRT_FLOAT32) { - roiAlignRotatedForward((float *)features, (float *)rois, batch, height, - width, channel, rois_num, rroiAlignParams, - (float *)output); - } else { - roiAlignRotatedForward((half *)features, (half *)rois, batch, height, width, - channel, rois_num, rroiAlignParams, (half *)output); - } -} - -__mlu_global__ void MLUUnion1KernelRoiAlignRotatedBackward( - const void *top_grad, const void *rois, void *bottom_grad, const int batch, - const int height, const int width, const int channel, const int rois_num, - const RoiAlignRotatedParams rroiAlignParams, - const cnrtDataType_t data_type) { - if (0x80 == coreId) { - return; - } - - if (data_type == CNRT_FLOAT32) { - roiAlignRotatedBackward((float *)top_grad, (float *)rois, batch, height, - width, channel, rois_num, rroiAlignParams, - (float *)bottom_grad); - } else { - roiAlignRotatedBackward((half *)top_grad, (half *)rois, batch, height, - width, channel, rois_num, rroiAlignParams, - (half *)bottom_grad); - } -} - -void KernelRoiAlignRotatedForward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const void *features, const void *rois, - void *output, const int batch, const int height, const int width, - const int channel, const int rois_num, - const RoiAlignRotatedParams roiAlignRotatedParams) { - MLUUnion1KernelRoiAlignRotatedForward<<>>( - features, rois, output, batch, height, width, channel, rois_num, - roiAlignRotatedParams, d_type); -} - -void KernelRoiAlignRotatedBackward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const void *top_grad, const void *rois, - void *bottom_grad, const int batch, const int height, const int width, - const int channel, const int rois_num, - const RoiAlignRotatedParams roiAlignRotatedParams) { - MLUUnion1KernelRoiAlignRotatedBackward<<>>( - top_grad, rois, bottom_grad, batch, height, width, channel, rois_num, - roiAlignRotatedParams, d_type); -} diff --git a/mmcv/ops/csrc/common/mlu/roi_align_rotated_utils.hpp b/mmcv/ops/csrc/common/mlu/roi_align_rotated_utils.hpp deleted file mode 100644 index cd0ec02484..0000000000 --- a/mmcv/ops/csrc/common/mlu/roi_align_rotated_utils.hpp +++ /dev/null @@ -1,24 +0,0 @@ -/************************************************************************* - * Copyright (C) 2022 Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ -#ifndef ROI_ALIGN_ROTATED_UTILS_HPP_ -#define ROI_ALIGN_ROTATED_UTILS_HPP_ - -struct RoiAlignRotatedParams { - int pooled_height; - int pooled_width; - int sample_ratio; - float spatial_scale; - bool aligned; - bool clockwise; -}; - -#endif // ROI_ALIGN_ROTATED_UTILS_HPP_ diff --git a/mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp old mode 100755 new mode 100644 index c3058c01f5..7cf059cd51 --- a/mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp @@ -9,37 +9,7 @@ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#include "pytorch_device_registry.hpp" -#include "pytorch_mlu_helper.hpp" -#include "roi_align_rotated_utils.hpp" - -namespace { - -void policyFunc(int bin_num, cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) { - unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - *k_type = CNRT_FUNC_TYPE_UNION1; - k_dim->x = core_num; - unsigned int use_cluster = (bin_num + core_num - 1) / core_num; - k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster; - k_dim->z = 1; -} - -} // namespace - -void KernelRoiAlignRotatedForward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const void *features, const void *rois, - void *output, const int batch, const int height, const int width, - const int channel, const int rois_num, - const RoiAlignRotatedParams roiAlignRotatedParams); - -void KernelRoiAlignRotatedBackward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const void *top_grad, const void *rois, - void *bottom_grad, const int batch, const int height, const int width, - const int channel, const int rois_num, - const RoiAlignRotatedParams roiAlignRotatedParams); +#include "mlu_common_helper.h" void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output, int pooled_height, @@ -47,153 +17,70 @@ void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois, float spatial_scale, int sampling_ratio, bool aligned, bool clockwise) { - TORCH_CHECK(((input.scalar_type() == output.scalar_type()) && - (output.scalar_type() == rois.scalar_type())), - "data types of input, rois and output should be the same, ", - "but now input type is ", input.scalar_type(), ", rois type is ", - rois.scalar_type(), ", output type is ", output.scalar_type(), - "."); - TORCH_CHECK( - (input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf), - "input type should be Float or Half, got ", input.scalar_type(), "."); - - TORCH_CHECK(input.dim() == 4, "input should be a 4d tensor, got ", - input.dim(), "D."); - TORCH_CHECK(rois.dim() == 2, "rois should be a 2d tensor, got ", rois.dim(), - "D."); - TORCH_CHECK(output.dim() == 4, "output should be a 4d tensor, got ", - output.dim(), "D."); - - TORCH_CHECK((rois.size(0) == output.size(0)), - "the 1st dimensions of rois and output should be the same, ", - "but now the 1st dimension of rois is ", rois.size(0), - ", and output is ", output.size(0), "."); - - TORCH_CHECK((input.size(1) == output.size(1)), - "the 2nd dimensions of input and output should be the same, ", - "but now the 2nd dimension of input is ", input.size(1), - ", and output is ", output.size(1), "."); - - int channel = input.size(1); - int width = input.size(3); - int height = input.size(2); - int batch = input.size(0); - int rois_nums = rois.size(0); - cnrtDataType_t d_type = torch_mlu::toCnrtDtype(input.dtype()); - - // return if zero-elements - if (input.numel() == 0) { - CNLOG(INFO) << "Skip the zero-elements case."; - return; - } - - RoiAlignRotatedParams roiAlignRotatedParams{pooled_height, pooled_width, - sampling_ratio, spatial_scale, - aligned, clockwise}; - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - policyFunc(rois_nums * pooled_height * pooled_width, &k_dim, &k_type); - auto memory_format = torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); - auto input_tensor = - torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format); - at::Tensor output_tmp = - at::empty({rois_nums, channel, pooled_height, pooled_width}, - input.options(), memory_format); + auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format); + auto rois_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format()); + auto output_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format); - // get compute queue - auto queue = torch_mlu::getCurQueue(); + MluOpTensorDescriptor input_desc, rois_desc, output_desc; + input_desc.set_with_layout(input_, MLUOP_LAYOUT_NHWC); + rois_desc.set(rois_contiguous); + output_desc.set_with_layout(output_contiguous, MLUOP_LAYOUT_NHWC); // get ptr of tensors - auto input_impl = torch_mlu::getMluTensorImpl(input_tensor); + auto input_impl = torch_mlu::getMluTensorImpl(input_); auto input_ptr = input_impl->cnnlMalloc(); - auto rois_impl = torch_mlu::getMluTensorImpl(rois); + auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous); auto rois_ptr = rois_impl->cnnlMalloc(); - auto output_impl = torch_mlu::getMluTensorImpl(output_tmp); + auto output_impl = torch_mlu::getMluTensorImpl(output_contiguous); auto output_ptr = output_impl->cnnlMalloc(); - KernelRoiAlignRotatedForward(k_dim, k_type, queue, d_type, input_ptr, - rois_ptr, output_ptr, batch, height, width, - channel, rois_nums, roiAlignRotatedParams); - output.copy_(output_tmp); + // get compute handle + auto handle = mluOpGetCurrentHandle(); + mluOpRoiAlignRotatedForward( + handle, input_desc.desc(), input_ptr, rois_desc.desc(), rois_ptr, + pooled_height, pooled_width, sampling_ratio, spatial_scale, aligned, + clockwise, output_desc.desc(), output_ptr); + + output.copy_(output_contiguous); } void ROIAlignRotatedBackwardMLUKernelLauncher( Tensor top_grad, Tensor rois, Tensor bottom_grad, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, bool aligned, bool clockwise) { - TORCH_CHECK(((top_grad.scalar_type() == bottom_grad.scalar_type()) && - (bottom_grad.scalar_type() == rois.scalar_type())), - "data types of top_grad, rois and bottom_grad should be ", - "the same, but now top_grad type is ", top_grad.scalar_type(), - ", rois type is ", rois.scalar_type(), ", bottom_grad type is ", - bottom_grad.scalar_type(), "."); - TORCH_CHECK((bottom_grad.scalar_type() == at::kFloat || - bottom_grad.scalar_type() == at::kHalf), - "Data type of bottom_grad should be Float ro Half, got ", - bottom_grad.scalar_type(), "."); - - TORCH_CHECK(bottom_grad.dim() == 4, "bottom_grad should be a 4d tensor, got ", - top_grad.dim(), "D."); - TORCH_CHECK(rois.dim() == 2, "rois should be a 2d tensor, got ", rois.dim(), - "D."); - TORCH_CHECK(top_grad.dim() == 4, "top_grad should be a 4d tensor, got ", - bottom_grad.dim(), "D."); - - TORCH_CHECK((rois.size(0) == top_grad.size(0)), - "the 1st dimensions of rois and top_grad should be the same, ", - "but now the 1st dimension of rois is ", rois.size(0), - ", and top_grad is ", top_grad.size(0), "."); - - TORCH_CHECK((bottom_grad.size(1) == top_grad.size(1)), - "the 2nd dimensions of bottom_grad and top_grad should be ", - "the same, but now the 2nd dimension of bottom_grad is ", - bottom_grad.size(1), ", and top_grad is ", top_grad.size(1), "."); - - int channel = bottom_grad.size(1); - int width = bottom_grad.size(3); - int height = bottom_grad.size(2); - int batch = bottom_grad.size(0); - int rois_nums = rois.size(0); - cnrtDataType_t d_type = torch_mlu::toCnrtDtype(bottom_grad.dtype()); - - // return if zero-elements - if (bottom_grad.numel() == 0) { - CNLOG(INFO) << "Skip the zero-elements case."; - return; - } - - RoiAlignRotatedParams roiAlignRotatedParams{pooled_height, pooled_width, - sampling_ratio, spatial_scale, - aligned, clockwise}; - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - policyFunc(rois_nums * pooled_height * pooled_width, &k_dim, &k_type); - auto memory_format = torch_mlu::cnnl::ops::get_channels_last_memory_format(top_grad.dim()); - auto top_grad_tensor = + auto top_grad_ = torch_mlu::cnnl::ops::cnnl_contiguous(top_grad, memory_format); - at::Tensor bottom_grad_tmp = at::empty({batch, channel, height, width}, - top_grad.options(), memory_format) - .zero_(); - - // get compute queue - auto queue = torch_mlu::getCurQueue(); + auto rois_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format()); + auto bottom_grad_ = + torch_mlu::cnnl::ops::cnnl_contiguous(bottom_grad, memory_format); // get ptr of tensors - auto bottom_grad_impl = torch_mlu::getMluTensorImpl(bottom_grad_tmp); - auto bottom_grad_ptr = bottom_grad_impl->cnnlMalloc(); - auto rois_impl = torch_mlu::getMluTensorImpl(rois); - auto rois_ptr = rois_impl->cnnlMalloc(); - auto top_grad_impl = torch_mlu::getMluTensorImpl(top_grad_tensor); + auto top_grad_impl = torch_mlu::getMluTensorImpl(top_grad_); auto top_grad_ptr = top_grad_impl->cnnlMalloc(); + auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous); + auto rois_ptr = rois_impl->cnnlMalloc(); + auto bottom_grad_impl = torch_mlu::getMluTensorImpl(bottom_grad_); + auto bottom_grad_ptr = bottom_grad_impl->cnnlMalloc(); - KernelRoiAlignRotatedBackward(k_dim, k_type, queue, d_type, top_grad_ptr, - rois_ptr, bottom_grad_ptr, batch, height, width, - channel, rois_nums, roiAlignRotatedParams); - bottom_grad.copy_(bottom_grad_tmp); + MluOpTensorDescriptor top_grad_desc, rois_desc, bottom_grad_desc; + top_grad_desc.set_with_layout(top_grad_, MLUOP_LAYOUT_NHWC); + rois_desc.set(rois_contiguous); + bottom_grad_desc.set_with_layout(bottom_grad_, MLUOP_LAYOUT_NHWC); + + // get compute handle + auto handle = mluOpGetCurrentHandle(); + mluOpRoiAlignRotatedBackward( + handle, top_grad_desc.desc(), top_grad_ptr, rois_desc.desc(), rois_ptr, + pooled_height, pooled_width, sampling_ratio, spatial_scale, aligned, + clockwise, bottom_grad_desc.desc(), bottom_grad_ptr); + bottom_grad.copy_(bottom_grad_); } void roi_align_rotated_forward_mlu(Tensor input, Tensor rois, Tensor output, diff --git a/tests/test_ops/test_roi_align_rotated.py b/tests/test_ops/test_roi_align_rotated.py index 1ad6b6e927..0d5ca432df 100644 --- a/tests/test_ops/test_roi_align_rotated.py +++ b/tests/test_ops/test_roi_align_rotated.py @@ -11,7 +11,6 @@ except ImportError: from torch.autograd import gradcheck _USING_PARROTS = False - # yapf:disable inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0.5, 0.5, 1., 1., 0]]),