Skip to content

Commit

Permalink
[Pattern matching] Add an option to rewrite the graph only once (#8843)
Browse files Browse the repository at this point in the history
* [Pattern matching] Add an option to rewrite the graph only once

If the graph returned from the callback consists of the original
pattern, the rewriter will run in the loop, which is not always desired.
So this patch proposes an option to run the rewriter only once.

Change-Id: I85cf0a055b8961d52394f21c1e4d7aad0a7e1d06

* Make rewrite_once default to false

Change-Id: Idf6f01f254c403158883681e75c2a5978efbd2d0
  • Loading branch information
ekalda committed Aug 26, 2021
1 parent 3f777d5 commit d263c6d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 74 deletions.
6 changes: 5 additions & 1 deletion include/tvm/relay/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,13 @@ class DFPatternCallbackNode : public Object {
PackedFunc function;
/*! \brief Require InferType to be run before the callback */
bool require_type;
/*! \brief Run the callback only once */
bool rewrite_once;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pattern", &pattern);
v->Visit("require_type", &require_type);
v->Visit("rewrite_once", &rewrite_once);
}

static constexpr const char* _type_key = "DFPatternCallbackNode";
Expand All @@ -63,7 +66,8 @@ class DFPatternCallbackNode : public Object {
*/
class DFPatternCallback : public ObjectRef {
public:
TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type);
TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type,
bool rewrite_once = false);
TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode);
};

Expand Down
17 changes: 13 additions & 4 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,11 +796,14 @@ class DFPatternCallback:
----------
require_type: bool
Whether InferType is required to be run before the callback.
rewrite_once: bool
If True, run the callback only once.
"""

def __init__(self, require_type=False):
def __init__(self, require_type=False, rewrite_once=False):
self.pattern = None
self.require_type = require_type
self.rewrite_once = rewrite_once

def rewrite(self, expr: Expr) -> Expr:
"""
Expand Down Expand Up @@ -842,8 +845,10 @@ def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Exp
class _DFPatternCallback(Object):
"""C++ implemenation"""

def __init__(self, pattern, callback, require_type):
self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback, require_type)
def __init__(self, pattern, callback, require_type, rewrite_once):
self.__init_handle_by_constructor__(
ffi.DFPatternCallback, pattern, callback, require_type, rewrite_once
)


def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr:
Expand All @@ -870,7 +875,11 @@ def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr:
tmp = []
for callback in callbacks:
assert callback.pattern is not None
tmp.append(_DFPatternCallback(callback.pattern, callback.callback, callback.require_type))
tmp.append(
_DFPatternCallback(
callback.pattern, callback.callback, callback.require_type, callback.rewrite_once
)
)

return ffi.rewrite(tmp, expr, mod)

Expand Down
11 changes: 7 additions & 4 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -752,19 +752,22 @@ bool PatternGrouper::EmbedConst(const Expr& expr, const DFPattern pattern) {

// Rewrite

DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function, bool require_type) {
DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function, bool require_type,
bool rewrite_once) {
ObjectPtr<DFPatternCallbackNode> n = make_object<DFPatternCallbackNode>();
n->pattern = std::move(pattern);
n->function = std::move(function);
n->require_type = require_type;
n->rewrite_once = rewrite_once;
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback")
.set_body_typed([](DFPattern pattern, PackedFunc function, bool require_type) {
return DFPatternCallback(pattern, function, require_type);
.set_body_typed([](DFPattern pattern, PackedFunc function, bool require_type,
bool rewrite_once) {
return DFPatternCallback(pattern, function, require_type, rewrite_once);
});

Expr PatternRewriter::Rewrite(const Array<DFPatternCallback>& callbacks, const Expr& pre) {
Expand All @@ -790,7 +793,7 @@ Expr PatternRewriter::Rewrite(const Array<DFPatternCallback>& callbacks, const E
count++;
}
equal = (*structural_equal)(last, post, false, true);
} while (!equal && count < 100);
} while (!equal && count < 100 && !callback_->rewrite_once);
if (count >= 100) {
LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?";
}
Expand Down
98 changes: 33 additions & 65 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,69 +1727,37 @@ def test_partition_constant_embedding():
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))


def test_rewrite_once():
# This class recursively removes the arguments to concat until there is nothing left to concatenate.
class ConcatRewriter(DFPatternCallback):
def __init__(self, rewrite_once):
super().__init__(rewrite_once=rewrite_once)
self.pattern = is_op("concatenate")(None)

def callback(self, pre, post, node_map):
concat_args = post.args[0]
# Remove the last argument
new_args = [concat_args[i] for i in range(len(concat_args) - 1)]
if new_args:
return relay.op.concatenate(relay.expr.Tuple(new_args), axis=0)
else:
return concat_args

x = relay.var("x")
y = relay.var("y")
z = relay.var("z")
concat = relay.op.concatenate(relay.expr.Tuple([x, y, z]), axis=0)

# Let the rewriter run recursively
out = rewrite(ConcatRewriter(False), concat)
expected = relay.expr.Tuple([x])
assert tvm.ir.structural_equal(out, expected)

# Run the rewriter once
out = rewrite(ConcatRewriter(True), concat)
expected = relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0)
assert tvm.ir.structural_equal(out, expected)


if __name__ == "__main__":
test_expr_pattern()
test_var_pattern()
test_constant_pattern()
test_wildcard_pattern()
test_CallPattern()
test_TuplePattern()
test_TupleGetItemPattern()
test_AltPattern()
test_TypePattern()
test_DataTypePattern()
test_ShapePattern()
test_AttrPattern()
test_match_op()
test_no_match_op()
test_match_op_or()
test_match_call_commutive()
test_no_match_call_commutive()
test_match_call()
test_no_match_call()
test_match_option()
test_no_match_option()
test_match_const()
test_match_tuple()
test_no_match_tuple()
test_match_type()
test_no_match_type()
test_match_dtype()
test_no_match_dtype()
test_match_shape()
test_no_match_shape()
test_match_op_attr()
test_no_match_op_attr()
test_match_func_attr()
test_no_match_func_attr()
test_match_call_attr()
test_no_match_call_attr()
test_match_diamond()
test_no_match_diamond()
test_match_fake_diamond()
test_match_dominator()
test_not_match_dominator()
test_rewrite()
test_rewrite_func()
test_nested_rewrite()
test_not_fuse_multi_diamond()
test_fuse_batchnorm()
test_no_fuse_batchnorm()
test_fuse_double_batchnorm()
test_partial_fuse_double_batchnorm()
test_fuse_batchnorm_commutation()
test_quadruple_rewrite_dominator()
test_algebraic_simplify()
test_double_partition()
test_partition_dominator()
test_quadruple_partition_dominator()
test_partition_batchnorm()
test_partition_double_batchnorm()
test_partition_check()
test_partition_check_types()
test_partition_option()
test_match_match()
test_partition_constant_embedding()
test_IfPattern()
test_match_if()
test_no_match_if()
pytest.main([__file__])

0 comments on commit d263c6d

Please sign in to comment.