Skip to content

Commit

Permalink
General Layout Support (dmlc#447)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu authored and tqchen committed May 29, 2018
1 parent 46d7f97 commit 8f9c593
Show file tree
Hide file tree
Showing 39 changed files with 2,362 additions and 439 deletions.
28 changes: 0 additions & 28 deletions nnvm/include/nnvm/compiler/contrib_op_param.h

This file was deleted.

21 changes: 10 additions & 11 deletions nnvm/include/nnvm/compiler/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <nnvm/graph.h>
#include <vector>
#include <string>
#include "packed_func_ext.h"

namespace nnvm {
namespace compiler {
Expand Down Expand Up @@ -73,19 +74,17 @@ using FTVMSchedule = std::function<
const Array<Tensor>& outs,
const std::string& target)>;

/*! \brief Layout Information about an entry */
using TLayoutInfo = std::string;

/*!
* \brief The producer consumer function of node layout
* \param attrs The attribute of the node.
* \param ilayouts The input layouts that the node request.
* \param olayouts The output layouts that the node produce.
* \return bool The success flag.
* \brief Modify the op node to alter its input layout.
* it is invoked in AlterOpLayout pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos The inferred shape and dtype of the inputs.
*/
using FTVMLayoutRequest = std::function<bool (const NodeAttrs& attrs,
std::vector<TLayoutInfo> *ilayouts,
std::vector<TLayoutInfo> *olayouts)>;
using FTVMAlterOpLayout = std::function<
Symbol(const NodeAttrs& attrs,
const Symbol& inputs,
const Array<Tensor>& tinfos)>;

/*!
* \brief Transform from normal operator to vectorized operator
Expand Down
2 changes: 2 additions & 0 deletions nnvm/include/nnvm/compiler/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <nnvm/graph.h>
#include <nnvm/symbolic.h>
#include <string>
#include <vector>
#include <unordered_map>

namespace nnvm {
Expand Down Expand Up @@ -52,6 +53,7 @@ template<>
struct extension_class_info<nnvm::compiler::AttrDict> {
static const int code = 18;
};

} // namespace runtime
} // namespace tvm
#endif // NNVM_COMPILER_PACKED_FUNC_EXT_H_
18 changes: 17 additions & 1 deletion nnvm/include/nnvm/graph_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <vector>
#include <string>
#include "./tuple.h"
#include "./layout.h"

namespace nnvm {

Expand Down Expand Up @@ -46,14 +47,29 @@ using ShapeVector = std::vector<TShape>;
* \code
* Graph g = ApplyPass(src_graph, "InferType");
* const DTypeVector& types = g.GetAttr<DTypeVector>("dtype");
* // get shape by entry id
* // get type by entry id
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)];
* \endcode
*
* \sa FInferType
*/
using DTypeVector = std::vector<int>;

/*!
* \brief The result holder of layout of each NodeEntry in the graph.
* \note Stored under graph.attrs["layout"], provided by Pass "InferType"
*
* \code
* Graph g = ApplyPass(src_graph, "LayoutTransform");
* const LayoutVector& layouts = g.GetAttr<LayoutVector>("layout");
* // get layout by entry id
* int entry_layout = layouts[g.indexed_graph().entry_id(my_entry)];
* \endcode
*
* \sa FInferLayout
*/
using LayoutVector = std::vector<Layout>;

/*!
* \brief The result holder of device of each operator in the graph.
* \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice"
Expand Down
Loading

0 comments on commit 8f9c593

Please sign in to comment.