Skip to content

Commit

Permalink
【fix-bug】Support attn_mask=None input cases for fused_attention_op. (P…
Browse files Browse the repository at this point in the history
…addlePaddle#36951)

目前的fused_attention_op不支持attn_mask=None的输入,本PR对此进行了补充,并补充了相应的单测逻辑。
  • Loading branch information
limin2021 committed Nov 8, 2021
1 parent b7e8830 commit 472dcca
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 74 deletions.
22 changes: 11 additions & 11 deletions paddle/fluid/operators/fused/fmha_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class FMHARef {
~FMHARef() {}

void ComputeForward(const Tensor& qkv_input_tensor,
const Tensor& src_mask_tensor,
const Tensor* src_mask_tensor,
Tensor* transpose_2_out_tensor, Tensor* qk_out_tensor,
Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor,
Tensor* dropout_mask_out_tensor,
Expand Down Expand Up @@ -111,17 +111,17 @@ class FMHARef {
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, q_ptr,
k_ptr, beta, qk_out_data, gemm_batch_size, stride_a,
stride_b);

std::vector<const Tensor*> ins;
std::vector<Tensor*> outs;
ins.emplace_back(qk_out_tensor);
ins.emplace_back(&src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
int softmax_axis = -1;
if (&src_mask_tensor != nullptr) {
if (src_mask_tensor != nullptr) {
std::vector<const Tensor*> ins;
std::vector<Tensor*> outs;
ins.emplace_back(qk_out_tensor);
ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>());

SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *src_mask_out_tensor,
softmax_axis, softmax_out_tensor);
} else {
Expand Down Expand Up @@ -165,7 +165,7 @@ class FMHARef {
}

void ComputeBackward(
const Tensor& transpose_2_out_tensor, const Tensor& src_mask_tensor,
const Tensor& transpose_2_out_tensor, const Tensor* src_mask_tensor,
const Tensor& softmax_out_tensor, const Tensor& dropout_mask_out_tensor,
const Tensor& dropout_out_tensor, const Tensor& qk_out_tensor,
const Tensor& src_mask_out_tensor, const Tensor& fmha_out_grad_tensor,
Expand Down Expand Up @@ -249,7 +249,7 @@ class FMHARef {
softmax_out_grad_tensor);
}

if (&src_mask_tensor != nullptr) {
if (src_mask_tensor != nullptr) {
SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, softmax_out_tensor,
*softmax_out_grad_tensor, softmax_axis,
src_mask_out_grad_tensor);
Expand Down
39 changes: 25 additions & 14 deletions paddle/fluid/operators/fused/fused_attention_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class FusedAttentionOp : public framework::OperatorWithKernel {

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionOp");
Expand Down Expand Up @@ -57,8 +55,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut",
"FusedAttentionOp");

if (ctx->HasInput("SrcMask")) {
OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut",
"FusedAttentionOp");
}
OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"), "Output", "SoftmaxOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"), "Output",
Expand Down Expand Up @@ -119,7 +120,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
{y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
// [batch, num_head, seq_len, seq_len]
ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});

if (ctx->HasInput("SrcMask")) {
ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
}
// the same as QKOut's shape.
ctx->SetOutputDim("AttnDropoutOut",
{x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
Expand Down Expand Up @@ -320,7 +324,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
{
out = transpose(out, perm=[2, 0, 3, 1, 4]);
out = q * k^t;
out = attn_mark + out;
out = attn_mask + out;
out = softmax(out);
out = dropout(out);
out = out * v;
Expand Down Expand Up @@ -368,8 +372,6 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
Expand Down Expand Up @@ -413,8 +415,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("SoftmaxOut"));
ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
ctx->GetInputDim("AttnDropoutOut"));
ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
ctx->GetInputDim("SrcMaskOut"));

if (ctx->HasOutput(framework::GradVarName("SrcMaskOut"))) {
ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
ctx->GetInputDim("SrcMaskOut"));
}
ctx->SetOutputDim(framework::GradVarName("QKVOut"),
ctx->GetInputDim("QKVOut"));
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
Expand Down Expand Up @@ -448,7 +453,14 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("X", this->Input("X"));
op->SetInput("QKVW", this->Input("QKVW"));
op->SetInput("QKVBias", this->Input("QKVBias"));
op->SetInput("SrcMask", this->Input("SrcMask"));

if (this->HasInput("SrcMask")) {
op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
op->SetOutput(framework::GradVarName("SrcMaskOut"),
this->OutputGrad("SrcMaskOut"));
}

op->SetInput("OutLinearW", this->Input("OutLinearW"));
op->SetInput("OutLinearBias", this->Input("OutLinearBias"));

Expand Down Expand Up @@ -508,7 +520,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("SoftmaxOut", this->Output("SoftmaxOut"));
op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut"));
op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut"));
op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));

