diff --git a/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh b/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh index 019e4110e3..15e07d1970 100644 --- a/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh +++ b/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh @@ -18,19 +18,13 @@ __global__ void ca_forward_kernel(const T *t, const T *f, T *weight, int num, int batch = blockIdx.z / len; if (x < width && y < height) { + T *weight_ptr = weight + (batch * len + z) * sp + y * width + x; + const int t_offset = y * width + x; + const int j = (z - width < y) ? z - width : z - width + 1; + const int f_offset = z < width ? y * width + z : j * width + x; for (int plane = 0; plane < chn; ++plane) { - T _t = t[(batch * chn + plane) * sp + y * width + x]; - - if (z < width) { - int i = z; - T _f = f[(batch * chn + plane) * sp + y * width + i]; - weight[(batch * len + i) * sp + y * width + x] += _t * _f; - } else { - int i = z - width; - int j = i < y ? i : i + 1; - T _f = f[(batch * chn + plane) * sp + j * width + x]; - weight[(batch * len + width + i) * sp + y * width + x] += _t * _f; - } + const int tf_base = (batch * chn + plane) * sp; + *weight_ptr += t[tf_base + t_offset] * f[tf_base + f_offset]; } } }