diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index 587df51f0039f..596bf8534bfe6 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -149,6 +149,12 @@ def insert_new_mutable_attributes( } op_arg_name_mappings["matmul"] = {"x": "X", "y": "Y", "out": "Out"} + op_arg_name_mappings["matrix_rank"] = { + "x": "X", + "atol_tensor": "TolTensor", + "out": "Out", + } + op_name_normailzer_template = env.get_template("op_compat_info.cc.j2") with open(output_source_file, 'wt') as f: op_compat_definition = op_name_normailzer_template.render( diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 74901075e5204..913903dc75611 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -2671,6 +2671,26 @@ struct FusedElemwiseAddActivationGradOpTranscriber } }; +struct MatrixRankOpTranscriber : public OpTranscriber { + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { + std::string target_op_name = ""; + if (op_desc.HasInput("TolTensor") && !op_desc.Input("TolTensor").empty()) { + target_op_name = "pd_op.matrix_rank_tol"; + } else { + target_op_name = "pd_op.matrix_rank"; + } + const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (!op_info) { + IR_THROW( + "Op matrix_rank should have corresponding OpInfo pd_op.matrix_rank " + "or " + "pd_op.matrix_rank_tol."); + } + return op_info; + } +}; + struct LodArrayLengthOpTranscriber : public OpTranscriber { pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc) override { @@ -2956,6 +2976,7 @@ OpTranslator::OpTranslator() { special_handlers["split"] = SplitOpTranscriber(); special_handlers["sum"] = AddNOpTranscriber(); special_handlers["tril_triu"] = TrilAndTriuOpTranscriber(); + special_handlers["matrix_rank"] = MatrixRankOpTranscriber(); special_handlers["mul"] = MulOpTranscriber(); special_handlers["mul_grad"] = MulGradOpTranscriber(); special_handlers["select_input"] = SelectInputOpTranscriber(); diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index c00f7fba97698..fe57bbe32693f 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -197,6 +197,7 @@ test_matmul_v2_op test_matmul_v2_op_static_build test_matrix_nms_op test_matrix_power_op +test_matrix_rank_op test_maxout_op test_mean_op test_memcpy_op