Skip to content

Commit

Permalink
drop QBias
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jul 16, 2019
1 parent 20ec855 commit 9d71db8
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 22 deletions.
7 changes: 4 additions & 3 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def multiply_rewrite(ref_call, new_args, ctx):
if lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
# quantize rhs to WEIGHT field
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.BIAS)
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

Expand All @@ -251,7 +251,7 @@ def add_rewrite(ref_call, new_args, ctx):

if lhs_kind is None and rhs_kind is not None:
# quantize lhs to INPUT field if it is normal expression
assert rhs_kind == QAnnotateKind.INPUT
assert rhs_kind in [QAnnotateKind.INPUT, QAnnotateKind.ACTIVATION]
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
Expand All @@ -275,7 +275,8 @@ def add_rewrite(ref_call, new_args, ctx):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT:
if (lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT) or \
(lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.ACTIVATION):
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError()
Expand Down
23 changes: 10 additions & 13 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class QAnnotateKind(object):
INPUT = 1
WEIGHT = 2
ACTIVATION = 3
BIAS = 4


def kind2str(kind):
Expand Down Expand Up @@ -203,7 +202,7 @@ def collect_stats(graph):
return _quantize.CollectStats(graph)


def calibrate(graph, mod=None, ctx=None, scales=None):
def calibrate(graph, mod=None, ctx=None, weight_scales='max', scales=None):
"""The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator.
Expand All @@ -230,14 +229,12 @@ def power2_scale(arr):
return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0

def max_scale(arr):
"""calculate weight scale with maximum absolute value"""
val = np.amax(np.abs(arr.asnumpy()))
return val

scale_idx = 0

#fcalib_weight = power2_scale
fcalib_weight = max_scale

cfg = current_qconfig()
const_params = {}
quantize_op = _op.get("relay.op.annotation.simulated_quantize")
Expand All @@ -253,27 +250,27 @@ def visit_func(expr):

valid_bit = nbit - attrs.sign
if kind in [QAnnotateKind.WEIGHT, QAnnotateKind.BIAS]:
if all([isinstance(arg, _expr.Constant) for arg in [ndom_scale, nclip_min, nclip_max]]):
if all([isinstance(arg, _expr.Constant)
for arg in [ndom_scale, nclip_min, nclip_max]]):
return
var = expr.args[0]
assert isinstance(var, _expr.Constant)
scale = fcalib_weight(var.data)
print('weight scale: {}'.format(scale))
if weight_scales == 'max':
scale = max_scale(var.data)
elif weight_scales == 'power2':
scale = power2_scale(var.data)
else:
raise ValueError('{} not supported'.format(weight_scales))
elif scales is not None:
scale = scales[scale_idx]
scale_idx += 1
print('{} / {} ...'.format(scale_idx, len(scales)))
print('act scale: {}'.format(scale))
else:
scale = cfg.global_scale

def _make_const(val):
return _expr.const(val, 'float32')

valid_range = 2**valid_bit
if kind == QAnnotateKind.BIAS:
# bias hack
valid_range = 2**15

const_params[ndom_scale] = _make_const(scale / valid_range)
const_params[nclip_min] = _make_const(- (valid_range - 1))
Expand Down
3 changes: 1 addition & 2 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,6 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "nbit_input=" << op->nbit_input << ", ";
p->stream << "nbit_weight=" << op->nbit_weight << ", ";
p->stream << "nbit_activation=" << op->nbit_activation << ", ";
p->stream << "nbit_bias=" << op->nbit_bias << ", ";
p->stream << "global_scale=" << op->global_scale << ", ";
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
Expand Down Expand Up @@ -765,7 +764,7 @@ class StatsCollector : private ExprMutator {
CHECK(new_call);
if (new_call->op.same_as(simulated_quantize)) {
auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
if (attrs->kind != QAnnotateKind::kQWeight && attrs->kind != QAnnotateKind::kQBias) {
if (attrs->kind != QAnnotateKind::kQWeight) {
CHECK(!new_call->args[0].as<ConstantNode>());
const Expr& quantize_input = new_call->args[0]; // expression being quantized
profile_data_.push_back(quantize_input);
Expand Down
4 changes: 0 additions & 4 deletions src/relay/pass/quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ enum QAnnotateKind : int {
kQInput = 1,
kQWeight = 2,
kQActivation = 3,
kQBias = 4,
};

/*!
Expand Down Expand Up @@ -149,7 +148,6 @@ class QConfigNode : public Node {
int nbit_input = 8;
int nbit_weight = 8;
int nbit_activation = 32;
int nbit_bias = 32;
DataType dtype_input = Int(8);
DataType dtype_weight = Int(8);
DataType dtype_activation = Int(32);
Expand All @@ -164,11 +162,9 @@ class QConfigNode : public Node {
v->Visit("nbit_input", &nbit_input);
v->Visit("nbit_weight", &nbit_weight);
v->Visit("nbit_activation", &nbit_activation);
v->Visit("nbit_bias", &nbit_bias);
v->Visit("dtype_input", &dtype_input);
v->Visit("dtype_weight", &dtype_weight);
v->Visit("dtype_activation", &dtype_activation);
v->Visit("dtype_bias", &dtype_bias);
v->Visit("global_scale", &global_scale);
v->Visit("skip_conv_layers", &skip_conv_layers);
v->Visit("round_for_shift", &round_for_shift);
Expand Down

0 comments on commit 9d71db8

Please sign in to comment.