Skip to content

Commit

Permalink
Reimplement cc_attention using pure pytorch (#1201)
Browse files Browse the repository at this point in the history
* Reimplement cc_attention using pure pytorch

* fix: avoid BC-Breaking

* delete cc_attention related cpp and cuda files

* delete cc_attention related lines in pybind.cpp

* make out Tensor contiguous.

* remove unneeded lines.

* Update mmcv/ops/cc_attention.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update TestCrissCrossAttention

* passing pre-commit

* Update docstring of CrissCrossAttention

* Update docstring of CrissCrossAttention

* Update mmcv/ops/cc_attention.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* [docs]polish the docstring

* [Docs] Polish the docstring

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
Leojc and zhouzaida committed Sep 9, 2021
1 parent 642d281 commit 2a3d2d4
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 695 deletions.
112 changes: 49 additions & 63 deletions mmcv/ops/cc_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,92 +2,78 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import once_differentiable

from mmcv.cnn import PLUGIN_LAYERS, Scale
from ..utils import ext_loader

ext_module = ext_loader.load_ext(
'_ext', ['ca_forward', 'ca_backward', 'ca_map_forward', 'ca_map_backward'])

def NEG_INF_DIAG(n, device):
"""Returns a diagonal matrix of size [n, n].
class CAWeightFunction(torch.autograd.Function):
The diagonal are all "-inf". This is for avoiding calculating the
overlapped element in the Criss-Cross twice.
"""
return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0)

@staticmethod
def symbolic(g, t, f):
return g.op('mmcv::MMCVCAWeight', t, f)

@staticmethod
def forward(ctx, t, f):
n, c, h, w = t.size()
weight = torch.zeros(n, h + w - 1, h, w).to(t.device)
ext_module.ca_forward(t, f, weight)

ctx.save_for_backward(t, f)

return weight

@staticmethod
@once_differentiable
def backward(ctx, dw):
t, f = ctx.saved_tensors
dt = torch.zeros_like(t)
df = torch.zeros_like(f)
ext_module.ca_backward(dw, t, f, dt, df)
return dt, df


class CAMapFunction(torch.autograd.Function):

@staticmethod
def symbolic(g, weight, v):
return g.op('mmcv::MMCVCAMap', weight, v)

@staticmethod
def forward(ctx, weight, v):
out = torch.zeros_like(v)
ext_module.ca_map_forward(weight, v, out)

ctx.save_for_backward(weight, v)

return out

@staticmethod
@once_differentiable
def backward(ctx, dout):
weight, v = ctx.saved_tensors
dw = torch.zeros_like(weight)
dv = torch.zeros_like(v)
ext_module.ca_map_backward(dout, weight, v, dw, dv)
@PLUGIN_LAYERS.register_module()
class CrissCrossAttention(nn.Module):
"""Criss-Cross Attention Module.
return dw, dv
.. note::
Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch
to a pure PyTorch and equivalent implementation. For more
details, please refer to https://github.com/open-mmlab/mmcv/pull/1201.
Speed comparison for one forward pass
ca_weight = CAWeightFunction.apply
ca_map = CAMapFunction.apply
- Input size: [2,512,97,97]
- Device: 1 NVIDIA GeForce RTX 2080 Ti
+-----------------------+---------------+------------+---------------+
| |PyTorch version|CUDA version|Relative speed |
+=======================+===============+============+===============+
|with torch.no_grad() |0.00554402 s |0.0299619 s |5.4x |
+-----------------------+---------------+------------+---------------+
|no with torch.no_grad()|0.00562803 s |0.0301349 s |5.4x |
+-----------------------+---------------+------------+---------------+
@PLUGIN_LAYERS.register_module()
class CrissCrossAttention(nn.Module):
"""Criss-Cross Attention Module."""
Args:
in_channels (int): Channels of the input feature map.
"""

def __init__(self, in_channels):
super(CrissCrossAttention, self).__init__()
super().__init__()
self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
self.gamma = Scale(0.)
self.in_channels = in_channels

def forward(self, x):
proj_query = self.query_conv(x)
proj_key = self.key_conv(x)
proj_value = self.value_conv(x)
"""forward function of Criss-Cross Attention.
Args:
x (Tensor): Input feature. \
shape (batch_size, in_channels, height, width)
Returns:
Tensor: Output of the layer, with shape of \
(batch_size, in_channels, height, width)
"""
B, C, H, W = x.size()
query = self.query_conv(x)
key = self.key_conv(x)
value = self.value_conv(x)
energy_H = torch.einsum('bchw,bciw->bwhi', query, key) + NEG_INF_DIAG(
H, query.device)
energy_H = energy_H.transpose(1, 2)
energy_W = torch.einsum('bchw,bchj->bhwj', query, key)
attn = F.softmax(
torch.cat([energy_H, energy_W], dim=-1), dim=-1) # [B,H,W,(H+W)]
out = torch.einsum('bciw,bhwi->bchw', value, attn[..., :H])
out += torch.einsum('bchj,bhwj->bchw', value, attn[..., H:])

energy = ca_weight(proj_query, proj_key)
attention = F.softmax(energy, 1)
out = ca_map(attention, proj_value)
out = self.gamma(out) + x
out = out.contiguous()

return out

Expand Down
168 changes: 0 additions & 168 deletions mmcv/ops/csrc/common/cuda/cc_attention_cuda_kernel.cuh

This file was deleted.

Loading

0 comments on commit 2a3d2d4

Please sign in to comment.