From 0a6981a9eacf9832937bd12431da116b4dc63d14 Mon Sep 17 00:00:00 2001 From: DanieeelLiu Date: Fri, 26 Jul 2024 17:48:11 +0800 Subject: [PATCH] [Feature] Support MsDeformAttnForward with fast kernel --- .../mlu/ms_deform_attn_fast_mlu_kernel.hpp | 23 + .../mlu/ms_deform_attn_fast_mlu_kernel.mlu | 724 ++++++++++++++++++ .../csrc/pytorch/mlu/ms_deform_attn_mlu.cpp | 559 ++------------ 3 files changed, 798 insertions(+), 508 deletions(-) create mode 100644 mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.hpp create mode 100644 mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.mlu diff --git a/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.hpp b/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.hpp new file mode 100644 index 0000000000..45da263fea --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.hpp @@ -0,0 +1,23 @@ +/************************************************************************* + * Copyright (C) [2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef MS_DEFORM_ATTN_FORWARD_FAST_MLU_KERNEL_HPP_ +#define MS_DEFORM_ATTN_FORWARD_FAST_MLU_KERNEL_HPP_ +void KernelMsDeformAttnForwardFast( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const char *data_value_gdram, + const char *data_spatial_shapes_gdram, + const char *data_level_start_index_gdram, + const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char *data_col_gdram); +#endif // MS_DEFORM_ATTN_FORWARD_FAST_MLU_KERNEL_HPP_ diff --git a/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.mlu new file mode 100644 index 0000000000..6afe03ebc2 --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.mlu @@ -0,0 +1,724 @@ +/************************************************************************* + * Copyright (C) [2024] by Cambricon, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "common_mlu_helper.hpp" +#include "ms_deform_attn_fast_mlu_kernel.hpp" + +#define NRAM_REMAIN_SIZE (48 * 1024) +#define NRAM_AVALIABLE_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE) + +__nram__ char nram_buffer[NRAM_AVALIABLE_SIZE]; + +template +__mlu_func__ inline T __mluop_min(T a, T b) { + return a < b ? a : b; +} + +template +__mlu_func__ inline T __mluop_max(T a, T b) { + return a > b ? a : b; +} + +template +__mlu_func__ void __mluop_floor(T* dst_ram, T* src_ram, int size) { + if (sizeof(T) == sizeof(float)) { + int16* mid = (int16*)(dst_ram + size / 2); + __bang_float2int16_dn(mid, (float*)src_ram, size, 0); + __bang_int162float((float*)dst_ram, (int16_t*)mid, size, 0); + } else { + __bang_half2int16_dn((int16_t*)dst_ram, (half*)src_ram, size, 0); + __bang_int162half((half*)dst_ram, (int16_t*)dst_ram, size, 0); + } +} + +__mlu_func__ void broadcastSpatialHW( + float* spatial_offset_bd_nram, // (num_levels, num_points) + float* spatial_h_bd_nram, // (num_levels, num_points) + float* spatial_w_bd_nram, // (num_levels, num_points) + int32_t* spatial_shapes_nram, // (num_levels, 2) + int32_t* spatial_offset_nram, // (num_levels) + const int32_t num_levels, const int32_t num_points) { + for (int i = 0; i < num_levels * 2; i++) { + ((float*)spatial_shapes_nram)[i] = (float)spatial_shapes_nram[i]; + } + + for (int i = 0; i < num_levels; i++) { + ((float*)spatial_offset_nram)[i] = (float)spatial_offset_nram[i]; + } + + for (int i = 0; i < num_levels; i++) { + __memcpy(spatial_h_bd_nram + i * num_points, spatial_shapes_nram + i * 2, + sizeof(float), NRAM2NRAM, sizeof(float), 0, num_points - 1); + } + + for (int i = 0; i < num_levels; i++) { + __memcpy(spatial_w_bd_nram + i * num_points, + spatial_shapes_nram + 1 + i * 2, sizeof(float), NRAM2NRAM, + sizeof(float), 0, num_points - 1); + } + + for (int i = 0; i < num_levels; i++) { + __memcpy(spatial_offset_bd_nram + i * num_points, spatial_offset_nram + i, + sizeof(float), NRAM2NRAM, sizeof(float), 0, num_points - 1); + } +} + +template +__mlu_func__ void getConditionCoordWeight( + int32_t* data_offset_nram, T* weight_polation_nram, + T* cond_point_polation_nram, T* cond_point_valid_nram, T* loc_nram, + T* weight_attn_nram, T* spatial_offset_bd_nram, T* spatial_w_bd_nram, + T* spatial_h_bd_nram, T* buf_nram, const int32_t deal_n, + const int32_t num_levels, const int32_t num_points, const int32_t num_heads, + int32_t pad_num_levels_points) { + int32_t pad_total_points = deal_n * pad_num_levels_points; + int32_t pad_block_points = pad_num_levels_points; + T* buf_x_nram = buf_nram; + T* buf_y_nram = buf_nram + pad_total_points; + T* buf_cond_nram = buf_nram + 2 * pad_total_points; + T* buf_x_floor = buf_nram + 2 * pad_total_points; + T* buf_y_floor = buf_nram + 3 * pad_total_points; + T* buf_x_ceil = buf_nram + 4 * pad_total_points; + T* buf_y_ceil = buf_nram + 5 * pad_total_points; + __sync_io_move_compute(); + __bang_write_value(buf_x_nram, pad_total_points, 0); + __bang_write_value(buf_y_nram, pad_total_points, 0); + __bang_write_value(buf_x_floor, pad_total_points, 0); + __bang_write_value(buf_x_ceil, pad_total_points, 0); + __bang_write_value(buf_y_floor, pad_total_points, 0); + __bang_write_value(buf_y_ceil, pad_total_points, 0); + + //================================================================================================ + __memcpy(buf_x_nram, loc_nram, sizeof(T), NRAM2NRAM, sizeof(T), 2 * sizeof(T), + pad_total_points - 1); + __memcpy(buf_y_nram, loc_nram + 1, sizeof(T), NRAM2NRAM, sizeof(T), + 2 * sizeof(T), pad_total_points - 1); + + // x = loc_x * spatial_w - 0.5; y = loc_y * spatial_h - 0.5; + __bang_cycle_mul(buf_x_nram, buf_x_nram, spatial_w_bd_nram, pad_total_points, + pad_block_points); + __bang_sub_scalar(buf_x_nram, buf_x_nram, (T)0.5, pad_total_points); + __bang_cycle_mul(buf_y_nram, buf_y_nram, spatial_h_bd_nram, pad_total_points, + pad_block_points); + __bang_sub_scalar(buf_y_nram, buf_y_nram, (T)0.5, pad_total_points); + + //================================================================================================ + // get point condition. use buf0, buf1, buf2 + // (x > -1 && y > -1 && y < spatial_h && x < spatial_w) + __bang_write_value(cond_point_valid_nram, pad_total_points, (T)-1.0); + __bang_gt(cond_point_valid_nram, buf_x_nram, cond_point_valid_nram, + pad_total_points); + __bang_write_value(buf_cond_nram, pad_total_points, (T)-1.0); + __bang_gt(buf_cond_nram, buf_y_nram, buf_cond_nram, pad_total_points); + + __bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram, + pad_total_points); + __bang_cycle_lt(buf_cond_nram, buf_x_nram, spatial_w_bd_nram, + pad_total_points, pad_block_points); + __bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram, + pad_total_points); + __bang_cycle_lt(buf_cond_nram, buf_y_nram, spatial_h_bd_nram, + pad_total_points, pad_block_points); + __bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram, + pad_total_points); + //================================================================================================ + __mluop_floor(buf_x_floor, buf_x_nram, 2 * pad_total_points); + __bang_add_scalar(buf_x_ceil, buf_x_floor, 1.0, pad_total_points); + __bang_add_scalar(buf_y_ceil, buf_y_floor, 1.0, pad_total_points); + + T* cond_point_polation_nram_tl = cond_point_polation_nram; + T* cond_point_polation_nram_bl = cond_point_polation_nram + pad_total_points; + T* cond_point_polation_nram_tr = + cond_point_polation_nram + 2 * pad_total_points; + T* cond_point_polation_nram_br = + cond_point_polation_nram + 3 * pad_total_points; + T* cond_point_polation_nram_cond1 = weight_polation_nram; + T* cond_point_polation_nram_cond2 = weight_polation_nram + pad_total_points; + T* cond_point_polation_nram_cond3 = + weight_polation_nram + 2 * pad_total_points; + T* cond_point_polation_nram_cond4 = + weight_polation_nram + 3 * pad_total_points; + __bang_ge_scalar(cond_point_polation_nram_cond1, buf_x_floor, (T)0, + pad_total_points); + __bang_cycle_lt(cond_point_polation_nram_cond2, buf_x_ceil, spatial_w_bd_nram, + pad_total_points, pad_block_points); + __bang_ge_scalar(cond_point_polation_nram_cond3, buf_y_floor, (T)0, + pad_total_points); + __bang_cycle_lt(cond_point_polation_nram_cond4, buf_y_ceil, spatial_h_bd_nram, + pad_total_points, pad_block_points); + __bang_and(cond_point_polation_nram_tl, cond_point_polation_nram_cond1, + cond_point_polation_nram_cond4, pad_total_points); + __bang_and(cond_point_polation_nram_bl, cond_point_polation_nram_cond1, + cond_point_polation_nram_cond3, pad_total_points); + __bang_and(cond_point_polation_nram_tr, cond_point_polation_nram_cond2, + cond_point_polation_nram_cond4, pad_total_points); + __bang_and(cond_point_polation_nram_br, cond_point_polation_nram_cond2, + cond_point_polation_nram_cond3, pad_total_points); + //================================================================================================ + // get polation weight. + T* buf_dx = (T*)data_offset_nram; + T* buf_dy = buf_dx + pad_total_points; + T* buf_dx_1 = buf_dy + pad_total_points; + T* buf_dy_1 = buf_dx_1 + pad_total_points; + // -dx = x_floor-x + // -dy = y_floor-y + // w1 = (1-dx)*dy = (dx-1)*(-dy) + // w2 = (1-dx)*(1-dy) = (dx-1)*(dy-1) + // w3 = dx*dy = (-dx)*(-dy) + // w4 = dx*(1-dy) = (-dx)*(dy-1) + T* weight_polation_nram_1 = weight_polation_nram; + T* weight_polation_nram_2 = weight_polation_nram + 1 * pad_total_points; + T* weight_polation_nram_3 = weight_polation_nram + 2 * pad_total_points; + T* weight_polation_nram_4 = weight_polation_nram + 3 * pad_total_points; + // T* weight_polation_nram_buf = buf_nram + 4 * total_points; + __bang_sub(buf_dx, buf_x_floor, buf_x_nram, pad_total_points); + __bang_sub(buf_dy, buf_y_floor, buf_y_nram, pad_total_points); + + __bang_sub(buf_dx_1, buf_x_nram, buf_x_floor, pad_total_points); + __bang_sub_scalar(buf_dx_1, buf_dx_1, (T)1.0, pad_total_points); + + __bang_sub(buf_dy_1, buf_y_nram, buf_y_floor, pad_total_points); + __bang_sub_scalar(buf_dy_1, buf_dy_1, (T)1.0, pad_total_points); + + __bang_mul(weight_polation_nram_1, buf_dx_1, buf_dy, pad_total_points); + __bang_mul(weight_polation_nram_2, buf_dx_1, buf_dy_1, pad_total_points); + __bang_mul(weight_polation_nram_3, buf_dx, buf_dy, pad_total_points); + __bang_mul(weight_polation_nram_4, buf_dx, buf_dy_1, pad_total_points); + //================================================================================================ + // correct the x,y in [0, w-1] and [0, h-1] + T* spatial_w1_bd_nram = buf_nram; + T* spatial_h1_bd_nram = buf_nram + pad_total_points; + __bang_sub_scalar(spatial_w1_bd_nram, spatial_w_bd_nram, (T)1, + pad_total_points); + __bang_sub_scalar(spatial_h1_bd_nram, spatial_h_bd_nram, (T)1, + pad_total_points); + T* maxtemp = (T*)data_offset_nram; + __bang_write_value(maxtemp, pad_total_points, (T)0); + __bang_maxequal(buf_x_floor, buf_x_floor, maxtemp, pad_total_points); + __bang_maxequal(buf_x_ceil, buf_x_ceil, maxtemp, pad_total_points); + __bang_cycle_minequal(buf_x_floor, buf_x_floor, spatial_w1_bd_nram, + pad_total_points, pad_block_points); + __bang_cycle_minequal(buf_x_ceil, buf_x_ceil, spatial_w1_bd_nram, + pad_total_points, pad_block_points); + __bang_maxequal(buf_y_floor, buf_y_floor, maxtemp, pad_total_points); + __bang_maxequal(buf_y_ceil, buf_y_ceil, maxtemp, pad_total_points); + __bang_cycle_minequal(buf_y_floor, buf_y_floor, spatial_h1_bd_nram, + pad_total_points, pad_block_points); + __bang_cycle_minequal(buf_y_ceil, buf_y_ceil, spatial_h1_bd_nram, + pad_total_points, pad_block_points); + //================================================================================================ + // offset = y*w + x + T* buf_hw_offset = buf_nram; + T* data_offset_nram_tl = (T*)data_offset_nram; + T* data_offset_nram_bl = data_offset_nram_tl + pad_total_points; + T* data_offset_nram_tr = data_offset_nram_bl + pad_total_points; + T* data_offset_nram_br = data_offset_nram_tr + pad_total_points; + // y_ceil*w + offset + x_floor + __bang_cycle_mul(buf_hw_offset, buf_y_ceil, spatial_w_bd_nram, + pad_total_points, pad_block_points); + __bang_cycle_add(buf_hw_offset, buf_hw_offset, spatial_offset_bd_nram, + pad_total_points, pad_block_points); + __bang_add(data_offset_nram_tl, buf_hw_offset, buf_x_floor, pad_total_points); + // y_ceil*w + offset + x_ceil + __bang_add(data_offset_nram_tr, buf_hw_offset, buf_x_ceil, pad_total_points); + // y_floor*w + offset + x_foor + __bang_cycle_mul(buf_hw_offset, buf_y_floor, spatial_w_bd_nram, + pad_total_points, pad_block_points); + __bang_cycle_add(buf_hw_offset, buf_hw_offset, spatial_offset_bd_nram, + pad_total_points, pad_block_points); + + __bang_add(data_offset_nram_bl, buf_hw_offset, buf_x_floor, pad_total_points); + // y_floor*w + offset + x_ceil + __bang_add(data_offset_nram_br, buf_hw_offset, buf_x_ceil, pad_total_points); + __bang_cycle_and(cond_point_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, 4 * pad_total_points, + pad_total_points); + __bang_cycle_mul(weight_polation_nram, weight_polation_nram, weight_attn_nram, + 4 * pad_total_points, pad_total_points); + __bang_mul(weight_polation_nram, weight_polation_nram, + cond_point_polation_nram, pad_total_points * 4); + __bang_sub((float*)data_offset_nram_bl, (float*)data_offset_nram_bl, + (float*)data_offset_nram_tl, pad_total_points); + __bang_sub((float*)data_offset_nram_tr, (float*)data_offset_nram_tr, + (float*)data_offset_nram_tl, pad_total_points); +} + +/* + shape of each tensor: + output_nram: (channels) + input_nram: (4, valid_num, channels) + input_trans: (channels, 4, valid_num) + weight_selected_base: (4, deal_n, num_levels, num_points) + weight_compute: (4, valid_num) +*/ +template +__mlu_func__ void reduceLevelByConv(T* output_nram, T* input_nram, + T* input_trans, T* weight_selected_base, + T* weight_compute, + const int32_t pad_num_levels_points, + const int32_t pad_channels, + const int32_t pad_sample_stride_3) { + int32_t ci = 4 * pad_num_levels_points; + int32_t co = pad_channels; + __bang_write_value(weight_compute, 4 * pad_num_levels_points, 0); + __memcpy(weight_compute, weight_selected_base, + pad_num_levels_points * sizeof(T), NRAM2NRAM, + pad_num_levels_points * sizeof(T), pad_sample_stride_3 * sizeof(T), + 3); + __bang_transpose(input_trans, input_nram, ci, co); + __bang_cycle_mul(input_trans, input_trans, weight_compute, co * ci, ci); + __bang_sumpool(input_nram, input_trans, pad_num_levels_points, pad_channels, + 4, 1, 4, 1, 1); + __bang_transpose(input_trans, input_nram, pad_channels, + pad_num_levels_points); + __bang_sumpool(output_nram, input_trans, pad_channels, pad_num_levels_points, + 1, pad_num_levels_points, 1, 1, 1); +} + +__mlu_func__ void loadNram2Gpr(int32_t& v1, int32_t& v2, int32_t& v3, + int32_t* p1, int32_t* p2, int32_t* p3, + int32_t num_heads, int32_t channels_size) { + int32_t stride = num_heads * channels_size; + v1 = (int32_t)(*(float*)p1) * stride; + v2 = (int32_t)(*(float*)p2) * stride; + v3 = (int32_t)(*(float*)p3) * stride; +} + +/* + Load 4 neighbors use one 3D-memcpy, just use offset of N1, stride_3_1 + and + stride_2_1. + |<- stride_3_1 ->| + N1 N3 + ^ + | + stride_2_1 + | + v + N2 N4 + + Trickly fold the loop as 2. +*/ +template +__mlu_func__ void loadDataValueXram2NramAsync( + T* buf_value_nram_1, int32_t* offset_1, int32_t* stride_2_1, + int32_t* stride_3_1, T* value_src, const int32_t num_levels_points, + const int32_t channel_size, const int32_t num_heads) { + int32_t offset_1_a, stride_2_1_a, stride_3_1_a; + int32_t offset_1_b, stride_2_1_b, stride_3_1_b; + loadNram2Gpr(offset_1_a, stride_2_1_a, stride_3_1_a, offset_1, stride_2_1, + stride_3_1, num_heads, channel_size); + loadNram2Gpr(offset_1_b, stride_2_1_b, stride_3_1_b, offset_1 + 1, + stride_2_1 + 1, stride_3_1 + 1, num_heads, channel_size); + + int32_t value_offset = 0; + int32_t next = 0; + int32_t loop_num = num_levels_points / 2; + int32_t remain = num_levels_points % 2; + + int32_t pad_num_levels_points = + PAD_UP(num_levels_points, NFU_ALIGN_SIZE / sizeof(T)); + int32_t pad_channels = + PAD_UP(channel_size / sizeof(T), NFU_ALIGN_SIZE / sizeof(T)); + int32_t pad_channels_size = pad_channels * sizeof(T); + int32_t pad_data_value_stride = pad_num_levels_points * pad_channels_size; + for (int32_t j = 0; j < loop_num * 2; j += 2) { + value_offset = j * pad_channels_size; + next = j + 2; + for (int i = 0; i < 2; i++) { + __memcpy_async( + (int8_t*)buf_value_nram_1 + value_offset + pad_data_value_stride * i, + (int8_t*)value_src + offset_1_a + i * stride_2_1_a, channel_size, + GDRAM2NRAM, 2 * pad_data_value_stride, stride_3_1_a, 1); + } + + loadNram2Gpr(offset_1_a, stride_2_1_a, stride_3_1_a, offset_1 + next, + stride_2_1 + next, stride_3_1 + next, num_heads, channel_size); + + for (int i = 0; i < 2; i++) { + __memcpy_async((int8_t*)buf_value_nram_1 + value_offset + + pad_channels_size + pad_data_value_stride * i, + (int8_t*)value_src + offset_1_b + i * stride_2_1_b, + channel_size, GDRAM2NRAM, 2 * pad_data_value_stride, + stride_3_1_b, 1); + } + + loadNram2Gpr(offset_1_b, stride_2_1_b, stride_3_1_b, offset_1 + next + 1, + stride_2_1 + next + 1, stride_3_1 + next + 1, num_heads, + channel_size); + } + + if (remain > 0) { + value_offset = loop_num * 2 * pad_channels_size; + for (int i = 0; i < 2; i++) { + __memcpy_async( + (int8_t*)buf_value_nram_1 + value_offset + pad_data_value_stride * i, + (int8_t*)value_src + offset_1_a + i * stride_2_1_a, channel_size, + GDRAM2NRAM, 2 * pad_data_value_stride, stride_3_1_a, 1); + } + } +} + +template +__mlu_func__ void loadNeighborPolationAttn( + T* value_output_nram, T* value_gdram, int32_t* data_offset_nram, + T* weight_polation_nram, T* cond_point_polation_nram, + T* cond_point_valid_nram, T* weight_attn_nram, T* buf_nram, + T* compute_buf_nram, const int32_t deal_n, const int32_t num_levels, + const int32_t num_points, const int32_t num_keys, const int32_t channels, + const int32_t num_heads, const int32_t pad_channels, + const int32_t pad_num_levels_points) { + int32_t channel_size = channels * sizeof(T); + int32_t pad_sample_stride_3 = deal_n * pad_num_levels_points; + + T* buf_value_nram = buf_nram; // (4, num_levels, num_points, channels) + T* buf_value_nram_trans = buf_nram + 4 * pad_num_levels_points * pad_channels; + T* weight_compute_nram = compute_buf_nram; // (4, num_levels, num_points) + + int32_t* offset = data_offset_nram; + int32_t* stride_2_1 = offset + pad_sample_stride_3; + int32_t* stride_3_1 = stride_2_1 + pad_sample_stride_3; + T* output_nram = value_output_nram; + int32_t step_offset = 0; + int32_t num_levels_points = num_levels * num_points; + for (int32_t i = 0; i < deal_n; i++) { + __bang_write_value(buf_value_nram, 4 * pad_num_levels_points * pad_channels, + 0); + __sync_compute(); + loadDataValueXram2NramAsync(buf_value_nram, offset, stride_2_1, + stride_3_1, value_gdram, num_levels_points, + channel_size, num_heads); + __sync_io(); + reduceLevelByConv(output_nram, buf_value_nram, buf_value_nram_trans, + weight_polation_nram + step_offset, weight_compute_nram, + pad_num_levels_points, pad_channels, pad_sample_stride_3); + step_offset += pad_num_levels_points; + offset = data_offset_nram + step_offset; + stride_2_1 = offset + pad_sample_stride_3; + stride_3_1 = stride_2_1 + pad_sample_stride_3; + output_nram += pad_channels; + } +} + +template +__mlu_func__ void prepareLoop( + int32_t* spatial_offset_nram, int32_t* spatial_hw_nram, + T* spatial_offset_bd_nram, T* spatial_h_bd_nram, T* spatial_w_bd_nram, + const char* data_level_start_index_gdram, + const char* data_spatial_shapes_gdram, const int32_t num_keys, + const int32_t num_levels, const int32_t num_points, + const int32_t max_deal_n, const int32_t channels) { + __memcpy(spatial_offset_nram, data_level_start_index_gdram, + num_levels * sizeof(int32_t), GDRAM2NRAM); + __memcpy(spatial_hw_nram, data_spatial_shapes_gdram, + num_levels * 2 * sizeof(int32_t), GDRAM2NRAM); + broadcastSpatialHW(spatial_offset_bd_nram, spatial_h_bd_nram, + spatial_w_bd_nram, spatial_hw_nram, spatial_offset_nram, + num_levels, num_points); +} + +/* + The shape of each tensor: + buf_compute_nram: (8, num_levels, num_points) + spatial_offset_nram: (num_levels) + spatial_hw_nram: (num_levels, 2) + spatial_offset_bd_nram: (num_levels, num_points) + spatial_w_bd_nram: (num_levels, num_points) + spatial_h_bd_nram: (num_levels, num_points) + value_output_nram: (deal_n, channels) + data_offset_nram: (4, deal_n, num_levels, num_points) + weight_polation_nram: (4, deal_n, num_levels, num_points) + cond_point_polation_nram: (4, deal_n, num_levels, num_points) + cond_point_valid_nram: (deal_n, num_levels, num_points) + loc_nram: (deal_n, num_levels, num_points, 2) + weight_attn_nram: (deal_n, num_levels, num_points) + buf_nram: (6, deal_n, num_levels, num_points) + + Note: buf_nram is reused in polation computing. +*/ +template +__mlu_func__ void memPolicyCommon( + T*& buf_compute_nram, T*& value_output_nram, int32_t*& data_offset_nram, + T*& weight_polation_nram, T*& cond_point_polation_nram, + T*& cond_point_valid_nram, T*& loc_nram, T*& weight_attn_nram, T*& buf_nram, + T*& buf_nram_end, T*& spatial_offset_bd_nram, T*& spatial_w_bd_nram, + T*& spatial_h_bd_nram, int32_t*& spatial_offset_nram, + int32_t*& spatial_hw_nram, int32_t& max_deal_n, int32_t& pad_channels, + int32_t& pad_num_levels_points, int32_t& pad_total_points, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points) { + pad_channels = PAD_UP(channels, NFU_ALIGN_SIZE / sizeof(T)); + int32_t num_levels_points = num_levels * num_points; + pad_num_levels_points = PAD_UP(num_levels_points, NFU_ALIGN_SIZE / sizeof(T)); + int32_t pad_num_levels_points_8 = 8 * pad_num_levels_points; + int32_t spatial_info_size = + PAD_UP(3 * num_levels * sizeof(int32_t), NFU_ALIGN_SIZE); + int32_t fix_space_size = + spatial_info_size + + (3 * pad_num_levels_points + pad_num_levels_points) * sizeof(T); + int32_t left_space_size = NRAM_AVALIABLE_SIZE - fix_space_size; + int32_t common_buffer_size_each = 6 * pad_num_levels_points * sizeof(T); + int32_t inter_result_size_each = + 17 * pad_num_levels_points * sizeof(T) + pad_channels * sizeof(T); + + max_deal_n = + left_space_size / (common_buffer_size_each + inter_result_size_each); + + int32_t compute_buffer_size = + (9 * pad_num_levels_points * pad_channels) * sizeof(T); + int32_t common_buffer_size = max_deal_n * common_buffer_size_each; + // make sure buf_nram is large enough for compute + if (compute_buffer_size > common_buffer_size) { + int32_t tmp_deal_n = + (left_space_size - compute_buffer_size) / inter_result_size_each; + max_deal_n = __mluop_min(max_deal_n, tmp_deal_n); + } + + pad_total_points = max_deal_n * pad_num_levels_points; + buf_compute_nram = (T*)nram_buffer; + spatial_offset_nram = (int32_t*)(buf_compute_nram + pad_num_levels_points_8); + int32_t pad_3_levels = PAD_UP(3 * num_levels, NFU_ALIGN_SIZE / sizeof(T)); + spatial_hw_nram = spatial_offset_nram + num_levels; + spatial_offset_bd_nram = (T*)(spatial_offset_nram + pad_3_levels); + spatial_w_bd_nram = spatial_offset_bd_nram + pad_num_levels_points; + spatial_h_bd_nram = spatial_w_bd_nram + pad_num_levels_points; + value_output_nram = spatial_h_bd_nram + pad_num_levels_points; + data_offset_nram = (int32_t*)(value_output_nram + max_deal_n * pad_channels); + weight_polation_nram = (T*)(data_offset_nram + 4 * pad_total_points); + cond_point_polation_nram = weight_polation_nram + 4 * pad_total_points; + cond_point_valid_nram = cond_point_polation_nram + 4 * pad_total_points; + loc_nram = cond_point_valid_nram + pad_total_points; + weight_attn_nram = + loc_nram + + 2 * pad_total_points; // total_coord_pad = 2 * pad_total_points + buf_nram = weight_attn_nram + pad_total_points; + buf_nram_end = buf_nram + 6 * max_deal_n * pad_num_levels_points; +} + +template +__mlu_func__ void MLUKernelMsDeformAttnForwardFastImpl( + const char* data_value_gdram, const char* data_spatial_shapes_gdram, + const char* data_level_start_index_gdram, + const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char* data_col_gdram) { + int32_t input_stride_4 = num_queries * num_heads * num_levels * num_points; + int32_t input_stride_3 = num_heads * num_levels * num_points; + int32_t input_stride_2 = num_levels * num_points; + int32_t output_stride_3 = num_queries * num_heads * channels; + int32_t output_stride_2 = num_heads * channels; + int32_t data_value_stride_3 = num_keys * num_heads * channels; + + T* value_output_nram = nullptr; // (deal_n, channels) + int32_t* data_offset_nram = nullptr; // (4, deal_n, num_levels, num_points) + T* weight_polation_nram = nullptr; // (4, deal_n, num_levels, num_points) + T* cond_point_polation_nram = nullptr; // (4, deal_n, num_levels, num_points) + T* cond_point_valid_nram = nullptr; // (deal_n, num_levels, num_points) + T* loc_nram = nullptr; // (deal_n, num_levels, num_points, 2) + T* weight_attn_nram = nullptr; // (deal_n, num_levels, num_points) + T* buf_nram = nullptr; // (6, deal_n, num_levels, num_points) + T* buf_nram_end = nullptr; + T* spatial_offset_bd_nram = nullptr; // (num_levels, num_points) + T* spatial_w_bd_nram = nullptr; // (num_levels, num_points) + T* spatial_h_bd_nram = nullptr; // (num_levels, num_points) + int32_t* spatial_offset_nram = nullptr; // (num_levels) + int32_t* spatial_hw_nram = nullptr; // (num_levels, 2) + T* buf_compute_nram = nullptr; // (8, num_levels, num_points) + int32_t max_deal_n = 0; + int32_t pad_channels = 0; + int32_t pad_num_levels_points = 0; + int32_t pad_total_points = 0; + + memPolicyCommon(buf_compute_nram, value_output_nram, data_offset_nram, + weight_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, loc_nram, weight_attn_nram, buf_nram, + buf_nram_end, spatial_offset_bd_nram, spatial_w_bd_nram, + spatial_h_bd_nram, spatial_offset_nram, spatial_hw_nram, + max_deal_n, pad_channels, pad_num_levels_points, + pad_total_points, batch_size, num_keys, num_heads, channels, + num_levels, num_queries, num_points); + if (max_deal_n <= 0) { + return; + } + + // split batch*head into taskDimY + int32_t batch_head = batch_size * num_heads; + int32_t cluster_avg_batch_head = (batch_head + taskDimY - 1) / taskDimY; + int32_t cluster_begin_batch_head = taskIdY * cluster_avg_batch_head; + int32_t cluster_act_batch_head = __mluop_min( + cluster_avg_batch_head, batch_head - cluster_begin_batch_head); + int32_t cluster_end_batch_head = + cluster_begin_batch_head + cluster_act_batch_head; + // split query into coreDim + int32_t core_avg_query = (num_queries + coreDim - 1) / coreDim; + int32_t core_begin_query = coreId * core_avg_query; + int32_t core_act_query = + __mluop_min(num_queries - core_begin_query, core_avg_query); + int32_t core_loop_num = (core_act_query + max_deal_n - 1) / max_deal_n; + int32_t core_step_query = + core_loop_num > 0 ? (core_act_query + core_loop_num - 1) / core_loop_num + : 0; + int32_t core_remain_query = + core_act_query - (core_loop_num - 1) * core_step_query; + int32_t first_deal_query = + (int)(core_loop_num > 0) * + (core_loop_num > 1 ? core_step_query : core_remain_query); + + prepareLoop(spatial_offset_nram, spatial_hw_nram, spatial_offset_bd_nram, + spatial_h_bd_nram, spatial_w_bd_nram, + data_level_start_index_gdram, data_spatial_shapes_gdram, num_keys, + num_levels, num_points, max_deal_n, channels); + + for (int32_t bh_idx = cluster_begin_batch_head; + bh_idx < cluster_end_batch_head; bh_idx++) { + int32_t b = bh_idx / num_heads; + int32_t head_idx = bh_idx % num_heads; + + size_t output_base_offset = + (size_t)b * output_stride_3 + head_idx * channels; + int32_t attn_weight_base_offset = + b * input_stride_4 + head_idx * input_stride_2; + + __sync_cluster(); + + if (__is_ipu()) { + // compute weight, offset and condition + int32_t attn_weight_offset = + attn_weight_base_offset + core_begin_query * input_stride_3; + int32_t loc_offset = attn_weight_offset * 2; + if (first_deal_query > 0) { + __bang_write_value(loc_nram, 2 * pad_total_points, 0); + __bang_write_value(weight_attn_nram, pad_total_points, 0); + __sync_compute(); + __memcpy_async(loc_nram, (T*)data_sampling_loc_gdram + loc_offset, + input_stride_2 * 2 * sizeof(T), GDRAM2NRAM, + pad_num_levels_points * 2 * sizeof(T), + input_stride_3 * 2 * sizeof(T), first_deal_query - 1); + __memcpy_async(weight_attn_nram, + (T*)data_attn_weight_gdram + attn_weight_offset, + input_stride_2 * sizeof(T), GDRAM2NRAM, + pad_num_levels_points * sizeof(T), + input_stride_3 * sizeof(T), first_deal_query - 1); + getConditionCoordWeight( + data_offset_nram, weight_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, loc_nram, weight_attn_nram, + spatial_offset_bd_nram, spatial_w_bd_nram, spatial_h_bd_nram, + buf_nram, first_deal_query, num_levels, num_points, num_heads, + pad_num_levels_points); + } + } + + for (int32_t i = 0; __is_ipu() && i < core_loop_num; i++) { + __bang_write_value(loc_nram, 2 * pad_total_points, 0); + __bang_write_value(weight_attn_nram, pad_total_points, 0); + int32_t deal_n = + i < core_loop_num - 1 ? core_step_query : core_remain_query; + int32_t load_n = + i < core_loop_num - 2 ? core_step_query : core_remain_query; + // load value and polation + loadNeighborPolationAttn( + value_output_nram, + (T*)data_value_gdram + b * data_value_stride_3 + head_idx * channels, + data_offset_nram, weight_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, weight_attn_nram, buf_nram, buf_compute_nram, + deal_n, num_levels, num_points, num_keys, channels, num_heads, + pad_channels, pad_num_levels_points); + __sync_io_move_compute(); + // load next weight and loc + if (i < core_loop_num - 1) { + int32_t core_query_offset = (i + 1) * core_step_query; + int32_t attn_weight_offset = + attn_weight_base_offset + + (core_begin_query + core_query_offset) * input_stride_3; + int32_t loc_offset = attn_weight_offset * 2; + __memcpy_async(loc_nram, (T*)data_sampling_loc_gdram + loc_offset, + input_stride_2 * 2 * sizeof(T), GDRAM2NRAM, + pad_num_levels_points * 2 * sizeof(T), + input_stride_3 * 2 * sizeof(T), load_n - 1); + __memcpy_async(weight_attn_nram, + (T*)data_attn_weight_gdram + attn_weight_offset, + input_stride_2 * sizeof(T), GDRAM2NRAM, + pad_num_levels_points * sizeof(T), + input_stride_3 * sizeof(T), load_n - 1); + __sync_io_move_compute(); + } + // store result + size_t output_offset = + ((size_t)core_begin_query + i * core_step_query) * output_stride_2; + __memcpy_async((T*)data_col_gdram + output_base_offset + output_offset, + value_output_nram, channels * sizeof(T), NRAM2GDRAM, + output_stride_2 * sizeof(T), pad_channels * sizeof(T), + deal_n - 1); + + // compute cond/weight/offset + if (i < core_loop_num - 1) { + getConditionCoordWeight( + data_offset_nram, weight_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, loc_nram, weight_attn_nram, + spatial_offset_bd_nram, spatial_w_bd_nram, spatial_h_bd_nram, + buf_nram, load_n, num_levels, num_points, num_heads, + pad_num_levels_points); + } + __sync_io_move_compute(); + } + __sync_cluster(); + } +} + +template +__mlu_global__ void MLUKernelMsDeformAttnForwardFast( + const char* data_value_gdram, const char* data_spatial_shapes_gdram, + const char* data_level_start_index_gdram, + const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char* data_col_gdram) { + MLUKernelMsDeformAttnForwardFastImpl( + data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram, + data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys, + num_heads, channels, num_levels, num_queries, num_points, data_col_gdram); +} + +template __mlu_global__ void MLUKernelMsDeformAttnForwardFast( + const char* data_value_gdram, const char* data_spatial_shapes_gdram, + const char* data_level_start_index_gdram, + const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char* data_col_gdram); + +void KernelMsDeformAttnForwardFast( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const char* data_value_gdram, + const char* data_spatial_shapes_gdram, + const char* data_level_start_index_gdram, + const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char* data_col_gdram) { + MLUKernelMsDeformAttnForwardFast<<>>( + data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram, + data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys, + num_heads, channels, num_levels, num_queries, num_points, data_col_gdram); +} diff --git a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp index 845465ae4b..25c8f6209b 100644 --- a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp @@ -1,517 +1,60 @@ -/************************************************************************* - * Copyright (C) 2022 by Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "pytorch_cpp_helper.hpp" #include "pytorch_device_registry.hpp" -#include "pytorch_mlu_helper.hpp" -#define MIN(a, b) (((a) < (b)) ? (a) : (b)) - -typedef enum { - MS_DEFORM_ATTN_FORWARD_INVALID = 0, /*!< Index is invalid. */ - MS_DEFORM_ATTN_FORWARD_DEFAULT = - 1, /*!< MLUKernelMsDeformAttnForwardDefault */ - MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL = - 2, /*!< MLUKernelMsDeformAttnForwardSmallChannel */ -} MsDeformAttnForwardPolicy; - -void KernelMsDeformAttnForwardDefault( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const char* data_value_gdram, - const char* data_spatial_shapes_gdram, - const char* data_level_start_index_gdram, - const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, - const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_queries, - const int32_t num_points, char* data_col_gdram); -void KernelMsDeformAttnForwardSmallChannel( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const char* data_value_gdram, - const char* data_spatial_shapes_gdram, - const char* data_level_start_index_gdram, - const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, - const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_queries, - const int32_t num_points, char* data_col_gdram); - -typedef enum { - MS_DEFORM_ATTN_BACKWARD_DEFAULT = 0, - MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL = 1, -} MsDeformAttnBackwardKernelPolicy; - -MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc( - const int32_t channels, const int32_t num_levels, - const int32_t num_points) { - const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); - const uint64_t max_num = nram_size / sizeof(float); - const uint64_t deal_num = - 12 * PAD_UP(channels, 32) + 3 * PAD_UP(num_levels, 32) + 3 * num_points; - - if (max_num >= deal_num) { - return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL; - } - - return MS_DEFORM_ATTN_BACKWARD_DEFAULT; +Tensor ms_deform_attn_impl_forward(const Tensor &value, + const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const int im2col_step) { + return DISPATCH_DEVICE_IMPL(ms_deform_attn_impl_forward, value, + spatial_shapes, level_start_index, sampling_loc, + attn_weight, im2col_step); } -void KernelMsDeformAttnBackwardDefaultKernel( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const float* data_value, - const int32_t* spatial_shapes, const int32_t* data_level_start_index, - const float* data_sampling_loc, const float* data_attn_weight, - const float* grad_output, const int32_t batch_size, const int32_t num_keys, - const int32_t num_heads, const int32_t channels, const int32_t num_levels, - const int32_t num_queries, const int32_t num_points, float* grad_value, - float* grad_sampling_loc, float* grad_attn_weight); - -void KernelMsDeformAttnBackwardSmallChannelsKernel( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const float* data_value, - const int32_t* spatial_shapes, const int32_t* data_level_start_index, - const float* data_sampling_loc, const float* data_attn_weight, - const float* grad_output, const int32_t batch, const int32_t spatial_size, - const int32_t num_heads, const int32_t channels, const int32_t num_levels, - const int32_t num_query, const int32_t num_points, float* grad_value, - float* grad_sampling_loc, float* grad_attn_weight); - -// policy function -MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc( - cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, const int32_t batch_size, - const int32_t num_keys, const int32_t num_heads, const int32_t channels, - const int32_t num_levels, const int32_t num_queries, - const int32_t num_points) { - k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - k_dim->y = - MIN((batch_size * num_queries * num_heads + k_dim->x - 1) / k_dim->x, - torch_mlu::getDeviceAttr(cnrtAttrClusterCount)); - k_dim->z = 1; -#if __BANG_ARCH__ == 520 - *k_type = CNRT_FUNC_TYPE_BLOCK; -#else - *k_type = CNRT_FUNC_TYPE_UNION1; -#endif - - int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); - if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) { - return MS_DEFORM_ATTN_FORWARD_DEFAULT; - } else if (channels > nram_size / 12 / sizeof(float)) { - return MS_DEFORM_ATTN_FORWARD_DEFAULT; - } else { - return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL; - } -} - -// policy function for backward -static void policyFuncBackward(const int32_t batch_size, - const int32_t num_queries, - const int32_t num_heads, - const int32_t num_levels, - cnrtFunctionType_t* k_type, cnrtDim3_t* k_dim) { - size_t cluster_limit = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - size_t core_limit = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - k_dim->x = core_limit; - int32_t total_num = batch_size * num_queries * num_heads * num_levels; - size_t total_num_align = CEIL_ALIGN(total_num, core_limit); - k_dim->y = (total_num_align / core_limit) > cluster_limit - ? cluster_limit - : (total_num_align / core_limit); - k_dim->z = 1; - *k_type = CNRT_FUNC_TYPE_UNION1; +void ms_deform_attn_impl_backward( + const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, const Tensor &sampling_loc, + const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value, + Tensor &grad_sampling_loc, Tensor &grad_attn_weight, + const int im2col_step) { + DISPATCH_DEVICE_IMPL(ms_deform_attn_impl_backward, value, spatial_shapes, + level_start_index, sampling_loc, attn_weight, + grad_output, grad_value, grad_sampling_loc, + grad_attn_weight, im2col_step); } -Tensor ms_deform_attn_mlu_forward(const Tensor& value, - const Tensor& spatial_shapes, - const Tensor& level_start_index, - const Tensor& sampling_loc, - const Tensor& attn_weight, - const int im2col_step) { - // check contiguous - AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); - AT_ASSERTM(spatial_shapes.is_contiguous(), - "spatial_shapes tensor has to be contiguous"); - AT_ASSERTM(level_start_index.is_contiguous(), - "level_start_index tensor has to be contiguous"); - AT_ASSERTM(sampling_loc.is_contiguous(), - "sampling_loc tensor has to be contiguous"); - AT_ASSERTM(attn_weight.is_contiguous(), - "attn_weight tensor has to be contiguous"); - - // check datatype - TORCH_CHECK((value.scalar_type() == at::kFloat), - "value type should be Float, got ", value.scalar_type(), "."); - TORCH_CHECK((spatial_shapes.scalar_type() == at::kInt || - spatial_shapes.scalar_type() == at::kLong), - "spatial_shapes type should be Int, got ", - spatial_shapes.scalar_type(), "."); - TORCH_CHECK((level_start_index.scalar_type() == at::kInt || - level_start_index.scalar_type() == at::kLong), - "level_start_index type should be Int, got ", - level_start_index.scalar_type(), "."); - TORCH_CHECK((sampling_loc.scalar_type() == at::kFloat), - "sampling_loc type should be Float, got ", - sampling_loc.scalar_type(), "."); - TORCH_CHECK((attn_weight.scalar_type() == at::kFloat), - "attn_weight type should be Float, got ", - attn_weight.scalar_type(), "."); - - // check shape - TORCH_CHECK(value.dim() == 4, "value should be a 4d tensor, got ", - value.dim(), "D."); - TORCH_CHECK(spatial_shapes.dim() == 2, - "spatial_shapes should be a 2d tensor, got ", - spatial_shapes.dim(), "D."); - TORCH_CHECK(level_start_index.dim() == 1, - "level_start_index should be a 1d tensor, got ", - level_start_index.dim(), "D."); - TORCH_CHECK(sampling_loc.dim() == 6, - "sampling_loc should be a 6d tensor, got ", sampling_loc.dim(), - "D."); - TORCH_CHECK(attn_weight.dim() == 5, "attn_weight should be a 5d tensor, got ", - attn_weight.dim(), "D."); - - const int batch_size = value.size(0); - const int num_keys = value.size(1); - const int num_heads = value.size(2); - const int channels = value.size(3); - const int num_levels = spatial_shapes.size(0); - const int num_queries = sampling_loc.size(1); - const int num_points = sampling_loc.size(4); - - TORCH_CHECK(spatial_shapes.size(1) == 2, - "the 2nd dimensions of spatial_shapes should be 2, got ", - spatial_shapes.size(1), "."); - TORCH_CHECK(sampling_loc.size(5) == 2, - "the 6th dimensions of sampling_loc should be 2, got ", - sampling_loc.size(5), "."); - TORCH_CHECK((sampling_loc.size(0) == batch_size), - "the 1st dimensions of sampling_loc should be batch_size, ", - "but now the 1st dimension of sampling_loc is ", - sampling_loc.size(0), ", and batch_size is ", batch_size, "."); - TORCH_CHECK((attn_weight.size(0) == batch_size), - "the 1st dimensions of attn_weight should be batch_size, ", - "but now the 1st dimension of attn_weight is ", - attn_weight.size(0), ", and batch_size is ", batch_size, "."); - TORCH_CHECK((sampling_loc.size(2) == num_heads), - "the 3rd dimensions of sampling_loc should be num_heads, ", - "but now the 3rd dimension of sampling_loc is ", - sampling_loc.size(2), ", and num_heads is ", num_heads, "."); - TORCH_CHECK((attn_weight.size(2) == num_heads), - "the 3rd dimensions of attn_weight should be num_heads, ", - "but now the 3rd dimension of attn_weight is ", - attn_weight.size(2), ", and num_heads is ", num_heads, "."); - TORCH_CHECK((level_start_index.size(0) == num_levels), - "the 1st dimensions of level_start_index should be num_levels, ", - "but now the 1st dimension of level_start_index is ", - level_start_index.size(0), ", and num_levels is ", num_levels, - "."); - TORCH_CHECK((sampling_loc.size(3) == num_levels), - "the 4th dimensions of sampling_loc should be num_levels, ", - "but now the 4th dimension of sampling_loc is ", - sampling_loc.size(3), ", and num_levels is ", num_levels, "."); - TORCH_CHECK((attn_weight.size(3) == num_levels), - "the 4th dimensions of attn_weight should be num_levels, ", - "but now the 4th dimension of attn_weight is ", - attn_weight.size(3), ", and num_levels is ", num_levels, "."); - TORCH_CHECK((attn_weight.size(1) == num_queries), - "the 2nd dimensions of attn_weight should be num_queries, ", - "but now the 2nd dimension of attn_weight is ", - attn_weight.size(1), ", and num_queries is ", num_queries, "."); - TORCH_CHECK((attn_weight.size(4) == num_points), - "the 5th dimensions of attn_weight should be num_points, ", - "but now the 5th dimension of attn_weight is ", - attn_weight.size(4), ", and num_points is ", num_points, "."); - - auto output = at::zeros({batch_size, num_queries, num_heads, channels}, - value.options()); - - // large tensor check - const size_t max_input_size = 2147483648; - TORCH_CHECK(value.numel() < max_input_size, - "value element num should be less than 2^31, got ", value.numel(), - "."); - TORCH_CHECK(sampling_loc.numel() < max_input_size, - "sampling_loc element num should be less than 2^31, got ", - sampling_loc.numel(), "."); - TORCH_CHECK(output.numel() < max_input_size, - "output element num should be less than 2^31, got ", - output.numel(), "."); - - // check zero element - TORCH_CHECK(batch_size != 0, "batch_size should not be zero"); - TORCH_CHECK(num_heads != 0, "num_heads should not be zero"); - TORCH_CHECK(channels != 0, "channels should not be zero"); - TORCH_CHECK(num_queries != 0, "num_queries should not be zero"); - - if (num_keys == 0 || num_levels == 0 || num_points == 0) { - return output; - } - - // calculate task dimension - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - MsDeformAttnForwardPolicy policy = msDeformAttnForwardPolicyFunc( - &k_dim, &k_type, batch_size, num_keys, num_heads, channels, num_levels, - num_queries, num_points); - - // get compute queue - auto queue = torch_mlu::getCurQueue(); - - auto spatial_shapes_ = spatial_shapes.to(at::kInt); - auto level_start_index_ = level_start_index.to(at::kInt); - - // get ptr of tensors - auto value_impl = torch_mlu::getMluTensorImpl(value); - auto value_ptr = value_impl->cnnlMalloc(); - auto spatial_shapes_impl = torch_mlu::getMluTensorImpl(spatial_shapes_); - auto spatial_shapes_ptr = spatial_shapes_impl->cnnlMalloc(); - auto level_start_index_impl = torch_mlu::getMluTensorImpl(level_start_index_); - auto level_start_index_ptr = level_start_index_impl->cnnlMalloc(); - auto sampling_loc_impl = torch_mlu::getMluTensorImpl(sampling_loc); - auto sampling_loc_ptr = sampling_loc_impl->cnnlMalloc(); - auto attn_weight_impl = torch_mlu::getMluTensorImpl(attn_weight); - auto attn_weight_ptr = attn_weight_impl->cnnlMalloc(); - auto output_impl = torch_mlu::getMluTensorImpl(output); - auto output_ptr = output_impl->cnnlMalloc(); - - // get compute dtype of input - cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype()); - - // launch kernel - switch (policy) { - default: { - VLOG(5) << "MsDeformAttnForward Policy not supported"; - }; break; - case MS_DEFORM_ATTN_FORWARD_DEFAULT: { - CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardDefault<<<" - << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - KernelMsDeformAttnForwardDefault( - k_dim, k_type, queue, data_type, (char*)value_ptr, - (char*)spatial_shapes_ptr, (char*)level_start_index_ptr, - (char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys, - num_heads, channels, num_levels, num_queries, num_points, - (char*)output_ptr); - break; - } - case MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL: { - CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardSmallChannel<<<" - << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - KernelMsDeformAttnForwardSmallChannel( - k_dim, k_type, queue, data_type, (char*)value_ptr, - (char*)spatial_shapes_ptr, (char*)level_start_index_ptr, - (char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys, - num_heads, channels, num_levels, num_queries, num_points, - (char*)output_ptr); - break; - } - } - - output = output.view({batch_size, num_queries, num_heads * channels}); - return output; +Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const int im2col_step) { + at::DeviceGuard guard(value.device()); + return ms_deform_attn_impl_forward(value, spatial_shapes, level_start_index, + sampling_loc, attn_weight, im2col_step); } -void ms_deform_attn_mlu_backward( - const Tensor& value, const Tensor& spatial_shapes, - const Tensor& level_start_index, const Tensor& sampling_loc, - const Tensor& attn_weight, const Tensor& grad_output, Tensor& grad_value, - Tensor& grad_sampling_loc, Tensor& grad_attn_weight, - const int im2col_step) { - // check contiguous - AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); - AT_ASSERTM(spatial_shapes.is_contiguous(), - "spatial_shapes tensor has to be contiguous"); - AT_ASSERTM(level_start_index.is_contiguous(), - "level_start_index tensor has to be contiguous"); - AT_ASSERTM(sampling_loc.is_contiguous(), - "sampling_loc tensor has to be contiguous"); - AT_ASSERTM(attn_weight.is_contiguous(), - "attn_weight tensor has to be contiguous"); - AT_ASSERTM(grad_output.is_contiguous(), - "grad_output tensor has to be contiguous"); - - // check datatype - TORCH_CHECK((value.scalar_type() == at::kFloat), - "value type should be Float, got ", value.scalar_type(), "."); - TORCH_CHECK((spatial_shapes.scalar_type() == at::kInt || - spatial_shapes.scalar_type() == at::kLong), - "spatial_shapes type should be Int, got ", - spatial_shapes.scalar_type(), "."); - TORCH_CHECK((level_start_index.scalar_type() == at::kInt || - level_start_index.scalar_type() == at::kLong), - "level_start_index type should be Int, got ", - level_start_index.scalar_type(), "."); - TORCH_CHECK((sampling_loc.scalar_type() == at::kFloat), - "sampling_loc type should be Float, got ", - sampling_loc.scalar_type(), "."); - TORCH_CHECK((attn_weight.scalar_type() == at::kFloat), - "attn_weight type should be Float, got ", - attn_weight.scalar_type(), "."); - TORCH_CHECK((grad_output.scalar_type() == at::kFloat), - "grad_output type should be Float, got ", - grad_output.scalar_type(), "."); - - const int batch_size = value.size(0); - const int num_keys = value.size(1); - const int num_heads = value.size(2); - const int channels = value.size(3); - const int num_levels = spatial_shapes.size(0); - const int num_queries = sampling_loc.size(1); - const int num_points = sampling_loc.size(4); - // Check shape. - TORCH_CHECK(spatial_shapes.size(1) == 2, - "the 2nd dimensions of spatial_shapes should be 2, got ", - spatial_shapes.size(1), "."); - - TORCH_CHECK((level_start_index.size(0) == num_levels), - "the 1st dimensions of level_start_index should be num_levels, ", - "but now the 1st dimension of level_start_index is ", - level_start_index.size(0), ", and num_levels is ", num_levels, - "."); - - TORCH_CHECK((sampling_loc.size(0) == batch_size), - "the 1st dimensions of sampling_loc should be batch_size, ", - "but now the 1st dimension of sampling_loc is ", - sampling_loc.size(0), ", and batch_size is ", batch_size, "."); - TORCH_CHECK((sampling_loc.size(2) == num_heads), - "the 3rd dimensions of sampling_loc should be num_heads, ", - "but now the 3rd dimension of sampling_loc is ", - sampling_loc.size(2), ", and num_heads is ", num_heads, "."); - TORCH_CHECK((sampling_loc.size(3) == num_levels), - "the 4th dimensions of sampling_loc should be num_levels, ", - "but now the 4th dimension of sampling_loc is ", - sampling_loc.size(3), ", and num_levels is ", num_levels, "."); - TORCH_CHECK(sampling_loc.size(5) == 2, - "the 6th dimensions of sampling_loc should be 2, got ", - sampling_loc.size(5), "."); - - TORCH_CHECK((attn_weight.size(0) == batch_size), - "the 1st dimensions of attn_weight should be batch_size, ", - "but now the 1st dimension of attn_weight is ", - attn_weight.size(0), ", and batch_size is ", batch_size, "."); - TORCH_CHECK((attn_weight.size(1) == num_queries), - "the 2nd dimensions of attn_weight should be num_queries, ", - "but now the 2nd dimension of attn_weight is ", - attn_weight.size(1), ", and num_queries is ", num_queries, "."); - - TORCH_CHECK((attn_weight.size(2) == num_heads), - "the 3rd dimensions of attn_weight should be num_heads, ", - "but now the 3rd dimension of attn_weight is ", - attn_weight.size(2), ", and num_heads is ", num_heads, "."); - TORCH_CHECK((attn_weight.size(3) == num_levels), - "the 4th dimensions of attn_weight should be num_levels, ", - "but now the 4th dimension of attn_weight is ", - attn_weight.size(3), ", and num_levels is ", num_levels, "."); - TORCH_CHECK((attn_weight.size(4) == num_points), - "the 5th dimensions of attn_weight should be num_points, ", - "but now the 5th dimension of attn_weight is ", - attn_weight.size(4), ", and num_points is ", num_points, "."); - - TORCH_CHECK((grad_output.size(0) == batch_size), - "the 1st dimensions of grad_output should be batch_size, ", - "but now the 1st dimension of grad_output is ", - grad_output.size(0), ", and batch_size is ", batch_size, "."); - TORCH_CHECK((grad_output.size(1) == num_queries), - "the 2nd dimensions of grad_output should be num_queries, ", - "but now the 2nd dimension of grad_output is ", - grad_output.size(1), ", and num_queries is ", num_queries, "."); - TORCH_CHECK( - (grad_output.size(2) == num_heads * channels), - "the 3rd dimensions of grad_output should be num_heads * channels, ", - "but now the 3rd dimension of grad_output is ", grad_output.size(2), - ", and num_heads * channels is ", num_heads * channels, "."); - - // check zero element - TORCH_CHECK(batch_size != 0, "The batch_size is zero."); - TORCH_CHECK(channels != 0, "The channels is zero."); - TORCH_CHECK(num_keys != 0, "The num_keys is zero."); - TORCH_CHECK(num_heads != 0, "The num_heads is zero."); - TORCH_CHECK(num_queries != 0, "The num_queries is zero."); - if (num_levels == 0 || num_points == 0) { - return; - } - - // calculate task dimension - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - policyFuncBackward(batch_size, num_queries, num_heads, num_levels, &k_type, - &k_dim); - - // get compute queue - auto queue = torch_mlu::getCurQueue(); - - // get ptr of tensors - auto value_impl = torch_mlu::getMluTensorImpl(value); - auto value_ptr = value_impl->cnnlMalloc(); - auto spatial_shapes_impl = torch_mlu::getMluTensorImpl(spatial_shapes); - auto spatial_shapes_ptr = spatial_shapes_impl->cnnlMalloc(); - auto level_start_index_impl = torch_mlu::getMluTensorImpl(level_start_index); - auto level_start_index_ptr = level_start_index_impl->cnnlMalloc(); - auto sampling_loc_impl = torch_mlu::getMluTensorImpl(sampling_loc); - auto sampling_loc_ptr = sampling_loc_impl->cnnlMalloc(); - auto attn_weight_impl = torch_mlu::getMluTensorImpl(attn_weight); - auto attn_weight_ptr = attn_weight_impl->cnnlMalloc(); - auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output); - auto grad_output_ptr = grad_output_impl->cnnlMalloc(); - auto grad_value_impl = torch_mlu::getMluTensorImpl(grad_value); - auto grad_value_ptr = grad_value_impl->cnnlMalloc(); - auto grad_sampling_loc_impl = torch_mlu::getMluTensorImpl(grad_sampling_loc); - auto grad_sampling_loc_ptr = grad_sampling_loc_impl->cnnlMalloc(); - auto grad_attn_weight_impl = torch_mlu::getMluTensorImpl(grad_attn_weight); - auto grad_attn_weight_ptr = grad_attn_weight_impl->cnnlMalloc(); - - // get comput dtype of input - cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype()); - - // launch kernel - CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x - << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - MsDeformAttnBackwardKernelPolicy kernelPolicy = - msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points); - switch (kernelPolicy) { - default: { - VLOG(5) << "NotImplemented."; - } break; - case MS_DEFORM_ATTN_BACKWARD_DEFAULT: { - KernelMsDeformAttnBackwardDefaultKernel( - k_dim, k_type, queue, data_type, (float*)value_ptr, - (int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr, - (float*)sampling_loc_ptr, (float*)attn_weight_ptr, - (float*)grad_output_ptr, batch_size, num_keys, num_heads, channels, - num_levels, num_queries, num_points, (float*)grad_value_ptr, - (float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr); - } break; - case MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL: { - KernelMsDeformAttnBackwardSmallChannelsKernel( - k_dim, k_type, queue, data_type, (float*)value_ptr, - (int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr, - (float*)sampling_loc_ptr, (float*)attn_weight_ptr, - (float*)grad_output_ptr, batch_size, num_keys, num_heads, channels, - num_levels, num_queries, num_points, (float*)grad_value_ptr, - (float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr); - } break; - } +void ms_deform_attn_backward(const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const Tensor &grad_output, Tensor &grad_value, + Tensor &grad_sampling_loc, + Tensor &grad_attn_weight, const int im2col_step) { + at::DeviceGuard guard(value.device()); + ms_deform_attn_impl_backward(value, spatial_shapes, level_start_index, + sampling_loc, attn_weight, grad_output, + grad_value, grad_sampling_loc, grad_attn_weight, + im2col_step); } - -Tensor ms_deform_attn_impl_forward(const Tensor& value, - const Tensor& spatial_shapes, - const Tensor& level_start_index, - const Tensor& sampling_loc, - const Tensor& attn_weight, - const int im2col_step); - -void ms_deform_attn_impl_backward( - const Tensor& value, const Tensor& spatial_shapes, - const Tensor& level_start_index, const Tensor& sampling_loc, - const Tensor& attn_weight, const Tensor& grad_output, Tensor& grad_value, - Tensor& grad_sampling_loc, Tensor& grad_attn_weight, const int im2col_step); - -REGISTER_DEVICE_IMPL(ms_deform_attn_impl_forward, MLU, - ms_deform_attn_mlu_forward); -REGISTER_DEVICE_IMPL(ms_deform_attn_impl_backward, MLU, - ms_deform_attn_mlu_backward);