diff --git a/mmcv/cnn/bricks/wrappers.py b/mmcv/cnn/bricks/wrappers.py index 07eb04ee32..fc98c35584 100644 --- a/mmcv/cnn/bricks/wrappers.py +++ b/mmcv/cnn/bricks/wrappers.py @@ -41,7 +41,7 @@ def backward(ctx, grad: torch.Tensor) -> tuple: class Conv2d(nn.Conv2d): def forward(self, x: torch.Tensor) -> torch.Tensor: - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: out_shape = [x.shape[0], self.out_channels] for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size, self.padding, self.stride, self.dilation): @@ -62,7 +62,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Conv3d(nn.Conv3d): def forward(self, x: torch.Tensor) -> torch.Tensor: - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: out_shape = [x.shape[0], self.out_channels] for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size, self.padding, self.stride, self.dilation): @@ -84,7 +84,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ConvTranspose2d(nn.ConvTranspose2d): def forward(self, x: torch.Tensor) -> torch.Tensor: - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: out_shape = [x.shape[0], self.out_channels] for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size, self.padding, self.stride, @@ -106,7 +106,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ConvTranspose3d(nn.ConvTranspose3d): def forward(self, x: torch.Tensor) -> torch.Tensor: - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: out_shape = [x.shape[0], self.out_channels] for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size, self.padding, self.stride, @@ -127,7 +127,7 @@ class MaxPool2d(nn.MaxPool2d): def forward(self, x: torch.Tensor) -> torch.Tensor: # PyTorch 1.9 does not support empty tensor inference yet - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): + if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0: out_shape = list(x.shape[:2]) for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), _pair(self.padding), _pair(self.stride), @@ -145,7 +145,7 @@ class MaxPool3d(nn.MaxPool3d): def forward(self, x: torch.Tensor) -> torch.Tensor: # PyTorch 1.9 does not support empty tensor inference yet - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): + if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0: out_shape = list(x.shape[:2]) for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size), _triple(self.padding), @@ -164,7 +164,7 @@ class Linear(torch.nn.Linear): def forward(self, x: torch.Tensor) -> torch.Tensor: # empty tensor forward of Linear layer is supported in Pytorch 1.6 - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)): + if obsolete_torch_version(TORCH_VERSION, (1, 5)) and x.numel() == 0: out_shape = [x.shape[0], self.out_features] empty = NewEmptyTensorOp.apply(x, out_shape) if self.training: diff --git a/tests/test_cnn/test_wrappers.py b/tests/test_cnn/test_wrappers.py index 02e0f13cd7..8c76ccbdd4 100644 --- a/tests/test_cnn/test_wrappers.py +++ b/tests/test_cnn/test_wrappers.py @@ -4,6 +4,8 @@ import pytest import torch import torch.nn as nn +from mmengine.utils import digit_version +from mmengine.utils.dl_utils import TORCH_VERSION from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, Linear, MaxPool2d, MaxPool3d) @@ -374,3 +376,21 @@ def test_nn_op_forward_called(): wrapper = Linear(3, 3) wrapper(x_normal) nn_module_forward.assert_called_with(x_normal) + + +@pytest.mark.skipif( + digit_version(TORCH_VERSION) < digit_version('1.10'), + reason='MaxPool2d and MaxPool3d will fail fx for torch<=1.9') +def test_fx_compatibility(): + from torch import fx + + # ensure the fx trace can pass the network + for Net in (MaxPool2d, MaxPool3d): + net = Net(1) + gm_module = fx.symbolic_trace(net) # noqa: F841 + for Net in (Linear, ): + net = Net(1, 1) + gm_module = fx.symbolic_trace(net) # noqa: F841 + for Net in (Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d): + net = Net(1, 1, 1) + gm_module = fx.symbolic_trace(net) # noqa: F841