From deef0321260c14e19871016bd435a773aeb21275 Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Fri, 22 Dec 2023 16:26:53 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PIR=20OpTest=20Fix=20No.3=E3=80=91=20f?= =?UTF-8?q?ix=20test=5Fmatrix=5Frank=5Fop=20(#59959)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix * fix * Update op_translator.cc * Update op_translator.cc * Update op_translator.cc * Update op_translator.cc * fix * fix codestyle --- .../ir_adaptor/translator/op_compat_gen.py | 6 ++++++ .../ir_adaptor/translator/op_translator.cc | 21 +++++++++++++++++++ test/white_list/pir_op_test_white_list | 1 + 3 files changed, 28 insertions(+) diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index 587df51f0039f..596bf8534bfe6 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -149,6 +149,12 @@ def insert_new_mutable_attributes( } op_arg_name_mappings["matmul"] = {"x": "X", "y": "Y", "out": "Out"} + op_arg_name_mappings["matrix_rank"] = { + "x": "X", + "atol_tensor": "TolTensor", + "out": "Out", + } + op_name_normailzer_template = env.get_template("op_compat_info.cc.j2") with open(output_source_file, 'wt') as f: op_compat_definition = op_name_normailzer_template.render( diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 74901075e5204..913903dc75611 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -2671,6 +2671,26 @@ struct FusedElemwiseAddActivationGradOpTranscriber } }; +struct MatrixRankOpTranscriber : public OpTranscriber { + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { + std::string target_op_name = ""; + if (op_desc.HasInput("TolTensor") && !op_desc.Input("TolTensor").empty()) { + target_op_name = "pd_op.matrix_rank_tol"; + } else { + target_op_name = "pd_op.matrix_rank"; + } + const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (!op_info) { + IR_THROW( + "Op matrix_rank should have corresponding OpInfo pd_op.matrix_rank " + "or " + "pd_op.matrix_rank_tol."); + } + return op_info; + } +}; + struct LodArrayLengthOpTranscriber : public OpTranscriber { pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc) override { @@ -2956,6 +2976,7 @@ OpTranslator::OpTranslator() { special_handlers["split"] = SplitOpTranscriber(); special_handlers["sum"] = AddNOpTranscriber(); special_handlers["tril_triu"] = TrilAndTriuOpTranscriber(); + special_handlers["matrix_rank"] = MatrixRankOpTranscriber(); special_handlers["mul"] = MulOpTranscriber(); special_handlers["mul_grad"] = MulGradOpTranscriber(); special_handlers["select_input"] = SelectInputOpTranscriber(); diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index c00f7fba97698..fe57bbe32693f 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -197,6 +197,7 @@ test_matmul_v2_op test_matmul_v2_op_static_build test_matrix_nms_op test_matrix_power_op +test_matrix_rank_op test_maxout_op test_mean_op test_memcpy_op