Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
yeliang2258 committed Sep 1, 2022
1 parent d84f969 commit 6d85e3b
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 188 deletions.
263 changes: 88 additions & 175 deletions paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ void QuantDequantMkldnnPass::CollectInfoFromFake(
void QuantDequantMkldnnPass::CollectWeightScalesInfoFromONNXFormatDequantize(
ir::Graph* graph,
Scope* scope,
const std::unordered_set<std::string>& onnx_format_dequantize_types,
std::unordered_map<std::string, std::vector<float>>* weight_thresholds,
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
bool* onnx_format_quantize_model) const {
Expand All @@ -108,7 +107,7 @@ void QuantDequantMkldnnPass::CollectWeightScalesInfoFromONNXFormatDequantize(
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp()) continue;

if (onnx_format_dequantize_types.count(op_node->Name())) {
if (op_node->Name() == "dequantize_linear") {
auto* op_desc = op_node->Op();
auto x_var_name = op_desc->Input("X")[0];
auto* weight_var = scope->FindVar(x_var_name);
Expand Down Expand Up @@ -140,58 +139,7 @@ void QuantDequantMkldnnPass::CollectWeightScalesInfoFromONNXFormatDequantize(
}
}

void QuantDequantMkldnnPass::CollectInputScalesFromONNXFormatQuantize(
ir::Graph* graph,
Scope* scope,
const std::unordered_set<std::string>& onnx_format_quantize_types,
std::unordered_map<std::string, std::vector<float>>* var_quant_scales)
const {
VLOG(3) << "gather input scales from quantize_linear op";
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp()) continue;

if (onnx_format_quantize_types.count(op_node->Name())) {
auto* op_desc = op_node->Op();
const int bit_length =
PADDLE_GET_CONST(int, op_desc->GetAttr("bit_length"));
PADDLE_ENFORCE_EQ(bit_length,
8,
platform::errors::InvalidArgument(
"Unsupported number quantization "
"bits: %d, only 8 is supported now.",
bit_length));

auto x_var_name = op_desc->Input("X")[0];
auto scale_name = op_desc->Input("Scale")[0];
auto out_var_name = op_desc->Output("Y")[0];
auto* var = scope->FindVar(scale_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"The InScale variable [%s] of quantize op is not found.", var));

auto* scale_tensor = var->GetMutable<LoDTensor>();
auto* scale_data = scale_tensor->data<float>();
float scale = 1.0 / scale_data[0];
if (std::isinf(scale) || std::isnan(scale)) {
scale = 0.0;
}

if (!var_quant_scales->count(x_var_name)) {
std::vector<float> scale_v = {scale};
var_quant_scales->insert(std::make_pair(x_var_name, scale_v));
}

if (!var_quant_scales->count(out_var_name)) {
std::vector<float> scale_v = {scale};
var_quant_scales->insert(std::make_pair(out_var_name, scale_v));
}
}
}
}

