From 3920035366ea18dd204549c19135085c5862b67e Mon Sep 17 00:00:00 2001 From: Kenny Date: Fri, 25 Jun 2021 16:22:56 +0800 Subject: [PATCH 1/5] fix saconv --- mmcv/ops/saconv.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mmcv/ops/saconv.py b/mmcv/ops/saconv.py index cd7eea122f..d2d3a441d7 100644 --- a/mmcv/ops/saconv.py +++ b/mmcv/ops/saconv.py @@ -98,6 +98,8 @@ def forward(self, x): switch = self.switch(avg_x) # sac weight = self._get_weight(self.weight) + zero_bias = torch.zeros(self.out_channels, device=weight.device) + if self.use_deform: offset = self.offset_s(avg_x) out_s = deform_conv2d(x, offset, weight, self.stride, self.padding, @@ -106,7 +108,7 @@ def forward(self, x): if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots': out_s = super().conv2d_forward(x, weight) else: - out_s = super()._conv_forward(x, weight) + out_s = super()._conv_forward(x, weight, zero_bias) ori_p = self.padding ori_d = self.dilation self.padding = tuple(3 * p for p in self.padding) @@ -120,7 +122,7 @@ def forward(self, x): if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots': out_l = super().conv2d_forward(x, weight) else: - out_l = super()._conv_forward(x, weight) + out_l = super()._conv_forward(x, weight, zero_bias) out = switch * out_s + (1 - switch) * out_l self.padding = ori_p self.dilation = ori_d From 88434812f42b23254c64d50546a8950883197058 Mon Sep 17 00:00:00 2001 From: Kenny Date: Fri, 25 Jun 2021 21:35:50 +0800 Subject: [PATCH 2/5] update --- mmcv/ops/saconv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmcv/ops/saconv.py b/mmcv/ops/saconv.py index d2d3a441d7..ec7d2fe853 100644 --- a/mmcv/ops/saconv.py +++ b/mmcv/ops/saconv.py @@ -98,7 +98,8 @@ def forward(self, x): switch = self.switch(avg_x) # sac weight = self._get_weight(self.weight) - zero_bias = torch.zeros(self.out_channels, device=weight.device) + zero_bias = torch.zeros( + self.out_channels, device=weight.device, dtype=weight.dtype) if self.use_deform: offset = self.offset_s(avg_x) From b9b887db5e06a77c116639ac2b9db122b2a693a8 Mon Sep 17 00:00:00 2001 From: Kenny Date: Fri, 25 Jun 2021 22:05:25 +0800 Subject: [PATCH 3/5] update --- mmcv/ops/saconv.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mmcv/ops/saconv.py b/mmcv/ops/saconv.py index ec7d2fe853..dcb8dedabf 100644 --- a/mmcv/ops/saconv.py +++ b/mmcv/ops/saconv.py @@ -109,6 +109,7 @@ def forward(self, x): if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots': out_s = super().conv2d_forward(x, weight) else: + # bias is a required argument of _conv_forward in torch 1.9.0 out_s = super()._conv_forward(x, weight, zero_bias) ori_p = self.padding ori_d = self.dilation @@ -123,6 +124,7 @@ def forward(self, x): if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots': out_l = super().conv2d_forward(x, weight) else: + # bias is a required argument of _conv_forward in torch 1.9.0 out_l = super()._conv_forward(x, weight, zero_bias) out = switch * out_s + (1 - switch) * out_l self.padding = ori_p From 3798aba22599e5aef3c66276b1e77257842f4a45 Mon Sep 17 00:00:00 2001 From: Kenny Date: Tue, 29 Jun 2021 12:03:41 +0800 Subject: [PATCH 4/5] fix --- mmcv/ops/saconv.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/saconv.py b/mmcv/ops/saconv.py index dcb8dedabf..e6c9da050d 100644 --- a/mmcv/ops/saconv.py +++ b/mmcv/ops/saconv.py @@ -108,9 +108,11 @@ def forward(self, x): else: if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots': out_s = super().conv2d_forward(x, weight) - else: - # bias is a required argument of _conv_forward in torch 1.9.0 + elif TORCH_VERSION >= '1.8.0': + # bias is a required argument of _conv_forward in torch 1.8.0 out_s = super()._conv_forward(x, weight, zero_bias) + else: + out_s = super()._conv_forward(x, weight) ori_p = self.padding ori_d = self.dilation self.padding = tuple(3 * p for p in self.padding) @@ -123,9 +125,12 @@ def forward(self, x): else: if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots': out_l = super().conv2d_forward(x, weight) - else: - # bias is a required argument of _conv_forward in torch 1.9.0 + elif TORCH_VERSION >= '1.8.0': + # bias is a required argument of _conv_forward in torch 1.8.0 out_l = super()._conv_forward(x, weight, zero_bias) + else: + out_l = super()._conv_forward(x, weight) + out = switch * out_s + (1 - switch) * out_l self.padding = ori_p self.dilation = ori_d From 07152acdec9ad402fcf911238c7796cb48d0191c Mon Sep 17 00:00:00 2001 From: Kenny Date: Tue, 29 Jun 2021 21:58:34 +0800 Subject: [PATCH 5/5] use LooseVersion --- mmcv/ops/saconv.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/saconv.py b/mmcv/ops/saconv.py index e6c9da050d..6b19ce5719 100644 --- a/mmcv/ops/saconv.py +++ b/mmcv/ops/saconv.py @@ -1,3 +1,5 @@ +from distutils.version import LooseVersion + import torch import torch.nn as nn import torch.nn.functional as F @@ -106,9 +108,10 @@ def forward(self, x): out_s = deform_conv2d(x, offset, weight, self.stride, self.padding, self.dilation, self.groups, 1) else: - if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots': + if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0') + or TORCH_VERSION == 'parrots'): out_s = super().conv2d_forward(x, weight) - elif TORCH_VERSION >= '1.8.0': + elif LooseVersion(TORCH_VERSION) >= LooseVersion('1.8.0'): # bias is a required argument of _conv_forward in torch 1.8.0 out_s = super()._conv_forward(x, weight, zero_bias) else: @@ -123,9 +126,10 @@ def forward(self, x): out_l = deform_conv2d(x, offset, weight, self.stride, self.padding, self.dilation, self.groups, 1) else: - if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots': + if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0') + or TORCH_VERSION == 'parrots'): out_l = super().conv2d_forward(x, weight) - elif TORCH_VERSION >= '1.8.0': + elif LooseVersion(TORCH_VERSION) >= LooseVersion('1.8.0'): # bias is a required argument of _conv_forward in torch 1.8.0 out_l = super()._conv_forward(x, weight, zero_bias) else: