Skip to content

Commit

Permalink
Merge pull request #29 from qingshui/paddlebox
Browse files Browse the repository at this point in the history
优化dump、优化rank_attention单卡节约显存4G,使用paddle现有结构进行pack重构
  • Loading branch information
qingshui committed Feb 10, 2022
2 parents 349bc2d + 9712e31 commit a283817
Show file tree
Hide file tree
Showing 18 changed files with 578 additions and 321 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/framework/boxps_trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ void BoxPSTrainer::Initialize(const TrainerDesc& trainer_desc,
VLOG(3) << "async mode ";
}
dump_thread_num_ = param_config_.dump_thread_num();
if (need_dump_field_ && dump_thread_num_ <= 0) {
dump_thread_num_ = 1;
if (need_dump_field_ && dump_thread_num_ <= 1) {
dump_thread_num_ = 20;
}

workers_.resize(thread_num_);
Expand Down
67 changes: 39 additions & 28 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2581,16 +2581,16 @@ void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(const int ins_num) {
const UsedSlotGpuType* used_slot_gpu_types =
static_cast<const UsedSlotGpuType*>(pack_->get_gpu_slots());
FillSlotValueOffset(ins_num, use_slot_size_, pack_->gpu_slot_offsets(),
value.d_uint64_offset.data(), uint64_use_slot_size_,
value.d_float_offset.data(), float_use_slot_size_,
value.d_uint64_offset.data<int>(), uint64_use_slot_size_,
value.d_float_offset.data<int>(), float_use_slot_size_,
used_slot_gpu_types);
fill_timer_.Pause();
size_t* d_slot_offsets = pack_->gpu_slot_offsets();

offset_timer_.Resume();
HostBuffer<size_t>& offsets = pack_->offsets();
std::vector<size_t>& offsets = pack_->offsets();
offsets.resize(slot_total_num);
HostBuffer<void*>& h_tensor_ptrs = pack_->h_tensor_ptrs();
std::vector<void*>& h_tensor_ptrs = pack_->h_tensor_ptrs();
h_tensor_ptrs.resize(use_slot_size_);
// alloc gpu memory
pack_->resize_tensor();
Expand Down Expand Up @@ -2679,15 +2679,13 @@ void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(const int ins_num) {
use_slot_size_ * sizeof(void*),
cudaMemcpyHostToDevice));

CopyForTensor(ins_num, use_slot_size_, dest_gpu_p,
(const size_t*)pack_->gpu_slot_offsets(),
(const uint64_t*)value.d_uint64_keys.data(),
(const int*)value.d_uint64_offset.data(),
(const int*)value.d_uint64_lens.data(), uint64_use_slot_size_,
(const float*)value.d_float_keys.data(),
(const int*)value.d_float_offset.data(),
(const int*)value.d_float_lens.data(), float_use_slot_size_,
used_slot_gpu_types);
CopyForTensor(
ins_num, use_slot_size_, dest_gpu_p, pack_->gpu_slot_offsets(),
reinterpret_cast<const uint64_t*>(value.d_uint64_keys.data<int64_t>()),
value.d_uint64_offset.data<int>(), value.d_uint64_lens.data<int>(),
uint64_use_slot_size_, value.d_float_keys.data<float>(),
value.d_float_offset.data<int>(), value.d_float_lens.data<int>(),
float_use_slot_size_, used_slot_gpu_types);
trans_timer_.Pause();
#endif
}
Expand All @@ -2708,9 +2706,8 @@ void SlotPaddleBoxDataFeed::GetRankOffsetGPU(const int pv_num,
int* tensor_ptr =
rank_offset_->mutable_data<int>({ins_num, col}, this->place_);
CopyRankOffset(tensor_ptr, ins_num, pv_num, max_rank,
(const int*)value.d_rank.data(),
(const int*)value.d_cmatch.data(),
(const int*)value.d_ad_offset.data(), col);
value.d_rank.data<int>(), value.d_cmatch.data<int>(),
value.d_ad_offset.data<int>(), col);
#endif
}
void SlotPaddleBoxDataFeed::GetRankOffset(const SlotPvInstance* pv_vec,
Expand Down Expand Up @@ -3827,7 +3824,11 @@ MiniBatchGpuPack::MiniBatchGpuPack(const paddle::platform::Place& place,
++used_float_num_;
}
}
copy_host2device(&gpu_slots_, gpu_used_slots_.data(), gpu_used_slots_.size());
gpu_slots_ = memory::AllocShared(
place_, gpu_used_slots_.size() * sizeof(UsedSlotGpuType));
CUDA_CHECK(cudaMemcpyAsync(gpu_slots_->ptr(), gpu_used_slots_.data(),
gpu_used_slots_.size() * sizeof(UsedSlotGpuType),
cudaMemcpyHostToDevice, stream_));

