Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Support LayerScale #2451

Merged
merged 16 commits into from
Dec 11, 2022
2 changes: 1 addition & 1 deletion mmcv/cnn/bricks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@
'Scale', 'ConvAWS2d', 'ConvWS2d', 'conv_ws_2d',
'DepthwiseSeparableConvModule', 'Swish', 'Linear', 'Conv2dAdaptivePadding',
'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d',
'Conv3d', 'Dropout', 'DropPath'
'Conv3d', 'Dropout', 'DropPath', 'LayerScale'
]
41 changes: 41 additions & 0 deletions mmcv/cnn/bricks/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,38 @@ def forward(self,
return identity + self.dropout_layer(self.proj_drop(out))


class LayerScale(nn.Module):
"""LayerScale layer.

Args:
dim (int): Dimension of input features.
inplace (bool): Whether performs operation in-place.
Default: `False`.
data_format (str): The input data format, could be 'channels_last'
or 'channels_first', representing (B, C, H, W) and
(B, N, C) format data respectively. Default: 'channels_last'.
"""

def __init__(self,
dim: int,
inplace: bool = False,
data_format: str = 'channels_last'):
super().__init__()
assert data_format in ('channels_last', 'channels_first'), \
"'data_format' could only be channels_last or channels_first."
self.inplace = inplace
self.data_format = data_format
self.weight = nn.Parameter(torch.ones(dim) * 1e-5)

def forward(self, x):
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
if self.data_format == 'channels_first':
if self.inplace:
return x.mul_(self.weight.view(-1, 1, 1))
else:
return x * self.weight.view(-1, 1, 1)
return x.mul_(self.weight) if self.inplace else x * self.weight
okotaku marked this conversation as resolved.
Show resolved Hide resolved


@MODELS.register_module()
class FFN(BaseModule):
"""Implements feed-forward networks (FFNs) with identity connection.
Expand All @@ -568,6 +600,8 @@ class FFN(BaseModule):
zeroed in FFN. Default 0.0.
add_identity (bool, optional): Whether to add the
identity connection. Default: `True`.
use_layer_scale (bool): Whether to use layer_scale in FFN.
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
Default: `True`.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Expand All @@ -588,6 +622,7 @@ def __init__(self,
ffn_drop=0.,
dropout_layer=None,
add_identity=True,
use_layer_scale=True,
init_cfg=None,
**kwargs):
super().__init__(init_cfg)
Expand All @@ -614,13 +649,19 @@ def __init__(self,
dropout_layer) if dropout_layer else torch.nn.Identity()
self.add_identity = add_identity

if use_layer_scale:
self.gamma2 = LayerScale(embed_dims)
else:
self.gamma2 = nn.Identity()

@deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
def forward(self, x, identity=None):
"""Forward function for `FFN`.

The function would add x to the output tensor if residue is None.
"""
out = self.layers(x)
out = self.gamma2(out)
if not self.add_identity:
return self.dropout_layer(out)
if identity is None:
Expand Down
49 changes: 47 additions & 2 deletions tests/test_cnn/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.bricks.transformer import (FFN, AdaptivePadding,
BaseTransformerLayer,
BaseTransformerLayer, LayerScale,
MultiheadAttention, PatchEmbed,
PatchMerging,
TransformerLayerSequence)
Expand Down Expand Up @@ -538,7 +538,6 @@ def test_ffn():
with pytest.raises(AssertionError):
# num_fcs should be no less than 2
FFN(num_fcs=1)
FFN(dropout=0, add_residual=True)
ffn = FFN(dropout=0, add_identity=True)

input_tensor = torch.rand(2, 20, 256)
Expand All @@ -553,6 +552,52 @@ def test_ffn():
ffn(input_tensor, identity=residual).sum(),
ffn(input_tensor).sum() + residual.sum() - input_tensor.sum())

# test with layer_scale
ffn = FFN(dropout=0, add_identity=True, use_layer_scale=True)

input_tensor = torch.rand(2, 20, 256)
input_tensor_nbc = input_tensor.transpose(0, 1)
assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum())


def test_layer_scale():
with pytest.raises(AssertionError):
cfg = dict(
dim=10,
data_format='BNC',
)
LayerScale(**cfg)

# test init
cfg = dict(dim=10)
ls = LayerScale(**cfg)
assert torch.equal(ls.weight, torch.ones(10, requires_grad=True) * 1e-5)

# test forward
# test channels_last
cfg = dict(dim=256, inplace=False, data_format='channels_last')
ls_channels_last = LayerScale(**cfg)
x = torch.randn((4, 49, 256))
out = ls_channels_last(x)
assert tuple(out.size()) == (4, 49, 256)
assert torch.equal(x * 1e-5, out)

# test channels_first
cfg = dict(dim=256, inplace=False, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7)
assert torch.equal(x * 1e-5, out)

# test inplace True
cfg = dict(dim=256, inplace=True, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7)
assert x is out


@pytest.mark.skipif(not torch.cuda.is_available(), reason='Cuda not available')
def test_basetransformerlayer_cuda():
Expand Down