Skip to content

Commit

Permalink
Merge 07152ac into 1d5ee6e
Browse files Browse the repository at this point in the history
  • Loading branch information
kennymckormick committed Jun 29, 2021
2 parents 1d5ee6e + 07152ac commit e02a978
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions mmcv/ops/saconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def forward(self, x):
switch = self.switch(avg_x)
# sac
weight = self._get_weight(self.weight)
zero_bias = torch.zeros(
self.out_channels, device=weight.device, dtype=weight.dtype)

if self.use_deform:
offset = self.offset_s(avg_x)
out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
Expand All @@ -108,6 +111,9 @@ def forward(self, x):
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
or TORCH_VERSION == 'parrots'):
out_s = super().conv2d_forward(x, weight)
elif LooseVersion(TORCH_VERSION) >= LooseVersion('1.8.0'):
# bias is a required argument of _conv_forward in torch 1.8.0
out_s = super()._conv_forward(x, weight, zero_bias)
else:
out_s = super()._conv_forward(x, weight)
ori_p = self.padding
Expand All @@ -123,8 +129,12 @@ def forward(self, x):
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
or TORCH_VERSION == 'parrots'):
out_l = super().conv2d_forward(x, weight)
elif LooseVersion(TORCH_VERSION) >= LooseVersion('1.8.0'):
# bias is a required argument of _conv_forward in torch 1.8.0
out_l = super()._conv_forward(x, weight, zero_bias)
else:
out_l = super()._conv_forward(x, weight)

out = switch * out_s + (1 - switch) * out_l
self.padding = ori_p
self.dilation = ori_d
Expand Down

0 comments on commit e02a978

Please sign in to comment.