diff --git a/source/adapters/cuda/context.hpp b/source/adapters/cuda/context.hpp index a10e8e9ca7..f2bb1580ba 100644 --- a/source/adapters/cuda/context.hpp +++ b/source/adapters/cuda/context.hpp @@ -74,6 +74,35 @@ typedef void (*ur_context_extended_deleter_t)(void *user_data); /// if necessary. /// /// +namespace { +class ScopedContext { +public: + ScopedContext(ur_device_handle_t Device) { + if (!Device) { + throw UR_RESULT_ERROR_INVALID_DEVICE; + } + setContext(Device->getNativeContext()); + } + + ScopedContext(CUcontext NativeContext) { setContext(NativeContext); } + + ~ScopedContext() {} + +private: + void setContext(CUcontext Desired) { + CUcontext Original = nullptr; + + UR_CHECK_ERROR(cuCtxGetCurrent(&Original)); + + // Make sure the desired context is active on the current thread, setting + // it if necessary + if (Original != Desired) { + UR_CHECK_ERROR(cuCtxSetCurrent(Desired)); + } + } +}; +} // namespace + struct ur_context_handle_t_ { struct deleter_data { @@ -88,14 +117,29 @@ struct ur_context_handle_t_ { ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices) : Devices{Devs, Devs + NumDevices}, RefCount{1} { + CUevent EvBase = nullptr; + int i = 0; for (auto &Dev : Devices) { urDeviceRetain(Dev); + Dev->retainNativeContext(); + // The first device in the context is used to create a base event for all + // devices in the context. Any queue created with this context will have + // the same base event used as a base timestamp for profiling. + if (i == 0) { + ScopedContext Active(Dev); + UR_CHECK_ERROR(cuEventCreate(&EvBase, CU_EVENT_DEFAULT)); + // Use default stream to record base event counter + UR_CHECK_ERROR(cuEventRecord(EvBase, 0)); + } + Dev->setBaseEvent(EvBase); + i++; } }; ~ur_context_handle_t_() { for (auto &Dev : Devices) { urDeviceRelease(Dev); + UR_CHECK_ERROR(cuDevicePrimaryCtxRelease(Dev->get())); } } @@ -140,32 +184,3 @@ struct ur_context_handle_t_ { std::vector ExtendedDeleters; std::set PoolHandles; }; - -namespace { -class ScopedContext { -public: - ScopedContext(ur_device_handle_t Device) { - if (!Device) { - throw UR_RESULT_ERROR_INVALID_DEVICE; - } - setContext(Device->getNativeContext()); - } - - ScopedContext(CUcontext NativeContext) { setContext(NativeContext); } - - ~ScopedContext() {} - -private: - void setContext(CUcontext Desired) { - CUcontext Original = nullptr; - - UR_CHECK_ERROR(cuCtxGetCurrent(&Original)); - - // Make sure the desired context is active on the current thread, setting - // it if necessary - if (Original != Desired) { - UR_CHECK_ERROR(cuCtxSetCurrent(Desired)); - } - } -}; -} // namespace diff --git a/source/adapters/cuda/device.cpp b/source/adapters/cuda/device.cpp index 9c8a0c807c..976daa08c2 100644 --- a/source/adapters/cuda/device.cpp +++ b/source/adapters/cuda/device.cpp @@ -47,8 +47,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, static constexpr uint32_t MaxWorkItemDimensions = 3u; - ScopedContext Active(hDevice); - switch ((uint32_t)propName) { case UR_DEVICE_INFO_TYPE: { return ReturnValue(UR_DEVICE_TYPE_GPU); @@ -809,9 +807,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, case UR_DEVICE_INFO_GLOBAL_MEM_FREE: { size_t FreeMemory = 0; size_t TotalMemory = 0; + // This driver api call requires a CUcontext set on the device. In the case + // that a CUcontext is not already initialized on the device, we implicitly + // choose to momentarily initialize the primary CUcontext of the device (by + // retaining it). The CUcontext is immediately released in case the device + // is not later used. In the unlikely circumstance that the primary + // CUcontext has not already been retained by the programmer, the implicit + // initialization/destruction here adds a ~0.1s overhead, but this ensures + // SYCL specification compliance and MPI compatibility. + CUcontext tmpContext; + UR_CHECK_ERROR(cuDevicePrimaryCtxRetain(&tmpContext, hDevice->get())); + ScopedContext Active(tmpContext); detail::ur::assertion(cuMemGetInfo(&FreeMemory, &TotalMemory) == CUDA_SUCCESS, "failed cuMemGetInfo() API."); + UR_CHECK_ERROR(cuDevicePrimaryCtxRelease(hDevice->get())); return ReturnValue(FreeMemory); } case UR_DEVICE_INFO_MEMORY_CLOCK_RATE: { diff --git a/source/adapters/cuda/device.hpp b/source/adapters/cuda/device.hpp index 3654f2bb36..0776c2ad0a 100644 --- a/source/adapters/cuda/device.hpp +++ b/source/adapters/cuda/device.hpp @@ -81,7 +81,13 @@ struct ur_device_handle_t_ { UR_CHECK_ERROR(cuDeviceTotalMem(&MaxAllocSize, cuDevice)); } - ~ur_device_handle_t_() { cuDevicePrimaryCtxRelease(CuDevice); } + ~ur_device_handle_t_() {} + + void retainNativeContext() { + UR_CHECK_ERROR(cuDevicePrimaryCtxRetain(&CuContext, CuDevice)); + }; + + void setBaseEvent(const CUevent &event) { EvBase = event; }; native_type get() const noexcept { return CuDevice; }; diff --git a/source/adapters/cuda/platform.cpp b/source/adapters/cuda/platform.cpp index 218cf9b0db..d04f13ce8c 100644 --- a/source/adapters/cuda/platform.cpp +++ b/source/adapters/cuda/platform.cpp @@ -77,19 +77,11 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, int NumDevices = 0; UR_CHECK_ERROR(cuDeviceGetCount(&NumDevices)); try { + CUevent EvBase = nullptr; for (int i = 0; i < NumDevices; ++i) { CUdevice Device; UR_CHECK_ERROR(cuDeviceGet(&Device, i)); - CUcontext Context; - UR_CHECK_ERROR(cuDevicePrimaryCtxRetain(&Context, Device)); - - ScopedContext Active(Context); // Set native ctx as active - CUevent EvBase; - UR_CHECK_ERROR(cuEventCreate(&EvBase, CU_EVENT_DEFAULT)); - - // Use default stream to record base event counter - UR_CHECK_ERROR(cuEventRecord(EvBase, 0)); - + CUcontext Context = nullptr; Platform.Devices.emplace_back( new ur_device_handle_t_{Device, Context, EvBase, &Platform, static_cast(i)}); diff --git a/test/conformance/device/device_adapter_cuda.match b/test/conformance/device/device_adapter_cuda.match index 9989fbd774..a138435306 100644 --- a/test/conformance/device/device_adapter_cuda.match +++ b/test/conformance/device/device_adapter_cuda.match @@ -1,2 +1,3 @@ urDeviceCreateWithNativeHandleTest.SuccessWithUnOwnedNativeHandle +{{OPT}}urDeviceGetGlobalTimestampTest.Success {{OPT}}urDeviceGetGlobalTimestampTest.SuccessSynchronizedTime