Skip to content

Commit

Permalink
fix pass and convert_op for preln_ernie
Browse files Browse the repository at this point in the history
  • Loading branch information
Wangzheee committed Feb 21, 2022
1 parent b2ecdaf commit 78fcfbf
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,19 @@ PrelnEmbeddingEltwiseLayerNormFusePass::

void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);

bool enable_int8 = Get<bool>("enable_int8");
bool use_oss = Get<bool>("use_oss");
bool with_interleaved = Get<bool>("with_interleaved");
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
if (!(enable_int8 && use_oss && with_interleaved && with_dynamic_shape)) {
VLOG(4) << "preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
"enable_int8, "
"use_oss, with_interleaved, with_dynamic_shape. Stop this pass, "
"please reconfig.";
return;
}

int fusion_count =
PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_);
if (fusion_count > 0) {
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_skip_layernorm_fuse", graph);
bool enable_int8 = Get<bool>("enable_int8");
bool use_oss = Get<bool>("use_oss");
bool with_interleaved = Get<bool>("with_interleaved");
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
if (!(enable_int8 && use_oss && with_interleaved && with_dynamic_shape)) {
VLOG(4) << "preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
"use_oss, "
"with_interleaved, with_dynamic_shape. Stop this pass, please "
"reconfig. ";
return;
}

int found_subgraph_count = 0;

GraphPatternDetector gpd;
Expand Down
46 changes: 20 additions & 26 deletions paddle/fluid/inference/analysis/ir_pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@ void IRPassManager::CreatePasses(Argument *argument,
int pass_num = 0;
for (const std::string &pass_name : passes) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
pass->Set("use_oss", new bool(argument->tensorrt_use_oss()));
pass->Set("with_interleaved",
new bool(argument->tensorrt_with_interleaved()));
pass->Set("disable_logs", new bool(argument->disable_logs()));
auto precision_mode = argument->tensorrt_precision_mode();
bool enable_int8 = precision_mode == AnalysisConfig::Precision::kInt8;
pass->Set("enable_int8", new bool(enable_int8));
pass->Set("max_input_shape", new std::map<std::string, std::vector<int>>(
argument->max_input_shape()));
pass->Set("min_input_shape", new std::map<std::string, std::vector<int>>(
argument->min_input_shape()));
pass->Set("optim_input_shape", new std::map<std::string, std::vector<int>>(
argument->optim_input_shape()));
bool with_dynamic_shape = (argument->max_input_shape().size() > 0 &&
argument->min_input_shape().size() > 0 &&
argument->optim_input_shape().size() > 0) ||
argument->tensorrt_tuned_dynamic_shape();
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));

if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir();
Expand Down Expand Up @@ -99,17 +117,9 @@ void IRPassManager::CreatePasses(Argument *argument,
new int(argument->tensorrt_min_subgraph_size()));
pass->Set("program",
new framework::ProgramDesc *(&argument->main_program()));

auto precision_mode = argument->tensorrt_precision_mode();
bool enable_int8 = precision_mode == AnalysisConfig::Precision::kInt8;

pass->Set("predictor_id", new int(argument->predictor_id()));
bool use_calib_mode = argument->tensorrt_use_calib_mode();
pass->Set("enable_int8", new bool(enable_int8));
pass->Set("use_calib_mode", new bool(use_calib_mode));
pass->Set("use_oss", new bool(argument->tensorrt_use_oss()));
pass->Set("with_interleaved",
new bool(argument->tensorrt_with_interleaved()));
pass->Set("precision_mode",
new AnalysisConfig::Precision(precision_mode));

