diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 3e033f70aca38..1ddb9c8e5fa9f 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -410,6 +410,23 @@ def quantize(self): for op_type in self._dynamic_quantize_op_type): self._collect_dynamic_quantize_op_threshold( self._dynamic_quantize_op_type) + + # Move sub blocks persistable var to global block + global_block = self._program.global_block() + for _op in global_block.ops: + if _op.type == "while": + _block_id = _op.attr("sub_block").id + _block = self._program.block(_block_id) + persistables = [] + for _name, _var in _block.vars.items(): + if _var.persistable: + global_block._clone_variable(_var) + persistables.append(_name) + for _name in persistables: + _block._remove_var(_name) + persistables.extend(_op.input('X')) + _op.desc.set_input("X", persistables) + return self._program def save_quantized_model(self, @@ -451,10 +468,6 @@ def _load_model_data(self): model_filename=self._model_filename, params_filename=self._params_filename) - if self._program.num_blocks > 1: - _logger.error("The post training quantization requires that the " - "program only has one block.") - if self._optimize_model: self._optimize_fp32_model() @@ -505,23 +518,26 @@ def collect_var_name(var_name_list, persistable_var_names, op_type): self._quantized_act_var_name.add(var_name) persistable_var_names = _all_persistable_var_names(self._program) - for op in self._program.global_block().ops: - op_type = op.type - if self._is_full_quantize and \ - op_type not in self._quantizable_op_type: - _logger.warning(op_type + " is not supported for quantization.") - # For quantized ops, sample inputs and outputs - if op_type in self._quantizable_op_type: - collect_var_name( - _get_op_input_var_names(op), persistable_var_names, op_type) - collect_var_name( - _get_op_output_var_names(op), persistable_var_names, - op_type) - # For other op, only sample output scale - elif op_type in self._out_scale_op_list: - collect_var_name( - _get_op_output_var_names(op), persistable_var_names, - op_type) + for block_id in range(len(self._program.blocks)): + for op in self._program.blocks[block_id].ops: + op_type = op.type + if self._is_full_quantize and \ + op_type not in self._quantizable_op_type: + _logger.warning(op_type + + " is not supported for quantization.") + # For quantized ops, sample inputs and outputs + if op_type in self._quantizable_op_type: + collect_var_name( + _get_op_input_var_names(op), persistable_var_names, + op_type) + collect_var_name( + _get_op_output_var_names(op), persistable_var_names, + op_type) + # For other op, only sample output scale + elif op_type in self._out_scale_op_list: + collect_var_name( + _get_op_output_var_names(op), persistable_var_names, + op_type) def _set_activation_persistable(self): ''' @@ -696,16 +712,17 @@ def _save_input_threhold(self): ''' assert self._algo == "min_max", \ "The algo should be min_max to save input threshold." - for op in self._program.global_block().ops: - if op.type in self._quantizable_op_type: - for var_name in _get_op_input_var_names(op): - assert var_name in self._quantized_var_min - assert var_name in self._quantized_var_max - op._set_attr(var_name + ".min", - self._quantized_var_min[var_name]) - op._set_attr(var_name + ".max", - self._quantized_var_max[var_name]) - op._set_attr("with_quant_attr", True) + for block_id in range(len(self._program.blocks)): + for op in self._program.blocks[block_id].ops: + if op.type in self._quantizable_op_type: + for var_name in _get_op_input_var_names(op): + assert var_name in self._quantized_var_min + assert var_name in self._quantized_var_max + op._set_attr(var_name + ".min", + self._quantized_var_min[var_name]) + op._set_attr(var_name + ".max", + self._quantized_var_max[var_name]) + op._set_attr("with_quant_attr", True) def _collect_activation_abs_min_max(self): ''' @@ -795,7 +812,12 @@ def _update_program(self): activation_quantize_type=self._activation_quantize_type, weight_quantize_type=self._weight_quantize_type, quantizable_op_type=major_quantizable_op_types) - transform_pass.apply(graph) + + for sub_graph in graph.all_sub_graphs(): + # Insert fake_quant/fake_dequantize op must in test graph, so + # set per graph's _for_test is True. + sub_graph._for_test = True + transform_pass.apply(sub_graph) # use AddQuantDequantPass to insert fake_quant_dequant op minor_quantizable_op_types = [] @@ -806,7 +828,10 @@ def _update_program(self): scope=self._scope, place=self._place, quantizable_op_type=minor_quantizable_op_types) - add_quant_dequant_pass.apply(graph) + + for sub_graph in graph.all_sub_graphs(): + sub_graph._for_test = True + add_quant_dequant_pass.apply(sub_graph) # save threshold to scale var node if self._algo in ["KL", "hist"]: @@ -836,7 +861,11 @@ def _update_program(self): activation_bits=self._activation_bits, weight_quantize_type=self._weight_quantize_type, quantizable_op_type=major_quantizable_op_types) - freeze_pass.apply(graph) + + for sub_graph in graph.all_sub_graphs(): + sub_graph._for_test = True + freeze_pass.apply(sub_graph) + self._program = graph.to_program() def _save_output_threshold(self): @@ -888,13 +917,15 @@ def analysis_and_save_info(op_node, out_var_name): save_info(op_node, out_var_name, self._quantized_var_max, "out_max", "post_min_max") - for op in self._program.global_block().ops: - if op.type in (self._quantizable_op_type + self._out_scale_op_list): - out_var_names = _get_op_output_var_names(op) - assert len(out_var_names) == 1, "Post training " + \ - "quantization only support one output for " + op.type - for var_name in out_var_names: - analysis_and_save_info(op, var_name) + for block_id in range(len(self._program.blocks)): + for op in self._program.blocks[block_id].ops: + if op.type in ( + self._quantizable_op_type + self._out_scale_op_list): + out_var_names = _get_op_output_var_names(op) + assert len(out_var_names) == 1, "Post training " + \ + "quantization only support one output for " + op.type + for var_name in out_var_names: + analysis_and_save_info(op, var_name) def _collect_dynamic_quantize_op_threshold(self, target_ops_type): """ diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 94d7a2ed15348..494ea96979719 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -139,6 +139,7 @@ endfunction() if(WIN32) list(REMOVE_ITEM TEST_OPS test_light_nas) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mnist) + list(REMOVE_ITEM TEST_OPS test_post_training_quantization_while) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model) @@ -336,6 +337,7 @@ if(NOT WIN32) set_tests_properties(test_post_training_quantization_mobilenetv1 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY") set_tests_properties(test_post_training_quantization_resnet50 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY") set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT 120) + set_tests_properties(test_post_training_quantization_while PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 120) set_tests_properties(test_weight_quantization_mobilenetv1 PROPERTIES TIMEOUT 120) endif() diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_while.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_while.py new file mode 100644 index 0000000000000..3c3dfd08fccfa --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_while.py @@ -0,0 +1,313 @@ +# copyright (c) 2021 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. +import unittest +import os +import time +import sys +import random +import math +import functools +import contextlib +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.dataset.common import download +from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization + +paddle.enable_static() + +random.seed(0) +np.random.seed(0) + + +class TestPostTrainingQuantization(unittest.TestCase): + def setUp(self): + self.download_path = 'int8/download' + self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + + self.download_path) + self.timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) + self.int8_model_path = os.path.join(os.getcwd(), + "post_training_" + self.timestamp) + try: + os.system("mkdir -p " + self.int8_model_path) + except Exception as e: + print("Failed to create {} due to {}".format(self.int8_model_path, + str(e))) + sys.exit(-1) + + def tearDown(self): + try: + os.system("rm -rf {}".format(self.int8_model_path)) + except Exception as e: + print("Failed to delete {} due to {}".format(self.int8_model_path, + str(e))) + + def cache_unzipping(self, target_folder, zip_path): + cmd = 'tar xf {0} -C {1}'.format(zip_path, target_folder) + os.system(cmd) + + def download_model(self, data_url, data_md5, folder_name): + download(data_url, self.download_path, data_md5) + file_name = data_url.split('/')[-1] + zip_path = os.path.join(self.cache_folder, file_name) + print('Data is downloaded at {0}'.format(zip_path)) + + data_cache_folder = os.path.join(self.cache_folder, folder_name) + self.cache_unzipping(self.cache_folder, zip_path) + return data_cache_folder + + def run_program(self, model_path, batch_size, infer_iterations): + print("test model path:" + model_path) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + [infer_program, feed_dict, fetch_targets] = \ + fluid.io.load_inference_model(model_path, + model_filename='model.pdmodel', + params_filename='model.pdiparams', executor=exe) + val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size) + + img_shape = [1, 28, 28] + test_info = [] + cnt = 0 + periods = [] + for batch_id, data in enumerate(val_reader()): + image = np.array( + [x[0].reshape(img_shape) for x in data]).astype("float32") + input_label = np.array([x[1] for x in data]).astype("int64") + + t1 = time.time() + out = exe.run(infer_program, + feed={feed_dict[0]: image}, + fetch_list=fetch_targets) + t2 = time.time() + period = t2 - t1 + periods.append(period) + + out_label = np.argmax(np.array(out[0]), axis=1) + top1_num = sum(input_label == out_label) + test_info.append(top1_num) + cnt += len(data) + + if (batch_id + 1) == infer_iterations: + break + + throughput = cnt / np.sum(periods) + latency = np.average(periods) + acc1 = np.sum(test_info) / cnt + return (throughput, latency, acc1) + + def generate_quantized_model(self, + model_path, + algo="KL", + quantizable_op_type=["conv2d"], + is_full_quantize=False, + is_use_cache_file=False, + is_optimize_model=False, + batch_size=10, + batch_nums=10): + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.global_scope() + val_reader = paddle.dataset.mnist.train() + + ptq = PostTrainingQuantization( + executor=exe, + model_dir=model_path, + model_filename='model.pdmodel', + params_filename='model.pdiparams', + sample_generator=val_reader, + batch_size=batch_size, + batch_nums=batch_nums, + algo=algo, + quantizable_op_type=quantizable_op_type, + is_full_quantize=is_full_quantize, + optimize_model=is_optimize_model, + is_use_cache_file=is_use_cache_file) + ptq.quantize() + ptq.save_quantized_model( + self.int8_model_path, + model_filename='model.pdmodel', + params_filename='model.pdiparams') + + def run_test(self, + model_name, + data_url, + data_md5, + algo, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + batch_size=10, + infer_iterations=10, + quant_iterations=5): + + origin_model_path = self.download_model(data_url, data_md5, model_name) + #origin_model_path = os.path.join(origin_model_path, model_name) + + print("Start FP32 inference for {0} on {1} images ...".format( + model_name, infer_iterations * batch_size)) + (fp32_throughput, fp32_latency, fp32_acc1) = self.run_program( + origin_model_path, batch_size, infer_iterations) + + print("Start INT8 post training quantization for {0} on {1} images ...". + format(model_name, quant_iterations * batch_size)) + self.generate_quantized_model( + origin_model_path, algo, quantizable_op_type, is_full_quantize, + is_use_cache_file, is_optimize_model, batch_size, quant_iterations) + + print("Start INT8 inference for {0} on {1} images ...".format( + model_name, infer_iterations * batch_size)) + (int8_throughput, int8_latency, int8_acc1) = self.run_program( + self.int8_model_path, batch_size, infer_iterations) + + print("---Post training quantization of {} method---".format(algo)) + print( + "FP32 {0}: batch_size {1}, throughput {2} img/s, latency {3} s, acc1 {4}.". + format(model_name, batch_size, fp32_throughput, fp32_latency, + fp32_acc1)) + print( + "INT8 {0}: batch_size {1}, throughput {2} img/s, latency {3} s, acc1 {4}.\n". + format(model_name, batch_size, int8_throughput, int8_latency, + int8_acc1)) + sys.stdout.flush() + + delta_value = fp32_acc1 - int8_acc1 + self.assertLess(delta_value, diff_threshold) + + +class TestPostTrainingKLForWhile(TestPostTrainingQuantization): + def test_post_training_kl(self): + model_name = "mnist_while" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz" + data_md5 = "2387390beeb37b51dec041c27b8a681f" + algo = "KL" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold, batch_size, infer_iterations, + quant_iterations) + + +class TestPostTraininghistForWhile(TestPostTrainingQuantization): + def test_post_training_hist(self): + model_name = "mnist_while" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz" + data_md5 = "2387390beeb37b51dec041c27b8a681f" + algo = "hist" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold, batch_size, infer_iterations, + quant_iterations) + + +class TestPostTrainingmseForWhile(TestPostTrainingQuantization): + def test_post_training_mse(self): + model_name = "mnist_while" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz" + data_md5 = "2387390beeb37b51dec041c27b8a681f" + algo = "mse" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold, batch_size, infer_iterations, + quant_iterations) + + +class TestPostTrainingavgForWhile(TestPostTrainingQuantization): + def test_post_training_avg(self): + model_name = "mnist_while" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz" + data_md5 = "2387390beeb37b51dec041c27b8a681f" + algo = "avg" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold, batch_size, infer_iterations, + quant_iterations) + + +class TestPostTrainingMinMaxForWhile(TestPostTrainingQuantization): + def test_post_training_min_max(self): + model_name = "mnist_while" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz" + data_md5 = "2387390beeb37b51dec041c27b8a681f" + algo = "min_max" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold, batch_size, infer_iterations, + quant_iterations) + + +class TestPostTrainingAbsMaxForWhile(TestPostTrainingQuantization): + def test_post_training_abs_max(self): + model_name = "mnist_while" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz" + data_md5 = "2387390beeb37b51dec041c27b8a681f" + algo = "abs_max" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold, batch_size, infer_iterations, + quant_iterations) + + +if __name__ == '__main__': + unittest.main()