Skip to content

Commit

Permalink
[SCAN/Refactor] Refactor scan interface, enable fix point analysis. (#47
Browse files Browse the repository at this point in the history
)
  • Loading branch information
tqchen committed Feb 20, 2017
1 parent 5198c10 commit c8ec411
Show file tree
Hide file tree
Showing 21 changed files with 977 additions and 378 deletions.
4 changes: 1 addition & 3 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,12 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
/*!
* \brief Construct new tensors by scan over scan_axis.
*
* \param scan_axis The iteration representing the scan.
* \param init The intialize tensor of first K steps.
* \param update The update tensor indicated the updated result after each timestamp.
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
*/
Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name = "scan");
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ enum AttachType : int {
kNone = 0,
kRoot = 1,
kInline = 2,
kScope = 3
kInlinedAlready = 3,
kScope = 4,
kScanUpdate = 5
};

/*! \brief IterVar type */
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ class OperationNode : public FunctionBaseNode {
virtual Type output_dtype(size_t i) const = 0;
/*! \return shape of i-th output */
virtual Array<Expr> output_shape(size_t i) const = 0;

static constexpr const char* _type_key = "Operation";
};

// Implementations of inline functions
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/addon/nvcc_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
import tempfile
import subprocess

def compile_source(code, target="cubin"):
def compile_source(code, target="cubin", options=None):
"""Compile cuda code with NVCC from env.
Parameters
----------
code : str
The cuda code.
target: str
target : str
The target format
options : str
The additional options
Return
------
cubin : bytearray
Expand All @@ -32,6 +35,8 @@ def compile_source(code, target="cubin"):
cmd = ["nvcc"]
cmd += ["--%s" % target, "-O3"]
cmd += ["-o", path_target]
if options:
cmd += options
cmd += [path_code]
args = ' '.join(cmd)

Expand Down
11 changes: 4 additions & 7 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,11 @@ def compute(shape, fcompute, name="compute"):
return op_node.output(0)


def scan(axis, init, update, state_placeholder, name="scan"):
def scan(init, update, state_placeholder, name="scan"):
"""Construct new tensors by scanning over axis.
Parameters
----------
axis: IterVar
The scanning axis.
init: Tensor or list of Tensor
The initial condition of first init.shape[0] timestamps
Expand All @@ -170,12 +167,11 @@ def scan(axis, init, update, state_placeholder, name="scan"):
# The following code is equivalent to numpy.cumsum
m = tvm.Var("m")
n = tvm.Var("n")
t = tvm.IterVar((1, m), name="t")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state)
"""
if isinstance(init, _tensor.Tensor):
init = [init]
Expand All @@ -185,6 +181,7 @@ def scan(axis, init, update, state_placeholder, name="scan"):
state_placeholder = [state_placeholder]
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
axis = IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name)
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
res = [op.output(i) for i in range(len(update))]
return (res[0] if len(res) == 1 else res)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def build(sch,
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
# lowering
# normalize schedule first
sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
Expand Down
3 changes: 3 additions & 0 deletions src/api/api_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise)
REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(CreateAttachPath);
REGISTER_SCHEDULE_PASS2(ScanFixPointAnalysis);
REGISTER_SCHEDULE_PASS2(ScheduleOps);

} // namespace schedule
Expand Down
10 changes: 9 additions & 1 deletion src/arithmetic/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,15 @@ IntSet Union(const Array<IntSet>& set) {
if (set.size() == 1) return set[0];
Interval x = set[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < set.size(); ++i) {
x.include(set[i].cover_interval().as<IntervalSet>()->i);
IntSet s = set[i].cover_interval();
const Interval& y = s.as<IntervalSet>()->i;
if (can_prove(x.max + 1 >= y.min)) {
x.max = y.max;
} else if (can_prove(y.max + 1 >= x.min)) {
x.min = y.min;
} else {
x.include(y);
}
}
return IntervalSet::make(x);
}
Expand Down
28 changes: 15 additions & 13 deletions src/lang/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ Operation PlaceholderOpNode::make(std::string name,
return Operation(n);
}



Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
Expand Down Expand Up @@ -162,24 +160,25 @@ Operation ScanOpNode::make(std::string name,
<< " scan_axis.dom.min + scan_axis.dom.extent";
CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
<< "The dimension of init need to match state_placeholder";
CHECK_EQ(update[i].ndim() + 1, state_placeholder[i].ndim())
CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim())
<< "The update.ndim need to be state_placeholder.ndim - 1";
for (size_t k = 0; k < update[i].ndim(); ++k) {
CHECK(prove_equal(
update[i]->shape[k], state_placeholder[i]->shape[k + 1]));
// setup spatial axis
std::ostringstream spatial_name;
spatial_name << name << ".out" << i << ".i" << k + 1;
n->spatial_axis_.push_back(
IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
spatial_name.str()));
update[i]->shape[k], state_placeholder[i]->shape[k]));
if (k != 0) {
// setup spatial axis
std::ostringstream spatial_name;
spatial_name << name << ".out" << i << ".i" << k;
n->spatial_axis_.push_back(
IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
spatial_name.str()));
}
}
for (size_t k = 1; k < init[i].ndim(); ++k) {
CHECK(prove_equal(
init[i]->shape[k], state_placeholder[i]->shape[k]));
}
}

n->name = name;
n->scan_axis = axis;
n->init = init;
Expand All @@ -188,11 +187,14 @@ Operation ScanOpNode::make(std::string name,
return Operation(n);
}

Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name) {
IterVar scan_axis(
Range::make_with_min_extent(
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
name + ".idx");
Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder);
Array<Tensor> res;
Expand Down
4 changes: 3 additions & 1 deletion src/pass/inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ Stmt Inline(Stmt stmt,
Expr body) {
CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation";
return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
Stmt ret = IRInline(f, args, body).Mutate(stmt);
if (ret.same_as(stmt)) return ret;
return ConvertSSA(ret);
}
} // namespace ir
} // namespace tvm
Loading

0 comments on commit c8ec411

Please sign in to comment.