diff --git a/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh b/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh index 0dd9c33c66..15e07d1970 100644 --- a/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh +++ b/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh @@ -14,25 +14,17 @@ __global__ void ca_forward_kernel(const T *t, const T *f, T *weight, int num, int y = blockIdx.y * blockDim.y + threadIdx.y; int sp = height * width; int len = height + width - 1; - int z = blockIdx.z; - - if (x < width && y < height && z < height + width - 1) { - for (int batch = 0; batch < num; ++batch) { - for (int plane = 0; plane < chn; ++plane) { - T _t = t[(batch * chn + plane) * sp + y * width + x]; - - if (z < width) { - int i = z; - T _f = f[(batch * chn + plane) * sp + y * width + i]; - weight[(batch * len + i) * sp + y * width + x] += _t * _f; - } else { - int i = z - width; - int j = i < y ? i : i + 1; - - T _f = f[(batch * chn + plane) * sp + j * width + x]; - weight[(batch * len + width + i) * sp + y * width + x] += _t * _f; - } - } + int z = blockIdx.z % len; + int batch = blockIdx.z / len; + + if (x < width && y < height) { + T *weight_ptr = weight + (batch * len + z) * sp + y * width + x; + const int t_offset = y * width + x; + const int j = (z - width < y) ? z - width : z - width + 1; + const int f_offset = z < width ? y * width + z : j * width + x; + for (int plane = 0; plane < chn; ++plane) { + const int tf_base = (batch * chn + plane) * sp; + *weight_ptr += t[tf_base + t_offset] * f[tf_base + f_offset]; } } } @@ -44,23 +36,22 @@ __global__ void ca_backward_kernel_t(const T *dw, const T *t, const T *f, T *dt, int y = blockIdx.y * blockDim.y + threadIdx.y; int sp = height * width; int len = height + width - 1; - int plane = blockIdx.z; - - if (x < width && y < height && plane < chn) { - for (int batch = 0; batch < num; ++batch) { - for (int i = 0; i < width; ++i) { - T _dw = dw[(batch * len + i) * sp + y * width + x]; - T _f = f[(batch * chn + plane) * sp + y * width + i]; - dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f; - } - for (int i = 0; i < height; ++i) { - if (i == y) continue; - int j = i < y ? i : i - 1; - - T _dw = dw[(batch * len + width + j) * sp + y * width + x]; - T _f = f[(batch * chn + plane) * sp + i * width + x]; - dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f; - } + int plane = blockIdx.z % chn; + int batch = blockIdx.z / chn; + + if (x < width && y < height) { + for (int i = 0; i < width; ++i) { + T _dw = dw[(batch * len + i) * sp + y * width + x]; + T _f = f[(batch * chn + plane) * sp + y * width + i]; + dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f; + } + for (int i = 0; i < height; ++i) { + if (i == y) continue; + int j = i < y ? i : i - 1; + + T _dw = dw[(batch * len + width + j) * sp + y * width + x]; + T _f = f[(batch * chn + plane) * sp + i * width + x]; + dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f; } } } @@ -72,23 +63,22 @@ __global__ void ca_backward_kernel_f(const T *dw, const T *t, const T *f, T *df, int y = blockIdx.y * blockDim.y + threadIdx.y; int sp = height * width; int len = height + width - 1; - int plane = blockIdx.z; - - if (x < width && y < height && plane < chn) { - for (int batch = 0; batch < num; ++batch) { - for (int i = 0; i < width; ++i) { - T _dw = dw[(batch * len + x) * sp + y * width + i]; - T _t = t[(batch * chn + plane) * sp + y * width + i]; - df[(batch * chn + plane) * sp + y * width + x] += _dw * _t; - } - for (int i = 0; i < height; ++i) { - if (i == y) continue; - int j = i > y ? y : y - 1; - - T _dw = dw[(batch * len + width + j) * sp + i * width + x]; - T _t = t[(batch * chn + plane) * sp + i * width + x]; - df[(batch * chn + plane) * sp + y * width + x] += _dw * _t; - } + int plane = blockIdx.z % chn; + int batch = blockIdx.z / chn; + + if (x < width && y < height) { + for (int i = 0; i < width; ++i) { + T _dw = dw[(batch * len + x) * sp + y * width + i]; + T _t = t[(batch * chn + plane) * sp + y * width + i]; + df[(batch * chn + plane) * sp + y * width + x] += _dw * _t; + } + for (int i = 0; i < height; ++i) { + if (i == y) continue; + int j = i > y ? y : y - 1; + + T _dw = dw[(batch * len + width + j) * sp + i * width + x]; + T _t = t[(batch * chn + plane) * sp + i * width + x]; + df[(batch * chn + plane) * sp + y * width + x] += _dw * _t; } } } @@ -100,24 +90,22 @@ __global__ void ca_map_forward_kernel(const T *weight, const T *g, T *out, int y = blockIdx.y * blockDim.y + threadIdx.y; int sp = height * width; int len = height + width - 1; - int plane = blockIdx.z; - - if (x < width && y < height && plane < chn) { - for (int batch = 0; batch < num; ++batch) { - for (int i = 0; i < width; ++i) { - T _g = g[(batch * chn + plane) * sp + y * width + i]; - T _w = weight[(batch * len + i) * sp + y * width + x]; - out[(batch * chn + plane) * sp + y * width + x] += _g * _w; - } - for (int i = 0; i < height; ++i) { - if (i == y) continue; - - int j = i < y ? i : i - 1; - - T _g = g[(batch * chn + plane) * sp + i * width + x]; - T _w = weight[(batch * len + width + j) * sp + y * width + x]; - out[(batch * chn + plane) * sp + y * width + x] += _g * _w; - } + int plane = blockIdx.z % chn; + int batch = blockIdx.z / chn; + if (x < width && y < height) { + for (int i = 0; i < width; ++i) { + T _g = g[(batch * chn + plane) * sp + y * width + i]; + T _w = weight[(batch * len + i) * sp + y * width + x]; + out[(batch * chn + plane) * sp + y * width + x] += _g * _w; + } + for (int i = 0; i < height; ++i) { + if (i == y) continue; + + int j = i < y ? i : i - 1; + + T _g = g[(batch * chn + plane) * sp + i * width + x]; + T _w = weight[(batch * len + width + j) * sp + y * width + x]; + out[(batch * chn + plane) * sp + y * width + x] += _g * _w; } } } @@ -130,25 +118,23 @@ __global__ void ca_map_backward_kernel_w(const T *dout, const T *weight, int y = blockIdx.y * blockDim.y + threadIdx.y; int sp = height * width; int len = height + width - 1; - int z = blockIdx.z; - - if (x < width && y < height && z < height + width - 1) { - for (int batch = 0; batch < num; ++batch) { - for (int plane = 0; plane < chn; ++plane) { - T _dout = dout[(batch * chn + plane) * sp + y * width + x]; - - if (z < width) { - int i = z; - T _g = g[(batch * chn + plane) * sp + y * width + i]; - dw[(batch * len + i) * sp + y * width + x] += _dout * _g; - } else { - int i = z - width; - int j = i < y ? i : i + 1; - - T _g = g[(batch * chn + plane) * sp + j * width + x]; - dw[(batch * len + width + i) * sp + y * width + x] += _dout * _g; - } - } + + int z = blockIdx.z % len; + int batch = blockIdx.z / len; + + if (x < width && y < height) { + int widx = (batch * len + z) * sp + y * width + x; + int dout_idx = batch * chn * sp + y * width + x; + int gidx = batch * chn * sp; + if (z < width) { + gidx += y * width + z; + } else { + int j = z - width; + j = j < y ? j : j + 1; + gidx += j * width + x; + } + for (int plane = 0; plane < chn; plane++) { + dw[widx] += dout[dout_idx + plane * sp] * g[gidx + plane * sp]; } } } @@ -161,25 +147,21 @@ __global__ void ca_map_backward_kernel_g(const T *dout, const T *weight, int y = blockIdx.y * blockDim.y + threadIdx.y; int sp = height * width; int len = height + width - 1; - int plane = blockIdx.z; - - if (x < width && y < height && plane < chn) { - for (int batch = 0; batch < num; ++batch) { - for (int i = 0; i < width; ++i) { - T _dout = dout[(batch * chn + plane) * sp + y * width + i]; - T _w = weight[(batch * len + x) * sp + y * width + i]; - dg[(batch * chn + plane) * sp + y * width + x] += _dout * _w; - } - for (int i = 0; i < height; ++i) { - if (i == y) continue; - int j = i > y ? y : y - 1; - - T _dout = dout[(batch * chn + plane) * sp + i * width + x]; - T _w = weight[(batch * len + width + j) * sp + i * width + x]; - dg[(batch * chn + plane) * sp + y * width + x] += _dout * _w; - } + int plane = blockIdx.z % chn; + int batch = blockIdx.z / chn; + int index = (batch * chn + plane) * sp + y * width + x; + + if (x < width && y < height) { + for (int i = 0; i < width; ++i) { + dg[index] += dout[(batch * chn + plane) * sp + y * width + i] * + weight[(batch * len + x) * sp + y * width + i]; + } + for (int i = 0; i < height; ++i) { + if (i == y) continue; + int j = i > y ? y : y - 1; + dg[index] += dout[(batch * chn + plane) * sp + i * width + x] * + weight[(batch * len + width + j) * sp + i * width + x]; } } } - #endif // CC_ATTENTION_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/cc_attention_cuda.cu b/mmcv/ops/csrc/pytorch/cc_attention_cuda.cu index b948d5406a..fd4e7fd128 100644 --- a/mmcv/ops/csrc/pytorch/cc_attention_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cc_attention_cuda.cu @@ -24,8 +24,8 @@ void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f, dim3 threads(32, 32); int d1 = (w + threads.x - 1) / threads.x; int d2 = (h + threads.y - 1) / threads.y; - int d3 = h + w; - dim3 blocks(d1, d2, d3); + int d3 = h + w - 1; + dim3 blocks(d1, d2, d3 * n); AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_forward", [&] { ca_forward_kernel<<>>( @@ -53,7 +53,7 @@ void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t, dim3 threads(32, 32); int d1 = (w + threads.x - 1) / threads.x; int d2 = (h + threads.y - 1) / threads.y; - int d3 = c; + int d3 = c * n; dim3 blocks(d1, d2, d3); AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_backward_kernel_t", [&] { @@ -90,7 +90,7 @@ void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g, dim3 threads(32, 32); int d1 = (w + threads.x - 1) / threads.x; int d2 = (h + threads.y - 1) / threads.y; - int d3 = c; + int d3 = c * n; dim3 blocks(d1, d2, d3); AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_forward", [&] { @@ -119,8 +119,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight, dim3 threads(32, 32); int d1 = (w + threads.x - 1) / threads.x; int d2 = (h + threads.y - 1) / threads.y; - int d3 = h + w; - dim3 blocks(d1, d2, d3); + int d3 = h + w - 1; + dim3 blocks(d1, d2, d3 * n); AT_DISPATCH_FLOATING_TYPES( weight.scalar_type(), "ca_map_backward_kernel_w", [&] { @@ -130,7 +130,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight, g.contiguous().data_ptr(), dw.contiguous().data_ptr(), n, c, h, w); }); - + d3 = c * n; + blocks = dim3(d1, d2, d3); AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_backward_kernel_g", [&] { ca_map_backward_kernel_g<<>>( dout.contiguous().data_ptr(),