Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Support LayerScale #2451

Merged
merged 16 commits into from
Dec 11, 2022
4 changes: 2 additions & 2 deletions mmcv/cnn/bricks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .norm import build_norm_layer, is_norm
from .padding import build_padding_layer
from .plugin import build_plugin_layer
from .scale import Scale
from .scale import LayerScale, Scale
from .swish import Swish
from .upsample import build_upsample_layer
from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Expand All @@ -28,5 +28,5 @@
'Scale', 'ConvAWS2d', 'ConvWS2d', 'conv_ws_2d',
'DepthwiseSeparableConvModule', 'Swish', 'Linear', 'Conv2dAdaptivePadding',
'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d',
'Conv3d', 'Dropout', 'DropPath'
'Conv3d', 'Dropout', 'DropPath', 'LayerScale'
]
32 changes: 32 additions & 0 deletions mmcv/cnn/bricks/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,35 @@ def __init__(self, scale: float = 1.0):

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.scale


class LayerScale(nn.Module):
"""LayerScale layer.

Args:
dim (int): Dimension of input features.
inplace (bool): Whether performs operation in-place.
Default: `False`.
data_format (str): The input data format, could be 'channels_last'
or 'channels_first', representing (B, C, H, W) and
(B, N, C) format data respectively. Default: 'channels_last'.
"""

def __init__(self,
dim: int,
inplace: bool = False,
data_format: str = 'channels_last'):
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
assert data_format in ('channels_last', 'channels_first'), \
"'data_format' could only be channels_last or channels_first."
self.inplace = inplace
if data_format == 'channels_first':
self.weight = nn.Parameter(torch.ones(dim, 1, 1) * 1e-5)
okotaku marked this conversation as resolved.
Show resolved Hide resolved
else:
self.weight = nn.Parameter(torch.ones(dim) * 1e-5)

def forward(self, x):
if self.inplace:
return x.mul_(self.weight)
else:
return x * self.weight
10 changes: 10 additions & 0 deletions mmcv/cnn/bricks/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
build_norm_layer)
from .drop import build_dropout
from .scale import LayerScale

# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try:
Expand Down Expand Up @@ -572,6 +573,8 @@ class FFN(BaseModule):
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
use_layer_scale (bool): Whether to use layer_scale in FFN.
Default: `True`.
"""

@deprecated_api_warning(
Expand All @@ -589,6 +592,7 @@ def __init__(self,
dropout_layer=None,
add_identity=True,
init_cfg=None,
use_layer_scale=True,
**kwargs):
super().__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \
Expand All @@ -614,13 +618,19 @@ def __init__(self,
dropout_layer) if dropout_layer else torch.nn.Identity()
self.add_identity = add_identity

if use_layer_scale:
self.gamma2 = LayerScale(embed_dims)
else:
self.gamma2 = nn.Identity()

@deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
def forward(self, x, identity=None):
"""Forward function for `FFN`.

The function would add x to the output tensor if residue is None.
"""
out = self.layers(x)
out = self.gamma2(out)
if not self.add_identity:
return self.dropout_layer(out)
if identity is None:
Expand Down
42 changes: 41 additions & 1 deletion tests/test_cnn/test_scale.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmcv.cnn.bricks import Scale
from mmcv.cnn.bricks import LayerScale, Scale


def test_scale():
Expand All @@ -20,3 +21,42 @@ def test_scale():
x = torch.rand(1, 3, 64, 64)
output = scale(x)
assert output.shape == (1, 3, 64, 64)


def test_layer_scale():
with pytest.raises(AssertionError):
cfg = dict(
dim=10,
data_format='BNC',
)
LayerScale(**cfg)

# test init
cfg = dict(dim=10)
ls = LayerScale(**cfg)
assert torch.equal(ls.weight, torch.ones(10, requires_grad=True) * 1e-5)

# test forward
# test channels_last
cfg = dict(dim=256, inplace=False, data_format='channels_last')
ls_channels_last = LayerScale(**cfg)
x = torch.randn((4, 49, 256))
out = ls_channels_last(x)
assert tuple(out.size()) == (4, 49, 256)
assert torch.equal(x * 1e-5, out)

# test channels_first
cfg = dict(dim=256, inplace=False, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7)
assert torch.equal(x * 1e-5, out)

# test inplace True
cfg = dict(dim=256, inplace=True, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7)
assert x is out
8 changes: 7 additions & 1 deletion tests/test_cnn/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,6 @@ def test_ffn():
with pytest.raises(AssertionError):
# num_fcs should be no less than 2
FFN(num_fcs=1)
FFN(dropout=0, add_residual=True)
ffn = FFN(dropout=0, add_identity=True)

input_tensor = torch.rand(2, 20, 256)
Expand All @@ -553,6 +552,13 @@ def test_ffn():
ffn(input_tensor, identity=residual).sum(),
ffn(input_tensor).sum() + residual.sum() - input_tensor.sum())

# test with layer_scale
ffn = FFN(dropout=0, add_identity=True, use_layer_scale=True)

input_tensor = torch.rand(2, 20, 256)
input_tensor_nbc = input_tensor.transpose(0, 1)
assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum())


@pytest.mark.skipif(not torch.cuda.is_available(), reason='Cuda not available')
def test_basetransformerlayer_cuda():
Expand Down