Skip to content

Commit

Permalink
update remove_assign_out_pass: add more constraints to improve robust…
Browse files Browse the repository at this point in the history
…ness
  • Loading branch information
lszxb committed Sep 18, 2024
1 parent 20aa163 commit 7058e45
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion csrc/gpu/pass/remove_assign_out_pass.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "paddle/extension.h"
#include <iostream>
#include <algorithm>

namespace {

Expand All @@ -15,7 +16,17 @@ class RemoveAssignOutPattern : public paddle::drr::DrrPatternBase {

pat.AddConstraint([](const paddle::drr::MatchContext &match_ctx) {
auto &out = match_ctx.Tensor("out");
if (out.use_count() == 1 && out.use_begin()->owner()->name() == "cf.yield") {
auto parent_block = out.defining_op()->GetParent();
auto parent_op = out.defining_op()->GetParentOp();

auto &assign_out = match_ctx.Tensor("assign_out");

if (
parent_block && parent_op &&
parent_op->name() == "pd_op.while" &&
out.use_count() == 1 && out.use_begin()->owner()->name() == "cf.yield" &&
std::find(parent_block->args_begin(), parent_block->args_end(), assign_out) != parent_block->args_end()
) {
return true;
}
return false;
Expand Down

0 comments on commit 7058e45

Please sign in to comment.