From 4d0fe2d620c8f3201104baa950acdd3eda690b2a Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 16 Aug 2019 03:26:29 +0000 Subject: [PATCH] [Relay][Quantization] Fix out-of-date realize --- src/relay/pass/quantize/realize.cc | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/relay/pass/quantize/realize.cc b/src/relay/pass/quantize/realize.cc index e4bc63adc6a0a..7eae9992c9e4b 100644 --- a/src/relay/pass/quantize/realize.cc +++ b/src/relay/pass/quantize/realize.cc @@ -110,7 +110,6 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) { } else if (static_cast(factor) == factor) { return Multiply(data, MakeConstantScalar(dtype, factor)); } else { - LOG(FATAL) << "fall back to float computation"; data = Cast(data, Float(32)); data = Multiply(data, MakeConstantScalar(Float(32), factor)); return Cast(Round(data), dtype); @@ -147,15 +146,21 @@ Expr QuantizeRealize(const Call& ref_call, } float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); - CHECK_GT(shift_nbit, 0); + CHECK_NE(shift_nbit, 0); if (static_cast(shift_nbit) == shift_nbit) { - // use right shift - if (cfg->round_for_shift) { - float round_bias = std::pow(2.0, shift_nbit - 1); - data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast(round_bias))); + if (shift_nbit > 0) { + // use right shift + if (cfg->round_for_shift) { + float round_bias = std::pow(2.0, shift_nbit - 1); + data = Add(data, MakeConstantScalar(cfg->dtype_activation, + static_cast(round_bias))); + } + data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, + static_cast(shift_nbit))); + } else { + data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation, + static_cast(shift_nbit))); } - data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(shift_nbit))); data = Clip(data, clip_min_imm, clip_max_imm); return QRealizeIntExprNode::make(data, dom_scale, n->dtype); } else {