Skip to content

Commit

Permalink
[Feature] Add fast_conv_bn_eval option in ConvModule for fast validat…
Browse files Browse the repository at this point in the history
…ion and training in Eval mode (#2807)
  • Loading branch information
youkaichao committed Jun 13, 2023
1 parent f01d301 commit 36003b7
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 2 deletions.
132 changes: 130 additions & 2 deletions mmcv/cnn/bricks/conv_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from functools import partial
from typing import Dict, Optional, Tuple, Union

import torch
Expand All @@ -14,6 +15,55 @@
from .padding import build_padding_layer


def fast_conv_bn_eval_forward(bn: _BatchNorm, conv: nn.modules.conv._ConvNd,
x: torch.Tensor):
"""
Implementation based on https://arxiv.org/abs/2305.11624
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
It leverages the associative law between convolution and affine transform,
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
It works for Eval mode of ConvBN blocks during validation, and can be used
for training as well. It reduces memory and computation cost.
Args:
bn (_BatchNorm): a BatchNorm module.
conv (nn._ConvNd): a conv module
x (torch.Tensor): Input feature map.
"""
# These lines of code are designed to deal with various cases
# like bn without affine transform, and conv without bias
weight_on_the_fly = conv.weight
if conv.bias is not None:
bias_on_the_fly = conv.bias
else:
bias_on_the_fly = torch.zeros_like(bn.running_var)

if bn.weight is not None:
bn_weight = bn.weight
else:
bn_weight = torch.ones_like(bn.running_var)

if bn.bias is not None:
bn_bias = bn.bias
else:
bn_bias = torch.zeros_like(bn.running_var)

# shape of [C_out, 1, 1, 1] in Conv2d
weight_coeff = torch.rsqrt(bn.running_var +
bn.eps).reshape([-1] + [1] *
(len(conv.weight.shape) - 1))
# shape of [C_out, 1, 1, 1] in Conv2d
coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff

# shape of [C_out, C_in, k, k] in Conv2d
weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
# shape of [C_out] in Conv2d
bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\
(bias_on_the_fly - bn.running_mean)

return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly)


@MODELS.register_module()
class ConvModule(nn.Module):
"""A conv block that bundles conv/norm/activation layers.
Expand Down Expand Up @@ -65,6 +115,9 @@ class ConvModule(nn.Module):
sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm").
Default: ('conv', 'norm', 'act').
fast_conv_bn_eval (bool): Whether use fast conv when the consecutive
bn is in eval mode (either training or testing), as proposed in
https://arxiv.org/abs/2305.11624 . Default: False.
"""

_abbr_ = 'conv_block'
Expand All @@ -84,7 +137,8 @@ def __init__(self,
inplace: bool = True,
with_spectral_norm: bool = False,
padding_mode: str = 'zeros',
order: tuple = ('conv', 'norm', 'act')):
order: tuple = ('conv', 'norm', 'act'),
fast_conv_bn_eval: bool = False):
super().__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict)
Expand Down Expand Up @@ -155,6 +209,16 @@ def __init__(self,
else:
self.norm_name = None # type: ignore

# fast_conv_bn_eval works for conv + bn
# with `track_running_stats` option
if fast_conv_bn_eval and self.norm and isinstance(
self.norm, _BatchNorm) and self.norm.track_running_stats:
self.fast_conv_bn_eval_forward = partial(fast_conv_bn_eval_forward,
self.norm, self.conv)
else:
self.fast_conv_bn_eval_forward = None # type: ignore
self.original_conv_forward = self.conv.forward

# build activation layer
if self.with_activation:
act_cfg_ = act_cfg.copy() # type: ignore
Expand Down Expand Up @@ -200,13 +264,77 @@ def forward(self,
x: torch.Tensor,
activate: bool = True,
norm: bool = True) -> torch.Tensor:
for layer in self.order:
layer_index = 0
while layer_index < len(self.order):
layer = self.order[layer_index]
if layer == 'conv':
if self.with_explicit_padding:
x = self.padding_layer(x)
# if the next operation is norm and we have a norm layer in
# eval mode and we have enabled fast_conv_bn_eval for the conv
# operator, then activate the optimized forward and skip the
# next norm operator since it has been fused
if layer_index + 1 < len(self.order) and \
self.order[layer_index + 1] == 'norm' and norm and \
self.with_norm and not self.norm.training and \
self.fast_conv_bn_eval_forward is not None:
self.conv.forward = self.fast_conv_bn_eval_forward
layer_index += 1
else:
self.conv.forward = self.original_conv_forward
x = self.conv(x)
elif layer == 'norm' and norm and self.with_norm:
x = self.norm(x)
elif layer == 'act' and activate and self.with_activation:
x = self.activate(x)
layer_index += 1
return x

@staticmethod
def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd,
bn: torch.nn.modules.batchnorm._BatchNorm,
fast_conv_bn_eval=True) -> 'ConvModule':
"""Create a ConvModule from a conv and a bn module."""
self = ConvModule.__new__(ConvModule)
super(ConvModule, self).__init__()

self.conv_cfg = None
self.norm_cfg = None
self.act_cfg = None
self.inplace = False
self.with_spectral_norm = False
self.with_explicit_padding = False
self.order = ('conv', 'norm', 'act')

self.with_norm = True
self.with_activation = False
self.with_bias = conv.bias is not None

# build convolution layer
self.conv = conv
# export the attributes of self.conv to a higher level for convenience
self.in_channels = self.conv.in_channels
self.out_channels = self.conv.out_channels
self.kernel_size = self.conv.kernel_size
self.stride = self.conv.stride
self.padding = self.conv.padding
self.dilation = self.conv.dilation
self.transposed = self.conv.transposed
self.output_padding = self.conv.output_padding
self.groups = self.conv.groups

# build normalization layers
self.norm_name, norm = 'bn', bn
self.add_module(self.norm_name, norm)

# fast_conv_bn_eval works for conv + bn
# with `track_running_stats` option
if fast_conv_bn_eval and self.norm and isinstance(
self.norm, _BatchNorm) and self.norm.track_running_stats:
self.fast_conv_bn_eval_forward = partial(fast_conv_bn_eval_forward,
self.norm, self.conv)
else:
self.fast_conv_bn_eval_forward = None # type: ignore
self.original_conv_forward = self.conv.forward

return self
10 changes: 10 additions & 0 deletions tests/test_cnn/test_conv_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ def test_conv_module():
output = conv(x)
assert output.shape == (1, 8, 255, 255)

# conv + norm with fast mode
conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True)
conv.norm.eval()
x = torch.rand(1, 3, 256, 256)
fast_mode_output = conv(x)
conv.conv.forward = conv.original_conv_forward
plain_implementation = conv.activate(conv.norm(conv.conv(x)))
assert torch.allclose(fast_mode_output, plain_implementation, atol=1e-5)

# conv + act
conv = ConvModule(3, 8, 2)
assert conv.with_activation
Expand Down

0 comments on commit 36003b7

Please sign in to comment.