Skip to content

Commit

Permalink
optimize thrust alloc (PaddlePaddle#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
Thunderbrook committed Sep 14, 2022
1 parent 9441c97 commit 55739cf
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion paddle/fluid/operators/shuffle_batch_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,37 @@
namespace paddle {
namespace operators {

struct CacheAllocator {
typedef char value_type;
CacheAllocator(platform::Place place) {
VLOG(2) << "construct allocator";
place_ = place;
}

~CacheAllocator() { VLOG(2) << "destory allocator"; }

char *allocate(std::ptrdiff_t num_bytes) {
VLOG(2) << "allocate " << num_bytes << " bytes";
auto storage = memory::AllocShared(place_, num_bytes);
char *ptr = reinterpret_cast<char *>(storage->ptr());
busy_allocation_.emplace(std::make_pair(ptr, storage));
return ptr;
}

void deallocate(char *ptr, size_t) {
VLOG(2) << "deallocate ";
allocation_map_type::iterator iter = busy_allocation_.find(ptr);
CHECK(iter != busy_allocation_.end());
busy_allocation_.erase(iter);
}

private:
typedef std::unordered_map<char *, std::shared_ptr<phi::Allocation>>
allocation_map_type;
allocation_map_type busy_allocation_;
platform::Place place_;
};

template <typename T, bool kIsForward>
struct ReorderFunctor {
ReorderFunctor(const T *x, const int64_t *shuffle_idx, T *y, int64_t stride)
Expand Down Expand Up @@ -90,7 +121,8 @@ class ShuffleBatchCUDAKernel : public framework::OpKernel<T> {

auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
#ifdef PADDLE_WITH_CUDA
const auto &exec_policy = thrust::cuda::par.on(dev_ctx.stream());
CacheAllocator allocator(ctx.GetPlace());
const auto &exec_policy = thrust::cuda::par(allocator).on(dev_ctx.stream());
#else
const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream());
#endif
Expand Down

0 comments on commit 55739cf

Please sign in to comment.