Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix some convert error found in tipc. #44457

Merged
merged 3 commits into from
Jul 21, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 177 additions & 13 deletions paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"

#include <string>
#include <unordered_set>

#include "paddle/fluid/framework/block_desc.h"
Expand All @@ -22,6 +23,7 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h"
Expand Down Expand Up @@ -89,6 +91,31 @@ bool OutShouldNotConvert(ir::Node* var_node) {
return false;
}

// Get weight names which appear in multiple block (block 0 and block n).
std::unordered_set<std::string> GetMultiBlockPersistableNames(
framework::ProgramDesc* program_desc) {
std::unordered_set<std::string> special_weights;
size_t block_size = program_desc->Size();

std::unordered_set<std::string> block_0_weights;
for (auto var : program_desc->Block(0).AllVars()) {
if (var->Persistable()) block_0_weights.insert(var->Name());
}

for (size_t i = 1; i < block_size; ++i) {
// std::cout << program_desc->MutableBlock(i)->Proto()->DebugString() <<
// std::endl;;
auto all_ops = program_desc->Block(i).AllOps();
for (auto op : all_ops) {
for (auto name : op->InputArgumentNames()) {
if (block_0_weights.count(name)) special_weights.insert(name);
}
}
}

return special_weights;
}

// Just process special cases for weights conversion.
bool WeightsShouldNotConvert(ir::Node* var_node) {
auto op_nodes = var_node->outputs;
Expand Down Expand Up @@ -116,19 +143,139 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
}
}

// If cur_op's next is condition_flow op, then cur op should be fp32. Note, we
// now only convert to mixed in block 0.
for (auto* op_node : op_nodes) {
for (auto var : op_node->outputs) {
for (auto next_op : var->outputs) {
if (next_op->Op()->HasAttr("sub_block")) {
return true;
}
}
}
}

return false;
}

inline bool IsFloatVarType(framework::proto::VarType::Type type) {
if (type == framework::proto::VarType::FP16 ||
type == framework::proto::VarType::FP32 ||
type == framework::proto::VarType::BF16 ||
type == framework::proto::VarType::FP64)
type == framework::proto::VarType::BF16)
return true;
return false;
}

void ConvertTensorDtype(framework::ir::Graph* graph,
void ConvertAllFp64ToFp32(framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
if (op_type == "feed" || op_type == "fetch") continue;

if (op_type == "fill_constant") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "assign_value") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "eye") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "fill_any_like") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "cast") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"in_dtype", static_cast<int>(framework::proto::VarType::FP32));
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"out_dtype", static_cast<int>(framework::proto::VarType::FP32));
}

auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
if (in_node->IsCtrlVar()) continue;
auto* in_var = in_node->Var();
if (!in_var->Persistable() &&
in_var->GetDataType() == framework::proto::VarType::FP64) {
in_var->SetDataType(framework::proto::VarType::FP32);
}
}
}
}

// Handle special ops which contains dtype attribute. e.g., fill_constant,
// assign_value.
void HandleSpecialOps(framework::OpDesc* op_desc) {
if (op_desc->Type() == "fill_constant") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
} else if (op_desc->Type() == "assign_value") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
} else if (op_desc->Type() == "eye") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
} else if (op_desc->Type() == "fill_any_like") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
}
}

// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
void FixCastAttr(framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
if (op_type != "cast") continue;

auto input = op_node->inputs[0];
auto output = op_node->outputs[0];
op_node->Op()->SetAttr("in_dtype",
static_cast<int>(input->Var()->GetDataType()));
op_node->Op()->SetAttr("out_dtype",
static_cast<int>(output->Var()->GetDataType()));
}
}

// If op's output var is condition flow op's input, then the op must be fp32
// precision.
bool NextOpIncludesConditionFlowOp(framework::ir::Node* cur_op_node) {
auto cur_op_outs = cur_op_node->outputs;
for (auto out_var : cur_op_outs) {
for (auto next_op_node : out_var->outputs) {
if (next_op_node->Op()->HasAttr("sub_block")) {
return true;
}
}
}
return false;
}

void ConvertTensorDtype(framework::ProgramDesc* program_desc,
framework::ir::Graph* graph,
const std::unordered_set<std::string>& blacklist,
bool keep_io_types,
phi::Backend backend,
Expand All @@ -145,13 +292,14 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
static_cast<int>(tensor_dtype)));
}

auto weight_name_in_multi_block = GetMultiBlockPersistableNames(program_desc);
int num_low_precision = 0;
int suffix = 0;
framework::BlockDesc* block_desc{nullptr};
std::vector<framework::ir::Node*> output_nodes;
std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map;

for (auto* op_node : framework::ir::TopologySortOperations(*graph)) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
auto phi_op_type = phi::TransToPhiKernelName(op_type);
Expand All @@ -167,18 +315,29 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
auto* fetch_var = op_node->inputs[0];
output_nodes.push_back(fetch_var);
continue;
} else if (op_type == "cast") {
continue;
}

// 2. if op support fp16/bf16 and not in blacklist.
// - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
else if (blacklist.count(phi_op_type) == 0) { // NOLINT
else if (blacklist.count(phi_op_type) == 0 && // NOLINT
!NextOpIncludesConditionFlowOp(op_node)) {
bool support_precision =
OpSupportPrecision(phi_op_type, backend, tensor_dtype, blacklist);
VLOG(2) << "phi_op_type " << phi_op_type << " support low precision "
<< support_precision;
VLOG(2) << "op_type " << op_type << ", phi_op_type " << phi_op_type
<< " support low precision " << support_precision << ", "
<< reinterpret_cast<void*>(op_node->Op()->Block());

for (auto in_node : op_node->inputs) {
if (weight_name_in_multi_block.count(in_node->Name()))
support_precision = false;
}

if (support_precision) {
HandleSpecialOps(op_node->Op());
++num_low_precision;
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
Expand Down Expand Up @@ -232,9 +391,8 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
// 3. check op not support fp16/bf16 or in blacklist.
// - add cast op if the input dtype is not fp32.
else { // NOLINT
// trt pass should explicitle add cast op is input is bf16/tf32, etc.
if (op_node->Name() == "tensorrt_engine") continue;
for (auto* in_node : op_node->inputs) {
auto ins = op_node->inputs;
for (auto* in_node : ins) {
if (in_node->IsCtrlVar()) continue;
auto* in_var = in_node->Var();
if (in_var->GetDataType() == to_type) {
Expand Down Expand Up @@ -366,8 +524,14 @@ void ConvertToMixedPrecision(const std::string& model_file,
auto graph = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc));

ConvertTensorDtype(
graph.get(), black_list, keep_io_types, backend, mixed_precision);
ConvertAllFp64ToFp32(graph.get());
ConvertTensorDtype(program_desc.get(),
graph.get(),
black_list,
keep_io_types,
backend,
mixed_precision);
FixCastAttr(graph.get());

framework::ProgramDesc mixed_program_desc;
framework::ir::GraphToProgram(*graph, &mixed_program_desc);
Expand Down