From dae31ce160a65ce3ff86c5c8c91e1ab57741aade Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 29 Oct 2021 00:35:45 -0700 Subject: [PATCH] Use a callback in the exact matcher (#18) * Don't use checked_type for annotating types in the matcher (breaks tests and relies on type checking) * Add callback feature to exact matcher, expand tests, simplify recent additions --- python/tvm/relay/testing/__init__.py | 2 +- python/tvm/relay/testing/exact_matcher.py | 77 ++++-------- tests/python/relay/test_exact_matcher.py | 141 ++++++++++++++++++++-- 3 files changed, 156 insertions(+), 64 deletions(-) diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index f64f96b36900c..36a82d046ce4f 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -48,7 +48,7 @@ from ..transform import gradient from .exact_matcher import annotate_exact_matches # these are just for testing -from .exact_matcher import deduplicate_vars, check_compiler_call +from .exact_matcher import deduplicate_vars, check_compiler_call, check_annotations from .op_summary import count_all_ops, count_all_overloads, count_all_ops_in_overloads diff --git a/python/tvm/relay/testing/exact_matcher.py b/python/tvm/relay/testing/exact_matcher.py index 72c8ce7279fe9..7e79940c7a237 100644 --- a/python/tvm/relay/testing/exact_matcher.py +++ b/python/tvm/relay/testing/exact_matcher.py @@ -9,7 +9,7 @@ from tvm.relay.analysis import free_vars, bound_vars # dumb copy of what src/relay/transforms/de_duplicate.cc is doing -def deduplicate_vars(expr, var_map={}, use_original=False): +def deduplicate_vars(expr): """ Given the expr, replace all vars in the expression with fresh ones. This is done to preserve well-formedness in Relay (all var definitions must be unique) @@ -17,13 +17,11 @@ def deduplicate_vars(expr, var_map={}, use_original=False): class Deduplicator(ExprMutator): def __init__(self): super().__init__() - self.var_map = var_map.copy() + self.var_map = {} def visit_var(self, var): if var in self.var_map: return self.var_map[var] - if use_original: - return var fresh_var = relay.Var(var.name_hint) self.var_map[var] = fresh_var return fresh_var @@ -48,20 +46,6 @@ def visit_match(self, match): for c in match.clauses] return relay.Match(new_val, clauses) - def visit_function(self, func): - args = list(map(self.visit, func.params)) - body = self.visit(func.body) - return relay.Function(args, body, func.ret_type, func.type_params) - - def visit_let(self, let_expr): - new_var = self.visit(let_expr.var) - new_body = self.visit(let_expr.body) - new_value = self.visit(let_expr.value) - # if isinstance(let_expr.value, relay.Function): - # print(let_expr.var, let_expr.value) - # print(new_value) - return relay.Let(new_var, new_value, new_body) - dedup = Deduplicator() return dedup.visit(expr) @@ -284,13 +268,16 @@ def visit_ref_write(self, ref_write): class MatchMutator(ExprMutator): - def __init__(self, target, compiler_name, composite_name, composite_counter=0): + def __init__(self, target, compiler_name, composite_name, composite_counter=0, callback=None): """ Target: Expression that the matcher is seeking Compiler name: Name for the custom codegen Composite name: Name for the *construct produced* in the custom codegen Composite counter: Id number used for generating compiler IDs (they must be globally unique) + Callback: Function (expr -> bool) that checks properties of the matched expr. + Register a match only if callback(expr) is True. + Default behavior is always to return True. Free vars in the target expression will be arguments to the extracted function """ @@ -302,20 +289,7 @@ def __init__(self, target, compiler_name, composite_name, composite_counter=0): self.compiler_name = compiler_name self.composite_name = composite_name self.composite_counter = composite_counter - - def to_key(self, relay_var): - def hash_type(relay_type): - { - tvm.ir.type.PrimType: lambda: relay_type.dtype, - tvm.ir.type.PointerType: lambda: hash_type(relay_type.element_type), - tvm.ir.type.TypeVar: lambda: (relay_type.name_hint, relay_type.kind), - tvm.ir.type.GlobalTypeVar: lambda: (relay_type.name_hint, relay_type.kind), - tvm.ir.type.TupleType: lambda: tuple(map(hash_type, relay_type.fields)), - tvm.ir.type.FuncType: lambda: (tuple(map(hash_type, relay_type.arg_types)), hash_type(relay_type.ret_type)), - tvm.ir.type.IncompleteType:lambda: relay_type.kind, - tvm.ir.type.RelayRefType: lambda: hash_type(relay_type.value), - }.get(type(relay_type), lambda: None)() - return (relay_var.name_hint, hash_type(relay_var.type_annotation)) + self.callback = (lambda expr: True) if callback is None else callback def extract_target(self, match_args): """ @@ -335,22 +309,16 @@ def extract_target(self, match_args): })(a1, ..., an) })(match_args[0], ..., match_args[n-1]) """ - # print(f'======={self.composite_counter}=======') - assert all(map(lambda v: self.to_key(v) in match_args, self.target_vars)) - match_ordering = [match_args[v] for v in map(self.to_key, self.target_vars)] + match_ordering = [match_args[v] for v in self.target_vars] # we have to deduplicate vars for Relay's well-formedness check # (all var definitions must be unique) inner_body = deduplicate_vars(self.target) - inner_free_vars = free_vars(inner_body) - # inner_args = list(map(lambda v: relay.Var(v.name_hint, match_args[self.to_key(v)].checked_type), inner_free_vars)) - inner_args_map = {} - for var in inner_free_vars: - inner_args_map[var] = relay.Var(var.name_hint + str(self.composite_counter), match_args[self.to_key(var)].checked_type) - inner_body_rewritten = deduplicate_vars(inner_body, var_map=inner_args_map, use_original=True) - inner_func = relay.Function(list(map(inner_args_map.get, inner_free_vars)), inner_body_rewritten) + inner_args = free_vars(inner_body) + inner_func = relay.Function(inner_args, inner_body) + inner_func = inner_func.with_attr("Composite", self.composite_name) - outer_args = [relay.Var(f"outer_arg_{i}") for i in range(len(inner_free_vars))] + outer_args = [relay.Var(f"outer_arg_{i}") for i in range(len(inner_args))] outer_func = relay.Function(outer_args, inner_func(*outer_args)) outer_func = outer_func.with_attr("Compiler", self.compiler_name) outer_func = outer_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) @@ -358,7 +326,7 @@ def extract_target(self, match_args): "global_symbol", f"{self.composite_name}_{self.composite_counter}") self.composite_counter += 1 - return outer_func(*match_ordering) + return outer_func(*match_ordering) def visit(self, expr): """ @@ -372,14 +340,15 @@ def visit(self, expr): return expr found_match, match_args = check_match(self.target, expr) - if found_match: + # only permit the match if the callback is true + if found_match and self.callback(expr): # need to check for matches in the match args too - final_args = {self.to_key(var): self.visit(arg) for var, arg in match_args.items()} + final_args = {var: self.visit(arg) for var, arg in match_args.items()} return self.extract_target(final_args) return super().visit(expr) -def annotate_exact_matches(expr, target, compiler_name, composite_name): +def annotate_exact_matches(expr, target, compiler_name, composite_name, callback=None): """ Given an expression and a target pattern, this will replace all instances of the target pattern @@ -404,7 +373,7 @@ def annotate_exact_matches(expr, target, compiler_name, composite_name): This nested function structure is designed to make it easier for BYOC codegens to match those definitions. """ - mut = MatchMutator(target, compiler_name, composite_name) + mut = MatchMutator(target, compiler_name, composite_name, callback=callback) return mut.visit(expr) @@ -420,16 +389,20 @@ def call_func_with_attr(expr, func_attr): return func_attr in expr.op.attrs +def check_annotations(expr): + if not call_func_with_attr(expr, "Compiler"): + return False + return call_func_with_attr(expr.op.body, "Composite") + + def check_compiler_call(expr, expected_body): """ Provided for testing purposes: Checks if the given expression is a matcher-produced compiler function with the given body """ # check for a compiler function with an inner composite - if not call_func_with_attr(expr, "Compiler"): + if not check_annotations(expr): return False inner_call = expr.op.body - if not call_func_with_attr(inner_call, "Composite"): - return False inner_body = inner_call.op.body return tvm.ir.structural_equal(inner_body, expected_body, True) diff --git a/tests/python/relay/test_exact_matcher.py b/tests/python/relay/test_exact_matcher.py index fbed4a700773e..a83970a5b1f90 100644 --- a/tests/python/relay/test_exact_matcher.py +++ b/tests/python/relay/test_exact_matcher.py @@ -1,32 +1,32 @@ import tvm from tvm import relay -from tvm.relay.testing import annotate_exact_matches, deduplicate_vars, check_compiler_call +from tvm.relay.testing import annotate_exact_matches, deduplicate_vars, check_annotations, check_compiler_call -def assert_simple_cases(pattern, compiler_name, pattern_name): +def assert_simple_cases(pattern, compiler_name, pattern_name, callback=None): fresh_pattern = deduplicate_vars(pattern) - self_match = annotate_exact_matches(fresh_pattern, pattern, compiler_name, pattern_name) + self_match = annotate_exact_matches(fresh_pattern, pattern, compiler_name, pattern_name, callback=callback) assert check_compiler_call(self_match, pattern) a = relay.Var("a") plus = fresh_pattern + a - plus_match = annotate_exact_matches(plus, pattern, compiler_name, pattern_name) + plus_match = annotate_exact_matches(plus, pattern, compiler_name, pattern_name, callback=callback) assert isinstance(plus_match, relay.Call) assert plus_match.op.name == "add" assert plus_match.args[1] == a assert check_compiler_call(plus_match.args[0], pattern) in_func = relay.Function([], fresh_pattern) - in_func_match = annotate_exact_matches(in_func, pattern, compiler_name, pattern_name) + in_func_match = annotate_exact_matches(in_func, pattern, compiler_name, pattern_name, callback=callback) assert isinstance(in_func_match, relay.Function) assert len(in_func_match.params) == 0 assert check_compiler_call(in_func_match.body, pattern) b = relay.Var("b") let = relay.Let(b, fresh_pattern, fresh_pattern + b) - let_match = annotate_exact_matches(let, pattern, compiler_name, pattern_name) + let_match = annotate_exact_matches(let, pattern, compiler_name, pattern_name, callback=callback) assert isinstance(let_match, relay.Let) assert check_compiler_call(let_match.value, pattern) assert isinstance(let_match.body, relay.Call) @@ -35,7 +35,7 @@ def assert_simple_cases(pattern, compiler_name, pattern_name): x, y, z = relay.Var("x"), relay.Var("y"), relay.Var("z") call = relay.Function([x, y, z], (x + y) * z)(a, fresh_pattern, b) - call_match = annotate_exact_matches(call, pattern, compiler_name, pattern_name) + call_match = annotate_exact_matches(call, pattern, compiler_name, pattern_name, callback=callback) assert isinstance(call_match, relay.Call) assert tvm.ir.structural_equal(call_match.op, call.op, True) assert len(call_match.args) == 3 @@ -45,7 +45,7 @@ def assert_simple_cases(pattern, compiler_name, pattern_name): x, y = relay.Var("x"), relay.Var("y") tup = relay.Tuple([x, fresh_pattern, y]) - tup_match = annotate_exact_matches(tup, pattern, compiler_name, pattern_name) + tup_match = annotate_exact_matches(tup, pattern, compiler_name, pattern_name, callback=callback) assert isinstance(tup_match, relay.Tuple) assert isinstance(tup_match.fields[0], relay.Var) assert isinstance(tup_match.fields[2], relay.Var) @@ -59,7 +59,7 @@ def assert_simple_cases(pattern, compiler_name, pattern_name): relay.PatternVar(z), relay.PatternVar(w) ]), fresh_pattern) ]) - match_clause_match = annotate_exact_matches(match_clause, pattern, compiler_name, pattern_name) + match_clause_match = annotate_exact_matches(match_clause, pattern, compiler_name, pattern_name, callback=callback) assert isinstance(match_clause_match, relay.Match) assert len(match_clause_match.clauses) == 3 assert isinstance(match_clause_match.clauses[0].lhs, relay.PatternWildcard) @@ -75,11 +75,57 @@ def assert_simple_cases(pattern, compiler_name, pattern_name): assert check_compiler_call(match_clause_match.clauses[2].rhs, pattern) ref = relay.RefCreate(fresh_pattern) - ref_match = annotate_exact_matches(ref, pattern, compiler_name, pattern_name) + ref_match = annotate_exact_matches(ref, pattern, compiler_name, pattern_name, callback=callback) assert isinstance(ref_match, relay.RefCreate) assert check_compiler_call(ref_match.value, pattern) +def assert_simple_cases_fail(target, pattern, compiler_name, pattern_name, callback=None): + # if you pass a target that the pattern does not match, the nested cases should fail too + basic_match = annotate_exact_matches(target, pattern, compiler_name, pattern_name, callback=callback) + assert not check_compiler_call(basic_match, pattern) + + a = relay.Var("a") + + plus = target + a + plus_match = annotate_exact_matches(plus, pattern, compiler_name, pattern_name, callback=callback) + assert tvm.ir.structural_equal(plus, plus_match, True) + + in_func = relay.Function([], target) + in_func_match = annotate_exact_matches(in_func, pattern, compiler_name, pattern_name, callback=callback) + assert tvm.ir.structural_equal(in_func, in_func_match, True) + + b = relay.Var("b") + let = relay.Let(b, target, target + b) + let_match = annotate_exact_matches(let, pattern, compiler_name, pattern_name, callback=callback) + assert tvm.ir.structural_equal(let, let_match, True) + + x, y, z = relay.Var("x"), relay.Var("y"), relay.Var("z") + call = relay.Function([x, y, z], (x + y) * z)(a, target, b) + call_match = annotate_exact_matches(call, pattern, compiler_name, pattern_name, callback=callback) + assert tvm.ir.structural_equal(call, call_match, True) + + x, y = relay.Var("x"), relay.Var("y") + tup = relay.Tuple([x, target, y]) + tup_match = annotate_exact_matches(tup, pattern, compiler_name, pattern_name, callback=callback) + assert tvm.ir.structural_equal(tup, tup_match, True) + + x, y, z, w = relay.Var("x"), relay.Var("y"), relay.Var("z"), relay.Var("w") + match_clause = relay.Match(x, [ + relay.Clause(relay.PatternWildcard(), target), + relay.Clause(relay.PatternVar(y), y), + relay.Clause(relay.PatternTuple([ + relay.PatternVar(z), relay.PatternVar(w) + ]), target) + ]) + match_clause_match = annotate_exact_matches(match_clause, pattern, compiler_name, pattern_name, callback=callback) + assert tvm.ir.structural_equal(match_clause, match_clause_match, True) + + ref = relay.RefCreate(target) + ref_match = annotate_exact_matches(ref, pattern, compiler_name, pattern_name, callback=callback) + assert tvm.ir.structural_equal(ref, ref_match, True) + + def test_match_misses(): pattern = relay.nn.dense(relay.Var("v"), relay.Var("w")) x, y, z, a = relay.Var("x"), relay.Var("y"), relay.Var("z"), relay.Var("a") @@ -93,7 +139,7 @@ def test_match_misses(): for prog in progs: new_prog = annotate_exact_matches(prog, pattern, "MyCompiler", "Dense") assert tvm.ir.structural_equal(prog, new_prog), (prog, new_prog) - + assert_simple_cases_fail(prog, pattern, "MyCompiler", "Dense") def test_operator_simple_match(): pattern = relay.nn.dense(relay.Var("v"), relay.Var("w")) @@ -326,6 +372,77 @@ def make_linear(d, w, b): assert not check_compiler_call(lin_match.args[2], pattern) +def test_operator_callback(): + def non_grouped(conv2d): + assert isinstance(conv2d, relay.Call) + if "groups" not in conv2d.attrs.keys(): + return True + return conv2d.attrs.groups == 1 + + def grouped(conv2d): + assert isinstance(conv2d, relay.Call) + return "groups" in conv2d.attrs.keys() and conv2d.attrs.groups > 1 + + x, w = relay.Var("x"), relay.Var("w") + pattern = relay.nn.conv2d(x, w) + + y, z = relay.Var("y"), relay.Var("z") + ungrouped_conv = relay.nn.conv2d(y, z) + assert_simple_cases(pattern, "MyCompiler", "ungrouped_conv", callback=non_grouped) + assert_simple_cases_fail(ungrouped_conv, pattern, "MyCompiler", "grouped_conv", callback=grouped) + + grouped_conv = relay.nn.conv2d(y, z, groups=2) + assert_simple_cases(grouped_conv, "MyCompiler", "ungrouped_conv", callback=grouped) + assert_simple_cases_fail(grouped_conv, pattern, "MyCompiler", "ungrouped_conv", callback=non_grouped) + + +def test_linear_layer_case(): + # full-scale case that had problems before + def linear_definition(batch_size, in_features, out_features, num_call=0): + weight = relay.var(f'weight_{num_call}', relay.TensorType((out_features, in_features), 'float32')) + bias = relay.var(f'bias_{num_call}', relay.TensorType((out_features, ), 'float32')) + inp = relay.var(f'input', relay.TensorType((batch_size, in_features))) + return relay.Function([inp], relay.nn.bias_add(relay.nn.dense(inp, weight), bias)) + + batch_size = 8 + in_features = 32 + hidden_dim_1 = 64 + hidden_dim_2 = 8 + + img = relay.var('img', relay.TensorType((batch_size, in_features), 'float32')) + fc1 = linear_definition(batch_size, in_features, hidden_dim_1, 0)(img) + fc2 = linear_definition(batch_size, hidden_dim_1, hidden_dim_2, 1)(fc1) + result = relay.nn.softmax(fc2, axis=-1) + mod = tvm.IRModule.from_expr(result) + mod = relay.transform.InferType()(mod) + + linear_pattern = linear_definition(batch_size, in_features, hidden_dim_1).body + main_mut = annotate_exact_matches(mod["main"], linear_pattern, 'ilaflex', 'ilaflex.linear') + mod_mut = tvm.IRModule.from_expr(main_mut) + mod_mut = relay.transform.InferType()(mod_mut) + final_main = mod_mut["main"] + + # structure should be nn.softmax( + # call(func literal with pattern matches, + # call(func literal with pattern matches, args), + # other args)) + final_body = final_main.body + assert isinstance(final_body, relay.Call) + assert final_body.op.name == "nn.softmax" + second_call = final_body.args[0] + assert isinstance(second_call, relay.Call) + + # check_compiler_call uses structural equality, + # which will reject based on type annotations, + # so this is doing a simpler check + # (if you want to reject based on shape, use a callback) + assert check_annotations(second_call.op.body) + first_call = second_call.args[0] + assert isinstance(first_call, relay.Call) + assert check_annotations(first_call.op.body) + + + if __name__ == "__main__": test_match_misses() test_operator_simple_match() @@ -342,3 +459,5 @@ def make_linear(d, w, b): test_inconsistent_match() test_ref_match() test_multiple_matches() + test_operator_callback() + test_linear_layer_case()