diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index bef0052a00d6b..066e7e15e8831 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -69,7 +69,7 @@ class FMHARef { ~FMHARef() {} void ComputeForward(const Tensor& qkv_input_tensor, - const Tensor& src_mask_tensor, + const Tensor* src_mask_tensor, Tensor* transpose_2_out_tensor, Tensor* qk_out_tensor, Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor, Tensor* dropout_mask_out_tensor, @@ -111,17 +111,17 @@ class FMHARef { blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, q_ptr, k_ptr, beta, qk_out_data, gemm_batch_size, stride_a, stride_b); - - std::vector ins; - std::vector outs; - ins.emplace_back(qk_out_tensor); - ins.emplace_back(&src_mask_tensor); - outs.emplace_back(src_mask_out_tensor); - int elewise_add_axis = -1; int softmax_axis = -1; - if (&src_mask_tensor != nullptr) { + if (src_mask_tensor != nullptr) { + std::vector ins; + std::vector outs; + ins.emplace_back(qk_out_tensor); + ins.emplace_back(src_mask_tensor); + outs.emplace_back(src_mask_out_tensor); + int elewise_add_axis = -1; LaunchElementwiseCudaKernel( dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor()); + SoftmaxForwardCUDAKernelDriver(dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); } else { @@ -165,7 +165,7 @@ class FMHARef { } void ComputeBackward( - const Tensor& transpose_2_out_tensor, const Tensor& src_mask_tensor, + const Tensor& transpose_2_out_tensor, const Tensor* src_mask_tensor, const Tensor& softmax_out_tensor, const Tensor& dropout_mask_out_tensor, const Tensor& dropout_out_tensor, const Tensor& qk_out_tensor, const Tensor& src_mask_out_tensor, const Tensor& fmha_out_grad_tensor, @@ -249,7 +249,7 @@ class FMHARef { softmax_out_grad_tensor); } - if (&src_mask_tensor != nullptr) { + if (src_mask_tensor != nullptr) { SoftmaxBackwardCUDAKernelDriver(dev_ctx_, softmax_out_tensor, *softmax_out_grad_tensor, softmax_axis, src_mask_out_grad_tensor); diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index f7c7129c7732b..96e2a0fcad2b2 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -27,8 +27,6 @@ class FusedAttentionOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask", - "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", "FusedAttentionOp"); @@ -57,8 +55,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut", - "FusedAttentionOp"); + + if (ctx->HasInput("SrcMask")) { + OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut", + "FusedAttentionOp"); + } OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"), "Output", "SoftmaxOut", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"), "Output", @@ -119,7 +120,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); // [batch, num_head, seq_len, seq_len] ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); - ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + + if (ctx->HasInput("SrcMask")) { + ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + } // the same as QKOut's shape. ctx->SetOutputDim("AttnDropoutOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); @@ -320,7 +324,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { { out = transpose(out, perm=[2, 0, 3, 1, 4]); out = q * k^t; - out = attn_mark + out; + out = attn_mask + out; out = softmax(out); out = dropout(out); out = out * v; @@ -368,8 +372,6 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask", - "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", @@ -413,8 +415,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ctx->GetInputDim("SoftmaxOut")); ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"), ctx->GetInputDim("AttnDropoutOut")); - ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"), - ctx->GetInputDim("SrcMaskOut")); + + if (ctx->HasOutput(framework::GradVarName("SrcMaskOut"))) { + ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"), + ctx->GetInputDim("SrcMaskOut")); + } ctx->SetOutputDim(framework::GradVarName("QKVOut"), ctx->GetInputDim("QKVOut")); ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"), @@ -448,7 +453,14 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("X", this->Input("X")); op->SetInput("QKVW", this->Input("QKVW")); op->SetInput("QKVBias", this->Input("QKVBias")); - op->SetInput("SrcMask", this->Input("SrcMask")); + + if (this->HasInput("SrcMask")) { + op->SetInput("SrcMask", this->Input("SrcMask")); + op->SetInput("SrcMaskOut", this->Output("SrcMaskOut")); + op->SetOutput(framework::GradVarName("SrcMaskOut"), + this->OutputGrad("SrcMaskOut")); + } + op->SetInput("OutLinearW", this->Input("OutLinearW")); op->SetInput("OutLinearBias", this->Input("OutLinearBias")); @@ -508,7 +520,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("SoftmaxOut", this->Output("SoftmaxOut")); op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut")); op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut")); - op->SetInput("SrcMaskOut", this->Output("SrcMaskOut")); + op->SetInput("FMHAOut", this->Output("FMHAOut")); op->SetInput("OutLinearOut", this->Output("OutLinearOut")); @@ -538,8 +550,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { this->OutputGrad("SoftmaxOut")); op->SetOutput(framework::GradVarName("AttnDropoutOut"), this->OutputGrad("AttnDropoutOut")); - op->SetOutput(framework::GradVarName("SrcMaskOut"), - this->OutputGrad("SrcMaskOut")); + op->SetOutput(framework::GradVarName("FMHAOut"), this->OutputGrad("FMHAOut")); op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"), diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 01bc49bcf4079..99f08d38b6da1 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -114,7 +114,9 @@ class FusedAttentionOpKernel : public framework::OpKernel { transpose_out_2->mutable_data(ctx.GetPlace()); auto *qk_out_data = qk_out->mutable_data(ctx.GetPlace()); auto *qktv_out_data = qktv_out->mutable_data(ctx.GetPlace()); - auto *src_mask_out_data = src_mask_out->mutable_data(ctx.GetPlace()); + auto *src_mask_out_data = + (src_mask == nullptr) ? nullptr + : src_mask_out->mutable_data(ctx.GetPlace()); auto *softmax_out_data = softmax_out->mutable_data(ctx.GetPlace()); auto *attn_dropout_mask_out_data = attn_dropout_mask_out->mutable_data(ctx.GetPlace()); @@ -184,10 +186,11 @@ class FusedAttentionOpKernel : public framework::OpKernel { qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data, qkv_out_data, qkv_bias_out_data); } - fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2, + fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2, qk_out, src_mask_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, qktv_out, fmha_out); + // fmha_out: [batch_size, seq_len, num_head, head_dim] // weight: [embed_dim, embed_dim] // out_linear_out: [batch_size, seq_len, embed_dim] @@ -265,7 +268,8 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *qk_out_data = qk_out->data(); auto *qktv_out_data = qktv_out->data(); auto *softmax_out_data = softmax_out->data(); - auto *src_mask_out_data = src_mask_out->data(); + auto *src_mask_out_data = + (src_mask == nullptr) ? nullptr : src_mask_out->data(); auto *out_linear_out_data = out_linear_out->data(); auto *ln_2_mean_data = ln_2_mean->data(); auto *ln_2_var_data = ln_2_var->data(); @@ -302,7 +306,9 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *d_softmax_out_data = d_softmax_out->mutable_data(ctx.GetPlace()); auto *d_attn_dropout_out_data = d_attn_dropout_out->mutable_data(ctx.GetPlace()); - auto *d_src_mask_out_data = d_src_mask_out->mutable_data(ctx.GetPlace()); + auto *d_src_mask_out_data = + (src_mask == nullptr) ? nullptr + : d_src_mask_out->mutable_data(ctx.GetPlace()); auto *d_fmha_out_data = d_fmha_out->mutable_data(ctx.GetPlace()); auto *d_out_linear_out_data = d_out_linear_out->mutable_data(ctx.GetPlace()); @@ -386,7 +392,7 @@ class FusedAttentionGradKernel : public framework::OpKernel { d_out_linear_out_data, d_fmha_out_data, d_out_linear_weight_data, nullptr); fmha_ref_compute.ComputeBackward( - *transpose_out_2, *src_mask, *softmax_out, *attn_dropout_mask_out, + *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out, *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, d_transpose_out_2, nullptr, d_qkv_bias_out); diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index c33e1f53dfdb6..41962d5ada0a1 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -66,6 +66,7 @@ def config(self): self.x_type = np.float32 self.attn_mask_type = np.float64 self.pre_layer_norm = False + self.has_attn_mask = True self.training = True self.batch_size = 8 @@ -84,16 +85,20 @@ def config(self): def generate_input_data(self): self.query = np.random.rand(self.batch_size, self.query_length, self.embed_dim).astype(self.x_type) - self.attn_mask = np.ones( - (self.batch_size, self.num_heads, self.query_length, - self.key_length), - dtype=self.attn_mask_type) - if self.attn_mask_type == np.int64: - self.attn_mask = np.tril(self.attn_mask) - elif self.attn_mask_type == np.float64: - self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 + if self.has_attn_mask: + self.attn_mask = np.ones( + (self.batch_size, self.num_heads, self.query_length, + self.key_length), + dtype=self.attn_mask_type) + if self.attn_mask_type == np.int64: + self.attn_mask = np.tril(self.attn_mask) + elif self.attn_mask_type == np.float64: + self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 + else: + raise ValueError( + "'attn_mask_type' should be 'int64' or 'float64'.") else: - raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") + self.attn_mask = None self.key, self.value = self.query, self.query self.dout = np.random.random((self.batch_size, self.query_length, @@ -102,7 +107,10 @@ def generate_input_data(self): def GetBaselineOut(self): paddle.disable_static(place=paddle.CUDAPlace(0)) tensor_query = paddle.to_tensor(self.query, stop_gradient=False) - attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + if self.has_attn_mask: + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + else: + attn_mask = None residual = tensor_query ln1_out = tensor_query @@ -187,7 +195,10 @@ def GetFusedAttentionOut(self): qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) x = paddle.to_tensor(self.query, stop_gradient=False) - attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + if self.has_attn_mask: + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + else: + attn_mask = None qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) epsilon = 1e-05 @@ -218,6 +229,37 @@ def config(self): self.x_type = np.float32 self.attn_mask_type = np.float64 self.pre_layer_norm = True + self.has_attn_mask = True + self.training = True + + self.batch_size = 8 + self.query_length = 128 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + def test_fused_attention_op(self): + final_out_ref, x_grad_ref = self.GetBaselineOut() + final_out, x_grad = self.GetFusedAttentionOut() + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) + np.testing.assert_allclose( + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1) + + +class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp): + def config(self): + self.x_type = np.float32 + self.attn_mask_type = np.float64 + self.pre_layer_norm = True + self.has_attn_mask = False self.training = True self.batch_size = 8 @@ -247,6 +289,7 @@ def config(self): self.x_type = np.float16 self.attn_mask_type = np.float64 self.pre_layer_norm = False + self.has_attn_mask = True self.training = True self.batch_size = 8 diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py index 5fa9446763b1f..02695be61c30a 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py @@ -152,6 +152,7 @@ def config(self): self.x_type = np.float32 self.attn_mask_type = np.float64 self.pre_layer_norm = True + self.has_attn_mask = True self.training = True self.need_weight = False @@ -172,19 +173,27 @@ def config(self): def generate_input_data(self): self.query = np.random.rand(self.batch_size, self.query_length, self.embed_dim).astype(self.x_type) - self.attn_mask = np.ones( - (self.batch_size, self.num_heads, self.query_length, - self.key_length), - dtype=self.attn_mask_type) - if self.attn_mask_type == np.int64: - self.attn_mask = np.tril(self.attn_mask) - elif self.attn_mask_type == np.float64: - self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 + if self.has_attn_mask: + self.attn_mask = np.ones( + (self.batch_size, self.num_heads, self.query_length, + self.key_length), + dtype=self.attn_mask_type) + if self.attn_mask_type == np.int64: + self.attn_mask = np.tril(self.attn_mask) + elif self.attn_mask_type == np.float64: + self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 + else: + raise ValueError( + "'attn_mask_type' should be 'int64' or 'float64'.") else: - raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") + self.attn_mask = None self.key, self.value = self.query, self.query def run_imperative(self): + if self.has_attn_mask: + attn_mask_tensor = paddle.to_tensor(self.attn_mask) + else: + attn_mask_tensor = None fused_attn = FusedMultiHeadAttention( self.embed_dim, self.num_heads, self.dropout_prob, self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm, @@ -192,7 +201,7 @@ def run_imperative(self): out = fused_attn( paddle.to_tensor(self.query), paddle.to_tensor(self.query), - paddle.to_tensor(self.query), paddle.to_tensor(self.attn_mask)) + paddle.to_tensor(self.query), attn_mask_tensor) ref_out = compute_reference(self.pre_layer_norm, self.query, self.attn_mask, fused_attn.pre_ln_scale.numpy(), @@ -203,7 +212,7 @@ def run_imperative(self): fused_attn.qkv_bias.numpy(), fused_attn.linear_weight.numpy(), fused_attn.linear_bias.numpy()) - self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-5)) + np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-5) def run_static(self): fused_attn = FusedMultiHeadAttention( @@ -215,29 +224,42 @@ def run_static(self): name='X', shape=[self.batch_size, self.query_length, self.embed_dim], dtype=self.x_type) - attn_mask = paddle.static.data( - name='SrcMask', - shape=[ - self.batch_size, self.num_heads, self.query_length, - self.key_length - ], - dtype=self.attn_mask_type) - final_out = fused_attn(x, x, x, attn_mask) + if self.has_attn_mask: + attn_mask = paddle.static.data( + name='SrcMask', + shape=[ + self.batch_size, self.num_heads, self.query_length, + self.key_length + ], + dtype=self.attn_mask_type) + final_out = fused_attn(x, x, x, attn_mask) + else: + final_out = fused_attn(x, x, x) place = paddle.CUDAPlace(0) exe = paddle.static.Executor(place) exe.run(paddle.static.default_startup_program()) - out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( - paddle.static.default_main_program(), - feed={"X": self.query, - "SrcMask": self.attn_mask}, - fetch_list=[ - final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, - fused_attn.linear_weight, fused_attn.linear_bias, - fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, - fused_attn.ln_scale, fused_attn.ln_bias - ]) - + if self.has_attn_mask: + out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( + paddle.static.default_main_program(), + feed={"X": self.query, + "SrcMask": self.attn_mask}, + fetch_list=[ + final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, + fused_attn.linear_weight, fused_attn.linear_bias, + fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, + fused_attn.ln_scale, fused_attn.ln_bias + ]) + else: + out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( + paddle.static.default_main_program(), + feed={"X": self.query, }, + fetch_list=[ + final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, + fused_attn.linear_weight, fused_attn.linear_bias, + fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, + fused_attn.ln_scale, fused_attn.ln_bias + ]) return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias def test_static_api(self): @@ -249,14 +271,36 @@ def test_static_api(self): self.attn_mask, ln_scale, ln_bias, ln_2_scale, ln_2_bias, qkv_weight, qkv_bias, linear_weight, linear_bias) - self.assertTrue( - np.allclose( - np.array(ref_out), np.array(out), rtol=1e-5, atol=1e-5)) + np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-5) def test_dynamic_api(self): paddle.disable_static(place=paddle.CUDAPlace(0)) self.run_imperative() +class TestFusedAttentionAPINoneAttnMask(TestFusedAttentionAPI): + def config(self): + self.x_type = np.float32 + self.attn_mask_type = np.float64 + self.pre_layer_norm = True + self.has_attn_mask = False + self.training = True + self.need_weight = False + + self.batch_size = 1 + self.query_length = 2 + self.head_dim = 2 + self.num_heads = 2 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + if __name__ == "__main__": unittest.main()