Skip to content

Commit

Permalink
A few fixes for task extraction (apache#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Feb 15, 2022
1 parent 8e3df1b commit 0d40561
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 26 deletions.
12 changes: 9 additions & 3 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,9 @@ def auto_schedule_topi(func_name, outs):
"""

# pylint: disable=import-outside-toplevel
from tvm.auto_scheduler.measure import (
from tvm.auto_scheduler.measure import ( # lazily import to avoid recursive dependency
prepare_input_map,
) # lazily import to avoid recursive dependency
)

io_tensors, has_layout_free, has_complex_op = traverse_to_get_io_tensors(outs)
if not io_tensors: # The compute includes dynamic shapes which are not supported yet.
Expand Down Expand Up @@ -482,4 +482,10 @@ def is_auto_scheduler_enabled():
enabled: bool
Whether the auto-scheduler is enabled
"""
return PassContext.current().config.get("relay.backend.use_auto_scheduler", False)
return PassContext.current().config.get(
"relay.backend.use_auto_scheduler",
False,
) or PassContext.current().config.get(
"relay.backend.use_meta_schedule",
False,
)
49 changes: 26 additions & 23 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@

import logging
import os.path
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import tvm
from tvm._ffi.registry import register_func
from tvm.ir import IRModule, structural_equal, structural_hash
from tvm.ir import IRModule, structural_hash
from tvm.relay import Function as RelayFunc
from tvm.relay import build as relay_build
from tvm.relay.backend.executor_factory import ExecutorFactoryModule
Expand Down Expand Up @@ -602,21 +602,24 @@ def tune_te(
)


def deduplicate_tune_contexts(tune_contexts: List[TuneContext]) -> List[TuneContext]:
results: List[TuneContext] = []
hashs: List[int] = []
for i, task in enumerate(tune_contexts):
struct_hash: int = structural_hash(task.mod)
flag: bool = False
if struct_hash in hashs:
for other_task in tune_contexts[i + 1 :]:
if structural_equal(task.mod, other_task.mod):
flag = True
break
if not flag:
results.append(task)
hashs.append(struct_hash)
return results
def deduplicate_extracted_tasks(
extracted_tasks: List[ExtractedTask],
) -> Tuple[List[ExtractedTask], List[int]]:
hash2idx: Dict[int, int] = {}
dedup: List[ExtractedTask] = []
count: List[int] = []

for task in extracted_tasks:
assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now"
mod = Parse._mod(task.dispatched[0])
hash = structural_hash(mod)
if hash in hash2idx:
count[hash2idx[hash]] += 1
else:
hash2idx[hash] = len(dedup)
dedup.append(task)
count.append(1)
return dedup, count


def tune_extracted_tasks(
Expand All @@ -637,11 +640,14 @@ def tune_extracted_tasks(
mutator_probs: Optional[FnMutatorProb] = None,
num_threads: Optional[int] = None,
) -> Database:
# deduplication
logger.info(f"Before task deduplication: {len(extracted_tasks)} tasks")
extracted_tasks, _ = deduplicate_extracted_tasks(extracted_tasks)
logger.info(f"After task deduplication: {len(extracted_tasks)} tasks")
# pylint: disable=protected-access
tune_contexts = []
target = Parse._target(target)
database = Parse._database(database, "default", work_dir)
# parse the tuning contexts
tune_contexts = []
for task in extracted_tasks:
assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now"
tune_contexts.append(
Expand All @@ -658,11 +664,8 @@ def tune_extracted_tasks(
num_threads=num_threads,
)
)
# deduplication
logger.info(f"Before task deduplication: {len(tune_contexts)} tasks")
tune_contexts = deduplicate_tune_contexts(tune_contexts)
logger.info(f"After task deduplication: {len(tune_contexts)} tasks")
# parse the task scheduler
database = Parse._database(database, "default", work_dir)
task_scheduler = Parse._task_scheduler(
task_scheduler,
tune_contexts,
Expand Down

0 comments on commit 0d40561

Please sign in to comment.