Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jul 22, 2019
1 parent 577387d commit 5f0406e
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 61 deletions.
24 changes: 24 additions & 0 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,20 @@ def annotate_context():


def collect_stats(graph):
"""Given an annotated graph, create a profile graph to collect profile data from the
calibration dataset. This pass finds simulated_quantize op and collects its input into a tuple.
The tuple is the output of the profile graph.
Parameters
----------
graph: Function
The simulation graph after annotation.
Returns
-------
ret: Function
The profile graph which outputs a tuple of profile data.
"""
return _quantize.CollectStats(graph)


Expand All @@ -215,6 +229,16 @@ def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None):
ctx: tvm.relay.PassContext
The pass context used for calibration.
weight_scales: 'power2' or 'max'.
The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT).
power2: Find the maximum of the absolute value of the tensor, and then round up to power
of two.
max: Find the maximum of the absolute value of the tensor.
scales: List[float]
Pre-calculated scales for input and activations. Length and the order of elements of the
scales list should match the output tuple of the profile graph created by collect_stats.
Returns
-------
ret: Function
Expand Down
90 changes: 90 additions & 0 deletions src/relay/pass/quantize/calibration.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
*
* \file calibration.cc
*
* \brief Create profile graph and calibrate on dataset
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include "./quantize.h"


namespace tvm {
namespace relay {
namespace quantize {

class StatsCollector : private ExprMutator {
public:
Expr Collect(const Expr& expr) {
auto new_e = this->Mutate(expr);
const FunctionNode* func = new_e.as<FunctionNode>();
CHECK(func) << "Input shoule be Function";
Expr new_body = TupleNode::make(std::move(profile_data_));
return FunctionNode::make(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
func->attrs);
}

private:
Array<Expr> profile_data_;

Expr VisitExpr_(const CallNode* call) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
Expr new_e = ExprMutator::VisitExpr_(call);
const CallNode* new_call = new_e.as<CallNode>();
CHECK(new_call);
if (new_call->op.same_as(simulated_quantize)) {
auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
const Expr& quantize_input = new_call->args[0]; // expression being quantized
if (attrs->kind != QAnnotateKind::kQWeight) {
CHECK(!quantize_input.as<ConstantNode>());
profile_data_.push_back(quantize_input);
}
return quantize_input;
} else {
return new_e;
}
}
};

/*
* \brief Given an annotated graph, create a profile graph to collect profile data from the
*
* calibration dataset.
*
* This pass finds simulated_quantize op and collects its input into a tuple. The tuple is the
* output of the profile graph. Both input and output of this pass
* are relay::Function.
*
* \param expr Expression after Annotate pass.
* \return The profile graph.
*/
Expr CollectStats(const Expr& expr) {
return StatsCollector().Collect(expr);
}

TVM_REGISTER_API("relay._quantize.CollectStats")
.set_body_typed(CollectStats);

} // namespace quantize
} // namespace relay
} // namespace tvm
62 changes: 2 additions & 60 deletions src/relay/pass/quantize.cc → src/relay/pass/quantize/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
#include <vector>
#include <stack>
#include <utility>
#include "pattern_util.h"
#include "quantize.h"
#include "../pattern_util.h"
#include "./quantize.h"


namespace tvm {
Expand All @@ -46,22 +46,6 @@ namespace quantize {

using namespace relay::transform;

/*! \brief Attribute for simulated quantize operator */
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
int kind;
bool sign;
std::string rounding;

TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
TVM_ATTR_FIELD(kind)
.describe("kind of field, hint for nbit/dtype configuration.");
TVM_ATTR_FIELD(sign).set_default(true)
.describe("whether to use signed data type.");
TVM_ATTR_FIELD(rounding).set_default("round")
.describe("rounding mode. Can be 'floor', 'ceil', 'round'");
}
};

TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs);

bool SimulatedQuantizeRel(const Array<Type>& types,
Expand Down Expand Up @@ -739,48 +723,6 @@ TVM_REGISTER_API("relay._quantize.temp_expr_realize")
return n->Realize();
});

// =============
// calibration

class StatsCollector : private ExprMutator {
public:
Expr Collect(const Expr& expr) {
auto new_e = this->Mutate(expr);
const FunctionNode* func = new_e.as<FunctionNode>();
CHECK(func);
Expr new_body = TupleNode::make(std::move(profile_data_));
return FunctionNode::make(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
func->attrs);
}

private:
Array<Expr> profile_data_;

Expr VisitExpr_(const CallNode* call) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
Expr new_e = ExprMutator::VisitExpr_(call);
const CallNode* new_call = new_e.as<CallNode>();
CHECK(new_call);
if (new_call->op.same_as(simulated_quantize)) {
auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
if (attrs->kind != QAnnotateKind::kQWeight) {
CHECK(!new_call->args[0].as<ConstantNode>());
const Expr& quantize_input = new_call->args[0]; // expression being quantized
profile_data_.push_back(quantize_input);
}
return new_call->args[0];
} else {
return new_e;
}
}
};

Expr CollectStats(const Expr& expr) {
return StatsCollector().Collect(expr);
}

TVM_REGISTER_API("relay._quantize.CollectStats")
.set_body_typed(CollectStats);

} // namespace quantize
} // namespace relay
Expand Down
18 changes: 17 additions & 1 deletion src/relay/pass/quantize.h → src/relay/pass/quantize/quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <string>
#include "pattern_util.h"
#include "../pattern_util.h"

namespace tvm {
namespace relay {
Expand All @@ -42,6 +42,22 @@ enum QAnnotateKind : int {
kQActivation = 3,
};

/*! \brief Attribute for simulated quantize operator */
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
int kind;
bool sign;
std::string rounding;

TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
TVM_ATTR_FIELD(kind)
.describe("kind of field, hint for nbit/dtype configuration.");
TVM_ATTR_FIELD(sign).set_default(true)
.describe("whether to use signed data type.");
TVM_ATTR_FIELD(rounding).set_default("round")
.describe("rounding mode. Can be 'floor', 'ceil', 'round'");
}
};

/*!
* \brief TempExpr used during annotate forward rewrite.
*/
Expand Down

0 comments on commit 5f0406e

Please sign in to comment.