From c54f510e44c70d98384997066da0fb8a9008d8d0 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Thu, 26 Aug 2021 18:06:23 +0100 Subject: [PATCH] [Pattern matching] Add an option to rewrite the graph only once (#8843) * [Pattern matching] Add an option to rewrite the graph only once If the graph returned from the callback consists of the original pattern, the rewriter will run in the loop, which is not always desired. So this patch proposes an option to run the rewriter only once. Change-Id: I85cf0a055b8961d52394f21c1e4d7aad0a7e1d06 * Make rewrite_once default to false Change-Id: Idf6f01f254c403158883681e75c2a5978efbd2d0 --- include/tvm/relay/dataflow_matcher.h | 6 +- python/tvm/relay/dataflow_pattern/__init__.py | 17 +++- src/relay/ir/dataflow_matcher.cc | 11 ++- tests/python/relay/test_dataflow_pattern.py | 98 +++++++------------ 4 files changed, 58 insertions(+), 74 deletions(-) diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index 12e4e3f45fef..10e461645c8b 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -47,10 +47,13 @@ class DFPatternCallbackNode : public Object { PackedFunc function; /*! \brief Require InferType to be run before the callback */ bool require_type; + /*! \brief Run the callback only once */ + bool rewrite_once; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pattern", &pattern); v->Visit("require_type", &require_type); + v->Visit("rewrite_once", &rewrite_once); } static constexpr const char* _type_key = "DFPatternCallbackNode"; @@ -63,7 +66,8 @@ class DFPatternCallbackNode : public Object { */ class DFPatternCallback : public ObjectRef { public: - TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type); + TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type, + bool rewrite_once = false); TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode); }; diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 320a599d5d91..1f6d8bb9ab0b 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -796,11 +796,14 @@ class DFPatternCallback: ---------- require_type: bool Whether InferType is required to be run before the callback. + rewrite_once: bool + If True, run the callback only once. """ - def __init__(self, require_type=False): + def __init__(self, require_type=False, rewrite_once=False): self.pattern = None self.require_type = require_type + self.rewrite_once = rewrite_once def rewrite(self, expr: Expr) -> Expr: """ @@ -842,8 +845,10 @@ def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Exp class _DFPatternCallback(Object): """C++ implemenation""" - def __init__(self, pattern, callback, require_type): - self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback, require_type) + def __init__(self, pattern, callback, require_type, rewrite_once): + self.__init_handle_by_constructor__( + ffi.DFPatternCallback, pattern, callback, require_type, rewrite_once + ) def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr: @@ -870,7 +875,11 @@ def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr: tmp = [] for callback in callbacks: assert callback.pattern is not None - tmp.append(_DFPatternCallback(callback.pattern, callback.callback, callback.require_type)) + tmp.append( + _DFPatternCallback( + callback.pattern, callback.callback, callback.require_type, callback.rewrite_once + ) + ) return ffi.rewrite(tmp, expr, mod) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index d7f130f2796d..851a498377b2 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -752,19 +752,22 @@ bool PatternGrouper::EmbedConst(const Expr& expr, const DFPattern pattern) { // Rewrite -DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function, bool require_type) { +DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function, bool require_type, + bool rewrite_once) { ObjectPtr n = make_object(); n->pattern = std::move(pattern); n->function = std::move(function); n->require_type = require_type; + n->rewrite_once = rewrite_once; data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback") - .set_body_typed([](DFPattern pattern, PackedFunc function, bool require_type) { - return DFPatternCallback(pattern, function, require_type); + .set_body_typed([](DFPattern pattern, PackedFunc function, bool require_type, + bool rewrite_once) { + return DFPatternCallback(pattern, function, require_type, rewrite_once); }); Expr PatternRewriter::Rewrite(const Array& callbacks, const Expr& pre) { @@ -790,7 +793,7 @@ Expr PatternRewriter::Rewrite(const Array& callbacks, const E count++; } equal = (*structural_equal)(last, post, false, true); - } while (!equal && count < 100); + } while (!equal && count < 100 && !callback_->rewrite_once); if (count >= 100) { LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?"; } diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 1c721f40d129..74e03f6a9755 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -1727,69 +1727,37 @@ def test_partition_constant_embedding(): assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) +def test_rewrite_once(): + # This class recursively removes the arguments to concat until there is nothing left to concatenate. + class ConcatRewriter(DFPatternCallback): + def __init__(self, rewrite_once): + super().__init__(rewrite_once=rewrite_once) + self.pattern = is_op("concatenate")(None) + + def callback(self, pre, post, node_map): + concat_args = post.args[0] + # Remove the last argument + new_args = [concat_args[i] for i in range(len(concat_args) - 1)] + if new_args: + return relay.op.concatenate(relay.expr.Tuple(new_args), axis=0) + else: + return concat_args + + x = relay.var("x") + y = relay.var("y") + z = relay.var("z") + concat = relay.op.concatenate(relay.expr.Tuple([x, y, z]), axis=0) + + # Let the rewriter run recursively + out = rewrite(ConcatRewriter(False), concat) + expected = relay.expr.Tuple([x]) + assert tvm.ir.structural_equal(out, expected) + + # Run the rewriter once + out = rewrite(ConcatRewriter(True), concat) + expected = relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0) + assert tvm.ir.structural_equal(out, expected) + + if __name__ == "__main__": - test_expr_pattern() - test_var_pattern() - test_constant_pattern() - test_wildcard_pattern() - test_CallPattern() - test_TuplePattern() - test_TupleGetItemPattern() - test_AltPattern() - test_TypePattern() - test_DataTypePattern() - test_ShapePattern() - test_AttrPattern() - test_match_op() - test_no_match_op() - test_match_op_or() - test_match_call_commutive() - test_no_match_call_commutive() - test_match_call() - test_no_match_call() - test_match_option() - test_no_match_option() - test_match_const() - test_match_tuple() - test_no_match_tuple() - test_match_type() - test_no_match_type() - test_match_dtype() - test_no_match_dtype() - test_match_shape() - test_no_match_shape() - test_match_op_attr() - test_no_match_op_attr() - test_match_func_attr() - test_no_match_func_attr() - test_match_call_attr() - test_no_match_call_attr() - test_match_diamond() - test_no_match_diamond() - test_match_fake_diamond() - test_match_dominator() - test_not_match_dominator() - test_rewrite() - test_rewrite_func() - test_nested_rewrite() - test_not_fuse_multi_diamond() - test_fuse_batchnorm() - test_no_fuse_batchnorm() - test_fuse_double_batchnorm() - test_partial_fuse_double_batchnorm() - test_fuse_batchnorm_commutation() - test_quadruple_rewrite_dominator() - test_algebraic_simplify() - test_double_partition() - test_partition_dominator() - test_quadruple_partition_dominator() - test_partition_batchnorm() - test_partition_double_batchnorm() - test_partition_check() - test_partition_check_types() - test_partition_option() - test_match_match() - test_partition_constant_embedding() - test_IfPattern() - test_match_if() - test_no_match_if() + pytest.main([__file__])