Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR OpTest Fix No.3】 fix test_matrix_rank_op #59959

Merged
merged 10 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_compat_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ def insert_new_mutable_attributes(
"grad_out": "DY",
}

op_arg_name_mappings["matrix_rank"] = {
"x": "X",
"atol_tensor": "TolTensor",
"out": "Out",
}

op_name_normailzer_template = env.get_template("op_compat_info.cc.j2")
with open(output_source_file, 'wt') as f:
op_compat_definition = op_name_normailzer_template.render(
Expand Down
21 changes: 21 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2634,6 +2634,26 @@ struct FusedElemwiseAddActivationGradOpTranscriber
}
};

struct MatrixRankOpTranscriber : public OpTranscriber {
pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
std::string target_op_name = "";
if (op_desc.HasInput("TolTensor") && !op_desc.Input("TolTensor").empty()) {
target_op_name = "pd_op.matrix_rank_tol";
} else {
target_op_name = "pd_op.matrix_rank";
}
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) {
IR_THROW(
"Op matrix_rank should have corresponding OpInfo pd_op.matrix_rank "
"or "
"pd_op.matrix_rank_tol.");
}
return op_info;
}
};

struct LodArrayLengthOpTranscriber : public OpTranscriber {
pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
Expand Down Expand Up @@ -2851,6 +2871,7 @@ OpTranslator::OpTranslator() {
special_handlers["split"] = SplitOpTranscriber();
special_handlers["sum"] = AddNOpTranscriber();
special_handlers["tril_triu"] = TrilAndTriuOpTranscriber();
special_handlers["matrix_rank"] = MatrixRankOpTranscriber();
special_handlers["mul"] = MulOpTranscriber();
special_handlers["mul_grad"] = MulGradOpTranscriber();
special_handlers["select_input"] = SelectInputOpTranscriber();
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ test_matmul_v2_op
test_matmul_v2_op_static_build
test_matrix_nms_op
test_matrix_power_op
test_matrix_rank_op
test_maxout_op
test_mean_op
test_memcpy_op
Expand Down