Skip to content

Commit

Permalink
[Feature] Optimize the PyTorch CUDA implementation for Criss Cross At…
Browse files Browse the repository at this point in the history
…tention (#1143)

* optimize criss cross attention

* optimize criss cross attention

* optimize criss cross attention

* fix lint

* fix ci, remove useless variable

* better ca_forward_kernel

Co-authored-by: wondervictor <victorchanchina@gmail.com>
  • Loading branch information
q.yao and wondervictor committed Jun 25, 2021
1 parent 6fe3722 commit 7b150fa
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 115 deletions.
198 changes: 90 additions & 108 deletions mmcv/ops/csrc/cc_attention_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,17 @@ __global__ void ca_forward_kernel(const T *t, const T *f, T *weight, int num,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int z = blockIdx.z;

if (x < width && y < height && z < height + width - 1) {
for (int batch = 0; batch < num; ++batch) {
for (int plane = 0; plane < chn; ++plane) {
T _t = t[(batch * chn + plane) * sp + y * width + x];

if (z < width) {
int i = z;
T _f = f[(batch * chn + plane) * sp + y * width + i];
weight[(batch * len + i) * sp + y * width + x] += _t * _f;
} else {
int i = z - width;
int j = i < y ? i : i + 1;

T _f = f[(batch * chn + plane) * sp + j * width + x];
weight[(batch * len + width + i) * sp + y * width + x] += _t * _f;
}
}
int z = blockIdx.z % len;
int batch = blockIdx.z / len;

if (x < width && y < height) {
T *weight_ptr = weight + (batch * len + z) * sp + y * width + x;
const int t_offset = y * width + x;
const int j = (z - width < y) ? z - width : z - width + 1;
const int f_offset = z < width ? y * width + z : j * width + x;
for (int plane = 0; plane < chn; ++plane) {
const int tf_base = (batch * chn + plane) * sp;
*weight_ptr += t[tf_base + t_offset] * f[tf_base + f_offset];
}
}
}
Expand All @@ -44,23 +36,22 @@ __global__ void ca_backward_kernel_t(const T *dw, const T *t, const T *f, T *dt,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int plane = blockIdx.z;

if (x < width && y < height && plane < chn) {
for (int batch = 0; batch < num; ++batch) {
for (int i = 0; i < width; ++i) {
T _dw = dw[(batch * len + i) * sp + y * width + x];
T _f = f[(batch * chn + plane) * sp + y * width + i];
dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i < y ? i : i - 1;

T _dw = dw[(batch * len + width + j) * sp + y * width + x];
T _f = f[(batch * chn + plane) * sp + i * width + x];
dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f;
}
int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;

if (x < width && y < height) {
for (int i = 0; i < width; ++i) {
T _dw = dw[(batch * len + i) * sp + y * width + x];
T _f = f[(batch * chn + plane) * sp + y * width + i];
dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i < y ? i : i - 1;

T _dw = dw[(batch * len + width + j) * sp + y * width + x];
T _f = f[(batch * chn + plane) * sp + i * width + x];
dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f;
}
}
}
Expand All @@ -72,23 +63,22 @@ __global__ void ca_backward_kernel_f(const T *dw, const T *t, const T *f, T *df,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int plane = blockIdx.z;

if (x < width && y < height && plane < chn) {
for (int batch = 0; batch < num; ++batch) {
for (int i = 0; i < width; ++i) {
T _dw = dw[(batch * len + x) * sp + y * width + i];
T _t = t[(batch * chn + plane) * sp + y * width + i];
df[(batch * chn + plane) * sp + y * width + x] += _dw * _t;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i > y ? y : y - 1;

T _dw = dw[(batch * len + width + j) * sp + i * width + x];
T _t = t[(batch * chn + plane) * sp + i * width + x];
df[(batch * chn + plane) * sp + y * width + x] += _dw * _t;
}
int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;

if (x < width && y < height) {
for (int i = 0; i < width; ++i) {
T _dw = dw[(batch * len + x) * sp + y * width + i];
T _t = t[(batch * chn + plane) * sp + y * width + i];
df[(batch * chn + plane) * sp + y * width + x] += _dw * _t;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i > y ? y : y - 1;

T _dw = dw[(batch * len + width + j) * sp + i * width + x];
T _t = t[(batch * chn + plane) * sp + i * width + x];
df[(batch * chn + plane) * sp + y * width + x] += _dw * _t;
}
}
}
Expand All @@ -100,24 +90,22 @@ __global__ void ca_map_forward_kernel(const T *weight, const T *g, T *out,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int plane = blockIdx.z;

if (x < width && y < height && plane < chn) {
for (int batch = 0; batch < num; ++batch) {
for (int i = 0; i < width; ++i) {
T _g = g[(batch * chn + plane) * sp + y * width + i];
T _w = weight[(batch * len + i) * sp + y * width + x];
out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;

int j = i < y ? i : i - 1;

T _g = g[(batch * chn + plane) * sp + i * width + x];
T _w = weight[(batch * len + width + j) * sp + y * width + x];
out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
}
int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;
if (x < width && y < height) {
for (int i = 0; i < width; ++i) {
T _g = g[(batch * chn + plane) * sp + y * width + i];
T _w = weight[(batch * len + i) * sp + y * width + x];
out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;

int j = i < y ? i : i - 1;

T _g = g[(batch * chn + plane) * sp + i * width + x];
T _w = weight[(batch * len + width + j) * sp + y * width + x];
out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
}
}
}
Expand All @@ -130,25 +118,23 @@ __global__ void ca_map_backward_kernel_w(const T *dout, const T *weight,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int z = blockIdx.z;

if (x < width && y < height && z < height + width - 1) {
for (int batch = 0; batch < num; ++batch) {
for (int plane = 0; plane < chn; ++plane) {
T _dout = dout[(batch * chn + plane) * sp + y * width + x];

if (z < width) {
int i = z;
T _g = g[(batch * chn + plane) * sp + y * width + i];
dw[(batch * len + i) * sp + y * width + x] += _dout * _g;
} else {
int i = z - width;
int j = i < y ? i : i + 1;

T _g = g[(batch * chn + plane) * sp + j * width + x];
dw[(batch * len + width + i) * sp + y * width + x] += _dout * _g;
}
}

int z = blockIdx.z % len;
int batch = blockIdx.z / len;

if (x < width && y < height) {
int widx = (batch * len + z) * sp + y * width + x;
int dout_idx = batch * chn * sp + y * width + x;
int gidx = batch * chn * sp;
if (z < width) {
gidx += y * width + z;
} else {
int j = z - width;
j = j < y ? j : j + 1;
gidx += j * width + x;
}
for (int plane = 0; plane < chn; plane++) {
dw[widx] += dout[dout_idx + plane * sp] * g[gidx + plane * sp];
}
}
}
Expand All @@ -161,25 +147,21 @@ __global__ void ca_map_backward_kernel_g(const T *dout, const T *weight,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int plane = blockIdx.z;

if (x < width && y < height && plane < chn) {
for (int batch = 0; batch < num; ++batch) {
for (int i = 0; i < width; ++i) {
T _dout = dout[(batch * chn + plane) * sp + y * width + i];
T _w = weight[(batch * len + x) * sp + y * width + i];
dg[(batch * chn + plane) * sp + y * width + x] += _dout * _w;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i > y ? y : y - 1;

T _dout = dout[(batch * chn + plane) * sp + i * width + x];
T _w = weight[(batch * len + width + j) * sp + i * width + x];
dg[(batch * chn + plane) * sp + y * width + x] += _dout * _w;
}
int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;
int index = (batch * chn + plane) * sp + y * width + x;

if (x < width && y < height) {
for (int i = 0; i < width; ++i) {
dg[index] += dout[(batch * chn + plane) * sp + y * width + i] *
weight[(batch * len + x) * sp + y * width + i];
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i > y ? y : y - 1;
dg[index] += dout[(batch * chn + plane) * sp + i * width + x] *
weight[(batch * len + width + j) * sp + i * width + x];
}
}
}

#endif // CC_ATTENTION_CUDA_KERNEL_CUH
15 changes: 8 additions & 7 deletions mmcv/ops/csrc/pytorch/cc_attention_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = h + w;
dim3 blocks(d1, d2, d3);
int d3 = h + w - 1;
dim3 blocks(d1, d2, d3 * n);

AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_forward", [&] {
ca_forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
Expand Down Expand Up @@ -53,7 +53,7 @@ void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = c;
int d3 = c * n;
dim3 blocks(d1, d2, d3);

AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_backward_kernel_t", [&] {
Expand Down Expand Up @@ -90,7 +90,7 @@ void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = c;
int d3 = c * n;
dim3 blocks(d1, d2, d3);

AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_forward", [&] {
Expand Down Expand Up @@ -119,8 +119,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = h + w;
dim3 blocks(d1, d2, d3);
int d3 = h + w - 1;
dim3 blocks(d1, d2, d3 * n);

AT_DISPATCH_FLOATING_TYPES(
weight.scalar_type(), "ca_map_backward_kernel_w", [&] {
Expand All @@ -130,7 +130,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
g.contiguous().data_ptr<scalar_t>(),
dw.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});

d3 = c * n;
blocks = dim3(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_backward_kernel_g", [&] {
ca_map_backward_kernel_g<scalar_t><<<blocks, threads, 0, stream>>>(
dout.contiguous().data_ptr<scalar_t>(),
Expand Down

0 comments on commit 7b150fa

Please sign in to comment.