slot_buf_ptr_ = memory::AllocShared(place_, used_slot_size_ * sizeof(void*));

Expand Down Expand Up @@ -4067,17 +4068,27 @@ void MiniBatchGpuPack::pack_instance(const SlotRecord* ins_vec, int num) {
void MiniBatchGpuPack::transfer_to_gpu(void) {
trans_timer_.Resume();
if (enable_pv_) {
copy_host2device(&value_.d_ad_offset, buf_.h_ad_offset);
copy_host2device(&value_.d_rank, buf_.h_rank);
copy_host2device(&value_.d_cmatch, buf_.h_cmatch);
}
copy_host2device(&value_.d_uint64_lens, buf_.h_uint64_lens);
copy_host2device(&value_.d_uint64_keys, buf_.h_uint64_keys);
copy_host2device(&value_.d_uint64_offset, buf_.h_uint64_offset);

copy_host2device(&value_.d_float_lens, buf_.h_float_lens);
copy_host2device(&value_.d_float_keys, buf_.h_float_keys);
copy_host2device(&value_.d_float_offset, buf_.h_float_offset);
copy_host2device(&value_.d_ad_offset, buf_.h_ad_offset.data(),
buf_.h_ad_offset.size());
copy_host2device(&value_.d_rank, buf_.h_rank.data(), buf_.h_rank.size());
copy_host2device(&value_.d_cmatch, buf_.h_cmatch.data(),
buf_.h_cmatch.size());
}
copy_host2device(&value_.d_uint64_lens, buf_.h_uint64_lens.data(),
buf_.h_uint64_lens.size());
copy_host2device<int64_t>(
&value_.d_uint64_keys,
reinterpret_cast<int64_t*>(buf_.h_uint64_keys.data()),
buf_.h_uint64_keys.size());
copy_host2device(&value_.d_uint64_offset, buf_.h_uint64_offset.data(),
buf_.h_uint64_offset.size());

copy_host2device(&value_.d_float_lens, buf_.h_float_lens.data(),
buf_.h_float_lens.size());
copy_host2device(&value_.d_float_keys, buf_.h_float_keys.data(),
buf_.h_float_keys.size());
copy_host2device(&value_.d_float_offset, buf_.h_float_offset.data(),
buf_.h_float_offset.size());
CUDA_CHECK(cudaStreamSynchronize(stream_));
trans_timer_.Pause();
}
Expand Down
135 changes: 27 additions & 108 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -1320,109 +1320,32 @@ struct UsedSlotGpuType {
int slot_value_idx;
};
#define CUDA_CHECK(val) CHECK(val == cudaSuccess)
template <typename T>
struct CudaBuffer {
T* cu_buffer;
uint64_t buf_size;

CudaBuffer<T>() {
cu_buffer = NULL;
buf_size = 0;
}
~CudaBuffer<T>() { free(); }
T* data() { return cu_buffer; }
uint64_t size() { return buf_size; }
void malloc(uint64_t size) {
buf_size = size;
CUDA_CHECK(
cudaMalloc(reinterpret_cast<void**>(&cu_buffer), size * sizeof(T)));
}
void free() {
if (cu_buffer != NULL) {
CUDA_CHECK(cudaFree(cu_buffer));
cu_buffer = NULL;
}
buf_size = 0;
}
void resize(uint64_t size) {
if (size <= buf_size) {
return;
}
free();
malloc(size);
}
};
template <typename T>
struct HostBuffer {
T* host_buffer;
size_t buf_size;
size_t data_len;

HostBuffer<T>() {
host_buffer = NULL;
buf_size = 0;
data_len = 0;
}
~HostBuffer<T>() { free(); }

T* data() { return host_buffer; }
const T* data() const { return host_buffer; }
size_t size() const { return data_len; }
void clear() { free(); }
T& back() { return host_buffer[data_len - 1]; }

T& operator[](size_t i) { return host_buffer[i]; }
const T& operator[](size_t i) const { return host_buffer[i]; }
void malloc(size_t len) {
buf_size = len;
CUDA_CHECK(cudaHostAlloc(reinterpret_cast<void**>(&host_buffer),
buf_size * sizeof(T), cudaHostAllocDefault));
CHECK(host_buffer != NULL);
}
void free() {
if (host_buffer != NULL) {
CUDA_CHECK(cudaFreeHost(host_buffer));
host_buffer = NULL;
}
buf_size = 0;
}
void resize(size_t size) {
if (size <= buf_size) {
data_len = size;
return;
}
data_len = size;
free();
malloc(size);
}
};

