Skip to content

Commit

Permalink
Minor fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Jan 7, 2022
1 parent 81cd0b1 commit 03a6e9a
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 113 deletions.
2 changes: 0 additions & 2 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,6 @@ Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,

PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr);

PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_ITER_AFFINE_MAP_H_
3 changes: 1 addition & 2 deletions tests/python/unittest/test_meta_schedule_search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,5 +269,4 @@ def predict(


if __name__ == "__main__":
test_meta_schedule_evolutionary_search()
# sys.exit(pytest.main([__file__] + sys.argv[1:]))
sys.exit(pytest.main([__file__] + sys.argv[1:]))
109 changes: 0 additions & 109 deletions tests/python/unittest/test_meta_schedule_tune_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,115 +212,6 @@ def _postproc():
)


@pytest.mark.skip("Integeration test")
def test_tune_matmul_cuda_tensor_core():
def f_tune_context(mod, target, config, task_name):
return TuneContext(
mod=mod,
target=target,
space_generator=PostOrderApply(),
search_strategy=config.create_strategy(),
sch_rules=[
schedule_rule.AutoInline(
into_producer=False,
into_consumer=True,
into_cache_only=False,
inline_const_tensor=True,
disallow_if_then_else=False,
require_injective=False,
require_ordered=False,
disallow_op=None,
),
schedule_rule.MultiLevelTiling(
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"],
use_tensor_core=True,
max_innermost_factor=64,
vector_load_max_len=4,
reuse_read=schedule_rule.ReuseType(
req="must",
levels=[4],
scope="shared",
),
reuse_write=schedule_rule.ReuseType(
req="no",
levels=[],
scope="",
),
),
schedule_rule.AutoInline(
into_producer=True,
into_consumer=True,
into_cache_only=True,
inline_const_tensor=True,
disallow_if_then_else=False,
require_injective=False,
require_ordered=False,
disallow_op=None,
),
schedule_rule.ParallelizeVectorizeUnroll(
max_jobs_per_core=-1, # disable parallelize
max_vectorize_extent=-1, # disable vectorize
unroll_max_steps=[0, 16, 64, 512, 1024],
unroll_explicit=True,
),
],
postprocs=[
postproc.RewriteCooperativeFetch(),
# postproc.RewriteUnboundBlock(),
postproc.RewriteParallelVectorizeUnroll(),
postproc.RewriteReductionBlock(),
postproc.RewriteTensorCore(),
postproc.VerifyGPUCode(),
],
mutators=[],
task_name=task_name,
rand_state=-1,
num_threads=None,
)

n = 4096
mod = create_prim_func(te_workload.matmul_fp16(n, n, n))
target = Target("nvidia/geforce-rtx-3070")
config = ReplayTraceConfig(
num_trials_per_iter=32,
num_trials_total=320,
)

sch: Schedule = tune_tir(mod=mod, target=target, config=config, f_tune_context=f_tune_context)
if sch is None:
print("No valid schedule found!")
else:
print(sch.mod.script())
print(sch.trace)

from tvm.contrib import nvcc
import numpy as np

ctx = tvm.gpu(0)
if nvcc.have_tensorcore(ctx.compute_version):
with tvm.transform.PassContext():
func = tvm.build(sch.mod["main"], [], "cuda")
print(sch.mod.script())
print(func.imported_modules[0].get_source())
a_np = np.random.uniform(size=(n, n)).astype("float16")
b_np = np.random.uniform(size=(n, n)).astype("float16")
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, n), dtype="float32"), ctx)
evaluator = func.time_evaluator(
func.entry_name, ctx, number=3, repeat=1, min_repeat_ms=40
)
print("matmul with tensor core: %f ms" % (evaluator(a, b, c).mean * 1e3))

np.testing.assert_allclose(
c.asnumpy(),
np.matmul(a_np.astype("float32"), b_np.astype("float32")),
rtol=1e-4,
atol=1e-4,
)


if __name__ == """__main__""":
test_tune_matmul_cpu()
test_tune_matmul_cuda()
Expand Down

0 comments on commit 03a6e9a

Please sign in to comment.