Skip to content

Commit

Permalink
add weightonly int4 fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
humingqing authored and humingqing committed Jan 16, 2024
1 parent a7e3f45 commit ca25c78
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class FusedMultiTransformerWeightOnlyOp : public framework::OperatorWithKernel {

CHECK_OUTPUT(Out);

const std::string weight_dtype =
ctx->Attrs().Get<std::string>("weight_dtype");
PADDLE_ENFORCE(weight_dtype == "int8" || weight_dtype == "int4",
platform::errors::InvalidArgument(
"quant_method must be 'int8' or 'int4'."));
// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto x_dim = ctx->GetInputDim("X");
Expand Down Expand Up @@ -112,8 +117,9 @@ class FusedMultiTransformerWeightOnlyOp : public framework::OperatorWithKernel {
"head %d, but got %d",
y_dim[1],
c_dim[2])); // num_head
int64_t head_size = (weight_dtype == "int4") ? y_dim[2] * 2 : y_dim[2];
PADDLE_ENFORCE_EQ(c_dim[4],
y_dim[2],
head_size,
paddle::platform::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ static void PrintMatrix(const T* mat_d, int num, std::string name, int i) {
outfile.close();
}



template <typename T>
class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel<T> {
public:
Expand All @@ -49,7 +47,9 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel<T> {
int seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int bsz_seq = bsz * seq_len;
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
LOG(INFO) << "intput X: bsz: " << bsz << ", seq_len: " << seq_len << ", dim_embed: " << dim_embed;
#endif
const std::string act_method = ctx.Attr<std::string>("act_method");
const std::string none_act = "none";
bool use_glu = (act_method == "geglu");
Expand All @@ -68,7 +68,9 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel<T> {
phi::DenseTensor padding_offset_tensor;
phi::DenseTensor x_remove_padding;
bool encoder_remove_padding = (remove_padding && !time_step);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
LOG(INFO) << "remove padding: " << encoder_remove_padding;
#endif
int token_num = 0;

auto *out = ctx.Output<phi::DenseTensor>("Out");
Expand Down Expand Up @@ -138,20 +140,24 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel<T> {
auto qkv_scales = ctx.MultiInput<phi::DenseTensor>("QKVWScale");
auto qkv_biases = ctx.MultiInput<phi::DenseTensor>("QKVBias");
const std::string weight_dtype = ctx.Attr<std::string>("weight_dtype");
//const bool trans_qkvw = ctx.Attr<bool>("trans_qkvw");
const bool is_int4 = (weight_dtype == "int4");

const auto qkv_w_dims = qkv_weights[0]->dims();
//int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2];
//int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3];
int num_head = qkv_w_dims[1];
int dim_head = qkv_w_dims[2];
if (is_int4) {
dim_head = dim_head * 2;
}
int hidden_size = num_head * dim_head;
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
LOG(INFO) << "num head: " << num_head << ", dim head: " << dim_head << ", hidden size:" << hidden_size;
#endif
int output_size = 3 * hidden_size;
int qkv_output_size = 3 * hidden_size;
int input_size = dim_embed;
//weight only gemm
auto weight_only_gemm =
AttnMatMulWeightOnly<T>(dev_ctx, (weight_dtype == "int4"));
AttnMatMulWeightOnly<T>(dev_ctx, is_int4);
int default_act = weight_only_gemm.GetActivation("none");
int ffn_act = weight_only_gemm.GetActivation(act_method);

Expand Down Expand Up @@ -287,6 +293,9 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel<T> {
auto ffn1_weight_dim = ffn1_weights[0]->dims();

int dim_ffn = ffn1_weight_dim[0];
if (is_int4) {
dim_ffn = dim_ffn * 2;
}
//int dim_ffn = ffn1_weight_dim[1];
FFNGluHelper<T> ffn1_glu_helper(
dev_ctx, act_method, token_num, dim_ffn / 2, dim_ffn, dim_embed);
Expand Down Expand Up @@ -408,12 +417,14 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel<T> {
} else {
//qkv_compute.ComputeForward(
// qkv_weights[i], buf1, bias, &qkv_out, &qkv_out);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "layer id=" << i << ", qkv input=" << buf1->dims()
<< ", weight=" << qkv_weights[i]->dims()
<< ", scale=" << qkv_scales[i]->dims()
<< ", output=" << qkv_out.dims();
VLOG(0) << "token num=" << token_num << ", output size=" << qkv_output_size
<< ", dim_embed=" << dim_embed;
<< ", output=" << qkv_out.dims()
<< ", token num=" << token_num << ", output size=" << qkv_output_size
<< ", dim_embed=" << dim_embed;
#endif
weight_only_gemm.Linear(
*buf1,
*qkv_weights[i],
Expand Down Expand Up @@ -595,13 +606,13 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel<T> {
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step3";
#endif
VLOG(0) << "layer id=" << i << ", out linear input=" << fmha_out.dims()
<< ", weight=" << out_linear_weights[i]->dims()
<< ", scale=" << out_linear_scales[i]->dims()
<< ", out linear out: " << buf1->dims();
VLOG(0) << "token num=" << token_num << ", dim embed=" << dim_embed
<< ", out linear out: " << buf1->dims()
<< ", token num=" << token_num << ", dim embed=" << dim_embed
<< ", hidden size=" << hidden_size;
#endif
//PrintMatrix(fmha_out_data, bsz*seq_len*num_head*dim_head, "fmha_out", i);
if (pre_layer_norm) {
//out_linear_compute.ComputeForward(
Expand Down Expand Up @@ -689,13 +700,14 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel<T> {
ffn1_weights[i], buf1, nullptr, &ffn1_out, nullptr);
}
**/

#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "layer id=" << i << ", ffn1 input=" << buf1->dims()
<< ", weight=" << ffn1_weights[i]->dims()
<< ", scale=" << ffn1_weights_scales[i]->dims()
<< ", ffn1 out: " << (ffn1_out).dims();
VLOG(0) << "token num=" << token_num << ", dim ffn=" << dim_ffn
<< ", ffn1 out: " << (ffn1_out).dims()
<< ", token num=" << token_num << ", dim ffn=" << dim_ffn
<< ", dim_embed=" << dim_embed;
#endif
weight_only_gemm.Linear(*buf1,
*ffn1_weights[i],
nullptr,
Expand Down Expand Up @@ -723,12 +735,14 @@ class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel<T> {
if (pre_layer_norm) {
//ffn2_linear_compute.ComputeForward(
// ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "layer id=" << i << ", ffn2 input=" << ffn1_dropout_out.dims()
<< ", weight=" << ffn2_weights[i]->dims()
<< ", scale=" << ffn2_weights_scales[i]->dims()
<< ", ffn2 out: " << buf1->dims();
VLOG(0) << "token num=" << token_num << ", dim embed=" << dim_embed
<< ", ffn2 out: " << buf1->dims()
<< ", token num=" << token_num << ", dim embed=" << dim_embed
<< ", dim_ffn=" << dim_ffn;
#endif
weight_only_gemm.Linear(ffn1_dropout_out,
*ffn2_weights[i],
nullptr,
Expand Down
30 changes: 18 additions & 12 deletions python/paddle/incubate/nn/layer/fused_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,14 +1544,11 @@ def __init__(

self.ln_scales, self.ln_biases = ParameterList(), ParameterList()
self.qkv_weights, self.qkv_scales, self.qkv_biases = ParameterList(), ParameterList(), ParameterList()
#self.qkv_weights, self.qkv_biases = ParameterList(), ParameterList()
self.linear_weights, self.linear_scales, self.linear_biases = ParameterList(), ParameterList(), ParameterList()
#self.linear_weights, self.linear_biases = ParameterList(), ParameterList()
self.ffn_ln_scales, self.ffn_ln_biases = ParameterList(), ParameterList()
self.ffn1_weights, self.ffn1_scales, self.ffn1_biases = ParameterList(), ParameterList(), ParameterList()
#self.ffn1_weights, self.ffn1_biases = ParameterList(), ParameterList()
self.ffn2_weights, self.ffn2_scales, self.ffn2_biases = ParameterList(), ParameterList(), ParameterList()
#self.ffn2_weights, self.ffn2_biases = ParameterList(), ParameterList()

def get_attr(attrs, idx):
if isinstance(attrs, (list, tuple, ParameterList)):
assert len(attrs) == num_layers
Expand Down Expand Up @@ -1589,10 +1586,14 @@ def get_attr(attrs, idx):
attr=ln_bias_attr, shape=[embed_dim], is_bias=True, dtype="float32"
)
qkv_weight = self.create_parameter(
shape=[3, num_heads, self.head_dim, embed_dim],
shape=[3,
num_heads,
self.head_dim if weight_int8 else int(self.head_dim / 2),
embed_dim],
attr=qkv_weight_attr,
dtype=self._dtype,
dtype="uint8",
is_bias=False,
default_initializer=Constant(value=0),
)
qkv_scale = self.create_parameter(
shape=[int(3 * num_heads * self.head_dim)],
Expand All @@ -1617,10 +1618,11 @@ def get_attr(attrs, idx):
'''
linear_weight = self.create_parameter(
shape=[embed_dim if weight_int8 else int(embed_dim / 2),
int(num_heads * self.head_dim)],
int(num_heads * self.head_dim)],
attr=linear_weight_attr,
dtype=self._dtype,
dtype="uint8",
is_bias=False,
default_initializer=Constant(value=0),
)
linear_scale = self.create_parameter(
shape=[embed_dim],
Expand Down Expand Up @@ -1654,10 +1656,12 @@ def get_attr(attrs, idx):
)
'''
ffn1_weight = self.create_parameter(
shape=[dim_feedforward, embed_dim],
shape=[dim_feedforward if weight_int8 else int(dim_feedforward / 2),
embed_dim],
attr=ffn1_weight_attr,
dtype=self._dtype,
dtype="uint8",
is_bias=False,
default_initializer=Constant(value=0),
)
ffn1_scale = self.create_parameter(
shape=[dim_feedforward],
Expand All @@ -1681,10 +1685,12 @@ def get_attr(attrs, idx):
)
'''
ffn2_weight = self.create_parameter(
shape=[embed_dim, dim_feedforward],
shape=[embed_dim if weight_int8 else int(embed_dim / 2),
dim_feedforward],
attr=ffn2_weight_attr,
dtype=self._dtype,
dtype="uint8",
is_bias=False,
default_initializer=Constant(value=0),
)
ffn2_scale = self.create_parameter(
shape=[embed_dim],
Expand Down

0 comments on commit ca25c78

Please sign in to comment.