op->SetInput("FMHAOut", this->Output("FMHAOut"));
op->SetInput("OutLinearOut", this->Output("OutLinearOut"));

Expand Down Expand Up @@ -538,8 +550,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->OutputGrad("SoftmaxOut"));
op->SetOutput(framework::GradVarName("AttnDropoutOut"),
this->OutputGrad("AttnDropoutOut"));
op->SetOutput(framework::GradVarName("SrcMaskOut"),
this->OutputGrad("SrcMaskOut"));

op->SetOutput(framework::GradVarName("FMHAOut"),
this->OutputGrad("FMHAOut"));
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
Expand Down
16 changes: 11 additions & 5 deletions paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
transpose_out_2->mutable_data<T>(ctx.GetPlace());
auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace());
auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace());
auto *src_mask_out_data = src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr
: src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *softmax_out_data = softmax_out->mutable_data<T>(ctx.GetPlace());
auto *attn_dropout_mask_out_data =
attn_dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
Expand Down Expand Up @@ -184,10 +186,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data);
}
fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2,
fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);

// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim]
Expand Down Expand Up @@ -265,7 +268,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *qk_out_data = qk_out->data<T>();
auto *qktv_out_data = qktv_out->data<T>();
auto *softmax_out_data = softmax_out->data<T>();
auto *src_mask_out_data = src_mask_out->data<T>();
auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
auto *out_linear_out_data = out_linear_out->data<T>();
auto *ln_2_mean_data = ln_2_mean->data<U>();
auto *ln_2_var_data = ln_2_var->data<U>();
Expand Down Expand Up @@ -302,7 +306,9 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_softmax_out_data = d_softmax_out->mutable_data<T>(ctx.GetPlace());
auto *d_attn_dropout_out_data =
d_attn_dropout_out->mutable_data<T>(ctx.GetPlace());
auto *d_src_mask_out_data = d_src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *d_src_mask_out_data =
(src_mask == nullptr) ? nullptr
: d_src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_out_data =
d_out_linear_out->mutable_data<T>(ctx.GetPlace());
Expand Down Expand Up @@ -386,7 +392,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_out_data, d_fmha_out_data,
d_out_linear_weight_data, nullptr);
fmha_ref_compute.ComputeBackward(
*transpose_out_2, *src_mask, *softmax_out, *attn_dropout_mask_out,
*transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_bias_out);
Expand Down
65 changes: 54 additions & 11 deletions python/paddle/fluid/tests/unittests/test_fused_attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = False
self.has_attn_mask = True
self.training = True

self.batch_size = 8
Expand All @@ -84,16 +85,20 @@ def config(self):
def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
self.attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length,
self.key_length),
dtype=self.attn_mask_type)
if self.attn_mask_type == np.int64:
self.attn_mask = np.tril(self.attn_mask)
elif self.attn_mask_type == np.float64:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
if self.has_attn_mask:
self.attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length,
self.key_length),
dtype=self.attn_mask_type)
if self.attn_mask_type == np.int64:
self.attn_mask = np.tril(self.attn_mask)
elif self.attn_mask_type == np.float64:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
else:
raise ValueError(
"'attn_mask_type' should be 'int64' or 'float64'.")
else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")
self.attn_mask = None
self.key, self.value = self.query, self.query

self.dout = np.random.random((self.batch_size, self.query_length,
Expand All @@ -102,7 +107,10 @@ def generate_input_data(self):
def GetBaselineOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
attn_mask = None
residual = tensor_query

ln1_out = tensor_query
Expand Down Expand Up @@ -187,7 +195,10 @@ def GetFusedAttentionOut(self):
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))

x = paddle.to_tensor(self.query, stop_gradient=False)
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
attn_mask = None
qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
epsilon = 1e-05
Expand Down Expand Up @@ -218,6 +229,37 @@ def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = True
self.training = True

self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads

self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length

def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1)


class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = False
self.training = True

self.batch_size = 8
Expand Down Expand Up @@ -247,6 +289,7 @@ def config(self):
self.x_type = np.float16
self.attn_mask_type = np.float64
self.pre_layer_norm = False
self.has_attn_mask = True
self.training = True

self.batch_size = 8
Expand Down
Loading

0 comments on commit 472dcca

Please sign in to comment.