-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Ansor] Improve OpenCL support #10108
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,8 +104,31 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target | |
max_threads_per_block, max_vthread_extent, warp_size); | ||
} else { | ||
// add other opencl target | ||
auto target_device = target->GetAttr<String>("device", ""); | ||
LOG(FATAL) << "No default hardware parameters for opencl target device: " << target_device; | ||
auto dev = Device{static_cast<DLDeviceType>(device_type), 0}; | ||
auto device_name = "device_api.opencl"; | ||
auto func = tvm::runtime::Registry::Get(device_name); | ||
ICHECK(func != nullptr) << "Cannot find OpenCL device_api in registry"; | ||
auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*()); | ||
|
||
tvm::runtime::TVMRetValue ret; | ||
device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); | ||
int max_shared_memory_per_block = ret; | ||
|
||
int max_local_memory_per_block = INT32_MAX; | ||
|
||
device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); | ||
int max_threads_per_block = ret; | ||
|
||
device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); | ||
int warp_size = ret; | ||
|
||
if (warp_size == 1) { | ||
LOG(WARNING) << "Th warp size is 1, tuning might crash or stuck."; | ||
masahi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
int max_vthread_extent = warp_size / 4; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I just come back after vocation. I want to check here max_vthread_extent. As I wrote in the tutorial https://github.com/apache/tvm/blob/main/gallery/how_to/tune_with_autoscheduler/tune_network_mali.py#L188-L194 : There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should do it like Vulkan |
||
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block, | ||
max_threads_per_block, max_vthread_extent, warp_size); | ||
} | ||
} else if (device_type == kDLVulkan) { | ||
auto dev = Device{static_cast<DLDeviceType>(device_type), 0}; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
/*! | ||
* \file opencl_device_api.cc | ||
*/ | ||
#include <dmlc/parameter.h> | ||
#include <dmlc/thread_local.h> | ||
#include <tvm/runtime/registry.h> | ||
|
||
|
@@ -122,7 +123,8 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) | |
corresponding to the number of SIMD entries the heardware configures. | ||
We need to figure out a way to query this information from the hardware. | ||
*/ | ||
*rv = 1; | ||
const int warp_size = dmlc::GetEnv("TVM_OPENCL_WARP_SIZE", 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although I don't really like environment variable since it creates side effects, I don't have a better solution just as mentioned by the above TODO. Maybe that's it for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our vulkan backend has a better solution, that uses a Vulkan API function to query the warp size on a given HW. However, I didn't find such API in OpenCL, for some reason. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Understood. It's reasonable that the desired API is not always available. A better solution I'm thinking is the direction of exposing this option to the hardware parameters in tuning options instead of an environment variable. For example, the default wrap size of OpenCL devices is always 1, or user should provide wrap size in hardware parameters otherwise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That reminds me of the fact that it is already possible to set HW params from a python script https://github.com/apache/tvm/blob/main/gallery/how_to/tune_with_autoscheduler/tune_network_mali.py#L188-L194 So in practice, this patch might not be necessary. But since the possibility to manually specify HW params is not known well and cumbersome anyway, I still want to land this PR. What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm I agree to have this change but with a different reason. If we just look at the device API without Ansor, this change could be the only general workaround for OpenCL devices. Specifically, any place in TVM may query device API to get the wrap size, so setting default wrap size to 1 in Ansor might not be a general solution. |
||
*rv = warp_size; | ||
break; | ||
} | ||
case kMaxSharedMemoryPerBlock: { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is stale and blocks this script from running. We don't need this anymore, so better just to remove it.