Expand Down Expand Up @@ -165,18 +175,6 @@ void IRPassManager::CreatePasses(Argument *argument,
new bool(argument->tensorrt_tuned_dynamic_shape()));
pass->Set("trt_allow_build_at_runtime",
new bool(argument->tensorrt_allow_build_at_runtime()));
pass->Set("max_input_shape", new std::map<std::string, std::vector<int>>(
argument->max_input_shape()));
pass->Set("min_input_shape", new std::map<std::string, std::vector<int>>(
argument->min_input_shape()));
pass->Set("optim_input_shape",
new std::map<std::string, std::vector<int>>(
argument->optim_input_shape()));
bool with_dynamic_shape = (argument->max_input_shape().size() > 0 &&
argument->min_input_shape().size() > 0 &&
argument->optim_input_shape().size() > 0) ||
argument->tensorrt_tuned_dynamic_shape();
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));
pass->Set("trt_disabled_ops", new std::vector<std::string>(
argument->tensorrt_disabled_ops()));
pass->Set("trt_use_dla", new bool(argument->tensorrt_use_dla()));
Expand All @@ -192,14 +190,14 @@ void IRPassManager::CreatePasses(Argument *argument,
new framework::ProgramDesc *(&argument->main_program()));
}
if (pass_name == "lite_subgraph_pass") {
bool enable_int8 =
bool lite_enable_int8 =
argument->lite_precision_mode() == AnalysisConfig::Precision::kInt8;
pass->Set("program",
new framework::ProgramDesc *(&argument->main_program()));
pass->Set("lite_ops_filter",
new std::vector<std::string>(argument->lite_ops_filter()));
pass->Set("predictor_id", new int(argument->predictor_id()));
pass->Set("enable_int8", new bool(enable_int8));
pass->Set("enable_int8", new bool(lite_enable_int8));
pass->Set("use_gpu", new bool(argument->use_gpu()));
pass->Set("zero_copy", new bool(argument->lite_zero_copy()));
pass->Set("use_xpu", new bool(argument->use_xpu()));
Expand Down Expand Up @@ -236,7 +234,6 @@ void IRPassManager::CreatePasses(Argument *argument,
new std::vector<std::string>(
argument->nnadapter_model_cache_token()));
}
disable_logs_ = argument->disable_logs();
if (pass_name == "fc_fuse_pass") {
pass->Set("use_gpu", new bool(argument->use_gpu()));
bool fc_mkldnn_pass = 0;
Expand All @@ -248,9 +245,6 @@ void IRPassManager::CreatePasses(Argument *argument,
bool use_fc_padding = !fc_mkldnn_pass && argument->use_fc_padding();
pass->Set("use_fc_padding", new bool(use_fc_padding));
}

pass->Set("disable_logs", new bool(disable_logs_));

pre_pass = pass_name;

passes_.emplace_back(std::move(pass));
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,12 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetModelParamsPath(config_.params_file());
}

argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_.SetTensorRtUseOSS(config_.trt_use_oss_);
argument_.SetTensorRtWithInterleaved(config_.trt_with_interleaved_);
argument_.SetMinInputShape(config_.min_input_shape_);
argument_.SetMaxInputShape(config_.max_input_shape_);
argument_.SetOptimInputShape(config_.optim_input_shape_);
if (config_.use_gpu() && config_.tensorrt_engine_enabled()) {
LOG(INFO) << "TensorRT subgraph engine is enabled";
argument_.SetUseTensorRT(true);
Expand All @@ -601,14 +607,8 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetTensorRtDisabledOPs(config_.trt_disabled_ops_);
argument_.SetTensorRtUseDLA(config_.trt_use_dla_);
argument_.SetTensorRtDLACore(config_.trt_dla_core_);
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_);
argument_.SetTensorRtUseCalibMode(config_.trt_use_calib_mode_);
argument_.SetTensorRtUseOSS(config_.trt_use_oss_);
argument_.SetTensorRtWithInterleaved(config_.trt_with_interleaved_);
argument_.SetMinInputShape(config_.min_input_shape_);
argument_.SetMaxInputShape(config_.max_input_shape_);
argument_.SetOptimInputShape(config_.optim_input_shape_);
argument_.SetCloseTrtPluginFp16(config_.disable_trt_plugin_fp16_);
argument_.SetTensorRtShapeRangeInfoPath(config_.shape_range_info_path());
argument_.SetTensorRtTunedDynamicShape(
Expand Down

0 comments on commit 78fcfbf

Please sign in to comment.