Skip to content

Commit

Permalink
Support quantization of condition block (#37498)
Browse files Browse the repository at this point in the history
* Support sub graph quant-post
  • Loading branch information
yghstill committed Dec 10, 2021
1 parent 76c7322 commit 89069af
Show file tree
Hide file tree
Showing 3 changed files with 387 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,23 @@ def quantize(self):
for op_type in self._dynamic_quantize_op_type):
self._collect_dynamic_quantize_op_threshold(
self._dynamic_quantize_op_type)

# Move sub blocks persistable var to global block
global_block = self._program.global_block()
for _op in global_block.ops:
if _op.type == "while":
_block_id = _op.attr("sub_block").id
_block = self._program.block(_block_id)
persistables = []
for _name, _var in _block.vars.items():
if _var.persistable:
global_block._clone_variable(_var)
persistables.append(_name)
for _name in persistables:
_block._remove_var(_name)
persistables.extend(_op.input('X'))
_op.desc.set_input("X", persistables)

return self._program

def save_quantized_model(self,
Expand Down Expand Up @@ -451,10 +468,6 @@ def _load_model_data(self):
model_filename=self._model_filename,
params_filename=self._params_filename)

if self._program.num_blocks > 1:
_logger.error("The post training quantization requires that the "
"program only has one block.")

if self._optimize_model:
self._optimize_fp32_model()

Expand Down Expand Up @@ -505,23 +518,26 @@ def collect_var_name(var_name_list, persistable_var_names, op_type):
self._quantized_act_var_name.add(var_name)

persistable_var_names = _all_persistable_var_names(self._program)
for op in self._program.global_block().ops:
op_type = op.type
if self._is_full_quantize and \
op_type not in self._quantizable_op_type:
_logger.warning(op_type + " is not supported for quantization.")
# For quantized ops, sample inputs and outputs
if op_type in self._quantizable_op_type:
collect_var_name(
_get_op_input_var_names(op), persistable_var_names, op_type)
collect_var_name(
_get_op_output_var_names(op), persistable_var_names,
op_type)
# For other op, only sample output scale
elif op_type in self._out_scale_op_list:
collect_var_name(
_get_op_output_var_names(op), persistable_var_names,
op_type)
for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops:
op_type = op.type
if self._is_full_quantize and \
op_type not in self._quantizable_op_type:
_logger.warning(op_type +
" is not supported for quantization.")
# For quantized ops, sample inputs and outputs
if op_type in self._quantizable_op_type:
collect_var_name(
_get_op_input_var_names(op), persistable_var_names,
op_type)
collect_var_name(
_get_op_output_var_names(op), persistable_var_names,
op_type)
# For other op, only sample output scale
elif op_type in self._out_scale_op_list:
collect_var_name(
_get_op_output_var_names(op), persistable_var_names,
op_type)

def _set_activation_persistable(self):
'''
Expand Down Expand Up @@ -696,16 +712,17 @@ def _save_input_threhold(self):
'''
assert self._algo == "min_max", \
"The algo should be min_max to save input threshold."
for op in self._program.global_block().ops:
if op.type in self._quantizable_op_type:
for var_name in _get_op_input_var_names(op):
assert var_name in self._quantized_var_min
assert var_name in self._quantized_var_max
op._set_attr(var_name + ".min",
self._quantized_var_min[var_name])
op._set_attr(var_name + ".max",
self._quantized_var_max[var_name])
op._set_attr("with_quant_attr", True)
for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops:
if op.type in self._quantizable_op_type:
for var_name in _get_op_input_var_names(op):
assert var_name in self._quantized_var_min
assert var_name in self._quantized_var_max
op._set_attr(var_name + ".min",
self._quantized_var_min[var_name])
op._set_attr(var_name + ".max",
self._quantized_var_max[var_name])
op._set_attr("with_quant_attr", True)

def _collect_activation_abs_min_max(self):
'''
Expand Down Expand Up @@ -795,7 +812,12 @@ def _update_program(self):
activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
transform_pass.apply(graph)

for sub_graph in graph.all_sub_graphs():
# Insert fake_quant/fake_dequantize op must in test graph, so
# set per graph's _for_test is True.
sub_graph._for_test = True
transform_pass.apply(sub_graph)

# use AddQuantDequantPass to insert fake_quant_dequant op
minor_quantizable_op_types = []
Expand All @@ -806,7 +828,10 @@ def _update_program(self):
scope=self._scope,
place=self._place,
quantizable_op_type=minor_quantizable_op_types)
add_quant_dequant_pass.apply(graph)

for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
add_quant_dequant_pass.apply(sub_graph)

# save threshold to scale var node
if self._algo in ["KL", "hist"]:
Expand Down Expand Up @@ -836,7 +861,11 @@ def _update_program(self):
activation_bits=self._activation_bits,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
freeze_pass.apply(graph)

for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
freeze_pass.apply(sub_graph)

self._program = graph.to_program()

def _save_output_threshold(self):
Expand Down Expand Up @@ -888,13 +917,15 @@ def analysis_and_save_info(op_node, out_var_name):
save_info(op_node, out_var_name, self._quantized_var_max,
"out_max", "post_min_max")

for op in self._program.global_block().ops:
if op.type in (self._quantizable_op_type + self._out_scale_op_list):
out_var_names = _get_op_output_var_names(op)
assert len(out_var_names) == 1, "Post training " + \
"quantization only support one output for " + op.type
for var_name in out_var_names:
analysis_and_save_info(op, var_name)
for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops:
if op.type in (
self._quantizable_op_type + self._out_scale_op_list):
out_var_names = _get_op_output_var_names(op)
assert len(out_var_names) == 1, "Post training " + \
"quantization only support one output for " + op.type
for var_name in out_var_names:
analysis_and_save_info(op, var_name)

def _collect_dynamic_quantize_op_threshold(self, target_ops_type):
"""
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ endfunction()
if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mnist)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_while)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model)
Expand Down Expand Up @@ -336,6 +337,7 @@ if(NOT WIN32)
set_tests_properties(test_post_training_quantization_mobilenetv1 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY")
set_tests_properties(test_post_training_quantization_resnet50 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY")
set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT 120)
set_tests_properties(test_post_training_quantization_while PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 120)
set_tests_properties(test_weight_quantization_mobilenetv1 PROPERTIES TIMEOUT 120)
endif()
Expand Down
Loading

0 comments on commit 89069af

Please sign in to comment.