diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cc index 4a4fb2c17b1c4..79befbf905fdd 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cc @@ -58,6 +58,11 @@ class FusedMultiTransformerWeightOnlyOp : public framework::OperatorWithKernel { CHECK_OUTPUT(Out); + const std::string weight_dtype = + ctx->Attrs().Get("weight_dtype"); + PADDLE_ENFORCE(weight_dtype == "int8" || weight_dtype == "int4", + platform::errors::InvalidArgument( + "quant_method must be 'int8' or 'int4'.")); // x: qkv's input [batch_size, seq_len, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed] auto x_dim = ctx->GetInputDim("X"); @@ -112,8 +117,9 @@ class FusedMultiTransformerWeightOnlyOp : public framework::OperatorWithKernel { "head %d, but got %d", y_dim[1], c_dim[2])); // num_head + int64_t head_size = (weight_dtype == "int4") ? y_dim[2] * 2 : y_dim[2]; PADDLE_ENFORCE_EQ(c_dim[4], - y_dim[2], + head_size, paddle::platform::errors::InvalidArgument( "The fifth dim of CacheKV must be equal with head " "size %d, but got %d", diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cu index 52659b60b0d81..f836e476ab2dc 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cu @@ -32,8 +32,6 @@ static void PrintMatrix(const T* mat_d, int num, std::string name, int i) { outfile.close(); } - - template class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel { public: @@ -49,7 +47,9 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel { int seq_len = input_x_dims[1]; int dim_embed = input_x_dims[2]; int bsz_seq = bsz * seq_len; +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER LOG(INFO) << "intput X: bsz: " << bsz << ", seq_len: " << seq_len << ", dim_embed: " << dim_embed; +#endif const std::string act_method = ctx.Attr("act_method"); const std::string none_act = "none"; bool use_glu = (act_method == "geglu"); @@ -68,7 +68,9 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel { phi::DenseTensor padding_offset_tensor; phi::DenseTensor x_remove_padding; bool encoder_remove_padding = (remove_padding && !time_step); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER LOG(INFO) << "remove padding: " << encoder_remove_padding; +#endif int token_num = 0; auto *out = ctx.Output("Out"); @@ -138,20 +140,24 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel { auto qkv_scales = ctx.MultiInput("QKVWScale"); auto qkv_biases = ctx.MultiInput("QKVBias"); const std::string weight_dtype = ctx.Attr("weight_dtype"); - //const bool trans_qkvw = ctx.Attr("trans_qkvw"); + const bool is_int4 = (weight_dtype == "int4"); + const auto qkv_w_dims = qkv_weights[0]->dims(); - //int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; - //int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; int num_head = qkv_w_dims[1]; int dim_head = qkv_w_dims[2]; + if (is_int4) { + dim_head = dim_head * 2; + } int hidden_size = num_head * dim_head; +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER LOG(INFO) << "num head: " << num_head << ", dim head: " << dim_head << ", hidden size:" << hidden_size; +#endif int output_size = 3 * hidden_size; int qkv_output_size = 3 * hidden_size; int input_size = dim_embed; //weight only gemm auto weight_only_gemm = - AttnMatMulWeightOnly(dev_ctx, (weight_dtype == "int4")); + AttnMatMulWeightOnly(dev_ctx, is_int4); int default_act = weight_only_gemm.GetActivation("none"); int ffn_act = weight_only_gemm.GetActivation(act_method); @@ -287,6 +293,9 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel { auto ffn1_weight_dim = ffn1_weights[0]->dims(); int dim_ffn = ffn1_weight_dim[0]; + if (is_int4) { + dim_ffn = dim_ffn * 2; + } //int dim_ffn = ffn1_weight_dim[1]; FFNGluHelper ffn1_glu_helper( dev_ctx, act_method, token_num, dim_ffn / 2, dim_ffn, dim_embed); @@ -408,12 +417,14 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel { } else { //qkv_compute.ComputeForward( // qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "layer id=" << i << ", qkv input=" << buf1->dims() << ", weight=" << qkv_weights[i]->dims() << ", scale=" << qkv_scales[i]->dims() - << ", output=" << qkv_out.dims(); - VLOG(0) << "token num=" << token_num << ", output size=" << qkv_output_size - << ", dim_embed=" << dim_embed; + << ", output=" << qkv_out.dims() + << ", token num=" << token_num << ", output size=" << qkv_output_size + << ", dim_embed=" << dim_embed; +#endif weight_only_gemm.Linear( *buf1, *qkv_weights[i], @@ -595,13 +606,13 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel { } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step3"; -#endif VLOG(0) << "layer id=" << i << ", out linear input=" << fmha_out.dims() << ", weight=" << out_linear_weights[i]->dims() << ", scale=" << out_linear_scales[i]->dims() - << ", out linear out: " << buf1->dims(); - VLOG(0) << "token num=" << token_num << ", dim embed=" << dim_embed + << ", out linear out: " << buf1->dims() + << ", token num=" << token_num << ", dim embed=" << dim_embed << ", hidden size=" << hidden_size; +#endif //PrintMatrix(fmha_out_data, bsz*seq_len*num_head*dim_head, "fmha_out", i); if (pre_layer_norm) { //out_linear_compute.ComputeForward( @@ -689,13 +700,14 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel { ffn1_weights[i], buf1, nullptr, &ffn1_out, nullptr); } **/ - +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "layer id=" << i << ", ffn1 input=" << buf1->dims() << ", weight=" << ffn1_weights[i]->dims() << ", scale=" << ffn1_weights_scales[i]->dims() - << ", ffn1 out: " << (ffn1_out).dims(); - VLOG(0) << "token num=" << token_num << ", dim ffn=" << dim_ffn + << ", ffn1 out: " << (ffn1_out).dims() + << ", token num=" << token_num << ", dim ffn=" << dim_ffn << ", dim_embed=" << dim_embed; +#endif weight_only_gemm.Linear(*buf1, *ffn1_weights[i], nullptr, @@ -723,12 +735,14 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel { if (pre_layer_norm) { //ffn2_linear_compute.ComputeForward( // ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "layer id=" << i << ", ffn2 input=" << ffn1_dropout_out.dims() << ", weight=" << ffn2_weights[i]->dims() << ", scale=" << ffn2_weights_scales[i]->dims() - << ", ffn2 out: " << buf1->dims(); - VLOG(0) << "token num=" << token_num << ", dim embed=" << dim_embed + << ", ffn2 out: " << buf1->dims() + << ", token num=" << token_num << ", dim embed=" << dim_embed << ", dim_ffn=" << dim_ffn; +#endif weight_only_gemm.Linear(ffn1_dropout_out, *ffn2_weights[i], nullptr, diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index 8878c4e29b4de..547f2aded1f00 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -1544,14 +1544,11 @@ def __init__( self.ln_scales, self.ln_biases = ParameterList(), ParameterList() self.qkv_weights, self.qkv_scales, self.qkv_biases = ParameterList(), ParameterList(), ParameterList() - #self.qkv_weights, self.qkv_biases = ParameterList(), ParameterList() self.linear_weights, self.linear_scales, self.linear_biases = ParameterList(), ParameterList(), ParameterList() - #self.linear_weights, self.linear_biases = ParameterList(), ParameterList() self.ffn_ln_scales, self.ffn_ln_biases = ParameterList(), ParameterList() self.ffn1_weights, self.ffn1_scales, self.ffn1_biases = ParameterList(), ParameterList(), ParameterList() - #self.ffn1_weights, self.ffn1_biases = ParameterList(), ParameterList() self.ffn2_weights, self.ffn2_scales, self.ffn2_biases = ParameterList(), ParameterList(), ParameterList() - #self.ffn2_weights, self.ffn2_biases = ParameterList(), ParameterList() + def get_attr(attrs, idx): if isinstance(attrs, (list, tuple, ParameterList)): assert len(attrs) == num_layers @@ -1589,10 +1586,14 @@ def get_attr(attrs, idx): attr=ln_bias_attr, shape=[embed_dim], is_bias=True, dtype="float32" ) qkv_weight = self.create_parameter( - shape=[3, num_heads, self.head_dim, embed_dim], + shape=[3, + num_heads, + self.head_dim if weight_int8 else int(self.head_dim / 2), + embed_dim], attr=qkv_weight_attr, - dtype=self._dtype, + dtype="uint8", is_bias=False, + default_initializer=Constant(value=0), ) qkv_scale = self.create_parameter( shape=[int(3 * num_heads * self.head_dim)], @@ -1617,10 +1618,11 @@ def get_attr(attrs, idx): ''' linear_weight = self.create_parameter( shape=[embed_dim if weight_int8 else int(embed_dim / 2), - int(num_heads * self.head_dim)], + int(num_heads * self.head_dim)], attr=linear_weight_attr, - dtype=self._dtype, + dtype="uint8", is_bias=False, + default_initializer=Constant(value=0), ) linear_scale = self.create_parameter( shape=[embed_dim], @@ -1654,10 +1656,12 @@ def get_attr(attrs, idx): ) ''' ffn1_weight = self.create_parameter( - shape=[dim_feedforward, embed_dim], + shape=[dim_feedforward if weight_int8 else int(dim_feedforward / 2), + embed_dim], attr=ffn1_weight_attr, - dtype=self._dtype, + dtype="uint8", is_bias=False, + default_initializer=Constant(value=0), ) ffn1_scale = self.create_parameter( shape=[dim_feedforward], @@ -1681,10 +1685,12 @@ def get_attr(attrs, idx): ) ''' ffn2_weight = self.create_parameter( - shape=[embed_dim, dim_feedforward], + shape=[embed_dim if weight_int8 else int(embed_dim / 2), + dim_feedforward], attr=ffn2_weight_attr, - dtype=self._dtype, + dtype="uint8", is_bias=False, + default_initializer=Constant(value=0), ) ffn2_scale = self.create_parameter( shape=[embed_dim],