diff --git a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc old mode 100755 new mode 100644 index 95a2fdee31436..7f1bc37183ec8 --- a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc @@ -336,27 +336,46 @@ void ComputePropagateScalesMkldnnPass::ComputeWeightScales( ComputeLstmWeightScales(graph, scope, "WeightX", "WeightH", var_quant_scales); } -void ComputePropagateScalesMkldnnPass::UpdateScaleOpInScale( +void ComputePropagateScalesMkldnnPass::UpdateScaleOpInOutScales( Node* op_node, const std::string& input_name, const std::string& output_name, StringPairMap* var_quant_scales) const { - auto iter = var_quant_scales->find(output_name); - if (iter != var_quant_scales->end()) { - auto pair = iter->second; - const auto tensor = pair.second; - - const auto scale = PADDLE_GET_CONST(float, op_node->Op()->GetAttr("scale")); - phi::DenseTensor tmp_tensor; - tmp_tensor.Resize(tensor.dims()); - auto* data = tmp_tensor.mutable_data(platform::CPUPlace()); - for (int i = 0; i < tensor.numel(); i++) { - data[i] = data[i] * scale; - } + auto out_iter = var_quant_scales->find(output_name); + auto input_iter = var_quant_scales->find(input_name); + // All the input and output have scales + if (out_iter != var_quant_scales->end() && + input_iter != var_quant_scales->end()) { + return; + } - auto new_pair = std::make_pair(pair.first, tmp_tensor); - var_quant_scales->insert(std::make_pair(input_name, new_pair)); + const auto scale = PADDLE_GET_CONST(float, op_node->Op()->GetAttr("scale")); + if (std::abs(scale) < 1e-6 && out_iter != var_quant_scales->end()) { + return; } + + std::string name = input_name; + auto iter = out_iter; + if (input_iter != var_quant_scales->end()) { + iter = input_iter; + name = output_name; + } + + phi::DenseTensor tmp_tensor; + auto pair = iter->second; + const auto tensor = pair.second; + tmp_tensor.Resize(tensor.dims()); + auto* data = tmp_tensor.mutable_data(platform::CPUPlace()); + auto* src_data = tensor.data(); + for (int i = 0; i < tensor.numel(); i++) { + if (out_iter != var_quant_scales->end()) { + data[i] = src_data[i] / scale; + } else { + data[i] = src_data[i] * scale; + } + } + auto new_pair = std::make_pair(pair.first, tmp_tensor); + var_quant_scales->insert(std::make_pair(name, new_pair)); } std::unordered_set ComputePropagateScalesMkldnnPass::UpdateScales( @@ -403,10 +422,12 @@ std::unordered_set ComputePropagateScalesMkldnnPass::UpdateScales( } } else if (op_name == "scale") { const std::string output_name = op_node->Op()->Output("Out")[0]; + const std::string input_name = op_node->Op()->Input("X")[0]; auto out_iter = var_quant_scales->find(output_name); - if (out_iter != var_quant_scales->end()) { - const std::string input_name = op_node->Op()->Input("X")[0]; - UpdateScaleOpInScale( + auto input_iter = var_quant_scales->find(input_name); + if (out_iter != var_quant_scales->end() || + input_iter != var_quant_scales->end()) { + UpdateScaleOpInOutScales( op_node, input_name, output_name, var_quant_scales); } } diff --git a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.h b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.h index bae810746ae2d..2c2474438bedf 100644 --- a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.h @@ -79,10 +79,10 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase { void UpdateReluOutputScales(ir::Graph* graph, StringPairMap* var_quant_scales) const; - void UpdateScaleOpInScale(Node* op_node, - const std::string& input_name, - const std::string& output_name, - StringPairMap* var_quant_scales) const; + void UpdateScaleOpInOutScales(Node* op_node, + const std::string& input_name, + const std::string& output_name, + StringPairMap* var_quant_scales) const; std::unordered_set UpdateScales( ir::Graph* graph, diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc old mode 100644 new mode 100755 index 59ebbb5764a56..7286b603c6b0b --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -384,6 +384,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("quant_dequant_mkldnn_pass"); passes_.push_back("mkldnn_placement_pass"); passes_.push_back("simplify_with_basic_ops_pass"); + passes_.push_back("constant_folding_pass"); passes_.push_back("layer_norm_fuse_pass"); passes_.push_back("attention_lstm_fuse_pass"); passes_.push_back("seqconv_eltadd_relu_fuse_pass");