From 115bbda50c2767420b24231c49c9ada12200600f Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Wed, 22 Jul 2020 10:19:13 +0800 Subject: [PATCH] Fix unittest error (#393) * fix unittest error * fix quant_post unittest error --- tests/test_darts.py | 2 ++ tests/test_earlystop.py | 2 ++ tests/test_quant_embedding.py | 15 ++++++++++++ tests/test_quant_post_only_weight.py | 7 +++--- tests/test_rl_nas.py | 2 ++ tests/test_sa_nas.py | 36 +++++++++++++++++++++------- 6 files changed, 51 insertions(+), 13 deletions(-) diff --git a/tests/test_darts.py b/tests/test_darts.py index 63383ea46f4cd..9c6a5817c466a 100644 --- a/tests/test_darts.py +++ b/tests/test_darts.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +sys.path.append("../") import paddle import unittest import paddle.fluid as fluid diff --git a/tests/test_earlystop.py b/tests/test_earlystop.py index 6af4de6df5336..62ea3ed08b2a0 100644 --- a/tests/test_earlystop.py +++ b/tests/test_earlystop.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +sys.path.append("../") import unittest import paddle from paddleslim.nas import SANAS diff --git a/tests/test_quant_embedding.py b/tests/test_quant_embedding.py index 107ce32ea536c..028d4a6183676 100644 --- a/tests/test_quant_embedding.py +++ b/tests/test_quant_embedding.py @@ -1,3 +1,18 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") import paddle.fluid as fluid import paddleslim.quant as quant import unittest diff --git a/tests/test_quant_post_only_weight.py b/tests/test_quant_post_only_weight.py index e90b13e6fe54d..ede4094dc2c70 100644 --- a/tests/test_quant_post_only_weight.py +++ b/tests/test_quant_post_only_weight.py @@ -77,9 +77,8 @@ def test(program, outputs=[avg_cost, acc_top1, acc_top5]): fetch_list=outputs) iter += 1 if iter % 100 == 0: - print( - 'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'. - format(iter, cost, top1, top5)) + print('eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'. + format(iter, cost, top1, top5)) result[0].append(cost) result[1].append(top1) result[2].append(top5) @@ -99,7 +98,7 @@ def test(program, outputs=[avg_cost, acc_top1, acc_top5]): params_filename='params') quant_post_dynamic( - model_dir='./test_quant_post', + model_dir='./test_quant_post_dynamic', save_model_dir='./test_quant_post_inference', model_filename='model', params_filename='params', diff --git a/tests/test_rl_nas.py b/tests/test_rl_nas.py index cd81eaf15a1f7..ba9c92210acb1 100644 --- a/tests/test_rl_nas.py +++ b/tests/test_rl_nas.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +sys.path.append("../") import unittest import paddle.fluid as fluid from paddleslim.nas import RLNAS diff --git a/tests/test_sa_nas.py b/tests/test_sa_nas.py index 8630136ef2560..1ba7df9befbcd 100644 --- a/tests/test_sa_nas.py +++ b/tests/test_sa_nas.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +sys.path.append("../") import os import sys import unittest @@ -19,24 +21,27 @@ from paddleslim.analysis import flops import numpy as np + def compute_op_num(program): params = {} ch_list = [] for block in program.blocks: for param in block.all_parameters(): - if len(param.shape) == 4: + if len(param.shape) == 4: params[param.name] = param.shape ch_list.append(int(param.shape[0])) return params, ch_list + class TestSANAS(unittest.TestCase): def setUp(self): self.init_test_case() port = np.random.randint(8337, 8773) - self.sanas = SANAS(configs=self.configs, server_addr=("", port), save_checkpoint=None) + self.sanas = SANAS( + configs=self.configs, server_addr=("", port), save_checkpoint=None) def init_test_case(self): - self.configs=[('MobileNetV2BlockSpace', {'block_mask':[0]})] + self.configs = [('MobileNetV2BlockSpace', {'block_mask': [0]})] self.filter_num = np.array([ 3, 4, 8, 12, 16, 24, 32, 48, 64, 80, 96, 128, 144, 160, 192, 224, 256, 320, 384, 512 @@ -53,7 +58,10 @@ def check_chnum_convnum(self, program): conv_list, ch_pro = compute_op_num(program) ### assert conv number - self.assertTrue((repeat_num * 3) == len(conv_list), "the number of conv is NOT match, the number compute from token: {}, actual conv number: {}".format(repeat_num * 3, len(conv_list))) + self.assertTrue((repeat_num * 3) == len( + conv_list + ), "the number of conv is NOT match, the number compute from token: {}, actual conv number: {}". + format(repeat_num * 3, len(conv_list))) ### assert number of channels ch_token = [] @@ -64,7 +72,10 @@ def check_chnum_convnum(self, program): ch_token.append(filter_num) init_ch_num = filter_num - self.assertTrue(str(ch_token) == str(ch_pro), "channel num is WRONG, channel num from token is {}, channel num come fom program is {}".format(str(ch_token), str(ch_pro))) + self.assertTrue( + str(ch_token) == str(ch_pro), + "channel num is WRONG, channel num from token is {}, channel num come fom program is {}". + format(str(ch_token), str(ch_pro))) def test_all_function(self): ### unittest for next_archs @@ -73,7 +84,8 @@ def test_all_function(self): token2arch_program = fluid.Program() with fluid.program_guard(next_program, startup_program): - inputs = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32') + inputs = fluid.data( + name='input', shape=[None, 3, 32, 32], dtype='float32') archs = self.sanas.next_archs() for arch in archs: output = arch(inputs) @@ -85,8 +97,10 @@ def test_all_function(self): ### uniitest for tokens2arch with fluid.program_guard(token2arch_program, startup_program): - inputs = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32') - arch = self.sanas.tokens2arch(self.sanas.current_info()['current_tokens']) + inputs = fluid.data( + name='input', shape=[None, 3, 32, 32], dtype='float32') + arch = self.sanas.tokens2arch(self.sanas.current_info()[ + 'current_tokens']) for arch in archs: output = arch(inputs) inputs = output @@ -94,7 +108,11 @@ def test_all_function(self): ### unittest for current_info current_info = self.sanas.current_info() - self.assertTrue(isinstance(current_info, dict), "the type of current info must be dict, but now is {}".format(type(current_info))) + self.assertTrue( + isinstance(current_info, dict), + "the type of current info must be dict, but now is {}".format( + type(current_info))) + if __name__ == '__main__': unittest.main()