From e7932647c9b97d7718164da324135c077d60bf35 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 20 Jan 2017 19:32:19 -0800 Subject: [PATCH] [CODEGEN] Add LoweredFunc, MakeAPI to build a C API function with checks. --- HalideIR | 2 +- include/tvm/buffer.h | 3 + include/tvm/c_runtime_api.h | 40 +++++- include/tvm/codegen.h | 51 +++++++ include/tvm/ir.h | 42 ++++++ include/tvm/ir_pass.h | 1 - include/tvm/module.h | 108 ++++++++++++++ python/tvm/collections.py | 6 + src/base/common.h | 3 +- src/c_api/c_api_codegen.cc | 8 +- src/codegen/codegen_c.cc | 205 +++++++++++++++++++++------ src/codegen/codegen_c.h | 23 +-- src/codegen/make_api.cc | 190 +++++++++++++++++++++++++ src/pass/ir_util.h | 70 +++++++++ src/pass/schedule_ops.cc | 40 +----- tests/python/test_codegen_cuda.py | 7 +- tests/python/test_codegen_makeapi.py | 29 ++++ 17 files changed, 726 insertions(+), 102 deletions(-) create mode 100644 include/tvm/codegen.h create mode 100644 include/tvm/module.h create mode 100644 src/codegen/make_api.cc create mode 100644 src/pass/ir_util.h create mode 100644 tests/python/test_codegen_makeapi.py diff --git a/HalideIR b/HalideIR index adfa662402650..d8200348fcef1 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit adfa662402650e2f9b02ea600ffb70d6e7bb5adf +Subproject commit d8200348fcef184a374c2dbd46d3f5a5136a53e3 diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index beed7e9d1281c..2e4d7debcf427 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -50,6 +50,9 @@ class Buffer : public NodeRef { * \return the pointer to the internal node container */ inline const BufferNode* operator->() const; + + /*! \brief specify container node */ + using ContainerType = BufferNode; }; /*! \brief Node to represent a buffer */ diff --git a/include/tvm/c_runtime_api.h b/include/tvm/c_runtime_api.h index 1a21adc41cd69..25b81d80ce5af 100644 --- a/include/tvm/c_runtime_api.h +++ b/include/tvm/c_runtime_api.h @@ -30,6 +30,7 @@ #endif #include +#include TVM_EXTERN_C { @@ -216,18 +217,45 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream); /*! - * \brief Launch a generated TVM function + * \brief TVM Function API: Get resource requirement + * + * By default TVM function try not to do internal allocations. + * Instead, TVMFuncRequirement can be called, given the input arguments. + * + * \param func function handle to be launched. + * \param args The arguments + * \param arg_type_ids The type id of the arguments + * \param num_args Number of arguments. + * \param out_workspace_size The workspace size needed to launch this function. + * \param out_workspace_align The alignment requirement of workspace. + * + * \note The data pointer in the arrays is not used by requirement. + */ +TVM_DLL int TVMFuncRequirement(TVMFunctionHandle func, + TVMArg* args, + int* arg_type_ids, + int num_args, + size_t* out_workspace_size, + size_t* out_workspace_align); + +/*! + * \brief TVM Function API: Launch generated function. + * * \param func function handle to be launched. * \param args The arguments * \param arg_type_ids The type id of the arguments * \param num_args Number of arguments. * \param stream The stream this function to be launched on. + * \param workspace Additional workspace used to launch this function. + * + * \sa TVMFuncRequirement */ -TVM_DLL int TVMLaunch(TVMFunctionHandle func, - TVMArg* args, - int* arg_type_ids, - int num_args, - TVMStreamHandle stream); +TVM_DLL int TVMFuncLaunch(TVMFunctionHandle func, + TVMArg* args, + int* arg_type_ids, + int num_args, + TVMStreamHandle stream, + TVMArrayHandle workspace); } // TVM_EXTERN_C #endif // TVM_C_RUNTIME_API_H_ diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h new file mode 100644 index 0000000000000..f47471d7ef0fc --- /dev/null +++ b/include/tvm/codegen.h @@ -0,0 +1,51 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file codegen.h + * \brief Collection of Lowlevel IR pass to codegen. + */ +#ifndef TVM_CODEGEN_H_ +#define TVM_CODEGEN_H_ + +#include +#include "./base.h" +#include "./expr.h" +#include "./module.h" + +namespace tvm { +/*! \brief namespace for lowlevel IR pass and codegen */ +namespace codegen { +/*! + * \brief Make an user callable API LoweredFunc. + * + * The main task of this function is to create code to : + * - Map the values in the api_args to of Var that is required by body. + * - Insert assertions to check type/value of the passed arguments. + * + * \param body The body of the function. + * \param name The name of the function. + * \param api_args Arguments to the function, can be either Var, or Buffer + * \param num_packed_args Number of arguments that are processed in packed form. + * \return a LoweredFunc with the specified signiture. + * + * \note + * The function signiture have two cases + * + * if num_packed_args is zero: + * f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args) + * + * if num_packed_args is not zero: + * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, + * api_arg_k, api_arg_k+1, ... api_arg_n) + * + * where n == len(api_args), k == num_packed_args + * + * There is no thread_axis in generated function. + */ +LoweredFunc MakeAPI(Stmt body, + std::string name, + Array api_args, + int num_packed_args); +} // namespace codegen +} // namespace tvm + +#endif // TVM_CODEGEN_H_ diff --git a/include/tvm/ir.h b/include/tvm/ir.h index dd53d53b2c37d..0676104213c9a 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -49,6 +49,48 @@ struct Reduce : public ExprNode { static constexpr const char* Min = "Min"; }; +/*! \brief namespace of TVM Intrinsic functions */ +namespace intrinsic { +// Most of the intrinsics is to enab +/*! + * \brief See pesudo code + * + * Type tvm_api_load_arg(TVMArg* args, int* args_type_id, i) { + * assert(arg_type_id[i] == typeid(Type)); + * return args[i]; + * } + */ +constexpr const char* tvm_api_load_arg = "tvm_api_load_arg"; +/*! + * \brief See pesudo code + * + * Type tvm_array_get_field(TVMArray* arr, int field_id) { + * return arr->field; + * } + * \sa TVMArrayFieldKind + */ +constexpr const char* tvm_array_get_field = "tvm_array_get_field"; +/*! + * \brief See pesudo code + * + * bool tvm_handle_is_null(void* handle) { + * return handle == nullptr + * } + */ +constexpr const char* tvm_handle_is_null = "tvm_handle_is_null"; + +/*! \brief The field id of each field in array */ +enum TVMArrayFieldKind { + kData = 0, + kNDim = 1, + kShape = 2, + kStrides = 3, + kTypeCode = 4, + kTypeBits = 5, + kTypeLanes = 6 +}; +} // namespace intrinsic + // Reuse IR node defintiion from HalideIR using Halide::Internal::IntImm; using Halide::Internal::UIntImm; diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index a45bbbb91fd86..06674a18bd0da 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -79,7 +79,6 @@ Stmt Inline(Stmt stmt, Array args, Expr body); - /*! * \brief Flatten the multi-dimensional read/write * to single dimensional Load/Store diff --git a/include/tvm/module.h b/include/tvm/module.h new file mode 100644 index 0000000000000..263fdc2f28f17 --- /dev/null +++ b/include/tvm/module.h @@ -0,0 +1,108 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file module.h + * \brief Low level IR module, + * Contains lowered function information. + */ +#ifndef TVM_MODULE_H_ +#define TVM_MODULE_H_ + +#include +#include +#include + +#include "./base.h" +#include "./expr.h" +#include "./tensor.h" + +namespace tvm { + +// Internal node container of lowered function. +class LoweredFuncNode; + +// Internal node container of module. +class ModuleNode; + +/*! + * \brief LoweredFunc represents function after lowering. + * This is the final IR representation before codegen. + */ +class LoweredFunc : public FunctionRef { + public: + LoweredFunc() {} + explicit LoweredFunc(std::shared_ptr n) : FunctionRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const LoweredFuncNode* operator->() const; + /*! \brief specify container node */ + using ContainerType = LoweredFuncNode; +}; + +/*! \brief Node container of LoweredFunc */ +class LoweredFuncNode : public FunctionBaseNode { + public: + /*! \brief The name of the function */ + std::string name; + /*! + * \brief The arguments of the function + * This function can only take pod type(int, float) and void* as arguments. + */ + Array args; + /*! + * \brief The IterVar axis of threads + * Each axis need host function to specify a size. + * \note Calling convention into LoweredFunc + * + * Assume we have a LoweredFunc f, a call into f + * Call(f, arg1, arg2, ..., arg_n, + * size_axis_1, size_axis_2, ... size_axis_m) + * + * Here n = len(args), m = len(thread_axis) + * + * The CodeGen should take this and translate this call + * to corresponding API specific kernel launchs or function calls. + */ + Array thread_axis; + /*! + * \brief The hint data type of Var handles defined in LetStmt + * Can be used as hint when generating type signiture. + * The creation rule is given by + * handle_data_type[var_handle] = make_const(the_type, 0); + * + * \note Expr is used instead Type, because Type cannot be hold by Map. + * constant Expr of given type is used. + */ + Map handle_data_type; + /*! \brief The body statment of the function */ + Stmt body; + /*! \return name of the operation */ + const std::string& func_name() const final { + return name; + } + // there is no return value, but return 1 + // to enable Call into this function. + int num_outputs() const final { + return 1; + } + void VisitAttrs(AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("args", &args); + v->Visit("thread_axis", &thread_axis); + v->Visit("handle_data_type", &handle_data_type); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "LoweredFunc"; + TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode); +}; + +// Implementations of inline functions +inline const LoweredFuncNode* LoweredFunc::operator->() const { + return static_cast(node_.get()); +} + +} // namespace tvm + +#endif // TVM_MODULE_H_ diff --git a/python/tvm/collections.py b/python/tvm/collections.py index 85e629cc96da9..2e43e2e6bec0b 100644 --- a/python/tvm/collections.py +++ b/python/tvm/collections.py @@ -56,3 +56,9 @@ class IterVar(NodeBase, _expr.ExprOp): class Buffer(NodeBase): """Represent a Buffer in TVM.""" pass + + +@register_node +class LoweredFunc(NodeBase): + """Represent a LoweredFunc in TVM.""" + pass diff --git a/src/base/common.h b/src/base/common.h index ea2f4bdad9e58..432ec74db9af0 100644 --- a/src/base/common.h +++ b/src/base/common.h @@ -7,6 +7,7 @@ #define TVM_BASE_COMMON_H_ #include +#include #include namespace tvm { @@ -30,7 +31,7 @@ inline Type String2Type(std::string s) { } else if (s.substr(0, 5) == "float") { code = Type::Float; s = s.substr(5); } else if (s == "handle") { - return Type(Type::Handle, 32, 1); + return Handle(); } else { LOG(FATAL) << "unknown type " << s; } diff --git a/src/c_api/c_api_codegen.cc b/src/c_api/c_api_codegen.cc index 365033ea445f5..e3b08369731f3 100644 --- a/src/c_api/c_api_codegen.cc +++ b/src/c_api/c_api_codegen.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include "./c_api_registry.h" #include "../codegen/codegen_c.h" @@ -17,7 +18,12 @@ using RetValue = APIVariantValue; TVM_REGISTER_API(_codegen_CompileToC) .set_body([](const ArgStack& args, RetValue *ret) { - *ret = CodeGenC().Compile( + *ret = CodeGenC().Compile(args.at(0), args.at(1)); + }); + +TVM_REGISTER_API(_codegen_MakeAPI) +.set_body([](const ArgStack& args, RetValue *ret) { + *ret = MakeAPI( args.at(0), args.at(1), args.at(2), args.at(3)); }); diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index a42569e9ad327..1107032cec5b9 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -9,24 +9,27 @@ namespace codegen { using namespace ir; -std::string CodeGenC::Compile( - Stmt stmt, std::string fun_name, - Array args, bool output_ssa) { +std::string CodeGenC::Compile(LoweredFunc f, + bool output_ssa) { print_ssa_form_ = output_ssa; // skip the first underscore, so SSA variable starts from _1 if (print_ssa_form_) GetUniqueName("_"); + // add to alloc buffer type. + for (const auto & kv : f->handle_data_type) { + HandleTypeRegister(kv.first.get(), kv.second.type()); + } this->indent += 2; - this->stream << "void " << fun_name << "("; - for (size_t i = 0; i < args.size(); ++i) { - Var v = args[i]; + this->stream << "void " << f->name << "("; + for (size_t i = 0; i < f->args.size(); ++i) { + Var v = f->args[i]; std::string vid = AllocVarID(v.get()); if (i != 0) stream << ", "; PrintType(v.type(), stream); stream << ' ' << vid; } stream << ") {\n"; - this->PrintStmt(stmt); + this->PrintStmt(f->body); this->indent -= 2; this->PrintIndent(); this->stream << "}\n"; @@ -104,12 +107,22 @@ std::string CodeGenC::GetVarID(const Variable* v) const { return it->second; } -bool CodeGenC::BufferTypeMatch(const Variable* buf_var, Type t) const { - auto it = alloc_buf_type_.find(buf_var); - if (it == alloc_buf_type_.end()) return false; +bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const { + auto it = handle_data_type_.find(buf_var); + if (it == handle_data_type_.end()) return false; return it->second == t; } +void CodeGenC::HandleTypeRegister(const Variable* buf_var, Type t) { + auto it = handle_data_type_.find(buf_var); + if (it == handle_data_type_.end()) { + handle_data_type_[buf_var] = t; + } else { + CHECK(it->second == t) + << "conflicting buf var type"; + } +} + void CodeGenC::PrintIndent() { for (int i = 0; i < this->indent; ++i) { this->stream << ' '; @@ -234,6 +247,18 @@ inline void PrintBinaryExpr(const T* op, os << ')'; } +inline void PrintBinaryIntrinsitc(const Call* op, + const char *opstr, + std::ostream& os, // NOLINT(*) + CodeGenC* p) { + CHECK_EQ(op->args.size(), 2U); + os << '('; + p->PrintExpr(op->args[0], os); + os << opstr; + p->PrintExpr(op->args[1], os); + os << ')'; +} + TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) .set_dispatch([](const Cast *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) p->PrintType(op->type, os); @@ -300,24 +325,9 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) .set_dispatch([](const Not *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) os << '!'; p->PrintExpr(op->a, os); - }) -.set_dispatch([](const Call *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - os << op->name << "("; - for (size_t i = 0; i < op->args.size(); i++) { - p->PrintExpr(op->args[i], os); - if (i < op->args.size() - 1) { - os << ", "; - } - } - os << ")"; }); TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt) -.set_dispatch([](const AssertStmt *op, CodeGenC* p) { - std::string cond = p->PrintExpr(op->condition); - p->PrintIndent(); - p->stream << "assert(" << cond << ");\n"; - }) .set_dispatch([](const ProducerConsumer *op, CodeGenC* p) { p->PrintStmt(op->body); }) @@ -372,14 +382,95 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt) TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) .DISPATCH_EXPR(Load) +.DISPATCH_EXPR(Call) .DISPATCH_EXPR(Let) .DISPATCH_EXPR(Ramp) .DISPATCH_EXPR(Broadcast) .DISPATCH_EXPR(Select); + +void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*) + CodeGenC* p = this; + if (op->is_intrinsic(Call::bitwise_and)) { + PrintBinaryIntrinsitc(op, " & ", os, p); + } else if (op->is_intrinsic(Call::bitwise_xor)) { + PrintBinaryIntrinsitc(op, " ^ ", os, p); + } else if (op->is_intrinsic(Call::bitwise_or)) { + PrintBinaryIntrinsitc(op, " | ", os, p); + } else if (op->is_intrinsic(Call::bitwise_not)) { + CHECK_EQ(op->args.size(), 1U); + os << "(~"; + p->PrintExpr(op->args[0], os); + os << ')'; + } else if (op->is_intrinsic(Call::shift_left)) { + PrintBinaryIntrinsitc(op, " << ", os, p); + } else if (op->is_intrinsic(Call::shift_right)) { + PrintBinaryIntrinsitc(op, " >> ", os, p); + } else if (op->is_intrinsic(Call::address_of)) { + const Load *l = op->args[0].as(); + CHECK(op->args.size() == 1 && l); + os << "(("; + p->PrintType(l->type.element_of(), os); + os << " *)" << p->GetVarID(l->buffer_var.get()) + << " + "; + p->PrintExpr(l->index, os); + os << ')'; + } else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) { + CHECK_EQ(op->args.size(), 3U); + if (!op->type.is_handle()) { + os << '('; + p->PrintType(op->type, os); + os << ')'; + } + os << "(((TVMArg*)"; + p->PrintExpr(op->args[0], os); + os << ")[" << op->args[2] << "]."; + if (op->type.is_handle()) { + os << "v_handle"; + } else if (op->type.is_float()) { + os << "v_double"; + } else if (op->type.is_int() || op->type.is_uint()) { + os << "v_long"; + } else { + LOG(FATAL) << "donot know how to handle type" << op->type; + } + os << ")"; + } else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) { + CHECK_EQ(op->args.size(), 2U); + os << "(((TVMArray*)"; + p->PrintExpr(op->args[0], os); + os << ")->"; + switch (op->args[1].as()->value) { + case intrinsic::kData: os << "data"; break; + case intrinsic::kShape: os << "shape"; break; + case intrinsic::kStrides: os << "strides"; break; + case intrinsic::kNDim: os << "ndim"; break; + case intrinsic::kTypeCode: os << "dtype.type_code"; break; + case intrinsic::kTypeBits: os << "dtype.bits"; break; + case intrinsic::kTypeLanes: os << "dtype.lanes"; break; + default: LOG(FATAL) << "unknown field code"; + } + os << ')'; + } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { + CHECK_EQ(op->args.size(), 1U); + os << "("; + p->PrintExpr(op->args[0], os); + os << " == NULL)"; + } else { + os << op->name << "("; + for (size_t i = 0; i < op->args.size(); i++) { + p->PrintExpr(op->args[i], os); + if (i < op->args.size() - 1) { + os << ", "; + } + } + os << ")"; + } +} + void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*) std::string vid = GetVarID(op->buffer_var.get()); - if (!BufferTypeMatch(op->buffer_var.get(), op->type)) { + if (!HandleTypeMatch(op->buffer_var.get(), op->type)) { os << "((const "; PrintType(op->type, os); os << "*)" << vid << ')'; @@ -416,7 +507,8 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt) .set_dispatch([](const LetStmt *op, CodeGenC* p) { p->PrintStmt(op); }) .set_dispatch([](const Store *op, CodeGenC* p) { p->PrintStmt(op); }) .set_dispatch([](const Allocate *op, CodeGenC* p) { p->PrintStmt(op); }) -.set_dispatch([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); }); +.set_dispatch([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); }) +.set_dispatch([](const AssertStmt *op, CodeGenC* p) { p->PrintStmt(op); }); void CodeGenC::PrintStmt(const LetStmt* op) { @@ -426,10 +518,20 @@ void CodeGenC::PrintStmt(const LetStmt* op) { var_idmap_[op->var.get()] = value; } else { PrintIndent(); - PrintType(op->var.type(), this->stream); - this->stream << ' ' - << AllocVarID(op->var.get()) - << " = " << value << ";\n"; + if (op->var.type() == Handle() && + handle_data_type_.count(op->var.get())) { + PrintType(handle_data_type_.at(op->var.get()), stream); + stream << "* " + << AllocVarID(op->var.get()) + << " = ("; + PrintType(handle_data_type_.at(op->var.get()), stream); + stream << "*)" << value << ";\n"; + } else { + PrintType(op->var.type(), this->stream); + this->stream << ' ' + << AllocVarID(op->var.get()) + << " = " << value << ";\n"; + } } PrintStmt(op->body); } @@ -439,7 +541,7 @@ void CodeGenC::PrintStmt(const Store* op) { std::string value = this->PrintExpr(op->value); this->PrintIndent(); std::string vid = GetVarID(op->buffer_var.get()); - if (!BufferTypeMatch(op->buffer_var.get(), op->value.type())) { + if (!HandleTypeMatch(op->buffer_var.get(), op->value.type())) { this->stream << "(("; PrintType(op->value.type(), this->stream); this->stream << "*)" << vid << ')'; @@ -452,16 +554,25 @@ void CodeGenC::PrintStmt(const Store* op) { } void CodeGenC::PrintStmt(const Allocate* op) { - this->PrintIndent(); - int32_t constant_size = op->constant_allocation_size(); - std::string vid = AllocVarID(op->buffer_var.get()); - CHECK(!op->new_expr.defined()); CHECK(!is_zero(op->condition)); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation for now"; - PrintType(op->type, stream); - stream << ' '<< vid << '[' - << constant_size << "]\n;"; + std::string vid = AllocVarID(op->buffer_var.get()); + if (op->new_expr.defined()) { + // Prefer global static allocation for the program + CHECK_EQ(op->free_function, "nop"); + std::string new_data = PrintExpr(op->new_expr); + this->PrintIndent(); + PrintType(op->type, stream); + stream << "* "<< vid << '=' << new_data << ";\n"; + } else { + this->PrintIndent(); + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now"; + PrintType(op->type, stream); + stream << ' '<< vid << '[' + << constant_size << "]\n;"; + } + HandleTypeRegister(op->buffer_var.get(), op->type); this->PrintStmt(op->body); } @@ -479,5 +590,17 @@ void CodeGenC::PrintStmt(const AttrStmt* op) { this->PrintStmt(op->body); } +void CodeGenC::PrintStmt(const AssertStmt* op) { + std::string cond = PrintExpr(op->condition); + PrintIndent(); + if (op->message.as()) { + // GLOG style check + stream << "CHECK(" << cond << ") << \"" + << op->message.as()->value << "\";\n"; + } else { + stream << "assert(" << cond << ");\n"; + } +} + } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index a8ce1828e4b43..4630e9990b568 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -23,16 +24,12 @@ class CodeGenC { public: /*! * \brief Generate the C code of statement - * \param body The body of the function. - * \param fun_name The name of the function. - * \param args The arguments to the function. + * \param f The function to be compiled * \param output_ssa Whether output ssa form. * \note Only call compile once, * create a new codegen object each time. */ - std::string Compile(Stmt body, - std::string fun_name, - Array args, + std::string Compile(LoweredFunc f, bool output_ssa); /*! * \brief Print the Stmt n to CodeGenC->stream @@ -49,7 +46,7 @@ class CodeGenC { * \brief Same as PrintExpr, but simply returns result string * \param n The expression to be printed. */ - inline std::string PrintExpr(const Expr& n) { + std::string PrintExpr(const Expr& n) { std::ostringstream os; PrintExpr(n, os); return os.str(); @@ -85,7 +82,9 @@ class CodeGenC { virtual void PrintStmt(const ir::Store* op); virtual void PrintStmt(const ir::Allocate* op); virtual void PrintStmt(const ir::AttrStmt* op); + virtual void PrintStmt(const ir::AssertStmt* op); virtual void PrintExpr(const ir::Load* op, std::ostream& os); // NOLINT(*) + virtual void PrintExpr(const ir::Call* op, std::ostream& os); // NOLINT(*) virtual void PrintExpr(const ir::Let* op, std::ostream& os); // NOLINT(*) virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*) virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*) @@ -116,7 +115,13 @@ class CodeGenC { * \param buf_var The buffer variable. * \param t The type to be checked. */ - bool BufferTypeMatch(const Variable* buf_var, Type t) const; + bool HandleTypeMatch(const Variable* buf_var, Type t) const; + /*! + * \brief Register the data type of buf_var + * \param buf_var The buffer variable. + * \param t The type to be checked. + */ + void HandleTypeRegister(const Variable* buf_var, Type t); /*! * \brief get a unique name with the corresponding prefix * \param prefix The prefix of the name @@ -128,7 +133,7 @@ class CodeGenC { /*! \brief name of each variable */ std::unordered_map var_idmap_; /*! \brief the data type of allocated buffers */ - std::unordered_map alloc_buf_type_; + std::unordered_map handle_data_type_; /*! \brief name allocation map */ std::unordered_map name_alloc_map_; /*! \brief assignment map of ssa */ diff --git a/src/codegen/make_api.cc b/src/codegen/make_api.cc new file mode 100644 index 0000000000000..717410f65738e --- /dev/null +++ b/src/codegen/make_api.cc @@ -0,0 +1,190 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file make_api.cc Build API function. + */ +#include +#include +#include + +#include +#include +#include + +#include "../pass/ir_util.h" + +namespace tvm { +namespace codegen { +using namespace ir; + +inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMArrayFieldKind kind) { + return Call::make( + t, intrinsic::tvm_array_get_field, + {arr, IntImm::make(Int(32), kind)}, + Call::PureIntrinsic); +} + +inline Stmt AssertNull(Var handle, std::string msg) { + return AssertStmt::make(Call::make( + Bool(1), intrinsic::tvm_handle_is_null, + {handle}, Call::PureIntrinsic), msg); +} + +inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) { + return AssertStmt::make(lhs == rhs, msg); +} + +LoweredFunc MakeAPI(Stmt body, + std::string name, + Array api_args, + int num_packed_args) { + const Type tvm_index_type = UInt(32); + const Stmt nop = Evaluate::make(0); + // Data field definitions + // The packed fields + Var v_packed_args("args", Handle()); + Var v_packed_arg_type_ids("arg_type_ids", Handle()); + Var v_num_packed_args("num_args", Int(32)); + // The arguments of the function. + Array args; + // seq_init gives sequence of initialization + // seq_check gives sequence of later checks after iniit + std::vector seq_init, seq_check; + std::unordered_set visited; + // the handle data types + Map handle_data_type; + // --------------------------- + // local function defintiions + // load i-th argument as type t + auto f_arg_value = [&](Type t, int i) { + Array call_args{ + v_packed_args, v_packed_arg_type_ids, IntImm::make(Int(32), i)}; + return Call::make( + t, intrinsic::tvm_api_load_arg, call_args, + Call::PureIntrinsic); + }; + // get declaration of argument i + auto f_arg_decl = [&](int i) { + std::ostringstream os; + os << "arg" << i; + const Variable* v = api_args[i].as(); + return Var(os.str(), v ? v->type: Handle()); + }; + // Push related into assertions or variable defintion + // given the symbolic declaration and concrete value + auto f_push = [&](Expr sym, Expr value, std::string field) { + if (sym.as()) { + // If sym is a Variable and this Variable is not yet defined + // add this to defintion. + Var v(sym.node_); + if (!visited.count(v.get())) { + seq_init.emplace_back(LetStmt::make(v, value, nop)); + visited.insert(v.get()); + return true; + } + } + // otherwise, assume sym is already defined, insert assertion. + std::ostringstream os; + os << "Field " << field << " has a unsatisfied constraint"; + seq_check.emplace_back(MakeAssertEQ(sym, value, os.str())); + return false; + }; + // --------------------------- + // start of logics + // add signiture for packed arguments. + if (num_packed_args != 0) { + args.push_back(v_packed_args); + args.push_back(v_packed_arg_type_ids); + args.push_back(v_num_packed_args); + std::ostringstream os; + os << "expected num_args to be " << num_packed_args; + seq_init.emplace_back( + MakeAssertEQ(v_num_packed_args, num_packed_args, os.str())); + } + + for (size_t i = 0; i < api_args.size(); ++i) { + Var v_arg = f_arg_decl(i); + if (i < static_cast(num_packed_args)) { + seq_init.emplace_back(LetStmt::make( + v_arg, f_arg_value(v_arg.type(), i), nop)); + } else { + args.push_back(v_arg); + } + // add checks for functions. + if (api_args[i].as()) { + f_push(Var(api_args[i].node_), v_arg, v_arg->name_hint); + } else { + // Buffer checks + CHECK(api_args[i].as()) + << "api_args can only be Buffer or Var"; + Buffer buf(api_args[i].node_); + // dimension checks + Expr v_ndim = TVMArrayGet(tvm_index_type, v_arg, intrinsic::kNDim); + std::ostringstream ndim_err_msg; + ndim_err_msg << "arg_" << i + << ": TVMArray->ndim is expected to equal " + << buf->shape.size(); + seq_init.emplace_back( + MakeAssertEQ(v_ndim, make_const(tvm_index_type, buf->shape.size()), + ndim_err_msg.str())); + // type checks + Type dtype = buf->dtype; + std::ostringstream type_err_msg; + type_err_msg << "arg" << i << ".dtype is expected to be " << dtype; + Expr cond = (TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeCode) == + make_const(UInt(8), dtype.code()) && + TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeBits) == + make_const(UInt(8), dtype.bits()) && + TVMArrayGet(UInt(16), v_arg, intrinsic::kTypeLanes) == + make_const(UInt(16), dtype.lanes())); + seq_init.emplace_back(AssertStmt::make(cond, type_err_msg.str())); + // Data Field + if (f_push(buf->ptr, TVMArrayGet(Handle(), v_arg, intrinsic::kData), + v_arg->name_hint + ".data")) { + Var vptr(buf->ptr); + handle_data_type.Set(vptr, make_const(buf->dtype, 0)); + } + // shape field + Var v_shape(v_arg->name_hint + ".shape", Handle()); + handle_data_type.Set(v_shape, make_const(tvm_index_type, 0)); + seq_init.emplace_back(LetStmt::make( + v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kShape), nop)); + for (size_t k = 0; k < buf->shape.size(); ++k) { + std::ostringstream field_name; + field_name << v_shape->name_hint << '[' << k << ']'; + f_push(buf->shape[k], + cast(buf->shape[k].type(), + Load::make(tvm_index_type, v_shape, make_const(Int(32), k))), + field_name.str()); + } + // strides field + Var v_strides(v_arg->name_hint + ".strides", Handle()); + handle_data_type.Set(v_strides, make_const(tvm_index_type, 0)); + seq_init.emplace_back(LetStmt::make( + v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kStrides), nop)); + if (buf->strides.size() == 0) { + std::ostringstream stride_err_msg; + stride_err_msg << "arg_" << i << ".strides:" + << " expected to be nullptr for contiguous array"; + seq_init.emplace_back(AssertNull(v_strides, stride_err_msg.str())); + } else { + for (size_t k = 0; k < buf->strides.size(); ++k) { + std::ostringstream field_name; + field_name << v_strides->name_hint << '[' << k << ']'; + f_push(buf->strides[k], + cast(buf->shape[k].type(), + Load::make(tvm_index_type, v_strides, make_const(Int(32), k))), + field_name.str()); + } + } + } + } + + std::shared_ptr n = std::make_shared(); + n->name = name; + n->args = args; + n->handle_data_type = handle_data_type; + n->body = MergeNest({seq_init, seq_check}, body); + return LoweredFunc(n); +} +} // namespace codegen +} // namespace tvm diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h new file mode 100644 index 0000000000000..794dcd8207153 --- /dev/null +++ b/src/pass/ir_util.h @@ -0,0 +1,70 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file ir_util.h + * \brief Helper functions to construct and compose IR nodes. + */ +#ifndef TVM_PASS_IR_UTIL_H_ +#define TVM_PASS_IR_UTIL_H_ + +#include +#include + +namespace tvm { +namespace ir { + +/*! + * \brief combine the nest stmt, whose body is not defined. + * \param nest A list of For and LetStmt, whose body is not defined. + * \param body body + * \return The combined Stmt + */ +inline Stmt MergeNest(std::vector nest, Stmt body) { + // use reverse iteration + for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { + Stmt s = *ri; + if (s.as()) { + auto n = std::make_shared(*s.as()); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (s.as()) { + auto n = std::make_shared(*s.as()); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (s.as()) { + auto n = std::make_shared(*s.as()); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (s.as()) { + auto n = std::make_shared(*s.as()); + CHECK(is_no_op(n->then_case)); + CHECK(!n->else_case.defined()); + n->then_case = body; + body = Stmt(n); + } else if (s.as()) { + body = Block::make(s, body); + } else { + LOG(FATAL) << "not supported nest type"; + } + } + return body; +} + +/*! + * \brief combine the nest stmt, whose body is not defined. + * \param nest A list of For and LetStmt, whose body is not defined. + * \param body body + * \return The combined Stmt + */ +inline Stmt MergeNest(std::vector > nest, Stmt body) { + for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { + body = MergeNest(*ri, body); + } + return body; +} + +} // namespace ir +} // namespace tvm +#endif // TVM_PASS_IR_UTIL_H_ diff --git a/src/pass/schedule_ops.cc b/src/pass/schedule_ops.cc index a62cf678b8cf6..2ac21105a8326 100644 --- a/src/pass/schedule_ops.cc +++ b/src/pass/schedule_ops.cc @@ -9,6 +9,7 @@ #include #include "./scope.h" +#include "./ir_util.h" #include "../schedule/graph.h" namespace tvm { @@ -81,45 +82,6 @@ void SplitByAdd(Expr expr, } } -/*! - * \brief combine the nest stmt, whose body is not defined. - * \param nest A list of For and LetStmt, whose body is not defined. - * \param body body - */ -Stmt MergeNest(std::vector > nest, Stmt body) { - // use reverse iteration - for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { - for (auto rj = ri->rbegin(); rj != ri->rend(); ++rj) { - Stmt s = *rj; - if (s.as()) { - auto n = std::make_shared(*s.as()); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (s.as()) { - auto n = std::make_shared(*s.as()); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (s.as()) { - auto n = std::make_shared(*s.as()); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (s.as()) { - auto n = std::make_shared(*s.as()); - CHECK(is_no_op(n->then_case)); - CHECK(!n->else_case.defined()); - n->then_case = body; - body = Stmt(n); - } else { - LOG(FATAL) << "not supported nest type"; - } - } - } - return body; -} - /*! * \brief Make the loop nest of the correspondings schedule. * \param sch The schedule. diff --git a/tests/python/test_codegen_cuda.py b/tests/python/test_codegen_cuda.py index 0f0a8df305063..26f52bcf06ba9 100644 --- a/tests/python/test_codegen_cuda.py +++ b/tests/python/test_codegen_cuda.py @@ -24,10 +24,11 @@ def mock_test_add(): Bb = tvm.Buffer(B.shape, B.dtype, name='B') Cb = tvm.Buffer(C.shape, C.dtype, name='C') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) + stmt = tvm.ir_pass.Simplify(stmt) print(stmt) - output_ssa = False - code = tvm.codegen.CompileToC(stmt, "myadd", - [Ab.ptr, Bb.ptr, Cb.ptr, n], + output_ssa = True + f = tvm.codegen.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 1) + code = tvm.codegen.CompileToC(f, output_ssa) print(code) diff --git a/tests/python/test_codegen_makeapi.py b/tests/python/test_codegen_makeapi.py new file mode 100644 index 0000000000000..5b264db29ce56 --- /dev/null +++ b/tests/python/test_codegen_makeapi.py @@ -0,0 +1,29 @@ +import tvm +import numpy + +def test_makeapi(): + """Not yet working, mock design""" + n = tvm.Var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.Schedule(C.op) + + bounds = tvm.schedule.InferBound(s) + stmt = tvm.ir_pass.ScheduleOps(s, bounds) + + Ab = tvm.Buffer(A.shape, A.dtype, name='A') + Bb = tvm.Buffer(B.shape, B.dtype, name='B') + Cb = tvm.Buffer(C.shape, C.dtype, name='C') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) + num_packed_args = 2 + f = tvm.codegen.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args) + assert(f.handle_data_type[Ab.ptr].dtype == Ab.dtype) + assert(len(f.args) == 5) + output_ssa = False + code = tvm.codegen.CompileToC(f, output_ssa) + print(code) + + +if __name__ == "__main__": + test_makeapi()