diff --git a/mmcv/ops/masked_conv.py b/mmcv/ops/masked_conv.py index 82f16d4a39..f591b17088 100644 --- a/mmcv/ops/masked_conv.py +++ b/mmcv/ops/masked_conv.py @@ -52,13 +52,15 @@ def forward(ctx, weight, bias, stride=(stride_h, stride_w), - padding=padding, + padding=(pad_h, pad_w), dilation=(1, 1), groups=1) - features_h, features_w = features.size()[2:] - mask_reshape = mask.reshape(1, 1, features_h, features_w) - mask_bool = mask_reshape > 0 - output = conv * mask_bool + if mask.size()[1:] != conv.size()[2:]: + raise ValueError( + 'The mask is inconsistent with the shape of output_conv.') + conv_h, conv_w = conv.size()[2:] + mask_reshape = mask.reshape(1, 1, conv_h, conv_w) + output = conv * mask_reshape return output batch_size = features.size(0)