diff --git a/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.mlu index 6afe03ebc2..f95d3129c8 100644 --- a/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.mlu @@ -25,667 +25,752 @@ #define NRAM_REMAIN_SIZE (48 * 1024) #define NRAM_AVALIABLE_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE) - + +#define SRAM_REMAIN_SIZE (32 * 1024) +#define SRAM_AVALIABLE_SIZE (__MLU_SRAM_SIZE__ * 1024 - SRAM_REMAIN_SIZE) +#define SRAM_FOR_VALUE_SIZE (SRAM_AVALIABLE_SIZE - 128) + +#define MAX_MEMCPY_SEGNUM 65536 + __nram__ char nram_buffer[NRAM_AVALIABLE_SIZE]; - +__mlu_shared__ char sram_buffer[SRAM_AVALIABLE_SIZE]; + template __mlu_func__ inline T __mluop_min(T a, T b) { return a < b ? a : b; } - + template __mlu_func__ inline T __mluop_max(T a, T b) { return a > b ? a : b; } - -template -__mlu_func__ void __mluop_floor(T* dst_ram, T* src_ram, int size) { - if (sizeof(T) == sizeof(float)) { - int16* mid = (int16*)(dst_ram + size / 2); - __bang_float2int16_dn(mid, (float*)src_ram, size, 0); - __bang_int162float((float*)dst_ram, (int16_t*)mid, size, 0); - } else { - __bang_half2int16_dn((int16_t*)dst_ram, (half*)src_ram, size, 0); - __bang_int162half((half*)dst_ram, (int16_t*)dst_ram, size, 0); - } -} - -__mlu_func__ void broadcastSpatialHW( - float* spatial_offset_bd_nram, // (num_levels, num_points) - float* spatial_h_bd_nram, // (num_levels, num_points) - float* spatial_w_bd_nram, // (num_levels, num_points) - int32_t* spatial_shapes_nram, // (num_levels, 2) - int32_t* spatial_offset_nram, // (num_levels) - const int32_t num_levels, const int32_t num_points) { - for (int i = 0; i < num_levels * 2; i++) { - ((float*)spatial_shapes_nram)[i] = (float)spatial_shapes_nram[i]; - } - - for (int i = 0; i < num_levels; i++) { - ((float*)spatial_offset_nram)[i] = (float)spatial_offset_nram[i]; - } - - for (int i = 0; i < num_levels; i++) { - __memcpy(spatial_h_bd_nram + i * num_points, spatial_shapes_nram + i * 2, - sizeof(float), NRAM2NRAM, sizeof(float), 0, num_points - 1); - } - - for (int i = 0; i < num_levels; i++) { - __memcpy(spatial_w_bd_nram + i * num_points, - spatial_shapes_nram + 1 + i * 2, sizeof(float), NRAM2NRAM, - sizeof(float), 0, num_points - 1); - } - - for (int i = 0; i < num_levels; i++) { - __memcpy(spatial_offset_bd_nram + i * num_points, spatial_offset_nram + i, - sizeof(float), NRAM2NRAM, sizeof(float), 0, num_points - 1); - } -} - -template -__mlu_func__ void getConditionCoordWeight( - int32_t* data_offset_nram, T* weight_polation_nram, - T* cond_point_polation_nram, T* cond_point_valid_nram, T* loc_nram, - T* weight_attn_nram, T* spatial_offset_bd_nram, T* spatial_w_bd_nram, - T* spatial_h_bd_nram, T* buf_nram, const int32_t deal_n, - const int32_t num_levels, const int32_t num_points, const int32_t num_heads, - int32_t pad_num_levels_points) { - int32_t pad_total_points = deal_n * pad_num_levels_points; - int32_t pad_block_points = pad_num_levels_points; - T* buf_x_nram = buf_nram; - T* buf_y_nram = buf_nram + pad_total_points; - T* buf_cond_nram = buf_nram + 2 * pad_total_points; - T* buf_x_floor = buf_nram + 2 * pad_total_points; - T* buf_y_floor = buf_nram + 3 * pad_total_points; - T* buf_x_ceil = buf_nram + 4 * pad_total_points; - T* buf_y_ceil = buf_nram + 5 * pad_total_points; - __sync_io_move_compute(); - __bang_write_value(buf_x_nram, pad_total_points, 0); - __bang_write_value(buf_y_nram, pad_total_points, 0); - __bang_write_value(buf_x_floor, pad_total_points, 0); - __bang_write_value(buf_x_ceil, pad_total_points, 0); - __bang_write_value(buf_y_floor, pad_total_points, 0); - __bang_write_value(buf_y_ceil, pad_total_points, 0); - - //================================================================================================ - __memcpy(buf_x_nram, loc_nram, sizeof(T), NRAM2NRAM, sizeof(T), 2 * sizeof(T), - pad_total_points - 1); - __memcpy(buf_y_nram, loc_nram + 1, sizeof(T), NRAM2NRAM, sizeof(T), - 2 * sizeof(T), pad_total_points - 1); - - // x = loc_x * spatial_w - 0.5; y = loc_y * spatial_h - 0.5; - __bang_cycle_mul(buf_x_nram, buf_x_nram, spatial_w_bd_nram, pad_total_points, - pad_block_points); - __bang_sub_scalar(buf_x_nram, buf_x_nram, (T)0.5, pad_total_points); - __bang_cycle_mul(buf_y_nram, buf_y_nram, spatial_h_bd_nram, pad_total_points, - pad_block_points); - __bang_sub_scalar(buf_y_nram, buf_y_nram, (T)0.5, pad_total_points); - - //================================================================================================ - // get point condition. use buf0, buf1, buf2 - // (x > -1 && y > -1 && y < spatial_h && x < spatial_w) - __bang_write_value(cond_point_valid_nram, pad_total_points, (T)-1.0); - __bang_gt(cond_point_valid_nram, buf_x_nram, cond_point_valid_nram, - pad_total_points); - __bang_write_value(buf_cond_nram, pad_total_points, (T)-1.0); - __bang_gt(buf_cond_nram, buf_y_nram, buf_cond_nram, pad_total_points); - - __bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram, + + template + __mlu_func__ void __mluop_floor(T* dst_ram, T* src_ram, int size) { + if (sizeof(T) == sizeof(float)) { + int16* mid = (int16*)(dst_ram + size / 2); + __bang_float2int16_dn(mid, (float*)src_ram, size, 0); + __bang_int162float((float*)dst_ram, (int16_t*)mid, size, 0); + } else { + __bang_half2int16_dn((int16_t*)dst_ram, (half*)src_ram, size, 0); + __bang_int162half((half*)dst_ram, (int16_t*)dst_ram, size, 0); + } + } + + __mlu_func__ void broadcastSpatialHW( + float* spatial_offset_bd_nram, // (num_levels, num_points) + float* spatial_h_bd_nram, // (num_levels, num_points) + float* spatial_w_bd_nram, // (num_levels, num_points) + int32_t* spatial_shapes_nram, // (num_levels, 2) + int32_t* spatial_offset_nram, // (num_levels) + const int32_t num_levels, const int32_t num_points) { + for (int i = 0; i < num_levels * 2; i++) { + ((float*)spatial_shapes_nram)[i] = (float)spatial_shapes_nram[i]; + } + + for (int i = 0; i < num_levels; i++) { + ((float*)spatial_offset_nram)[i] = (float)spatial_offset_nram[i]; + } + + for (int i = 0; i < num_levels; i++) { + __memcpy(spatial_h_bd_nram + i * num_points, spatial_shapes_nram + i * 2, + sizeof(float), NRAM2NRAM, sizeof(float), 0, num_points - 1); + } + + for (int i = 0; i < num_levels; i++) { + __memcpy(spatial_w_bd_nram + i * num_points, + spatial_shapes_nram + 1 + i * 2, sizeof(float), NRAM2NRAM, + sizeof(float), 0, num_points - 1); + } + + for (int i = 0; i < num_levels; i++) { + __memcpy(spatial_offset_bd_nram + i * num_points, spatial_offset_nram + i, + sizeof(float), NRAM2NRAM, sizeof(float), 0, num_points - 1); + } + + } + + template + __mlu_func__ void getConditionCoordWeight( + int32_t* data_offset_nram, T* weight_polation_nram, + T* cond_point_polation_nram, T* cond_point_valid_nram, T* loc_nram, + T* weight_attn_nram, T* spatial_offset_bd_nram, T* spatial_w_bd_nram, + T* spatial_h_bd_nram, T* buf_nram, const int32_t deal_n, + const int32_t num_levels, const int32_t num_points, const int32_t num_heads, + int32_t pad_num_levels_points) { + int32_t pad_total_points = deal_n * pad_num_levels_points; + int32_t pad_block_points = pad_num_levels_points; + T* buf_x_nram = buf_nram; + T* buf_y_nram = buf_nram + pad_total_points; + T* buf_cond_nram = buf_nram + 2 * pad_total_points; + T* buf_x_floor = buf_nram + 2 * pad_total_points; + T* buf_y_floor = buf_nram + 3 * pad_total_points; + T* buf_x_ceil = buf_nram + 4 * pad_total_points; + T* buf_y_ceil = buf_nram + 5 * pad_total_points; + // __sync_io_move_compute(); + __bang_write_value(buf_x_nram, pad_total_points, 0); + __bang_write_value(buf_y_nram, pad_total_points, 0); + __bang_write_value(buf_x_floor, pad_total_points, 0); + __bang_write_value(buf_x_ceil, pad_total_points, 0); + __bang_write_value(buf_y_floor, pad_total_points, 0); + __bang_write_value(buf_y_ceil, pad_total_points, 0); + + //================================================================================================ + __memcpy(buf_x_nram, loc_nram, sizeof(T), NRAM2NRAM, sizeof(T), 2 * sizeof(T), + pad_total_points - 1); + __memcpy(buf_y_nram, loc_nram + 1, sizeof(T), NRAM2NRAM, sizeof(T), + 2 * sizeof(T), pad_total_points - 1); + + // x = loc_x * spatial_w - 0.5; y = loc_y * spatial_h - 0.5; + __bang_cycle_mul(buf_x_nram, buf_x_nram, spatial_w_bd_nram, pad_total_points, + pad_block_points); + __bang_sub_scalar(buf_x_nram, buf_x_nram, (T)0.5, pad_total_points); + __bang_cycle_mul(buf_y_nram, buf_y_nram, spatial_h_bd_nram, pad_total_points, + pad_block_points); + __bang_sub_scalar(buf_y_nram, buf_y_nram, (T)0.5, pad_total_points); + + //================================================================================================ + // get point condition. use buf0, buf1, buf2 + // (x > -1 && y > -1 && y < spatial_h && x < spatial_w) + __bang_write_value(cond_point_valid_nram, pad_total_points, (T)-1.0); + __bang_gt(cond_point_valid_nram, buf_x_nram, cond_point_valid_nram, pad_total_points); - __bang_cycle_lt(buf_cond_nram, buf_x_nram, spatial_w_bd_nram, - pad_total_points, pad_block_points); - __bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram, - pad_total_points); - __bang_cycle_lt(buf_cond_nram, buf_y_nram, spatial_h_bd_nram, - pad_total_points, pad_block_points); - __bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram, - pad_total_points); - //================================================================================================ - __mluop_floor(buf_x_floor, buf_x_nram, 2 * pad_total_points); - __bang_add_scalar(buf_x_ceil, buf_x_floor, 1.0, pad_total_points); - __bang_add_scalar(buf_y_ceil, buf_y_floor, 1.0, pad_total_points); - - T* cond_point_polation_nram_tl = cond_point_polation_nram; - T* cond_point_polation_nram_bl = cond_point_polation_nram + pad_total_points; - T* cond_point_polation_nram_tr = - cond_point_polation_nram + 2 * pad_total_points; - T* cond_point_polation_nram_br = - cond_point_polation_nram + 3 * pad_total_points; - T* cond_point_polation_nram_cond1 = weight_polation_nram; - T* cond_point_polation_nram_cond2 = weight_polation_nram + pad_total_points; - T* cond_point_polation_nram_cond3 = - weight_polation_nram + 2 * pad_total_points; - T* cond_point_polation_nram_cond4 = - weight_polation_nram + 3 * pad_total_points; - __bang_ge_scalar(cond_point_polation_nram_cond1, buf_x_floor, (T)0, - pad_total_points); - __bang_cycle_lt(cond_point_polation_nram_cond2, buf_x_ceil, spatial_w_bd_nram, - pad_total_points, pad_block_points); - __bang_ge_scalar(cond_point_polation_nram_cond3, buf_y_floor, (T)0, - pad_total_points); - __bang_cycle_lt(cond_point_polation_nram_cond4, buf_y_ceil, spatial_h_bd_nram, - pad_total_points, pad_block_points); - __bang_and(cond_point_polation_nram_tl, cond_point_polation_nram_cond1, - cond_point_polation_nram_cond4, pad_total_points); - __bang_and(cond_point_polation_nram_bl, cond_point_polation_nram_cond1, - cond_point_polation_nram_cond3, pad_total_points); - __bang_and(cond_point_polation_nram_tr, cond_point_polation_nram_cond2, - cond_point_polation_nram_cond4, pad_total_points); - __bang_and(cond_point_polation_nram_br, cond_point_polation_nram_cond2, - cond_point_polation_nram_cond3, pad_total_points); - //================================================================================================ - // get polation weight. - T* buf_dx = (T*)data_offset_nram; - T* buf_dy = buf_dx + pad_total_points; - T* buf_dx_1 = buf_dy + pad_total_points; - T* buf_dy_1 = buf_dx_1 + pad_total_points; - // -dx = x_floor-x - // -dy = y_floor-y - // w1 = (1-dx)*dy = (dx-1)*(-dy) - // w2 = (1-dx)*(1-dy) = (dx-1)*(dy-1) - // w3 = dx*dy = (-dx)*(-dy) - // w4 = dx*(1-dy) = (-dx)*(dy-1) - T* weight_polation_nram_1 = weight_polation_nram; - T* weight_polation_nram_2 = weight_polation_nram + 1 * pad_total_points; - T* weight_polation_nram_3 = weight_polation_nram + 2 * pad_total_points; - T* weight_polation_nram_4 = weight_polation_nram + 3 * pad_total_points; - // T* weight_polation_nram_buf = buf_nram + 4 * total_points; - __bang_sub(buf_dx, buf_x_floor, buf_x_nram, pad_total_points); - __bang_sub(buf_dy, buf_y_floor, buf_y_nram, pad_total_points); - - __bang_sub(buf_dx_1, buf_x_nram, buf_x_floor, pad_total_points); - __bang_sub_scalar(buf_dx_1, buf_dx_1, (T)1.0, pad_total_points); - - __bang_sub(buf_dy_1, buf_y_nram, buf_y_floor, pad_total_points); - __bang_sub_scalar(buf_dy_1, buf_dy_1, (T)1.0, pad_total_points); - - __bang_mul(weight_polation_nram_1, buf_dx_1, buf_dy, pad_total_points); - __bang_mul(weight_polation_nram_2, buf_dx_1, buf_dy_1, pad_total_points); - __bang_mul(weight_polation_nram_3, buf_dx, buf_dy, pad_total_points); - __bang_mul(weight_polation_nram_4, buf_dx, buf_dy_1, pad_total_points); - //================================================================================================ - // correct the x,y in [0, w-1] and [0, h-1] - T* spatial_w1_bd_nram = buf_nram; - T* spatial_h1_bd_nram = buf_nram + pad_total_points; - __bang_sub_scalar(spatial_w1_bd_nram, spatial_w_bd_nram, (T)1, - pad_total_points); - __bang_sub_scalar(spatial_h1_bd_nram, spatial_h_bd_nram, (T)1, - pad_total_points); - T* maxtemp = (T*)data_offset_nram; - __bang_write_value(maxtemp, pad_total_points, (T)0); - __bang_maxequal(buf_x_floor, buf_x_floor, maxtemp, pad_total_points); - __bang_maxequal(buf_x_ceil, buf_x_ceil, maxtemp, pad_total_points); - __bang_cycle_minequal(buf_x_floor, buf_x_floor, spatial_w1_bd_nram, - pad_total_points, pad_block_points); - __bang_cycle_minequal(buf_x_ceil, buf_x_ceil, spatial_w1_bd_nram, - pad_total_points, pad_block_points); - __bang_maxequal(buf_y_floor, buf_y_floor, maxtemp, pad_total_points); - __bang_maxequal(buf_y_ceil, buf_y_ceil, maxtemp, pad_total_points); - __bang_cycle_minequal(buf_y_floor, buf_y_floor, spatial_h1_bd_nram, - pad_total_points, pad_block_points); - __bang_cycle_minequal(buf_y_ceil, buf_y_ceil, spatial_h1_bd_nram, - pad_total_points, pad_block_points); - //================================================================================================ - // offset = y*w + x - T* buf_hw_offset = buf_nram; - T* data_offset_nram_tl = (T*)data_offset_nram; - T* data_offset_nram_bl = data_offset_nram_tl + pad_total_points; - T* data_offset_nram_tr = data_offset_nram_bl + pad_total_points; - T* data_offset_nram_br = data_offset_nram_tr + pad_total_points; - // y_ceil*w + offset + x_floor - __bang_cycle_mul(buf_hw_offset, buf_y_ceil, spatial_w_bd_nram, + __bang_write_value(buf_cond_nram, pad_total_points, (T)-1.0); + __bang_gt(buf_cond_nram, buf_y_nram, buf_cond_nram, pad_total_points); + + __bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram, + pad_total_points); + __bang_cycle_lt(buf_cond_nram, buf_x_nram, spatial_w_bd_nram, pad_total_points, pad_block_points); - __bang_cycle_add(buf_hw_offset, buf_hw_offset, spatial_offset_bd_nram, + __bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram, + pad_total_points); + __bang_cycle_lt(buf_cond_nram, buf_y_nram, spatial_h_bd_nram, pad_total_points, pad_block_points); - __bang_add(data_offset_nram_tl, buf_hw_offset, buf_x_floor, pad_total_points); - // y_ceil*w + offset + x_ceil - __bang_add(data_offset_nram_tr, buf_hw_offset, buf_x_ceil, pad_total_points); - // y_floor*w + offset + x_foor - __bang_cycle_mul(buf_hw_offset, buf_y_floor, spatial_w_bd_nram, + __bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram, + pad_total_points); + //================================================================================================ + __mluop_floor(buf_x_floor, buf_x_nram, 2 * pad_total_points); + __bang_add_scalar(buf_x_ceil, buf_x_floor, 1.0, pad_total_points); + __bang_add_scalar(buf_y_ceil, buf_y_floor, 1.0, pad_total_points); + + T* cond_point_polation_nram_tl = cond_point_polation_nram; + T* cond_point_polation_nram_bl = cond_point_polation_nram + pad_total_points; + T* cond_point_polation_nram_tr = + cond_point_polation_nram + 2 * pad_total_points; + T* cond_point_polation_nram_br = + cond_point_polation_nram + 3 * pad_total_points; + T* cond_point_polation_nram_cond1 = weight_polation_nram; + T* cond_point_polation_nram_cond2 = weight_polation_nram + pad_total_points; + T* cond_point_polation_nram_cond3 = + weight_polation_nram + 2 * pad_total_points; + T* cond_point_polation_nram_cond4 = + weight_polation_nram + 3 * pad_total_points; + __bang_ge_scalar(cond_point_polation_nram_cond1, buf_x_floor, (T)0, + pad_total_points); + __bang_cycle_lt(cond_point_polation_nram_cond2, buf_x_ceil, spatial_w_bd_nram, pad_total_points, pad_block_points); - __bang_cycle_add(buf_hw_offset, buf_hw_offset, spatial_offset_bd_nram, + __bang_ge_scalar(cond_point_polation_nram_cond3, buf_y_floor, (T)0, + pad_total_points); + __bang_cycle_lt(cond_point_polation_nram_cond4, buf_y_ceil, spatial_h_bd_nram, pad_total_points, pad_block_points); - - __bang_add(data_offset_nram_bl, buf_hw_offset, buf_x_floor, pad_total_points); - // y_floor*w + offset + x_ceil - __bang_add(data_offset_nram_br, buf_hw_offset, buf_x_ceil, pad_total_points); - __bang_cycle_and(cond_point_polation_nram, cond_point_polation_nram, - cond_point_valid_nram, 4 * pad_total_points, - pad_total_points); - __bang_cycle_mul(weight_polation_nram, weight_polation_nram, weight_attn_nram, - 4 * pad_total_points, pad_total_points); - __bang_mul(weight_polation_nram, weight_polation_nram, - cond_point_polation_nram, pad_total_points * 4); - __bang_sub((float*)data_offset_nram_bl, (float*)data_offset_nram_bl, - (float*)data_offset_nram_tl, pad_total_points); - __bang_sub((float*)data_offset_nram_tr, (float*)data_offset_nram_tr, - (float*)data_offset_nram_tl, pad_total_points); -} - -/* - shape of each tensor: - output_nram: (channels) - input_nram: (4, valid_num, channels) - input_trans: (channels, 4, valid_num) - weight_selected_base: (4, deal_n, num_levels, num_points) - weight_compute: (4, valid_num) -*/ -template -__mlu_func__ void reduceLevelByConv(T* output_nram, T* input_nram, - T* input_trans, T* weight_selected_base, - T* weight_compute, - const int32_t pad_num_levels_points, - const int32_t pad_channels, - const int32_t pad_sample_stride_3) { - int32_t ci = 4 * pad_num_levels_points; - int32_t co = pad_channels; - __bang_write_value(weight_compute, 4 * pad_num_levels_points, 0); - __memcpy(weight_compute, weight_selected_base, - pad_num_levels_points * sizeof(T), NRAM2NRAM, - pad_num_levels_points * sizeof(T), pad_sample_stride_3 * sizeof(T), - 3); - __bang_transpose(input_trans, input_nram, ci, co); - __bang_cycle_mul(input_trans, input_trans, weight_compute, co * ci, ci); - __bang_sumpool(input_nram, input_trans, pad_num_levels_points, pad_channels, - 4, 1, 4, 1, 1); - __bang_transpose(input_trans, input_nram, pad_channels, - pad_num_levels_points); - __bang_sumpool(output_nram, input_trans, pad_channels, pad_num_levels_points, - 1, pad_num_levels_points, 1, 1, 1); -} - -__mlu_func__ void loadNram2Gpr(int32_t& v1, int32_t& v2, int32_t& v3, - int32_t* p1, int32_t* p2, int32_t* p3, - int32_t num_heads, int32_t channels_size) { - int32_t stride = num_heads * channels_size; - v1 = (int32_t)(*(float*)p1) * stride; - v2 = (int32_t)(*(float*)p2) * stride; - v3 = (int32_t)(*(float*)p3) * stride; -} - -/* - Load 4 neighbors use one 3D-memcpy, just use offset of N1, stride_3_1 - and - stride_2_1. - |<- stride_3_1 ->| - N1 N3 - ^ - | - stride_2_1 - | - v - N2 N4 - - Trickly fold the loop as 2. -*/ -template + __bang_and(cond_point_polation_nram_tl, cond_point_polation_nram_cond1, + cond_point_polation_nram_cond4, pad_total_points); + __bang_and(cond_point_polation_nram_bl, cond_point_polation_nram_cond1, + cond_point_polation_nram_cond3, pad_total_points); + __bang_and(cond_point_polation_nram_tr, cond_point_polation_nram_cond2, + cond_point_polation_nram_cond4, pad_total_points); + __bang_and(cond_point_polation_nram_br, cond_point_polation_nram_cond2, + cond_point_polation_nram_cond3, pad_total_points); + //================================================================================================ + // get polation weight. + T* buf_dx = (T*)data_offset_nram; + T* buf_dy = buf_dx + pad_total_points; + T* buf_dx_1 = buf_dy + pad_total_points; + T* buf_dy_1 = buf_dx_1 + pad_total_points; + // -dx = x_floor-x + // -dy = y_floor-y + // w1 = (1-dx)*dy = (dx-1)*(-dy) + // w2 = (1-dx)*(1-dy) = (dx-1)*(dy-1) + // w3 = dx*dy = (-dx)*(-dy) + // w4 = dx*(1-dy) = (-dx)*(dy-1) + T* weight_polation_nram_1 = weight_polation_nram; + T* weight_polation_nram_2 = weight_polation_nram + 1 * pad_total_points; + T* weight_polation_nram_3 = weight_polation_nram + 2 * pad_total_points; + T* weight_polation_nram_4 = weight_polation_nram + 3 * pad_total_points; + // T* weight_polation_nram_buf = buf_nram + 4 * total_points; + __bang_sub(buf_dx, buf_x_floor, buf_x_nram, pad_total_points); + __bang_sub(buf_dy, buf_y_floor, buf_y_nram, pad_total_points); + + __bang_sub(buf_dx_1, buf_x_nram, buf_x_floor, pad_total_points); + __bang_sub_scalar(buf_dx_1, buf_dx_1, (T)1.0, pad_total_points); + + __bang_sub(buf_dy_1, buf_y_nram, buf_y_floor, pad_total_points); + __bang_sub_scalar(buf_dy_1, buf_dy_1, (T)1.0, pad_total_points); + + __bang_mul(weight_polation_nram_1, buf_dx_1, buf_dy, pad_total_points); + __bang_mul(weight_polation_nram_2, buf_dx_1, buf_dy_1, pad_total_points); + __bang_mul(weight_polation_nram_3, buf_dx, buf_dy, pad_total_points); + __bang_mul(weight_polation_nram_4, buf_dx, buf_dy_1, pad_total_points); + //================================================================================================ + // correct the x,y in [0, w-1] and [0, h-1] + T* spatial_w1_bd_nram = buf_nram; + T* spatial_h1_bd_nram = buf_nram + pad_total_points; + __bang_sub_scalar(spatial_w1_bd_nram, spatial_w_bd_nram, (T)1, + pad_total_points); + __bang_sub_scalar(spatial_h1_bd_nram, spatial_h_bd_nram, (T)1, + pad_total_points); + T* maxtemp = (T*)data_offset_nram; + __bang_write_value(maxtemp, pad_total_points, (T)0); + __bang_maxequal(buf_x_floor, buf_x_floor, maxtemp, pad_total_points); + __bang_maxequal(buf_x_ceil, buf_x_ceil, maxtemp, pad_total_points); + __bang_cycle_minequal(buf_x_floor, buf_x_floor, spatial_w1_bd_nram, + pad_total_points, pad_block_points); + __bang_cycle_minequal(buf_x_ceil, buf_x_ceil, spatial_w1_bd_nram, + pad_total_points, pad_block_points); + __bang_maxequal(buf_y_floor, buf_y_floor, maxtemp, pad_total_points); + __bang_maxequal(buf_y_ceil, buf_y_ceil, maxtemp, pad_total_points); + __bang_cycle_minequal(buf_y_floor, buf_y_floor, spatial_h1_bd_nram, + pad_total_points, pad_block_points); + __bang_cycle_minequal(buf_y_ceil, buf_y_ceil, spatial_h1_bd_nram, + pad_total_points, pad_block_points); + //================================================================================================ + // offset = y*w + x + T* buf_hw_offset = buf_nram; + T* data_offset_nram_tl = (T*)data_offset_nram; + T* data_offset_nram_bl = data_offset_nram_tl + pad_total_points; + T* data_offset_nram_tr = data_offset_nram_bl + pad_total_points; + T* data_offset_nram_br = data_offset_nram_tr + pad_total_points; + // y_ceil*w + offset + x_floor + __bang_cycle_mul(buf_hw_offset, buf_y_ceil, spatial_w_bd_nram, + pad_total_points, pad_block_points); + __bang_cycle_add(buf_hw_offset, buf_hw_offset, spatial_offset_bd_nram, + pad_total_points, pad_block_points); + __bang_add(data_offset_nram_tl, buf_hw_offset, buf_x_floor, pad_total_points); + // y_ceil*w + offset + x_ceil + __bang_add(data_offset_nram_tr, buf_hw_offset, buf_x_ceil, pad_total_points); + // y_floor*w + offset + x_foor + __bang_cycle_mul(buf_hw_offset, buf_y_floor, spatial_w_bd_nram, + pad_total_points, pad_block_points); + __bang_cycle_add(buf_hw_offset, buf_hw_offset, spatial_offset_bd_nram, + pad_total_points, pad_block_points); + + __bang_add(data_offset_nram_bl, buf_hw_offset, buf_x_floor, pad_total_points); + // y_floor*w + offset + x_ceil + __bang_add(data_offset_nram_br, buf_hw_offset, buf_x_ceil, pad_total_points); + __bang_cycle_and(cond_point_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, 4 * pad_total_points, + pad_total_points); + __bang_cycle_mul(weight_polation_nram, weight_polation_nram, weight_attn_nram, + 4 * pad_total_points, pad_total_points); + __bang_mul(weight_polation_nram, weight_polation_nram, + cond_point_polation_nram, pad_total_points * 4); + __bang_sub((float*)data_offset_nram_bl, (float*)data_offset_nram_bl, + (float*)data_offset_nram_tl, pad_total_points); + __bang_sub((float*)data_offset_nram_tr, (float*)data_offset_nram_tr, + (float*)data_offset_nram_tl, pad_total_points); + } + + /* + shape of each tensor: + output_nram: (channels) + input_nram: (4, valid_num, channels) + input_trans: (channels, 4, valid_num) + weight_selected_base: (4, deal_n, num_levels, num_points) + weight_compute: (4, valid_num) + */ + template + __mlu_func__ void reduceLevelByConv(T* output_nram, T* input_nram, + T* input_trans, T* weight_selected_base, + T* weight_compute, + const int32_t pad_num_levels_points, + const int32_t pad_channels, + const int32_t pad_sample_stride_3) { + int32_t ci = 4 * pad_num_levels_points; + int32_t co = pad_channels; + __bang_write_value(weight_compute, 4 * pad_num_levels_points, 0); + __memcpy(weight_compute, weight_selected_base, + pad_num_levels_points * sizeof(T), NRAM2NRAM, + pad_num_levels_points * sizeof(T), pad_sample_stride_3 * sizeof(T), + 3); + __bang_transpose(input_trans, input_nram, ci, co); + __bang_cycle_mul(input_trans, input_trans, weight_compute, co * ci, ci); + __bang_sumpool(input_nram, input_trans, pad_num_levels_points, pad_channels, + 4, 1, 4, 1, 1); + __bang_transpose(input_trans, input_nram, pad_channels, + pad_num_levels_points); + __bang_sumpool(output_nram, input_trans, pad_channels, pad_num_levels_points, + 1, pad_num_levels_points, 1, 1, 1); + } + + + __mlu_func__ void loadNram2Gpr(int32_t& v1, int32_t& v2, int32_t& v3, + int32_t* p1, int32_t* p2, int32_t* p3, + int32_t num_heads, int32_t channels_size, bool sram_stay, int32_t sram_level_start_index) { + v1 = (int32_t)(*(float*)p1); + v2 = (int32_t)(*(float*)p2); + v3 = (int32_t)(*(float*)p3); + int32_t stride = sram_stay? channels_size : num_heads * channels_size; + if(sram_stay) { + v1 = (v1 - sram_level_start_index) * stride; + v2 = v2 * stride; + v3 = v3 * stride; + } else { + v1 = v1 * stride; + v2 = v2 * stride; + v3 = v3 * stride; + } + } + + /* + Load 4 neighbors use one 3D-memcpy, just use offset of N1, stride_3_1 + and + stride_2_1. + |<- stride_3_1 ->| + N1 N3 + ^ + | + stride_2_1 + | + v + N2 N4 + + Trickly fold the loop as 2. + */ +template __mlu_func__ void loadDataValueXram2NramAsync( T* buf_value_nram_1, int32_t* offset_1, int32_t* stride_2_1, - int32_t* stride_3_1, T* value_src, const int32_t num_levels_points, - const int32_t channel_size, const int32_t num_heads) { + int32_t* stride_3_1, T* value_src, const int32_t pad_num_levels_points, const int32_t deal_points, const int32_t start_points_index, + const int32_t channel_size, const int32_t num_heads, bool sram_stay, const int32_t sram_level_start_offset) { int32_t offset_1_a, stride_2_1_a, stride_3_1_a; int32_t offset_1_b, stride_2_1_b, stride_3_1_b; - loadNram2Gpr(offset_1_a, stride_2_1_a, stride_3_1_a, offset_1, stride_2_1, - stride_3_1, num_heads, channel_size); - loadNram2Gpr(offset_1_b, stride_2_1_b, stride_3_1_b, offset_1 + 1, - stride_2_1 + 1, stride_3_1 + 1, num_heads, channel_size); + loadNram2Gpr(offset_1_a, stride_2_1_a, stride_3_1_a, offset_1 + start_points_index, stride_2_1 + start_points_index, + stride_3_1 + start_points_index, num_heads, channel_size, sram_stay, sram_level_start_offset); + loadNram2Gpr(offset_1_b, stride_2_1_b, stride_3_1_b, offset_1 + start_points_index + 1, + stride_2_1 + start_points_index + 1, stride_3_1 + start_points_index + 1, num_heads, channel_size, sram_stay, sram_level_start_offset); int32_t value_offset = 0; int32_t next = 0; - int32_t loop_num = num_levels_points / 2; - int32_t remain = num_levels_points % 2; - - int32_t pad_num_levels_points = - PAD_UP(num_levels_points, NFU_ALIGN_SIZE / sizeof(T)); + int32_t loop_num = deal_points / 2; + int32_t remain = deal_points % 2; int32_t pad_channels = PAD_UP(channel_size / sizeof(T), NFU_ALIGN_SIZE / sizeof(T)); int32_t pad_channels_size = pad_channels * sizeof(T); int32_t pad_data_value_stride = pad_num_levels_points * pad_channels_size; - for (int32_t j = 0; j < loop_num * 2; j += 2) { - value_offset = j * pad_channels_size; + for (int32_t j = start_points_index; j < start_points_index + loop_num * 2; j += 2) { + value_offset = j * pad_channels_size; next = j + 2; for (int i = 0; i < 2; i++) { __memcpy_async( (int8_t*)buf_value_nram_1 + value_offset + pad_data_value_stride * i, (int8_t*)value_src + offset_1_a + i * stride_2_1_a, channel_size, - GDRAM2NRAM, 2 * pad_data_value_stride, stride_3_1_a, 1); + DIR, 2 * pad_data_value_stride, stride_3_1_a, 1); } loadNram2Gpr(offset_1_a, stride_2_1_a, stride_3_1_a, offset_1 + next, - stride_2_1 + next, stride_3_1 + next, num_heads, channel_size); + stride_2_1 + next, stride_3_1 + next, num_heads, channel_size, sram_stay, sram_level_start_offset); for (int i = 0; i < 2; i++) { __memcpy_async((int8_t*)buf_value_nram_1 + value_offset + - pad_channels_size + pad_data_value_stride * i, - (int8_t*)value_src + offset_1_b + i * stride_2_1_b, - channel_size, GDRAM2NRAM, 2 * pad_data_value_stride, - stride_3_1_b, 1); + pad_channels_size + pad_data_value_stride * i, + (int8_t*)value_src + offset_1_b + i * stride_2_1_b, + channel_size, DIR, 2 * pad_data_value_stride, + stride_3_1_b, 1); } loadNram2Gpr(offset_1_b, stride_2_1_b, stride_3_1_b, offset_1 + next + 1, - stride_2_1 + next + 1, stride_3_1 + next + 1, num_heads, - channel_size); + stride_2_1 + next + 1, stride_3_1 + next + 1, num_heads, + channel_size, sram_stay, sram_level_start_offset); } if (remain > 0) { - value_offset = loop_num * 2 * pad_channels_size; + value_offset = (start_points_index + loop_num * 2) * pad_channels_size; for (int i = 0; i < 2; i++) { __memcpy_async( (int8_t*)buf_value_nram_1 + value_offset + pad_data_value_stride * i, (int8_t*)value_src + offset_1_a + i * stride_2_1_a, channel_size, - GDRAM2NRAM, 2 * pad_data_value_stride, stride_3_1_a, 1); + DIR, 2 * pad_data_value_stride, stride_3_1_a, 1); } } } - -template -__mlu_func__ void loadNeighborPolationAttn( - T* value_output_nram, T* value_gdram, int32_t* data_offset_nram, - T* weight_polation_nram, T* cond_point_polation_nram, - T* cond_point_valid_nram, T* weight_attn_nram, T* buf_nram, - T* compute_buf_nram, const int32_t deal_n, const int32_t num_levels, - const int32_t num_points, const int32_t num_keys, const int32_t channels, - const int32_t num_heads, const int32_t pad_channels, - const int32_t pad_num_levels_points) { - int32_t channel_size = channels * sizeof(T); - int32_t pad_sample_stride_3 = deal_n * pad_num_levels_points; - - T* buf_value_nram = buf_nram; // (4, num_levels, num_points, channels) - T* buf_value_nram_trans = buf_nram + 4 * pad_num_levels_points * pad_channels; - T* weight_compute_nram = compute_buf_nram; // (4, num_levels, num_points) - - int32_t* offset = data_offset_nram; - int32_t* stride_2_1 = offset + pad_sample_stride_3; - int32_t* stride_3_1 = stride_2_1 + pad_sample_stride_3; - T* output_nram = value_output_nram; - int32_t step_offset = 0; - int32_t num_levels_points = num_levels * num_points; - for (int32_t i = 0; i < deal_n; i++) { - __bang_write_value(buf_value_nram, 4 * pad_num_levels_points * pad_channels, - 0); - __sync_compute(); - loadDataValueXram2NramAsync(buf_value_nram, offset, stride_2_1, - stride_3_1, value_gdram, num_levels_points, - channel_size, num_heads); - __sync_io(); - reduceLevelByConv(output_nram, buf_value_nram, buf_value_nram_trans, - weight_polation_nram + step_offset, weight_compute_nram, - pad_num_levels_points, pad_channels, pad_sample_stride_3); - step_offset += pad_num_levels_points; - offset = data_offset_nram + step_offset; - stride_2_1 = offset + pad_sample_stride_3; - stride_3_1 = stride_2_1 + pad_sample_stride_3; - output_nram += pad_channels; + + template + __mlu_func__ void loadNeighborPolationAttn( + T* value_output_nram, T* value_gdram, int32_t* data_offset_nram, + T* weight_polation_nram, T* cond_point_polation_nram, + T* cond_point_valid_nram, T* weight_attn_nram, T* buf_nram, + T* compute_buf_nram, const int32_t deal_n, const int32_t num_levels, + const int32_t num_points, const int32_t num_keys, const int32_t channels, + const int32_t num_heads, const int32_t pad_channels, + const int32_t pad_num_levels_points, T* value_sram, const int32_t sram_level_start_index, const int32_t sram_level_start_offset) { + // int32_t sram_level_start_offset= 204800; + // return; + int32_t channel_size = channels * sizeof(T); + int32_t pad_sample_stride_3 = deal_n * pad_num_levels_points; + + T* buf_value_nram = buf_nram; // (4, num_levels, num_points, channels) + T* buf_value_nram_trans = buf_nram + 4 * pad_num_levels_points * pad_channels; + T* weight_compute_nram = compute_buf_nram; // (4, num_levels, num_points) + + int32_t* offset = data_offset_nram; + int32_t* stride_2_1 = offset + pad_sample_stride_3; + int32_t* stride_3_1 = stride_2_1 + pad_sample_stride_3; + T* output_nram = value_output_nram; + int32_t step_offset = 0; + for (int32_t i = 0; i < deal_n; i++) { + __bang_write_value(buf_value_nram, 4 * pad_num_levels_points * pad_channels, + 0); + __sync_compute(); + if(sram_level_start_index > 0) { + loadDataValueXram2NramAsync(buf_value_nram, offset, stride_2_1, + stride_3_1, value_gdram, pad_num_levels_points, num_points*sram_level_start_index, 0, + channel_size, num_heads, false, sram_level_start_offset); + } + if(sram_level_start_index < num_levels) { + loadDataValueXram2NramAsync(buf_value_nram, offset, stride_2_1, + stride_3_1, value_sram, pad_num_levels_points, num_points*(num_levels - sram_level_start_index), num_points*sram_level_start_index, + channel_size, num_heads, true, sram_level_start_offset); + } + __sync_io(); + reduceLevelByConv(output_nram, buf_value_nram, buf_value_nram_trans, + weight_polation_nram + step_offset, weight_compute_nram, + pad_num_levels_points, pad_channels, pad_sample_stride_3); + step_offset += pad_num_levels_points; + offset = data_offset_nram + step_offset; + stride_2_1 = offset + pad_sample_stride_3; + stride_3_1 = stride_2_1 + pad_sample_stride_3; + output_nram += pad_channels; + } + } + + + template + __mlu_func__ void prepareLoop( + int32_t* spatial_offset_nram, int32_t* spatial_hw_nram, + T* spatial_offset_bd_nram, T* spatial_h_bd_nram, T* spatial_w_bd_nram, + const char* data_level_start_index_gdram, + const char* data_spatial_shapes_gdram, const int32_t num_keys, + const int32_t num_levels, const int32_t num_points, + const int32_t max_deal_n, const int32_t channels) { + __memcpy(spatial_offset_nram, data_level_start_index_gdram, + num_levels * sizeof(int32_t), GDRAM2NRAM); + __memcpy(spatial_hw_nram, data_spatial_shapes_gdram, + num_levels * 2 * sizeof(int32_t), GDRAM2NRAM); + broadcastSpatialHW(spatial_offset_bd_nram, spatial_h_bd_nram, + spatial_w_bd_nram, spatial_hw_nram, spatial_offset_nram, + num_levels, num_points); + } + + /* + The shape of each tensor: + buf_compute_nram: (8, num_levels, num_points) + spatial_offset_nram: (num_levels) + spatial_hw_nram: (num_levels, 2) + spatial_offset_bd_nram: (num_levels, num_points) + spatial_w_bd_nram: (num_levels, num_points) + spatial_h_bd_nram: (num_levels, num_points) + value_output_nram: (deal_n, channels) + data_offset_nram: (4, deal_n, num_levels, num_points) + weight_polation_nram: (4, deal_n, num_levels, num_points) + cond_point_polation_nram: (4, deal_n, num_levels, num_points) + cond_point_valid_nram: (deal_n, num_levels, num_points) + loc_nram: (deal_n, num_levels, num_points, 2) + weight_attn_nram: (deal_n, num_levels, num_points) + buf_nram: (6, deal_n, num_levels, num_points) + + Note: buf_nram is reused in polation computing. + */ + template + __mlu_func__ void memPolicyCommon( + T*& buf_compute_nram, T*& value_output_nram, int32_t*& data_offset_nram, + T*& weight_polation_nram, T*& cond_point_polation_nram, + T*& cond_point_valid_nram, T*& loc_nram, T*& weight_attn_nram, T*& buf_nram, + T*& buf_nram_end, T*& spatial_offset_bd_nram, T*& spatial_w_bd_nram, + T*& spatial_h_bd_nram, T* &value_sram, int32_t*& spatial_offset_nram, + int32_t*& spatial_hw_nram, int32_t& max_deal_n, int32_t& pad_channels, + int32_t& pad_num_levels_points, int32_t& pad_total_points, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points) { + pad_channels = PAD_UP(channels, NFU_ALIGN_SIZE / sizeof(T)); + int32_t num_levels_points = num_levels * num_points; + pad_num_levels_points = PAD_UP(num_levels_points, NFU_ALIGN_SIZE / sizeof(T)); + int32_t pad_num_levels_points_8 = 8 * pad_num_levels_points; + int32_t spatial_info_size = + PAD_UP(3 * num_levels * sizeof(int32_t), NFU_ALIGN_SIZE); + int32_t fix_space_size = + spatial_info_size + + (3 * pad_num_levels_points + pad_num_levels_points) * sizeof(T); + int32_t left_space_size = NRAM_AVALIABLE_SIZE - fix_space_size; + int32_t common_buffer_size_each = 6 * pad_num_levels_points * sizeof(T); + int32_t inter_result_size_each = + 17 * pad_num_levels_points * sizeof(T) + pad_channels * sizeof(T); + + max_deal_n = + left_space_size / (common_buffer_size_each + inter_result_size_each); + + int32_t compute_buffer_size = + (9 * pad_num_levels_points * pad_channels) * sizeof(T); + int32_t common_buffer_size = max_deal_n * common_buffer_size_each; + // make sure buf_nram is large enough for compute + if (compute_buffer_size > common_buffer_size) { + int32_t tmp_deal_n = + (left_space_size - compute_buffer_size) / inter_result_size_each; + max_deal_n = __mluop_min(max_deal_n, tmp_deal_n); + } + + pad_total_points = max_deal_n * pad_num_levels_points; + buf_compute_nram = (T*)nram_buffer; + spatial_offset_nram = (int32_t*)(buf_compute_nram + pad_num_levels_points_8); + int32_t pad_3_levels = PAD_UP(3 * num_levels, NFU_ALIGN_SIZE / sizeof(T)); + spatial_hw_nram = spatial_offset_nram + num_levels; + spatial_offset_bd_nram = (T*)(spatial_offset_nram + pad_3_levels); + spatial_w_bd_nram = spatial_offset_bd_nram + pad_num_levels_points; + spatial_h_bd_nram = spatial_w_bd_nram + pad_num_levels_points; + value_output_nram = spatial_h_bd_nram + pad_num_levels_points; + data_offset_nram = (int32_t*)(value_output_nram + max_deal_n * pad_channels); + weight_polation_nram = (T*)(data_offset_nram + 4 * pad_total_points); + cond_point_polation_nram = weight_polation_nram + 4 * pad_total_points; + cond_point_valid_nram = cond_point_polation_nram + 4 * pad_total_points; + loc_nram = cond_point_valid_nram + pad_total_points; + weight_attn_nram = + loc_nram + + 2 * pad_total_points; // total_coord_pad = 2 * pad_total_points + buf_nram = weight_attn_nram + pad_total_points; + buf_nram_end = buf_nram + 6 * max_deal_n * pad_num_levels_points; + value_sram = (T*)sram_buffer; + } + + template +__mlu_func__ void loadDataValueGdram2Sram(T* value_sram, + T* data_value_gdram, + const int32_t batch_idx, + const int32_t head_idx, + const int32_t sram_num_keys, + const int32_t num_heads, + const int32_t channels, + const int32_t skip_num_key) { + int32_t loop_num = (sram_num_keys + MAX_MEMCPY_SEGNUM - 1) / MAX_MEMCPY_SEGNUM; + int32_t num_heads_channels = num_heads * channels; + for (int32_t i = 0; i < loop_num; i++) { + int32_t load_num = + __mluop_min(MAX_MEMCPY_SEGNUM, sram_num_keys - i * MAX_MEMCPY_SEGNUM); + size_t src_offset = ((size_t)batch_idx * sram_num_keys + skip_num_key + i * MAX_MEMCPY_SEGNUM) * + num_heads_channels + + head_idx * channels; + int32_t dst_offset = i * MAX_MEMCPY_SEGNUM * channels; + __memcpy(value_sram + dst_offset, (T*)data_value_gdram + src_offset, + channels * sizeof(T), GDRAM2SRAM, channels * sizeof(T), + num_heads_channels * sizeof(T), load_num - 1); } } template -__mlu_func__ void prepareLoop( - int32_t* spatial_offset_nram, int32_t* spatial_hw_nram, - T* spatial_offset_bd_nram, T* spatial_h_bd_nram, T* spatial_w_bd_nram, - const char* data_level_start_index_gdram, - const char* data_spatial_shapes_gdram, const int32_t num_keys, - const int32_t num_levels, const int32_t num_points, - const int32_t max_deal_n, const int32_t channels) { - __memcpy(spatial_offset_nram, data_level_start_index_gdram, - num_levels * sizeof(int32_t), GDRAM2NRAM); - __memcpy(spatial_hw_nram, data_spatial_shapes_gdram, - num_levels * 2 * sizeof(int32_t), GDRAM2NRAM); - broadcastSpatialHW(spatial_offset_bd_nram, spatial_h_bd_nram, - spatial_w_bd_nram, spatial_hw_nram, spatial_offset_nram, - num_levels, num_points); -} - -/* - The shape of each tensor: - buf_compute_nram: (8, num_levels, num_points) - spatial_offset_nram: (num_levels) - spatial_hw_nram: (num_levels, 2) - spatial_offset_bd_nram: (num_levels, num_points) - spatial_w_bd_nram: (num_levels, num_points) - spatial_h_bd_nram: (num_levels, num_points) - value_output_nram: (deal_n, channels) - data_offset_nram: (4, deal_n, num_levels, num_points) - weight_polation_nram: (4, deal_n, num_levels, num_points) - cond_point_polation_nram: (4, deal_n, num_levels, num_points) - cond_point_valid_nram: (deal_n, num_levels, num_points) - loc_nram: (deal_n, num_levels, num_points, 2) - weight_attn_nram: (deal_n, num_levels, num_points) - buf_nram: (6, deal_n, num_levels, num_points) - - Note: buf_nram is reused in polation computing. -*/ -template -__mlu_func__ void memPolicyCommon( - T*& buf_compute_nram, T*& value_output_nram, int32_t*& data_offset_nram, - T*& weight_polation_nram, T*& cond_point_polation_nram, - T*& cond_point_valid_nram, T*& loc_nram, T*& weight_attn_nram, T*& buf_nram, - T*& buf_nram_end, T*& spatial_offset_bd_nram, T*& spatial_w_bd_nram, - T*& spatial_h_bd_nram, int32_t*& spatial_offset_nram, - int32_t*& spatial_hw_nram, int32_t& max_deal_n, int32_t& pad_channels, - int32_t& pad_num_levels_points, int32_t& pad_total_points, - const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_queries, - const int32_t num_points) { - pad_channels = PAD_UP(channels, NFU_ALIGN_SIZE / sizeof(T)); - int32_t num_levels_points = num_levels * num_points; - pad_num_levels_points = PAD_UP(num_levels_points, NFU_ALIGN_SIZE / sizeof(T)); - int32_t pad_num_levels_points_8 = 8 * pad_num_levels_points; - int32_t spatial_info_size = - PAD_UP(3 * num_levels * sizeof(int32_t), NFU_ALIGN_SIZE); - int32_t fix_space_size = - spatial_info_size + - (3 * pad_num_levels_points + pad_num_levels_points) * sizeof(T); - int32_t left_space_size = NRAM_AVALIABLE_SIZE - fix_space_size; - int32_t common_buffer_size_each = 6 * pad_num_levels_points * sizeof(T); - int32_t inter_result_size_each = - 17 * pad_num_levels_points * sizeof(T) + pad_channels * sizeof(T); - - max_deal_n = - left_space_size / (common_buffer_size_each + inter_result_size_each); - - int32_t compute_buffer_size = - (9 * pad_num_levels_points * pad_channels) * sizeof(T); - int32_t common_buffer_size = max_deal_n * common_buffer_size_each; - // make sure buf_nram is large enough for compute - if (compute_buffer_size > common_buffer_size) { - int32_t tmp_deal_n = - (left_space_size - compute_buffer_size) / inter_result_size_each; - max_deal_n = __mluop_min(max_deal_n, tmp_deal_n); - } - - pad_total_points = max_deal_n * pad_num_levels_points; - buf_compute_nram = (T*)nram_buffer; - spatial_offset_nram = (int32_t*)(buf_compute_nram + pad_num_levels_points_8); - int32_t pad_3_levels = PAD_UP(3 * num_levels, NFU_ALIGN_SIZE / sizeof(T)); - spatial_hw_nram = spatial_offset_nram + num_levels; - spatial_offset_bd_nram = (T*)(spatial_offset_nram + pad_3_levels); - spatial_w_bd_nram = spatial_offset_bd_nram + pad_num_levels_points; - spatial_h_bd_nram = spatial_w_bd_nram + pad_num_levels_points; - value_output_nram = spatial_h_bd_nram + pad_num_levels_points; - data_offset_nram = (int32_t*)(value_output_nram + max_deal_n * pad_channels); - weight_polation_nram = (T*)(data_offset_nram + 4 * pad_total_points); - cond_point_polation_nram = weight_polation_nram + 4 * pad_total_points; - cond_point_valid_nram = cond_point_polation_nram + 4 * pad_total_points; - loc_nram = cond_point_valid_nram + pad_total_points; - weight_attn_nram = - loc_nram + - 2 * pad_total_points; // total_coord_pad = 2 * pad_total_points - buf_nram = weight_attn_nram + pad_total_points; - buf_nram_end = buf_nram + 6 * max_deal_n * pad_num_levels_points; -} - -template -__mlu_func__ void MLUKernelMsDeformAttnForwardFastImpl( - const char* data_value_gdram, const char* data_spatial_shapes_gdram, - const char* data_level_start_index_gdram, - const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, - const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_queries, - const int32_t num_points, char* data_col_gdram) { - int32_t input_stride_4 = num_queries * num_heads * num_levels * num_points; - int32_t input_stride_3 = num_heads * num_levels * num_points; - int32_t input_stride_2 = num_levels * num_points; - int32_t output_stride_3 = num_queries * num_heads * channels; - int32_t output_stride_2 = num_heads * channels; - int32_t data_value_stride_3 = num_keys * num_heads * channels; - - T* value_output_nram = nullptr; // (deal_n, channels) - int32_t* data_offset_nram = nullptr; // (4, deal_n, num_levels, num_points) - T* weight_polation_nram = nullptr; // (4, deal_n, num_levels, num_points) - T* cond_point_polation_nram = nullptr; // (4, deal_n, num_levels, num_points) - T* cond_point_valid_nram = nullptr; // (deal_n, num_levels, num_points) - T* loc_nram = nullptr; // (deal_n, num_levels, num_points, 2) - T* weight_attn_nram = nullptr; // (deal_n, num_levels, num_points) - T* buf_nram = nullptr; // (6, deal_n, num_levels, num_points) - T* buf_nram_end = nullptr; - T* spatial_offset_bd_nram = nullptr; // (num_levels, num_points) - T* spatial_w_bd_nram = nullptr; // (num_levels, num_points) - T* spatial_h_bd_nram = nullptr; // (num_levels, num_points) - int32_t* spatial_offset_nram = nullptr; // (num_levels) - int32_t* spatial_hw_nram = nullptr; // (num_levels, 2) - T* buf_compute_nram = nullptr; // (8, num_levels, num_points) - int32_t max_deal_n = 0; - int32_t pad_channels = 0; - int32_t pad_num_levels_points = 0; - int32_t pad_total_points = 0; - - memPolicyCommon(buf_compute_nram, value_output_nram, data_offset_nram, - weight_polation_nram, cond_point_polation_nram, - cond_point_valid_nram, loc_nram, weight_attn_nram, buf_nram, - buf_nram_end, spatial_offset_bd_nram, spatial_w_bd_nram, - spatial_h_bd_nram, spatial_offset_nram, spatial_hw_nram, - max_deal_n, pad_channels, pad_num_levels_points, - pad_total_points, batch_size, num_keys, num_heads, channels, - num_levels, num_queries, num_points); - if (max_deal_n <= 0) { - return; - } - - // split batch*head into taskDimY - int32_t batch_head = batch_size * num_heads; - int32_t cluster_avg_batch_head = (batch_head + taskDimY - 1) / taskDimY; - int32_t cluster_begin_batch_head = taskIdY * cluster_avg_batch_head; - int32_t cluster_act_batch_head = __mluop_min( - cluster_avg_batch_head, batch_head - cluster_begin_batch_head); - int32_t cluster_end_batch_head = - cluster_begin_batch_head + cluster_act_batch_head; - // split query into coreDim - int32_t core_avg_query = (num_queries + coreDim - 1) / coreDim; - int32_t core_begin_query = coreId * core_avg_query; - int32_t core_act_query = - __mluop_min(num_queries - core_begin_query, core_avg_query); - int32_t core_loop_num = (core_act_query + max_deal_n - 1) / max_deal_n; - int32_t core_step_query = - core_loop_num > 0 ? (core_act_query + core_loop_num - 1) / core_loop_num - : 0; - int32_t core_remain_query = - core_act_query - (core_loop_num - 1) * core_step_query; - int32_t first_deal_query = - (int)(core_loop_num > 0) * - (core_loop_num > 1 ? core_step_query : core_remain_query); - - prepareLoop(spatial_offset_nram, spatial_hw_nram, spatial_offset_bd_nram, - spatial_h_bd_nram, spatial_w_bd_nram, - data_level_start_index_gdram, data_spatial_shapes_gdram, num_keys, - num_levels, num_points, max_deal_n, channels); - - for (int32_t bh_idx = cluster_begin_batch_head; - bh_idx < cluster_end_batch_head; bh_idx++) { - int32_t b = bh_idx / num_heads; - int32_t head_idx = bh_idx % num_heads; - - size_t output_base_offset = - (size_t)b * output_stride_3 + head_idx * channels; - int32_t attn_weight_base_offset = - b * input_stride_4 + head_idx * input_stride_2; - - __sync_cluster(); - - if (__is_ipu()) { - // compute weight, offset and condition - int32_t attn_weight_offset = - attn_weight_base_offset + core_begin_query * input_stride_3; - int32_t loc_offset = attn_weight_offset * 2; - if (first_deal_query > 0) { - __bang_write_value(loc_nram, 2 * pad_total_points, 0); - __bang_write_value(weight_attn_nram, pad_total_points, 0); - __sync_compute(); - __memcpy_async(loc_nram, (T*)data_sampling_loc_gdram + loc_offset, - input_stride_2 * 2 * sizeof(T), GDRAM2NRAM, - pad_num_levels_points * 2 * sizeof(T), - input_stride_3 * 2 * sizeof(T), first_deal_query - 1); - __memcpy_async(weight_attn_nram, - (T*)data_attn_weight_gdram + attn_weight_offset, - input_stride_2 * sizeof(T), GDRAM2NRAM, - pad_num_levels_points * sizeof(T), - input_stride_3 * sizeof(T), first_deal_query - 1); - getConditionCoordWeight( - data_offset_nram, weight_polation_nram, cond_point_polation_nram, - cond_point_valid_nram, loc_nram, weight_attn_nram, - spatial_offset_bd_nram, spatial_w_bd_nram, spatial_h_bd_nram, - buf_nram, first_deal_query, num_levels, num_points, num_heads, - pad_num_levels_points); - } +__mlu_func__ void computeSramCacheSizeAndOffset(int32_t *sram_level_cache_size, + int32_t *sram_level_start_index, + int32_t *sram_level_start_offset, + const int32_t num_levels, + const int32_t num_keys, + const int32_t channels, + const T *data_level_start_index_gdram, + const int32_t sram_size){ + for(int32_t level_id = num_levels;level_id > 0; level_id--){ + int current_level_end_index = level_id == num_levels ? num_keys : ((int32_t*)data_level_start_index_gdram)[level_id]; + int32_t current_level_size = current_level_end_index - ((int32_t*)data_level_start_index_gdram)[level_id - 1]; + if((*sram_level_cache_size + current_level_size) * channels * sizeof(T) > sram_size){ + break; } - - for (int32_t i = 0; __is_ipu() && i < core_loop_num; i++) { - __bang_write_value(loc_nram, 2 * pad_total_points, 0); - __bang_write_value(weight_attn_nram, pad_total_points, 0); - int32_t deal_n = - i < core_loop_num - 1 ? core_step_query : core_remain_query; - int32_t load_n = - i < core_loop_num - 2 ? core_step_query : core_remain_query; - // load value and polation - loadNeighborPolationAttn( - value_output_nram, - (T*)data_value_gdram + b * data_value_stride_3 + head_idx * channels, - data_offset_nram, weight_polation_nram, cond_point_polation_nram, - cond_point_valid_nram, weight_attn_nram, buf_nram, buf_compute_nram, - deal_n, num_levels, num_points, num_keys, channels, num_heads, - pad_channels, pad_num_levels_points); - __sync_io_move_compute(); - // load next weight and loc - if (i < core_loop_num - 1) { - int32_t core_query_offset = (i + 1) * core_step_query; - int32_t attn_weight_offset = - attn_weight_base_offset + - (core_begin_query + core_query_offset) * input_stride_3; - int32_t loc_offset = attn_weight_offset * 2; - __memcpy_async(loc_nram, (T*)data_sampling_loc_gdram + loc_offset, - input_stride_2 * 2 * sizeof(T), GDRAM2NRAM, - pad_num_levels_points * 2 * sizeof(T), - input_stride_3 * 2 * sizeof(T), load_n - 1); - __memcpy_async(weight_attn_nram, - (T*)data_attn_weight_gdram + attn_weight_offset, - input_stride_2 * sizeof(T), GDRAM2NRAM, - pad_num_levels_points * sizeof(T), - input_stride_3 * sizeof(T), load_n - 1); - __sync_io_move_compute(); - } - // store result - size_t output_offset = - ((size_t)core_begin_query + i * core_step_query) * output_stride_2; - __memcpy_async((T*)data_col_gdram + output_base_offset + output_offset, - value_output_nram, channels * sizeof(T), NRAM2GDRAM, - output_stride_2 * sizeof(T), pad_channels * sizeof(T), - deal_n - 1); - - // compute cond/weight/offset - if (i < core_loop_num - 1) { - getConditionCoordWeight( - data_offset_nram, weight_polation_nram, cond_point_polation_nram, - cond_point_valid_nram, loc_nram, weight_attn_nram, - spatial_offset_bd_nram, spatial_w_bd_nram, spatial_h_bd_nram, - buf_nram, load_n, num_levels, num_points, num_heads, - pad_num_levels_points); - } - __sync_io_move_compute(); - } - __sync_cluster(); + *sram_level_cache_size += current_level_size; + *sram_level_start_index = level_id - 1; + *sram_level_start_offset = num_keys - *sram_level_cache_size; } } + template + __mlu_func__ void MLUKernelMsDeformAttnForwardFastImpl( + const char* data_value_gdram, const char* data_spatial_shapes_gdram, + const char* data_level_start_index_gdram, + const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char* data_col_gdram) { + int32_t input_stride_4 = num_queries * num_heads * num_levels * num_points; + int32_t input_stride_3 = num_heads * num_levels * num_points; + int32_t input_stride_2 = num_levels * num_points; + int32_t output_stride_3 = num_queries * num_heads * channels; + int32_t output_stride_2 = num_heads * channels; + int32_t data_value_stride_3 = num_keys * num_heads * channels; + // constexpr bool sram_stay = (POLICY == 0); + + T* value_output_nram = nullptr; // (deal_n, channels) + int32_t* data_offset_nram = nullptr; // (4, deal_n, num_levels, num_points) + T* weight_polation_nram = nullptr; // (4, deal_n, num_levels, num_points) + T* cond_point_polation_nram = nullptr; // (4, deal_n, num_levels, num_points) + T* cond_point_valid_nram = nullptr; // (deal_n, num_levels, num_points) + T* loc_nram = nullptr; // (deal_n, num_levels, num_points, 2) + T* weight_attn_nram = nullptr; // (deal_n, num_levels, num_points) + T* buf_nram = nullptr; // (6, deal_n, num_levels, num_points) + T* buf_nram_end = nullptr; + T* spatial_offset_bd_nram = nullptr; // (num_levels, num_points) + T* spatial_w_bd_nram = nullptr; // (num_levels, num_points) + T* spatial_h_bd_nram = nullptr; // (num_levels, num_points) + int32_t* spatial_offset_nram = nullptr; // (num_levels) + int32_t* spatial_hw_nram = nullptr; // (num_levels, 2) + T* buf_compute_nram = nullptr; // (8, num_levels, num_points) + int32_t max_deal_n = 0; + int32_t pad_channels = 0; + int32_t pad_num_levels_points = 0; + int32_t pad_total_points = 0; + T* value_sram = nullptr; // (num_keys, channels) + + memPolicyCommon(buf_compute_nram, value_output_nram, data_offset_nram, + weight_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, loc_nram, weight_attn_nram, buf_nram, + buf_nram_end, spatial_offset_bd_nram, spatial_w_bd_nram, + spatial_h_bd_nram, value_sram, spatial_offset_nram, spatial_hw_nram, + max_deal_n, pad_channels, pad_num_levels_points, + pad_total_points, batch_size, num_keys, num_heads, channels, + num_levels, num_queries, num_points); + if (max_deal_n <= 0) { + return; + } + + // split batch*head into taskDimY + int32_t batch_head = batch_size * num_heads; + int32_t cluster_avg_batch_head = (batch_head + taskDimY - 1) / taskDimY; + int32_t cluster_begin_batch_head = taskIdY * cluster_avg_batch_head; + int32_t cluster_act_batch_head = __mluop_min( + cluster_avg_batch_head, batch_head - cluster_begin_batch_head); + int32_t cluster_end_batch_head = + cluster_begin_batch_head + cluster_act_batch_head; + // split query into coreDim + int32_t core_avg_query = (num_queries + coreDim - 1) / coreDim; + int32_t core_begin_query = coreId * core_avg_query; + int32_t core_act_query = + __mluop_min(num_queries - core_begin_query, core_avg_query); + int32_t core_loop_num = (core_act_query + max_deal_n - 1) / max_deal_n; + int32_t core_step_query = + core_loop_num > 0 ? (core_act_query + core_loop_num - 1) / core_loop_num + : 0; + int32_t core_remain_query = + core_act_query - (core_loop_num - 1) * core_step_query; + int32_t first_deal_query = + (int)(core_loop_num > 0) * + (core_loop_num > 1 ? core_step_query : core_remain_query); + + prepareLoop(spatial_offset_nram, spatial_hw_nram, spatial_offset_bd_nram, + spatial_h_bd_nram, spatial_w_bd_nram, + data_level_start_index_gdram, data_spatial_shapes_gdram, num_keys, + num_levels, num_points, max_deal_n, channels); + + int sram_total_size = 0; + int sram_level_start_index = num_levels; + int sram_level_start_offset = 0; + computeSramCacheSizeAndOffset(&sram_total_size, &sram_level_start_index, + &sram_level_start_offset, num_levels, + num_keys, channels, (int32_t *)data_level_start_index_gdram, SRAM_FOR_VALUE_SIZE); + + for (int32_t bh_idx = cluster_begin_batch_head; + bh_idx < cluster_end_batch_head; bh_idx++) { + int32_t b = bh_idx / num_heads; + int32_t head_idx = bh_idx % num_heads; + + size_t output_base_offset = + (size_t)b * output_stride_3 + head_idx * channels; + int32_t attn_weight_base_offset = + b * input_stride_4 + head_idx * input_stride_2; + + if(__is_mpu() && (sram_level_start_index != num_levels)) + { + loadDataValueGdram2Sram(value_sram, (T*)data_value_gdram, b, head_idx, + sram_total_size, num_heads, channels, sram_level_start_offset); + } + + __sync_cluster(); + + if (__is_ipu()) { + // compute weight, offset and condition + int32_t attn_weight_offset = + attn_weight_base_offset + core_begin_query * input_stride_3; + int32_t loc_offset = attn_weight_offset * 2; + if (first_deal_query > 0) { + __bang_write_value(loc_nram, 2 * pad_total_points, 0); + __bang_write_value(weight_attn_nram, pad_total_points, 0); + __sync_compute(); + __memcpy_async(loc_nram, (T*)data_sampling_loc_gdram + loc_offset, + input_stride_2 * 2 * sizeof(T), GDRAM2NRAM, + pad_num_levels_points * 2 * sizeof(T), + input_stride_3 * 2 * sizeof(T), first_deal_query - 1); + __memcpy_async(weight_attn_nram, + (T*)data_attn_weight_gdram + attn_weight_offset, + input_stride_2 * sizeof(T), GDRAM2NRAM, + pad_num_levels_points * sizeof(T), + input_stride_3 * sizeof(T), first_deal_query - 1); + getConditionCoordWeight( + data_offset_nram, weight_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, loc_nram, weight_attn_nram, + spatial_offset_bd_nram, spatial_w_bd_nram, spatial_h_bd_nram, + buf_nram, first_deal_query, num_levels, num_points, num_heads, + pad_num_levels_points); + } + } + + for (int32_t i = 0; __is_ipu() && i < core_loop_num; i++) { + __bang_write_value(loc_nram, 2 * pad_total_points, 0); + __bang_write_value(weight_attn_nram, pad_total_points, 0); + int32_t deal_n = + i < core_loop_num - 1 ? core_step_query : core_remain_query; + int32_t load_n = + i < core_loop_num - 2 ? core_step_query : core_remain_query; + // load value and polation + loadNeighborPolationAttn( + value_output_nram, + (T*)data_value_gdram + b * data_value_stride_3 + head_idx * channels, + data_offset_nram, weight_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, weight_attn_nram, buf_nram, buf_compute_nram, + deal_n, num_levels, num_points, num_keys, channels, num_heads, + pad_channels, pad_num_levels_points, value_sram, sram_level_start_index, sram_level_start_offset); + __sync_io_move_compute(); + // load next weight and loc + if (i < core_loop_num - 1) { + int32_t core_query_offset = (i + 1) * core_step_query; + int32_t attn_weight_offset = + attn_weight_base_offset + + (core_begin_query + core_query_offset) * input_stride_3; + int32_t loc_offset = attn_weight_offset * 2; + __memcpy_async(loc_nram, (T*)data_sampling_loc_gdram + loc_offset, + input_stride_2 * 2 * sizeof(T), GDRAM2NRAM, + pad_num_levels_points * 2 * sizeof(T), + input_stride_3 * 2 * sizeof(T), load_n - 1); + __memcpy_async(weight_attn_nram, + (T*)data_attn_weight_gdram + attn_weight_offset, + input_stride_2 * sizeof(T), GDRAM2NRAM, + pad_num_levels_points * sizeof(T), + input_stride_3 * sizeof(T), load_n - 1); + __sync_io_move_compute(); + } + // store result + size_t output_offset = + ((size_t)core_begin_query + i * core_step_query) * output_stride_2; + __memcpy_async((T*)data_col_gdram + output_base_offset + output_offset, + value_output_nram, channels * sizeof(T), NRAM2GDRAM, + output_stride_2 * sizeof(T), pad_channels * sizeof(T), + deal_n - 1); + + // compute cond/weight/offset + if (i < core_loop_num - 1) { + getConditionCoordWeight( + data_offset_nram, weight_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, loc_nram, weight_attn_nram, + spatial_offset_bd_nram, spatial_w_bd_nram, spatial_h_bd_nram, + buf_nram, load_n, num_levels, num_points, num_heads, + pad_num_levels_points); + } + __sync_io_move_compute(); + } + __sync_cluster(); + } + } + template __mlu_global__ void MLUKernelMsDeformAttnForwardFast( const char* data_value_gdram, const char* data_spatial_shapes_gdram,