Skip to content

Commit

Permalink
enhance cas and fix bugs (PaddlePaddle#390)
Browse files Browse the repository at this point in the history
Co-authored-by: haozech <chenhaoze94@gmail.com>
  • Loading branch information
wenming2014 and haozech committed May 28, 2021
1 parent 4278b3a commit 53bdc07
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 19 deletions.
156 changes: 139 additions & 17 deletions cinn/common/cas.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,30 @@ Expr Divide(const Sum* a, int b) {
return Sum::Make(args);
}
Expr Divide(const Product* a, int b) {
auto* a_first_i = a->operand(0).As<IntImm>();
CHECK(a_first_i);
int times = a_first_i->value / b;
if (times == 1) {
return Product::Make(Rest(a->operands()));
} else {
auto args = Rest(a->operands());
args.insert(std::begin(args), make_const(a->type(), times));
return Product::Make(args);
std::vector<Expr> args;
int i = 0;
int times = -1;
bool is_divisible = false;
for (i = 0; i < a->operands().size(); i++) {
auto* a_i = a->operand(i).As<IntImm>();
if (a_i && a_i->value % b == 0) {
times = a_i->value / b;
is_divisible = true;
break;
}
}
// NOTE that a should be divisible by b.
CHECK(is_divisible) << "a should be divisible by b";
if (times != 1) {
args.push_back(make_const(a->type(), times));
}
for (int j = 0; j < a->operands().size(); j++) {
if (j == i) continue;
args.push_back(a->operand(j));
}
return Product::Make(args);
}

// @}

inline int Iquot(int n, int d) { return n / d; }
Expand Down Expand Up @@ -768,8 +781,11 @@ std::vector<Expr> CasSimplifyMutator::MergeSum(const std::vector<Expr>& p, const

return MergeExprs(p, q, [this](Expr left, Expr right) -> std::vector<Expr> {
auto&& h = SimplifyBinarySum(std::move(left), std::move(right));
if (h.size() == 1 && h[0].is_constant() && h[0].get_constant() == 0) {return {};}
else {return std::move(h);}
if (h.size() == 1 && h[0].is_constant() && h[0].get_constant() == 0) {
return {};
} else {
return std::move(h);
}
});
}

Expand Down Expand Up @@ -946,6 +962,7 @@ bool CasSimplifyMutator::GetVarBound(Expr* lower_bound, Expr* upper_bound, Expr
CHECK(upper_bound);
auto v_var = var.As<_Var_>();
auto v_product = var.As<Product>();
auto v_frac = var.As<FracOp>();
if (v_var && (var_intervals.count(v_var->name) || !unfold_const_bound)) {
UnfoldBound(lower_bound, upper_bound, var, unfold_const_bound);
return true;
Expand All @@ -972,6 +989,29 @@ bool CasSimplifyMutator::GetVarBound(Expr* lower_bound, Expr* upper_bound, Expr
AddBaseAndSimplify(upper_bound, p_upper_bound);
return true;
}
} else if (v_frac) {
// only deal with x/2
Expr p_lower_bound;
Expr p_upper_bound;
Expr non_const_oper = v_frac->a();
Expr const_oper = v_frac->b();
auto v_var = non_const_oper.As<_Var_>();
if (v_var && var_intervals.count(v_var->name)) {
Expr v_lower, v_upper;
UnfoldBound(&v_lower, &v_upper, non_const_oper, unfold_const_bound);
auto const_v = const_oper.get_constant();
CHECK(v_lower.defined() && v_upper.defined());
if (const_v > 0) {
p_lower_bound = FracOp::Make(v_lower, const_oper);
p_upper_bound = FracOp::Make(v_upper, const_oper);
} else {
p_lower_bound = FracOp::Make(v_upper, const_oper);
p_upper_bound = FracOp::Make(v_lower, const_oper);
}
AddBaseAndSimplify(lower_bound, p_lower_bound);
AddBaseAndSimplify(upper_bound, p_upper_bound);
return true;
}
}
return false;
}
Expand Down Expand Up @@ -1221,20 +1261,102 @@ Expr CasSimplifyMutator::SimplifyMinAndMax(Expr u) {
auto* u_max = u.As<Max>();
auto* u_min = u.As<Min>();
if (u_max) {
Expr a = CasSimplify(u_max->a(), var_intervals);
Expr b = CasSimplify(u_max->b(), var_intervals);
if (a.is_constant() && b.is_constant()) {
Expr a = CasSimplify(u_max->a(), var_intervals);
Expr b = CasSimplify(u_max->b(), var_intervals);
bool is_a_const = a.is_constant();
bool is_b_const = b.is_constant();
if (is_a_const && is_b_const) {
return a.get_constant() >= b.get_constant() ? a : b;
}
Expr lower_bound, upper_bound;
Expr const_operand, non_const_operand;
if (is_a_const) {
const_operand = a;
non_const_operand = b;
}
if (is_b_const) {
const_operand = b;
non_const_operand = a;
}
if (const_operand.defined() && non_const_operand.defined()) {
auto const_size = const_operand.get_constant();
// unfold var with bounds
if (GetExprBound(&lower_bound, &upper_bound, non_const_operand, true)) {
// if non_const_operand's lower_bound is larger than const_operand, then non_const_operand must be larger than
// const_operand
if (lower_bound.is_constant() && const_size <= lower_bound.get_constant()) {
return non_const_operand;
}
// if non_const_operand's upper_bound is smaller than a, then const_operand must be larger than
// non_const_operand
if (upper_bound.is_constant() && const_size >= upper_bound.get_constant()) {
return const_operand;
}
}
// not unfold var for var may be eliminated in the caculation
if (GetExprBound(&lower_bound, &upper_bound, non_const_operand, false)) {
// if non_const_operand's lower_bound is larger than const_operand, then non_const_operand must be larger than
// const_operand
lower_bound = CasSimplify(lower_bound, var_intervals);
upper_bound = CasSimplify(upper_bound, var_intervals);
if (lower_bound.is_constant() && const_size <= lower_bound.get_constant()) {
return non_const_operand;
}
// if non_const_operand's upper_bound is smaller than a, then const_operand must be larger than
// non_const_operand
if (upper_bound.is_constant() && const_size >= upper_bound.get_constant()) {
return const_operand;
}
}
}
return ir::Max::Make(a, b);
}

if (u_min) {
Expr a = CasSimplify(u_min->a(), var_intervals);
Expr b = CasSimplify(u_min->b(), var_intervals);
if (a.is_constant() && b.is_constant()) {
Expr a = CasSimplify(u_min->a(), var_intervals);
Expr b = CasSimplify(u_min->b(), var_intervals);
bool is_a_const = a.is_constant();
bool is_b_const = b.is_constant();
if (is_a_const && is_b_const) {
return a.get_constant() <= b.get_constant() ? a : b;
}
Expr lower_bound, upper_bound;
Expr const_operand, non_const_operand;
if (is_a_const) {
const_operand = a;
non_const_operand = b;
}
if (is_b_const) {
const_operand = b;
non_const_operand = a;
}
if (const_operand.defined() && non_const_operand.defined()) {
auto const_size = const_operand.get_constant();
if (GetExprBound(&lower_bound, &upper_bound, non_const_operand, true)) {
// if non_const_operand's lower_bound is larger than const_operand, then non_const_operand must be larger than
// const_operand
if (lower_bound.is_constant() && const_size <= lower_bound.get_constant()) {
return const_operand;
}
// if non_const_operand's upper_bound is smaller than a, then const_operand must be larger than
// non_const_operand
if (upper_bound.is_constant() && const_size >= upper_bound.get_constant()) {
return non_const_operand;
}
}
if (GetExprBound(&lower_bound, &upper_bound, non_const_operand, false)) {
// if non_const_operand's lower_bound is larger than const_operand, then non_const_operand must be larger than
// const_operand
if (lower_bound.is_constant() && const_size <= lower_bound.get_constant()) {
return const_operand;
}
// if non_const_operand's upper_bound is smaller than a, then const_operand must be larger than
// non_const_operand
if (upper_bound.is_constant() && const_size >= upper_bound.get_constant()) {
return non_const_operand;
}
}
}
return ir::Min::Make(a, b);
}
return u;
Expand Down
18 changes: 16 additions & 2 deletions cinn/optim/ir_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ void PartialSimplify(Expr* expr, const std::unordered_map<std::string, common::C

//! Simplify the expression but Load.
struct SimplifyButStoreLoadMutator : public ir::IRMutator<ir::Expr*> {
const common::cas_intervals_t& var_intervals;
explicit SimplifyButStoreLoadMutator(const common::cas_intervals_t& var_intervals) : var_intervals(var_intervals) {}
common::cas_intervals_t& var_intervals;
explicit SimplifyButStoreLoadMutator(common::cas_intervals_t& var_intervals) : var_intervals(var_intervals) {}

void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }

Expand All @@ -48,6 +48,8 @@ struct SimplifyButStoreLoadMutator : public ir::IRMutator<ir::Expr*> {
__(Mul)
__(Sub)
__(Div)
__(Min)
__(Max)
#undef __

void Visit(const Ramp* op, Expr* expr) override {
Expand All @@ -72,8 +74,20 @@ struct SimplifyButStoreLoadMutator : public ir::IRMutator<ir::Expr*> {

void Visit(const For* op, Expr* expr) override {
auto* node = expr->As<ir::For>();
Visit(&node->min, &node->min);
Visit(&node->extent, &node->extent);
auto* min_i = op->min.As<IntImm>();
auto* extent_i = op->extent.As<IntImm>();
if (min_i && extent_i && extent_i->value > min_i->value) {
var_intervals.emplace(op->loop_var->name, common::CasInterval{min_i->value, extent_i->value - 1});
} else {
var_intervals.emplace(op->loop_var->name, common::CasInterval{op->min, op->extent - 1});
}

Visit(&node->body, &node->body);
if (min_i && extent_i) {
var_intervals.erase(op->loop_var->name);
}
}

void Visit(const _Tensor_* op, Expr* expr) override {
Expand Down
4 changes: 4 additions & 0 deletions cinn/optim/vectorize_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,12 +407,14 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
if (extent_min || extent_max || !vectorizable_) {
// not vectorize if has tail blocks, for llvm to optimize
node->reset_vectorize_info();
var_intervals.erase(forloop->loop_var->name);
return;
}

auto _new_forloop = SplitForLoop(node, forloop->vectorize_info().factor);
if (!_new_forloop.defined()) {
IRMutator<>::Visit(&node->body, &node->body);
var_intervals.erase(forloop->loop_var->name);
return;
}

Expand All @@ -426,6 +428,7 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {

if (!extent_int) {
IRMutator<>::Visit(&node->body, &node->body);
var_intervals.erase(forloop->loop_var->name);
return;
}

Expand All @@ -445,6 +448,7 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
} else {
IRMutator::Visit(forloop, expr);
}
var_intervals.erase(forloop->loop_var->name);
}

//! unroll the forloop if its' extent is min type by solving the condition extent
Expand Down

0 comments on commit 53bdc07

Please sign in to comment.