From 5f0406e571739a0ca60fc04ffff163fcb509b1b7 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 22 Jul 2019 06:40:13 +0000 Subject: [PATCH] address comments --- python/tvm/relay/quantize/quantize.py | 24 ++++++ src/relay/pass/quantize/calibration.cc | 90 +++++++++++++++++++++++ src/relay/pass/{ => quantize}/quantize.cc | 62 +--------------- src/relay/pass/{ => quantize}/quantize.h | 18 ++++- 4 files changed, 133 insertions(+), 61 deletions(-) create mode 100644 src/relay/pass/quantize/calibration.cc rename src/relay/pass/{ => quantize}/quantize.cc (92%) rename src/relay/pass/{ => quantize}/quantize.h (91%) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 562e97a79768a..ac806f2ab5f0b 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -196,6 +196,20 @@ def annotate_context(): def collect_stats(graph): + """Given an annotated graph, create a profile graph to collect profile data from the + calibration dataset. This pass finds simulated_quantize op and collects its input into a tuple. + The tuple is the output of the profile graph. + + Parameters + ---------- + graph: Function + The simulation graph after annotation. + + Returns + ------- + ret: Function + The profile graph which outputs a tuple of profile data. + """ return _quantize.CollectStats(graph) @@ -215,6 +229,16 @@ def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None): ctx: tvm.relay.PassContext The pass context used for calibration. + weight_scales: 'power2' or 'max'. + The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT). + power2: Find the maximum of the absolute value of the tensor, and then round up to power + of two. + max: Find the maximum of the absolute value of the tensor. + + scales: List[float] + Pre-calculated scales for input and activations. Length and the order of elements of the + scales list should match the output tuple of the profile graph created by collect_stats. + Returns ------- ret: Function diff --git a/src/relay/pass/quantize/calibration.cc b/src/relay/pass/quantize/calibration.cc new file mode 100644 index 0000000000000..c91ee01fd80d8 --- /dev/null +++ b/src/relay/pass/quantize/calibration.cc @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file calibration.cc + * + * \brief Create profile graph and calibrate on dataset + */ +#include +#include +#include "./quantize.h" + + +namespace tvm { +namespace relay { +namespace quantize { + +class StatsCollector : private ExprMutator { + public: + Expr Collect(const Expr& expr) { + auto new_e = this->Mutate(expr); + const FunctionNode* func = new_e.as(); + CHECK(func) << "Input shoule be Function"; + Expr new_body = TupleNode::make(std::move(profile_data_)); + return FunctionNode::make(FreeVars(new_body), new_body, NullValue(), func->type_params, + func->attrs); + } + + private: + Array profile_data_; + + Expr VisitExpr_(const CallNode* call) { + static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); + Expr new_e = ExprMutator::VisitExpr_(call); + const CallNode* new_call = new_e.as(); + CHECK(new_call); + if (new_call->op.same_as(simulated_quantize)) { + auto attrs = new_call->attrs.as(); + const Expr& quantize_input = new_call->args[0]; // expression being quantized + if (attrs->kind != QAnnotateKind::kQWeight) { + CHECK(!quantize_input.as()); + profile_data_.push_back(quantize_input); + } + return quantize_input; + } else { + return new_e; + } + } +}; + +/* + * \brief Given an annotated graph, create a profile graph to collect profile data from the + * + * calibration dataset. + * + * This pass finds simulated_quantize op and collects its input into a tuple. The tuple is the + * output of the profile graph. Both input and output of this pass + * are relay::Function. + * + * \param expr Expression after Annotate pass. + * \return The profile graph. + */ +Expr CollectStats(const Expr& expr) { + return StatsCollector().Collect(expr); +} + +TVM_REGISTER_API("relay._quantize.CollectStats") +.set_body_typed(CollectStats); + +} // namespace quantize +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize/quantize.cc similarity index 92% rename from src/relay/pass/quantize.cc rename to src/relay/pass/quantize/quantize.cc index 81044a9addad7..6cffc2053e5cd 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -36,8 +36,8 @@ #include #include #include -#include "pattern_util.h" -#include "quantize.h" +#include "../pattern_util.h" +#include "./quantize.h" namespace tvm { @@ -46,22 +46,6 @@ namespace quantize { using namespace relay::transform; -/*! \brief Attribute for simulated quantize operator */ -struct SimulatedQuantizeAttrs : public tvm::AttrsNode { - int kind; - bool sign; - std::string rounding; - - TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { - TVM_ATTR_FIELD(kind) - .describe("kind of field, hint for nbit/dtype configuration."); - TVM_ATTR_FIELD(sign).set_default(true) - .describe("whether to use signed data type."); - TVM_ATTR_FIELD(rounding).set_default("round") - .describe("rounding mode. Can be 'floor', 'ceil', 'round'"); - } -}; - TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs); bool SimulatedQuantizeRel(const Array& types, @@ -739,48 +723,6 @@ TVM_REGISTER_API("relay._quantize.temp_expr_realize") return n->Realize(); }); -// ============= -// calibration - -class StatsCollector : private ExprMutator { - public: - Expr Collect(const Expr& expr) { - auto new_e = this->Mutate(expr); - const FunctionNode* func = new_e.as(); - CHECK(func); - Expr new_body = TupleNode::make(std::move(profile_data_)); - return FunctionNode::make(FreeVars(new_body), new_body, NullValue(), func->type_params, - func->attrs); - } - - private: - Array profile_data_; - - Expr VisitExpr_(const CallNode* call) { - static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); - Expr new_e = ExprMutator::VisitExpr_(call); - const CallNode* new_call = new_e.as(); - CHECK(new_call); - if (new_call->op.same_as(simulated_quantize)) { - auto attrs = new_call->attrs.as(); - if (attrs->kind != QAnnotateKind::kQWeight) { - CHECK(!new_call->args[0].as()); - const Expr& quantize_input = new_call->args[0]; // expression being quantized - profile_data_.push_back(quantize_input); - } - return new_call->args[0]; - } else { - return new_e; - } - } -}; - -Expr CollectStats(const Expr& expr) { - return StatsCollector().Collect(expr); -} - -TVM_REGISTER_API("relay._quantize.CollectStats") -.set_body_typed(CollectStats); } // namespace quantize } // namespace relay diff --git a/src/relay/pass/quantize.h b/src/relay/pass/quantize/quantize.h similarity index 91% rename from src/relay/pass/quantize.h rename to src/relay/pass/quantize/quantize.h index 262d420acf97c..367b6785ed199 100644 --- a/src/relay/pass/quantize.h +++ b/src/relay/pass/quantize/quantize.h @@ -29,7 +29,7 @@ #include #include #include -#include "pattern_util.h" +#include "../pattern_util.h" namespace tvm { namespace relay { @@ -42,6 +42,22 @@ enum QAnnotateKind : int { kQActivation = 3, }; +/*! \brief Attribute for simulated quantize operator */ +struct SimulatedQuantizeAttrs : public tvm::AttrsNode { + int kind; + bool sign; + std::string rounding; + + TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { + TVM_ATTR_FIELD(kind) + .describe("kind of field, hint for nbit/dtype configuration."); + TVM_ATTR_FIELD(sign).set_default(true) + .describe("whether to use signed data type."); + TVM_ATTR_FIELD(rounding).set_default("round") + .describe("rounding mode. Can be 'floor', 'ceil', 'round'"); + } +}; + /*! * \brief TempExpr used during annotate forward rewrite. */