Skip to content

Commit

Permalink
[SCHEDULE] Improve bound inference, support reduce codegen.
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 2, 2017
1 parent d4af7ad commit 68b30c9
Show file tree
Hide file tree
Showing 32 changed files with 1,250 additions and 646 deletions.
21 changes: 12 additions & 9 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ using Halide::Internal::IRPrinter;
using Halide::Internal::Variable;

using Halide::Internal::make_const;
using Halide::Internal::make_zero;
using Halide::Internal::as_const_int;
using Halide::Internal::as_const_uint;


inline Type TVMType2Type(TVMType t) {
Expand Down Expand Up @@ -126,25 +129,25 @@ using Halide::abs;
using Halide::select;

/*!
* \brief sum of of source expression over rdom
* \brief sum of of source expression over axis
* \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
* \param axis List of iteration variables that will be used for reduction.
*/
Expr sum(Expr source, Array<IterVar> rdom);
Expr sum(Expr source, Array<IterVar> axis);

/*!
* \brief max of of source expression over rdom
* \brief max of of source expression over axis
* \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
* \param axis List of iteration variables that will be used for reduction.
*/
Expr max(Expr source, Array<IterVar> rdom);
Expr max(Expr source, Array<IterVar> axis);

/*!
* \brief max of of source expression over rdom
* \brief max of of source expression over axis
* \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
* \param axis List of iteration variables that will be used for reduction.
*/
Expr min(Expr source, Array<IterVar> rdom);
Expr min(Expr source, Array<IterVar> axis);


// print functions for expr
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ struct Reduce : public ExprNode<Reduce> {
std::string op;
/*! \brief The source operand */
Expr source;
/*! \brief The reduction domains */
Array<IterVar> rdom;
/*! \brief The reduction axis */
Array<IterVar> axis;

/*! \brief construct expr from op and rdom */
static Expr make(std::string op, Expr src, Array<IterVar> rdom);
Expand All @@ -40,7 +40,7 @@ struct Reduce : public ExprNode<Reduce> {
v->Visit("dtype", &type);
v->Visit("op", &op);
v->Visit("source", &source);
v->Visit("rdom", &rdom);
v->Visit("axis", &axis);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce";
Expand Down
21 changes: 10 additions & 11 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
* \file ir_pass.h
* \brief Collection of IR pass functions
*
* All the pass functions in this file are for Stmt,
* We can use PassFunction(Evaluate(expr)) to apply it to Expr
* When the pass functions in this file are for Stmt,
* we can use PassFunction(Evaluate(expr)) to apply it to Expr
*/
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
Expand Down Expand Up @@ -37,15 +37,6 @@ inline Stmt Simplify(Stmt a) {
return Halide::Internal::simplify(a);
}

/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);

/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
Expand All @@ -69,6 +60,14 @@ bool HasSideEffect(const Expr& e);
*/
Stmt ConvertSSA(Stmt stmt);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map);

/*!
* \brief inline all calls of f in stmt.
*
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class ComputeOpNode : public OperationNode {
public:
/*! \brief IterVar on each axis */
Array<IterVar> axis;
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array<IterVar> reduce_axis;
/*! \brief the compute expression */
Expr body;
/*! \brief constructor */
Expand All @@ -64,6 +66,7 @@ class ComputeOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body);
}
static Operation make(std::string name,
Expand Down
37 changes: 37 additions & 0 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ class Stage : public NodeRef {
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor);
// declare container type
using ContainerType = StageNode;
};

/*!
Expand Down Expand Up @@ -152,11 +154,22 @@ class Schedule : public NodeRef {
Stage operator[](const Tensor& tensor) {
return this->operator[](tensor->op);
}
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
* Insert necessary RebaseNode to make sure all leaf_iter_vars
* are in form [0, extent)
*
* \return A normalized schedule, can be same as current one.
*/
void normalize();
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
// declare container type
using ContainerType = ScheduleNode;
};

/*!
Expand Down Expand Up @@ -308,6 +321,30 @@ class FuseNode : public IterVarRelationNode {
TVM_DECLARE_NODE_TYPE_INFO(FuseNode);
};

/*!
* \brief Rebase the iteration to make min to be 0.
* This is useful to normalize the Schedule
* to make every leaf variable's min to be 0.
*/
class RebaseNode : public IterVarRelationNode {
public:
/*! \brief The parent domain */
IterVar parent;
/*! \brief The inner domain */
IterVar rebased;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("parent", &parent);
v->Visit("rebased", &rebased);
}

static IterVarRelation make(IterVar parent, IterVar rebased);

