Skip to content

Commit

Permalink
[CODEGEN] Add LoweredFunc, MakeAPI to build a C API function with che…
Browse files Browse the repository at this point in the history
…cks.
  • Loading branch information
tqchen committed Jan 22, 2017
1 parent 3c1020d commit e793264
Show file tree
Hide file tree
Showing 17 changed files with 726 additions and 102 deletions.
2 changes: 1 addition & 1 deletion HalideIR
Submodule HalideIR updated from adfa66 to d82003
3 changes: 3 additions & 0 deletions include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class Buffer : public NodeRef {
* \return the pointer to the internal node container
*/
inline const BufferNode* operator->() const;

/*! \brief specify container node */
using ContainerType = BufferNode;
};

/*! \brief Node to represent a buffer */
Expand Down
40 changes: 34 additions & 6 deletions include/tvm/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#endif

#include <stdint.h>
#include <stddef.h>


TVM_EXTERN_C {
Expand Down Expand Up @@ -216,18 +217,45 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);

/*!
* \brief Launch a generated TVM function
* \brief TVM Function API: Get resource requirement
*
* By default TVM function try not to do internal allocations.
* Instead, TVMFuncRequirement can be called, given the input arguments.
*
* \param func function handle to be launched.
* \param args The arguments
* \param arg_type_ids The type id of the arguments
* \param num_args Number of arguments.
* \param out_workspace_size The workspace size needed to launch this function.
* \param out_workspace_align The alignment requirement of workspace.
*
* \note The data pointer in the arrays is not used by requirement.
*/
TVM_DLL int TVMFuncRequirement(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
size_t* out_workspace_size,
size_t* out_workspace_align);

/*!
* \brief TVM Function API: Launch generated function.
*
* \param func function handle to be launched.
* \param args The arguments
* \param arg_type_ids The type id of the arguments
* \param num_args Number of arguments.
* \param stream The stream this function to be launched on.
* \param workspace Additional workspace used to launch this function.
*
* \sa TVMFuncRequirement
*/
TVM_DLL int TVMLaunch(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
TVMStreamHandle stream);
TVM_DLL int TVMFuncLaunch(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
TVMStreamHandle stream,
TVMArrayHandle workspace);
} // TVM_EXTERN_C

#endif // TVM_C_RUNTIME_API_H_
51 changes: 51 additions & 0 deletions include/tvm/codegen.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*!
* Copyright (c) 2016 by Contributors
* \file codegen.h
* \brief Collection of Lowlevel IR pass to codegen.
*/
#ifndef TVM_CODEGEN_H_
#define TVM_CODEGEN_H_

#include <string>
#include "./base.h"
#include "./expr.h"
#include "./module.h"

namespace tvm {
/*! \brief namespace for lowlevel IR pass and codegen */
namespace codegen {
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
* - Map the values in the api_args to of Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_packed_args Number of arguments that are processed in packed form.
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n)
*
* where n == len(api_args), k == num_packed_args
*
* There is no thread_axis in generated function.
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args);
} // namespace codegen
} // namespace tvm

#endif // TVM_CODEGEN_H_
42 changes: 42 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,48 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};

/*! \brief namespace of TVM Intrinsic functions */
namespace intrinsic {
// Most of the intrinsics is to enab
/*!
* \brief See pesudo code
*
* Type tvm_api_load_arg(TVMArg* args, int* args_type_id, i) {
* assert(arg_type_id[i] == typeid(Type));
* return args[i];
* }
*/
constexpr const char* tvm_api_load_arg = "tvm_api_load_arg";
/*!
* \brief See pesudo code
*
* Type tvm_array_get_field(TVMArray* arr, int field_id) {
* return arr->field;
* }
* \sa TVMArrayFieldKind
*/
constexpr const char* tvm_array_get_field = "tvm_array_get_field";
/*!
* \brief See pesudo code
*
* bool tvm_handle_is_null(void* handle) {
* return handle == nullptr
* }
*/
constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";

/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {
kData = 0,
kNDim = 1,
kShape = 2,
kStrides = 3,
kTypeCode = 4,
kTypeBits = 5,
kTypeLanes = 6
};
} // namespace intrinsic

