Skip to content

Commit

Permalink
[Ansor] Improve OpenCL support (apache#10108)
Browse files Browse the repository at this point in the history
* Support OpenCL in Autoscheduler tuning

* add warning

* Update src/auto_scheduler/search_task.cc

Co-authored-by: Cody Yu <comaniac0422@gmail.com>

* fix lint

Co-authored-by: Cody Yu <comaniac0422@gmail.com>
  • Loading branch information
2 people authored and ylc committed Feb 16, 2022
1 parent 435ed93 commit 81c16c1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 24 deletions.
21 changes: 0 additions & 21 deletions apps/topi_recipe/gemm/cuda_gemm_square.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,6 @@
USE_MANUAL_CODE = False


@tvm.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target_format="ptx")
return ptx


def write_code(code, fname):
with open(fname, "w") as f:
f.write(code)


@tvm.register_func
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
os.mkdir("perf")
write_code(code, "perf/%s_generated.cu" % TASK)
if USE_MANUAL_CODE:
code = open("perf/%s_manual.cu" % TASK).read()
return code


def test_gemm():
# graph
nn = 2048
Expand Down
28 changes: 26 additions & 2 deletions src/auto_scheduler/search_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,32 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
max_threads_per_block, max_vthread_extent, warp_size);
} else {
// add other opencl target
auto target_device = target->GetAttr<String>("device", "");
LOG(FATAL) << "No default hardware parameters for opencl target device: " << target_device;
auto dev = Device{static_cast<DLDeviceType>(device_type), 0};
auto device_name = "device_api.opencl";
auto func = tvm::runtime::Registry::Get(device_name);
ICHECK(func != nullptr) << "Cannot find OpenCL device_api in registry";
auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());

tvm::runtime::TVMRetValue ret;
device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret);
int max_shared_memory_per_block = ret;

int max_local_memory_per_block = INT32_MAX;

device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret);
int max_threads_per_block = ret;

device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kWarpSize, &ret);
int warp_size = ret;

if (warp_size == 1) {
LOG(WARNING)
<< "Warp size 1 is not recommended for OpenCL devices. Tuning might crash or stuck";
}

int max_vthread_extent = warp_size / 4;
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
}
} else if (device_type == kDLVulkan) {
auto dev = Device{static_cast<DLDeviceType>(device_type), 0};
Expand Down
4 changes: 3 additions & 1 deletion src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
/*!
* \file opencl_device_api.cc
*/
#include <dmlc/parameter.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>

Expand Down Expand Up @@ -122,7 +123,8 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
corresponding to the number of SIMD entries the heardware configures.
We need to figure out a way to query this information from the hardware.
*/
*rv = 1;
const int warp_size = dmlc::GetEnv("TVM_OPENCL_WARP_SIZE", 1);
*rv = warp_size;
break;
}
case kMaxSharedMemoryPerBlock: {
Expand Down

0 comments on commit 81c16c1

Please sign in to comment.