static constexpr const char* _type_key = "Rebase";
TVM_DECLARE_NODE_TYPE_INFO(RebaseNode);
};


// implementations
inline const StageNode* Stage::operator->() const {
return static_cast<const StageNode*>(node_.get());
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/schedule_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ namespace schedule {
*/
Map<IterVar, Range> InferBound(Schedule sch);

/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);

} // namespace schedule
} // namespace tvm
#endif // TVM_SCHEDULE_PASS_H_
36 changes: 18 additions & 18 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,51 +212,51 @@ def IterVar(dom=None, name=None, thread_tag=''):
return _api_internal._IterVar(dom, name, thread_tag)


def sum(expr, rdom):
"""Create a sum expression over rdom
def sum(expr, axis):
"""Create a sum expression over axis
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
axis : IterVar
The reduction IterVar axis
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Add", expr, rdom)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Add", expr, axis)
return x


def min(expr, rdom):
"""Create a min expression over rdom
def min(expr, axis):
"""Create a min expression over axis
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
axis : IterVar
The reduction IterVar axis
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Min", expr, rdom)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Min", expr, axis)
return x


def max(expr, rdom):
"""Create a min expression over rdom
def max(expr, axis):
"""Create a min expression over axis
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
axis : IterVar
The reduction IterVar axis
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Max", expr, rdom)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Max", expr, axis)
return x


Expand Down
6 changes: 4 additions & 2 deletions python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def build(sch,

# lowering
bounds = schedule.InferBound(sch)
stmt = ir_pass.ScheduleOps(sch, bounds)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.Simplify(stmt)
print(stmt)
fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = codegen.SplitHostDevice(fapi)

Expand All @@ -73,7 +74,8 @@ def build(sch,
for i, f in enumerate(fsplits):
t = target if i >= 1 else "c"
record_codes.append(codegen.CompileToC(f, output_ssa, t))

for c in record_codes:
print(c)
if target == "cuda":
ret = codegen.BuildNVRTC(fsplits, "stackvm")
elif target == "opencl":
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def __getitem__(self, k):
raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
return self.stage_map[k]

def normalize(self):
"""Build a normalized schedule.
Insert necessary rebase to make certain iter var to start from 0.
This is needed before bound inference and followup step.
"""
_api_internal._ScheduleNormalize(self)

@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
Expand Down
6 changes: 6 additions & 0 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,10 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});

TVM_REGISTER_API(_ScheduleNormalize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Schedule()
.normalize();
});

} // namespace tvm
1 change: 0 additions & 1 deletion src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ TVM_REGISTER_API(_pass_Equal)
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(ScheduleOps);
REGISTER_PASS2(StorageFlatten);

} // namespace ir
Expand Down
1 change: 1 addition & 0 deletions src/api/api_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace schedule {
REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS2(ScheduleOps);

} // namespace schedule
} // namespace tvm
5 changes: 3 additions & 2 deletions src/codegen/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
* Copyright (c) 2017 by Contributors
* \file codegen_c.cc
*/
#include <iomanip>
#include "./codegen_c.h"

namespace tvm {
Expand Down Expand Up @@ -216,7 +217,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
switch (op->type.bits()) {
case 64: case 32: {
std::ostringstream temp;
temp << op->value;
temp << std::scientific << op->value;
if (op->type.bits() == 32) temp << 'f';
p->MarkConst(temp.str());
os << temp.str();
Expand All @@ -225,7 +226,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
case 16: {
os << '(';
p->PrintType(op->type, os);
os << ')' << op->value << 'f';
os << ')' << std::scientific <<op->value << 'f';
break;
}
default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n";
Expand Down
10 changes: 5 additions & 5 deletions src/lang/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<< op->op
<< ", ";
p->print(op->source);
p->stream << ", rdom=" << op->rdom << ")";
p->stream << ", axis=" << op->axis << ")";
});

} // namespace Internal
Expand All @@ -35,16 +35,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace tvm {
namespace ir {

Expr Reduce::make(std::string op, Expr source, Array<IterVar> rdom) {
Expr Reduce::make(std::string op, Expr source, Array<IterVar> axis) {
auto n = std::make_shared<Reduce>();
CHECK(source.defined());
for (size_t i = 0; i < rdom.size(); ++i) {
CHECK(rdom[i].defined());
for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined());
}
n->type = source.type();
n->source = source;
n->op = op;
n->rdom = rdom;
n->axis = axis;
return Expr(n);
}

Expand Down
Loading

0 comments on commit 68b30c9

Please sign in to comment.