diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index 1f8e160517..a8a55ff316 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings +from functools import partial from typing import Dict, Optional, Tuple, Union import torch @@ -14,6 +15,55 @@ from .padding import build_padding_layer +def fast_conv_bn_eval_forward(bn: _BatchNorm, conv: nn.modules.conv._ConvNd, + x: torch.Tensor): + """ + Implementation based on https://arxiv.org/abs/2305.11624 + "Tune-Mode ConvBN Blocks For Efficient Transfer Learning" + It leverages the associative law between convolution and affine transform, + i.e., normalize (weight conv feature) = (normalize weight) conv feature. + It works for Eval mode of ConvBN blocks during validation, and can be used + for training as well. It reduces memory and computation cost. + + Args: + bn (_BatchNorm): a BatchNorm module. + conv (nn._ConvNd): a conv module + x (torch.Tensor): Input feature map. + """ + # These lines of code are designed to deal with various cases + # like bn without affine transform, and conv without bias + weight_on_the_fly = conv.weight + if conv.bias is not None: + bias_on_the_fly = conv.bias + else: + bias_on_the_fly = torch.zeros_like(bn.running_var) + + if bn.weight is not None: + bn_weight = bn.weight + else: + bn_weight = torch.ones_like(bn.running_var) + + if bn.bias is not None: + bn_bias = bn.bias + else: + bn_bias = torch.zeros_like(bn.running_var) + + # shape of [C_out, 1, 1, 1] in Conv2d + weight_coeff = torch.rsqrt(bn.running_var + + bn.eps).reshape([-1] + [1] * + (len(conv.weight.shape) - 1)) + # shape of [C_out, 1, 1, 1] in Conv2d + coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff + + # shape of [C_out, C_in, k, k] in Conv2d + weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly + # shape of [C_out] in Conv2d + bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\ + (bias_on_the_fly - bn.running_mean) + + return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly) + + @MODELS.register_module() class ConvModule(nn.Module): """A conv block that bundles conv/norm/activation layers. @@ -65,6 +115,9 @@ class ConvModule(nn.Module): sequence of "conv", "norm" and "act". Common examples are ("conv", "norm", "act") and ("act", "conv", "norm"). Default: ('conv', 'norm', 'act'). + fast_conv_bn_eval (bool): Whether use fast conv when the consecutive + bn is in eval mode (either training or testing), as proposed in + https://arxiv.org/abs/2305.11624 . Default: False. """ _abbr_ = 'conv_block' @@ -84,7 +137,8 @@ def __init__(self, inplace: bool = True, with_spectral_norm: bool = False, padding_mode: str = 'zeros', - order: tuple = ('conv', 'norm', 'act')): + order: tuple = ('conv', 'norm', 'act'), + fast_conv_bn_eval: bool = False): super().__init__() assert conv_cfg is None or isinstance(conv_cfg, dict) assert norm_cfg is None or isinstance(norm_cfg, dict) @@ -155,6 +209,16 @@ def __init__(self, else: self.norm_name = None # type: ignore + # fast_conv_bn_eval works for conv + bn + # with `track_running_stats` option + if fast_conv_bn_eval and self.norm and isinstance( + self.norm, _BatchNorm) and self.norm.track_running_stats: + self.fast_conv_bn_eval_forward = partial(fast_conv_bn_eval_forward, + self.norm, self.conv) + else: + self.fast_conv_bn_eval_forward = None # type: ignore + self.original_conv_forward = self.conv.forward + # build activation layer if self.with_activation: act_cfg_ = act_cfg.copy() # type: ignore @@ -200,13 +264,77 @@ def forward(self, x: torch.Tensor, activate: bool = True, norm: bool = True) -> torch.Tensor: - for layer in self.order: + layer_index = 0 + while layer_index < len(self.order): + layer = self.order[layer_index] if layer == 'conv': if self.with_explicit_padding: x = self.padding_layer(x) + # if the next operation is norm and we have a norm layer in + # eval mode and we have enabled fast_conv_bn_eval for the conv + # operator, then activate the optimized forward and skip the + # next norm operator since it has been fused + if layer_index + 1 < len(self.order) and \ + self.order[layer_index + 1] == 'norm' and norm and \ + self.with_norm and not self.norm.training and \ + self.fast_conv_bn_eval_forward is not None: + self.conv.forward = self.fast_conv_bn_eval_forward + layer_index += 1 + else: + self.conv.forward = self.original_conv_forward x = self.conv(x) elif layer == 'norm' and norm and self.with_norm: x = self.norm(x) elif layer == 'act' and activate and self.with_activation: x = self.activate(x) + layer_index += 1 return x + + @staticmethod + def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd, + bn: torch.nn.modules.batchnorm._BatchNorm, + fast_conv_bn_eval=True) -> 'ConvModule': + """Create a ConvModule from a conv and a bn module.""" + self = ConvModule.__new__(ConvModule) + super(ConvModule, self).__init__() + + self.conv_cfg = None + self.norm_cfg = None + self.act_cfg = None + self.inplace = False + self.with_spectral_norm = False + self.with_explicit_padding = False + self.order = ('conv', 'norm', 'act') + + self.with_norm = True + self.with_activation = False + self.with_bias = conv.bias is not None + + # build convolution layer + self.conv = conv + # export the attributes of self.conv to a higher level for convenience + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = self.conv.padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + # build normalization layers + self.norm_name, norm = 'bn', bn + self.add_module(self.norm_name, norm) + + # fast_conv_bn_eval works for conv + bn + # with `track_running_stats` option + if fast_conv_bn_eval and self.norm and isinstance( + self.norm, _BatchNorm) and self.norm.track_running_stats: + self.fast_conv_bn_eval_forward = partial(fast_conv_bn_eval_forward, + self.norm, self.conv) + else: + self.fast_conv_bn_eval_forward = None # type: ignore + self.original_conv_forward = self.conv.forward + + return self diff --git a/tests/test_cnn/test_conv_module.py b/tests/test_cnn/test_conv_module.py index d31167a743..af7fc25ec1 100644 --- a/tests/test_cnn/test_conv_module.py +++ b/tests/test_cnn/test_conv_module.py @@ -75,6 +75,16 @@ def test_conv_module(): output = conv(x) assert output.shape == (1, 8, 255, 255) + # conv + norm with fast mode + conv = ConvModule( + 3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True) + conv.norm.eval() + x = torch.rand(1, 3, 256, 256) + fast_mode_output = conv(x) + conv.conv.forward = conv.original_conv_forward + plain_implementation = conv.activate(conv.norm(conv.conv(x))) + assert torch.allclose(fast_mode_output, plain_implementation, atol=1e-5) + # conv + act conv = ConvModule(3, 8, 2) assert conv.with_activation