Skip to content

Commit

Permalink
better ca_forward_kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Jun 24, 2021
1 parent d0e6c5d commit bc3d033
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions mmcv/ops/csrc/cc_attention_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,13 @@ __global__ void ca_forward_kernel(const T *t, const T *f, T *weight, int num,
int batch = blockIdx.z / len;

if (x < width && y < height) {
T *weight_ptr = weight + (batch * len + z) * sp + y * width + x;
const int t_offset = y * width + x;
const int j = (z - width < y) ? z - width : z - width + 1;
const int f_offset = z < width ? y * width + z : j * width + x;
for (int plane = 0; plane < chn; ++plane) {
T _t = t[(batch * chn + plane) * sp + y * width + x];

if (z < width) {
int i = z;
T _f = f[(batch * chn + plane) * sp + y * width + i];
weight[(batch * len + i) * sp + y * width + x] += _t * _f;
} else {
int i = z - width;
int j = i < y ? i : i + 1;
T _f = f[(batch * chn + plane) * sp + j * width + x];
weight[(batch * len + width + i) * sp + y * width + x] += _t * _f;
}
const int tf_base = (batch * chn + plane) * sp;
*weight_ptr += t[tf_base + t_offset] * f[tf_base + f_offset];
}
}
}
Expand Down

0 comments on commit bc3d033

Please sign in to comment.