Skip to content

Commit

Permalink
fix unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill committed Jun 9, 2022
1 parent 0db6579 commit c819188
Showing 1 changed file with 42 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ def generate_quantized_model(self,
is_full_quantize=False,
is_use_cache_file=False,
is_optimize_model=False,
onnx_format=False):
onnx_format=False,
skip_tensor_list=None):
try:
os.system("mkdir " + self.int8_model)
except Exception as e:
Expand All @@ -267,6 +268,7 @@ def generate_quantized_model(self,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
onnx_format=onnx_format,
skip_tensor_list=skip_tensor_list,
is_use_cache_file=is_use_cache_file)
ptq.quantize()
ptq.save_quantized_model(self.int8_model)
Expand All @@ -282,7 +284,8 @@ def run_test(self,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=False):
onnx_format=False,
skip_tensor_list=None):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
Expand All @@ -296,10 +299,10 @@ def run_test(self,

print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size))
self.generate_quantized_model(model_cache_folder + "/model",
quantizable_op_type, algo, round_type,
is_full_quantize, is_use_cache_file,
is_optimize_model, onnx_format)
self.generate_quantized_model(
model_cache_folder + "/model", quantizable_op_type, algo,
round_type, is_full_quantize, is_use_cache_file, is_optimize_model,
onnx_format, skip_tensor_list)

print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
Expand Down Expand Up @@ -445,5 +448,38 @@ def test_post_training_onnx_format_mobilenetv1(self):
onnx_format=onnx_format)


class TestPostTrainingForMobilenetv1SkipOP(TestPostTrainingQuantization):
def test_post_training_mobilenetv1_skip(self):
model = "MobileNet-V1"
algo = "avg"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
skip_tensor_list = ["fc_0.w_0"]
self.run_test(
model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
skip_tensor_list=skip_tensor_list)


if __name__ == '__main__':
unittest.main()

0 comments on commit c819188

Please sign in to comment.