Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#3 from jczaja/prv-onednn-3.0-quantize
Browse files Browse the repository at this point in the history
onednn 3.0 quantize & dequantize
  • Loading branch information
jczaja committed Mar 9, 2023
2 parents 0b287b2 + 89926a7 commit 2894b63
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 29 deletions.
17 changes: 7 additions & 10 deletions paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<phi::DenseTensor>("Input");
const auto quantization_scale = ctx.Attr<float>("Scale");
const auto quantization_shift = ctx.Attr<float>("Shift");
const auto quantization_shift =
static_cast<int32_t>(ctx.Attr<float>("Shift"));
const bool with_shift = quantization_shift != 0.0f;
auto* out = ctx.Output<phi::DenseTensor>("Output");

Expand All @@ -56,12 +57,10 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
dnnl::primitive_attr attrs;
static constexpr int32_t mask = 0; // same shift and scale for whole tensor

const float reorder_scale = 1. / quantization_scale;
// attrs.set_output_scales(mask, {reorder_scale});
attrs.set_scales_mask(DNNL_ARG_DST, mask);

if (with_shift) {
attrs.set_zero_points_mask(DNNL_ARG_DST, mask);
attrs.set_zero_points_mask(DNNL_ARG_SRC, mask);
}

phi::funcs::ReorderOneDNNHandler reorder_handler(
Expand All @@ -82,25 +81,23 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
auto scales_mem =
dnnl::memory(scales_md,
dev_ctx.GetEngine(),
phi::funcs::to_void_cast<float>(&reorder_scale));
phi::funcs::to_void_cast<float>(&quantization_scale));

auto zero_points_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::u8, dnnl::memory::format_tag::x);
{1}, dnnl::memory::data_type::s32, dnnl::memory::format_tag::x);
auto zero_points_mem =
dnnl::memory(zero_points_md,
dev_ctx.GetEngine(),
phi::funcs::to_void_cast<float>(&quantization_shift));
phi::funcs::to_void_cast<int32_t>(&quantization_shift));
std::unordered_map<int, dnnl::memory> reorder_args;
reorder_args.insert({DNNL_ARG_SRC, *reorder_src_memory_p});
reorder_args.insert({DNNL_ARG_DST, *reorder_dst_memory_p});
reorder_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, scales_mem});
if (with_shift) {
reorder_args.insert(
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, zero_points_mem});
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, zero_points_mem});
}
reorder_p->execute(astream, reorder_args);

reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();

out->set_mem_desc(reorder_dst_memory_p->get_desc());
Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class QuantOpKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<phi::DenseTensor>("Output");

const auto quantization_scale = ctx.Attr<float>("Scale");
const auto quantization_shift = ctx.Attr<float>("Shift");
const auto quantization_shift =
static_cast<int32_t>(ctx.Attr<float>("Shift"));
const bool with_scale = quantization_scale != 1.0f;
const bool with_shift = quantization_shift != 0.0f;

Expand All @@ -61,7 +62,7 @@ class QuantOpKernel : public framework::OpKernel<T> {
static constexpr int32_t mask = 0;

if (with_scale) {
attrs.set_scales_mask(DNNL_ARG_DST, mask);
attrs.set_scales_mask(DNNL_ARG_SRC, mask);
}

if (with_shift) {
Expand Down Expand Up @@ -105,13 +106,13 @@ class QuantOpKernel : public framework::OpKernel<T> {
auto zero_points_mem =
dnnl::memory(zero_points_md,
dev_ctx.GetEngine(),
phi::funcs::to_void_cast<float>(&quantization_shift));
phi::funcs::to_void_cast<int32_t>(&quantization_shift));

std::unordered_map<int, dnnl::memory> reorder_args;
reorder_args.insert({DNNL_ARG_SRC, *reorder_src_memory_p});
reorder_args.insert({DNNL_ARG_DST, *reorder_dst_memory_p});
if (with_scale) {
reorder_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, scales_mem});
reorder_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scales_mem});
}
if (with_shift) {
reorder_args.insert(
Expand Down
37 changes: 22 additions & 15 deletions paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<phi::DenseTensor>("Input");
auto scale_in = ctx.Attr<float>("Scale_in");
auto shift_in = ctx.Attr<float>("Shift_in");
auto shift_in = static_cast<int32_t>(ctx.Attr<float>("Shift_in"));
auto scale_out = ctx.Attr<float>("Scale_out");
auto shift_out = ctx.Attr<float>("Shift_out");
bool with_shift = shift_in != 0.0f || shift_out != 0.0f;
auto shift_out = static_cast<int32_t>(ctx.Attr<float>("Shift_out"));
bool with_shift = shift_in != 0 || shift_out != 0;
auto* output = ctx.Output<phi::DenseTensor>("Output");

PADDLE_ENFORCE_NE(
Expand All @@ -53,7 +53,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
scale_out,
0.0f,
platform::errors::InvalidArgument("Scale of output cannot be 0.0"));
if (shift_in != 0.0f) {
if (shift_in != 0) {
PADDLE_ENFORCE_EQ(
input->dtype(),
DataType::UINT8,
Expand All @@ -72,7 +72,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> {

dnnl::primitive_attr attrs;
int mask = 0;
float reorder_scale = scale_out / scale_in;
float reorder_scale = scale_in / scale_out;
attrs.set_scales_mask(DNNL_ARG_DST, mask);
auto scales_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
Expand All @@ -81,12 +81,12 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
dev_ctx.GetEngine(),
phi::funcs::to_void_cast<float>(&reorder_scale));

uint8_t reorder_shift =
with_shift ? clip_to_uint8(shift_out - reorder_scale * shift_in) : 0;

if (with_shift) {
if (shift_out != 0) {
attrs.set_zero_points_mask(DNNL_ARG_DST, mask);
}
if (shift_in != 0) {
attrs.set_zero_points_mask(DNNL_ARG_SRC, mask);
}

phi::funcs::ReorderOneDNNHandler reorder_handler(
src_tz,
Expand All @@ -107,18 +107,25 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
auto& astream = phi::OneDNNContext::tls().get_stream();

auto zero_points_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::u8, dnnl::memory::format_tag::x);
auto zero_points_mem = dnnl::memory(zero_points_md,
dev_ctx.GetEngine(),
static_cast<uint8_t*>(&reorder_shift));
{1}, dnnl::memory::data_type::s32, dnnl::memory::format_tag::x);
auto zero_points_in_mem =
dnnl::memory(zero_points_md, dev_ctx.GetEngine(), &shift_in);
auto zero_points_out_mem =
dnnl::memory(zero_points_md, dev_ctx.GetEngine(), &shift_out);

std::unordered_map<int, dnnl::memory> reorder_args;
reorder_args.insert({DNNL_ARG_SRC, *src_memory_p});
reorder_args.insert({DNNL_ARG_DST, *dst_memory_p});
reorder_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, scales_mem});
if (with_shift) {
// shift for SRC
if (shift_in != 0) {
reorder_args.insert(
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, zero_points_in_mem});
}
// shift for DST
if (shift_out != 0) {
reorder_args.insert(
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, zero_points_mem});
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, zero_points_out_mem});
}

reorder_p->execute(astream, reorder_args);
Expand Down

0 comments on commit 2894b63

Please sign in to comment.