From 187aeb5fe813c5bffc81e88ba890399ce7d53e1e Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Mon, 31 Jan 2022 19:26:46 -0800 Subject: [PATCH] [AUTOTVM] Use opt level 3 when extracting tasks (#10065) * [AUTOTVM] Use opt level 3 when extracting tasks Autotvm was implicitly ignoring opt_level when extracting tasks because pass opt_level is a thread local variable and extraction happens in a new thread. Not having opt_level 3 causes alter op layout to not fire, which in turn prevents tuning from finding all possible kernels. * disable alter op layout --- python/tvm/autotvm/task/relay_integration.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index d2861cfe8ad6..04ca333a5ea8 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -34,7 +34,7 @@ # TODO(moreau89) find a more elegant way to lower for VTAs -def _lower(mod, target, params): +def _lower(mod, target, params, opt_level=3): """Helper to lower VTA properly.""" # pylint: disable=import-outside-toplevel from tvm import relay @@ -43,16 +43,19 @@ def _lower(mod, target, params): if hasattr(target, "device_name") and target.device_name == "vta": import vta - with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + with vta.build_config(opt_level=opt_level, disabled_pass={"AlterOpLayout"}): mod, _ = relay.optimize(mod, target, params) grc = graph_executor_codegen.GraphExecutorCodegen(None, target) grc.codegen(mod, mod["main"]) return - compiler = relay.vm.VMCompiler() - if params: - compiler.set_params(params) - compiler.lower(mod, target=target) + # Alter op layout code has been written expecting that tuning is applied + # without it, so we disable AlterOpLayout to maintain that behavior. + with tvm.transform.PassContext(opt_level=opt_level, disabled_pass={"AlterOpLayout"}): + compiler = relay.vm.VMCompiler() + if params: + compiler.set_params(params) + compiler.lower(mod, target=target) def extract_from_program(mod, params, target, target_host=None, ops=None):