Skip to content

Commit

Permalink
Fix match case in Python-side expr functor (apache#4037)
Browse files Browse the repository at this point in the history
  • Loading branch information
weberlo authored and wweic committed Oct 18, 2019
1 parent f499fd5 commit 15b5752
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ def visit_constructor(self, con):
return con

def visit_match(self, m):
return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses])
return Match(
self.visit(m.data),
[Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
complete=m.complete)

def visit_ref_create(self, r):
return RefCreate(self.visit(r.value))
Expand Down
11 changes: 11 additions & 0 deletions tests/python/relay/test_expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ def test_match():
p = relay.prelude.Prelude()
check_visit(p.mod[p.map])


def test_match_completeness():
p = relay.prelude.Prelude()
for completeness in [True, False]:
match_expr = relay.adt.Match(p.nil, [], complete=completeness)
result_expr = ExprMutator().visit(match_expr)
# ensure the mutator doesn't mangle the completeness flag
assert result_expr.complete == completeness


if __name__ == "__main__":
test_constant()
test_tuple()
Expand All @@ -139,3 +149,4 @@ def test_match():
test_ref_write()
test_memo()
test_match()
test_match_completeness()

0 comments on commit 15b5752

Please sign in to comment.