Skip to content

Commit

Permalink
opt embedding score
Browse files Browse the repository at this point in the history
  • Loading branch information
humingqing authored and humingqing committed Feb 23, 2024
1 parent a4297df commit b3d53e6
Showing 1 changed file with 23 additions and 26 deletions.
49 changes: 23 additions & 26 deletions paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,16 @@ __global__ void FusedSeqpoolKernelQuantFilter(const size_t N,

T val = pad_value;
for (auto k = start; k < end; ++k) {
T &show = *(input_values[x] + k * embedding_size);
T &click = *(input_values[x] + k * embedding_size + 1);
T *in = (input_values[x] + k * embedding_size);
T &show = in[0];
T &click = in[1];
if ((show - click) * show_coeff + click * clk_coeff < threshold) {
continue;
}
if (offset < cvm_offset) { // show & click
val += *(input_values[x] + k * embedding_size + offset);
val += in[offset];
} else {
val += ((static_cast<int>(
*(input_values[x] + k * embedding_size + offset) *
quant_ratio +
0.5)) /
val += ((static_cast<int>(in[offset] * quant_ratio + 0.5)) /
static_cast<float>(quant_ratio));
}
}
Expand Down Expand Up @@ -155,19 +153,20 @@ __global__ void KernelEmbedQuantFilter(T **input_values,
auto &slot_offset = gpu_slot_fea_offsets[x];
int *out_ptr = &gpu_slot_fea_flag[slot_offset];
for (auto k = threadIdx.x + start; k < end; k += blockDim.x) {
T &show = *(input_values[x] + k * embedding_size);
T &click = *(input_values[x] + k * embedding_size + 1);
T *in = (input_values[x] + k * embedding_size);
T &show = in[0]; // show
T &click = in[1]; // click
if ((show - click) * show_coeff + click * clk_coeff < threshold) {
out_ptr[k] = 0;
continue;
}
T &embedw = *(input_values[x] + k * embedding_size + cvm_offset);
T *embed_ptr = &in[cvm_offset];
T embedx_weight_score = 0.0;
for (int i = cvm_offset + 1; i < cvm_offset + embed_thres_size; i++) {
embedx_weight_score +=
pow(*(input_values[x] + k * embedding_size + i), 2);
for (int i = 1; i < embed_thres_size; ++i) {
embedx_weight_score += embed_ptr[i] * embed_ptr[i];
}
embedx_weight_score = std::sqrt(embedx_weight_score) + std::abs(embedw);
embedx_weight_score =
std::sqrt(embedx_weight_score) + std::abs(embed_ptr[0]);
if (embedx_weight_score < embed_threshold) {
out_ptr[k] = 0;
continue;
Expand Down Expand Up @@ -646,11 +645,11 @@ __global__ void FusedSeqpoolCVMGradKernelWithCVM(const size_t N,
int x = key / batch_size; // slot id
int y = key % batch_size; // ins id
auto &start = lods_values[x * (batch_size + 1) + y];
auto &end = lods_values[x * (batch_size + 1) + y + 1];
T &val = (offset < cvm_offset)
? *(cvm_values[x] + y * cvm_offset + offset)
: *(out_grads_values[x] + y * embedding_size + offset);
auto &start = lods_values[x * (batch_size + 1) + y];
auto &end = lods_values[x * (batch_size + 1) + y + 1];
for (auto k = start; k < end; ++k) {
*(in_grads_values[x] + k * embedding_size + offset) = val;
}
Expand All @@ -672,12 +671,12 @@ __global__ void FusedSeqpoolCVMGradKernelWithShow(const size_t N,
int x = key / batch_size; // slot id
int y = key % batch_size; // ins id
auto &start = lods_values[x * (batch_size + 1) + y];
auto &end = lods_values[x * (batch_size + 1) + y + 1];
T &val =
(offset < cvm_offset)
? *(cvm_values[x] + y * cvm_offset + offset)
: *(out_grads_values[x] + y * (embedding_size - 1) + offset - 1);
auto &start = lods_values[x * (batch_size + 1) + y];
auto &end = lods_values[x * (batch_size + 1) + y + 1];
for (auto k = start; k < end; ++k) {
*(in_grads_values[x] + k * embedding_size + offset) = val;
}
Expand Down Expand Up @@ -705,15 +704,13 @@ __global__ void FusedSeqpoolCVMGradKernelWithShowConcate(
int x = key / batch_size; // slot id
int y = key % batch_size; // ins id
auto &start = lods_values[x * (batch_size + 1) + y];
auto &end = lods_values[x * (batch_size + 1) + y + 1];
T &val = (offset < cvm_offset)
? *(cvm_values[x] + y * cvm_offset + offset)
: *(out_grads_values[x] +
y * (embedding_size - 1) * embedx_concate_size +
(embedding_size - 1) * concate_index + offset - 1);
auto &start = lods_values[x * (batch_size + 1) + y];
auto &end = lods_values[x * (batch_size + 1) + y + 1];
auto concat_end = start + concate_index + 1;
if (concat_end > end || concate_index == embedx_concate_size - 1) {
concat_end = end;
Expand Down Expand Up @@ -741,8 +738,6 @@ __global__ void FusedSeqpoolCVMGradKernelNoCVM(const size_t N,
int x = key / batch_size; // slot id
int y = key % batch_size; // ins id
auto &start = lods_values[x * (batch_size + 1) + y];
auto &end = lods_values[x * (batch_size + 1) + y + 1];
T val = 0;
if (embed_thres_size == 0) {
val = (offset < cvm_offset)
Expand All @@ -756,6 +751,8 @@ __global__ void FusedSeqpoolCVMGradKernelNoCVM(const size_t N,
y * (embedding_size - cvm_offset - embed_thres_size) +
offset - cvm_offset - embed_thres_size);
}
auto &start = lods_values[x * (batch_size + 1) + y];
auto &end = lods_values[x * (batch_size + 1) + y + 1];
for (auto k = start; k < end; ++k) {
*(in_grads_values[x] + k * embedding_size + offset) = val;
}
Expand Down Expand Up @@ -783,14 +780,14 @@ __global__ void FusedSeqpoolCVMGradKernelNoCVMConcate(
int x = key / batch_size; // slot id
int y = key % batch_size; // ins id
auto &start = lods_values[x * (batch_size + 1) + y];
auto &end = lods_values[x * (batch_size + 1) + y + 1];
T &val = (offset < cvm_offset)
? *(cvm_values[x] + y * cvm_offset + offset)
: *(out_grads_values[x] +
y * (embedding_size - cvm_offset) * embedx_concate_size +
(embedding_size - cvm_offset) * concate_index + offset -
cvm_offset);
auto &start = lods_values[x * (batch_size + 1) + y];
auto &end = lods_values[x * (batch_size + 1) + y + 1];
auto concat_end = start + concate_index + 1;
if (concat_end > end || concate_index == embedx_concate_size - 1) {
concat_end = end;
Expand Down

0 comments on commit b3d53e6

Please sign in to comment.