From 8dab3a3d601a75182bca09b87f61b7083be8a56f Mon Sep 17 00:00:00 2001 From: zcc-zjut Date: Mon, 7 Nov 2022 16:29:09 +0800 Subject: [PATCH] Masked_conv2d NPU --- mmcv/ops/masked_conv.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mmcv/ops/masked_conv.py b/mmcv/ops/masked_conv.py index 82f16d4a39..a372fb9ed4 100644 --- a/mmcv/ops/masked_conv.py +++ b/mmcv/ops/masked_conv.py @@ -52,13 +52,15 @@ def forward(ctx, weight, bias, stride=(stride_h, stride_w), - padding=padding, + padding=(pad_h, pad_w), dilation=(1, 1), groups=1) - features_h, features_w = features.size()[2:] - mask_reshape = mask.reshape(1, 1, features_h, features_w) - mask_bool = mask_reshape > 0 - output = conv * mask_bool + if mask.size()[1:] != conv.size()[2:]: + raise ValueError( + 'The mask is consistent with the shape of output_conv.') + conv_h, conv_w = conv.size()[2:] + mask_reshape = mask.reshape(1, 1, conv_h, conv_w) + output = conv * mask_reshape return output batch_size = features.size(0)