Skip to content

Commit

Permalink
[Cherry-pick][ROCm] fix dcu error in device event base, test=develop (#…
Browse files Browse the repository at this point in the history
…41523)

Cherry-pick of #41521
  • Loading branch information
qili93 committed Apr 8, 2022
1 parent cb7551f commit ebe72b8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/platform/device_event.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using ::paddle::platform::kCPU;
USE_EVENT(kCPU)
USE_EVENT_WAIT(kCPU, kCPU)

#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
USE_EVENT(kCUDA);
USE_EVENT_WAIT(kCUDA, kCUDA)
USE_EVENT_WAIT(kCPU, kCUDA)
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/platform/device_event_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/event.h"

#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
namespace paddle {
namespace platform {
struct CUDADeviceEventWrapper {
Expand Down
52 changes: 52 additions & 0 deletions paddle/fluid/platform/device_event_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,58 @@ TEST(DeviceEvent, CUDA) {
}
#endif

#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>

TEST(DeviceEvent, CUDA) {
VLOG(1) << "In Test";
using paddle::platform::CUDAPlace;

auto& pool = DeviceContextPool::Instance();
auto place = CUDAPlace(0);
auto* context =
static_cast<paddle::platform::CUDADeviceContext*>(pool.Get(place));

ASSERT_NE(context, nullptr);
// case 1. test for event_creator
DeviceEvent event(place);
ASSERT_NE(event.GetEvent().get(), nullptr);
bool status = event.Query();
ASSERT_EQ(status, true);
// case 2. test for event_recorder
event.Record(context);
status = event.Query();
ASSERT_EQ(status, false);
// case 3. test for event_finisher
event.Finish();
status = event.Query();
ASSERT_EQ(status, true);

// case 4. test for event_waiter
float *src_fp32, *dst_fp32;
int size = 1000000 * sizeof(float);
hipMallocHost(reinterpret_cast<void**>(&src_fp32), size);
hipMalloc(reinterpret_cast<void**>(&dst_fp32), size);
hipMemcpyAsync(dst_fp32, src_fp32, size, hipMemcpyHostToDevice,
context->stream());
event.Record(context); // step 1. record it
status = event.Query();
ASSERT_EQ(status, false);

event.Wait(kCUDA, context); // step 2. add streamWaitEvent
status = event.Query();
ASSERT_EQ(status, false); // async

event.Wait(kCPU, context); // step 3. EventSynchornize
status = event.Query();
ASSERT_EQ(status, true); // sync

// release resource
hipFree(dst_fp32);
hipFreeHost(src_fp32);
}
#endif

TEST(DeviceEvent, CPU) {
using paddle::platform::CPUPlace;
auto place = CPUPlace();
Expand Down

0 comments on commit ebe72b8

Please sign in to comment.