From b3d53e6ab594f411eb656ecc41e0cd7c886e4c68 Mon Sep 17 00:00:00 2001 From: humingqing Date: Fri, 23 Feb 2024 16:20:59 +0800 Subject: [PATCH] opt embedding score --- .../operators/fused/fused_seqpool_cvm_op.cu | 49 +++++++++---------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu index 1af23cda2e3e8..99c9b2cca7f6f 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu @@ -114,18 +114,16 @@ __global__ void FusedSeqpoolKernelQuantFilter(const size_t N, T val = pad_value; for (auto k = start; k < end; ++k) { - T &show = *(input_values[x] + k * embedding_size); - T &click = *(input_values[x] + k * embedding_size + 1); + T *in = (input_values[x] + k * embedding_size); + T &show = in[0]; + T &click = in[1]; if ((show - click) * show_coeff + click * clk_coeff < threshold) { continue; } if (offset < cvm_offset) { // show & click - val += *(input_values[x] + k * embedding_size + offset); + val += in[offset]; } else { - val += ((static_cast( - *(input_values[x] + k * embedding_size + offset) * - quant_ratio + - 0.5)) / + val += ((static_cast(in[offset] * quant_ratio + 0.5)) / static_cast(quant_ratio)); } } @@ -155,19 +153,20 @@ __global__ void KernelEmbedQuantFilter(T **input_values, auto &slot_offset = gpu_slot_fea_offsets[x]; int *out_ptr = &gpu_slot_fea_flag[slot_offset]; for (auto k = threadIdx.x + start; k < end; k += blockDim.x) { - T &show = *(input_values[x] + k * embedding_size); - T &click = *(input_values[x] + k * embedding_size + 1); + T *in = (input_values[x] + k * embedding_size); + T &show = in[0]; // show + T &click = in[1]; // click if ((show - click) * show_coeff + click * clk_coeff < threshold) { out_ptr[k] = 0; continue; } - T &embedw = *(input_values[x] + k * embedding_size + cvm_offset); + T *embed_ptr = &in[cvm_offset]; T embedx_weight_score = 0.0; - for (int i = cvm_offset + 1; i < cvm_offset + embed_thres_size; i++) { - embedx_weight_score += - pow(*(input_values[x] + k * embedding_size + i), 2); + for (int i = 1; i < embed_thres_size; ++i) { + embedx_weight_score += embed_ptr[i] * embed_ptr[i]; } - embedx_weight_score = std::sqrt(embedx_weight_score) + std::abs(embedw); + embedx_weight_score = + std::sqrt(embedx_weight_score) + std::abs(embed_ptr[0]); if (embedx_weight_score < embed_threshold) { out_ptr[k] = 0; continue; @@ -646,11 +645,11 @@ __global__ void FusedSeqpoolCVMGradKernelWithCVM(const size_t N, int x = key / batch_size; // slot id int y = key % batch_size; // ins id - auto &start = lods_values[x * (batch_size + 1) + y]; - auto &end = lods_values[x * (batch_size + 1) + y + 1]; T &val = (offset < cvm_offset) ? *(cvm_values[x] + y * cvm_offset + offset) : *(out_grads_values[x] + y * embedding_size + offset); + auto &start = lods_values[x * (batch_size + 1) + y]; + auto &end = lods_values[x * (batch_size + 1) + y + 1]; for (auto k = start; k < end; ++k) { *(in_grads_values[x] + k * embedding_size + offset) = val; } @@ -672,12 +671,12 @@ __global__ void FusedSeqpoolCVMGradKernelWithShow(const size_t N, int x = key / batch_size; // slot id int y = key % batch_size; // ins id - auto &start = lods_values[x * (batch_size + 1) + y]; - auto &end = lods_values[x * (batch_size + 1) + y + 1]; T &val = (offset < cvm_offset) ? *(cvm_values[x] + y * cvm_offset + offset) : *(out_grads_values[x] + y * (embedding_size - 1) + offset - 1); + auto &start = lods_values[x * (batch_size + 1) + y]; + auto &end = lods_values[x * (batch_size + 1) + y + 1]; for (auto k = start; k < end; ++k) { *(in_grads_values[x] + k * embedding_size + offset) = val; } @@ -705,15 +704,13 @@ __global__ void FusedSeqpoolCVMGradKernelWithShowConcate( int x = key / batch_size; // slot id int y = key % batch_size; // ins id - auto &start = lods_values[x * (batch_size + 1) + y]; - auto &end = lods_values[x * (batch_size + 1) + y + 1]; - T &val = (offset < cvm_offset) ? *(cvm_values[x] + y * cvm_offset + offset) : *(out_grads_values[x] + y * (embedding_size - 1) * embedx_concate_size + (embedding_size - 1) * concate_index + offset - 1); - + auto &start = lods_values[x * (batch_size + 1) + y]; + auto &end = lods_values[x * (batch_size + 1) + y + 1]; auto concat_end = start + concate_index + 1; if (concat_end > end || concate_index == embedx_concate_size - 1) { concat_end = end; @@ -741,8 +738,6 @@ __global__ void FusedSeqpoolCVMGradKernelNoCVM(const size_t N, int x = key / batch_size; // slot id int y = key % batch_size; // ins id - auto &start = lods_values[x * (batch_size + 1) + y]; - auto &end = lods_values[x * (batch_size + 1) + y + 1]; T val = 0; if (embed_thres_size == 0) { val = (offset < cvm_offset) @@ -756,6 +751,8 @@ __global__ void FusedSeqpoolCVMGradKernelNoCVM(const size_t N, y * (embedding_size - cvm_offset - embed_thres_size) + offset - cvm_offset - embed_thres_size); } + auto &start = lods_values[x * (batch_size + 1) + y]; + auto &end = lods_values[x * (batch_size + 1) + y + 1]; for (auto k = start; k < end; ++k) { *(in_grads_values[x] + k * embedding_size + offset) = val; } @@ -783,14 +780,14 @@ __global__ void FusedSeqpoolCVMGradKernelNoCVMConcate( int x = key / batch_size; // slot id int y = key % batch_size; // ins id - auto &start = lods_values[x * (batch_size + 1) + y]; - auto &end = lods_values[x * (batch_size + 1) + y + 1]; T &val = (offset < cvm_offset) ? *(cvm_values[x] + y * cvm_offset + offset) : *(out_grads_values[x] + y * (embedding_size - cvm_offset) * embedx_concate_size + (embedding_size - cvm_offset) * concate_index + offset - cvm_offset); + auto &start = lods_values[x * (batch_size + 1) + y]; + auto &end = lods_values[x * (batch_size + 1) + y + 1]; auto concat_end = start + concate_index + 1; if (concat_end > end || concate_index == embedx_concate_size - 1) { concat_end = end;