Skip to content

Commit

Permalink
[PASS] Assign unique names to variables in ConvertSSA pass (#18)
Browse files Browse the repository at this point in the history
* [PASS] Assign unique names to variables in ConvertSSA pass

* revert change to ConverSSA pass
  • Loading branch information
icemelon authored and tqchen committed Jan 18, 2017
1 parent 9e1a5ec commit 110c9be
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
18 changes: 9 additions & 9 deletions src/c_api/c_api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@ using RetValue = APIVariantValue;

TVM_REGISTER_API(_pass_Simplify)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) {
*ret = Simplify(args.at(0).operator Expr());
} else {
if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) {
*ret = Simplify(args.at(0).operator Stmt());
} else {
*ret = Simplify(args.at(0).operator Expr());
}
});

TVM_REGISTER_API(_pass_Equal)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(1).type_id == kNodeHandle);
if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) {
*ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr());
} else {
if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) {
CHECK(args.at(1).type_id == kNodeHandle);
*ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
} else {
Expr a = args.at(0).operator Expr();
Expr b = args.at(1).operator Expr();
*ret = Equal(a, b);
}
});

Expand Down
8 changes: 6 additions & 2 deletions tests/python/test_pass_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ def test_simplify():
assert(tvm.ir_pass.Equal(e2, x * 8))
e3 = tvm.ir_pass.Simplify(x - x / 3 * 3)
assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3)))
let = tvm.make.Let(x, 1, x + 3)
e4 = tvm.ir_pass.Simplify(let)
assert(tvm.ir_pass.Equal(e4, 4))


def test_verify_ssa():
Expand All @@ -20,8 +23,9 @@ def test_verify_ssa():
def test_convert_ssa():
x = tvm.Var('x')
y = tvm.Var()
let = tvm.make.Let(x, 1, x + 1)
z = tvm.make.Evaluate(let + let)
let1 = tvm.make.Let(x, 1, x + 1)
let2 = tvm.make.Let(x, 1, x + y)
z = tvm.make.Evaluate(let1 + let2)
assert(not tvm.ir_pass.VerifySSA(z))
z_ssa = tvm.ir_pass.ConvertSSA(z)
assert(tvm.ir_pass.VerifySSA(z_ssa))

0 comments on commit 110c9be

Please sign in to comment.