Skip to content

Commit

Permalink
[Lang] Add TensorType support for Constant Folding (#8250)
Browse files Browse the repository at this point in the history
Issue: #

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at b0140a9</samp>

Improve and simplify the constant folding transform for arithmetic
operations. Use helper functions to evaluate and replace constant
expressions with matrices in `taichi/transforms/constant_fold.cpp`.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at b0140a9</samp>

* Refactor `visit` functions for `BinaryOpStmt` and `UnaryOpStmt` to use
separate functions `get_scalar_value_to_replace` that return optional
`TypedConstant` based on operands and operation type
([link](https://github.com/taichi-dev/taichi/pull/8250/files?diff=unified&w=0#diff-82a8161a7771f3ee974357cda46f5684aead5225049e864e89767b078ad58b30L49-R52),
[link](https://github.com/taichi-dev/taichi/pull/8250/files?diff=unified&w=0#diff-82a8161a7771f3ee974357cda46f5684aead5225049e864e89767b078ad58b30L233-R301))
* Handle constant folding for binary and unary operations on matrices by
iterating over scalar values and creating new `MatrixInitStmt` with
evaluated constants
([link](https://github.com/taichi-dev/taichi/pull/8250/files?diff=unified&w=0#diff-82a8161a7771f3ee974357cda46f5684aead5225049e864e89767b078ad58b30L152-R211),
[link](https://github.com/taichi-dev/taichi/pull/8250/files?diff=unified&w=0#diff-82a8161a7771f3ee974357cda46f5684aead5225049e864e89767b078ad58b30L233-R301))
* Simplify `insert_and_erase` function calls by using local variables
`res` and `dst_type` and moving them out of switch statements
([link](https://github.com/taichi-dev/taichi/pull/8250/files?diff=unified&w=0#diff-82a8161a7771f3ee974357cda46f5684aead5225049e864e89767b078ad58b30L60-R58),
[link](https://github.com/taichi-dev/taichi/pull/8250/files?diff=unified&w=0#diff-82a8161a7771f3ee974357cda46f5684aead5225049e864e89767b078ad58b30L68-R90),
[link](https://github.com/taichi-dev/taichi/pull/8250/files?diff=unified&w=0#diff-82a8161a7771f3ee974357cda46f5684aead5225049e864e89767b078ad58b30L118-R122),
[link](https://github.com/taichi-dev/taichi/pull/8250/files?diff=unified&w=0#diff-82a8161a7771f3ee974357cda46f5684aead5225049e864e89767b078ad58b30L209-R242))
* Add new function `insert_and_erase` that takes a vector of
`TypedConstant` and creates a vector of `ConstStmt` and a
`MatrixInitStmt` to replace a statement with a constant matrix
([link](https://github.com/taichi-dev/taichi/pull/8250/files?diff=unified&w=0#diff-82a8161a7771f3ee974357cda46f5684aead5225049e864e89767b078ad58b30R327-R343))
  • Loading branch information
jim19930609 committed Jul 4, 2023
1 parent 560b740 commit a992f22
Showing 1 changed file with 157 additions and 92 deletions.
249 changes: 157 additions & 92 deletions taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,52 +46,48 @@ class ConstantFold : public BasicStmtVisitor {
return false;
}

void visit(BinaryOpStmt *stmt) override {
auto lhs = stmt->lhs->cast<ConstStmt>();
auto rhs = stmt->rhs->cast<ConstStmt>();
if (!lhs || !rhs)
return;
auto dst_type = stmt->ret_type;
std::optional<TypedConstant> get_scalar_value_to_replace(BinaryOpStmt *stmt,
ConstStmt *lhs,
ConstStmt *rhs,
DataType dst_type) {
TypedConstant new_constant(dst_type);

if (stmt->op_type == BinaryOpType::pow) {
if (is_integral(rhs->ret_type)) {
auto rhs_val = rhs->val.val_int();
if (rhs_val < 0 && is_integral(stmt->ret_type)) {
if (rhs_val < 0 && is_integral(dst_type)) {
TI_ERROR("Negative exponent in pow(int, int) is not allowed.");
}
}
}

// Type check should have been done at this point.
auto dt = lhs->val.dt;

std::optional<TypedConstant> res = std::nullopt;
switch (stmt->op_type) {
#define COMMA ,
#define HANDLE_REAL_AND_INTEGRAL_BINARY(OP_TYPE, PREFIX, OP_CPP) \
case BinaryOpType::OP_TYPE: { \
if (dt->is_primitive(PrimitiveTypeID::f32) || \
dt->is_primitive(PrimitiveTypeID::f64)) { \
auto res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_cast_to_float64() \
OP_CPP rhs->val.val_cast_to_float64())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::i32) || \
dt->is_primitive(PrimitiveTypeID::i64)) { \
auto res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_int() OP_CPP rhs->val.val_int())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::u32) || \
dt->is_primitive(PrimitiveTypeID::u64)) { \
auto res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_uint() OP_CPP rhs->val.val_uint())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::u1)) { \
auto res = TypedConstant( \
dst_type, PREFIX(int32(lhs->val.val_uint1()) \
OP_CPP int32(rhs->val.val_uint1()))); \
insert_and_erase(stmt, res); \
} \
break; \
#define HANDLE_REAL_AND_INTEGRAL_BINARY(OP_TYPE, PREFIX, OP_CPP) \
case BinaryOpType::OP_TYPE: { \
if (dt->is_primitive(PrimitiveTypeID::f32) || \
dt->is_primitive(PrimitiveTypeID::f64)) { \
res = TypedConstant(dst_type, \
PREFIX(lhs->val.val_cast_to_float64() \
OP_CPP rhs->val.val_cast_to_float64())); \
} else if (dt->is_primitive(PrimitiveTypeID::i32) || \
dt->is_primitive(PrimitiveTypeID::i64)) { \
res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_int() OP_CPP rhs->val.val_int())); \
} else if (dt->is_primitive(PrimitiveTypeID::u32) || \
dt->is_primitive(PrimitiveTypeID::u64)) { \
res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_uint() OP_CPP rhs->val.val_uint())); \
} else if (dt->is_primitive(PrimitiveTypeID::u1)) { \
res = TypedConstant(dst_type, \
PREFIX(int32(lhs->val.val_uint1()) \
OP_CPP int32(rhs->val.val_uint1()))); \
} \
break; \
}

HANDLE_REAL_AND_INTEGRAL_BINARY(mul, , *)
Expand All @@ -115,18 +111,15 @@ class ConstantFold : public BasicStmtVisitor {
#define HANDLE_INTEGRAL_BINARY(OP_TYPE, PREFIX, OP_CPP) \
case BinaryOpType::OP_TYPE: { \
if (dt->is_primitive(PrimitiveTypeID::i32)) { \
auto res = TypedConstant( \
res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_int32() OP_CPP rhs->val.val_int32())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::i64)) { \
auto res = TypedConstant( \
res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_int() OP_CPP rhs->val.val_int())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::u32) || \
dt->is_primitive(PrimitiveTypeID::u64)) { \
auto res = TypedConstant( \
res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_uint() OP_CPP rhs->val.val_uint())); \
insert_and_erase(stmt, res); \
} \
break; \
}
Expand All @@ -149,44 +142,73 @@ class ConstantFold : public BasicStmtVisitor {
default:
break;
}

return res;
}

void visit(UnaryOpStmt *stmt) override {
if (stmt->is_cast() && stmt->cast_type == stmt->operand->ret_type) {
stmt->replace_usages_with(stmt->operand);
modifier.erase(stmt);
return;
void visit(BinaryOpStmt *stmt) override {
auto lhs = stmt->lhs;
auto rhs = stmt->rhs;

if (lhs->is<ConstStmt>() && rhs->is<ConstStmt>()) {
auto typed_constant = get_scalar_value_to_replace(
stmt, lhs->as<ConstStmt>(), rhs->as<ConstStmt>(), stmt->ret_type);
if (!typed_constant)
return;

TypedConstant new_constant = *typed_constant;
insert_and_erase(stmt, new_constant);
} else if (lhs->is<MatrixInitStmt>() && rhs->is<MatrixInitStmt>()) {
int num_values = rhs->as<MatrixInitStmt>()->values.size();

std::vector<TypedConstant> typed_constants;
for (int i = 0; i < num_values; i++) {
auto scalar_lhs =
lhs->as<MatrixInitStmt>()->values[i]->cast<ConstStmt>();
auto scalar_rhs =
rhs->as<MatrixInitStmt>()->values[i]->cast<ConstStmt>();
if (!scalar_lhs || !scalar_rhs)
return;

auto typed_constant = get_scalar_value_to_replace(
stmt, scalar_lhs, scalar_rhs, stmt->ret_type.get_element_type());
if (!typed_constant)
return;

TypedConstant new_constant = *typed_constant;
typed_constants.push_back(new_constant);
}
insert_and_erase(stmt, typed_constants);
}
auto operand = stmt->operand->cast<ConstStmt>();
if (!operand)
return;
}

std::optional<TypedConstant> get_scalar_value_to_replace(UnaryOpStmt *stmt,
ConstStmt *operand,
DataType dst_type) {
if (stmt->is_cast() && stmt->op_type == UnaryOpType::cast_bits) {
TypedConstant new_constant(stmt->ret_type);
TypedConstant new_constant(dst_type);
new_constant.value_bits = operand->val.value_bits;
insert_and_erase(stmt, new_constant);
return;
return new_constant;
}
const auto dt = operand->val.dt;
if (!is_good_type(dt))
return;
const auto dst_type = stmt->ret_type;
return std::nullopt;

std::optional<TypedConstant> res = std::nullopt;
switch (stmt->op_type) {
#define HANDLE_REAL_AND_INTEGRAL_UNARY(OP_TYPE, OP_CPP) \
case UnaryOpType::OP_TYPE: { \
if (dt->is_primitive(PrimitiveTypeID::f32) || \
dt->is_primitive(PrimitiveTypeID::f64)) { \
auto res = TypedConstant(dst_type, OP_CPP(operand->val.val_float())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::i32) || \
dt->is_primitive(PrimitiveTypeID::i64)) { \
auto res = TypedConstant(dst_type, OP_CPP(operand->val.val_int())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::u32) || \
dt->is_primitive(PrimitiveTypeID::u64)) { \
auto res = TypedConstant(dst_type, OP_CPP(operand->val.val_uint())); \
insert_and_erase(stmt, res); \
} \
break; \
#define HANDLE_REAL_AND_INTEGRAL_UNARY(OP_TYPE, OP_CPP) \
case UnaryOpType::OP_TYPE: { \
if (dt->is_primitive(PrimitiveTypeID::f32) || \
dt->is_primitive(PrimitiveTypeID::f64)) { \
res = TypedConstant(dst_type, OP_CPP(operand->val.val_float())); \
} else if (dt->is_primitive(PrimitiveTypeID::i32) || \
dt->is_primitive(PrimitiveTypeID::i64)) { \
res = TypedConstant(dst_type, OP_CPP(operand->val.val_int())); \
} else if (dt->is_primitive(PrimitiveTypeID::u32) || \
dt->is_primitive(PrimitiveTypeID::u64)) { \
res = TypedConstant(dst_type, OP_CPP(operand->val.val_uint())); \
} \
break; \
}

HANDLE_REAL_AND_INTEGRAL_UNARY(neg, -)
Expand All @@ -206,21 +228,18 @@ class ConstantFold : public BasicStmtVisitor {
HANDLE_REAL_AND_INTEGRAL_UNARY(rsqrt, 1.0 / std::sqrt)
#undef HANDLE_REAL_AND_INTEGRAL_UNARY

#define HANDLE_INTEGRAL_UNARY(OP_TYPE, OP_CPP) \
case UnaryOpType::OP_TYPE: { \
if (dt->is_primitive(PrimitiveTypeID::i32) || \
dt->is_primitive(PrimitiveTypeID::i64)) { \
auto res = TypedConstant(dst_type, OP_CPP(operand->val.val_int())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::u32) || \
dt->is_primitive(PrimitiveTypeID::u64)) { \
auto res = TypedConstant(dst_type, OP_CPP(operand->val.val_uint())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::u1)) { \
auto res = TypedConstant(dst_type, !operand->val.val_uint1()); \
insert_and_erase(stmt, res); \
} \
break; \
#define HANDLE_INTEGRAL_UNARY(OP_TYPE, OP_CPP) \
case UnaryOpType::OP_TYPE: { \
if (dt->is_primitive(PrimitiveTypeID::i32) || \
dt->is_primitive(PrimitiveTypeID::i64)) { \
res = TypedConstant(dst_type, OP_CPP(operand->val.val_int())); \
} else if (dt->is_primitive(PrimitiveTypeID::u32) || \
dt->is_primitive(PrimitiveTypeID::u64)) { \
res = TypedConstant(dst_type, OP_CPP(operand->val.val_uint())); \
} else if (dt->is_primitive(PrimitiveTypeID::u1)) { \
res = TypedConstant(dst_type, !operand->val.val_uint1()); \
} \
break; \
}

HANDLE_INTEGRAL_UNARY(bit_not, ~)
Expand All @@ -230,27 +249,56 @@ class ConstantFold : public BasicStmtVisitor {
case UnaryOpType::cast_value: {
if (dt->is_primitive(PrimitiveTypeID::f32) ||
dt->is_primitive(PrimitiveTypeID::f64)) {
auto res = TypedConstant(dst_type, operand->val.val_float());
insert_and_erase(stmt, res);
res = TypedConstant(dst_type, operand->val.val_float());
} else if (dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::i64)) {
auto res = TypedConstant(dst_type, operand->val.val_int());
insert_and_erase(stmt, res);
res = TypedConstant(dst_type, operand->val.val_int());
} else if (dt->is_primitive(PrimitiveTypeID::u32) ||
dt->is_primitive(PrimitiveTypeID::u64)) {
auto res = TypedConstant(dst_type, operand->val.val_uint());
insert_and_erase(stmt, res);
res = TypedConstant(dst_type, operand->val.val_uint());
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
auto res = TypedConstant(dst_type, operand->val.val_uint1());
insert_and_erase(stmt, res);
res = TypedConstant(dst_type, operand->val.val_uint1());
}
break;
}
default:
return;
return std::nullopt;
}
return res;
}

void visit(UnaryOpStmt *stmt) override {
if (stmt->is_cast() && stmt->cast_type == stmt->operand->ret_type) {
stmt->replace_usages_with(stmt->operand);
modifier.erase(stmt);
return;
}

return;
if (auto operand = stmt->operand->cast<ConstStmt>()) {
auto typed_constant =
get_scalar_value_to_replace(stmt, operand, stmt->ret_type);
if (!typed_constant)
return;

TypedConstant new_constant = *typed_constant;
insert_and_erase(stmt, new_constant);
} else if (auto operand = stmt->operand->cast<MatrixInitStmt>()) {
std::vector<TypedConstant> typed_constants;
for (auto &scalar_operand : operand->values) {
auto const_scalar_operand = scalar_operand->cast<ConstStmt>();
if (!const_scalar_operand)
return;

auto typed_constant = get_scalar_value_to_replace(
stmt, const_scalar_operand, stmt->ret_type.get_element_type());
if (!typed_constant)
return;

TypedConstant new_constant = *typed_constant;
typed_constants.push_back(new_constant);
}
insert_and_erase(stmt, typed_constants);
}
}

static bool run(IRNode *node) {
Expand All @@ -276,6 +324,23 @@ class ConstantFold : public BasicStmtVisitor {
modifier.insert_before(stmt, std::move(evaluated));
modifier.erase(stmt);
}

void insert_and_erase(Stmt *stmt,
const std::vector<TypedConstant> &new_constants) {
std::vector<Stmt *> values;
for (auto &new_constant : new_constants) {
auto const_stmt = Stmt::make<ConstStmt>(new_constant);
values.push_back(const_stmt.get());
modifier.insert_before(stmt, std::move(const_stmt));
}

auto evaluated = Stmt::make<MatrixInitStmt>(values);
evaluated->ret_type = stmt->ret_type;

stmt->replace_usages_with(evaluated.get());
modifier.insert_before(stmt, std::move(evaluated));
modifier.erase(stmt);
}
};

const PassID ConstantFoldPass::id = "ConstantFoldPass";
Expand Down

0 comments on commit a992f22

Please sign in to comment.