Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the bug of channel-wise quantization for ernie #34948

Merged
merged 24 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f4f31f2
support quantization of conv2d_transpose
XGZhang11 Aug 2, 2021
11fbba0
Merge branch 'PaddlePaddle:develop' into develop
XGZhang11 Aug 5, 2021
ac21a60
fix quantization bugs
XGZhang11 Aug 5, 2021
350048e
Update post_training_quantization.py
XGZhang11 Aug 8, 2021
cdfa3fe
Update quantization_pass.py
XGZhang11 Aug 8, 2021
111387f
Merge branch 'PaddlePaddle:develop' into develop
XGZhang11 Aug 8, 2021
9cfc38f
Merge branch 'PaddlePaddle:develop' into develop
XGZhang11 Aug 9, 2021
4b047da
update docs
XGZhang11 Aug 9, 2021
e5ea4eb
add tests for quantized_conv2d_transpose
XGZhang11 Aug 9, 2021
3231853
update codestyle
XGZhang11 Aug 9, 2021
da48df7
update docs
XGZhang11 Aug 9, 2021
7981cb3
Merge branch 'PaddlePaddle:develop' into develop
XGZhang11 Aug 14, 2021
43976be
update tests and conv2dtranspose layer
XGZhang11 Aug 14, 2021
8ec36b6
update quant tests
XGZhang11 Aug 14, 2021
fa20111
Merge branch 'PaddlePaddle:develop' into develop
XGZhang11 Aug 16, 2021
fc74ab0
update sampcd_processor for tests
XGZhang11 Aug 16, 2021
ccd1675
update code examples
XGZhang11 Aug 16, 2021
199cf30
Merge branch 'PaddlePaddle:develop' into develop
XGZhang11 Aug 16, 2021
a5b7c71
fix channel_wise quantization for ernie
XGZhang11 Aug 16, 2021
a0f8b32
Merge branch 'PaddlePaddle:develop' into ernie_channel
XGZhang11 Aug 17, 2021
891d1e5
update fake_dequant op
XGZhang11 Aug 17, 2021
5f047ff
register new attr in enhanced pass
XGZhang11 Aug 17, 2021
e33249a
Merge branch 'PaddlePaddle:develop' into ernie_channel
XGZhang11 Aug 17, 2021
76d7f07
Update quant_conv2d_dequant_fuse_pass.cc
XGZhang11 Aug 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ QuantDequantFusePass::QuantDequantFusePass() {
.AddAttr("quant_axis")
.IsIntIn({0, 1})
.IsOptional()
.End()
.AddAttr("x_num_col_dims")
.IsType<int>()
.IsOptional()
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,8 @@ def {
name: "quant_axis"
type: INT
}
attrs {
name: "x_num_col_dims"
type: INT
}
}
81 changes: 61 additions & 20 deletions paddle/fluid/operators/fake_dequantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, const int quant_axis,
framework::Tensor* out) {
const int x_num_col_dims, framework::Tensor* out) {
if (scale_num == 1) {
// Dequant op is before quantized op
// Dequantize the weight of quantized op
Expand Down Expand Up @@ -81,23 +81,50 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
} else if (scale_num == 2) {
// Dequant op is after quantized op
// Dequantize the output tensor of quantized op
int batch_size = in->dims()[0];
int channel = in->dims()[1];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
for (int i = 0; i < batch_size; i++) {
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
framework::slice_ddim(in->dims(), 1, in->dims().size()));
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
framework::slice_ddim(out->dims(), 1, out->dims().size()));
for (int j = 0; j < channel; j++) {
T s = scale_one[j];
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * s * scale_two[0] / max_range;
if (x_num_col_dims > 1) {
auto in_dims = in->dims();
const int64_t channel = in_dims[x_num_col_dims];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
int64_t out_iter = 1;
for (int i = 0; i < x_num_col_dims; i++) {
out_iter *= in_dims[i];
}
int64_t step_i = in->numel() / out_iter;
int64_t step_j = in->numel() / (out_iter * channel);
auto* in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
for (int64_t i = 0; i < out_iter; i++) {
for (int64_t j = 0; j < channel; j++) {
auto* cur_in = in_data + i * step_i + j * step_j;
auto* cur_out = out_data + i * step_i + j * step_j;
T s = scale_one[j];
for (int64_t k = 0; k < step_j; k++) {
*cur_out = (*cur_in) * s * scale_two[0] / max_range;
++cur_in;
++cur_out;
}
}
}
} else {
int batch_size = in->dims()[0];
int channel = in->dims()[1];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
for (int i = 0; i < batch_size; i++) {
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
framework::slice_ddim(in->dims(), 1, in->dims().size()));
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
framework::slice_ddim(out->dims(), 1, out->dims().size()));
for (int j = 0; j < channel; j++) {
T s = scale_one[j];
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * s * scale_two[0] / max_range;
}
}
}
}
Expand Down Expand Up @@ -199,7 +226,16 @@ class FakeChannelWiseDequantizeMaxAbsOpMaker
"the received is %d",
quant_axis));
});

