Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Makes UR cuda backend compatible with MPI #2077

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 44 additions & 29 deletions source/adapters/cuda/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,35 @@ typedef void (*ur_context_extended_deleter_t)(void *user_data);
/// if necessary.
///
///
namespace {
class ScopedContext {
public:
ScopedContext(ur_device_handle_t Device) {
if (!Device) {
throw UR_RESULT_ERROR_INVALID_DEVICE;
}
setContext(Device->getNativeContext());
}

ScopedContext(CUcontext NativeContext) { setContext(NativeContext); }

~ScopedContext() {}

private:
void setContext(CUcontext Desired) {
CUcontext Original = nullptr;

UR_CHECK_ERROR(cuCtxGetCurrent(&Original));

// Make sure the desired context is active on the current thread, setting
// it if necessary
if (Original != Desired) {
UR_CHECK_ERROR(cuCtxSetCurrent(Desired));
}
}
};
} // namespace

struct ur_context_handle_t_ {

struct deleter_data {
Expand All @@ -88,14 +117,29 @@ struct ur_context_handle_t_ {

ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
: Devices{Devs, Devs + NumDevices}, RefCount{1} {
CUevent EvBase = nullptr;
int i = 0;
for (auto &Dev : Devices) {
urDeviceRetain(Dev);
Dev->retainNativeContext();
Copy link
Contributor

@hdelan hdelan Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would personally prefer to see all of this logic for contexts happen in ur_queue_handle_t_s. This avoids giving sycl::contexts extra semantics for the CUDA backend. Within urQueueCreate you could call something like ur_device_handle_t_::init_device() which would retain the primary ctx and then set the base event, which would then be cached in the device, so if another queue is created for the same device, it doesn't need to do the same base event getting, info querying, etc.

Let's see what @npmiller thinks.

// The first device in the context is used to create a base event for all
// devices in the context. Any queue created with this context will have
// the same base event used as a base timestamp for profiling.
if (i == 0) {
ScopedContext Active(Dev);
UR_CHECK_ERROR(cuEventCreate(&EvBase, CU_EVENT_DEFAULT));
// Use default stream to record base event counter
UR_CHECK_ERROR(cuEventRecord(EvBase, 0));
}
Dev->setBaseEvent(EvBase);
i++;
}
};

~ur_context_handle_t_() {
for (auto &Dev : Devices) {
urDeviceRelease(Dev);
UR_CHECK_ERROR(cuDevicePrimaryCtxRelease(Dev->get()));
}
}

Expand Down Expand Up @@ -140,32 +184,3 @@ struct ur_context_handle_t_ {
std::vector<deleter_data> ExtendedDeleters;
std::set<ur_usm_pool_handle_t> PoolHandles;
};

namespace {
class ScopedContext {
public:
ScopedContext(ur_device_handle_t Device) {
if (!Device) {
throw UR_RESULT_ERROR_INVALID_DEVICE;
}
setContext(Device->getNativeContext());
}

ScopedContext(CUcontext NativeContext) { setContext(NativeContext); }

~ScopedContext() {}

private:
void setContext(CUcontext Desired) {
CUcontext Original = nullptr;

UR_CHECK_ERROR(cuCtxGetCurrent(&Original));

// Make sure the desired context is active on the current thread, setting
// it if necessary
if (Original != Desired) {
UR_CHECK_ERROR(cuCtxSetCurrent(Desired));
}
}
};
} // namespace
14 changes: 12 additions & 2 deletions source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,

static constexpr uint32_t MaxWorkItemDimensions = 3u;

ScopedContext Active(hDevice);

switch ((uint32_t)propName) {
case UR_DEVICE_INFO_TYPE: {
return ReturnValue(UR_DEVICE_TYPE_GPU);
Expand Down Expand Up @@ -809,9 +807,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
case UR_DEVICE_INFO_GLOBAL_MEM_FREE: {
size_t FreeMemory = 0;
size_t TotalMemory = 0;
// This driver api call requires a CUcontext set on the device. In the case
// that a CUcontext is not already initialized on the device, we implicitly
// choose to momentarily initialize the primary CUcontext of the device (by
// retaining it). The CUcontext is immediately released in case the device
// is not later used. In the unlikely circumstance that the primary
// CUcontext has not already been retained by the programmer, the implicit
// initialization/destruction here adds a ~0.1s overhead, but this ensures
// SYCL specification compliance and MPI compatibility.
CUcontext tmpContext;
UR_CHECK_ERROR(cuDevicePrimaryCtxRetain(&tmpContext, hDevice->get()));
ScopedContext Active(tmpContext);
detail::ur::assertion(cuMemGetInfo(&FreeMemory, &TotalMemory) ==
CUDA_SUCCESS,
"failed cuMemGetInfo() API.");
UR_CHECK_ERROR(cuDevicePrimaryCtxRelease(hDevice->get()));
return ReturnValue(FreeMemory);
}
case UR_DEVICE_INFO_MEMORY_CLOCK_RATE: {
Expand Down
8 changes: 7 additions & 1 deletion source/adapters/cuda/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,13 @@ struct ur_device_handle_t_ {
UR_CHECK_ERROR(cuDeviceTotalMem(&MaxAllocSize, cuDevice));
}

~ur_device_handle_t_() { cuDevicePrimaryCtxRelease(CuDevice); }
~ur_device_handle_t_() {}

void retainNativeContext() {
UR_CHECK_ERROR(cuDevicePrimaryCtxRetain(&CuContext, CuDevice));
};

void setBaseEvent(const CUevent &event) { EvBase = event; };

native_type get() const noexcept { return CuDevice; };

Expand Down
12 changes: 2 additions & 10 deletions source/adapters/cuda/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,11 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
int NumDevices = 0;
UR_CHECK_ERROR(cuDeviceGetCount(&NumDevices));
try {
CUevent EvBase = nullptr;
for (int i = 0; i < NumDevices; ++i) {
CUdevice Device;
UR_CHECK_ERROR(cuDeviceGet(&Device, i));
CUcontext Context;
UR_CHECK_ERROR(cuDevicePrimaryCtxRetain(&Context, Device));

ScopedContext Active(Context); // Set native ctx as active
CUevent EvBase;
UR_CHECK_ERROR(cuEventCreate(&EvBase, CU_EVENT_DEFAULT));

// Use default stream to record base event counter
UR_CHECK_ERROR(cuEventRecord(EvBase, 0));

CUcontext Context = nullptr;
Platform.Devices.emplace_back(
new ur_device_handle_t_{Device, Context, EvBase, &Platform,
static_cast<uint32_t>(i)});
Expand Down
1 change: 1 addition & 0 deletions test/conformance/device/device_adapter_cuda.match
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
urDeviceCreateWithNativeHandleTest.SuccessWithUnOwnedNativeHandle
{{OPT}}urDeviceGetGlobalTimestampTest.Success
{{OPT}}urDeviceGetGlobalTimestampTest.SuccessSynchronizedTime
Loading