Skip to content

Commit

Permalink
fix skip_tensor_list name
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill committed May 3, 2022
1 parent 6f0e1a0 commit 242fe2f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(self,
onnx_format=False,
optimize_model=False,
is_use_cache_file=False,
skip_op_list=None,
skip_tensor_list=None,
cache_dir=None):
'''
Constructor.
Expand Down Expand Up @@ -199,7 +199,7 @@ def __init__(self,
the model accuracy is usually higher when it is 'channel_wise_abs_max'.
onnx_format(bool): Whether to export the quantized model with format of ONNX.
Default is False.
skip_op_list(list): List of skip quant tensor name.
skip_tensor_list(list): List of skip quant tensor name.
optimize_model(bool, optional): If set optimize_model as True, it applies
some passes to the model before quantization, and it supports
`conv2d/depthwise_conv2d + bn` pass so far. Some targets require the
Expand Down Expand Up @@ -303,7 +303,7 @@ def __init__(self,
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
self._onnx_format = onnx_format
self._skip_op_list = skip_op_list
self._skip_tensor_list = skip_tensor_list
self._is_full_quantize = is_full_quantize
if is_full_quantize:
self._quantizable_op_type = self._support_quantize_op_type
Expand Down Expand Up @@ -550,10 +550,10 @@ def collect_var_name(var_name_list, persistable_var_names, op_type):
persistable_var_names = _all_persistable_var_names(self._program)
for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops:
# skip quant form self._skip_op_list
if self._skip_op_list is not None:
# skip quant form self._skip_tensor_list
if self._skip_tensor_list is not None:
for inp_name in utils._get_op_input_var_names(op):
if inp_name in self._skip_op_list:
if inp_name in self._skip_tensor_list:
op._set_attr("op_namescope", "skip_quant")

op_type = op.type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def generate_quantized_model(self,
batch_size=10,
batch_nums=10,
onnx_format=False,
skip_op_list=None):
skip_tensor_list=None):

place = fluid.CPUPlace()
exe = fluid.Executor(place)
Expand All @@ -137,7 +137,7 @@ def generate_quantized_model(self,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
onnx_format=onnx_format,
skip_op_list=skip_op_list,
skip_tensor_list=skip_tensor_list,
is_use_cache_file=is_use_cache_file)
ptq.quantize()
ptq.save_quantized_model(self.int8_model_path)
Expand All @@ -157,7 +157,7 @@ def run_test(self,
infer_iterations=10,
quant_iterations=5,
onnx_format=False,
skip_op_list=None):
skip_tensor_list=None):

origin_model_path = self.download_model(data_url, data_md5, model_name)
origin_model_path = os.path.join(origin_model_path, model_name)
Expand All @@ -172,7 +172,7 @@ def run_test(self,
self.generate_quantized_model(
origin_model_path, algo, round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model, batch_size,
quant_iterations, onnx_format, skip_op_list)
quant_iterations, onnx_format, skip_tensor_list)

print("Start INT8 inference for {0} on {1} images ...".format(
model_name, infer_iterations * batch_size))
Expand Down Expand Up @@ -423,7 +423,7 @@ def test_post_training_avg_skip_op(self):
batch_size = 10
infer_iterations = 50
quant_iterations = 5
skip_op_list = ["fc_0.w_0"]
skip_tensor_list = ["fc_0.w_0"]
self.run_test(
model_name,
data_url,
Expand All @@ -438,7 +438,7 @@ def test_post_training_avg_skip_op(self):
batch_size,
infer_iterations,
quant_iterations,
skip_op_list=skip_op_list)
skip_tensor_list=skip_tensor_list)


if __name__ == '__main__':
Expand Down

1 comment on commit 242fe2f

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.