Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename fast_conv_bn_eval to efficient_conv_bn_eval #2884

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions mmcv/cnn/bricks/conv_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from .padding import build_padding_layer


def fast_conv_bn_eval_forward(bn: _BatchNorm, conv: nn.modules.conv._ConvNd,
x: torch.Tensor):
def efficient_conv_bn_eval_forward(bn: _BatchNorm,
conv: nn.modules.conv._ConvNd,
x: torch.Tensor):
"""
Implementation based on https://arxiv.org/abs/2305.11624
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
Expand Down Expand Up @@ -115,9 +116,9 @@ class ConvModule(nn.Module):
sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm").
Default: ('conv', 'norm', 'act').
fast_conv_bn_eval (bool): Whether use fast conv when the consecutive
bn is in eval mode (either training or testing), as proposed in
https://arxiv.org/abs/2305.11624 . Default: False.
efficient_conv_bn_eval (bool): Whether use efficient conv when the
consecutive bn is in eval mode (either training or testing), as
proposed in https://arxiv.org/abs/2305.11624 . Default: `False`.
"""

_abbr_ = 'conv_block'
Expand All @@ -138,7 +139,7 @@ def __init__(self,
with_spectral_norm: bool = False,
padding_mode: str = 'zeros',
order: tuple = ('conv', 'norm', 'act'),
fast_conv_bn_eval: bool = False):
efficient_conv_bn_eval: bool = False):
super().__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict)
Expand Down Expand Up @@ -209,7 +210,7 @@ def __init__(self,
else:
self.norm_name = None # type: ignore

self.turn_on_fast_conv_bn_eval(fast_conv_bn_eval)
self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval)

# build activation layer
if self.with_activation:
Expand Down Expand Up @@ -263,15 +264,16 @@ def forward(self,
if self.with_explicit_padding:
x = self.padding_layer(x)
# if the next operation is norm and we have a norm layer in
# eval mode and we have enabled fast_conv_bn_eval for the conv
# operator, then activate the optimized forward and skip the
# next norm operator since it has been fused
# eval mode and we have enabled `efficient_conv_bn_eval` for
# the conv operator, then activate the optimized forward and
# skip the next norm operator since it has been fused
if layer_index + 1 < len(self.order) and \
self.order[layer_index + 1] == 'norm' and norm and \
self.with_norm and not self.norm.training and \
self.fast_conv_bn_eval_forward is not None:
self.conv.forward = partial(self.fast_conv_bn_eval_forward,
self.norm, self.conv)
self.efficient_conv_bn_eval_forward is not None:
self.conv.forward = partial(
self.efficient_conv_bn_eval_forward, self.norm,
self.conv)
layer_index += 1
x = self.conv(x)
del self.conv.forward
Expand All @@ -284,20 +286,22 @@ def forward(self,
layer_index += 1
return x

def turn_on_fast_conv_bn_eval(self, fast_conv_bn_eval=True):
# fast_conv_bn_eval works for conv + bn
def turn_on_efficient_conv_bn_eval(self, efficient_conv_bn_eval=True):
# efficient_conv_bn_eval works for conv + bn
# with `track_running_stats` option
if fast_conv_bn_eval and self.norm \
if efficient_conv_bn_eval and self.norm \
and isinstance(self.norm, _BatchNorm) \
and self.norm.track_running_stats:
self.fast_conv_bn_eval_forward = fast_conv_bn_eval_forward
# this is to bypass the flake8 check for 79 chars in one line
enabled = efficient_conv_bn_eval_forward
self.efficient_conv_bn_eval_forward = enabled
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
else:
self.fast_conv_bn_eval_forward = None # type: ignore
self.efficient_conv_bn_eval_forward = None # type: ignore

@staticmethod
def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd,
bn: torch.nn.modules.batchnorm._BatchNorm,
fast_conv_bn_eval=True) -> 'ConvModule':
efficient_conv_bn_eval=True) -> 'ConvModule':
"""Create a ConvModule from a conv and a bn module."""
self = ConvModule.__new__(ConvModule)
super(ConvModule, self).__init__()
Expand Down Expand Up @@ -331,6 +335,6 @@ def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd,
self.norm_name, norm = 'bn', bn
self.add_module(self.norm_name, norm)

self.turn_on_fast_conv_bn_eval(fast_conv_bn_eval)
self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval)

return self
37 changes: 20 additions & 17 deletions tests/test_cnn/test_conv_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,27 +75,30 @@ def test_conv_module():
output = conv(x)
assert output.shape == (1, 8, 255, 255)

# conv + norm with fast mode
fast_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True).eval()
# conv + norm with efficient mode
efficient_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), efficient_conv_bn_eval=True).eval()
plain_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=False).eval()
for fast_param, plain_param in zip(fast_conv.state_dict().values(),
plain_conv.state_dict().values()):
plain_param.copy_(fast_param)

fast_mode_output = fast_conv(x)
3, 8, 2, norm_cfg=dict(type='BN'),
efficient_conv_bn_eval=False).eval()
for efficient_param, plain_param in zip(
efficient_conv.state_dict().values(),
plain_conv.state_dict().values()):
plain_param.copy_(efficient_param)

efficient_mode_output = efficient_conv(x)
plain_mode_output = plain_conv(x)
assert torch.allclose(fast_mode_output, plain_mode_output, atol=1e-5)
assert torch.allclose(efficient_mode_output, plain_mode_output, atol=1e-5)

# `conv` attribute can be dynamically modified in fast mode
fast_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True).eval()
# `conv` attribute can be dynamically modified in efficient mode
efficient_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), efficient_conv_bn_eval=True).eval()
new_conv = nn.Conv2d(3, 8, 2).eval()
fast_conv.conv = new_conv
fast_mode_output = fast_conv(x)
plain_mode_output = fast_conv.activate(fast_conv.norm(new_conv(x)))
assert torch.allclose(fast_mode_output, plain_mode_output, atol=1e-5)
efficient_conv.conv = new_conv
efficient_mode_output = efficient_conv(x)
plain_mode_output = efficient_conv.activate(
efficient_conv.norm(new_conv(x)))
assert torch.allclose(efficient_mode_output, plain_mode_output, atol=1e-5)

# conv + act
conv = ConvModule(3, 8, 2)
Expand Down
Loading