void QuantDequantMkldnnPass::CollectInputScalesFromFake(
void QuantDequantMkldnnPass::CollectInputScalesFromQuantize(
ir::Graph* graph,
Scope* scope,
const std::unordered_set<std::string>& fake_quantize_types,
Expand All @@ -203,6 +151,7 @@ void QuantDequantMkldnnPass::CollectInputScalesFromFake(
if (!op_node->IsOp()) continue;

if (op_node->Name() == "fake_quantize_dequantize_moving_average_abs_max" ||
op_node->Name() == "quantize_linear" ||
fake_quantize_types.count(op_node->Name())) {
auto* op_desc = op_node->Op();
const int bit_length =
Expand All @@ -214,10 +163,17 @@ void QuantDequantMkldnnPass::CollectInputScalesFromFake(
"bits: %d, only 8 is supported now.",
bit_length));

std::string scale_name = "InScale";
std::string out_name = "Out";
if (op_node->Name() == "quantize_linear") {
scale_name = "Scale";
out_name = "Y";
}
auto x_var_name = op_desc->Input("X")[0];
auto scale_name = op_desc->Input("InScale")[0];
auto out_var_name = op_desc->Output("Out")[0];
auto* var = scope->FindVar(scale_name);
auto scale_var_name = op_desc->Input(scale_name)[0];
auto out_var_name = op_desc->Output(out_name)[0];

auto* var = scope->FindVar(scale_var_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
Expand Down Expand Up @@ -505,64 +461,54 @@ bool QuantDequantMkldnnPass::IsInt8Weight(
return is_int8;
}

void QuantDequantMkldnnPass::DequantizeOpWeights(
Node* op_node,
Scope* scope,
const std::string& weight_name,
const std::string& output_name,
const std::unordered_map<std::string, std::vector<float>>&
weight_thresholds) const {
auto* op_desc = op_node->Op();
std::string weight_var_name = op_desc->Input(weight_name)[0];
std::string output_var_name = op_desc->Output(output_name)[0];

std::vector<float> scales;
auto iter = weight_thresholds.find(output_var_name);
if (iter != weight_thresholds.end()) {
scales = iter->second;
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Could not find threshold information for [%s] var, please check if "
"the model is correct.",
output_var_name));
}

auto* var = scope->FindVar(weight_var_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"The input persistable [%s] var of [%s] op is not found.",
weight_var_name,
op_desc->Type()));
auto* weight_tensor = var->GetMutable<LoDTensor>();
void QuantDequantMkldnnPass::ConvertFromINT8ToFP32(
const std::vector<float>& scales,
Tensor* weight_tensor,
int8_t* int8_weight_data,
float* fp32_weight_data,
const std::string& weight_var_name) const {
const auto weight_dims = weight_tensor->dims();

std::vector<float> weight_data;
weight_data.resize(weight_tensor->numel());
const int size = scales.size();

if (size == 1 || size == weight_dims[0]) {
auto* weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data[i] /= 127;
if (int8_weight_data) {
weight_data[i] = static_cast<float>(int8_weight_data[i]) / 127.0;
} else {
weight_data[i] = fp32_weight_data[i] / 127.0;
}
}

weight_tensor->clear(); // clear int weight
weight_tensor->Resize(phi::make_ddim(phi::vectorize(weight_dims)));
auto* new_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_weight_data,
weight_data.data(),
weight_tensor->numel() * sizeof(float));

TransposeWeight(weight_tensor);

if (size == 1) {
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data[i] *= scales[0];
new_weight_data[i] *= scales[0];
}
} else {
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data[i] *= scales[i % size];
new_weight_data[i] *= scales[i % size];
}
}

TransposeWeight(weight_tensor);
} else if (weight_dims.size() > 1 && size == weight_dims[1]) {
auto* weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data[i] /= 127;
if (int8_weight_data) {
weight_data[i] = static_cast<float>(int8_weight_data[i]) / 127.0;
} else {
weight_data[i] = fp32_weight_data[i] / 127.0;
}
}

int step_n = 1;
Expand All @@ -581,6 +527,13 @@ void QuantDequantMkldnnPass::DequantizeOpWeights(
}
}
}
weight_tensor->clear(); // clear int weight
weight_tensor->Resize(phi::make_ddim(phi::vectorize(weight_dims)));
auto* new_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_weight_data,
weight_data.data(),
weight_tensor->numel() * sizeof(float));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The size of weight scales vector (%d) does not "
Expand All @@ -589,10 +542,45 @@ void QuantDequantMkldnnPass::DequantizeOpWeights(
weight_tensor->dims().size(),
weight_var_name));
}

weight_tensor->Resize(weight_dims);
}

void QuantDequantMkldnnPass::DequantizeOpWeights(
Node* op_node,
Scope* scope,
const std::string& weight_name,
const std::string& output_name,
const std::unordered_map<std::string, std::vector<float>>&
weight_thresholds) const {
auto* op_desc = op_node->Op();
std::string weight_var_name = op_desc->Input(weight_name)[0];
std::string output_var_name = op_desc->Output(output_name)[0];

std::vector<float> scales;
auto iter = weight_thresholds.find(output_var_name);
if (iter != weight_thresholds.end()) {
scales = iter->second;
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Could not find threshold information for [%s] var, please check if "
"the model is correct.",
output_var_name));
}

