diff --git a/python/tvm/relay/backend/contrib/tidl.py b/python/tvm/relay/backend/contrib/tidl.py index eb27edc4d522..8d796119b509 100755 --- a/python/tvm/relay/backend/contrib/tidl.py +++ b/python/tvm/relay/backend/contrib/tidl.py @@ -379,6 +379,11 @@ def visit_call(self, call): return self.call_map[call] return super().visit_call(call) + def visit_tuple_getitem(self, tuplegetitem): + if tuplegetitem in self.call_map: + return self.call_map[tuplegetitem] + return super().visit_tuple_getitem(tuplegetitem) + class VarRenamer(ExprMutator): """ Renames vars to match the new subgraph name. Used when subgraphs are renamed starting from zero. @@ -571,24 +576,50 @@ def visit_call(self, call): counter = SubgraphSizeCounter() counter.visit(self.mod[name]) if counter.num_layers > self.max_num_layers or counter.get_total_memory_mb() > self.max_total_memory_mb: - # Mark that we have reduced the subgraph size. - self.reduced = True # "Inline" the last op only back into new main function. original_func = self.mod[name] - # Get last_op last_op = original_func.body + if isinstance(last_op, tvm.relay.expr.Tuple) and len(last_op.fields) > 2: + # Currently can't reduce when there are more than 2 outputs. + args = [] + for arg in call.args: + args.append(super().visit(arg)) + subgraph_gv = relay.GlobalVar(name) + self.new_mod[subgraph_gv] = self.mod[name] + return subgraph_gv(*args) + # Mark that we have reduced the subgraph size. + self.reduced = True last_op_args = [] if isinstance(last_op, tvm.relay.expr.Tuple): # Subgraph has multiple outputs! ancestor, dist0, dist1 = FindCommonAncestor(last_op) + def get_field(field, exclude): + """Get field as it is, unless it is a TupleGetItem which we will remove.""" + if isinstance(field, tvm.relay.expr.Call): + return [field] + elif isinstance(field, tvm.relay.expr.TupleGetItem): + args = [] + for arg in field.tuple_value.args: + if arg not in exclude: + args.append(arg) + return args + else: + raise ValueError("New output of subgraph must be Call node.") + def get_args(field, exclude): """Gather args from field, excluding exclude node""" args = [] - assert isinstance(field, tvm.relay.expr.Call) - for arg in field.args: - if arg != exclude: - args.append(arg) + if isinstance(field, tvm.relay.expr.Call): + for arg in field.args: + if arg not in exclude: + args.append(arg) + elif isinstance(field, tvm.relay.expr.TupleGetItem): + for arg in field.tuple_value.args: + if arg not in exclude: + args.append(arg) + else: + raise ValueError("New output of subgraph must be Call node.") return args # If all fields in tuple are not CallNodes, we will just remove all up to common ancestor. @@ -596,14 +627,22 @@ def get_args(field, exclude): last_op_args = ancestor.args elif dist0 > dist1: # field[0] is further from LCA, remove it by replacing it with its args. - last_op_args = get_args(last_op.fields[0], exclude=last_op.fields[1]) + [last_op.fields[1]] + from_field_0 = get_args(last_op.fields[0], exclude=[last_op.fields[1]]) + from_field_1 = get_field(last_op.fields[1], exclude=from_field_0) + last_op_args = from_field_0 + from_field_1 elif dist1 >= dist0: # field[1] is further from LCA, Remove it by replacing it with its args. - last_op_args = [last_op.fields[0]] + get_args(last_op.fields[1], exclude=last_op.fields[0]) + from_field_0 = get_field(last_op.fields[0], exclude=[last_op.fields[1]]) + from_field_1 = get_args(last_op.fields[1], exclude=from_field_0) + last_op_args = from_field_0 + from_field_1 elif isinstance(last_op, tvm.relay.expr.Call): last_op_args = last_op.args + elif isinstance(last_op, tvm.relay.expr.TupleGetItem): + last_op_arg = [] + for arg in last_op.tuple_value.args: + last_op_arg.append(arg) else: - raise ValueError("Input to last op is not call or tuple") + raise ValueError("Last op is not Call, Tuple, or TupleGetItem") # Gather new outputs of the subgraph - from removed op's inputs # This map will map Expr to index in new_outputs tuple #print('last_op_args', last_op_args) @@ -669,11 +708,25 @@ def ReduceSubgraphSize(mod, compiler="tidl", max_num_layers=256, max_total_memor new_mod['main'] = reducer.visit(mod["main"]) # If no subgraphs where reduced in size, we are done. if not reducer.reduced: - return new_mod + break mod = new_mod # Avoid infinite loop. sanity_counter -= 1 - return mod + + # Fallback: Completely remove all subgraphs still in violation. + # SubgraphReducer can only handle subgraphs with 1 or 2 outputs. + subgraph_names_to_remove = [] + for subgraph in mod.get_global_vars(): + name = subgraph.name_hint + if not mod[name].attrs or mod[name].attrs["Compiler"] != compiler: + continue + counter = SubgraphSizeCounter() + counter.visit(mod[name]) + if counter.num_layers > max_num_layers or counter.get_total_memory_mb() > max_total_memory_mb: + subgraph_names_to_remove.append(name) + new_mod = tvm.IRModule() + new_mod["main"] = SubgraphRemover(subgraph_names_to_remove, mod, new_mod).visit(mod["main"]) + return new_mod def PruneSubgraphsWithMoreThanOneInput(mod, compiler="tidl"): subgraph_names_to_remove = [] diff --git a/python/tvm/relay/op/contrib/tidl.py b/python/tvm/relay/op/contrib/tidl.py index 8c5093feec66..11e7b564e34c 100755 --- a/python/tvm/relay/op/contrib/tidl.py +++ b/python/tvm/relay/op/contrib/tidl.py @@ -124,20 +124,6 @@ def _dense_bias_pattern(): bias_out = is_op('nn.bias_add')(dense_out, is_constant()) return bias_out - def _bn_tuple_get_item(): - bn_out = is_op('nn.batch_norm')(wildcard(), is_constant(), is_constant(), is_constant(), is_constant()) - tuple_get_item_node = is_tuple_get_item(bn_out, 0) - return tuple_get_item_node - - def _bn_tuple_get_item_checker(extract): - bn_op = extract.tuple_value - data1 = infer_type(bn_op.args[1]) - if data1.checked_type.dtype != 'float32': - return False - elif bn_op.attrs.axis != 1 and bn_op.attrs.axis != 3: - return False - return True - pattern_table = [ ('tidl.squeeze_reshape', _squeeze_reshape_pattern()), #TODO: add import of op 'transpose' and uncomment 2 items below @@ -156,7 +142,6 @@ def _bn_tuple_get_item_checker(extract): ('tidl.dense_relu', _dense_relu_pattern()), ('tidl.dense_bias_relu', _dense_bias_relu_pattern()), ('tidl.dense_bias', _dense_bias_pattern()), - ('tidl.bn_tuple_get_item', _bn_tuple_get_item(), _bn_tuple_get_item_checker), ] return relay.transform.MergeComposite(pattern_table)(mod) @@ -243,10 +228,6 @@ def _conv2d_pad_whitelist_fn(attrs, args): supported = pad_supported and conv2d_supported return supported -@tvm.ir.register_op_attr("tidl.bn_tuple_get_item", "target.tidl") -def _bn_tuple_get_item_whitelist_fn(attrs, args): - return True - @tvm.ir.register_op_attr("add", "target.tidl") def _add_whitelist_fn(attrs, args): supported = True @@ -280,8 +261,12 @@ def _batch_flatten_fn(attrs, args): @tvm.ir.register_op_attr("nn.batch_norm", "target.tidl") def _batch_norm_whitelist_fn(attrs, args): - # Standalone batch_norm is supported only as a pattern (bn_tuple_get_item). - return False + data1 = infer_type(args[1]) + if data1.checked_type.dtype != 'float32': + return False + elif attrs.axis != 1 and attrs.axis != 3: + return False + return True @tvm.ir.register_op_attr("nn.bias_add", "target.tidl") def _bias_add_whitelist_fn(attrs, args): diff --git a/tests/python/relay/test_tidl_reduce_subgraph_size.py b/tests/python/relay/test_tidl_reduce_subgraph_size.py index bf38763d53ef..5751db5bd98d 100644 --- a/tests/python/relay/test_tidl_reduce_subgraph_size.py +++ b/tests/python/relay/test_tidl_reduce_subgraph_size.py @@ -25,7 +25,7 @@ def test_reduce_subgraph_size_single_output(): def create_graph(): ishape = (1, 3, 12, 12) - x = relay.var('tidl_i0', shape=ishape, dtype='float32') + x = relay.var('tidl_0_i0', shape=ishape, dtype='float32') y = relay.nn.relu(x) out = relay.nn.relu(y) func = relay.Function([x], out) @@ -41,7 +41,7 @@ def create_graph(): def expected(): ishape = (1, 3, 12, 12) - x = relay.var('tidl_i0', shape=ishape, dtype='float32') + x = relay.var('tidl_0_i0', shape=ishape, dtype='float32') out = relay.nn.relu(x) func = relay.Function([x], out) func = set_func_attr(func, "tidl", "tidl_0") @@ -171,9 +171,119 @@ def expected_2(): # Will remove 2nd conv2d. ref_mod = expected_2() reduced = ReduceSubgraphSize(create_graph(), max_num_layers=1, compiler="tidl") - print('reduced', reduced) + assert tvm.ir.structural_equal(reduced, ref_mod, map_free_vars=True) + +def test_reduce_subgraph_size_tuple_get_item(): + def create_graph(): + ishape = (1, 32, 14, 14) + w1shape = (32, ) + dtype = "float32" + data0 = relay.var("tidl_0_i0", shape=(ishape), dtype=dtype) + input0 = relay.var("tidl_0_i1", shape=(w1shape), dtype=dtype) + input1 = relay.var("tidl_0_i2", shape=(w1shape), dtype=dtype) + input2 = relay.var("tidl_0_i3", shape=(w1shape), dtype=dtype) + input3 = relay.var("tidl_0_i4", shape=(w1shape), dtype=dtype) + input4 = relay.var("tidl_0_i5", shape=(w1shape), dtype=dtype) + input5 = relay.var("tidl_0_i6", shape=(w1shape), dtype=dtype) + input6 = relay.var("tidl_0_i7", shape=(w1shape), dtype=dtype) + input7 = relay.var("tidl_0_i8", shape=(w1shape), dtype=dtype) + params = {"tidl_0_i" + str(i): np.ones(w1shape, dtype="float32") for i in range(1, 9)} + r = relay.nn.relu(data0) + batch_norm_0 = relay.nn.batch_norm(r, input0, input1, input2, input3) + batch_norm_1 = relay.nn.batch_norm(r, input4, input5, input6, input7) + tup_get_0 = batch_norm_0[0] + tup_get_1 = batch_norm_1[0] + out = relay.Tuple([tup_get_0, tup_get_1]) + func = relay.Function([data0, input0, input1, input2, input3, input4, input5, input6, input7], out) + func = set_func_attr(func, "tidl", "tidl_0") + func = bind_params_by_name(func, params) + gv = relay.GlobalVar("tidl_0") + + mod = tvm.IRModule() + mod[gv] = func + x_main = relay.var('x', shape=ishape, dtype='float32') + main_f = relay.Function([x_main], gv(x_main)) + mod['main'] = main_f #bind_params_by_name(main_f, params) + return mod + + def expected(): + ishape = (1, 32, 14, 14) + w1shape = (32, ) + dtype = "float32" + data0 = relay.var("tidl_0_i0", shape=(ishape), dtype=dtype) + input0 = relay.var("tidl_0_i1", shape=(w1shape), dtype=dtype) + input1 = relay.var("tidl_0_i2", shape=(w1shape), dtype=dtype) + input2 = relay.var("tidl_0_i3", shape=(w1shape), dtype=dtype) + input3 = relay.var("tidl_0_i4", shape=(w1shape), dtype=dtype) + input4 = relay.var("tidl_0_i5", shape=(w1shape), dtype=dtype) + input5 = relay.var("tidl_0_i6", shape=(w1shape), dtype=dtype) + input6 = relay.var("tidl_0_i7", shape=(w1shape), dtype=dtype) + input7 = relay.var("tidl_0_i8", shape=(w1shape), dtype=dtype) + params = {"tidl_0_i" + str(i): np.ones(w1shape, dtype="float32") for i in range(1, 9)} + r = relay.nn.relu(data0) + func = relay.Function([data0], r) + func = set_func_attr(func, "tidl", "tidl_0") + func = bind_params_by_name(func, params) + gv = relay.GlobalVar("tidl_0") + + mod = tvm.IRModule() + mod[gv] = func + x_main = relay.var('x', shape=ishape, dtype='float32') + call = gv(x_main) + batch_norm_0 = relay.nn.batch_norm(call, input0, input1, input2, input3) + batch_norm_1 = relay.nn.batch_norm(call, input4, input5, input6, input7) + tup_get_0 = batch_norm_0[0] + tup_get_1 = batch_norm_1[0] + out = relay.Tuple([tup_get_0, tup_get_1]) + main_f = relay.Function([x_main, input0, input1, input2, input3, input4, input5, input6, input7], out) + mod['main'] = bind_params_by_name(main_f, params) + return mod + + ref_mod = expected() + reduced = ReduceSubgraphSize(create_graph(), max_num_layers=2, compiler="tidl") + assert tvm.ir.structural_equal(reduced, ref_mod, map_free_vars=True) + +def test_reduce_subgraph_size_three_outputs_fallback(): + def create_graph(): + ishape = (1, 32, 14, 14) + dtype = "float32" + data0 = relay.var("tidl_0_i0", shape=(ishape), dtype=dtype) + r = relay.nn.relu(data0) + r0 = relay.nn.relu(r) + r1 = relay.tanh(r) + r2 = relay.sin(r) + out = relay.Tuple([r0, r1, r2]) + func = relay.Function([data0], out) + func = set_func_attr(func, "tidl", "tidl_0") + gv = relay.GlobalVar("tidl_0") + + mod = tvm.IRModule() + mod[gv] = func + x_main = relay.var('x', shape=ishape, dtype='float32') + main_f = relay.Function([x_main], gv(x_main)) + mod['main'] = main_f + return mod + + def expected(): + ishape = (1, 32, 14, 14) + dtype = "float32" + data0 = relay.var("x", shape=(ishape), dtype=dtype) + r = relay.nn.relu(data0) + r0 = relay.nn.relu(r) + r1 = relay.tanh(r) + r2 = relay.sin(r) + out = relay.Tuple([r0, r1, r2]) + + mod = tvm.IRModule() + mod['main'] = relay.Function([data0], out) + return mod + + ref_mod = expected() + reduced = ReduceSubgraphSize(create_graph(), max_num_layers=1, compiler="tidl") assert tvm.ir.structural_equal(reduced, ref_mod, map_free_vars=True) if __name__ == '__main__': - #test_reduce_subgraph_size_single_output() + test_reduce_subgraph_size_single_output() test_reduce_subgraph_size_multiple_output() + test_reduce_subgraph_size_tuple_get_item() + test_reduce_subgraph_size_three_outputs_fallback()