From 5753cd0c28d2708bedddbb875f689dc3024127e6 Mon Sep 17 00:00:00 2001 From: mojingcj Date: Wed, 6 Dec 2023 16:57:35 +0800 Subject: [PATCH 1/4] fused_seqpool_cvm_with_conv support filter by threshold --- .../fused/fused_seqpool_cvm_with_conv_op.cc | 4 ++ .../fused/fused_seqpool_cvm_with_conv_op.cu | 55 +++++++++++++++++-- python/paddle/fluid/contrib/layers/nn.py | 8 +++ 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cc index 2cb1a0caf30ea..66bb9afdde8c6 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cc @@ -109,6 +109,10 @@ class FusedSeqpoolCVMOpWithConvMaker : public framework::OpProtoAndCheckerMaker "(float, default 0.0) The value to pad for empty sequence.") .SetDefault(0.0); AddAttr("use_cvm", "bool, use cvm or not").SetDefault(true); + AddAttr("need_filter", "(bool, default false)").SetDefault(false); + AddAttr("show_coeff", "(float, default 0.2)").SetDefault(0.2); + AddAttr("clk_coeff", "(float, default 1)").SetDefault(1); + AddAttr("threshold", "(float, default 0.96)").SetDefault(0.96); AddAttr("cvm_offset", "(int, default 3)").SetDefault(3); AddAttr("show_filter", "(bool, default false)").SetDefault(false); AddAttr("embedx_concate_size", "(int, default 1)").SetDefault(1); diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cu index cb56a9109e6c7..0e01eb1785132 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cu +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cu @@ -53,6 +53,38 @@ __global__ void FusedSeqpoolWithConvKernelNormal(const size_t N, T **input_value } } +// Filter +template +__global__ void FusedSeqpoolWithConvKernelFilter(const size_t N, T **input_values, + T **seqpool_output_values, + size_t **lods_values, + const int batch_size, + const int embedding_size, + const float pad_value, + const float show_coeff, + const float clk_coeff, + const float threshold) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + + double val = pad_value; + for (auto k = start; k < end; ++k) { + T &show = *(input_values[x] + k * embedding_size); + T &click = *(input_values[x] + k * embedding_size + 1); + if ((show - click) * show_coeff + click * clk_coeff < threshold) { + continue; + } + val += *(input_values[x] + k * embedding_size + offset); + } + *(seqpool_output_values[x] + y * embedding_size + offset) = val; + } +} + // normal & expand slot's feasign template __global__ void FusedSeqpoolWithConvKernelNormalEmbedxConcate(const size_t N, T **input_values, @@ -257,6 +289,8 @@ void FusedSeqpoolCVMWithConv(const paddle::platform::Place &place, std::vector lods, const int batch_size, const int slot_num, const int embedding_size, const float padding_value, const bool use_cvm, + float need_filter, float show_coeff, + float clk_coeff, float threshold, const int cvm_offset, bool show_filter, const int embedx_concate_size) { auto stream = dynamic_cast( @@ -290,10 +324,17 @@ void FusedSeqpoolCVMWithConv(const paddle::platform::Place &place, size_t N = static_cast(batch_size * slot_num * embedding_size); // first sum pool if (embedx_concate_size == 1){ + if (need_filter) { //filter + FusedSeqpoolWithConvKernelFilter<<>>( + N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, + embedding_size, padding_value, show_coeff, clk_coeff, threshold); + } else { //normal FusedSeqpoolWithConvKernelNormal<<>>( - N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, - embedding_size, padding_value); + stream>>>( + N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, + embedding_size, padding_value); + } } else { FusedSeqpoolWithConvKernelNormalEmbedxConcate<<>>( @@ -595,6 +636,10 @@ class FusedSeqpoolCVMWithConvCUDAKernel : public framework::OpKernel { auto padding_value = ctx.Attr("pad_value"); auto use_cvm = ctx.Attr("use_cvm"); + bool need_filter = ctx.Attr("need_filter"); + float show_coeff = ctx.Attr("show_coeff"); + float clk_coeff = ctx.Attr("clk_coeff"); + float threshold = ctx.Attr("threshold"); const int cvm_offset = ctx.Attr("cvm_offset"); bool show_filter = ctx.Attr("show_filter"); const int embedx_concate_size = ctx.Attr("embedx_concate_size"); @@ -638,7 +683,9 @@ class FusedSeqpoolCVMWithConvCUDAKernel : public framework::OpKernel { } FusedSeqpoolCVMWithConv(ctx.GetPlace(), input_data, output_data, seqpool_output_data, lods_data, batch_size, slot_size, - embedding_size, padding_value, use_cvm, cvm_offset, show_filter, embedx_concate_size); + embedding_size, padding_value, use_cvm, + need_filter, show_coeff, clk_coeff, threshold, + cvm_offset, show_filter, embedx_concate_size); } }; diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 5e3eb92f2d401..222ac19dca143 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -1908,6 +1908,10 @@ def fused_seqpool_cvm_with_conv(input, cvm, pad_value=0.0, use_cvm=True, + need_filter=False, + show_coeff=0.2, + clk_coeff=1.0, + threshold=0.96, show_filter=False, cvm_offset=3, embedx_concate_size=1): @@ -1955,6 +1959,10 @@ def fused_seqpool_cvm_with_conv(input, "pad_value": pad_value, "use_cvm": use_cvm, "cvm_offset": cvm_offset, + "need_filter": need_filter, + "show_coeff": show_coeff, + "clk_coeff": clk_coeff, + "threshold": threshold, "show_filter": show_filter, "embedx_concate_size": embedx_concate_size, }) From b888ec3ba90ef5f9eeb69248e8764e282c67297b Mon Sep 17 00:00:00 2001 From: yuandong1998 <1377526365@qq.com> Date: Wed, 20 Dec 2023 00:34:24 +0800 Subject: [PATCH 2/4] add fill zero in fused_seqpool_cvm --- .../operators/fused/fused_seqpool_cvm_op.cc | 2 ++ .../operators/fused/fused_seqpool_cvm_op.cu | 36 ++++++++++++++----- python/paddle/fluid/contrib/layers/nn.py | 6 ++-- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc index 1dfbc30d06606..3945c027364e7 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc @@ -48,6 +48,7 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { bool clk_filter = ctx->Attrs().Get("clk_filter"); const int embed_thres_size = ctx->Attrs().Get("embed_thres_size"); const int embedx_concate_size = ctx->Attrs().Get("embedx_concate_size"); + //const bool fill_zero = ctx->Attrs().Get("fill_zero"); // need filter quant_ratio more than zero if (ctx->Attrs().Get("need_filter")) { @@ -142,6 +143,7 @@ class FusedSeqpoolCVMOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("embed_thres_size", "(int, default 0)").SetDefault(0); AddAttr("embedx_concate_size", "(int, default 1)").SetDefault(1); AddAttr("embedx_concate_filter", "(bool, default false)").SetDefault(false); + AddAttr("fill_zero", "(bool, default true)").SetDefault(true); AddComment(R"DOC( Fuse multiple pairs of Sequence Pool and CVM Operator. diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu index 93cd61f6df2b4..45dc840d28995 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu @@ -177,7 +177,7 @@ __global__ void FusedSeqpoolKernelEmbedQuantFilterEmbedxConcate( size_t **lods_values, const int batch_size, const int embedding_size, const float pad_value, const int cvm_offset, const float show_coeff, const float clk_coeff, const float threshold, const int quant_ratio, - const float embed_threshold, const int embedx_concate_size, bool embedx_concate_filter) { + const float embed_threshold, const int embedx_concate_size, bool embedx_concate_filter, bool fill_zero) { CUDA_KERNEL_LOOP(i, N) { int key = i / embedding_size; int offset = i % embedding_size; // embedx id @@ -188,11 +188,17 @@ __global__ void FusedSeqpoolKernelEmbedQuantFilterEmbedxConcate( double val = pad_value; int concate_index = 0; + bool val_use_zero = false; for (auto k = start; k < end; ++k) { + val_use_zero = false; T &show = *(input_values[x] + k * embedding_size); T &click = *(input_values[x] + k * embedding_size + 1); if (embedx_concate_filter && (show - click) * show_coeff + click * clk_coeff < threshold) { - continue; + if (fill_zero) { + val_use_zero = true; + } else { + continue; + } } T &embedw = *(input_values[x] + k * embedding_size + cvm_offset); T embedx_weight_score = 0.0; @@ -202,16 +208,28 @@ __global__ void FusedSeqpoolKernelEmbedQuantFilterEmbedxConcate( } embedx_weight_score = std::sqrt(embedx_weight_score) + std::abs(embedw); if (embedx_concate_filter && embedx_weight_score < embed_threshold) { - continue; + if (fill_zero) { + val_use_zero = true; + } else { + continue; + } } if (offset < cvm_offset) { // show & click - val = *(input_values[x] + k * embedding_size + offset); + if (val_use_zero) { + val = pad_value; + } else { + val = *(input_values[x] + k * embedding_size + offset); + } } else { - val = ((static_cast( + if (val_use_zero) { + val = pad_value; + } else { + val = ((static_cast( *(input_values[x] + k * embedding_size + offset) * quant_ratio + 0.5)) / static_cast(quant_ratio)); + } } if (concate_index == embedx_concate_size) { *(seqpool_output_values[x] + y * embedding_size * embedx_concate_size + (embedx_concate_size-1) * embedding_size + offset) += val; @@ -352,7 +370,8 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, float clk_coeff, float threshold, float embed_threshold, const int quant_ratio, const bool clk_filter, const int embed_thres_size, const int embedx_concate_size, - bool embedx_concate_filter) { + bool embedx_concate_filter, + bool fill_zero) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get(place)) ->stream(); @@ -395,7 +414,7 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, 0, stream>>>( N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, embedding_size, padding_value, cvm_offset, show_coeff, clk_coeff, - threshold, quant_ratio, embed_threshold, embedx_concate_size, embedx_concate_filter); + threshold, quant_ratio, embed_threshold, embedx_concate_size, embedx_concate_filter, fill_zero); } } else if (need_filter) { // quant need filter FusedSeqpoolKernelQuantFilter<< { const int embed_thres_size = ctx.Attr("embed_thres_size"); const int embedx_concate_size = ctx.Attr("embedx_concate_size"); bool embedx_concate_filter = ctx.Attr("embedx_concate_filter"); + bool fill_zero = ctx.Attr("fill_zero"); framework::GPULodVector gpu_lods[slot_size]; auto place = ctx.GetPlace(); @@ -737,7 +757,7 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel { embedding_size, padding_value, use_cvm, cvm_offset, need_filter, embed_threshold_filter, show_coeff, clk_coeff, threshold, embed_threshold, quant_ratio, clk_filter, - embed_thres_size, embedx_concate_size, embedx_concate_filter); + embed_thres_size, embedx_concate_size, embedx_concate_filter, fill_zero); } }; diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 222ac19dca143..6d945cc9369fd 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -1759,7 +1759,8 @@ def fused_seqpool_cvm(input, clk_filter=False, embed_thres_size=0, embedx_concate_size=1, - embedx_concate_filter=False): + embedx_concate_filter=False, + fill_zero=True): """ **Notes: The Op only receives List of LoDTensor as input, only support SUM pooling now. :attr:`input`. @@ -1818,7 +1819,8 @@ def fused_seqpool_cvm(input, "clk_filter": clk_filter, "embed_thres_size": embed_thres_size, "embedx_concate_size": embedx_concate_size, - "embedx_concate_filter": embedx_concate_filter + "embedx_concate_filter": embedx_concate_filter, + "fill_zero": fill_zero }) return outs From 08fbd501278dc5b6e9d198752153ee09a3c02bdb Mon Sep 17 00:00:00 2001 From: shangzhongbin Date: Thu, 4 Jan 2024 20:55:19 +0800 Subject: [PATCH 3/4] add fused seq tensor && support transpose batch fc weight --- paddle/fluid/operators/batch_fc_op.cc | 56 ++++ paddle/fluid/operators/batch_fc_op.cu | 85 +++++ .../operators/fused/fused_seq_tensor_op.cc | 132 ++++++++ .../operators/fused/fused_seq_tensor_op.cu | 290 ++++++++++++++++++ .../operators/fused/fused_seq_tensor_op.h | 16 + python/paddle/fluid/contrib/layers/nn.py | 71 ++++- 6 files changed, 648 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/fused/fused_seq_tensor_op.cc create mode 100644 paddle/fluid/operators/fused/fused_seq_tensor_op.cu create mode 100644 paddle/fluid/operators/fused/fused_seq_tensor_op.h diff --git a/paddle/fluid/operators/batch_fc_op.cc b/paddle/fluid/operators/batch_fc_op.cc index 9ee4bad1d73b7..7cc1844393b03 100644 --- a/paddle/fluid/operators/batch_fc_op.cc +++ b/paddle/fluid/operators/batch_fc_op.cc @@ -44,6 +44,61 @@ class BatchFCOp : public framework::OperatorWithKernel { auto w_dims = ctx->GetInputDim("W"); int batchcount = ctx->Attrs().Get("batchcount"); + int transpose_weight = ctx->Attrs().Get("transpose_weight"); + + if (transpose_weight) { + // Input_dim: [batch_count, ?, in_dim] + // W_dim: [in_dim, batch_count * out_dim] + // Bias_dim: [1, batch_count * out_dim] + // Out_dim: [batch_count, ?, out_dim] + PADDLE_ENFORCE_GT( + batchcount, + 0, + platform::errors::PreconditionNotMet( + "with transpose weight, batchcount should > 0")); + PADDLE_ENFORCE_EQ( + w_dims.size(), + 2, + platform::errors::InvalidArgument( + "W of BatchFCOp should have 2D.")); + + int out_dim = w_dims[1] / batchcount; + PADDLE_ENFORCE_EQ( + input_dims.size(), + 3, + platform::errors::InvalidArgument( + "Input of BatchFCOp should have 3D.")); + PADDLE_ENFORCE_EQ( + input_dims[2], + w_dims[0], + platform::errors::InvalidArgument( + "Input.dim[2] and w_dims[0] of BatchFCOp should be same.")); + PADDLE_ENFORCE_EQ( + input_dims[0], + batchcount, + platform::errors::InvalidArgument( + "Input.dim[0] and batchcount of BatchFCOp should be same.")); + PADDLE_ENFORCE_EQ( + input_dims[2], + w_dims[0], + platform::errors::InvalidArgument( + "Input.dim[2] and W.dim[1] of BatchFCOp should be same.")); + + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ( + bias_dims.size(), + 2, + platform::errors::InvalidArgument("Bias of BatchFCOp should have 2D.")); + PADDLE_ENFORCE_EQ( + bias_dims[1], + w_dims[1], + platform::errors::InvalidArgument( + "Bias.dim[1] should be same as input.dim[2].")); + + ctx->SetOutputDim("Out", {input_dims[0], input_dims[1], out_dim}); + ctx->ShareLoD("Input", /*->*/ "Out"); + return; + } if (batchcount > 0) { int feature_dim = input_dims[1] / batchcount; PADDLE_ENFORCE_EQ(feature_dim, w_dims[0], @@ -139,6 +194,7 @@ class BatchFCOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Bias", "(Tensor) Input tensor of batch_fc_op operator."); AddOutput("Out", "Output tensor of batch_fc_op operator."); AddAttr("batchcount", "(int64_t) the batchcount").SetDefault(0); + AddAttr("transpose_weight", "(bool) the transpose_weight").SetDefault(false); AddComment(R"DOC( BatchFC Operator. Notice: It currently supports GPU device. diff --git a/paddle/fluid/operators/batch_fc_op.cu b/paddle/fluid/operators/batch_fc_op.cu index f9fac45ef6e5e..652eddb560099 100644 --- a/paddle/fluid/operators/batch_fc_op.cu +++ b/paddle/fluid/operators/batch_fc_op.cu @@ -171,11 +171,96 @@ void transpose_split_row(cudaStream_t stream, const unsigned int rown, stream>>>(rown, coln, num_block, source, dest); } +template +__global__ void transpose_weight_kernel(const T* source, T* dest, + const unsigned int rown, const unsigned int coln, const int64_t batch_count) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + if (x < rown && y < coln) { + int dst_coln = coln / batch_count; + int dst_x = x + y / dst_coln * rown; + int dst_y = y % dst_coln; + dest[dst_x * dst_coln + dst_y] = source[x * coln + y]; + } +} + +template +void transpose_weight_impl(cudaStream_t stream, const T* source, T* dest, + const unsigned int rown, const unsigned int coln, const int64_t batch_count) { + dim3 grid((rown + 15) / 16, (coln + 15) / 16); + dim3 block(16, 16); + transpose_weight_kernel<<>>(source, dest, rown, coln, batch_count); +} + template class BatchFCCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { int batchcount = ctx.Attr("batchcount"); + auto transpose_weight = ctx.Attr("transpose_weight"); + if (transpose_weight) { + // Input_dim: [batch_count, ?, in_dim] + // W_dim: [in_dim, batch_count * out_dim] + // Bias_dim: [1, batch_count * out_dim] + // Out_dim: [batch_count, ?, out_dim] + auto* input = ctx.Input("Input"); + auto* w = ctx.Input("W"); + auto* bias = ctx.Input("Bias"); + auto* output = ctx.Output("Out"); + auto input_dims = input->dims(); + auto w_dims = w->dims(); + auto slot_pairs_num = input_dims[0]; + auto ins_num = input_dims[1]; + auto in_dim = input_dims[2]; + auto out_dim = w_dims[1] / batchcount; + + // get data ptr + const T* in_data = input->data(); + const T* w_data = w->data(); + const T* bias_data = bias->data(); + + output->Resize({slot_pairs_num, ins_num, out_dim}); + T* out_data = output->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + + Tensor w_help; + w_help = + ctx.AllocateTmpTensor({batchcount, w_dims[0], w_dims[1] / batchcount}, dev_ctx); + T* w_help_data = w_help.data(); + + transpose_weight_impl(ctx.cuda_device_context().stream(), w_data, w_help_data, w_dims[0], w_dims[1], batchcount); + + CBLAS_TRANSPOSE transA = CblasNoTrans; + CBLAS_TRANSPOSE transB = CblasNoTrans; + + T alpha = 1; + T beta = 0; + int64_t strideA = ins_num * in_dim; + int64_t strideB = in_dim * out_dim; + + auto blas = phi::funcs::GetBlas(dev_ctx); + blas.BatchedGEMM(transA, + transB, + ins_num, + out_dim, + in_dim, + alpha, + in_data, + w_help_data, + beta, + out_data, + slot_pairs_num, + strideA, + strideB); + add_bias(ctx.cuda_device_context().stream(), + out_data, + slot_pairs_num, + ins_num, + out_dim, + bias_data); + return; + } if (batchcount > 0) { auto* input = ctx.Input("Input"); auto* w = ctx.Input("W"); diff --git a/paddle/fluid/operators/fused/fused_seq_tensor_op.cc b/paddle/fluid/operators/fused/fused_seq_tensor_op.cc new file mode 100644 index 0000000000000..5ca2ec345f10e --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seq_tensor_op.cc @@ -0,0 +1,132 @@ +#include "paddle/fluid/operators/fused/fused_seq_tensor_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include + +namespace paddle { +namespace operators { + +class FusedSeqTensorOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "FusedSeqTensorOp"); + OP_INOUT_CHECK(ctx->HasInput("ADInput"), "ADInput", "ADInput", "FusedSeqTensorOp"); + + OP_INOUT_CHECK(ctx->HasOutput("DINOut"), "DINOut", "DINOut", "FusedSeqTensorOp"); + OP_INOUT_CHECK(ctx->HasOutput("MaskOut"), "MaskOut", "MaskOut", "FusedSeqTensorOp"); + OP_INOUT_CHECK(ctx->HasOutput("SideInfoOut"), "SideInfoOut", "SideInfoOut", "FusedSeqTensorOp"); + OP_INOUT_CHECK(ctx->HasOutput("ADSlotSessionOut"), "ADSlotSessionOut", "ADSlotSessionOut", "FusedSeqTensorOp"); + + const framework::DDim input_dims = ctx->GetInputDim("Input"); + const framework::DDim ad_input_dims = ctx->GetInputDim("ADInput"); + + auto ad_slot_num = ctx->Attrs().Get("ad_slot_num"); + auto batch_count = ctx->Attrs().Get("batch_count"); + auto max_length = ctx->Attrs().Get("max_length"); + auto slot_num = ctx->Attrs().Get("slot_num"); + auto fea_emb_dim = ctx->Attrs().Get("fea_emb_dim"); + auto ad_slot_offset = ctx->Attrs().Get("ad_slot_offset"); + + int64_t one_ins_dim = batch_count * max_length * slot_num * fea_emb_dim; + PADDLE_ENFORCE_EQ( + input_dims[1], one_ins_dim, + platform::errors::InvalidArgument( + "input dims error, %ld != %ld", input_dims[1], one_ins_dim)); + + int64_t one_ins_ad_dim = batch_count * 1 * ad_slot_num * fea_emb_dim; + PADDLE_ENFORCE_EQ( + ad_input_dims[1], one_ins_ad_dim, + platform::errors::InvalidArgument( + "input dims error, %ld != %ld", ad_input_dims[1], one_ins_ad_dim)); + PADDLE_ENFORCE_LT( + ad_slot_num, slot_num, + platform::errors::InvalidArgument( + "ad_slot_num [%ld] > slot_num [%ld]", ad_slot_num, slot_num)); + PADDLE_ENFORCE_GT( + ad_slot_num, 0, + platform::errors::InvalidArgument( + "ad_slot_num [%ld] <= 0", ad_slot_num)); + PADDLE_ENFORCE_LT( + ad_slot_offset, slot_num - 1, + platform::errors::InvalidArgument( + "ad_slot_num [%ld] > slot_num - 1 [%ld]", ad_slot_offset, slot_num)); + PADDLE_ENFORCE_GE( + ad_slot_offset, 0, + platform::errors::InvalidArgument( + "ad_slot_offset [%ld] < 0", ad_slot_offset)); + if (ad_slot_offset != 0) { + PADDLE_ENFORCE_EQ( + ad_slot_num + ad_slot_offset, slot_num, + platform::errors::InvalidArgument( + "ad_slot_num [%ld] + ad_slot_offset [%ld] != slot_num [%ld]", ad_slot_num, ad_slot_offset, slot_num)); + } + + auto ins_num = input_dims[0]; + if (batch_count > 1) { + ctx->SetOutputDim("DINOut", {batch_count, ins_num * max_length, ad_slot_num * fea_emb_dim * 4}); + ctx->SetOutputDim("MaskOut", {batch_count, ins_num, max_length}); + ctx->SetOutputDim("SideInfoOut", {batch_count, ins_num * max_length, (slot_num - ad_slot_num) * fea_emb_dim}); + ctx->SetOutputDim("ADSlotSessionOut", {batch_count, ins_num * max_length, ad_slot_num, fea_emb_dim}); + } else { + ctx->SetOutputDim("DINOut", {ins_num, max_length, ad_slot_num * fea_emb_dim * 4}); + ctx->SetOutputDim("MaskOut", {ins_num, max_length}); + ctx->SetOutputDim("SideInfoOut", {ins_num, max_length, (slot_num - ad_slot_num) * fea_emb_dim}); + ctx->SetOutputDim("ADSlotSessionOut", {ins_num, max_length, ad_slot_num * fea_emb_dim}); + } + ctx->ShareLoD("Input", "DINOut"); + ctx->ShareLoD("Input", "MaskOut"); + ctx->ShareLoD("Input", "SideInfoOut"); + ctx->ShareLoD("Input", "ADSlotSessionOut"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); + } +}; + +class FusedSeqTensorOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", + "The input tensors of operator."); + AddInput("ADInput", + "The input ad tensors of operator. "); + AddOutput("DINOut", + "DINOut"); + AddOutput("MaskOut", + "MaskOut"); + AddOutput("SideInfoOut", + "SideInfoOut"); + AddOutput("ADSlotSessionOut", + "ADSlotSessionOut"); + + AddAttr("batch_count", "(int, default 1)"); + AddAttr("max_length", "(int, default 1)"); + AddAttr("slot_num", "(int, default 1)"); + AddAttr("fea_emb_dim", "(int, default 1)"); + AddAttr("ad_slot_num", "(int, default 1)"); + AddAttr("ad_slot_offset", "(int, default 1)"); + + AddComment(R"DOC( +Fuse seq tensor. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(fused_seq_tensor, + ops::FusedSeqTensorOp, ops::FusedSeqTensorOpMaker); + +REGISTER_OP_CPU_KERNEL( + fused_seq_tensor, + ops::FusedSeqTensorCPUKernel); diff --git a/paddle/fluid/operators/fused/fused_seq_tensor_op.cu b/paddle/fluid/operators/fused/fused_seq_tensor_op.cu new file mode 100644 index 0000000000000..d2fdf364d731d --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seq_tensor_op.cu @@ -0,0 +1,290 @@ +#include +#include +#include +#include "paddle/fluid/operators/fused/fused_seq_tensor_op.h" // don't remove this +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace operators { + +template +__global__ void cal_ad_slot_session_kernel(const T* input, + const T* ad_input, + T* din_output, + T* ad_slot_session_output, + const size_t batch_num, + const size_t ins_num, + const size_t slot_num, + const size_t max_length, + const size_t fea_emb_dim, + const size_t ad_slot_num, + const size_t ad_slot_offset) { + + size_t batch_idx = blockIdx.x; + size_t ins_idx = blockIdx.y; + size_t fea_idx = blockIdx.z; + + const size_t one_slot_dim = max_length * fea_emb_dim; + const size_t one_seq_dim = slot_num * one_slot_dim; + const size_t ad_seq_dim = ad_slot_num * one_slot_dim; + + const size_t piece_of_ad_seq_dim = ad_slot_num * fea_emb_dim; + for (size_t idx = threadIdx.x; idx < piece_of_ad_seq_dim; idx += blockDim.x) { + size_t slot_idx = idx / fea_emb_dim + ad_slot_offset; + size_t out_slot_idx = idx / fea_emb_dim; + size_t fea_dim_idx = idx % fea_emb_dim; + + size_t input_fea_begin_idx = ins_idx * (batch_num * one_seq_dim) + batch_idx * one_seq_dim + + slot_idx * one_slot_dim + fea_idx * fea_emb_dim; + + size_t ad_fea_begin_idx = + ins_idx * (1 * batch_num * piece_of_ad_seq_dim) + batch_idx * piece_of_ad_seq_dim + + out_slot_idx * fea_emb_dim; + + const T input_val = input[input_fea_begin_idx + fea_dim_idx]; + const T ad_val = ad_input[ad_fea_begin_idx + fea_dim_idx]; + + size_t fea_concat_start_idx = + batch_idx * (ins_num * ad_seq_dim * 4) + ins_idx * (ad_seq_dim * 4) + + fea_idx * (piece_of_ad_seq_dim * 4) + out_slot_idx * fea_emb_dim; + + din_output[fea_concat_start_idx + fea_dim_idx] = input_val; + din_output[fea_concat_start_idx + fea_dim_idx + piece_of_ad_seq_dim] = ad_val; + din_output[fea_concat_start_idx + fea_dim_idx + piece_of_ad_seq_dim * 2] = input_val - ad_val; + din_output[fea_concat_start_idx + fea_dim_idx + piece_of_ad_seq_dim * 3] = input_val * ad_val; + + size_t ad_slot_session_out_start_idx = + batch_idx * (ins_num * ad_seq_dim) + ins_idx * ad_seq_dim + + fea_idx * piece_of_ad_seq_dim + out_slot_idx * fea_emb_dim; + ad_slot_session_output[ad_slot_session_out_start_idx + fea_dim_idx] = input_val; + } +} + +template +__global__ void cal_sideinfo_kernel(const T* input, + T* side_info_output, + const size_t batch_num, + const size_t ins_num, + const size_t slot_num, + const size_t max_length, + const size_t fea_emb_dim, + const size_t sideinfo_slot_num, + const size_t sideinfo_slot_offset) { + + size_t batch_idx = blockIdx.x; + size_t ins_idx = blockIdx.y; + size_t fea_idx = blockIdx.z; + + const size_t one_slot_dim = max_length * fea_emb_dim; + const size_t input_one_seq_dim = slot_num * one_slot_dim; + const size_t sideinfo_seq_dim = sideinfo_slot_num * one_slot_dim; + + const size_t piece_of_sideinfo_seq_dim = sideinfo_slot_num * fea_emb_dim; + for (size_t idx = threadIdx.x; idx < piece_of_sideinfo_seq_dim; idx += blockDim.x) { + size_t out_slot_idx = idx / fea_emb_dim; + size_t slot_idx = out_slot_idx + sideinfo_slot_offset; + size_t fea_dim_idx = idx % fea_emb_dim; + + size_t input_fea_begin_idx = ins_idx * (batch_num * input_one_seq_dim) + batch_idx * input_one_seq_dim + + slot_idx * one_slot_dim + fea_idx * fea_emb_dim; + + size_t fea_transpose_start_idx = + batch_idx * (ins_num * sideinfo_seq_dim) + ins_idx * sideinfo_seq_dim + + fea_idx * (sideinfo_slot_num * fea_emb_dim) + out_slot_idx * fea_emb_dim; + + side_info_output[fea_transpose_start_idx + fea_dim_idx] = input[input_fea_begin_idx + fea_dim_idx]; + } +} + +template +__global__ void cal_sideinfo_kernel_without_loop(const T* input, + T* side_info_output, + const size_t batch_num, + const size_t ins_num, + const size_t slot_num, + const size_t max_length, + const size_t fea_emb_dim, + const size_t sideinfo_slot_num, + const size_t sideinfo_slot_offset) { + + size_t batch_idx = blockIdx.x; + size_t ins_idx = blockIdx.y; + size_t fea_idx = blockIdx.z; + + size_t slot_idx = threadIdx.y + sideinfo_slot_offset; + size_t out_slot_idx = threadIdx.y; + size_t fea_dim_idx = threadIdx.x; + + const size_t one_slot_dim = max_length * fea_emb_dim; + size_t input_one_seq_dim = slot_num * one_slot_dim; + size_t out_one_seq_dim = sideinfo_slot_num * one_slot_dim; + + size_t input_fea_begin_idx = ins_idx * (batch_num * input_one_seq_dim) + batch_idx * (input_one_seq_dim) + + slot_idx * one_slot_dim + fea_idx * fea_emb_dim; + + size_t fea_transpose_start_idx = + batch_idx * (ins_num * out_one_seq_dim) + ins_idx * out_one_seq_dim + + fea_idx * (sideinfo_slot_num * fea_emb_dim) + out_slot_idx * fea_emb_dim; + + side_info_output[fea_transpose_start_idx + fea_dim_idx] = input[input_fea_begin_idx + fea_dim_idx]; +} + +template +__device__ void warpReduce(volatile T* cache, int tid) { + cache[tid] += cache[tid+32]; + cache[tid] += cache[tid+16]; + cache[tid] += cache[tid+8]; + cache[tid] += cache[tid+4]; + cache[tid] += cache[tid+2]; + cache[tid] += cache[tid+1]; +} + +#define THREAD_PER_BLOCK 128 +template +__global__ void reduce_sum_max_length(const T* input, // 1 + T* mask_output, // mask + const size_t batch_count, + const size_t ins_num, + const size_t slot_num, + const size_t max_length, + const size_t fea_emb_dim) { + size_t batch_idx = blockIdx.x; + size_t ins_idx = blockIdx.y; + size_t fea_idx = blockIdx.z; + + size_t data_len_per_block = slot_num * fea_emb_dim; + + __shared__ T sdata[THREAD_PER_BLOCK]; + //each thread loads one element from global memory to shared mem + size_t input_start_idx = ins_idx * (batch_count * slot_num * max_length * fea_emb_dim) + + batch_idx * (slot_num * max_length * fea_emb_dim); + + size_t tid = threadIdx.x; + // memset shared mem + sdata[tid] = 0; + for (size_t idx = tid; idx < data_len_per_block; idx += blockDim.x) { + size_t slot_idx = idx / fea_emb_dim; + size_t fea_dim_idx = idx % fea_emb_dim; + size_t offset = slot_idx * (max_length * fea_emb_dim) + fea_idx * fea_emb_dim + fea_dim_idx; + sdata[tid] += input[input_start_idx + offset]; + } + __syncthreads(); + + for(size_t s = blockDim.x / 2; s > 32; s >>= 1) { + if (tid < s) { + sdata[tid] += sdata[tid + s]; + } + __syncthreads(); + } + // When s < 32, we have only one warp left, no need to sync threads, no need to if (tid < s) + if(tid < 32) { + warpReduce(sdata, tid); + } + + if(tid == 0) { + // [batch_count, ins_num, max_length] + size_t out_idx = batch_idx * (ins_num * max_length) + + ins_idx * (max_length) + + fea_idx; + if (fabs(sdata[tid]) > 1e-8) { + mask_output[out_idx] = 1; + } else { + mask_output[out_idx] = 0; + } + } +} + +template +class FusedSeqTensorCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto input = ctx.Input("Input"); + PADDLE_ENFORCE_NOT_NULL(input, platform::errors::NotFound("Input not found")); + auto ad_input = ctx.Input("ADInput"); + PADDLE_ENFORCE_NOT_NULL(ad_input, platform::errors::NotFound("Input not found")); + + auto din_output = ctx.Output("DINOut"); + PADDLE_ENFORCE_NOT_NULL(din_output, + platform::errors::NotFound("DINOut not found")); + T* din_output_data = din_output->mutable_data(ctx.GetPlace()); + auto mask_output = ctx.Output("MaskOut"); + PADDLE_ENFORCE_NOT_NULL(mask_output, + platform::errors::NotFound("MaskOut not found")); + T* mask_output_output_data = mask_output->mutable_data(ctx.GetPlace()); + auto side_info_output = ctx.Output("SideInfoOut"); + PADDLE_ENFORCE_NOT_NULL(side_info_output, + platform::errors::NotFound("Output not found")); + T* side_info_output_data = + side_info_output->mutable_data(ctx.GetPlace()); + auto ad_slot_session_output = + ctx.Output("ADSlotSessionOut"); + PADDLE_ENFORCE_NOT_NULL(ad_slot_session_output, + platform::errors::NotFound("Output not found")); + T* ad_slot_session_output_data = + ad_slot_session_output->mutable_data(ctx.GetPlace()); + + auto batch_count = ctx.Attr("batch_count"); + auto max_length = ctx.Attr("max_length"); + auto slot_num = ctx.Attr("slot_num"); + auto fea_emb_dim = ctx.Attr("fea_emb_dim"); + auto ad_slot_num = ctx.Attr("ad_slot_num"); + auto ad_slot_offset = ctx.Attr("ad_slot_offset"); + + auto& dev_ctx = ctx.template device_context(); + auto stream = ctx.cuda_device_context().stream(); + + auto input_dims = input->dims(); + size_t ins_num = input_dims[0]; + + dim3 ad_grid(batch_count, ins_num, max_length); + dim3 ad_block(std::min(static_cast(1024), static_cast(ad_slot_num * fea_emb_dim))); + + cal_ad_slot_session_kernel<<>>( + input->data(), ad_input->data(), din_output_data, + ad_slot_session_output_data, + batch_count, ins_num, slot_num, max_length, fea_emb_dim, + ad_slot_num, ad_slot_offset); + + size_t sideinfo_slot_offset = 0; + if (ad_slot_offset == 0) { + sideinfo_slot_offset = ad_slot_num; + } + size_t fea_padding_dim = ((fea_emb_dim + 31) / 32) * 32; + size_t sideinfo_slot_num = slot_num - ad_slot_num; + + if (sideinfo_slot_num * fea_emb_dim < 1024) { + dim3 sideinfo_grid(batch_count, ins_num, max_length); + dim3 sideinfo_block(fea_emb_dim, sideinfo_slot_num); + cal_sideinfo_kernel_without_loop<<>>( + input->data(), side_info_output_data, batch_count, ins_num, + slot_num, max_length, fea_emb_dim, + sideinfo_slot_num, sideinfo_slot_offset); + } else { + dim3 sideinfo_grid(batch_count, ins_num, max_length); + dim3 sideinfo_block(sideinfo_slot_num * fea_emb_dim); + cal_sideinfo_kernel<<>>( + input->data(), side_info_output_data, batch_count, ins_num, + slot_num, max_length, fea_emb_dim, + sideinfo_slot_num, sideinfo_slot_offset); + } + + dim3 reduce_grid(batch_count, ins_num, max_length); + dim3 reduce_block(THREAD_PER_BLOCK); + reduce_sum_max_length<<>>( + input->data(), mask_output_output_data, batch_count, + ins_num, slot_num, max_length, fea_emb_dim); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + fused_seq_tensor, + ops::FusedSeqTensorCUDAKernel); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_seq_tensor_op.h b/paddle/fluid/operators/fused/fused_seq_tensor_op.h new file mode 100644 index 0000000000000..d7bbadd72e3b5 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seq_tensor_op.h @@ -0,0 +1,16 @@ +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class FusedSeqTensorCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext&) const override { + PADDLE_THROW(platform::errors::Unimplemented("fused_seq_tensor supports only GPU")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 6d945cc9369fd..5143303a286b1 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -71,6 +71,7 @@ 'fused_seqpool_concat', 'fused_concat', 'rank_attention2', + 'fused_seq_tensor', ] @@ -1601,7 +1602,7 @@ def rank_attention2(input, return output -def batch_fc(input, param_size, param_attr, bias_size, bias_attr, act=None, batchcount=0): +def batch_fc(input, param_size, param_attr, bias_size, bias_attr, act=None, batchcount=0, transpose_weight=False): """ **Batch FC layer** This Op can calculate BatchFC. This is similar to matmul op, @@ -1666,7 +1667,10 @@ def batch_fc(input, param_size, param_attr, bias_size, bias_attr, act=None, batc "W": w, "Bias": b }, - attrs={'batchcount': batchcount}, + attrs={ + 'batchcount': batchcount, + 'transpose_weight': transpose_weight + }, outputs={"Out": pre_act}) return helper.append_activation(pre_act) @@ -2827,3 +2831,66 @@ def fused_concat(input, start_index=0, length=-1, axis=1): "length": length}) return out +def fused_seq_tensor(input, + batch_count, + max_length, + slot_num, + ad_slot_num, + fea_emb_dim, + ad_slot_offset): + """ + **fused seq tensor** + Notice: It currently only supports GPU device. + + Args: + input: [input, ad_input], input tensor list with data type float32. + batch_count: parrellel num. + max_length: max_length. + slot_num: slot_num, sum of ad_slot_num and side info slot. + ad_slot_num: ad slot num. + fea_emb_dim: embding dim. + ad_slot_offset: ad slot offset. + + Returns: + Variable: + din_out, mask_out, side_info_out, ad_slot_session_out + """ + + helper = LayerHelper("fused_seq_tensor", **locals()) + + check_type(input, "input", list, 'fused_seq_tensor') + + dtype = helper.input_dtype() + check_dtype(dtype, 'input', ['float32', 'float64'], 'fused_seq_tensor') + + check_type(batch_count, 'batch_count', (int, Variable), 'fused_seq_tensor') + check_type(max_length, 'max_length', (int, Variable), 'fused_seq_tensor') + check_type(slot_num, 'slot_num', (int, Variable), 'fused_seq_tensor') + check_type(fea_emb_dim, 'fea_emb_dim', (int, Variable), 'fused_seq_tensor') + + din_out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + mask_out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + side_info_out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + ad_slot_session_out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + + helper.append_op( + type="fused_seq_tensor", + inputs={"Input": input[0], + "ADInput": input[1] + }, + attrs={ + 'batch_count': batch_count, + 'max_length': max_length, + 'slot_num': slot_num, + 'fea_emb_dim': fea_emb_dim, + 'ad_slot_num': ad_slot_num, + 'ad_slot_offset': ad_slot_offset + }, + outputs={ + "DINOut": din_out, + "MaskOut": mask_out, + "SideInfoOut": side_info_out, + "ADSlotSessionOut": ad_slot_session_out + }) + + return din_out, mask_out, side_info_out, ad_slot_session_out From 7acd059a1b1c45973926bfbcdf575516a930c0d5 Mon Sep 17 00:00:00 2001 From: shangzhongbin Date: Wed, 24 Jan 2024 10:53:59 +0800 Subject: [PATCH 4/4] fix fused query seq tensor compare case --- paddle/fluid/operators/fused/fused_seq_tensor_op.cc | 4 ++-- paddle/fluid/operators/fused/fused_seq_tensor_op.cu | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_seq_tensor_op.cc b/paddle/fluid/operators/fused/fused_seq_tensor_op.cc index 5ca2ec345f10e..7430d0d32ca37 100644 --- a/paddle/fluid/operators/fused/fused_seq_tensor_op.cc +++ b/paddle/fluid/operators/fused/fused_seq_tensor_op.cc @@ -47,10 +47,10 @@ class FusedSeqTensorOp : public framework::OperatorWithKernel { ad_slot_num, 0, platform::errors::InvalidArgument( "ad_slot_num [%ld] <= 0", ad_slot_num)); - PADDLE_ENFORCE_LT( + PADDLE_ENFORCE_LE( ad_slot_offset, slot_num - 1, platform::errors::InvalidArgument( - "ad_slot_num [%ld] > slot_num - 1 [%ld]", ad_slot_offset, slot_num)); + "ad_slot_num [%ld] > slot_num - 1 [%ld]", ad_slot_offset, slot_num)); PADDLE_ENFORCE_GE( ad_slot_offset, 0, platform::errors::InvalidArgument( diff --git a/paddle/fluid/operators/fused/fused_seq_tensor_op.cu b/paddle/fluid/operators/fused/fused_seq_tensor_op.cu index d2fdf364d731d..8210cd43808c3 100644 --- a/paddle/fluid/operators/fused/fused_seq_tensor_op.cu +++ b/paddle/fluid/operators/fused/fused_seq_tensor_op.cu @@ -145,8 +145,8 @@ __device__ void warpReduce(volatile T* cache, int tid) { #define THREAD_PER_BLOCK 128 template -__global__ void reduce_sum_max_length(const T* input, // 1 - T* mask_output, // mask +__global__ void reduce_sum_max_length(const T* input, + T* mask_output, const size_t batch_count, const size_t ins_num, const size_t slot_num,