// Reuse IR node defintiion from HalideIR
using Halide::Internal::IntImm;
using Halide::Internal::UIntImm;
Expand Down
1 change: 0 additions & 1 deletion include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ Stmt Inline(Stmt stmt,
Array<Var> args,
Expr body);


/*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
Expand Down
108 changes: 108 additions & 0 deletions include/tvm/module.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*!
* Copyright (c) 2016 by Contributors
* \file module.h
* \brief Low level IR module,
* Contains lowered function information.
*/
#ifndef TVM_MODULE_H_
#define TVM_MODULE_H_

#include <tvm/container.h>
#include <ir/FunctionBase.h>
#include <string>

#include "./base.h"
#include "./expr.h"
#include "./tensor.h"

namespace tvm {

// Internal node container of lowered function.
class LoweredFuncNode;

// Internal node container of module.
class ModuleNode;

/*!
* \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen.
*/
class LoweredFunc : public FunctionRef {
public:
LoweredFunc() {}
explicit LoweredFunc(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const LoweredFuncNode* operator->() const;
/*! \brief specify container node */
using ContainerType = LoweredFuncNode;
};

/*! \brief Node container of LoweredFunc */
class LoweredFuncNode : public FunctionBaseNode {
public:
/*! \brief The name of the function */
std::string name;
/*!
* \brief The arguments of the function
* This function can only take pod type(int, float) and void* as arguments.
*/
Array<Var> args;
/*!
* \brief The IterVar axis of threads
* Each axis need host function to specify a size.
* \note Calling convention into LoweredFunc
*
* Assume we have a LoweredFunc f, a call into f
* Call(f, arg1, arg2, ..., arg_n,
* size_axis_1, size_axis_2, ... size_axis_m)
*
* Here n = len(args), m = len(thread_axis)
*
* The CodeGen should take this and translate this call
* to corresponding API specific kernel launchs or function calls.
*/
Array<IterVar> thread_axis;
/*!
* \brief The hint data type of Var handles defined in LetStmt
* Can be used as hint when generating type signiture.
* The creation rule is given by
* handle_data_type[var_handle] = make_const(the_type, 0);
*
* \note Expr is used instead Type, because Type cannot be hold by Map.
* constant Expr of given type is used.
*/
Map<Var, Expr> handle_data_type;
/*! \brief The body statment of the function */
Stmt body;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
// there is no return value, but return 1
// to enable Call into this function.
int num_outputs() const final {
return 1;
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis);
v->Visit("handle_data_type", &handle_data_type);
v->Visit("body", &body);
}

static constexpr const char* _type_key = "LoweredFunc";
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode);
};

// Implementations of inline functions
inline const LoweredFuncNode* LoweredFunc::operator->() const {
return static_cast<const LoweredFuncNode*>(node_.get());
}

} // namespace tvm

#endif // TVM_MODULE_H_
6 changes: 6 additions & 0 deletions python/tvm/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,9 @@ class IterVar(NodeBase, _expr.ExprOp):
class Buffer(NodeBase):
"""Represent a Buffer in TVM."""
pass


@register_node
class LoweredFunc(NodeBase):
"""Represent a LoweredFunc in TVM."""
pass
3 changes: 2 additions & 1 deletion src/base/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define TVM_BASE_COMMON_H_

#include <tvm/base.h>
#include <tvm/expr.h>
#include <string>

namespace tvm {
Expand All @@ -30,7 +31,7 @@ inline Type String2Type(std::string s) {
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s == "handle") {
return Type(Type::Handle, 32, 1);
return Handle();
} else {
LOG(FATAL) << "unknown type " << s;
}
Expand Down
8 changes: 7 additions & 1 deletion src/c_api/c_api_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/codegen.h>

#include "./c_api_registry.h"
#include "../codegen/codegen_c.h"
Expand All @@ -17,7 +18,12 @@ using RetValue = APIVariantValue;

TVM_REGISTER_API(_codegen_CompileToC)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = CodeGenC().Compile(
*ret = CodeGenC().Compile(args.at(0), args.at(1));
});

TVM_REGISTER_API(_codegen_MakeAPI)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = MakeAPI(
args.at(0), args.at(1), args.at(2), args.at(3));
});

Expand Down
Loading

0 comments on commit e793264

Please sign in to comment.