Skip to content

Commit

Permalink
Merge pull request neo-ai#8 from trevor-m/trevmorr-fix-reduce
Browse files Browse the repository at this point in the history
Remove bn_tuple_get_item pattern, support tuple get item in ReduceSubgraphSize
  • Loading branch information
jianzhong-xu committed Jun 25, 2020
2 parents c249527 + 5a5fd85 commit 9fec6a5
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 37 deletions.
77 changes: 65 additions & 12 deletions python/tvm/relay/backend/contrib/tidl.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,11 @@ def visit_call(self, call):
return self.call_map[call]
return super().visit_call(call)

def visit_tuple_getitem(self, tuplegetitem):
if tuplegetitem in self.call_map:
return self.call_map[tuplegetitem]
return super().visit_tuple_getitem(tuplegetitem)

class VarRenamer(ExprMutator):
"""
Renames vars to match the new subgraph name. Used when subgraphs are renamed starting from zero.
Expand Down Expand Up @@ -571,39 +576,73 @@ def visit_call(self, call):
counter = SubgraphSizeCounter()
counter.visit(self.mod[name])
if counter.num_layers > self.max_num_layers or counter.get_total_memory_mb() > self.max_total_memory_mb:
# Mark that we have reduced the subgraph size.
self.reduced = True
# "Inline" the last op only back into new main function.
original_func = self.mod[name]
# Get last_op
last_op = original_func.body
if isinstance(last_op, tvm.relay.expr.Tuple) and len(last_op.fields) > 2:
# Currently can't reduce when there are more than 2 outputs.
args = []
for arg in call.args:
args.append(super().visit(arg))
subgraph_gv = relay.GlobalVar(name)
self.new_mod[subgraph_gv] = self.mod[name]
return subgraph_gv(*args)
# Mark that we have reduced the subgraph size.
self.reduced = True
last_op_args = []
if isinstance(last_op, tvm.relay.expr.Tuple):
# Subgraph has multiple outputs!
ancestor, dist0, dist1 = FindCommonAncestor(last_op)

def get_field(field, exclude):
"""Get field as it is, unless it is a TupleGetItem which we will remove."""
if isinstance(field, tvm.relay.expr.Call):
return [field]
elif isinstance(field, tvm.relay.expr.TupleGetItem):
args = []
for arg in field.tuple_value.args:
if arg not in exclude:
args.append(arg)
return args
else:
raise ValueError("New output of subgraph must be Call node.")

def get_args(field, exclude):
"""Gather args from field, excluding exclude node"""
args = []
assert isinstance(field, tvm.relay.expr.Call)
for arg in field.args:
if arg != exclude:
args.append(arg)
if isinstance(field, tvm.relay.expr.Call):
for arg in field.args:
if arg not in exclude:
args.append(arg)
elif isinstance(field, tvm.relay.expr.TupleGetItem):
for arg in field.tuple_value.args:
if arg not in exclude:
args.append(arg)
else:
raise ValueError("New output of subgraph must be Call node.")
return args

# If all fields in tuple are not CallNodes, we will just remove all up to common ancestor.
if (dist0 == 0 and dist1 == 0):
last_op_args = ancestor.args
elif dist0 > dist1:
# field[0] is further from LCA, remove it by replacing it with its args.
last_op_args = get_args(last_op.fields[0], exclude=last_op.fields[1]) + [last_op.fields[1]]
from_field_0 = get_args(last_op.fields[0], exclude=[last_op.fields[1]])
from_field_1 = get_field(last_op.fields[1], exclude=from_field_0)
last_op_args = from_field_0 + from_field_1
elif dist1 >= dist0:
# field[1] is further from LCA, Remove it by replacing it with its args.
last_op_args = [last_op.fields[0]] + get_args(last_op.fields[1], exclude=last_op.fields[0])
from_field_0 = get_field(last_op.fields[0], exclude=[last_op.fields[1]])
from_field_1 = get_args(last_op.fields[1], exclude=from_field_0)
last_op_args = from_field_0 + from_field_1
elif isinstance(last_op, tvm.relay.expr.Call):
last_op_args = last_op.args
elif isinstance(last_op, tvm.relay.expr.TupleGetItem):
last_op_arg = []
for arg in last_op.tuple_value.args:
last_op_arg.append(arg)
else:
raise ValueError("Input to last op is not call or tuple")
raise ValueError("Last op is not Call, Tuple, or TupleGetItem")
# Gather new outputs of the subgraph - from removed op's inputs
# This map will map Expr to index in new_outputs tuple
#print('last_op_args', last_op_args)
Expand Down Expand Up @@ -669,11 +708,25 @@ def ReduceSubgraphSize(mod, compiler="tidl", max_num_layers=256, max_total_memor
new_mod['main'] = reducer.visit(mod["main"])
# If no subgraphs where reduced in size, we are done.
if not reducer.reduced:
return new_mod
break
mod = new_mod
# Avoid infinite loop.
sanity_counter -= 1
return mod

# Fallback: Completely remove all subgraphs still in violation.
# SubgraphReducer can only handle subgraphs with 1 or 2 outputs.
subgraph_names_to_remove = []
for subgraph in mod.get_global_vars():
name = subgraph.name_hint
if not mod[name].attrs or mod[name].attrs["Compiler"] != compiler:
continue
counter = SubgraphSizeCounter()
counter.visit(mod[name])
if counter.num_layers > max_num_layers or counter.get_total_memory_mb() > max_total_memory_mb:
subgraph_names_to_remove.append(name)
new_mod = tvm.IRModule()
new_mod["main"] = SubgraphRemover(subgraph_names_to_remove, mod, new_mod).visit(mod["main"])
return new_mod

def PruneSubgraphsWithMoreThanOneInput(mod, compiler="tidl"):
subgraph_names_to_remove = []
Expand Down
27 changes: 6 additions & 21 deletions python/tvm/relay/op/contrib/tidl.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,6 @@ def _dense_bias_pattern():
bias_out = is_op('nn.bias_add')(dense_out, is_constant())
return bias_out

def _bn_tuple_get_item():
bn_out = is_op('nn.batch_norm')(wildcard(), is_constant(), is_constant(), is_constant(), is_constant())
tuple_get_item_node = is_tuple_get_item(bn_out, 0)
return tuple_get_item_node

def _bn_tuple_get_item_checker(extract):
bn_op = extract.tuple_value
data1 = infer_type(bn_op.args[1])
if data1.checked_type.dtype != 'float32':
return False
elif bn_op.attrs.axis != 1 and bn_op.attrs.axis != 3:
return False
return True

pattern_table = [
('tidl.squeeze_reshape', _squeeze_reshape_pattern()),
#TODO: add import of op 'transpose' and uncomment 2 items below
Expand All @@ -156,7 +142,6 @@ def _bn_tuple_get_item_checker(extract):
('tidl.dense_relu', _dense_relu_pattern()),
('tidl.dense_bias_relu', _dense_bias_relu_pattern()),
('tidl.dense_bias', _dense_bias_pattern()),
('tidl.bn_tuple_get_item', _bn_tuple_get_item(), _bn_tuple_get_item_checker),
]

return relay.transform.MergeComposite(pattern_table)(mod)
Expand Down Expand Up @@ -243,10 +228,6 @@ def _conv2d_pad_whitelist_fn(attrs, args):
supported = pad_supported and conv2d_supported
return supported

@tvm.ir.register_op_attr("tidl.bn_tuple_get_item", "target.tidl")
def _bn_tuple_get_item_whitelist_fn(attrs, args):
return True

@tvm.ir.register_op_attr("add", "target.tidl")
def _add_whitelist_fn(attrs, args):
supported = True
Expand Down Expand Up @@ -280,8 +261,12 @@ def _batch_flatten_fn(attrs, args):

@tvm.ir.register_op_attr("nn.batch_norm", "target.tidl")
def _batch_norm_whitelist_fn(attrs, args):
# Standalone batch_norm is supported only as a pattern (bn_tuple_get_item).
return False
data1 = infer_type(args[1])
if data1.checked_type.dtype != 'float32':
return False
elif attrs.axis != 1 and attrs.axis != 3:
return False
return True

@tvm.ir.register_op_attr("nn.bias_add", "target.tidl")
def _bias_add_whitelist_fn(attrs, args):
Expand Down
118 changes: 114 additions & 4 deletions tests/python/relay/test_tidl_reduce_subgraph_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def test_reduce_subgraph_size_single_output():
def create_graph():
ishape = (1, 3, 12, 12)
x = relay.var('tidl_i0', shape=ishape, dtype='float32')
x = relay.var('tidl_0_i0', shape=ishape, dtype='float32')
y = relay.nn.relu(x)
out = relay.nn.relu(y)
func = relay.Function([x], out)
Expand All @@ -41,7 +41,7 @@ def create_graph():

def expected():
ishape = (1, 3, 12, 12)
x = relay.var('tidl_i0', shape=ishape, dtype='float32')
x = relay.var('tidl_0_i0', shape=ishape, dtype='float32')
out = relay.nn.relu(x)
func = relay.Function([x], out)
func = set_func_attr(func, "tidl", "tidl_0")
Expand Down Expand Up @@ -171,9 +171,119 @@ def expected_2():
# Will remove 2nd conv2d.
ref_mod = expected_2()
reduced = ReduceSubgraphSize(create_graph(), max_num_layers=1, compiler="tidl")
print('reduced', reduced)
assert tvm.ir.structural_equal(reduced, ref_mod, map_free_vars=True)

def test_reduce_subgraph_size_tuple_get_item():
def create_graph():
ishape = (1, 32, 14, 14)
w1shape = (32, )
dtype = "float32"
data0 = relay.var("tidl_0_i0", shape=(ishape), dtype=dtype)
input0 = relay.var("tidl_0_i1", shape=(w1shape), dtype=dtype)
input1 = relay.var("tidl_0_i2", shape=(w1shape), dtype=dtype)
input2 = relay.var("tidl_0_i3", shape=(w1shape), dtype=dtype)
input3 = relay.var("tidl_0_i4", shape=(w1shape), dtype=dtype)
input4 = relay.var("tidl_0_i5", shape=(w1shape), dtype=dtype)
input5 = relay.var("tidl_0_i6", shape=(w1shape), dtype=dtype)
input6 = relay.var("tidl_0_i7", shape=(w1shape), dtype=dtype)
input7 = relay.var("tidl_0_i8", shape=(w1shape), dtype=dtype)
params = {"tidl_0_i" + str(i): np.ones(w1shape, dtype="float32") for i in range(1, 9)}
r = relay.nn.relu(data0)
batch_norm_0 = relay.nn.batch_norm(r, input0, input1, input2, input3)
batch_norm_1 = relay.nn.batch_norm(r, input4, input5, input6, input7)
tup_get_0 = batch_norm_0[0]
tup_get_1 = batch_norm_1[0]
out = relay.Tuple([tup_get_0, tup_get_1])
func = relay.Function([data0, input0, input1, input2, input3, input4, input5, input6, input7], out)
func = set_func_attr(func, "tidl", "tidl_0")
func = bind_params_by_name(func, params)
gv = relay.GlobalVar("tidl_0")

mod = tvm.IRModule()
mod[gv] = func
x_main = relay.var('x', shape=ishape, dtype='float32')
main_f = relay.Function([x_main], gv(x_main))
mod['main'] = main_f #bind_params_by_name(main_f, params)
return mod

def expected():
ishape = (1, 32, 14, 14)
w1shape = (32, )
dtype = "float32"
data0 = relay.var("tidl_0_i0", shape=(ishape), dtype=dtype)
input0 = relay.var("tidl_0_i1", shape=(w1shape), dtype=dtype)
input1 = relay.var("tidl_0_i2", shape=(w1shape), dtype=dtype)
input2 = relay.var("tidl_0_i3", shape=(w1shape), dtype=dtype)
input3 = relay.var("tidl_0_i4", shape=(w1shape), dtype=dtype)
input4 = relay.var("tidl_0_i5", shape=(w1shape), dtype=dtype)
input5 = relay.var("tidl_0_i6", shape=(w1shape), dtype=dtype)
input6 = relay.var("tidl_0_i7", shape=(w1shape), dtype=dtype)
input7 = relay.var("tidl_0_i8", shape=(w1shape), dtype=dtype)
params = {"tidl_0_i" + str(i): np.ones(w1shape, dtype="float32") for i in range(1, 9)}
r = relay.nn.relu(data0)
func = relay.Function([data0], r)
func = set_func_attr(func, "tidl", "tidl_0")
func = bind_params_by_name(func, params)
gv = relay.GlobalVar("tidl_0")

mod = tvm.IRModule()
mod[gv] = func
x_main = relay.var('x', shape=ishape, dtype='float32')
call = gv(x_main)
batch_norm_0 = relay.nn.batch_norm(call, input0, input1, input2, input3)
batch_norm_1 = relay.nn.batch_norm(call, input4, input5, input6, input7)
tup_get_0 = batch_norm_0[0]
tup_get_1 = batch_norm_1[0]
out = relay.Tuple([tup_get_0, tup_get_1])
main_f = relay.Function([x_main, input0, input1, input2, input3, input4, input5, input6, input7], out)
mod['main'] = bind_params_by_name(main_f, params)
return mod

ref_mod = expected()
reduced = ReduceSubgraphSize(create_graph(), max_num_layers=2, compiler="tidl")
assert tvm.ir.structural_equal(reduced, ref_mod, map_free_vars=True)

def test_reduce_subgraph_size_three_outputs_fallback():
def create_graph():
ishape = (1, 32, 14, 14)
dtype = "float32"
data0 = relay.var("tidl_0_i0", shape=(ishape), dtype=dtype)
r = relay.nn.relu(data0)
r0 = relay.nn.relu(r)
r1 = relay.tanh(r)
r2 = relay.sin(r)
out = relay.Tuple([r0, r1, r2])
func = relay.Function([data0], out)
func = set_func_attr(func, "tidl", "tidl_0")
gv = relay.GlobalVar("tidl_0")

mod = tvm.IRModule()
mod[gv] = func
x_main = relay.var('x', shape=ishape, dtype='float32')
main_f = relay.Function([x_main], gv(x_main))
mod['main'] = main_f
return mod

def expected():
ishape = (1, 32, 14, 14)
dtype = "float32"
data0 = relay.var("x", shape=(ishape), dtype=dtype)
r = relay.nn.relu(data0)
r0 = relay.nn.relu(r)
r1 = relay.tanh(r)
r2 = relay.sin(r)
out = relay.Tuple([r0, r1, r2])

mod = tvm.IRModule()
mod['main'] = relay.Function([data0], out)
return mod

ref_mod = expected()
reduced = ReduceSubgraphSize(create_graph(), max_num_layers=1, compiler="tidl")
assert tvm.ir.structural_equal(reduced, ref_mod, map_free_vars=True)

if __name__ == '__main__':
#test_reduce_subgraph_size_single_output()
test_reduce_subgraph_size_single_output()
test_reduce_subgraph_size_multiple_output()
test_reduce_subgraph_size_tuple_get_item()
test_reduce_subgraph_size_three_outputs_fallback()

0 comments on commit 9fec6a5

Please sign in to comment.