Skip to content

Commit

Permalink
[Relay][Quantization] Fix out-of-date realize
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Aug 16, 2019
1 parent 674feba commit 8fd8cc2
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions src/relay/pass/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
LOG(FATAL) << "fall back to float computation";
data = Cast(data, Float(32));
data = Multiply(data, MakeConstantScalar(Float(32), factor));
return Cast(Round(data), dtype);
Expand Down Expand Up @@ -147,15 +146,21 @@ Expr QuantizeRealize(const Call& ref_call,
}

float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
CHECK_GT(shift_nbit, 0);
CHECK_NE(shift_nbit, 0);
if (static_cast<int>(shift_nbit) == shift_nbit) {
// use right shift
if (cfg->round_for_shift) {
float round_bias = std::pow(2.0, shift_nbit - 1);
data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias)));
if (shift_nbit > 0) {
// use right shift
if (cfg->round_for_shift) {
float round_bias = std::pow(2.0, shift_nbit - 1);
data = Add(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(round_bias)));
}
data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
} else {
data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
}
data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
} else {
Expand Down

0 comments on commit 8fd8cc2

Please sign in to comment.