AddAttr<int>("x_num_col_dims",
"The x_num_col_dims of mul. Only used for mul or matmul.")
.SetDefault(1)
.AddCustomChecker([](const int& x_num_col_dims) {
PADDLE_ENFORCE_EQ(x_num_col_dims == 0, false,
platform::errors::InvalidArgument(
"'x_num_col_dims' should be larger than 0, but "
"the received is %d",
x_num_col_dims));
});
AddComment(R"DOC(
FakeChannelWiseDequantizeMaxAbsOp operator.

Expand Down Expand Up @@ -245,4 +281,9 @@ REGISTER_OP_VERSION(fake_channel_wise_dequantize_max_abs)
R"ROC(add new attributes [quant_axis] for applying per-channel "
"dequantization to conv2d_tranpose and mul ops.)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"quant_axis", "The axis for dequantization.", 0));
"quant_axis", "The axis for dequantization.", 0))
.AddCheckpoint(
R"ROC(add new attributes [x_num_col_dims] for applying per-channel "
"dequantization to mul ops.)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"x_num_col_dims", "The x_num_col_dims for dequantization.", 1));
17 changes: 10 additions & 7 deletions paddle/fluid/operators/fake_dequantize_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ __global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale,
template <typename T>
__global__ void DequantizeTwoScale(const T* in, const T* scale_one,
const T* scale_two, T max_range, int num,
int batch_size, int channel, T* out) {
int iter_size, int channel, T* out) {
int tid = threadIdx.x;
int channel_size = num / (batch_size * channel);
int channel_size = num / (iter_size * channel);
int scale_index = blockIdx.x % channel;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
Expand All @@ -93,7 +93,7 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, const int quant_axis,
framework::Tensor* out) {
const int x_num_col_dims, framework::Tensor* out) {
auto in_dims = in->dims();
const T* in_data = in->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
Expand All @@ -116,14 +116,17 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
} else if (scale_num == 2) {
// Not need to consider quant_axis
int num = in->numel();
int batch_size = in->dims()[0];
int channel = in->dims()[1];
int iter_size = 1;
for (int i = 0; i < x_num_col_dims; i++) {
iter_size *= in->dims()[i];
}
int channel = in->dims()[x_num_col_dims];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
int block = 1024;
int grid = batch_size * channel;
int grid = iter_size * channel;
DequantizeTwoScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_one, scale_two, max_range, num, batch_size, channel,
in_data, scale_one, scale_two, max_range, num, iter_size, channel,
out_data);
}
}
Expand Down
12 changes: 7 additions & 5 deletions paddle/fluid/operators/fake_dequantize_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ template <typename DeviceContext, typename T>
struct ChannelDequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor** scales, const int scale_num,
T max_range, const int quant_axis, framework::Tensor* out);
T max_range, const int quant_axis, const int x_num_col_dims,
framework::Tensor* out);
};

template <typename DeviceContext, typename T>
Expand Down Expand Up @@ -64,6 +65,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {

auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
auto quant_axis = ctx.Attr<int>("quant_axis");
auto x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
int max_range = 1;

auto& dev_ctx = ctx.template device_context<DeviceContext>();
Expand All @@ -80,11 +82,11 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
max_range *= (std::pow(2, quant_bits[0] - 1) - 1);
} else if (scale_num == 2) {
PADDLE_ENFORCE_EQ(
scales[0]->numel(), in->dims()[1],
scales[0]->numel(), in->dims()[x_num_col_dims],
platform::errors::PreconditionNotMet(
"The number of first scale values must be the same with "
"second dimension value of Input(X) when the `Scales` has two "
"elements, but %ld != %ld here.",
"corresponding dimension value of Input(X) when the `Scales` "
"has two elements, but %ld != %ld here.",
scales[0]->numel(), in->dims()[1]));
PADDLE_ENFORCE_EQ(scales[1]->numel(), 1,
platform::errors::PreconditionNotMet(
Expand All @@ -96,7 +98,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
}
ChannelDequantizeFunctor<DeviceContext, T>()(
dev_ctx, in, scales.data(), scale_num, static_cast<T>(max_range),
quant_axis, out);
quant_axis, x_num_col_dims, out);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1273,12 +1273,17 @@ def _insert_post_channel_dequant_op(self, graph, op_node, quant_axis):
var_type=output_var_node.type(),
shape=output_var_node.shape(),
var_dtype=output_var_node.dtype())
if op_node.op().has_attr("x_num_col_dims"):
x_num_col_dims = op_node.op().attr("x_num_col_dims")
else:
x_num_col_dims = 1
dequant_op_node = graph.create_op_node(
op_type='fake_channel_wise_dequantize_max_abs',
attrs={
'quant_bits': [self._weight_bits, self._activation_bits],
'quant_axis': quant_axis,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'x_num_col_dims': x_num_col_dims
},
inputs={
'X': output_var_node,
Expand Down