auto* var = scope->FindVar(weight_var_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"The input persistable [%s] var of [%s] op is not found.",
weight_var_name,
op_desc->Type()));
auto* weight_tensor = var->GetMutable<LoDTensor>();
float* fp32_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
ConvertFromINT8ToFP32(
scales, weight_tensor, nullptr, fp32_weight_data, weight_var_name);
}

void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat(
Node* op_node,
Scope* scope,
Expand Down Expand Up @@ -624,77 +612,11 @@ void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat(
weight_var_name,
op_desc->Type()));
auto* weight_tensor = var->GetMutable<LoDTensor>();
const auto weight_dims = weight_tensor->dims();
int8_t* int8_weight_data =
weight_tensor->mutable_data<int8_t>(platform::CPUPlace());

std::vector<float> weight_data;
weight_data.resize(weight_tensor->numel());

const int size = scales.size();
if (size == 1 || size == weight_dims[0]) {
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data[i] = static_cast<float>(int8_weight_data[i]) / 127.0;
}

weight_tensor->clear(); // clear int weight
weight_tensor->Resize(phi::make_ddim(phi::vectorize(weight_dims)));
auto* new_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_weight_data,
weight_data.data(),
weight_tensor->numel() * sizeof(float));

TransposeWeight(weight_tensor);

if (size == 1) {
for (int i = 0; i < weight_tensor->numel(); i++) {
new_weight_data[i] *= scales[0];
}
} else {
for (int i = 0; i < weight_tensor->numel(); i++) {
new_weight_data[i] *= scales[i % size];
}
}
TransposeWeight(weight_tensor);
} else if (weight_dims.size() > 1 && size == weight_dims[1]) {
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data[i] = static_cast<float>(int8_weight_data[i]) / 127.0;
}

int step_n = 1;
for (int i = 1; i < weight_dims.size(); i++) {
step_n *= weight_dims[i];
}
int step_c = step_n / size;
for (int i = 0; i < weight_dims[0]; i++) {
int begin_n = i * step_n;
for (int j = begin_n; j < begin_n + step_n; j++) {
for (int k = 0; k < size; k++) {
int begin_c = k * step_c;
for (int m = begin_c; m < begin_c + step_c; m++) {
weight_data[m] *= scales[k];
}
}
}
}
weight_tensor->clear(); // clear int weight
weight_tensor->Resize(phi::make_ddim(phi::vectorize(weight_dims)));
auto* new_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_weight_data,
weight_data.data(),
weight_tensor->numel() * sizeof(float));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The size of weight scales vector (%d) does not "
"match the dimensions (%d) of the weights tensor %s.",
size,
weight_tensor->dims().size(),
weight_var_name));
}

weight_tensor->Resize(weight_dims);
ConvertFromINT8ToFP32(
scales, weight_tensor, int8_weight_data, nullptr, weight_var_name);
}

void QuantDequantMkldnnPass::DequantizeWeights(
Expand Down Expand Up @@ -788,12 +710,6 @@ void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const {
const std::unordered_set<std::string> onnx_format_quantize_dequantize_types =
{"quantize_linear", "dequantize_linear"};

const std::unordered_set<std::string> onnx_format_dequantize_types = {
"dequantize_linear"};

const std::unordered_set<std::string> onnx_format_quantize_types = {
"quantize_linear"};

std::unordered_map<std::string, std::vector<float>> weight_thresholds{};
std::unordered_map<std::string, std::vector<float>> var_quant_scales{};
bool onnx_format_quantize_model = false;
Expand All @@ -806,14 +722,11 @@ void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const {
CollectInfoFromFake(graph, scope, fake_dequantize_types, &weight_thresholds);
CollectWeightScalesInfoFromONNXFormatDequantize(graph,
scope,
onnx_format_dequantize_types,
&weight_thresholds,
&var_quant_scales,
&onnx_format_quantize_model);
CollectInputScalesFromFake(
CollectInputScalesFromQuantize(
graph, scope, fake_quantize_types, &var_quant_scales);
CollectInputScalesFromONNXFormatQuantize(
graph, scope, onnx_format_quantize_types, &var_quant_scales);
CollectOutputScalesFromAttr(graph, &var_quant_scales);
RemoveFakeOps(graph,
fake_quantize_types,
Expand Down
Loading

0 comments on commit 6d85e3b

Please sign in to comment.