struct BatchCPUValue {
HostBuffer<int> h_uint64_lens;
HostBuffer<uint64_t> h_uint64_keys;
HostBuffer<int> h_uint64_offset;
std::vector<int> h_uint64_lens;
std::vector<uint64_t> h_uint64_keys;
std::vector<int> h_uint64_offset;

HostBuffer<int> h_float_lens;
HostBuffer<float> h_float_keys;
HostBuffer<int> h_float_offset;
std::vector<int> h_float_lens;
std::vector<float> h_float_keys;
std::vector<int> h_float_offset;

HostBuffer<int> h_rank;
HostBuffer<int> h_cmatch;
HostBuffer<int> h_ad_offset;
std::vector<int> h_rank;
std::vector<int> h_cmatch;
std::vector<int> h_ad_offset;
};

struct BatchGPUValue {
CudaBuffer<int> d_uint64_lens;
CudaBuffer<uint64_t> d_uint64_keys;
CudaBuffer<int> d_uint64_offset;
Tensor d_uint64_lens;
Tensor d_uint64_keys;
Tensor d_uint64_offset;

CudaBuffer<int> d_float_lens;
CudaBuffer<float> d_float_keys;
CudaBuffer<int> d_float_offset;
Tensor d_float_lens;
Tensor d_float_keys;
Tensor d_float_offset;

CudaBuffer<int> d_rank;
CudaBuffer<int> d_cmatch;
CudaBuffer<int> d_ad_offset;
Tensor d_rank;
Tensor d_cmatch;
Tensor d_ad_offset;
};

class SlotPaddleBoxDataFeed;
Expand All @@ -1439,7 +1362,7 @@ class MiniBatchGpuPack {
BatchGPUValue& value() { return value_; }
BatchCPUValue& cpu_value() { return buf_; }
UsedSlotGpuType* get_gpu_slots(void) {
return reinterpret_cast<UsedSlotGpuType*>(gpu_slots_.data());
return reinterpret_cast<UsedSlotGpuType*>(gpu_slots_->ptr());
}
SlotRecord* get_records(void) { return &ins_vec_[0]; }
double pack_time_span(void) { return pack_timer_.ElapsedSec(); }
Expand All @@ -1466,8 +1389,8 @@ class MiniBatchGpuPack {
LoDTensor& float_tensor(void) { return float_tensor_; }
LoDTensor& uint64_tensor(void) { return uint64_tensor_; }

HostBuffer<size_t>& offsets(void) { return offsets_; }
HostBuffer<void*>& h_tensor_ptrs(void) { return h_tensor_ptrs_; }
std::vector<size_t>& offsets(void) { return offsets_; }
std::vector<void*>& h_tensor_ptrs(void) { return h_tensor_ptrs_; }

size_t* gpu_slot_offsets(void) {
return reinterpret_cast<size_t*>(gpu_slot_offsets_.data<int64_t>());
Expand Down Expand Up @@ -1497,18 +1420,14 @@ class MiniBatchGpuPack {

public:
template <typename T>
void copy_host2device(CudaBuffer<T>* buf, const T* val, size_t size) {
void copy_host2device(Tensor* buf, const T* val, size_t size) {
if (size == 0) {
return;
}
buf->resize(size);
CUDA_CHECK(cudaMemcpyAsync(buf->data(), val, size * sizeof(T),
T* data = buf->mutable_data<T>({static_cast<int64_t>(size), 1}, place_);
CUDA_CHECK(cudaMemcpyAsync(data, val, size * sizeof(T),
cudaMemcpyHostToDevice, stream_));
}
template <typename T>
void copy_host2device(CudaBuffer<T>* buf, const HostBuffer<T>& val) {
copy_host2device(buf, val.data(), val.size());
}

private:
paddle::platform::Place place_;
Expand All @@ -1523,7 +1442,7 @@ class MiniBatchGpuPack {
int used_uint64_num_ = 0;
int used_slot_size_ = 0;

CudaBuffer<UsedSlotGpuType> gpu_slots_;
std::shared_ptr<paddle::memory::allocation::Allocation> gpu_slots_ = nullptr;
std::vector<UsedSlotGpuType> gpu_used_slots_;
std::vector<SlotRecord> ins_vec_;
const SlotRecord* batch_ins_ = nullptr;
Expand All @@ -1536,8 +1455,8 @@ class MiniBatchGpuPack {
// float tensor
LoDTensor float_tensor_;
// batch
HostBuffer<size_t> offsets_;
HostBuffer<void*> h_tensor_ptrs_;
std::vector<size_t> offsets_;
std::vector<void*> h_tensor_ptrs_;
// slot offset
LoDTensor gpu_slot_offsets_;
std::shared_ptr<paddle::memory::allocation::Allocation> slot_buf_ptr_ =
Expand Down
Loading

0 comments on commit a283817

Please sign in to comment.