Skip to content

Commit

Permalink
fix mmcv ci for parrots (#782)
Browse files Browse the repository at this point in the history
* fix mmcv ci for parrots

* fix mmcv ci

* fix lint
  • Loading branch information
magicdream2222 committed Jan 14, 2021
1 parent f169fb5 commit 8e3a801
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
2 changes: 1 addition & 1 deletion mmcv/ops/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair

from ..onnx import is_custom_op_loaded
from ..utils import deprecated_api_warning, ext_loader

ext_module = ext_loader.load_ext('_ext',
Expand All @@ -16,6 +15,7 @@ class RoIAlignFunction(Function):
@staticmethod
def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
pool_mode, aligned):
from ..onnx import is_custom_op_loaded
has_custom_op = is_custom_op_loaded()
if has_custom_op:
return g.op(
Expand Down
37 changes: 25 additions & 12 deletions tests/test_runner/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import MagicMock

import pytest
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel

Expand Down Expand Up @@ -47,11 +48,19 @@ def assert_tensor_equal(tensor_a, tensor_b):


def test_get_state_dict():
state_dict_keys = set([
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
'block.norm.bias', 'block.norm.running_mean', 'block.norm.running_var',
'block.norm.num_batches_tracked', 'conv.weight', 'conv.bias'
])
if torch.__version__ == 'parrots':
state_dict_keys = set([
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
'block.norm.bias', 'block.norm.running_mean',
'block.norm.running_var', 'conv.weight', 'conv.bias'
])
else:
state_dict_keys = set([
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
'block.norm.bias', 'block.norm.running_mean',
'block.norm.running_var', 'block.norm.num_batches_tracked',
'conv.weight', 'conv.bias'
])

model = Model()
state_dict = get_state_dict(model)
Expand All @@ -68,8 +77,9 @@ def test_get_state_dict():
model.block.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'],
model.block.norm.running_var)
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
model.block.norm.num_batches_tracked)
if torch.__version__ != 'parrots':
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
model.block.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'], model.conv.weight)
assert_tensor_equal(state_dict['conv.bias'], model.conv.bias)

Expand All @@ -89,8 +99,10 @@ def test_get_state_dict():
wrapped_model.module.block.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'],
wrapped_model.module.block.norm.running_var)
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
wrapped_model.module.block.norm.num_batches_tracked)
if torch.__version__ != 'parrots':
assert_tensor_equal(
state_dict['block.norm.num_batches_tracked'],
wrapped_model.module.block.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'],
wrapped_model.module.conv.weight)
assert_tensor_equal(state_dict['conv.bias'],
Expand All @@ -115,9 +127,10 @@ def test_get_state_dict():
wrapped_model.module.block.module.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'],
wrapped_model.module.block.module.norm.running_var)
assert_tensor_equal(
state_dict['block.norm.num_batches_tracked'],
wrapped_model.module.block.module.norm.num_batches_tracked)
if torch.__version__ != 'parrots':
assert_tensor_equal(
state_dict['block.norm.num_batches_tracked'],
wrapped_model.module.block.module.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'],
wrapped_model.module.conv.module.weight)
assert_tensor_equal(state_dict['conv.bias'],
Expand Down

0 comments on commit 8e3a801

Please sign in to comment.