From 81c16c169ae11a8dc95b2f1cc121d777c17a1287 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 1 Feb 2022 12:26:09 +0900 Subject: [PATCH] [Ansor] Improve OpenCL support (#10108) * Support OpenCL in Autoscheduler tuning * add warning * Update src/auto_scheduler/search_task.cc Co-authored-by: Cody Yu * fix lint Co-authored-by: Cody Yu --- apps/topi_recipe/gemm/cuda_gemm_square.py | 21 ----------------- src/auto_scheduler/search_task.cc | 28 +++++++++++++++++++++-- src/runtime/opencl/opencl_device_api.cc | 4 +++- 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/apps/topi_recipe/gemm/cuda_gemm_square.py b/apps/topi_recipe/gemm/cuda_gemm_square.py index f9b10bd495c6..be55d158fcbc 100644 --- a/apps/topi_recipe/gemm/cuda_gemm_square.py +++ b/apps/topi_recipe/gemm/cuda_gemm_square.py @@ -27,27 +27,6 @@ USE_MANUAL_CODE = False -@tvm.register_func("tvm_callback_cuda_compile", override=True) -def tvm_callback_cuda_compile(code): - ptx = nvcc.compile_cuda(code, target_format="ptx") - return ptx - - -def write_code(code, fname): - with open(fname, "w") as f: - f.write(code) - - -@tvm.register_func -def tvm_callback_cuda_postproc(code): - if not os.path.exists("perf"): - os.mkdir("perf") - write_code(code, "perf/%s_generated.cu" % TASK) - if USE_MANUAL_CODE: - code = open("perf/%s_manual.cu" % TASK).read() - return code - - def test_gemm(): # graph nn = 2048 diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 03d880e7769e..cc18de25ee9e 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -104,8 +104,32 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target max_threads_per_block, max_vthread_extent, warp_size); } else { // add other opencl target - auto target_device = target->GetAttr("device", ""); - LOG(FATAL) << "No default hardware parameters for opencl target device: " << target_device; + auto dev = Device{static_cast(device_type), 0}; + auto device_name = "device_api.opencl"; + auto func = tvm::runtime::Registry::Get(device_name); + ICHECK(func != nullptr) << "Cannot find OpenCL device_api in registry"; + auto device_api = static_cast(((*func)()).operator void*()); + + tvm::runtime::TVMRetValue ret; + device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); + int max_shared_memory_per_block = ret; + + int max_local_memory_per_block = INT32_MAX; + + device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); + int max_threads_per_block = ret; + + device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); + int warp_size = ret; + + if (warp_size == 1) { + LOG(WARNING) + << "Warp size 1 is not recommended for OpenCL devices. Tuning might crash or stuck"; + } + + int max_vthread_extent = warp_size / 4; + return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block, + max_threads_per_block, max_vthread_extent, warp_size); } } else if (device_type == kDLVulkan) { auto dev = Device{static_cast(device_type), 0}; diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index f12a143ab0cc..5274d7713441 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -20,6 +20,7 @@ /*! * \file opencl_device_api.cc */ +#include #include #include @@ -122,7 +123,8 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) corresponding to the number of SIMD entries the heardware configures. We need to figure out a way to query this information from the hardware. */ - *rv = 1; + const int warp_size = dmlc::GetEnv("TVM_OPENCL_WARP_SIZE", 1); + *rv = warp_size; break; } case kMaxSharedMemoryPerBlock: {