diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 8890394637b36..550ea8220303b 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -133,6 +133,16 @@ class BasicBlockSimplify : public IRVisitor { } } + void visit(UnaryOpStmt *stmt) override { + if (stmt->op_type == UnaryOpType::abs) { + auto operand_type = stmt->operand->ret_type; + if (is_integral(operand_type) && is_unsigned(operand_type)) { + // abs(u) -> u + stmt->replace_usages_with(stmt->operand); + modifier.erase(stmt); + } + } + } template static bool identical_vectors(const std::vector &a, const std::vector &b) { diff --git a/tests/python/test_abs.py b/tests/python/test_abs.py index 8ff4bc35aeddd..ab86b392507f6 100644 --- a/tests/python/test_abs.py +++ b/tests/python/test_abs.py @@ -78,3 +78,13 @@ def foo(x: ti.i64) -> ti.i64: for x in [-(2**40), 0, 2**40]: assert foo(x) == abs(x) + + +@test_utils.test() +def test_abs_u32(): + @ti.kernel + def foo(x: ti.u32) -> ti.u32: + return abs(x) + + for x in [0, 2**20]: + assert foo(x) == abs(x)