diff --git a/demo/quant/quant_aware/README.md b/demo/quant/quant_aware/README.md index 5fae50c5ff752..0f690179739ca 100644 --- a/demo/quant/quant_aware/README.md +++ b/demo/quant/quant_aware/README.md @@ -20,8 +20,7 @@ quant_config = { 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], 'dtype': 'int8', 'window_size': 10000, - 'moving_rate': 0.9, - 'quant_weight_only': False + 'moving_rate': 0.9 } ``` @@ -49,7 +48,7 @@ compiled_train_prog = compiled_train_prog.with_data_parallel( ### 4. freeze program ``` -float_program, int8_program = convert(val_program, +float_program, int8_program = convert(val_program, place, quant_config, scope=None, diff --git a/demo/quant/quant_aware/train.py b/demo/quant/quant_aware/train.py index 45b1aa72c062e..1933a7d88aafa 100644 --- a/demo/quant/quant_aware/train.py +++ b/demo/quant/quant_aware/train.py @@ -78,27 +78,24 @@ def compress(args): # 1. quantization configs ############################################################################################################ quant_config = { - # weight quantize type, default is 'abs_max' - 'weight_quantize_type': 'abs_max', - # activation quantize type, default is 'abs_max' + # weight quantize type, default is 'channel_wise_abs_max' + 'weight_quantize_type': 'channel_wise_abs_max', + # activation quantize type, default is 'moving_average_abs_max' 'activation_quantize_type': 'moving_average_abs_max', # weight quantize bit num, default is 8 'weight_bits': 8, # activation quantize bit num, default is 8 'activation_bits': 8, - # op of name_scope in not_quant_pattern list, will not quantized + # ops of name_scope in not_quant_pattern list, will not be quantized 'not_quant_pattern': ['skip_quant'], - # op of types in quantize_op_types, will quantized + # ops of type in quantize_op_types, will be quantized 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], - # data type after quantization, default is 'int8' + # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' 'dtype': 'int8', # window size for 'range_abs_max' quantization. defaulf is 10000 'window_size': 10000, # The decay coefficient of moving average, default is 0.9 'moving_rate': 0.9, - # if set quant_weight_only True, then only quantize parameters of layers which need quantization, - # and insert anti-quantization op for parameters of these layers. - 'quant_weight_only': False } train_reader = None @@ -141,8 +138,10 @@ def compress(args): # According to the weight and activation quantization type, the graph will be added # some fake quantize operators and fake dequantize operators. ############################################################################################################ - val_program = quant_aware(val_program, place, quant_config, scope=None, for_test=True) - compiled_train_prog = quant_aware(train_prog, place, quant_config, scope=None, for_test=False) + val_program = quant_aware( + val_program, place, quant_config, scope=None, for_test=True) + compiled_train_prog = quant_aware( + train_prog, place, quant_config, scope=None, for_test=False) opt = create_optimizer(args) opt.minimize(avg_cost) @@ -152,7 +151,8 @@ def compress(args): if args.pretrained_model: def if_exist(var): - return os.path.exists(os.path.join(args.pretrained_model, var.name)) + return os.path.exists( + os.path.join(args.pretrained_model, var.name)) fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) @@ -199,9 +199,9 @@ def train(epoch, compiled_train_prog): build_strategy.sync_batch_norm = False exec_strategy = fluid.ExecutionStrategy() compiled_train_prog = compiled_train_prog.with_data_parallel( - loss_name=avg_cost.name, - build_strategy=build_strategy, - exec_strategy=exec_strategy) + loss_name=avg_cost.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) batch_id = 0 for data in train_reader(): @@ -242,8 +242,8 @@ def train(epoch, compiled_train_prog): # 4. Save inference model ############################################################################################################ model_path = os.path.join(quantization_model_save_dir, args.model, - 'act_' + quant_config['activation_quantize_type'] + '_w_' + quant_config[ - 'weight_quantize_type']) + 'act_' + quant_config['activation_quantize_type'] + + '_w_' + quant_config['weight_quantize_type']) float_path = os.path.join(model_path, 'float') int8_path = os.path.join(model_path, 'int8') if not os.path.isdir(model_path): @@ -252,7 +252,8 @@ def train(epoch, compiled_train_prog): fluid.io.save_inference_model( dirname=float_path, feeded_var_names=[image.name], - target_vars=[out], executor=exe, + target_vars=[out], + executor=exe, main_program=float_program, model_filename=float_path + '/model', params_filename=float_path + '/params') @@ -260,7 +261,8 @@ def train(epoch, compiled_train_prog): fluid.io.save_inference_model( dirname=int8_path, feeded_var_names=[image.name], - target_vars=[out], executor=exe, + target_vars=[out], + executor=exe, main_program=int8_program, model_filename=int8_path + '/model', params_filename=int8_path + '/params') diff --git a/docs/docs/api/quantization_api.md b/docs/docs/api/quantization_api.md index 356b937e8495d..1169308841d33 100644 --- a/docs/docs/api/quantization_api.md +++ b/docs/docs/api/quantization_api.md @@ -4,29 +4,50 @@ 通过字典配置量化参数 ``` -quant_config_default = { - 'weight_quantize_type': 'abs_max', - 'activation_quantize_type': 'abs_max', +TENSORRT_OP_TYPES = [ + 'mul', 'conv2d', 'pool2d', 'depthwise_conv2d', 'elementwise_add', + 'leaky_relu' +] +TRANSFORM_PASS_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul'] + +QUANT_DEQUANT_PASS_OP_TYPES = [ + "pool2d", "elementwise_add", "concat", "softmax", "argmax", "transpose", + "equal", "gather", "greater_equal", "greater_than", "less_equal", + "less_than", "mean", "not_equal", "reshape", "reshape2", + "bilinear_interp", "nearest_interp", "trilinear_interp", "slice", + "squeeze", "elementwise_sub", "relu", "relu6", "leaky_relu", "tanh", "swish" + ] + +_quant_config_default = { + # weight quantize type, default is 'channel_wise_abs_max' + 'weight_quantize_type': 'channel_wise_abs_max', + # activation quantize type, default is 'moving_average_abs_max' + 'activation_quantize_type': 'moving_average_abs_max', + # weight quantize bit num, default is 8 'weight_bits': 8, + # activation quantize bit num, default is 8 'activation_bits': 8, # ops of name_scope in not_quant_pattern list, will not be quantized 'not_quant_pattern': ['skip_quant'], # ops of type in quantize_op_types, will be quantized - 'quantize_op_types': - ['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'], + 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' 'dtype': 'int8', # window size for 'range_abs_max' quantization. defaulf is 10000 'window_size': 10000, # The decay coefficient of moving average, default is 0.9 'moving_rate': 0.9, + # if True, 'quantize_op_types' will be TENSORRT_OP_TYPES + 'for_tensorrt': False, + # if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES + 'is_full_quantize': False } ``` **参数:** -- **weight_quantize_type(str)** - 参数量化方式。可选``'abs_max'``, ``'channel_wise_abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'``。 默认``'abs_max'``。 -- **activation_quantize_type(str)** - 激活量化方式,可选``'abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'``,默认``'abs_max'``。 +- **weight_quantize_type(str)** - 参数量化方式。可选``'abs_max'``, ``'channel_wise_abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'``。如果使用``TensorRT``加载量化后的模型来预测,请使用``'channel_wise_abs_max'``。 默认``'channel_wise_abs_max'``。 +- **activation_quantize_type(str)** - 激活量化方式,可选``'abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'``。如果使用``TensorRT``加载量化后的模型来预测,请使用``'range_abs_max', 'moving_average_abs_max'``。,默认``'moving_average_abs_max'``。 - **weight_bits(int)** - 参数量化bit数,默认8, 推荐设为8。 - **activation_bits(int)** - 激活量化bit数,默认8, 推荐设为8。 - **not_quant_pattern(str | list[str])** - 所有``name_scope``包含``'not_quant_pattern'``字符串的``op``,都不量化, 设置方式请参考[*fluid.name_scope*](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/name_scope_cn.html#name-scope)。 @@ -34,7 +55,12 @@ quant_config_default = { - **dtype(int8)** - 量化后的参数类型,默认 ``int8``, 目前仅支持``int8``。 - **window_size(int)** - ``'range_abs_max'``量化方式的``window size``,默认10000。 - **moving_rate(int)** - ``'moving_average_abs_max'``量化方式的衰减系数,默认 0.9。 +- **for_tensorrt(bool)** - 量化后的模型是否使用``TensorRT``进行预测。如果是的话,量化op类型为:``TENSORRT_OP_TYPES``。默认值为False. +- **is_full_quantize(bool)** - 是否量化所有可支持op类型。默认值为False. +!!! note "注意事项" + +- 目前``Paddle-Lite``有int8 kernel来加速的op只有 ``['conv2d', 'depthwise_conv2d', 'mul']``, 其他op的int8 kernel将陆续支持。 ## quant_aware paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False)[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py) @@ -67,7 +93,7 @@ paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False) -## convert +## convert paddleslim.quant.convert(program, place, config, scope=None, save_int8=False)[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py) @@ -135,7 +161,7 @@ inference_prog = quant.convert(quant_eval_program, place, config) 更详细的用法请参考 量化训练demo。 ## quant_post -paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_generator, model_filename=None, params_filename=None, batch_size=16,batch_nums=None, scope=None, algo='KL', quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"])[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py) +paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_generator, model_filename=None, params_filename=None, batch_size=16,batch_nums=None, scope=None, algo='KL', quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], is_full_quantize=False, is_use_cache_file=False, cache_dir="./temp_post_training")[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py) : 对保存在``${model_dir}``下的模型进行量化,使用``sample_generator``的数据进行参数校正。 @@ -152,6 +178,9 @@ paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_gene - **scope(fluid.Scope, optional)** - 用来获取和写入``Variable``, 如果设置为``None``,则使用[*fluid.global_scope()*](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html). 默认值是``None``. - **algo(str)** - 量化时使用的算法名称,可为``'KL'``或者``'direct'``。该参数仅针对激活值的量化,因为参数值的量化使用的方式为``'channel_wise_abs_max'``. 当``algo`` 设置为``'direct'``时,使用校正数据的激活值的绝对值的最大值当作``Scale``值,当设置为``'KL'``时,则使用``KL``散度的方法来计算``Scale``值。默认值为``'KL'``。 - **quantizable_op_type(list[str])** - 需要量化的``op``类型列表。默认值为``["conv2d", "depthwise_conv2d", "mul"]``。 +- **is_full_quantize(bool)** - 是否量化所有可支持的op类型。如果设置为False, 则按照 ``'quantizable_op_type'`` 的设置进行量化。 +- **is_use_cache_file(bool)** - 是否使用硬盘对中间结果进行存储。如果为False, 则将中间结果存储在内存中。 +- **cache_dir(str)** - 如果 ``'is_use_cache_file'``为True, 则将中间结果存储在此参数设置的路径下。 **返回** @@ -159,7 +188,8 @@ paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_gene !!! note "注意事项" -因为该接口会收集校正数据的所有的激活值,所以使用的校正图片不能太多。``'KL'``散度的计算也比较耗时。 +- 因为该接口会收集校正数据的所有的激活值,当校正图片比较多时,请设置``'is_use_cache_file'``为True, 将中间结果存储在硬盘中。另外,``'KL'``散度的计算比较耗时。 +- 目前``Paddle-Lite``有int8 kernel来加速的op只有 ``['conv2d', 'depthwise_conv2d', 'mul']``, 其他op的int8 kernel将陆续支持。 **代码示例** diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index 254cf4958643e..b726ec329fa90 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -13,6 +13,8 @@ # limitations under the License. import copy +import logging + import paddle import paddle.fluid as fluid from paddle.fluid.framework import IrGraph @@ -24,22 +26,37 @@ from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid import core +from ..common import get_logger +_logger = get_logger(__name__, level=logging.INFO) + WEIGHT_QUANTIZATION_TYPES = [ 'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max' ] +WEIGHT_QUANTIZATION_TYPES_TENSORRT = ['channel_wise_abs_max'] + ACTIVATION_QUANTIZATION_TYPES = [ 'abs_max', 'range_abs_max', 'moving_average_abs_max' ] + +ACTIVATION_QUANTIZATION_TYPES_TENSORRT = [ + 'range_abs_max', 'moving_average_abs_max' +] + VALID_DTYPES = ['int8'] -TRANSFORM_PASS_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul'] -QUANT_DEQUANT_PASS_OP_TYPES = ['elementwise_add', 'pool2d'] +TRANSFORM_PASS_OP_TYPES = QuantizationTransformPass._supported_quantizable_op_type +QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type + \ + AddQuantDequantPass._activation_type +TENSORRT_OP_TYPES = [ + 'mul', 'conv2d', 'pool2d', 'depthwise_conv2d', 'elementwise_add', + 'leaky_relu' +] _quant_config_default = { - # weight quantize type, default is 'abs_max' - 'weight_quantize_type': 'abs_max', - # activation quantize type, default is 'abs_max' - 'activation_quantize_type': 'abs_max', + # weight quantize type, default is 'channel_wise_abs_max' + 'weight_quantize_type': 'channel_wise_abs_max', + # activation quantize type, default is 'moving_average_abs_max' + 'activation_quantize_type': 'moving_average_abs_max', # weight quantize bit num, default is 8 'weight_bits': 8, # activation quantize bit num, default is 8 @@ -47,25 +64,25 @@ # ops of name_scope in not_quant_pattern list, will not be quantized 'not_quant_pattern': ['skip_quant'], # ops of type in quantize_op_types, will be quantized - 'quantize_op_types': - ['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'], + 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' 'dtype': 'int8', # window size for 'range_abs_max' quantization. defaulf is 10000 'window_size': 10000, # The decay coefficient of moving average, default is 0.9 'moving_rate': 0.9, - # if set quant_weight_only True, then only quantize parameters of layers which need to be quantized, - # and activations will not be quantized. - 'quant_weight_only': False + # if True, 'quantize_op_types' will be TENSORRT_OP_TYPES + 'for_tensorrt': False, + # if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES + 'is_full_quantize': False } def _parse_configs(user_config): """ - check user configs is valid, and set default value if user not config. + check if user's configs are valid. Args: - user_config(dict):the config of user. + user_config(dict): user's config. Return: configs(dict): final configs will be used. """ @@ -73,12 +90,26 @@ def _parse_configs(user_config): configs = copy.deepcopy(_quant_config_default) configs.update(user_config) - # check configs is valid - assert configs['weight_quantize_type'] in WEIGHT_QUANTIZATION_TYPES, \ - "Unknown weight_quantize_type: '%s'. It can only be " + " ".join(WEIGHT_QUANTIZATION_TYPES) + assert isinstance(configs['for_tensorrt'], bool) and isinstance( + configs['is_full_quantize'], + bool), "'for_tensorrt' and 'is_full_quantize' must both be bool'" + + # check if configs is valid + if configs['for_tensorrt']: + weight_types = WEIGHT_QUANTIZATION_TYPES_TENSORRT + activation_types = ACTIVATION_QUANTIZATION_TYPES_TENSORRT + platform = 'TensorRT' + else: + weight_types = WEIGHT_QUANTIZATION_TYPES + activation_types = WEIGHT_QUANTIZATION_TYPES + platform = 'PaddleLite' + assert configs['weight_quantize_type'] in weight_types, \ + "Unknown weight_quantize_type: {}. {} only supports {} ".format(configs['weight_quantize_type'], + platform, weight_types) - assert configs['activation_quantize_type'] in ACTIVATION_QUANTIZATION_TYPES, \ - "Unknown activation_quantize_type: '%s'. It can only be " + " ".join(ACTIVATION_QUANTIZATION_TYPES) + assert configs['activation_quantize_type'] in activation_types, \ + "Unknown activation_quantize_type: {}. {} only supports {}".format(configs['activation_quantize_type'], + platform, activation_types) assert isinstance(configs['weight_bits'], int), \ "weight_bits must be int value." @@ -92,17 +123,24 @@ def _parse_configs(user_config): assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \ "activation_bits should be between 1 and 16." - assert isinstance(configs['not_quant_pattern'], list), \ - "not_quant_pattern must be a list" + assert isinstance(configs['not_quant_pattern'], (list, str)), \ + "not_quant_pattern must be list or str" assert isinstance(configs['quantize_op_types'], list), \ "quantize_op_types must be a list" - for op_type in configs['quantize_op_types']: - assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or ( - op_type in TRANSFORM_PASS_OP_TYPES), "{} is not support, \ - now support op types are {}".format( - op_type, TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES) + if configs['for_tensorrt']: + configs['quantize_op_types'] = TENSORRT_OP_TYPES + elif configs['is_full_quantize']: + configs[ + 'quantize_op_types'] = TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES + else: + for op_type in configs['quantize_op_types']: + assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or ( + op_type in TRANSFORM_PASS_OP_TYPES), "{} is not support, \ + now support op types are {}".format( + op_type, + TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES) assert isinstance(configs['dtype'], str), \ "dtype must be a str." @@ -116,36 +154,31 @@ def _parse_configs(user_config): assert isinstance(configs['moving_rate'], float), \ "moving_rate must be float value, The decay coefficient of moving average, default is 0.9." - assert isinstance(configs['quant_weight_only'], bool), \ - "quant_weight_only must be bool value, if set quant_weight_only True, " \ - "then only quantize parameters of layers which need to be quantized, " \ - " and activations will not be quantized." - return configs -def quant_aware(program, place, config, scope=None, for_test=False): +def quant_aware(program, place, config=None, scope=None, for_test=False): """ add trainable quantization ops in program. Args: - program(fluid.Program): program - scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope(). - place(fluid.CPUPlace or fluid.CUDAPlace): place - config(dict): configs for quantization, default values are in quant_config_default dict. - for_test: if program is test program, for_test should be set True, else False. + program(fluid.Program): program to quant + place(fluid.CPUPlace or fluid.CUDAPlace): CPU or CUDA device + config(dict, optional): configs for quantization. if None, will use default config. Default is None. + scope(fluid.Scope): the scope to store var, it should be program's scope. if None, will use fluid.global_scope(). + default is None. + for_test(bool): if program is test program, set True when program is for test, False when program is for train. Default is False. Return: fluid.Program: user can finetune this quantization program to enhance the accuracy. """ scope = fluid.global_scope() if not scope else scope - assert isinstance(config, dict), "config must be dict" - - assert 'weight_quantize_type' in config.keys( - ), 'weight_quantize_type must be configured' - assert 'activation_quantize_type' in config.keys( - ), 'activation_quantize_type must be configured' + if config is None: + config = _quant_config_default + else: + assert isinstance(config, dict), "config must be dict" + config = _parse_configs(config) + _logger.info("quant_aware config {}".format(config)) - config = _parse_configs(config) main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) transform_pass_ops = [] @@ -197,7 +230,10 @@ def quant_post(executor, batch_nums=None, scope=None, algo='KL', - quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"]): + quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], + is_full_quantize=False, + is_use_cache_file=False, + cache_dir="./temp_post_training"): """ The function utilizes post training quantization method to quantize the fp32 model. It uses calibrate data to calculate the scale factor of @@ -232,6 +268,11 @@ def quant_post(executor, quantizable_op_type(list[str], optional): The list of op types that will be quantized. Default is ["conv2d", "depthwise_conv2d", "mul"]. + is_full_quantize(bool): if True, apply quantization to all supported quantizable op type. + If False, only apply quantization to the input quantizable_op_type. Default is False. + is_use_cache_file(bool): If False, all temp data will be saved in memory. If True, + all temp data will be saved to disk. Defalut is False. + cache_dir(str): When 'is_use_cache_file' is True, temp data will be save in 'cache_dir'. Default is './temp_post_training'. Returns: None """ @@ -246,41 +287,64 @@ def quant_post(executor, scope=scope, algo=algo, quantizable_op_type=quantizable_op_type, - is_full_quantize=False) + is_full_quantize=is_full_quantize, + is_use_cache_file=is_use_cache_file, + cache_dir=cache_dir) post_training_quantization.quantize() post_training_quantization.save_quantized_model(quantize_model_path) -def convert(program, place, config, scope=None, save_int8=False): +def convert(program, place, config=None, scope=None, save_int8=False): """ - add quantization ops in program. the program returned is not trainable. + change quantization ops order in program. return program that can used by Paddle-Lite. Args: - program(fluid.Program): program - scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope() - place(fluid.CPUPlace or fluid.CUDAPlace): place - config(dict): configs for quantization, default values are in quant_config_default dict. - save_int8: is export int8 freezed program. + program(fluid.Program): program that returned by quant_aware + place(fluid.CPUPlace or fluid.CUDAPlace): CPU or CUDA device + scope(fluid.Scope, optional): the scope to store var, it should be program's scope. if None, will use fluid.global_scope(). + default is None. + config(dict, optional): configs for convert. if set None, will use default config. Default is None.\ + It must be same with config that used in 'quant_aware'. + save_int8: if return int8 freezed program. Int8 program can only be used to check size of model weights. \ + It cannot be used in Fluid or Paddle-Lite. Return: - fluid.Program: freezed program which can be used for inference. + freezed_program(fluid.Program): freezed program which can be used for inference. parameters is float32 type, but it's value in int8 range. - fluid.Program: freezed int8 program which can be used for inference. - if save_int8 is False, this value is None. + freezed_program_int8(fluid.Program): freezed int8 program. + when save_int8 is False, return freezed_program. + when save_int8 is True, return freezed_program and freezed_program_int8 """ scope = fluid.global_scope() if not scope else scope + + if config is None: + config = _quant_config_default + else: + assert isinstance(config, dict), "config must be dict" + config = _parse_configs(config) + _logger.info("convert config {}".format(config)) + test_graph = IrGraph(core.Graph(program.desc), for_test=True) + support_op_types = [] + for op in config['quantize_op_types']: + if op in QuantizationFreezePass._supported_quantizable_op_type: + support_op_types.append(op) # Freeze the graph after training by adjusting the quantize # operators' order for the inference. freeze_pass = QuantizationFreezePass( scope=scope, place=place, - weight_quantize_type=config['weight_quantize_type']) + weight_bits=config['weight_bits'], + activation_bits=config['activation_bits'], + weight_quantize_type=config['weight_quantize_type'], + quantizable_op_type=support_op_types) freeze_pass.apply(test_graph) freezed_program = test_graph.to_program() if save_int8: convert_int8_pass = ConvertToInt8Pass( - scope=fluid.global_scope(), place=place) + scope=fluid.global_scope(), + place=place, + quantizable_op_type=support_op_types) convert_int8_pass.apply(test_graph) freezed_program_int8 = test_graph.to_program() return freezed_program, freezed_program_int8