Skip to content

Commit

Permalink
Use a callback in the exact matcher (apache#18)
Browse files Browse the repository at this point in the history
* Don't use checked_type for annotating types in the matcher (breaks tests and relies on type checking)

* Add callback feature to exact matcher, expand tests, simplify recent additions
  • Loading branch information
slyubomirsky authored and gussmith23 committed Dec 29, 2021
1 parent 7c7ca4c commit dae31ce
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 64 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from ..transform import gradient
from .exact_matcher import annotate_exact_matches
# these are just for testing
from .exact_matcher import deduplicate_vars, check_compiler_call
from .exact_matcher import deduplicate_vars, check_compiler_call, check_annotations

from .op_summary import count_all_ops, count_all_overloads, count_all_ops_in_overloads

Expand Down
77 changes: 25 additions & 52 deletions python/tvm/relay/testing/exact_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,19 @@
from tvm.relay.analysis import free_vars, bound_vars

# dumb copy of what src/relay/transforms/de_duplicate.cc is doing
def deduplicate_vars(expr, var_map={}, use_original=False):
def deduplicate_vars(expr):
"""
Given the expr, replace all vars in the expression with fresh ones.
This is done to preserve well-formedness in Relay (all var definitions must be unique)
"""
class Deduplicator(ExprMutator):
def __init__(self):
super().__init__()
self.var_map = var_map.copy()
self.var_map = {}

def visit_var(self, var):
if var in self.var_map:
return self.var_map[var]
if use_original:
return var
fresh_var = relay.Var(var.name_hint)
self.var_map[var] = fresh_var
return fresh_var
Expand All @@ -48,20 +46,6 @@ def visit_match(self, match):
for c in match.clauses]
return relay.Match(new_val, clauses)

def visit_function(self, func):
args = list(map(self.visit, func.params))
body = self.visit(func.body)
return relay.Function(args, body, func.ret_type, func.type_params)

def visit_let(self, let_expr):
new_var = self.visit(let_expr.var)
new_body = self.visit(let_expr.body)
new_value = self.visit(let_expr.value)
# if isinstance(let_expr.value, relay.Function):
# print(let_expr.var, let_expr.value)
# print(new_value)
return relay.Let(new_var, new_value, new_body)

dedup = Deduplicator()
return dedup.visit(expr)

Expand Down Expand Up @@ -284,13 +268,16 @@ def visit_ref_write(self, ref_write):


class MatchMutator(ExprMutator):
def __init__(self, target, compiler_name, composite_name, composite_counter=0):
def __init__(self, target, compiler_name, composite_name, composite_counter=0, callback=None):
"""
Target: Expression that the matcher is seeking
Compiler name: Name for the custom codegen
Composite name: Name for the *construct produced* in the custom codegen
Composite counter: Id number used for generating compiler IDs
(they must be globally unique)
Callback: Function (expr -> bool) that checks properties of the matched expr.
Register a match only if callback(expr) is True.
Default behavior is always to return True.
Free vars in the target expression will be arguments to the extracted function
"""
Expand All @@ -302,20 +289,7 @@ def __init__(self, target, compiler_name, composite_name, composite_counter=0):
self.compiler_name = compiler_name
self.composite_name = composite_name
self.composite_counter = composite_counter

def to_key(self, relay_var):
def hash_type(relay_type):
{
tvm.ir.type.PrimType: lambda: relay_type.dtype,
tvm.ir.type.PointerType: lambda: hash_type(relay_type.element_type),
tvm.ir.type.TypeVar: lambda: (relay_type.name_hint, relay_type.kind),
tvm.ir.type.GlobalTypeVar: lambda: (relay_type.name_hint, relay_type.kind),
tvm.ir.type.TupleType: lambda: tuple(map(hash_type, relay_type.fields)),
tvm.ir.type.FuncType: lambda: (tuple(map(hash_type, relay_type.arg_types)), hash_type(relay_type.ret_type)),
tvm.ir.type.IncompleteType:lambda: relay_type.kind,
tvm.ir.type.RelayRefType: lambda: hash_type(relay_type.value),
}.get(type(relay_type), lambda: None)()
return (relay_var.name_hint, hash_type(relay_var.type_annotation))
self.callback = (lambda expr: True) if callback is None else callback

