Skip to content

Commit

Permalink
Merge pull request #75 from shang1017/query_seq_slots
Browse files Browse the repository at this point in the history
add fused seq tensor && support transpose batch fc weight
  • Loading branch information
qingshui committed Jan 8, 2024
2 parents 1ee052d + 08fbd50 commit 61fa982
Show file tree
Hide file tree
Showing 6 changed files with 648 additions and 2 deletions.
56 changes: 56 additions & 0 deletions paddle/fluid/operators/batch_fc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,61 @@ class BatchFCOp : public framework::OperatorWithKernel {
auto w_dims = ctx->GetInputDim("W");

int batchcount = ctx->Attrs().Get<int>("batchcount");
int transpose_weight = ctx->Attrs().Get<bool>("transpose_weight");

if (transpose_weight) {
// Input_dim: [batch_count, ?, in_dim]
// W_dim: [in_dim, batch_count * out_dim]
// Bias_dim: [1, batch_count * out_dim]
// Out_dim: [batch_count, ?, out_dim]
PADDLE_ENFORCE_GT(
batchcount,
0,
platform::errors::PreconditionNotMet(
"with transpose weight, batchcount should > 0"));
PADDLE_ENFORCE_EQ(
w_dims.size(),
2,
platform::errors::InvalidArgument(
"W of BatchFCOp should have 2D."));

int out_dim = w_dims[1] / batchcount;
PADDLE_ENFORCE_EQ(
input_dims.size(),
3,
platform::errors::InvalidArgument(
"Input of BatchFCOp should have 3D."));
PADDLE_ENFORCE_EQ(
input_dims[2],
w_dims[0],
platform::errors::InvalidArgument(
"Input.dim[2] and w_dims[0] of BatchFCOp should be same."));
PADDLE_ENFORCE_EQ(
input_dims[0],
batchcount,
platform::errors::InvalidArgument(
"Input.dim[0] and batchcount of BatchFCOp should be same."));
PADDLE_ENFORCE_EQ(
input_dims[2],
w_dims[0],
platform::errors::InvalidArgument(
"Input.dim[2] and W.dim[1] of BatchFCOp should be same."));

auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(
bias_dims.size(),
2,
platform::errors::InvalidArgument("Bias of BatchFCOp should have 2D."));
PADDLE_ENFORCE_EQ(
bias_dims[1],
w_dims[1],
platform::errors::InvalidArgument(
"Bias.dim[1] should be same as input.dim[2]."));

ctx->SetOutputDim("Out", {input_dims[0], input_dims[1], out_dim});
ctx->ShareLoD("Input", /*->*/ "Out");
return;
}
if (batchcount > 0) {
int feature_dim = input_dims[1] / batchcount;
PADDLE_ENFORCE_EQ(feature_dim, w_dims[0],
Expand Down Expand Up @@ -139,6 +194,7 @@ class BatchFCOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Bias", "(Tensor) Input tensor of batch_fc_op operator.");
AddOutput("Out", "Output tensor of batch_fc_op operator.");
AddAttr<int>("batchcount", "(int64_t) the batchcount").SetDefault(0);
AddAttr<bool>("transpose_weight", "(bool) the transpose_weight").SetDefault(false);
AddComment(R"DOC(
BatchFC Operator.
Notice: It currently supports GPU device.
Expand Down
85 changes: 85 additions & 0 deletions paddle/fluid/operators/batch_fc_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,96 @@ void transpose_split_row(cudaStream_t stream, const unsigned int rown,
stream>>>(rown, coln, num_block, source, dest);
}

template <typename T>
__global__ void transpose_weight_kernel(const T* source, T* dest,
const unsigned int rown, const unsigned int coln, const int64_t batch_count) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
if (x < rown && y < coln) {
int dst_coln = coln / batch_count;
int dst_x = x + y / dst_coln * rown;
int dst_y = y % dst_coln;
dest[dst_x * dst_coln + dst_y] = source[x * coln + y];
}
}

template <typename T>
void transpose_weight_impl(cudaStream_t stream, const T* source, T* dest,
const unsigned int rown, const unsigned int coln, const int64_t batch_count) {
dim3 grid((rown + 15) / 16, (coln + 15) / 16);
dim3 block(16, 16);
transpose_weight_kernel<<<grid, block, 0, stream>>>(source, dest, rown, coln, batch_count);
}

template <typename DeviceContext, typename T>
class BatchFCCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int batchcount = ctx.Attr<int>("batchcount");
auto transpose_weight = ctx.Attr<bool>("transpose_weight");
if (transpose_weight) {
// Input_dim: [batch_count, ?, in_dim]
// W_dim: [in_dim, batch_count * out_dim]
// Bias_dim: [1, batch_count * out_dim]
// Out_dim: [batch_count, ?, out_dim]
auto* input = ctx.Input<framework::LoDTensor>("Input");
auto* w = ctx.Input<Tensor>("W");
auto* bias = ctx.Input<Tensor>("Bias");
auto* output = ctx.Output<framework::LoDTensor>("Out");
auto input_dims = input->dims();
auto w_dims = w->dims();
auto slot_pairs_num = input_dims[0];
auto ins_num = input_dims[1];
auto in_dim = input_dims[2];
auto out_dim = w_dims[1] / batchcount;

// get data ptr
const T* in_data = input->data<T>();
const T* w_data = w->data<T>();
const T* bias_data = bias->data<T>();

output->Resize({slot_pairs_num, ins_num, out_dim});
T* out_data = output->mutable_data<T>(ctx.GetPlace());

auto& dev_ctx = ctx.template device_context<phi::GPUContext>();

Tensor w_help;
w_help =
ctx.AllocateTmpTensor<T, DeviceContext>({batchcount, w_dims[0], w_dims[1] / batchcount}, dev_ctx);
T* w_help_data = w_help.data<T>();

transpose_weight_impl<T>(ctx.cuda_device_context().stream(), w_data, w_help_data, w_dims[0], w_dims[1], batchcount);

CBLAS_TRANSPOSE transA = CblasNoTrans;
CBLAS_TRANSPOSE transB = CblasNoTrans;

T alpha = 1;
T beta = 0;
int64_t strideA = ins_num * in_dim;
int64_t strideB = in_dim * out_dim;

auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx);
blas.BatchedGEMM(transA,
transB,
ins_num,
out_dim,
in_dim,
alpha,
in_data,
w_help_data,
beta,
out_data,
slot_pairs_num,
strideA,
strideB);
add_bias<T>(ctx.cuda_device_context().stream(),
out_data,
slot_pairs_num,
ins_num,
out_dim,
bias_data);
return;
}
if (batchcount > 0) {
auto* input = ctx.Input<framework::LoDTensor>("Input");
auto* w = ctx.Input<Tensor>("W");
Expand Down
132 changes: 132 additions & 0 deletions paddle/fluid/operators/fused/fused_seq_tensor_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include "paddle/fluid/operators/fused/fused_seq_tensor_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include <string>

namespace paddle {
namespace operators {

class FusedSeqTensorOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "FusedSeqTensorOp");
OP_INOUT_CHECK(ctx->HasInput("ADInput"), "ADInput", "ADInput", "FusedSeqTensorOp");

OP_INOUT_CHECK(ctx->HasOutput("DINOut"), "DINOut", "DINOut", "FusedSeqTensorOp");
OP_INOUT_CHECK(ctx->HasOutput("MaskOut"), "MaskOut", "MaskOut", "FusedSeqTensorOp");
OP_INOUT_CHECK(ctx->HasOutput("SideInfoOut"), "SideInfoOut", "SideInfoOut", "FusedSeqTensorOp");
OP_INOUT_CHECK(ctx->HasOutput("ADSlotSessionOut"), "ADSlotSessionOut", "ADSlotSessionOut", "FusedSeqTensorOp");

const framework::DDim input_dims = ctx->GetInputDim("Input");
const framework::DDim ad_input_dims = ctx->GetInputDim("ADInput");

auto ad_slot_num = ctx->Attrs().Get<int64_t>("ad_slot_num");
auto batch_count = ctx->Attrs().Get<int64_t>("batch_count");
auto max_length = ctx->Attrs().Get<int64_t>("max_length");
auto slot_num = ctx->Attrs().Get<int64_t>("slot_num");
auto fea_emb_dim = ctx->Attrs().Get<int64_t>("fea_emb_dim");
auto ad_slot_offset = ctx->Attrs().Get<int64_t>("ad_slot_offset");

int64_t one_ins_dim = batch_count * max_length * slot_num * fea_emb_dim;
PADDLE_ENFORCE_EQ(
input_dims[1], one_ins_dim,
platform::errors::InvalidArgument(
"input dims error, %ld != %ld", input_dims[1], one_ins_dim));

int64_t one_ins_ad_dim = batch_count * 1 * ad_slot_num * fea_emb_dim;
PADDLE_ENFORCE_EQ(
ad_input_dims[1], one_ins_ad_dim,
platform::errors::InvalidArgument(
"input dims error, %ld != %ld", ad_input_dims[1], one_ins_ad_dim));
PADDLE_ENFORCE_LT(
ad_slot_num, slot_num,
platform::errors::InvalidArgument(
"ad_slot_num [%ld] > slot_num [%ld]", ad_slot_num, slot_num));
PADDLE_ENFORCE_GT(
ad_slot_num, 0,
platform::errors::InvalidArgument(
"ad_slot_num [%ld] <= 0", ad_slot_num));
PADDLE_ENFORCE_LT(
ad_slot_offset, slot_num - 1,
platform::errors::InvalidArgument(
"ad_slot_num [%ld] > slot_num - 1 [%ld]", ad_slot_offset, slot_num));
PADDLE_ENFORCE_GE(
ad_slot_offset, 0,
platform::errors::InvalidArgument(
"ad_slot_offset [%ld] < 0", ad_slot_offset));
if (ad_slot_offset != 0) {
PADDLE_ENFORCE_EQ(
ad_slot_num + ad_slot_offset, slot_num,
platform::errors::InvalidArgument(
"ad_slot_num [%ld] + ad_slot_offset [%ld] != slot_num [%ld]", ad_slot_num, ad_slot_offset, slot_num));
}

auto ins_num = input_dims[0];
if (batch_count > 1) {
ctx->SetOutputDim("DINOut", {batch_count, ins_num * max_length, ad_slot_num * fea_emb_dim * 4});
ctx->SetOutputDim("MaskOut", {batch_count, ins_num, max_length});
ctx->SetOutputDim("SideInfoOut", {batch_count, ins_num * max_length, (slot_num - ad_slot_num) * fea_emb_dim});
ctx->SetOutputDim("ADSlotSessionOut", {batch_count, ins_num * max_length, ad_slot_num, fea_emb_dim});
} else {
ctx->SetOutputDim("DINOut", {ins_num, max_length, ad_slot_num * fea_emb_dim * 4});
ctx->SetOutputDim("MaskOut", {ins_num, max_length});
ctx->SetOutputDim("SideInfoOut", {ins_num, max_length, (slot_num - ad_slot_num) * fea_emb_dim});
ctx->SetOutputDim("ADSlotSessionOut", {ins_num, max_length, ad_slot_num * fea_emb_dim});
}
ctx->ShareLoD("Input", "DINOut");
ctx->ShareLoD("Input", "MaskOut");
ctx->ShareLoD("Input", "SideInfoOut");
ctx->ShareLoD("Input", "ADSlotSessionOut");
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};

class FusedSeqTensorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"The input tensors of operator.");
AddInput("ADInput",
"The input ad tensors of operator. ");
AddOutput("DINOut",
"DINOut");
AddOutput("MaskOut",
"MaskOut");
AddOutput("SideInfoOut",
"SideInfoOut");
AddOutput("ADSlotSessionOut",
"ADSlotSessionOut");

AddAttr<int64_t>("batch_count", "(int, default 1)");
AddAttr<int64_t>("max_length", "(int, default 1)");
AddAttr<int64_t>("slot_num", "(int, default 1)");
AddAttr<int64_t>("fea_emb_dim", "(int, default 1)");
AddAttr<int64_t>("ad_slot_num", "(int, default 1)");
AddAttr<int64_t>("ad_slot_offset", "(int, default 1)");

AddComment(R"DOC(
Fuse seq tensor.
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(fused_seq_tensor,
ops::FusedSeqTensorOp, ops::FusedSeqTensorOpMaker);

REGISTER_OP_CPU_KERNEL(
fused_seq_tensor,
ops::FusedSeqTensorCPUKernel<phi::CPUContext, float>);
Loading

0 comments on commit 61fa982

Please sign in to comment.