def extract_target(self, match_args):
"""
Expand All @@ -335,30 +309,24 @@ def extract_target(self, match_args):
})(a1, ..., an)
})(match_args[0], ..., match_args[n-1])
"""
# print(f'======={self.composite_counter}=======')
assert all(map(lambda v: self.to_key(v) in match_args, self.target_vars))
match_ordering = [match_args[v] for v in map(self.to_key, self.target_vars)]
match_ordering = [match_args[v] for v in self.target_vars]

# we have to deduplicate vars for Relay's well-formedness check
# (all var definitions must be unique)
inner_body = deduplicate_vars(self.target)
inner_free_vars = free_vars(inner_body)
# inner_args = list(map(lambda v: relay.Var(v.name_hint, match_args[self.to_key(v)].checked_type), inner_free_vars))
inner_args_map = {}
for var in inner_free_vars:
inner_args_map[var] = relay.Var(var.name_hint + str(self.composite_counter), match_args[self.to_key(var)].checked_type)
inner_body_rewritten = deduplicate_vars(inner_body, var_map=inner_args_map, use_original=True)
inner_func = relay.Function(list(map(inner_args_map.get, inner_free_vars)), inner_body_rewritten)
inner_args = free_vars(inner_body)
inner_func = relay.Function(inner_args, inner_body)

inner_func = inner_func.with_attr("Composite", self.composite_name)
outer_args = [relay.Var(f"outer_arg_{i}") for i in range(len(inner_free_vars))]
outer_args = [relay.Var(f"outer_arg_{i}") for i in range(len(inner_args))]
outer_func = relay.Function(outer_args, inner_func(*outer_args))
outer_func = outer_func.with_attr("Compiler", self.compiler_name)
outer_func = outer_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
outer_func = outer_func.with_attr(
"global_symbol",
f"{self.composite_name}_{self.composite_counter}")
self.composite_counter += 1
return outer_func(*match_ordering)
return outer_func(*match_ordering)

def visit(self, expr):
"""
Expand All @@ -372,14 +340,15 @@ def visit(self, expr):
return expr

found_match, match_args = check_match(self.target, expr)
if found_match:
# only permit the match if the callback is true
if found_match and self.callback(expr):
# need to check for matches in the match args too
final_args = {self.to_key(var): self.visit(arg) for var, arg in match_args.items()}
final_args = {var: self.visit(arg) for var, arg in match_args.items()}
return self.extract_target(final_args)
return super().visit(expr)


def annotate_exact_matches(expr, target, compiler_name, composite_name):
def annotate_exact_matches(expr, target, compiler_name, composite_name, callback=None):
"""
Given an expression and a target pattern,
this will replace all instances of the target pattern
Expand All @@ -404,7 +373,7 @@ def annotate_exact_matches(expr, target, compiler_name, composite_name):
This nested function structure is designed to make it easier for BYOC codegens to match those definitions.
"""
mut = MatchMutator(target, compiler_name, composite_name)
mut = MatchMutator(target, compiler_name, composite_name, callback=callback)
return mut.visit(expr)


Expand All @@ -420,16 +389,20 @@ def call_func_with_attr(expr, func_attr):
return func_attr in expr.op.attrs


def check_annotations(expr):
if not call_func_with_attr(expr, "Compiler"):
return False
return call_func_with_attr(expr.op.body, "Composite")


def check_compiler_call(expr, expected_body):
"""
Provided for testing purposes: Checks if the given expression is a
matcher-produced compiler function with the given body
"""
# check for a compiler function with an inner composite
if not call_func_with_attr(expr, "Compiler"):
if not check_annotations(expr):
return False
inner_call = expr.op.body
if not call_func_with_attr(inner_call, "Composite"):
return False
inner_body = inner_call.op.body
return tvm.ir.structural_equal(inner_body, expected_body, True)
Loading

0 comments on commit dae31ce

Please sign in to comment.