From 8f9c593c696964e1c0d29ab25f7fc1689d0a4051 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 25 Apr 2018 09:07:31 -0700 Subject: [PATCH] General Layout Support (#447) --- nnvm/include/nnvm/compiler/contrib_op_param.h | 28 -- nnvm/include/nnvm/compiler/op_attr_types.h | 21 +- nnvm/include/nnvm/compiler/packed_func_ext.h | 2 + nnvm/include/nnvm/graph_attr_types.h | 18 +- nnvm/include/nnvm/layout.h | 455 ++++++++++++++++++ nnvm/include/nnvm/op_attr_types.h | 26 + nnvm/include/nnvm/top/nn.h | 80 +-- nnvm/python/nnvm/_ctypes/symbol.py | 5 +- nnvm/python/nnvm/compiler/build_module.py | 36 +- nnvm/python/nnvm/compiler/graph_attr.py | 19 +- nnvm/python/nnvm/contrib.py | 1 + nnvm/python/nnvm/frontend/mxnet.py | 12 +- nnvm/python/nnvm/symbol.py | 1 + nnvm/python/nnvm/top/nn.py | 43 +- nnvm/python/nnvm/top/registry.py | 27 ++ nnvm/src/c_api/c_api_symbolic.cc | 2 +- nnvm/src/compiler/alter_op_layout.cc | 151 ++++++ nnvm/src/compiler/fold_scale_axis.cc | 8 +- nnvm/src/compiler/layout_transform.cc | 159 ------ nnvm/src/compiler/packed_func_ext.cc | 20 +- nnvm/src/compiler/simplify_inference.cc | 16 +- nnvm/src/pass/correct_layout.cc | 169 +++++++ nnvm/src/pass/infer_shape_type.cc | 2 +- nnvm/src/top/elemwise_op_common.h | 173 +++++++ nnvm/src/top/nn/convolution.cc | 143 +++++- nnvm/src/top/nn/nn.cc | 200 +++++++- nnvm/src/top/nn/nn_common.h | 130 ++--- nnvm/src/top/nn/pooling.cc | 241 +++++++--- nnvm/src/top/nn/upsampling.cc | 15 + nnvm/src/top/op_common.h | 15 + nnvm/src/top/tensor/broadcast.cc | 77 ++- nnvm/src/top/tensor/elemwise.cc | 4 + nnvm/src/top/tensor/matrix_op.cc | 26 + nnvm/src/top/tensor/reduce.cc | 2 + nnvm/src/top/tensor/state_op.cc | 9 + nnvm/src/top/tensor/transform.cc | 69 ++- .../python/compiler/test_alter_op_layout.py | 49 ++ .../tests/python/compiler/test_nhwc_layout.py | 9 +- .../python/unittest/test_correct_layout.py | 338 +++++++++++++ 39 files changed, 2362 insertions(+), 439 deletions(-) delete mode 100644 nnvm/include/nnvm/compiler/contrib_op_param.h create mode 100644 nnvm/include/nnvm/layout.h create mode 100644 nnvm/python/nnvm/contrib.py create mode 100644 nnvm/src/compiler/alter_op_layout.cc delete mode 100644 nnvm/src/compiler/layout_transform.cc create mode 100644 nnvm/src/pass/correct_layout.cc create mode 100644 nnvm/tests/python/compiler/test_alter_op_layout.py create mode 100644 nnvm/tests/python/unittest/test_correct_layout.py diff --git a/nnvm/include/nnvm/compiler/contrib_op_param.h b/nnvm/include/nnvm/compiler/contrib_op_param.h deleted file mode 100644 index 4eb91f56c..000000000 --- a/nnvm/include/nnvm/compiler/contrib_op_param.h +++ /dev/null @@ -1,28 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file contrib_op_param.h - * \brief Additional parameters for compiler optimized operators. - */ -#ifndef NNVM_COMPILER_CONTRIB_OP_PARAM_H_ -#define NNVM_COMPILER_CONTRIB_OP_PARAM_H_ - -#include -#include - -namespace nnvm { -namespace compiler { - -/*! \brief Parameters of layout transform operator */ -struct LayoutTransformParam : public dmlc::Parameter { - std::string src_layout; - std::string dst_layout; - - DMLC_DECLARE_PARAMETER(LayoutTransformParam) { - DMLC_DECLARE_FIELD(src_layout); - DMLC_DECLARE_FIELD(dst_layout); - } -}; -} // namespace compiler -} // namespace nnvm - -#endif // NNVM_COMPILER_CONTRIB_OP_PARAM_H_ diff --git a/nnvm/include/nnvm/compiler/op_attr_types.h b/nnvm/include/nnvm/compiler/op_attr_types.h index f8d32320a..231e85093 100644 --- a/nnvm/include/nnvm/compiler/op_attr_types.h +++ b/nnvm/include/nnvm/compiler/op_attr_types.h @@ -16,6 +16,7 @@ #include #include #include +#include "packed_func_ext.h" namespace nnvm { namespace compiler { @@ -73,19 +74,17 @@ using FTVMSchedule = std::function< const Array& outs, const std::string& target)>; -/*! \brief Layout Information about an entry */ -using TLayoutInfo = std::string; - /*! - * \brief The producer consumer function of node layout - * \param attrs The attribute of the node. - * \param ilayouts The input layouts that the node request. - * \param olayouts The output layouts that the node produce. - * \return bool The success flag. + * \brief Modify the op node to alter its input layout. + * it is invoked in AlterOpLayout pass. + * \param attrs The attribute of the original node. + * \param inputs The input symbols of the original node. + * \param tinfos The inferred shape and dtype of the inputs. */ -using FTVMLayoutRequest = std::function *ilayouts, - std::vector *olayouts)>; +using FTVMAlterOpLayout = std::function< + Symbol(const NodeAttrs& attrs, + const Symbol& inputs, + const Array& tinfos)>; /*! * \brief Transform from normal operator to vectorized operator diff --git a/nnvm/include/nnvm/compiler/packed_func_ext.h b/nnvm/include/nnvm/compiler/packed_func_ext.h index 241febbe9..bd768ff90 100644 --- a/nnvm/include/nnvm/compiler/packed_func_ext.h +++ b/nnvm/include/nnvm/compiler/packed_func_ext.h @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace nnvm { @@ -52,6 +53,7 @@ template<> struct extension_class_info { static const int code = 18; }; + } // namespace runtime } // namespace tvm #endif // NNVM_COMPILER_PACKED_FUNC_EXT_H_ diff --git a/nnvm/include/nnvm/graph_attr_types.h b/nnvm/include/nnvm/graph_attr_types.h index 64894ec58..c8a9e7aad 100644 --- a/nnvm/include/nnvm/graph_attr_types.h +++ b/nnvm/include/nnvm/graph_attr_types.h @@ -9,6 +9,7 @@ #include #include #include "./tuple.h" +#include "./layout.h" namespace nnvm { @@ -46,7 +47,7 @@ using ShapeVector = std::vector; * \code * Graph g = ApplyPass(src_graph, "InferType"); * const DTypeVector& types = g.GetAttr("dtype"); - * // get shape by entry id + * // get type by entry id * int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)]; * \endcode * @@ -54,6 +55,21 @@ using ShapeVector = std::vector; */ using DTypeVector = std::vector; +/*! + * \brief The result holder of layout of each NodeEntry in the graph. + * \note Stored under graph.attrs["layout"], provided by Pass "InferType" + * + * \code + * Graph g = ApplyPass(src_graph, "LayoutTransform"); + * const LayoutVector& layouts = g.GetAttr("layout"); + * // get layout by entry id + * int entry_layout = layouts[g.indexed_graph().entry_id(my_entry)]; + * \endcode + * + * \sa FInferLayout + */ +using LayoutVector = std::vector; + /*! * \brief The result holder of device of each operator in the graph. * \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice" diff --git a/nnvm/include/nnvm/layout.h b/nnvm/include/nnvm/layout.h new file mode 100644 index 000000000..494db60c4 --- /dev/null +++ b/nnvm/include/nnvm/layout.h @@ -0,0 +1,455 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file layout.h + * \brief Layout expression. + * The layout is composed of upper cases, lower cases and numbers, + * where upper case indicates a (super-)dimension and + * the corresponding lower case with factor size indicates the split (sub-)dimension. + * For example, NCHW16c can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * Here sub-dimension channel_block=16 is the split of super-dimension C (channel). + */ +#ifndef NNVM_LAYOUT_H_ +#define NNVM_LAYOUT_H_ + +#include +#include +#include +#include +#include +#include + +namespace nnvm { + +class Layout { + public: + using LayoutDim = char; + + /*! \brief default constructor */ + Layout() : name_("__undef__") {} // NOLINT(*) + + /*! + * \brief construct from a string. + * \param layout input in layout convention: + * upper case indicates a dimension and + * the corresponding lower case with factor size + * indicates the split dimension. + * return undefined layout if "__undef__" is passed. + */ + inline Layout(const std::string& layout) { // NOLINT(*) + parse(layout); + } + /*! + * \brief copy constructor from another layout + * \param s the source layout + */ + inline Layout(const Layout& s) { // NOLINT(*) + this->parse(s.name_); + } + /*! + * \brief move constructor from Layout + * \param src the source layout + */ + inline Layout(Layout&& src) { // NOLINT(*) + this->swap(src); + } + /*! + * \brief assignment from another layout. + * \param src source layout + * \return reference of self + */ + inline Layout& operator=(const Layout& src) { + this->parse(src.name_); + return *this; + } + /*! + * \brief assignment from rvalue of another layout. + * \param src source layout + * \return reference of self + */ + inline Layout& operator=(Layout&& src) { + Layout(std::move(src)).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief assignment from string. + * \param src source layout + * \return reference of self + */ + inline Layout& operator=(const std::string& src) { + this->parse(src); + return *this; + } + /*! + * \return whether two layout equals + * \param s the layout to compare against + */ + inline bool operator==(const Layout& s) const { + return name_ == s.name_; + } + /*! + * \return whether two layout not equal + * \param s the layout to compare against + */ + inline bool operator!=(const Layout& s) const { + return !(*this == s); + } + + /*! + * \brief Append the current layout by another. + * @param other the layout to be appended + * @return a new layout + */ + inline Layout operator+(const Layout& other) const { + if (!this->defined() && !other.defined()) { + return Layout::Undef(); + } else if (!this->defined()) { + return other; + } else if (!other.defined()) { + return *this; + } + return Layout(this->name_ + other.name_); + } + + /*! + * \brief Check whether a given dimension is a super-dimension. + * \param dim input dimension + * \return Whether a given dimension is a super-dimension. + */ + static inline bool is_superdim(LayoutDim dim) { + return dim >= 'A' && dim <= 'Z'; + } + + /*! + * \brief Check whether a given dimension is a sub-dimension. + * \param dim input dimension + * \return Whether a given dimension is a sub-dimension. + */ + static inline bool is_subdim(LayoutDim dim) { + return dim >= 'a' && dim <= 'z'; + } + + /*! + * \brief Convert a given dimension to super-dimension. + * \param dim input dimension + * \return The converted description. + */ + static inline LayoutDim to_superdim(LayoutDim dim) { + if (is_subdim(dim)) { + return dim - 'a' + 'A'; + } + return dim; + } + + /*! + * \brief Convert a given dimension to sub-dimension. + * \param dim input dimension + * \return The converted description. + */ + static inline LayoutDim to_subdim(LayoutDim dim) { + if (is_superdim(dim)) { + return dim - 'A' + 'a'; + } + return dim; + } + + /*! + * \brief Return an undefined layout. + * \return a (global) undefined layout. + */ + static inline const Layout& Undef() { + static Layout undef; + return undef; + } + + /*! + * \brief Swap current object with other + * \param other another object to be swapped. + */ + inline void swap(Layout& other) { // NOLINT(*) + std::swap(name_, other.name_); + std::swap(superdim_pos_, other.superdim_pos_); + std::swap(subdim_pos_, other.subdim_pos_); + std::swap(subdim_size_, other.subdim_size_); + std::swap(layout_simplified_, other.layout_simplified_); + } + + /*! + * \brief Two layouts are convertible only if + * they have same set of super-dimensions. + * e.g., NCHW, NCHW16c, NHWC are convertible between each other, + * but NCHW, CHW, OIHW are not. + * \param dst the target layout + * \return Whether can be converted to dst layout. + */ + inline bool convertible(const Layout &dst) const { + if (!this->defined() || !dst.defined()) return false; + for (size_t i = 0; i < kUniqueDim; ++i) { + if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) || + (superdim_pos_[i] < 0 && dst.superdim_pos_[i] >= 0)) { + return false; + } + } + return true; + } + + /*! + * \brief Returns a sublayout which is the portion of the object + * that starts at dimension \p pos and spans \p len dimensions + * (or until the end of the layout, whichever comes first). + * \param pos The start position. + * \param len The length of the sub-layout. + * \return A newly constructed Layout object. + */ + inline Layout sublayout(size_t pos, size_t len) const { + if (pos > ndim()) return Layout::Undef(); + if (pos + len > ndim()) len = ndim() - pos; + if (len == 0) return Layout::Undef(); + std::ostringstream new_layout; + for (size_t i = pos; i < pos + len; ++i) { + if (is_subdim(layout_simplified_[i])) { + auto block_size = this->subsizeof(layout_simplified_[i]); + CHECK_GT(block_size, 0); + new_layout << block_size; + } + new_layout << layout_simplified_[i]; + } + return Layout(new_layout.str()); + } + + /*! \return A newly constructed reversed Layout object. */ + inline Layout reverse() const { + if (!this->defined()) return Layout::Undef(); + std::ostringstream new_layout; + for (int64_t i = this->ndim() - 1; i >= 0; --i) { + if (is_subdim(layout_simplified_[i])) { + auto block_size = this->subsizeof(layout_simplified_[i]); + CHECK_GT(block_size, 0); + new_layout << block_size; + } + new_layout << layout_simplified_[i]; + } + return Layout(new_layout.str()); + } + + /*! + * \brief Split \p dim by \p size and put the sub-dimension to position \p target_pos. + * \param dim The source dimension to be split. It must be a super-dimension. + * \param target_pos The target position of the newly split sub-dimension. + * \param size size of the sub-dimension. + * \return A newly constructed Layout object. + */ + inline Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const { + CHECK(target_pos <= this->ndim()) << "Invalid split position " + << target_pos << " for layout " << name_; + CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim; + CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_; + CHECK(!this->contains(to_subdim(dim))) << "Dimension " << dim + << " has already been split in " + << name_; + CHECK(size > 0) << "Invalid split size " << size; + std::ostringstream new_layout; + for (size_t i = 0; i <= this->ndim(); ++i) { + if (i == target_pos) { + new_layout << size << Layout::to_subdim(dim); + } + if (i == this->ndim()) break; + new_layout << this->at(i); + } + Layout x(new_layout.str()); + return x; + } + + using iterator = std::vector::const_iterator; + using reverse_iterator = std::vector::const_reverse_iterator; + + /*! \return begin iterator */ + inline iterator begin() const { + return layout_simplified_.begin(); + } + /*! \return end iterator */ + inline iterator end() const { + return layout_simplified_.end(); + } + /*! \return rbegin iterator */ + inline reverse_iterator rbegin() const { + return layout_simplified_.rbegin(); + } + /*! \return rend iterator */ + inline reverse_iterator rend() const { + return layout_simplified_.rend(); + } + + /*! \return number of dimensions */ + inline size_t ndim() const { + return layout_simplified_.size(); + } + + /*! + * \brief The description of the \p i-th dimension. + * If it is a sub-dimension, the size will be returned as well, + * e.g., 16c. Otherwise a single character is returned, e.g., C. + * \param i The position + * \return the description of the dimension. + */ + inline std::string at(size_t i) const { + CHECK_LT(i, this->ndim()) << "position " << i + << " exceeds ndim=" << this->ndim(); + std::ostringstream repr; + if (is_subdim(layout_simplified_[i])) { + auto factor = subsizeof(layout_simplified_[i]); + CHECK_LT(factor, 0); + repr << factor; + } + repr << layout_simplified_[i]; + return repr.str(); + } + + /*! + * \brief return the index of the input dimension. + * If it is not found in the layout or the layout is undefined, + * return -1. + * \param dim the input dimension. + * \return the index or -1 if not found. + */ + inline int32_t indexof(LayoutDim dim) const { + if (!this->defined()) return -1; + else if (is_superdim(dim)) return superdim_pos_[dim - 'A']; + else if (is_subdim(dim)) return subdim_pos_[dim - 'a']; + return -1; + } + + /*! + * \param dim the input super-dimension or sub-dimension. + * \return the size of the sub-dimension of \p dim (if \p dim is a super-dimension), + * or the size of \p dim itself (if \p dim is a sub-dimension). + * Return -1 if \p dim is not in the layout or the layout is undefined. + */ + inline int64_t subsizeof(LayoutDim dim) const { + CHECK(is_superdim(dim) || is_subdim(dim)) << "Invalid dim " << dim; + if (!this->defined() || !this->contains(to_subdim(dim))) { + return -1; + } + int idx = to_subdim(dim) - 'a'; + return subdim_size_[idx]; + } + + /*! + * \brief Whether the layout contains a dimension. + * \param dim dimension to be checked. + * \return Whether the layout contains the dimension. + */ + inline bool contains(LayoutDim dim) const { + if (is_superdim(dim)) { + return superdim_pos_[dim-'A'] >= 0; + } else if (is_subdim(dim)) { + return subdim_pos_[dim-'a'] >= 0; + } + return false; + } + + inline const LayoutDim operator[](size_t i) const { + return layout_simplified_[i]; + } + + /*! \return whether the layout is defined */ + inline bool defined() const { + return name_ != "__undef__"; + } + + /*! \return the string description of the layout */ + inline const std::string& name() const { + return name_; + } + + /*! + * \brief Write layout in JSON format. + * \param writer JSONWriter + */ + inline void Save(dmlc::JSONWriter* writer) const { + writer->Write(name_); + } + + /*! + * \brief Load layout from JSON. + * \param reader JSONReader + */ + inline void Load(dmlc::JSONReader* reader) { + std::string tmp; + reader->Read(&tmp); + this->parse(tmp); + } + + /*! + * \brief allow output string of layout to ostream + * \param os the output stream + * \param l the layout + * \return the ostream + */ + friend std::ostream& operator<<(std::ostream& os, const Layout& l) { + os << l.name_; + return os; + } + + private: + static const uint32_t kUniqueDim = 26; + + std::string name_; + int32_t superdim_pos_[kUniqueDim]; + int32_t subdim_pos_[kUniqueDim]; + int64_t subdim_size_[kUniqueDim]; + std::vector layout_simplified_; + + void parse(const std::string& layout) { + name_ = layout; + std::fill_n(superdim_pos_, kUniqueDim, -1); + std::fill_n(subdim_pos_, kUniqueDim, -1); + std::fill_n(subdim_size_, kUniqueDim, -1); + layout_simplified_.clear(); + + if (layout == "__undef__") return; + + int32_t factor = 0; + uint32_t curr = 0; + for (size_t i = 0; i < layout.size(); ++i) { + const LayoutDim c = layout.at(i); + if (is_superdim(c)) { + int pos = c - 'A'; + CHECK_EQ(factor, 0) << "Invalid layout " << layout + << ": invalid factor size " << factor + << " before dimension " << c; + CHECK_EQ(superdim_pos_[pos], -1) << "Invalid layout " << layout + << ": duplicate dimension " << c; + superdim_pos_[pos] = curr++; + layout_simplified_.push_back(c); + } else if (is_subdim(c)) { + int pos = c - 'a'; + CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size " + << factor << " for dimension " << c; + CHECK_EQ(subdim_pos_[pos], -1) << "Invalid layout " << layout + << ": duplicate dimension " << c; + CHECK_EQ(subdim_size_[pos], -1) << "Invalid layout " << layout + << ": duplicate dimension " << c; + subdim_pos_[pos] = curr++; + subdim_size_[pos] = factor; + layout_simplified_.push_back(c); + factor = 0; + } else if (c >= '0' && c <= '9') { + CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number."; + factor = factor * 10 + c - '0'; + } else { + LOG(FATAL) << "Invalid layout " << layout; + } + } + CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout; + for (LayoutDim dim : layout_simplified_) { + CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0) + << "Invalid layout " << layout << ": missing axis " + << static_cast(dim - 'a' + 'A'); + } + } +}; + +} // namespace nnvm + +#endif // NNVM_LAYOUT_H_ diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index e7bbb9685..9891b82f9 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -13,6 +13,7 @@ #include "./base.h" #include "./node.h" #include "./tuple.h" +#include "./layout.h" namespace nnvm { @@ -176,6 +177,31 @@ using FSetInputVarAttrOnCompose = std::function; +/*! + * \brief Inference function of node layout. See \p Layout for layout convention + * \param attrs The attribute of the node. + * \param ilayouts Given the input layouts produced by ancestor nodes, + * it should be filled by layouts that the node requests. + * If the requested layout is different from what ancestor produces, + * a __layout_transform__ operator will be inserted automatically. + * \param last_ilayouts The input layouts requested by the node + * at the last infer pass (if any). + * This can be useful when an operator wants to keep + * the input layout the same as the original one. + * For example, after the pass of AlterOpLayout, + * transpose(input, axis=[1, 2, 3, 0]) may receive an input of NCHW16c layout, + * with which it cannot calculate with axis=[1, 2, 3, 0]. + * Last input layouts allow it to know what the layout it originally inferred, + * i.e., the layout in the imported model. + * \param olayouts Inferred output layouts. + * \return success flag. + */ +using FInferLayout = std::function *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts)>; + } // namespace nnvm #endif // NNVM_OP_ATTR_TYPES_H_ diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index 7eb2e5e11..80d8e0288 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -9,23 +9,12 @@ #include #include #include +#include +#include namespace nnvm { namespace top { -// Layout flag in spatial conv and pooling. -enum LayoutFlag { - kNCHW, - kNHWC, - kCHWN, - kNCW, - kNWC, - kCWN, - kNCDHW, - kNDHWC, - kCDHWN -}; - struct DenseParam : public dmlc::Parameter { int units; bool use_bias; @@ -130,7 +119,9 @@ struct Conv2DParam : public dmlc::Parameter { TShape padding; TShape dilation; int groups; - int layout; + std::string layout; + std::string kernel_layout; + std::string out_layout; bool use_bias; DMLC_DECLARE_PARAMETER(Conv2DParam) { @@ -152,14 +143,19 @@ struct Conv2DParam : public dmlc::Parameter { "At groups=2, the operation becomes equivalent to having two convolution" "layers side by side, each seeing half the input channels, and producing" "half the output channels, and both subsequently concatenated."); - DMLC_DECLARE_FIELD(layout) - .add_enum("NCHW", kNCHW) - .add_enum("NHWC", kNHWC) - .set_default(kNCHW) - .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." + DMLC_DECLARE_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Convolution is applied on the 'H' and" "'W' dimensions."); + DMLC_DECLARE_FIELD(out_layout).set_default("__undef__") + .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); + DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW") + .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); DMLC_DECLARE_FIELD(use_bias).set_default(true) .describe("Whether the layer uses a bias vector."); } @@ -178,7 +174,8 @@ struct Conv2DTransposeParam : public dmlc::Parameter { TShape output_padding; TShape dilation; int groups; - int layout; + std::string layout; + std::string kernel_layout; bool use_bias; DMLC_DECLARE_PARAMETER(Conv2DTransposeParam) { @@ -202,14 +199,15 @@ struct Conv2DTransposeParam : public dmlc::Parameter { "At groups=2, the operation becomes equivalent to having two convolution" "layers side by side, each seeing half the input channels, and producing" "half the output channels, and both subsequently concatenated."); - DMLC_DECLARE_FIELD(layout) - .add_enum("NCHW", kNCHW) - .add_enum("NHWC", kNHWC) - .set_default(kNCHW) - .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." + DMLC_DECLARE_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Convolution is applied on the 'H' and" "'W' dimensions."); + DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW") + .describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); DMLC_DECLARE_FIELD(use_bias).set_default(true) .describe("Whether the layer uses a bias vector."); } @@ -224,7 +222,7 @@ struct Pool2DParam : public dmlc::Parameter { TShape pool_size; TShape strides; TShape padding; - int layout; + std::string layout; bool ceil_mode; DMLC_DECLARE_PARAMETER(Pool2DParam) { @@ -235,10 +233,7 @@ struct Pool2DParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) .describe("If padding is non-zero, then the input is implicitly zero-padded" "on both sides for padding number of points"); - DMLC_DECLARE_FIELD(layout) - .add_enum("NCHW", kNCHW) - .add_enum("NHWC", kNHWC) - .set_default(kNCHW) + DMLC_DECLARE_FIELD(layout).set_default("NCHW") .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Convolution is applied on the 'H' and" @@ -250,13 +245,10 @@ struct Pool2DParam : public dmlc::Parameter { struct GlobalPool2DParam : public dmlc::Parameter { - int layout; + std::string layout; DMLC_DECLARE_PARAMETER(GlobalPool2DParam) { - DMLC_DECLARE_FIELD(layout) - .add_enum("NCHW", kNCHW) - .add_enum("NHWC", kNHWC) - .set_default(kNCHW) + DMLC_DECLARE_FIELD(layout).set_default("NCHW") .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Convolution is applied on the 'H' and" @@ -266,15 +258,13 @@ struct GlobalPool2DParam : public dmlc::Parameter { struct UpSamplingParam : public dmlc::Parameter { int scale; - int layout; + std::string layout; DMLC_DECLARE_PARAMETER(UpSamplingParam) { DMLC_DECLARE_FIELD(scale) .describe("upsampling scaling factor"); DMLC_DECLARE_FIELD(layout) - .add_enum("NCHW", kNCHW) - .add_enum("NHWC", kNHWC) - .set_default(kNCHW) + .set_default("NCHW") .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Convolution is applied on the 'H' and" @@ -282,6 +272,18 @@ struct UpSamplingParam : public dmlc::Parameter { } }; +struct LayoutTransformParam : public dmlc::Parameter { + std::string src_layout; + std::string dst_layout; + + DMLC_DECLARE_PARAMETER(LayoutTransformParam) { + DMLC_DECLARE_FIELD(src_layout).set_default("__undef__") + .describe("Dimension ordering of data"); + DMLC_DECLARE_FIELD(dst_layout).set_default("__undef__") + .describe("Dimension ordering of data."); + } +}; + } // namespace top } // namespace nnvm diff --git a/nnvm/python/nnvm/_ctypes/symbol.py b/nnvm/python/nnvm/_ctypes/symbol.py index 7a680f093..843601c10 100644 --- a/nnvm/python/nnvm/_ctypes/symbol.py +++ b/nnvm/python/nnvm/_ctypes/symbol.py @@ -211,12 +211,15 @@ def _init_symbol_module(symbol_class, root_namespace): op_names.append(py_str(plist[i])) module_obj = sys.modules["%s.symbol" % root_namespace] + module_obj_contrib = sys.modules["%s.contrib" % root_namespace] module_internal = sys.modules["%s._symbol_internal" % root_namespace] for name in op_names: hdl = OpHandle() check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) function = _make_atomic_symbol_function(hdl, name) - if function.__name__.startswith('_'): + if function.__name__.startswith('_contrib_'): + setattr(module_obj_contrib, function.__name__.split('_contrib_')[1], function) + elif function.__name__.startswith('_'): setattr(module_internal, function.__name__, function) setattr(module_obj, function.__name__, function) else: diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index d97a8784d..436b06c5e 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -15,7 +15,8 @@ "SimplifyInference": 0, "PrecomputePrune": 2, "OpFusion": 1, - "FoldScaleAxis": 3 + "FoldScaleAxis": 3, + "AlterOpLayout": 3, } # List of optimization pass and level when switch on @@ -139,7 +140,7 @@ def _update_shape_dtype(shape, dtype, params): return shape, dtype -def optimize(graph, shape, dtype="float32"): +def optimize(graph, shape, dtype="float32", layout=None): """Perform target and parameter invariant graph optimization. This is an advanced function that usually do not need to be called. @@ -157,6 +158,18 @@ def optimize(graph, shape, dtype="float32"): """ # pylint: disable=unused-argument cfg = BuildConfig.current + + if cfg.pass_enabled("AlterOpLayout"): + layout = layout if layout else {} + graph = graph_attr.set_layout_inputs(graph, layout) + graph = graph.apply(["CorrectLayout"]) + + graph = graph_attr.set_shape_inputs(graph, shape) + graph = graph_attr.set_dtype_inputs(graph, dtype) + graph = graph.apply(["InferShape", "InferType", "AlterOpLayout"]) + graph = graph_attr.set_layout_inputs(graph, layout) + graph = graph.apply(["CorrectLayout"]) + if cfg.pass_enabled("SimplifyInference"): graph = graph_attr.set_shape_inputs(graph, shape) graph = graph.apply(["InferShape", "SimplifyInference"]) @@ -167,7 +180,8 @@ def optimize(graph, shape, dtype="float32"): return graph -def build(graph, target=None, shape=None, dtype="float32", params=None, target_host=None): +def build(graph, target=None, shape=None, dtype="float32", + params=None, target_host=None, layout=None): """Build graph into runtime library. The build function will optimize the graph and do the compilation. @@ -204,8 +218,8 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h By default, llvm is used if it is enabled, otherwise a stackvm intepreter is used. - initialize : bool, optional - Whether to initialize variables in global dict _all_var_init. + layout : dict of str to str or str optional + The input layout Returns ------- @@ -230,6 +244,15 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h cfg = BuildConfig.current graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph) shape, dtype = _update_shape_dtype(shape, dtype, params) + + # correct layout if necessary + layout = layout if layout else {} + graph = graph_attr.set_layout_inputs(graph, layout) + graph = graph.apply("CorrectLayout") + index = graph.index + layouts = graph.json_attr("layout") + layout = {x : layouts[index.entry_id(x)] for x in index.input_names} + # Initial pass do shape type inference ishape, _ = graph_util.infer_shape(graph, **shape) shape.update(zip(graph.index.input_names, ishape)) @@ -241,13 +264,14 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h if _all_var_init: init_var = initialize_variables(shape, dtype) # Apply optimization - graph = optimize(graph, shape, dtype) + graph = optimize(graph, shape, dtype, layout) # Precompute prune if params and cfg.pass_enabled("PrecomputePrune"): graph, params = precompute_prune(graph, params) shape, dtype = _update_shape_dtype(shape, dtype, params) # Operator Fusion and generation graph = graph_attr.set_shape_inputs(graph, shape) + graph = graph.apply("InferShape") graph = graph_attr.set_dtype_inputs(graph, dtype) graph._set_json_attr("target", str(target), "str") if target_host is not None: diff --git a/nnvm/python/nnvm/compiler/graph_attr.py b/nnvm/python/nnvm/compiler/graph_attr.py index 94b973353..3ce6c4b53 100644 --- a/nnvm/python/nnvm/compiler/graph_attr.py +++ b/nnvm/python/nnvm/compiler/graph_attr.py @@ -96,11 +96,22 @@ def set_layout_inputs(g, layout): Returns ------- g : Graph - The updated graph with updated dtype. + The updated graph with updated layout. """ - list_shape = [ - layout.get(name, "default") for name in g.index.input_names] - g._set_json_attr("layout_inputs", list_shape, 'list_str') + if isinstance(layout, dict): + list_layout = [ + layout.get(name, "__undef__") for name in g.index.input_names] + elif isinstance(layout, str): + list_layout = ["__undef__"] * len(g.index.input_names) + list_layout[0] = layout + else: + raise ValueError("Input layout must be str or dict") + last_inferred_layouts = g.json_attr("layout") + if last_inferred_layouts: + input_layout = [last_inferred_layouts[g.index.entry_id(x)] for x in g.index.input_names] + for i, layout_stored in enumerate(input_layout): + list_layout[i] = list_layout[i] if list_layout[i] != '__undef__' else layout_stored + g._set_json_attr("layout_inputs", list_layout, 'list_layout') return g _move_out_module = tvm.get_global_func("nnvm.graph._move_module") diff --git a/nnvm/python/nnvm/contrib.py b/nnvm/python/nnvm/contrib.py new file mode 100644 index 000000000..976eb532b --- /dev/null +++ b/nnvm/python/nnvm/contrib.py @@ -0,0 +1 @@ +"""Module space to register contrib functions. Leave empty""" diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 4296705f0..e671acbe7 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -86,6 +86,10 @@ def _conv2d(inputs, attrs): layout = attrs.get('layout', 'NCHW') if layout not in ['NCHW', 'NHWC']: _raise_not_supported('layout: ' + layout, 'conv2d') + if 'kernel_layout' in attrs: + kernel_layout = attrs['kernel_layout'] + else: + kernel_layout = 'HWIO' if layout == 'NHWC' else 'OIHW' op_name, new_attrs = 'conv2d', {} new_attrs['channels'] = _required_attr(attrs, 'num_filter') new_attrs['kernel_size'] = kernel @@ -94,6 +98,7 @@ def _conv2d(inputs, attrs): new_attrs['dilation'] = attrs.get('dilate', (1, 1)) new_attrs['groups'] = attrs.get('num_group', 1) new_attrs['layout'] = layout + new_attrs['kernel_layout'] = kernel_layout new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False' return _get_nnvm_op(op_name)(*inputs, **new_attrs) @@ -106,6 +111,10 @@ def _conv2d_transpose(inputs, attrs): layout = attrs.get('layout', 'NCHW') if layout not in ['NCHW', 'NHWC']: _raise_not_supported('layout: ' + layout, 'conv2d_transpose') + if 'kernel_layout' in attrs: + kernel_layout = attrs['kernel_layout'] + else: + kernel_layout = 'HWIO' if layout == 'NHWC' else 'OIHW' op_name, new_attrs = 'conv2d_transpose', {} new_attrs['channels'] = _required_attr(attrs, 'num_filter') new_attrs['kernel_size'] = kernel @@ -115,6 +124,7 @@ def _conv2d_transpose(inputs, attrs): new_attrs['dilation'] = attrs.get('dilate', (1, 1)) new_attrs['groups'] = attrs.get('num_group', 1) new_attrs['layout'] = layout + new_attrs['kernel_layout'] = kernel_layout new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias') return _get_nnvm_op(op_name)(*inputs, **new_attrs) @@ -237,7 +247,7 @@ def _upsampling(inputs, attrs): 'min_axis' : _rename('min'), 'reshape' : _reshape, 'sum_axis' : _rename('sum'), - 'UpSampling' : _upsampling + 'UpSampling' : _upsampling, } def _convert_symbol(op_name, inputs, attrs, diff --git a/nnvm/python/nnvm/symbol.py b/nnvm/python/nnvm/symbol.py index 8b390e2cb..9fa0af286 100644 --- a/nnvm/python/nnvm/symbol.py +++ b/nnvm/python/nnvm/symbol.py @@ -16,6 +16,7 @@ from ._base import _LIB, check_call as _check_call, _FFI_MODE, _all_var_init from .attribute import AttrScope from . import _symbol_internal as _internal +from . import contrib # Use different verison of SymbolBase # When possible, use cython to speedup part of computation. diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index d7e6e5a08..7a5144983 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -5,7 +5,7 @@ import tvm import topi from topi.util import get_const_int -from .tensor import _fschedule_broadcast +from .tensor import _fschedule_broadcast, _fschedule_injective from . import registry as reg from .registry import OpPattern @@ -32,6 +32,11 @@ reg.register_pattern("pad", OpPattern.INJECTIVE) +# layout transform +reg.register_schedule("__layout_transform__", _fschedule_injective) +reg.register_pattern("__layout_transform__", OpPattern.INJECTIVE) + + @reg.register_schedule("softmax") def schedule_softmax(_, outs, target): """Schedule definition of softmax""" @@ -108,6 +113,42 @@ def schedule_conv2d(attrs, outs, target): reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) +# convolution NCHWc +@reg.register_compute("_contrib_conv2d_NCHWc") +def compute_contrib_conv2d_NCHWc(attrs, inputs, _): + """Compute definition of conv2d NCHWc""" + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + kh, kw = attrs.get_int_tuple('kernel_size') + groups = attrs.get_int("groups") + channels = attrs.get_int("channels") + assert dilation == (1, 1), "not support dilate now" + if groups == 1: + out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], channels, (kh, kw), strides, padding) + else: + raise ValueError("not support arbitrary group number > 1 for now") + if attrs.get_bool("use_bias"): + bias = inputs[2] + bias = topi.expand_dims(bias, axis=1, num_newaxis=2) + out = topi.broadcast_add(out, bias) + return out + +@reg.register_schedule("_contrib_conv2d_NCHWc") +def schedule_contrib_conv2d_NCHWc(attrs, outs, target): + """Schedule definition of conv2d NCHWc""" + groups = attrs.get_int("groups") + kh, kw = attrs.get_int_tuple('kernel_size') + oc = attrs.get_int("channels") + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + with tvm.target.create(target): + if groups == 1: + return topi.generic.schedule_conv2d_NCHWc(oc, (kh, kw), strides, padding, outs) + else: + raise ValueError("not support group number > 1 for now") + +reg.register_pattern("_contrib_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE) # conv2d_transpose @reg.register_compute("conv2d_transpose") diff --git a/nnvm/python/nnvm/top/registry.py b/nnvm/python/nnvm/top/registry.py index 6a7209442..68ea80e7e 100644 --- a/nnvm/python/nnvm/top/registry.py +++ b/nnvm/python/nnvm/top/registry.py @@ -25,6 +25,7 @@ class OpPattern(object): _register_compute = tvm.get_global_func("nnvm._register_compute") _register_schedule = tvm.get_global_func("nnvm._register_schedule") _register_pattern = tvm.get_global_func("nnvm._register_pattern") +_register_alter_op_layout = tvm.get_global_func("nnvm.compiler._register_alter_op_layout") def register_compute(op_name, f=None, level=10): """Register compute function for operator @@ -93,3 +94,29 @@ def register_pattern(op_name, pattern, level=10): The priority level """ _register_pattern(op_name, pattern, level) + + +def register_alter_op_layout(op_name, f=None, level=10): + """Register alter layout function for operator + + Parameters + ---------- + op_name : str + The name of operator + + f : function + The schedule function + + level : int + The priority level + + Returns + ------- + fregister : function + Register function if f is not specified. + """ + def register(myf): + """internal register function""" + _register_alter_op_layout(op_name, myf, level) + return myf + return register(f) if f else register diff --git a/nnvm/src/c_api/c_api_symbolic.cc b/nnvm/src/c_api/c_api_symbolic.cc index 0327ca4ae..9f62dbd80 100644 --- a/nnvm/src/c_api/c_api_symbolic.cc +++ b/nnvm/src/c_api/c_api_symbolic.cc @@ -294,7 +294,7 @@ int NNSymbolGetNumOutputs(SymbolHandle symbol, nn_uint *output_count) { Symbol *s = static_cast(symbol); API_BEGIN(); - *output_count = static_cast(s->outputs.size()); + *output_count = static_cast(s->outputs.size()); API_END(); } diff --git a/nnvm/src/compiler/alter_op_layout.cc b/nnvm/src/compiler/alter_op_layout.cc new file mode 100644 index 000000000..893a0d298 --- /dev/null +++ b/nnvm/src/compiler/alter_op_layout.cc @@ -0,0 +1,151 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file alter_op_layout.cc + * \brief Alter the operator layouts. Keep inferred layouts (if any) from previous stages. + * e.g., convolution may calculates faster with NCHW16c layout. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include "./compile_engine.h" +#include "./graph_transform.h" + +namespace nnvm { +namespace compiler { +namespace { + +tvm::Array GetTensorInfo(const IndexedGraph& idx_graph, + const uint32_t nid, + const ShapeVector& shape_vec, + const DTypeVector& dtype_vec) { + tvm::Array vec; + for (uint32_t i = 0; i < idx_graph[nid].source->num_outputs(); ++i) { + tvm::Array shape; + for (int64_t x : shape_vec[idx_graph.entry_id(nid, i)]) { + CHECK_LE(x, static_cast(std::numeric_limits::max())); + shape.push_back(tvm::make_const(tvm::Int(32), x)); + } + vec.push_back(tvm::placeholder( + shape, GetTVMType(dtype_vec[idx_graph.entry_id(nid, i)]))); + } + return vec; +} + +Graph AlterOpLayout(const Graph& src) { + static auto& falter_op_layout = + Op::GetAttr("FTVMAlterOpLayout"); + + const ShapeVector& shape_vec = src.GetAttr("shape"); + const DTypeVector& dtype_vec = src.GetAttr("dtype"); + const IndexedGraph& idx_graph = src.indexed_graph(); + + std::vector > in_layouts_of_node(idx_graph.num_nodes()); + std::vector > out_layouts_of_node(idx_graph.num_nodes()); + std::unordered_map new_nodes; + + if (src.HasAttr("layout")) { + // record layouts so that LayoutTransform pass can fix layouts correctly, + // e.g., conv2d can be replaced by some contrib implement + // whose layout is different from the original one + // (which was imported from a model file). + const auto& layouts = src.GetAttr >("layout"); + for (uint32_t nid = 0; nid < idx_graph.num_nodes(); ++nid) { + const auto &inode = idx_graph[nid]; + if (falter_op_layout.count(inode.source->op())) { + // do not record input layouts of nodes that will be replaced. + continue; + } + std::vector in_layout; + for (const auto& e : inode.inputs) { + in_layout.emplace_back(layouts[idx_graph.entry_id(e)]); + } + in_layouts_of_node[nid] = in_layout; + + std::vector out_layout; + for (uint i = 0; i < inode.source->num_outputs(); ++i) { + out_layout.emplace_back(layouts[idx_graph.entry_id(nid, i)]); + } + out_layouts_of_node[nid] = out_layout; + } + } + + auto transform = [&](uint32_t nid, + const NodePtr& n, + std::vector* ret) { + nnvm::compiler::FTVMAlterOpLayout fn_alter_op_layout = + falter_op_layout.get(n->op(), nullptr); + if (fn_alter_op_layout == nullptr) { + new_nodes[n.get()] = nid; + return false; + } + + // construct parameters for registered function + std::vector op_inputs; + tvm::Array tensor_infos; + CHECK_EQ(n->num_inputs(), idx_graph[nid].inputs.size()); + for (uint32_t i = 0; i < n->num_inputs(); ++i) { + const nnvm::NodeEntry& input = n->inputs[i]; + // input operator + Symbol op_input; + op_input.outputs.push_back(input); + op_inputs.push_back(op_input); + + // input tinfo, extract from the original graph + // because it was where infer_shape & infer_type applied. + tvm::Array op_output_tinfos = + GetTensorInfo(idx_graph, idx_graph[nid].inputs[i].node_id, + shape_vec, dtype_vec); + tensor_infos.push_back(op_output_tinfos[input.index]); + } + // callback registered function to get a new operator. + auto op = fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos); + *ret = op.outputs; + return true; + }; + + Graph ret = nnvm::compiler::GraphTransform(src, transform); + + if (src.HasAttr("layout")) { + // restore the layouts to return graph + const auto& ret_idx = ret.indexed_graph(); + std::vector ret_layouts(ret_idx.num_node_entries(), Layout::Undef()); + for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) { + const auto& inode = ret_idx[nid]; + if (new_nodes.count(inode.source)) { + const std::vector& in_layouts = + in_layouts_of_node[new_nodes[inode.source]]; + for (const auto& e : inode.inputs) { + ret_layouts[ret_idx.entry_id(e)] = in_layouts[e.index]; + } + const std::vector& out_layouts = + out_layouts_of_node[new_nodes[inode.source]]; + for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { + ret_layouts[ret_idx.entry_id(nid, i)] = out_layouts[i]; + } + } + } + + // cannot call indexed_graph() before return the origin Graph, + // thus create a new one. + nnvm::Graph new_ret; + new_ret.outputs = ret.outputs; + new_ret.attrs["layout"] = std::make_shared(std::move(ret_layouts)); + return new_ret; + } + + return ret; +} + +// register pass +NNVM_REGISTER_PASS(AlterOpLayout) +.set_body(AlterOpLayout) +.set_change_graph(true); + +} // namespace +} // namespace compiler +} // namespace nnvm diff --git a/nnvm/src/compiler/fold_scale_axis.cc b/nnvm/src/compiler/fold_scale_axis.cc index 7b05153b6..f9524eb8e 100644 --- a/nnvm/src/compiler/fold_scale_axis.cc +++ b/nnvm/src/compiler/fold_scale_axis.cc @@ -362,7 +362,7 @@ bool Pool2DBackward( std::vector* in_axis) { using top::Pool2DParam; const Pool2DParam& param = nnvm::get(attrs.parsed); - if (out_info.axis == 1 && param.layout == top::kNCHW) { + if (out_info.axis == 1 && param.layout == "NCHW") { (*in_axis)[0] = out_info; } return false; @@ -376,7 +376,7 @@ bool Pool2DForward( FoldChainInfo* out_info) { using top::Pool2DParam; const Pool2DParam& param = nnvm::get(attrs.parsed); - if ((*in_info)[0].axis == 1 && param.layout == top::kNCHW) { + if ((*in_info)[0].axis == 1 && param.layout == "NCHW") { *out_info = (*in_info)[0]; } return false; @@ -467,7 +467,7 @@ bool Conv2DScaleAxisBackward( const Conv2DParam& param = nnvm::get(attrs.parsed); if (out_info.kind != kPending) return false; // only optimize for nchw for now - if (param.layout == top::kNCHW && out_info.axis == 1) { + if (param.layout == "NCHW" && out_info.axis == 1) { (*in_axis)[1].kind = kMulConsumer; (*in_axis)[1].axis = 0; (*in_axis)[1].source = out_info.source; @@ -492,7 +492,7 @@ bool Conv2DScaleAxisForward( const Conv2DParam& param = nnvm::get(attrs.parsed); if ((*in_info)[0].kind != kPending) return false; // only optimize for nchw for now - if (param.layout == top::kNCHW && (*in_info)[0].axis == 1) { + if (param.layout == "NCHW" && (*in_info)[0].axis == 1) { (*in_info)[1].kind = kMulConsumer; (*in_info)[1].axis = 1; (*in_info)[1].source = (*in_info)[0].source; diff --git a/nnvm/src/compiler/layout_transform.cc b/nnvm/src/compiler/layout_transform.cc deleted file mode 100644 index 5651838ff..000000000 --- a/nnvm/src/compiler/layout_transform.cc +++ /dev/null @@ -1,159 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file layout_transform.cc - * \brief Transforms layout. - */ -#include -#include -#include -#include -#include -#include - -namespace nnvm { -namespace compiler { - -const TLayoutInfo& GetDefaultLayout() { - static TLayoutInfo default_layout = "default"; - return default_layout; -} - -nnvm::NodePtr CreateLayoutTransformNode(const std::string& src, - const std::string& dst) { - static const nnvm::Op* trans_op = nnvm::Op::Get("layout_transform"); - static int count = 0; - nnvm::NodePtr n = nnvm::Node::Create(); - n->attrs.op = trans_op; - n->attrs.name = src + "_to_" + dst + std::to_string(count++); - n->attrs.dict["src_layout"] = src; - n->attrs.dict["dst_layout"] = dst; - n->op()->attr_parser(&(n->attrs)); - return n; -} - -/*! - * \brief A simple layout transform pass that will - * insert layout transform nodes automatically. - */ -nnvm::Graph LayoutTransform(nnvm::Graph src) { - static auto& op_layout_request = - nnvm::Op::GetAttr("FTVMLayoutRequest"); - static auto& op_vecop = - nnvm::Op::GetAttr("FTVMVectorizedOp"); - static auto& op_pattern = nnvm::Op::GetAttr("TOpPattern"); - - const ShapeVector& shape_vec = src.GetAttr("shape"); - const std::vector& input_layouts = - src.GetAttr >("layout_inputs"); - - const IndexedGraph& idx = src.indexed_graph(); - std::vector produce_vec(idx.num_node_entries(), GetDefaultLayout()); - std::vector mirror_vec(idx.num_nodes(), nullptr); - - // use op pattern to decide whether an op is map - auto is_map_op = [&](size_t nid) { - TOpPattern pt = op_pattern.get(idx[nid].source->op(), kOpaque); - bool is_map = (pt <= kBroadcast); - if (pt == kBroadcast) { - for (const auto& e : idx[nid].inputs) { - if (shape_vec[idx.entry_id(nid, 0)] != shape_vec[idx.entry_id(e)]) { - is_map = false; - break; - } - } - } - return is_map; - }; - - for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { - const auto& inode = idx[nid]; - nnvm::NodePtr new_node = nnvm::Node::Create(); - *new_node = *(inode.source); - if (new_node->is_variable()) { - auto input_iter = std::find( - idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid); - CHECK(input_iter != idx.input_nodes().cend()); - size_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter); - produce_vec[idx.entry_id(nid, 0)] = input_layouts[input_id]; - mirror_vec[nid] = new_node; - continue; - } - - if (op_vecop.count(inode.source->op())) { - new_node = op_vecop[inode.source->op()](inode.source); - new_node->inputs.resize(new_node->num_inputs()); - } - - // set up output and input layouts - std::vector request_ilayouts(new_node->num_inputs(), GetDefaultLayout()); - if (op_layout_request.count(new_node->op())) { - std::vector produce_olayouts(new_node->num_outputs(), GetDefaultLayout()); - CHECK(op_layout_request[new_node->op()]( - new_node->attrs, &request_ilayouts, &produce_olayouts)) - << "Layout request fail"; - - CHECK_EQ(request_ilayouts.size(), new_node->num_inputs()); - CHECK_EQ(produce_olayouts.size(), new_node->num_outputs()); - for (size_t i = 0; i < new_node->num_outputs(); ++i) { - produce_vec[idx.entry_id(nid, i)] = produce_olayouts[i]; - } - } - - bool map_layout = is_map_op(nid); - if (map_layout) { - const TLayoutInfo& layout = produce_vec[idx.entry_id(inode.inputs[0])]; - for (const auto& e : inode.inputs) { - if (produce_vec[idx.entry_id(e)] != layout) { - map_layout = false; - break; - } - } - if (map_layout) { - for (size_t i = 0; i < inode.source->num_outputs(); ++i) { - produce_vec[idx.entry_id(nid, i)] = layout; - } - } - } - - for (size_t i = 0; i < inode.inputs.size(); ++i) { - const auto& e = inode.inputs[i]; - const nnvm::NodePtr& in = mirror_vec[e.node_id]; - new_node->inputs[i] = - nnvm::NodeEntry{in, e.index, e.version}; - - TLayoutInfo produce = produce_vec[idx.entry_id(e)]; - TLayoutInfo request = request_ilayouts[i]; - if (!map_layout && (produce != request)) { - nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request); - tnode->attrs.name = - idx[e.node_id].source->attrs.name + "_" + request; - tnode->inputs.emplace_back(new_node->inputs[i]); - new_node->inputs[i] = nnvm::NodeEntry{tnode, 0, 0}; - } - } - mirror_vec[nid] = new_node; - } - - std::vector outputs; - for (const auto& e : idx.outputs()) { - TLayoutInfo produce = produce_vec[idx.entry_id(e)]; - if (produce != GetDefaultLayout()) { - nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, GetDefaultLayout()); - tnode->attrs.name = - idx[e.node_id].source->attrs.name + "_default"; - tnode->inputs.emplace_back( - nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version}); - outputs.emplace_back(nnvm::NodeEntry{tnode, 0, 0}); - } else { - outputs.emplace_back( - nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version}); - } - } - - nnvm::Graph ret; - ret.outputs = std::move(outputs); - return ret; -} - -} // namespace compiler -} // namespace nnvm diff --git a/nnvm/src/compiler/packed_func_ext.cc b/nnvm/src/compiler/packed_func_ext.cc index 4483f5303..2587534d7 100644 --- a/nnvm/src/compiler/packed_func_ext.cc +++ b/nnvm/src/compiler/packed_func_ext.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include "./node_attr.h" #include "compile_engine.h" @@ -62,6 +63,23 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._dict_keys") *rv = keys; }); +TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout") +.set_body([](TVMArgs args, TVMRetValue *rv) { + // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown + PackedFunc* f = new PackedFunc(args[1].operator PackedFunc()); + Op& op = ::dmlc::Registry::Get()->__REGISTER_OR_GET__(args[0]); + auto fpack = [f](const NodeAttrs& attrs, + const Symbol& inputs, + const Array& tinfos) { + TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, tinfos); + CHECK_EQ(ret.type_code(), tvm::runtime::extension_class_info::code) + << " expected " << "Symbol (code = " << tvm::runtime::extension_class_info::code + << ") but get code = " << ret.type_code(); + return *(static_cast(ret.value().v_handle)); + }; + op.set_attr("FTVMAlterOpLayout", fpack, args[2]); +}); + // custom version of TVM compute TVM_REGISTER_GLOBAL("nnvm._register_compute") .set_body([](TVMArgs args, TVMRetValue *rv) { @@ -84,7 +102,7 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute") TVM_REGISTER_GLOBAL("nnvm._register_schedule") .set_body([](TVMArgs args, TVMRetValue *rv) { - // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown + // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* f = new PackedFunc(args[1].operator PackedFunc()); Op& op = ::dmlc::Registry::Get()->__REGISTER_OR_GET__(args[0]); auto fschedule = [f](const NodeAttrs& attrs, diff --git a/nnvm/src/compiler/simplify_inference.cc b/nnvm/src/compiler/simplify_inference.cc index 141950b05..a0782222a 100644 --- a/nnvm/src/compiler/simplify_inference.cc +++ b/nnvm/src/compiler/simplify_inference.cc @@ -22,7 +22,8 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, nnvm::NodeEntry beta, nnvm::NodeEntry moving_mean, nnvm::NodeEntry moving_var, - TShape dshape) { + TShape dshape, + TShape bshape) { CHECK_NE(dshape.ndim(), 0); CHECK(attrs.op); static const Op* bn_op = Op::Get("batch_norm"); @@ -60,13 +61,14 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, "elemwise_add", bn_name + "_add_beta", {shift, beta}); } int axis = param.axis; - scale = ExpandBiasToMatchAxis(scale, dshape.ndim(), 1, axis); - shift = ExpandBiasToMatchAxis(shift, dshape.ndim(), 1, axis); + scale = ExpandBiasToMatchAxis(scale, dshape.ndim()-bshape.ndim()+1, 1, axis); + shift = ExpandBiasToMatchAxis(shift, dshape.ndim()-bshape.ndim()+1, 1, axis); + NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data", {data, scale}); out = MakeNode("broadcast_add", bn_name + "_out", {out, shift}); - // It is invalid to ref the other values of BN after infernece transform. + // It is invalid to ref the other values of BN after inference transform. NodeEntry undef = MakeNode("__undef__", "undef", {}); return {out, undef, undef}; } @@ -87,7 +89,8 @@ Graph SimplifyInference(nnvm::Graph src) { n->inputs[2], n->inputs[3], n->inputs[4], - shape_vec[idx.entry_id(nid, 0)]); + shape_vec[idx.entry_id(nid, 0)], + shape_vec[idx.entry_id(nid, 1)]); return true; } else if (n->op() == dropout_op) { NodeEntry undef = MakeNode("__undef__", "undef", {}); @@ -101,7 +104,8 @@ Graph SimplifyInference(nnvm::Graph src) { } NNVM_REGISTER_PASS(SimplifyInference) -.set_body(SimplifyInference); +.set_body(SimplifyInference) +.set_change_graph(true); } // namespace compiler } // namespace nnvm diff --git a/nnvm/src/pass/correct_layout.cc b/nnvm/src/pass/correct_layout.cc new file mode 100644 index 000000000..0aa7d7478 --- /dev/null +++ b/nnvm/src/pass/correct_layout.cc @@ -0,0 +1,169 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file correct_layout.cc + * \brief Infer and correct layout. + */ +#include +#include +#include +#include +#include + +namespace nnvm { +namespace pass { + +nnvm::NodePtr CreateLayoutTransformNode(const Layout& src, + const Layout& dst) { + static const nnvm::Op* trans_op = nnvm::Op::Get("__layout_transform__"); + static int count = 0; + nnvm::NodePtr n = nnvm::Node::Create(); + n->attrs.op = trans_op; + n->attrs.name = src.name() + "_to_" + dst.name() + std::to_string(count++); + n->attrs.dict["src_layout"] = src.name(); + n->attrs.dict["dst_layout"] = dst.name(); + n->op()->attr_parser(&(n->attrs)); + return n; +} + +using LayoutAttrDict = std::unordered_map >; + +/*! + * \brief A simple layout infer pass that will + * insert layout transform nodes automatically. + */ +nnvm::Graph CorrectLayout(nnvm::Graph src) { + static auto& op_infer_layout = + nnvm::Op::GetAttr("FInferLayout"); + + const IndexedGraph& idx = src.indexed_graph(); + std::vector mirror_vec(idx.num_nodes(), nullptr); + + // (new) NodePtr -> output_layouts + LayoutAttrDict new_layouts; + + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + const auto& inode = idx[nid]; + nnvm::NodePtr new_node = nnvm::Node::Create(); + *new_node = *(inode.source); + if (new_node->is_variable()) { + // Variable node. No operator. Only one output entry. + auto input_iter = std::find( + idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid); + CHECK(input_iter != idx.input_nodes().cend()); + int64_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter); + if (src.HasAttr("layout_inputs")) { + new_layouts[new_node.get()] = + {src.GetAttr >("layout_inputs")[input_id]}; + } else { + new_layouts[new_node.get()] = {Layout::Undef()}; + } + mirror_vec[nid] = new_node; + continue; + } + + const uint32_t num_inputs = inode.inputs.size(); + const uint32_t num_outputs = inode.source->num_outputs(); + // set up output and input layouts + std::vector request_ilayouts(num_inputs, Layout::Undef()); + for (size_t i = 0; i < num_inputs; ++i) { + const IndexedGraph::NodeEntry& input_entry = inode.inputs[i]; + const NodePtr& new_input_node = mirror_vec[input_entry.node_id]; + CHECK(new_input_node != nullptr); + + // fill inputs by previous node (DFS order) inferred layouts. + const auto& layouts_iter = new_layouts.find(new_input_node.get()); + CHECK(layouts_iter != new_layouts.end()); + request_ilayouts[i] = layouts_iter->second[input_entry.index]; + } + // layouts produced by previous node. + std::vector produce_ilayouts(request_ilayouts); + // input layouts from last pass of LayoutTransform (if apply) + std::vector last_request_ilayouts(num_inputs, Layout::Undef()); + // fill outputs by last pass of LayoutTransform (if apply) + std::vector produce_olayouts(num_outputs, Layout::Undef()); + if (src.HasAttr("layout")) { + const auto& layouts = src.GetAttr >("layout"); + for (uint32_t i = 0; i < num_outputs; ++i) { + produce_olayouts[i] = layouts[idx.entry_id(nid, i)]; + } + for (uint32_t i = 0; i < num_inputs; ++i) { + last_request_ilayouts[i] = layouts[idx.entry_id(inode.inputs[i])]; + } + } + + const auto& flayout = op_infer_layout[new_node->op()]; + CHECK(flayout != nullptr) << "Attribute FInferLayout" + << " is not registered by op " << inode.source->op()->name + << " we are not able to complete layout transform."; + CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts)) + << "Layout infer fail"; + CHECK_EQ(request_ilayouts.size(), num_inputs); + CHECK_EQ(produce_olayouts.size(), num_outputs); + + // update new layouts + new_layouts[new_node.get()] = std::move(produce_olayouts); + + for (uint32_t i = 0; i < inode.inputs.size(); ++i) { + const auto& e = inode.inputs[i]; + const nnvm::NodePtr& in = mirror_vec[e.node_id]; + new_node->inputs[i] = nnvm::NodeEntry{in, e.index, e.version}; + + // insert layout_transform if necessary + const Layout& produce = produce_ilayouts[i]; + const Layout& request = request_ilayouts[i]; + if (produce != request && produce.defined()) { + nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request); + tnode->attrs.name = idx[e.node_id].source->attrs.name + "_" + request.name(); + tnode->inputs.emplace_back(new_node->inputs[i]); + nnvm::NodeEntry tnode_output{tnode, 0, 0}; + new_node->inputs[i] = tnode_output; + // layout produced by LayoutTransformNode + new_layouts[tnode.get()] = {request}; + } else if (!produce.defined()) { + // do reverse infer + new_layouts[in.get()][e.index] = request; + } + } + mirror_vec[nid] = new_node; + } + + std::vector outputs; + for (const auto& e : idx.outputs()) { + outputs.emplace_back(nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version}); + } + + nnvm::Graph ret; + ret.outputs = outputs; + // restore the layouts to return graph + const auto& ret_idx = ret.indexed_graph(); + std::vector ret_layouts(ret_idx.num_node_entries(), Layout::Undef()); + for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) { + const auto& inode = ret_idx[nid]; + const auto& layout_iter = new_layouts.find(inode.source); + if (layout_iter != new_layouts.end()) { + for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { + ret_layouts[ret_idx.entry_id(nid, i)] = std::move(layout_iter->second[i]); + } + } + } + + // cannot call indexed_graph() before return the origin Graph, + // thus create a new one + nnvm::Graph new_ret; + new_ret.outputs = std::move(outputs); + new_ret.attrs["layout"] = std::make_shared(std::move(ret_layouts)); + + return new_ret; +} + +// register pass +NNVM_REGISTER_PASS(CorrectLayout) +.describe("Return a layout-transformed graph of src.") +.set_body(CorrectLayout) +.provide_graph_attr("layout") +.set_change_graph(true); + +DMLC_JSON_ENABLE_ANY(LayoutVector, list_layout); + +} // namespace pass +} // namespace nnvm diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index fd9b77c42..cc4916ce0 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -158,7 +158,7 @@ Graph InferAttr(Graph &&ret, } else { CHECK(!last_iter) << "Attribute " << infer_name - << " is not registed by op " << inode.source->op()->name + << " is not registered by op " << inode.source->op()->name << " we are not able to complete the inference because of this"; } } diff --git a/nnvm/src/top/elemwise_op_common.h b/nnvm/src/top/elemwise_op_common.h index 7bcc262d3..27a7c2f0e 100644 --- a/nnvm/src/top/elemwise_op_common.h +++ b/nnvm/src/top/elemwise_op_common.h @@ -6,9 +6,12 @@ #ifndef NNVM_TOP_ELEMWISE_OP_COMMON_H_ #define NNVM_TOP_ELEMWISE_OP_COMMON_H_ +#include +#include #include #include #include +#include #include "./op_common.h" namespace nnvm { @@ -100,12 +103,176 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs, attrs, in_attrs, out_attrs, -1); } +template +inline bool ElemwiseFixedLayout(const NodeAttrs& attrs, + std::vector *in_layouts, + const std::vector *last_in_layouts, + std::vector *out_layouts, + const std::function& finfer) { + const size_t in_size = (n_in == -1) ? in_layouts->size() : static_cast(n_in); + const size_t out_size = (n_out == -1) ? out_layouts->size() : static_cast(n_out); + + auto deduce = [&](Layout *target, const std::vector *vec, + size_t size, const char *name) { + for (size_t i = 0; i < size; ++i) { + if (vec->at(i).defined()) { + if (!target->defined()) { + *target = vec->at(i); + } + CHECK_EQ(*target, vec->at(i)) + << "Incompatible attr in node " << attrs.name << " at " << i << "-th " + << name << ": " << "expected " << *target + << ", got " << vec->at(i); + } + } + }; + + Layout in, last_in, out; + deduce(&in, in_layouts, in_size, "input"); + deduce(&last_in, last_in_layouts, in_size, "input (last infer pass)"); + deduce(&out, out_layouts, out_size, "output"); + + if (!last_in.defined()) { + last_in = in; + } else { + // else we copy in_layout produced by last infer pass to in_layout, + // and let LayoutTransform pass + // to insert an layout_transform node to fix the input layout. + in = last_in; + } + + out = finfer(in); + + auto write = [](std::vector *vec, Layout& value, size_t size) { + for (size_t i = 0; i < size; ++i) { + vec->at(i) = value; + } + }; + if (in.defined()) write(in_layouts, in, in_size); + if (out.defined()) write(out_layouts, out, out_size); + + return true; +} + +/*! \brief Fix the input layout as the previous inferred (if any) and copy to output */ +template +inline bool ElemwiseFixedLayoutCopyToOut(const NodeAttrs& attrs, + std::vector *in_layouts, + const std::vector *last_in_layouts, + std::vector *out_layouts) { + return ElemwiseFixedLayout( + attrs, in_layouts, last_in_layouts, out_layouts, [](const Layout& in) { + return in; + }); +} + +/*! \brief Fix the input layout as the previous inferred (if any) and do not define output */ +template +inline bool ElemwiseFixedLayoutUnknownOut(const NodeAttrs& attrs, + std::vector *in_layouts, + const std::vector *last_in_layouts, + std::vector *out_layouts) { + return ElemwiseFixedLayout( + attrs, in_layouts, last_in_layouts, out_layouts, [](const Layout& in) { + return Layout::Undef(); + }); +} + +/*! \brief take arbitrary input layout and copy to output */ +template +inline bool ElemwiseArbitraryLayout(const NodeAttrs& attrs, + std::vector *in_layouts, + const std::vector *last_in_layouts, + std::vector *out_layouts) { + const size_t in_size = (n_in == -1) ? in_layouts->size() : static_cast(n_in); + const size_t out_size = (n_out == -1) ? out_layouts->size() : static_cast(n_out); + + Layout in; + for (size_t i = 0; i < in_size; ++i) { + if (!in.defined()) in = in_layouts->at(i); + CHECK_EQ(in, in_layouts->at(i)) + << "Incompatible attr in node " << attrs.name << " at " << i + << "-th input: expected " << in + << ", got " << in_layouts->at(i); + } + + if (in.defined()) { + for (size_t i = 0; i < out_size; ++i) { + out_layouts->at(i) = in; + } + } + + return true; +} + +/*! + * \brief try to convert right layout to left layout if they are different. + * if the converting fails, it will use the last inferred layouts. + */ +inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs, + std::vector *in_layouts, + const std::vector *last_in_layouts, + std::vector *out_layouts) { + CHECK_EQ(in_layouts->size(), 2U); + CHECK_EQ(last_in_layouts->size(), 2U); + CHECK_EQ(out_layouts->size(), 1U); + + const Layout& lhs_last = (*last_in_layouts)[0]; + const Layout& rhs_last = (*last_in_layouts)[1]; + CHECK((lhs_last.defined() && rhs_last.defined()) || + (!lhs_last.defined() && !rhs_last.defined())); + + const Layout& lhs = (*in_layouts)[0]; + const Layout& rhs = (*in_layouts)[1]; + + if (!lhs.defined() && !rhs.defined()) { + CHECK(!lhs_last.defined() && !rhs_last.defined()) + << "Lost input layouts in node " << attrs.name + << ": last inferred lhs=" << lhs_last << ", rhs=" << rhs_last; + return true; + } else if (!lhs.defined()) { + CHECK(!lhs_last.defined() && !rhs_last.defined()); + in_layouts->at(0) = rhs; + out_layouts->at(0) = rhs; + return true; + } else if (!rhs.defined()) { + CHECK(!lhs_last.defined() && !rhs_last.defined()); + in_layouts->at(1) = lhs; + out_layouts->at(0) = lhs; + return true; + } + + if (lhs == rhs) { + // for same layout, we can always do binary calculation + // and pass the layout to next layer + out_layouts->at(0) = lhs; + return true; + } + + if (rhs.convertible(lhs)) { + in_layouts->at(1) = lhs; + out_layouts->at(0) = lhs; + } else { + CHECK(lhs_last.defined() && rhs_last.defined()) + << "Incompatible input layouts in node " << attrs.name + << ". lhs: " << lhs << ", rhs: " << rhs; + CHECK(lhs_last == rhs_last); + in_layouts->at(0) = lhs_last; + in_layouts->at(1) = rhs_last; + out_layouts->at(0) = lhs_last; + } + + return true; +} + #define NNVM_REGISTER_ELEMWISE_UNARY_OP(name) \ NNVM_REGISTER_OP(name) \ .set_num_inputs(1) \ .set_num_outputs(1) \ .set_attr("FInferShape", ElemwiseShape<1, 1>) \ .set_attr("FInferType", ElemwiseType<1, 1>) \ + .set_attr("FInferLayout", \ + ElemwiseArbitraryLayout<1, 1>) \ .set_attr("FInplaceOption", \ [](const NodeAttrs& attrs){ \ return std::vector >{{0, 0}}; \ @@ -131,6 +298,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs, .set_num_outputs(1) \ .set_attr("FInferShape", ElemwiseShape<2, 1>) \ .set_attr("FInferType", ElemwiseType<2, 1>) \ + .set_attr("FInferLayout", \ + ElemwiseBinaryKeepLeftLayout) \ .set_attr("FInplaceOption", \ [](const NodeAttrs& attrs) { \ return std::vector >{{0, 0}, {1, 0}}; \ @@ -150,6 +319,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs, ParamGetAttrDict) \ .set_attr("FInferShape", \ ElementWiseReduceShape) \ + .set_attr("FInferLayout", \ + ElemwiseFixedLayoutCopyToOut<1, 1>) \ .set_attr("FInferType", ElementWiseReduceType) \ .add_argument("args", "Symbol[]", "Positional input arguments") @@ -166,6 +337,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs, static_cast(kFloat32)); \ return true; \ }) \ + .set_attr("FInferLayout", \ + ElemwiseFixedLayoutUnknownOut<1, 1>) \ .set_attr( \ "FGradient", [](const NodePtr& n, \ const std::vector& ograds) { \ diff --git a/nnvm/src/top/nn/convolution.cc b/nnvm/src/top/nn/convolution.cc index d517e7e5d..dc7451e51 100644 --- a/nnvm/src/top/nn/convolution.cc +++ b/nnvm/src/top/nn/convolution.cc @@ -5,11 +5,22 @@ */ #include #include +#include #include #include +#include +#include +#include +#include #include "./nn_common.h" #include "../op_common.h" #include "../elemwise_op_common.h" +#include "topi/nn.h" + + +using tvm::Tensor; +using tvm::Array; +using nnvm::compiler::FTVMCompute; namespace nnvm { namespace top { @@ -20,7 +31,26 @@ DMLC_REGISTER_PARAMETER(Conv2DParam); inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, std::vector* in_shape, std::vector* out_shape) { + static const Layout kNCHW("NCHW"); + static const Layout kOIHW("OIHW"); + const Conv2DParam& param = nnvm::get(attrs.parsed); + + const Layout in_layout(param.layout); + const Layout kernel_layout(param.kernel_layout); + CHECK(in_layout.convertible(kNCHW)) + << "Conv only support input layouts that are convertible from NCHW." + << " But got " << in_layout; + CHECK(kernel_layout.convertible(kOIHW)) + << "Conv only support kernel layouts that are convertible from OIHW." + << " But got "<< kernel_layout; + + Layout out_layout(param.out_layout); + if (!out_layout.defined()) out_layout = in_layout; + CHECK(out_layout.convertible(kNCHW)) + << "Conv only support output layouts that are convertible from NCHW." + << " But got " << out_layout; + if (param.use_bias) { CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]"; } else { @@ -30,7 +60,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, TShape dshape = in_shape->at(0); if (dshape.ndim() == 0) return false; - dshape = ConvertLayout(dshape, param.layout, kNCHW); + dshape = ConvertLayout(dshape, in_layout, kNCHW); CHECK_EQ(dshape.ndim(), 4U) << "Input data should be 4D"; CHECK_EQ(param.kernel_size.ndim(), 2U); @@ -48,13 +78,20 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, param.kernel_size[0], param.kernel_size[1]}); - wshape = ConvertLayout(wshape, kNCHW, param.layout, true); + wshape = ConvertLayout(wshape, kOIHW, kernel_layout); wshape[0] *= param.groups; NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape); if (param.use_bias) { - NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, - Conv2DParam::kBias, TShape({param.channels})); + static const Layout default_bias_layout("C"); + TShape bias_shape({param.channels}); + auto oc_block = out_layout.subsizeof('C'); + if (oc_block > 0) { + size_t split_axis = (out_layout.indexof('C') < out_layout.indexof('c')) ? 1 : 0; + bias_shape = ConvertLayout(bias_shape, default_bias_layout, + default_bias_layout.split('C', split_axis, oc_block)); + } + NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kBias, bias_shape); } // dilation dim_t dilated_ksize_y = 1 + (param.kernel_size[0] - 1) * param.dilation[0]; @@ -66,12 +103,11 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, if (dshape[3] != 0) { oshape[3] = (dshape[3] + param.padding[1] * 2 - dilated_ksize_x) / param.strides[1] + 1; } - NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, - ConvertLayout(oshape, kNCHW, param.layout)); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, ConvertLayout(oshape, kNCHW, out_layout)); // Perform incomplete shape inference. Fill in the missing values in data shape. // 1) We can always fill in the batch_size. // 2) We can back-calculate the input height/width if the corresponding stride is 1. - oshape = ConvertLayout((*out_shape)[0], param.layout, kNCHW); + oshape = ConvertLayout((*out_shape)[0], out_layout, kNCHW); dshape[0] = oshape[0]; if (oshape[2] && param.strides[0] == 1) { dshape[2] = oshape[2] + dilated_ksize_y - 1 - 2 * param.padding[0]; @@ -80,7 +116,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, dshape[3] = oshape[3] + dilated_ksize_x - 1 - 2 * param.padding[1]; } NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kData, - ConvertLayout(dshape, kNCHW, param.layout)); + ConvertLayout(dshape, kNCHW, in_layout)); // Check whether the kernel sizes are valid if (dshape[2] != 0) { CHECK_LE(dilated_ksize_y, dshape[2] + 2 * param.padding[0]) @@ -93,6 +129,41 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, return true; } +inline bool Conv2DInferLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + const Conv2DParam& param = nnvm::get(attrs.parsed); + + const Layout in_layout(param.layout); + Layout out_layout(param.out_layout); + if (!out_layout.defined()) out_layout = in_layout; + + const Layout kernel_layout(param.kernel_layout); + if (param.use_bias) { + CHECK_EQ(ilayouts->size(), 3U) << "Input:[data, weight, bias]"; + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, in_layout); + NNVM_ASSIGN_LAYOUT(*ilayouts, 1, kernel_layout); + // automatically decide bias layout + Layout bias_layout("C"); + auto oc_block = out_layout.subsizeof('C'); + if (oc_block > 0) { + size_t split_axis = (out_layout.indexof('C') < out_layout.indexof('c')) ? 1 : 0; + bias_layout = bias_layout.split('C', split_axis, oc_block); + } + NNVM_ASSIGN_LAYOUT(*ilayouts, 2, bias_layout); + } else { + CHECK_EQ(ilayouts->size(), 2U) << "Input:[data, weight]"; + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, in_layout); + NNVM_ASSIGN_LAYOUT(*ilayouts, 1, kernel_layout); + } + + CHECK_EQ(olayouts->size(), 1U); + NNVM_ASSIGN_LAYOUT(*olayouts, 0, out_layout); + + return true; +} + NNVM_REGISTER_OP(conv2d) .describe(R"code(2D convolution layer (e.g. spatial convolution over images). @@ -118,6 +189,7 @@ a bias vector is created and added to the outputs. .set_attr("FListInputNames", UseBiasListInputNames) .set_attr("FInferShape", Conv2DInferShape) .set_attr("FInferType", ElemwiseType<-1, 1>) +.set_attr("FInferLayout", Conv2DInferLayout) .set_num_outputs(1) .set_num_inputs(UseBiasNumInputs) .set_support_level(2) @@ -130,6 +202,23 @@ a bias vector is created and added to the outputs. n->attrs.dict); }); +NNVM_REGISTER_OP(_contrib_conv2d_NCHWc) +.describe(R"code(2D convolution layer (e.g. spatial convolution over images). +)code" NNVM_ADD_FILELINE) +.add_argument("data", "5D Tensor", "Packed input data.") +.add_argument("weight", "6D Tensor", "Packed weight matrix.") +.add_argument("bias", "1D Tensor", "Bias parameter.") +.add_arguments(Conv2DParam::__FIELDS__()) +.set_attr_parser(ParamParser) +.set_attr("FGetAttrDict", ParamGetAttrDict) +.set_attr("FListInputNames", UseBiasListInputNames) +.set_attr("FInferShape", Conv2DInferShape) +.set_attr("FInferType", ElemwiseType<-1, 1>) +.set_attr("FInferLayout", Conv2DInferLayout) +.set_num_outputs(1) +.set_num_inputs(UseBiasNumInputs) +.set_support_level(2); + NNVM_REGISTER_OP(_conv2d_grad) .describe(R"code(2D convolution grad. @@ -163,16 +252,21 @@ DMLC_REGISTER_PARAMETER(Conv2DTransposeParam); inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs, std::vector* in_shape, std::vector* out_shape) { + static const Layout kNCHW("NCHW"); + static const Layout kOIHW("OIHW"); const Conv2DTransposeParam& param = nnvm::get(attrs.parsed); + const Layout layout(param.layout); + const Layout kernel_layout(param.kernel_layout); if (param.use_bias) { CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]"; } else { CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]"; } CHECK_EQ(out_shape->size(), 1U); + const TShape& dshape = (*in_shape)[Conv2DTransposeParam::kData]; if (dshape.ndim() == 0) return false; - TShape dshape_nchw = ConvertLayout(dshape, param.layout, kNCHW); + TShape dshape_nchw = ConvertLayout(dshape, layout, kNCHW); CHECK_EQ(dshape_nchw[1] % param.groups, 0U) << "input num_filter must divide group size"; @@ -189,7 +283,7 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs, param.channels / param.groups, param.kernel_size[0], param.kernel_size[1]}); - wshape = ConvertLayout(wshape, kNCHW, param.layout, true); + wshape = ConvertLayout(wshape, kOIHW, kernel_layout); NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DTransposeParam::kWeight, wshape); if (param.use_bias) { @@ -208,7 +302,33 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs, oshape[3] = (param.strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - 2 * param.padding[1] + param.output_padding[1]); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, - ConvertLayout(oshape, kNCHW, param.layout)); + ConvertLayout(oshape, kNCHW, layout)); + return true; +} + +inline bool Conv2DTransposeInferLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + const Conv2DTransposeParam& param = nnvm::get(attrs.parsed); + + const Layout in_layout(param.layout); + + const Layout kernel_layout(param.kernel_layout); + if (param.use_bias) { + CHECK_EQ(ilayouts->size(), 3U) << "Input:[data, weight, bias]"; + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, in_layout); + NNVM_ASSIGN_LAYOUT(*ilayouts, 1, kernel_layout); + NNVM_ASSIGN_LAYOUT(*ilayouts, 2, Layout("C")); + } else { + CHECK_EQ(ilayouts->size(), 2U) << "Input:[data, weight]"; + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, in_layout); + NNVM_ASSIGN_LAYOUT(*ilayouts, 1, kernel_layout); + } + + CHECK_EQ(olayouts->size(), 1U); + NNVM_ASSIGN_LAYOUT(*olayouts, 0, in_layout); + return true; } @@ -243,6 +363,7 @@ said convolution. .set_attr("FListInputNames", UseBiasListInputNames) .set_attr("FInferShape", Conv2DTransposeInferShape) .set_attr("FInferType", ElemwiseType<-1, 1>) +.set_attr("FInferLayout", Conv2DTransposeInferLayout) .set_num_outputs(1) .set_num_inputs(UseBiasNumInputs) .set_support_level(2); diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc index e3755d952..219b5f7f9 100644 --- a/nnvm/src/top/nn/nn.cc +++ b/nnvm/src/top/nn/nn.cc @@ -3,10 +3,12 @@ * \file nn.cc * \brief Property def of nn operators. */ +#include #include #include #include #include +#include #include #include #include @@ -20,6 +22,8 @@ namespace nnvm { namespace top { +using tvm::Var; +using tvm::Expr; using tvm::Tensor; using tvm::Array; using nnvm::compiler::FTVMCompute; @@ -82,6 +86,8 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored. .set_attr("FListInputNames", UseBiasListInputNames) .set_attr("FInferShape", DenseInferShape) .set_attr("FInferType", ElemwiseType<-1, 1>) +// leave weight & bias layout undefined +.set_attr("FInferLayout", ElemwiseFixedLayoutCopyToOut<1, 1>) .set_attr( "FGradient", [](const NodePtr& n, const std::vector& ograds) { @@ -161,6 +167,7 @@ NNVM_REGISTER_OP(dropout) .set_num_outputs(2) .set_attr("FInferShape", ElemwiseShape<1, 2>) .set_attr("FInferType", ElemwiseType<1, 2>) +.set_attr("FInferLayout", ElemwiseArbitraryLayout<1, 1>) .set_attr("FNumVisibleOutputs", [](const NodeAttrs& attrs) { return 1; }) @@ -184,13 +191,75 @@ inline bool BatchNormInferShape(const nnvm::NodeAttrs& attrs, CHECK((size_t)param.axis < dshape.Size()); TShape bshape({dshape[param.axis]}); - NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 1, bshape); - NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 2, bshape); - NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 3, bshape); - NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 4, bshape); + if (in_shape->at(1).ndim() == 0) in_shape->at(1) = bshape; + if (in_shape->at(2).ndim() == 0) in_shape->at(2) = bshape; + if (in_shape->at(3).ndim() == 0) in_shape->at(3) = bshape; + if (in_shape->at(4).ndim() == 0) in_shape->at(4) = bshape; NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, dshape); - NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 1, bshape); - NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 2, bshape); + out_shape->at(1) = in_shape->at(3); + out_shape->at(2) = in_shape->at(4); + return true; +} + +inline bool BatchNormInferLayout(const NodeAttrs& attrs, + std::vector *in_layouts, + const std::vector *last_in_layouts, + std::vector *out_layouts) { + const BatchNormParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_layouts->size(), 5U); + CHECK_EQ(last_in_layouts->size(), 5U); + CHECK_EQ(out_layouts->size(), 3U); + + Layout data_layout = in_layouts->at(0); + const Layout& origin_data_layout = last_in_layouts->at(0); + Layout param_layout("C"); + if (data_layout.defined()) { + if (data_layout.indexof('C') != param.axis) { + CHECK(origin_data_layout.defined()) + << "Channel in data layout " << data_layout + << " is not at index " << param.axis; + // convert it to the original one. + data_layout = origin_data_layout; + NNVM_ASSIGN_LAYOUT(*in_layouts, 0, origin_data_layout); + } else if (data_layout.indexof('c') >= 0 && + static_cast(data_layout.indexof('c')) != (data_layout.ndim()-1)) { + CHECK(origin_data_layout.defined()) + << "sub-channel c in data layout " << data_layout + << " does not at the final dimension"; + // convert it to the original one. + data_layout = origin_data_layout; + NNVM_ASSIGN_LAYOUT(*in_layouts, 0, origin_data_layout); + } else { + for (Layout::LayoutDim axis : data_layout) { + if (Layout::is_subdim(axis) && axis != 'c') { + CHECK(origin_data_layout.defined()) + << "sub-axis other than c appears in data layout " << data_layout; + // convert it to the original one. + data_layout = origin_data_layout; + NNVM_ASSIGN_LAYOUT(*in_layouts, 0, origin_data_layout); + break; + } + } + } + + // decide the param layout + if (data_layout.defined()) { + auto channel_block = data_layout.subsizeof('C'); + if (channel_block > 0) { + param_layout = param_layout.split('C', 1, channel_block); + } + } + } + + NNVM_ASSIGN_LAYOUT(*in_layouts, 0, data_layout); + NNVM_ASSIGN_LAYOUT(*in_layouts, 1, param_layout); + NNVM_ASSIGN_LAYOUT(*in_layouts, 2, param_layout); + NNVM_ASSIGN_LAYOUT(*in_layouts, 3, param_layout); + NNVM_ASSIGN_LAYOUT(*in_layouts, 4, param_layout); + + NNVM_ASSIGN_LAYOUT(*out_layouts, 0, data_layout); + NNVM_ASSIGN_LAYOUT(*out_layouts, 1, param_layout); + NNVM_ASSIGN_LAYOUT(*out_layouts, 2, param_layout); return true; } @@ -238,6 +307,7 @@ axis to be the last item in the input shape. .add_arguments(BatchNormParam::__FIELDS__()) .set_attr_parser(ParamParser) .set_attr("FGetAttrDict", ParamGetAttrDict) +.set_attr("FInferLayout", BatchNormInferLayout) .set_num_inputs(5) .set_num_outputs(3) .set_attr("FInferShape", BatchNormInferShape) @@ -275,6 +345,7 @@ NNVM_REGISTER_OP(softmax) .set_num_outputs(1) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferLayout", ElemwiseArbitraryLayout<1, 1>) .set_support_level(1) .set_attr( "FTVMCompute", [](const NodeAttrs& attrs, @@ -331,6 +402,7 @@ NNVM_REGISTER_OP(log_softmax) .set_num_outputs(1) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferLayout", ElemwiseArbitraryLayout<1, 1>) .set_attr( "FTVMCompute", [](const NodeAttrs& attrs, const Array& inputs, @@ -388,6 +460,7 @@ NNVM_REGISTER_OP(leaky_relu) .set_num_outputs(1) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferLayout", ElemwiseArbitraryLayout<1, 1>) .set_attr( "FTVMCompute", [](const NodeAttrs& attrs, const Array& inputs, @@ -439,6 +512,30 @@ inline bool PReluInferShape(const nnvm::NodeAttrs &attrs, return true; } +inline bool PReluInferLayout(const NodeAttrs& attrs, + std::vector *in_layouts, + const std::vector *last_in_layouts, + std::vector *out_layouts) { + const PReLUParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_layouts->size(), 2U); + CHECK_EQ(last_in_layouts->size(), 2U); + CHECK_EQ(out_layouts->size(), 1U); + + const Layout& data_layout = last_in_layouts->at(0).defined() ? + last_in_layouts->at(0) : in_layouts->at(0); + if (data_layout.defined()) { + CHECK(data_layout.indexof('C') == param.axis && !data_layout.contains('c')) + << "Channel in data layout " << data_layout + << " is not at index " << param.axis; + } + + NNVM_ASSIGN_LAYOUT(*in_layouts, 0, data_layout); + NNVM_ASSIGN_LAYOUT(*in_layouts, 1, Layout("C")); + NNVM_ASSIGN_LAYOUT(*out_layouts, 0, data_layout); + + return true; +} + NNVM_REGISTER_OP(prelu) .describe(R"code(Parametric version of a Rectified Linear Unit. It accepts two arguments: an input ``x`` and a channelwise slope ``alpha`` @@ -453,6 +550,7 @@ where :math:`*` is an channelwise multiplication for each sample in the .set_num_inputs(2) .set_num_outputs(1) .set_attr("FInferShape", PReluInferShape) +.set_attr("FInferLayout", PReluInferLayout) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"data", "alpha"}; }) @@ -499,6 +597,7 @@ NNVM_REGISTER_OP(pad) .set_num_inputs(1) .set_attr("FInferShape", PadInferShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferLayout", ElemwiseFixedLayoutCopyToOut<1, 1>) .set_attr( "FTVMCompute", [](const NodeAttrs& attrs, const Array& inputs, @@ -520,5 +619,94 @@ NNVM_REGISTER_OP(pad) }) .set_support_level(1); +// layout transformer +DMLC_REGISTER_PARAMETER(LayoutTransformParam); + +inline bool LayoutTransformInferShape(const NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]"; + CHECK_EQ(out_attrs->size(), 1U); + const LayoutTransformParam& param = nnvm::get(attrs.parsed); + const TShape &dshape = (*in_attrs)[0]; + if (dshape.ndim() == 0) return false; + const TShape &oshape = ConvertLayout(dshape, + Layout(param.src_layout), + Layout(param.dst_layout)); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); + return true; +} + +NNVM_REGISTER_OP(__layout_transform__) +.describe(R"code(Transform the input data layout. + +For transforming from NCHW to N16cHWC, the `__layout_transform__` operator reshapes +the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] + +)code" NNVM_ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.add_argument("data", "Tensor", "Input data.") +.add_arguments(LayoutTransformParam::__FIELDS__()) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", LayoutTransformInferShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr( + "FInferLayout", [](const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + const LayoutTransformParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(ilayouts->size(), 1U); + CHECK_EQ(olayouts->size(), 1U); + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, Layout(param.src_layout)); + NNVM_ASSIGN_LAYOUT(*olayouts, 0, Layout(param.dst_layout)); + return true; +}) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& outputs) { + const LayoutTransformParam& param = nnvm::get(attrs.parsed); + + Layout src_layout(param.src_layout); + Layout dst_layout(param.dst_layout); + + if (src_layout == dst_layout) { + return Array{ inputs[0] }; + } else if (!src_layout.defined() || !dst_layout.defined()) { + LOG(FATAL) << "cannot convert from/to undefined layout"; + } + + CHECK(src_layout.convertible(dst_layout)) << "cannot convert from " << param.src_layout + << " to " << param.dst_layout; + + return Array { + topi::layout_transform(inputs[0], outputs[0]->shape, [&](const Array& dst_indices) { + std::vector dst_to_src_indices; + for (Layout::LayoutDim src_axis : src_layout) { + int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_axis)); + int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_axis)); + int32_t src_factor = static_cast(src_layout.subsizeof(src_axis)); + int32_t dst_factor = static_cast(dst_layout.subsizeof(src_axis)); + + Expr src_index(dst_indices[dst_major_pos]); + if (dst_minor_pos >= 0) { + CHECK_GT(dst_factor, 0); + src_index = src_index * dst_factor + dst_indices[dst_minor_pos]; + } + if (Layout::is_superdim(src_axis) && src_factor > 0) { + src_index = src_index / src_factor; + } else if (Layout::is_subdim(src_axis) && src_factor > 0) { + src_index = src_index % src_factor; + } + dst_to_src_indices.push_back(src_index); + } + return Array(dst_to_src_indices); + }) + }; +}) +.set_support_level(1); + } // namespace top } // namespace nnvm diff --git a/nnvm/src/top/nn/nn_common.h b/nnvm/src/top/nn/nn_common.h index e9176d17a..49a020348 100644 --- a/nnvm/src/top/nn/nn_common.h +++ b/nnvm/src/top/nn/nn_common.h @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -40,100 +41,47 @@ inline std::vector UseBiasListInputNames(const NodeAttrs& attrs) { * \param dst_layout target layout * \return shape in target layout */ -inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout, bool is_weight = false) { - if (src_layout == dst_layout) return src; - TShape dst = src; - if (src.ndim() == 3) { - switch (src_layout) { - case kNCW: break; - case kNWC: { - std::swap(dst[1], dst[2]); - break; - } - default: { - LOG(FATAL) << "inavlid layout for 3d shape" << src_layout; - } - } - switch (dst_layout) { - case kNCW: break; - case kNWC: { - std::swap(dst[1], dst[2]); - break; - } - default: { - LOG(FATAL) << "inavlid layout for 3d shape" << dst_layout; - } - } - } else if (src.ndim() == 4) { - switch (src_layout) { - case kNCHW: break; - case kNHWC: { - if (is_weight) { - dst[2] = src[0]; - dst[3] = src[1]; - dst[1] = src[2]; - dst[0] = src[3]; - } else { - dst[2] = src[1]; - dst[3] = src[2]; - dst[1] = src[3]; - } - break; - } - default: { - LOG(FATAL) << "inavlid layout for 4d shape" << src_layout; - } - } - src = dst; - switch (dst_layout) { - case kNCHW: break; - case kNHWC: { - if (is_weight) { - dst[0] = src[2]; - dst[1] = src[3]; - dst[2] = src[1]; - dst[3] = src[0]; - } else { - dst[1] = src[2]; - dst[2] = src[3]; - dst[3] = src[1]; - } - break; - } - default: { - LOG(FATAL) << "inavlid layout for 4d shape" << dst_layout; - } - } - } else if (src.ndim() == 5) { - switch (src_layout) { - case kNCDHW: break; - case kNDHWC: { - dst[2] = src[1]; - dst[3] = src[2]; - dst[4] = src[3]; - dst[1] = src[4]; - break; - } - default: { - LOG(FATAL) << "inavlid layout for 5d shape" << src_layout; - } - } - src = dst; - switch (dst_layout) { - case kNCDHW: break; - case kNDHWC: { - dst[1] = src[2]; - dst[2] = src[3]; - dst[3] = src[4]; - dst[4] = src[1]; - break; +inline TShape ConvertLayout(TShape src, const Layout& src_layout, const Layout& dst_layout) { + if (src_layout == dst_layout) { + return src; + } else if (!src_layout.defined()) { + LOG(FATAL) << "cannot convert undefined layout to " << dst_layout; + } else if (!dst_layout.defined()) { + LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout"; + } + + CHECK(src_layout.convertible(dst_layout)) << "cannot convert from " + << src_layout << " to " << dst_layout; + + TShape dst(dst_layout.ndim()); + for (size_t i = 0; i < src_layout.ndim(); ++i) { + Layout::LayoutDim src_dim = src_layout[i]; + if (Layout::is_superdim(src_dim)) { + int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_dim)); + int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_dim)); + int src_minor_pos = src_layout.indexof(Layout::to_subdim(src_dim)); + int src_factor = src_layout.subsizeof(src_dim); + int dst_factor = dst_layout.subsizeof(src_dim); + + uint32_t src_dim_size = src[i]; + if (src_minor_pos >= 0) { + CHECK_EQ(src_factor, src[src_minor_pos]) << "src shape " << src + << " does not agree with layout " << src_layout; + src_dim_size *= src_factor; } - default: { - LOG(FATAL) << "inavlid layout for 5d shape" << dst_layout; + + dst[dst_major_pos] = src_dim_size; + if (dst_minor_pos >= 0) { + CHECK_GT(dst_factor, 0); + CHECK_LE(dst_factor, src_dim_size) << "Converting " << src + << " from " << src_layout + << " to " << dst_factor + << ": cannot split dimension size of " + << src_dim_size << " by " << dst_factor; + dst[dst_major_pos] /= dst_factor; + dst[dst_minor_pos] = dst_factor; } } - } else { - LOG(FATAL) << "no layout option for " << dst.ndim() << " dimensions"; } return dst; } diff --git a/nnvm/src/top/nn/pooling.cc b/nnvm/src/top/nn/pooling.cc index 10d9e2785..e32569ce3 100644 --- a/nnvm/src/top/nn/pooling.cc +++ b/nnvm/src/top/nn/pooling.cc @@ -30,34 +30,73 @@ inline bool Pool2DInferShape(const nnvm::NodeAttrs& attrs, TShape dshape = (*in_shape)[0]; if (dshape.ndim() == 0) return false; - dshape = ConvertLayout(dshape, param.layout, kNCHW); + + CHECK_GE(dshape.ndim(), 2U) + << "Pool2D only support input >= 2-D: input must have height and width"; + + Layout layout(param.layout); + CHECK(layout.contains('H') && layout.contains('W') && + !layout.contains('h') && !layout.contains('w')) + << "Invalid layout " << layout + << ". Pool2D layout must have H and W, which cannot be split"; + + const auto hidx = layout.indexof('H'); + const auto widx = layout.indexof('W'); TShape oshape = dshape; - CHECK_EQ(dshape.ndim(), 4U) - << "Pooling: Input data should be 4D"; - CHECK(param.pool_size[0] <= dshape[2] + 2 * param.padding[0]) - << "pool size (" << param.pool_size[0] << ") exceeds input (" << dshape[2] - << " padded to " << (dshape[2] + 2*param.padding[0]) << ")"; - CHECK(param.pool_size[1] <= dshape[3] + 2 * param.padding[1]) - << "pool size (" << param.pool_size[1] << ") exceeds input (" << dshape[3] - << " padded to " << (dshape[3] + 2*param.padding[1]) << ")"; + CHECK(param.pool_size[0] <= dshape[hidx] + 2 * param.padding[0]) + << "pool size (" << param.pool_size[0] << ") exceeds input (" << dshape[hidx] + << " padded to " << (dshape[hidx] + 2*param.padding[0]) << ")"; + CHECK(param.pool_size[1] <= dshape[widx] + 2 * param.padding[1]) + << "pool size (" << param.pool_size[1] << ") exceeds input (" << dshape[widx] + << " padded to " << (dshape[widx] + 2*param.padding[1]) << ")"; if (!param.ceil_mode) { - oshape[2] = ((dshape[2] + 2 * param.padding[0] - param.pool_size[0]) / - param.strides[0]) + 1; - oshape[3] = ((dshape[3] + 2 * param.padding[1] - param.pool_size[1]) / - param.strides[1]) + 1; + oshape[hidx] = ((dshape[hidx] + 2 * param.padding[0] - param.pool_size[0]) / + param.strides[0]) + 1; + oshape[widx] = ((dshape[widx] + 2 * param.padding[1] - param.pool_size[1]) / + param.strides[1]) + 1; } else { - oshape[2] = ((dshape[2] + 2 * param.padding[0] - param.pool_size[0] + - param.strides[0] - 1) / param.strides[0]) + 1; - oshape[3] = ((dshape[3] + 2 * param.padding[1] - param.pool_size[1] + - param.strides[1] - 1) / param.strides[1]) + 1; + oshape[hidx] = ((dshape[hidx] + 2 * param.padding[0] - param.pool_size[0] + + param.strides[0] - 1) / param.strides[0]) + 1; + oshape[widx] = ((dshape[3] + 2 * param.padding[1] - param.pool_size[1] + + param.strides[1] - 1) / param.strides[1]) + 1; } - oshape = ConvertLayout(oshape, kNCHW, param.layout); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); return true; } +inline bool Pool2DInferLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + const Pool2DParam ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(ilayouts->size(), 1); + CHECK_EQ(last_ilayouts->size(), 1); + CHECK_EQ(olayouts->size(), 1); + + Layout input = (*ilayouts)[0]; + const Layout layout(param.layout); + + if (input.defined()) { + CHECK(input.convertible(layout)) << "Invalid input layout " << input; + if (input.indexof('W') != layout.indexof('W') || + input.indexof('H') != layout.indexof('H') || + input.contains('w') || input.contains('h')) { + // as long as the index doesn't change for width and height + // pool2d can keep the input layout. + input = layout; + } + } else { + input = layout; + } + + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, input); + NNVM_ASSIGN_LAYOUT(*olayouts, 0, input); + + return true; +} + NNVM_REGISTER_OP(max_pool2d) .describe(R"code(Max pooling operation for one dimensional data. @@ -82,20 +121,29 @@ NNVM_REGISTER_OP(max_pool2d) .set_num_inputs(1) .set_attr("FInferShape", Pool2DInferShape) .set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr( - "FTVMCompute", [](const NodeAttrs& attrs, - const Array& inputs, - const Array& out_info) { - const Pool2DParam& param = nnvm::get(attrs.parsed); - auto pool_size = ShapeToArray(param.pool_size); - auto strides = ShapeToArray(param.strides); - auto padding = ShapeToArray(param.padding); - auto ceil_mode = param.ceil_mode; - CHECK(param.layout == kNCHW || param.layout == kNHWC) << "Unsupported layout"; - std::string layout = (param.layout == kNCHW ? "NCHW" : "NHWC"); - return Array{ - topi::nn::pool(inputs[0], pool_size, strides, padding, \ - topi::nn::kMaxPool, ceil_mode, layout) }; +.set_attr("FInferLayout", Pool2DInferLayout) +.set_attr("FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const Pool2DParam& param = nnvm::get(attrs.parsed); + auto pool_size = ShapeToArray(param.pool_size); + auto strides = ShapeToArray(param.strides); + auto padding = ShapeToArray(param.padding); + auto ceil_mode = param.ceil_mode; + + Layout layout(param.layout); + CHECK(layout.convertible(Layout("NCHW"))) + << "max_pool2d currently only supports layouts that are convertible from NCHW"; + CHECK_EQ(layout.indexof('h'), -1) << "max_pool2d does not support input split on height"; + CHECK_EQ(layout.indexof('w'), -1) << "max_pool2d does not support input split on width"; + + CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) + << "Pool2D only support 4-D input (e.g., NCHW)" + << " or 5-D input (last dimension is a split of channel)"; + + return Array{ + topi::nn::pool(inputs[0], pool_size, strides, padding, + topi::nn::kMaxPool, ceil_mode, layout.name())}; }) .set_attr( "FGradient", [](const NodePtr& n, @@ -144,20 +192,29 @@ NNVM_REGISTER_OP(avg_pool2d) .set_attr("FGetAttrDict", ParamGetAttrDict) .set_attr("FInferShape", Pool2DInferShape) .set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr( - "FTVMCompute", [](const NodeAttrs& attrs, - const Array& inputs, - const Array& out_info) { - const Pool2DParam& param = nnvm::get(attrs.parsed); - auto pool_size = ShapeToArray(param.pool_size); - auto strides = ShapeToArray(param.strides); - auto padding = ShapeToArray(param.padding); - auto ceil_mode = param.ceil_mode; - CHECK(param.layout == kNCHW || param.layout == kNHWC) << "Unsupported layout"; - std::string layout = (param.layout == kNCHW ? "NCHW" : "NHWC"); - return Array{ - topi::nn::pool(inputs[0], pool_size, strides, padding, \ - topi::nn::kAvgPool, ceil_mode, layout) }; +.set_attr("FInferLayout", Pool2DInferLayout) +.set_attr("FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const Pool2DParam& param = nnvm::get(attrs.parsed); + auto pool_size = ShapeToArray(param.pool_size); + auto strides = ShapeToArray(param.strides); + auto padding = ShapeToArray(param.padding); + auto ceil_mode = param.ceil_mode; + + Layout layout(param.layout); + CHECK(layout.convertible(Layout("NCHW"))) + << "avg_pool2d currently only supports layouts that are convertible from NCHW"; + CHECK_EQ(layout.indexof('h'), -1) << "avg_pool2d does not support input split on height"; + CHECK_EQ(layout.indexof('w'), -1) << "avg_pool2d does not support input split on width"; + + CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) + << "Pool2D only support 4-D input (e.g., NCHW)" + << " or 5-D input (last dimension is a split of channel)"; + + return Array{ + topi::nn::pool(inputs[0], pool_size, strides, padding, + topi::nn::kAvgPool, ceil_mode, layout.name())}; }) .set_num_outputs(1) .set_num_inputs(1) @@ -169,19 +226,63 @@ DMLC_REGISTER_PARAMETER(GlobalPool2DParam); inline bool GlobalPool2DInferShape(const nnvm::NodeAttrs& attrs, std::vector* in_shape, std::vector* out_shape) { + static const Layout kNCHW("NCHW"); const GlobalPool2DParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_shape->size(), 1U); CHECK_EQ(out_shape->size(), 1U); + TShape dshape = (*in_shape)[0]; if (dshape.ndim() == 0) return false; - dshape = ConvertLayout(dshape, param.layout, kNCHW); + + CHECK_GE(dshape.ndim(), 2U) + << "Pool2D only support input >= 2-D: input must have height and width"; + + Layout layout(param.layout); + CHECK(layout.contains('H') && layout.contains('W') && + !layout.contains('h') && !layout.contains('w')) + << "Invalid layout " << layout + << ". Pool2D layout must have H and W, which cannot be split"; + + const auto hidx = layout.indexof('H'); + const auto widx = layout.indexof('W'); + TShape oshape = dshape; - oshape[2] = oshape[3] = 1; - oshape = ConvertLayout(oshape, kNCHW, param.layout); + oshape[hidx] = oshape[widx] = 1; NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); return true; } +inline bool GlobalPool2DInferLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + const GlobalPool2DParam ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(ilayouts->size(), 1); + CHECK_EQ(last_ilayouts->size(), 1); + CHECK_EQ(olayouts->size(), 1); + + Layout input = (*ilayouts)[0]; + const Layout layout(param.layout); + + if (input.defined()) { + CHECK(input.convertible(layout)) << "Invalid input layout " << input; + if (input.indexof('W') != layout.indexof('W') || + input.indexof('H') != layout.indexof('H') || + input.contains('w') || input.contains('h')) { + // as long as the index doesn't change for width and height + // pool2d can keep the input layout. + input = layout; + } + } else { + input = layout; + } + + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, input); + NNVM_ASSIGN_LAYOUT(*olayouts, 0, input); + + return true; +} + NNVM_REGISTER_OP(global_max_pool2d) .describe(R"code(Global max pooling operation for 2D data. @@ -197,15 +298,26 @@ NNVM_REGISTER_OP(global_max_pool2d) .set_attr("FGetAttrDict", ParamGetAttrDict) .set_attr("FInferShape", GlobalPool2DInferShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferLayout", GlobalPool2DInferLayout) .set_attr( "FTVMCompute", [](const NodeAttrs& attrs, const Array& inputs, const Array& out_info) { - const GlobalPool2DParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(param.layout, kNCHW) - << "global_max_pool2d currently only supports NCHW layout"; - return Array{ - topi::nn::global_pool(inputs[0], topi::nn::kMaxPool) }; + const GlobalPool2DParam& param = nnvm::get(attrs.parsed); + Layout layout(param.layout); + CHECK(layout.convertible(Layout("NCHW"))) + << "global_max_pool2d currently only supports layouts that are convertible from NCHW"; + CHECK_EQ(layout.indexof('h'), -1) + << "global_max_pool2d does not support input split on height"; + CHECK_EQ(layout.indexof('w'), -1) + << "global_max_pool2d does not support input split on width"; + + CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) + << "Pool2D only support 4-D input (e.g., NCHW)" + << " or 5-D input (last dimension is a split of channel)"; + + return Array{ + topi::nn::global_pool(inputs[0], topi::nn::kMaxPool, layout.name()) }; }) .set_num_outputs(1) .set_num_inputs(1) @@ -227,15 +339,26 @@ NNVM_REGISTER_OP(global_avg_pool2d) .set_attr("FGetAttrDict", ParamGetAttrDict) .set_attr("FInferShape", GlobalPool2DInferShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferLayout", GlobalPool2DInferLayout) .set_attr( "FTVMCompute", [](const NodeAttrs& attrs, const Array& inputs, const Array& out_info) { - const GlobalPool2DParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(param.layout, kNCHW) - << "global_avg_pool2d currently only supports NCHW layout"; - return Array{ - topi::nn::global_pool(inputs[0], topi::nn::kAvgPool) }; + const GlobalPool2DParam& param = nnvm::get(attrs.parsed); + Layout layout(param.layout); + CHECK(layout.convertible(Layout("NCHW"))) + << "global_avg_pool2d currently only supports layouts that are convertible from NCHW"; + CHECK_EQ(layout.indexof('h'), -1) + << "global_avg_pool2d does not support input split on height"; + CHECK_EQ(layout.indexof('w'), -1) + << "global_avg_pool2d does not support input split on width"; + + CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) + << "Pool2D only support 4-D input (e.g., NCHW)" + << " or 5-D input (last dimension is a split of channel)"; + + return Array{ + topi::nn::global_pool(inputs[0], topi::nn::kAvgPool, layout.name()) }; }) .set_num_outputs(1) .set_num_inputs(1) diff --git a/nnvm/src/top/nn/upsampling.cc b/nnvm/src/top/nn/upsampling.cc index 3195338c2..0e0f3b274 100644 --- a/nnvm/src/top/nn/upsampling.cc +++ b/nnvm/src/top/nn/upsampling.cc @@ -19,6 +19,7 @@ DMLC_REGISTER_PARAMETER(UpSamplingParam); inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs, std::vector* in_shape, std::vector* out_shape) { + static const Layout kNCHW("NCHW"); const UpSamplingParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_shape->size(), 1U); CHECK_EQ(out_shape->size(), 1U); @@ -33,6 +34,19 @@ inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs, return true; } +inline bool UpsamplingLayout(const NodeAttrs& attrs, + std::vector *in_layouts, + const std::vector *last_in_layouts, + std::vector *out_layouts) { + const UpSamplingParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_layouts->size(), 1U); + CHECK_EQ(out_layouts->size(), 1U); + const Layout layout(param.layout); + NNVM_ASSIGN_LAYOUT(*in_layouts, 0, layout); + NNVM_ASSIGN_LAYOUT(*out_layouts, 0, layout); + return true; +} + NNVM_REGISTER_OP(upsampling) .describe(R"(Perform nearest neighbor upsampling to input array. @@ -46,6 +60,7 @@ NNVM_REGISTER_OP(upsampling) .set_attr("FGetAttrDict", ParamGetAttrDict) .set_attr("FInferShape", UpSamplingInferShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferLayout", UpsamplingLayout) .set_num_outputs(1) .set_num_inputs(1) .set_support_level(2); diff --git a/nnvm/src/top/op_common.h b/nnvm/src/top/op_common.h index ae4388ade..826067ed5 100644 --- a/nnvm/src/top/op_common.h +++ b/nnvm/src/top/op_common.h @@ -203,6 +203,13 @@ inline std::string attr_assign_error_msg(const NodeAttrs& attrs, } \ } +#define NNVM_ASSIGN_LAYOUT(outputs, index, layout) \ + { \ + if (layout.defined()) { \ + (outputs)[index] = layout; \ + } \ + } + /*! * \brief macro assign rhs shape to lhs * Use macro so we can see the error file more clearly @@ -253,6 +260,14 @@ inline bool ZeroShape(const NodeAttrs& attrs, } } +// do not infer layout +inline bool ZeroLayout(const NodeAttrs& attrs, + std::vector *in_layouts, + const std::vector *last_in_layouts, + std::vector *out_layouts) { + return true; +} + // simply assign output shape or type from input template inline bool AssignOutputAttr(const NodeAttrs& attrs, diff --git a/nnvm/src/top/tensor/broadcast.cc b/nnvm/src/top/tensor/broadcast.cc index 773281450..3200f1aaf 100644 --- a/nnvm/src/top/tensor/broadcast.cc +++ b/nnvm/src/top/tensor/broadcast.cc @@ -11,6 +11,7 @@ #include #include #include +#include #include "../op_common.h" #include "../elemwise_op_common.h" #include "topi/broadcast.h" @@ -74,6 +75,7 @@ So with `shape=(2,0)`, we will obtain the same result as in the above example. .set_attr("FGetAttrDict", ParamGetAttrDict) .set_attr("FInferShape", BroadcastToInferShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_attr( "FTVMCompute", [](const NodeAttrs& attrs, const Array& inputs, @@ -115,7 +117,7 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs, } else { CHECK(l == 1 || r == 1) << "operands could not be broadcast together with shapes " - << lhs << " " << rhs; + << lhs << " " << rhs << ", l=" << l << ", r=" << r; out[i] = std::max(l, r); } } else { @@ -126,6 +128,77 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs, return true; } +inline bool BinaryBroadcastInferLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + CHECK_EQ(ilayouts->size(), 2U); + CHECK_EQ(olayouts->size(), 1U); + Layout lhs = (*ilayouts)[0]; + Layout rhs = (*ilayouts)[1]; + Layout out(Layout::Undef()); + + if (lhs.defined() && rhs.defined()) { + if (lhs == rhs) { + NNVM_ASSIGN_LAYOUT(*olayouts, 0, lhs); + return true; + } + // For example, NCHW <-> CHW, N16nCH16cW <-> HCW16c, etc, are broadcast-convertible + // because as the definition, CHW can broadcast with NCHW. + // For the second case, we can convert HCW16c to CH16cW then it can broadcast with N16nCH16cW. + // But CNHW <-> CHW, NCHW16n <-> CHW are not, + // because not matter how we adjust the layout of 'CHW', + // we can never have an 'N' between 'C' and "HW". + size_t l_start = 0, r_start = 0; + size_t l = 0, r = 0; + bool find_first_match = false; + while (l < lhs.ndim() && r < rhs.ndim()) { + if (!rhs.contains(Layout::to_superdim(lhs[l]))) { + CHECK(!find_first_match) << lhs << " and " << rhs << " are not broadcast-convertible"; + l_start = ++l; + } else if (!lhs.contains(Layout::to_superdim(rhs[r]))) { + CHECK(!find_first_match) << lhs << " and " << rhs << " are not broadcast-convertible"; + r_start = ++r; + } else { + find_first_match = true; + ++l; ++r; + } + } + if (l_start > 0 && r_start > 0) { + LOG(FATAL) << lhs << " and " << rhs << " are not broadcast-convertible"; + } else if (l_start > 0) { + rhs = lhs.sublayout(l_start, lhs.ndim()-l_start); + out = lhs; + } else if (r_start > 0) { + lhs = rhs.sublayout(r_start, rhs.ndim()-r_start); + out = rhs; + } else { + // prior to keep left layout + rhs = lhs; + out = lhs; + } + } else if (lhs.defined()) { + const Layout& last_lhs = last_ilayouts->at(0); + if (last_lhs.defined()) { + CHECK(lhs.convertible(last_lhs)) << "current lhs layout " << lhs + << " cannot be converted to the original one " << last_lhs; + lhs = last_lhs; + // cannot decide output layout + } + } else if (rhs.defined()) { + const Layout& last_rhs = last_ilayouts->at(1); + if (last_rhs.defined()) { + CHECK(rhs.convertible(last_rhs)) << "current rhs layout " << rhs + << " cannot be converted to the original one " << last_rhs; + rhs = last_rhs; + // cannot decide output layout + } + } + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, lhs); + NNVM_ASSIGN_LAYOUT(*ilayouts, 1, rhs); + NNVM_ASSIGN_LAYOUT(*olayouts, 0, out); + return true; +} #define NNVM_REGISTER_BINARY_BROADCAST_OP(name) \ NNVM_REGISTER_OP(name) \ @@ -133,6 +206,8 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs, .set_num_outputs(1) \ .set_attr("FInferShape", BinaryBroadcastShape) \ .set_attr("FInferType", ElemwiseType<2, 1>) \ + .set_attr("FInferLayout", \ + BinaryBroadcastInferLayout) \ .set_attr("FInplaceOption", \ [](const NodeAttrs& attrs) { \ return std::vector >{{0, 0}, {1, 0}}; \ diff --git a/nnvm/src/top/tensor/elemwise.cc b/nnvm/src/top/tensor/elemwise.cc index 87fbf5823..51a55e649 100644 --- a/nnvm/src/top/tensor/elemwise.cc +++ b/nnvm/src/top/tensor/elemwise.cc @@ -333,6 +333,7 @@ NNVM_REGISTER_INIT_OP(full) .add_arguments(InitOpWithScalarParam::__FIELDS__()) .set_attr("FInferShape", ZeroShape) .set_attr("FInferType", ZeroType) +.set_attr("FInferLayout", ZeroLayout) .set_support_level(4); NNVM_REGISTER_INIT_OP(zeros) @@ -345,6 +346,7 @@ NNVM_REGISTER_INIT_OP(zeros) .add_arguments(InitOpParam::__FIELDS__()) .set_attr("FInferShape", ZeroShape) .set_attr("FInferType", ZeroType) +.set_attr("FInferLayout", ZeroLayout) .set_support_level(4); NNVM_REGISTER_INIT_OP(ones) @@ -357,6 +359,7 @@ NNVM_REGISTER_INIT_OP(ones) .add_arguments(InitOpParam::__FIELDS__()) .set_attr("FInferShape", ZeroShape) .set_attr("FInferType", ZeroType) +.set_attr("FInferLayout", ZeroLayout) .set_support_level(4); // full_like @@ -693,6 +696,7 @@ Example:: .set_attr("FGetAttrDict", ParamGetAttrDict) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_attr( "FTVMCompute", [](const NodeAttrs& attrs, const Array& inputs, diff --git a/nnvm/src/top/tensor/matrix_op.cc b/nnvm/src/top/tensor/matrix_op.cc index 149c609ee..d357267be 100644 --- a/nnvm/src/top/tensor/matrix_op.cc +++ b/nnvm/src/top/tensor/matrix_op.cc @@ -41,6 +41,31 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs, return true; } +inline bool DotInferLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + const MatMulParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(ilayouts->size(), 2U); + CHECK_EQ(olayouts->size(), 1U); + const Layout& lhs = last_ilayouts->at(0).defined() ? last_ilayouts->at(0) + : ilayouts->at(0); + const Layout& rhs = last_ilayouts->at(1).defined() ? last_ilayouts->at(1) + : ilayouts->at(1); + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, lhs); + NNVM_ASSIGN_LAYOUT(*ilayouts, 1, rhs); + + if (lhs.ndim() > 1 && rhs.ndim() > 1) { + // concat lhs and rhs layout + const Layout& lhs_out = param.transpose_a ? lhs.reverse() : lhs; + const Layout& rhs_out = param.transpose_b ? rhs.reverse() : rhs; + Layout out = std::move(lhs_out.sublayout(0, lhs_out.ndim()-1) + + rhs_out.sublayout(1, rhs_out.ndim()-1)); + NNVM_ASSIGN_LAYOUT(*olayouts, 0, out); + } + return true; +} + NNVM_REGISTER_OP(matmul) .describe(R"doc(Matrix multiplication of two arrays. @@ -67,6 +92,7 @@ NNVM_REGISTER_OP(matmul) .add_argument("rhs", "NDArray-or-Symbol", "The second input") .set_attr("FInferShape", DotShape) .set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FInferLayout", DotInferLayout) .set_attr( "FGradient", [](const NodePtr& n, const std::vector& ograds) { diff --git a/nnvm/src/top/tensor/reduce.cc b/nnvm/src/top/tensor/reduce.cc index 84a7dd0f0..d4049d362 100644 --- a/nnvm/src/top/tensor/reduce.cc +++ b/nnvm/src/top/tensor/reduce.cc @@ -111,6 +111,8 @@ inline void AxesParamParser(nnvm::NodeAttrs* attrs) { .set_attr("FGetAttrDict", ParamGetAttrDict) \ .set_attr("FInferShape", ReduceShape) \ .set_attr("FInferType", ElemwiseType<1, 1>) \ + .set_attr("FInferLayout", \ + ElemwiseFixedLayoutUnknownOut<1, 1>) \ .set_num_inputs(1) \ .set_num_outputs(1) diff --git a/nnvm/src/top/tensor/state_op.cc b/nnvm/src/top/tensor/state_op.cc index ebce07696..9adbf6f36 100644 --- a/nnvm/src/top/tensor/state_op.cc +++ b/nnvm/src/top/tensor/state_op.cc @@ -45,6 +45,15 @@ This is an experimental operator. return Array{ topi::identity(inputs[1]) }; }) .set_attr("FInferShape", SameShape) +.set_attr( + "FInferLayout", [](const NodeAttrs& attrs, + std::vector *in_layouts, + const std::vector *last_in_layouts, + std::vector *out_layouts) { + NNVM_ASSIGN_LAYOUT(*in_layouts, 1, (*in_layouts)[0]); + NNVM_ASSIGN_LAYOUT(*out_layouts, 0, (*in_layouts)[0]); + return true; +}) .set_attr( "FInplaceOption", [](const NodeAttrs& attrs) { return std::vector >{{1, 0}}; diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 245774734..034fa957b 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -10,6 +10,7 @@ #include #include #include +#include #include "../op_common.h" #include "../elemwise_op_common.h" #include "topi/nn/flatten.h" @@ -63,6 +64,7 @@ Example:: .set_num_outputs(1) .set_attr("FInferShape", FlattenInferShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .add_argument("data", "Tensor", "Input data.") .set_attr( "FTVMCompute", [](const NodeAttrs& attrs, @@ -119,6 +121,22 @@ inline bool ConcatenateInferShape(const NodeAttrs& attrs, return dshape.Size() != 0; } +inline bool ConcatenateInferLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + CHECK_EQ(ilayouts->size(), last_ilayouts->size()); + CHECK_EQ(olayouts->size(), 1U); + + for (size_t i = 0; i < ilayouts->size(); ++i) { + const Layout& input = last_ilayouts->at(i).defined() ? + last_ilayouts->at(i) : ilayouts->at(i); + NNVM_ASSIGN_LAYOUT(*ilayouts, i, input); + } + + return true; +} + NNVM_REGISTER_OP(concatenate) .describe(R"code(Joins input arrays along a given axis. @@ -156,6 +174,7 @@ Example:: .set_attr("FGetAttrDict", ParamGetAttrDict) .set_attr("FInferShape", ConcatenateInferShape) .set_attr("FInferType", ElemwiseType<-1, 1>) +.set_attr("FInferLayout", ConcatenateInferLayout) .set_attr( "FTVMCompute", [](const NodeAttrs& attrs, const Array& inputs, @@ -177,7 +196,8 @@ inline bool ExpandDimsInferShape(const NodeAttrs& attrs, CHECK_EQ(in_shape->size(), 1U); const TShape& dshape = in_shape->at(0); int ndim = static_cast(dshape.ndim()); - CHECK(param.axis >= -ndim - 1 && param.axis <= ndim); + CHECK(param.axis >= -ndim - 1 && param.axis <= ndim) + << "with axis = " << param.axis << " ndim = " << ndim; int axis = param.axis < 0 ? ndim + param.axis + 1 : param.axis; std::vector oshape; for (int i = 0; i < axis; ++i) { @@ -198,7 +218,7 @@ NNVM_REGISTER_OP(expand_dims) .describe(R"code(Inserts a new axis of size 1 into the array shape For example, given ``x`` with shape ``(2,3,4)``, then ``expand_dims(x, axis=1, num_newaxis=5)`` -will return a new array with shape ``(2,5,3,4)``. +will return a new array with shape ``(2,1,1,1,1,1,3,4)``. )code" NNVM_ADD_FILELINE) .add_argument("data", "Tensor", "Input tensor") @@ -207,6 +227,7 @@ will return a new array with shape ``(2,5,3,4)``. .set_attr("FGetAttrDict", ParamGetAttrDict) .set_attr("FInferShape", ExpandDimsInferShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_num_inputs(1) .set_num_outputs(1) .set_attr( @@ -249,6 +270,8 @@ Examples:: .set_attr("FGetAttrDict", ParamGetAttrDict) .set_attr("FInferShape", AssignOutputAttr) .set_attr("FInferType", ElemwiseType<2, 1>) +// never transform layout of the second input array. +.set_attr("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_num_inputs(2) .set_num_outputs(1) .set_attr( @@ -345,6 +368,7 @@ along which to split the array. .set_attr_parser(SplitParamParser) .set_attr("FInferShape", SplitInferShape) .set_attr("FInferType", ElemwiseType<1, -1>) +.set_attr("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, -1>) .set_num_inputs(1) .set_num_outputs(SplitNumOutputs) .set_attr( @@ -387,6 +411,7 @@ NNVM_REGISTER_OP(cast) .set_attr("FGetAttrDict", ParamGetAttrDict) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", CastInferType) +.set_attr("FInferLayout", ElemwiseArbitraryLayout<1, 1>) .set_num_inputs(1) .set_num_outputs(1) .set_support_level(1); @@ -539,6 +564,7 @@ The significance of each is explained below: .set_attr("FGetAttrDict", ParamGetAttrDict) .set_attr("FInferShape", ReshapeInferShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_num_inputs(1) .set_num_outputs(1) .set_attr( @@ -578,6 +604,8 @@ the input array into an output array with the same shape as the second input arr return true; }) .set_attr("FInferType", ElemwiseType<2, 1>) +// never transform layout of the second input array. +.set_attr("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_attr( "FGradient", [](const NodePtr& n, const std::vector& ograds) { @@ -660,6 +688,7 @@ Examples:: .set_attr("FGetAttrDict", ParamGetAttrDict) .set_attr("FInferShape", SqueezeShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_num_inputs(1) .set_num_outputs(1) .set_attr( @@ -680,7 +709,7 @@ Examples:: }) .set_support_level(1); -// tranpose +// transpose DMLC_REGISTER_PARAMETER(TransposeParam); inline bool TransposeShape(const nnvm::NodeAttrs& attrs, @@ -708,6 +737,39 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, return true; } +inline bool TransposeInferLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + const TransposeParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(ilayouts->size(), 1U); + CHECK_EQ(olayouts->size(), 1U); + + const Layout& input = last_ilayouts->at(0).defined() + ? last_ilayouts->at(0) + : ilayouts->at(0); + + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, input); + + if (input.defined()) { + std::ostringstream new_layout; + if (param.axes.ndim() == 0) { + for (size_t i = 0; i < input.ndim(); ++i) { + new_layout << input.at(input.ndim() - 1 - i); + } + } else { + CHECK_EQ(input.ndim(), param.axes.ndim()); + for (size_t i = 0; i < input.ndim(); ++i) { + CHECK(param.axes[i] < input.ndim()); + new_layout << input.at(param.axes[i]); + } + } + NNVM_ASSIGN_LAYOUT(*olayouts, 0, Layout(new_layout.str())); + } + + return true; +} + NNVM_REGISTER_OP(transpose) .describe(R"code(Permutes the dimensions of an array. @@ -743,6 +805,7 @@ Examples:: .set_attr("FGetAttrDict", ParamGetAttrDict) .set_attr("FInferShape", TransposeShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferLayout", TransposeInferLayout) .set_num_inputs(1) .set_num_outputs(1) .set_support_level(4) diff --git a/nnvm/tests/python/compiler/test_alter_op_layout.py b/nnvm/tests/python/compiler/test_alter_op_layout.py new file mode 100644 index 000000000..d921d4a64 --- /dev/null +++ b/nnvm/tests/python/compiler/test_alter_op_layout.py @@ -0,0 +1,49 @@ +"""Unittest cases for AlterOpLayout pass""" +from nnvm import symbol as sym +from nnvm.compiler import graph_attr +from nnvm.top import registry as reg +import nnvm.graph as graph + +def get_layouts(g): + ldict = {} + vlayout = g.json_attr("layout") + entry_ptr = g.index.entry_ptr + for i, n in enumerate(g.index.nodes): + begin, end = entry_ptr[i], entry_ptr[i + 1] + ldict[n["name"]] = vlayout[begin:end] + return ldict + + +def test_alter_conv2d_layout(): + data = sym.Variable("data", shape=(1, 32, 512, 512)) + conv = sym.conv2d(data, name="conv", channels=16, + kernel_size=(3,3), padding=(1,1), + use_bias=False, layout="NCHW") + relu = sym.relu(conv, name="relu") + flatten = sym.flatten(relu, name="flatten") + softmax = sym.softmax(flatten, name="softmax") + g = graph.create(softmax) + g = g.apply("CorrectLayout") + g = graph_attr.set_dtype_inputs(g, "float32") + g = g.apply(["InferShape", "InferType"]) + layouts_origin = get_layouts(g) + + @reg.register_alter_op_layout("conv2d") + def alter_conv2d_layout(attrs, inputs, tinfos): + new_attrs = {k : attrs[k] for k in attrs.keys()} + new_attrs["layout"] = "NCHW16c" + new_attrs["kernel_layout"] = "NCHW16c" + new_attrs["name"] = "conv_alter" + return sym.conv2d(inputs[0], inputs[1], **new_attrs) + + g = g.apply("AlterOpLayout") + layouts = get_layouts(g) + + # check copy layouts + for node in ["data", "relu", "flatten", "softmax", "conv_weight"]: + assert(layouts[node] == layouts_origin[node]) + assert(layouts["conv_alter"] == layouts_origin["conv"]) + + +if __name__ == "__main__": + test_alter_conv2d_layout() diff --git a/nnvm/tests/python/compiler/test_nhwc_layout.py b/nnvm/tests/python/compiler/test_nhwc_layout.py index 57e27db74..96a813543 100644 --- a/nnvm/tests/python/compiler/test_nhwc_layout.py +++ b/nnvm/tests/python/compiler/test_nhwc_layout.py @@ -5,9 +5,10 @@ import nnvm.compiler from nnvm.testing.config import ctx_list -def get_sym(layout, channels): +def get_sym(layout, kernel_layout, channels): data = sym.Variable(name="data") - data = sym.conv2d(data=data, kernel_size=(3,3), channels=channels, padding=(1, 1), layout=layout, use_bias=True) + data = sym.conv2d(data=data, kernel_size=(3,3), channels=channels, padding=(1, 1), + layout=layout, kernel_layout=kernel_layout, use_bias=True) data = sym.max_pool2d(data=data, pool_size=(2, 2), strides=(2, 2), layout=layout) data = sym.upsampling(data=data, scale=2, layout=layout) softmax_axis = 1 @@ -31,8 +32,8 @@ def build_and_run(sym, params, data, out_shape): def test_nhwc(): data_shape = (1, 3, 224, 224) out_channel = 8 - nchw_sym = get_sym("NCHW", out_channel) - nhwc_sym = get_sym("NHWC", out_channel) + nchw_sym = get_sym("NCHW", "OIHW", out_channel) + nhwc_sym = get_sym("NHWC", "HWIO", out_channel) conv_weight = np.random.uniform(-1, 1, (out_channel, 3, 3, 3)).astype(np.float32) conv_bias = np.random.uniform(-1, 1, (out_channel)).astype(np.float32) nchw_params = { diff --git a/nnvm/tests/python/unittest/test_correct_layout.py b/nnvm/tests/python/unittest/test_correct_layout.py new file mode 100644 index 000000000..c428a2f83 --- /dev/null +++ b/nnvm/tests/python/unittest/test_correct_layout.py @@ -0,0 +1,338 @@ +import nnvm +import nnvm.symbol as sym +import nnvm.graph as graph +from nnvm.compiler import graph_attr + +# Level 1 +def correct_layout(g, layout=None): + if isinstance(g, nnvm.symbol.Symbol): + g = graph.create(g) + if layout: + graph_attr.set_layout_inputs(g, layout) + g = g.apply("CorrectLayout") + ldict = {} + vlayout = g.json_attr("layout") + entry_ptr = g.index.entry_ptr + for i, n in enumerate(g.index.nodes): + begin, end = entry_ptr[i], entry_ptr[i + 1] + ldict[n["name"]] = vlayout[begin:end] + return g, ldict + + +def test_dense(): + x = sym.Variable("data", shape=(10, 20)) + y = sym.dense(x, units=30, name="fc") + g, ldict = correct_layout(y, "HW") + assert(ldict["data"][0] == "HW") + assert(ldict["fc"][0] == "HW") + assert(ldict["fc_bias"][0] == "__undef__") + # second pass will insert layout transform + _, ldict = correct_layout(g, "HW16w") + assert(ldict["data"][0] == "HW16w") + assert(ldict["data_HW"][0] == "HW") + assert(ldict["fc"][0] == "HW") + assert(ldict["fc_bias"][0] == "__undef__") + + +def test_matmul(): + a = sym.Variable("a", shape=(10, 20)) + b = sym.Variable("b", shape=(20, 30)) + c = sym.matmul(a, b, name="matmul") + g, ldict = correct_layout(c, {"a" : "HW", "b" : "WC"}) + assert(ldict["a"][0] == "HW") + assert(ldict["b"][0] == "WC") + assert(ldict["matmul"][0] == "HC") + # second pass will insert layout transform + _, ldict = correct_layout(g, {"a" : "HW16w", "b" : "WC16c"}) + assert(ldict["a"][0] == "HW16w") + assert(ldict["a_HW"][0] == "HW") + assert(ldict["b"][0] == "WC16c") + assert(ldict["b_WC"][0] == "WC") + assert(ldict["matmul"][0] == "HC") + a = sym.Variable("a", shape=(20, 10)) + c = sym.matmul(a, b, name="matmul", transpose_a=True) + g, ldict = correct_layout(c, {"a" : "HW", "b" : "HC"}) + assert(ldict["a"][0] == "HW") + assert(ldict["b"][0] == "HC") + assert(ldict["matmul"][0] == "WC") + b = sym.Variable("b", shape=(30, 20)) + c = sym.matmul(a, b, name="matmul", transpose_b=True) + g, ldict = correct_layout(c, {"a" : "HW", "b" : "CW"}) + assert(ldict["a"][0] == "HW") + assert(ldict["b"][0] == "CW") + assert(ldict["matmul"][0] == "HC") + a = sym.Variable("a", shape=(20, 10)) + b = sym.Variable("b", shape=(30, 20)) + c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True) + g, ldict = correct_layout(c, {"a" : "HW", "b" : "CH"}) + assert(ldict["a"][0] == "HW") + assert(ldict["b"][0] == "CH") + assert(ldict["matmul"][0] == "WC") + + +def test_concatenate(): + x1 = sym.Variable("x", shape=(10, 20)) + x2 = sym.Variable("y", shape=(10, 30)) + z = sym.concatenate(x1, x2, name="concat") + g, ldict = correct_layout(z, {"x": "HW", "y": "HW"}) + assert(ldict["x"][0] == "HW") + assert(ldict["y"][0] == "HW") + assert(ldict["concat"][0] == "__undef__") + # second pass will insert layout transform + _, ldict = correct_layout(g, {"x": "HW16w", "y": "HW16w"}) + assert(ldict["x"][0] == "HW16w") + assert(ldict["y"][0] == "HW16w") + assert(ldict["x_HW"][0] == "HW") + assert(ldict["y_HW"][0] == "HW") + assert(ldict["concat"][0] == "__undef__") + + +def test_expand_dims(): + x = sym.Variable("x", shape=(10, 20)) + y = sym.expand_dims(x, axis=1, name="y") + g, ldict = correct_layout(y, "HW") + assert(ldict["x"][0] == "HW") + assert(ldict["y"][0] == "__undef__") + # second pass will insert layout transform + _, ldict = correct_layout(g, "HW16w") + assert(ldict["x"][0] == "HW16w") + assert(ldict["x_HW"][0] == "HW") + assert(ldict["y"][0] == "__undef__") + + +def test_split(): + x = sym.Variable("x", shape=(10, 20)) + y = sym.split(x, indices_or_sections=[11], name="y") + g, ldict = correct_layout(y, "HW") + assert(ldict["x"][0] == "HW") + assert(ldict["y"][0] == "__undef__") + # second pass will insert layout transform + _, ldict = correct_layout(g, "HW16w") + assert(ldict["x"][0] == "HW16w") + assert(ldict["x_HW"][0] == "HW") + assert(ldict["y"][0] == "__undef__") + + +def test_batchnorm(): + x = sym.Variable("data", shape=(10, 20, 30, 40)) + y = sym.batch_norm(x, axis=1, epsilon=2e-5, name="bn") + g, ldict = correct_layout(y, "NCHW") + assert(ldict["data"][0] == "NCHW") + assert(ldict["bn"][0] == "NCHW") + assert(ldict["bn"][1] == "C") + assert(ldict["bn"][2] == "C") + assert(ldict["bn_beta"][0] == "C") + assert(ldict["bn_gamma"][0] == "C") + assert(ldict["bn_moving_mean"][0] == "C") + assert(ldict["bn_moving_var"][0] == "C") + # batch_norm can deal with sub-dim of C at the last dim. + g, ldict = correct_layout(g, "NCHW16c") + assert(ldict["data"][0] == "NCHW16c") + assert(ldict["bn"][0] == "NCHW16c") + assert(ldict["bn"][1] == "C16c") + assert(ldict["bn"][2] == "C16c") + assert(ldict["bn_beta"][0] == "C") + assert(ldict["bn_beta_C16c"][0] == "C16c") + assert(ldict["bn_gamma"][0] == "C") + assert(ldict["bn_gamma_C16c"][0] == "C16c") + assert(ldict["bn_moving_mean"][0] == "C") + assert(ldict["bn_moving_mean_C16c"][0] == "C16c") + assert(ldict["bn_moving_var"][0] == "C") + assert(ldict["bn_moving_var_C16c"][0] == "C16c") + # but for other layout, it does a layout transform for data + g, ldict = correct_layout(g, "NCH16cW") + assert(ldict["data"][0] == "NCH16cW") + assert(ldict["data_NCHW16c"][0] == "NCHW16c") + assert(ldict["bn"][0] == "NCHW16c") + assert(ldict["bn"][1] == "C16c") + assert(ldict["bn"][2] == "C16c") + assert(ldict["bn_beta"][0] == "C") + assert(ldict["bn_beta_C16c"][0] == "C16c") + assert(ldict["bn_gamma"][0] == "C") + assert(ldict["bn_gamma_C16c"][0] == "C16c") + assert(ldict["bn_moving_mean"][0] == "C") + assert(ldict["bn_moving_mean_C16c"][0] == "C16c") + assert(ldict["bn_moving_var"][0] == "C") + assert(ldict["bn_moving_var_C16c"][0] == "C16c") + + +def test_flatten(): + x = sym.Variable("x", shape=(10, 20, 10, 10)) + y = sym.flatten(x, name="y") + g, ldict = correct_layout(y, "NCHW") + assert(ldict["x"][0] == "NCHW") + assert(ldict["y"][0] == "__undef__") + # second pass will insert layout transform + _, ldict = correct_layout(g, "NCHW16c") + assert(ldict["x"][0] == "NCHW16c") + assert(ldict["x_NCHW"][0] == "NCHW") + assert(ldict["y"][0] == "__undef__") + + +# Level 2 +def test_conv2d(): + x = sym.Variable("data", shape=(1, 32, 512, 512)) + y = sym.conv2d(x, name="conv", channels=12, + kernel_size=(3,3), padding=(1,1), layout="NCHW") + _, ldict = correct_layout(y) + assert(ldict["data"][0] == "NCHW") + assert(ldict["conv_weight"][0] == "OIHW") + assert(ldict["conv_bias"][0] == "C") + assert(ldict["conv"][0] == "NCHW") + y = sym.conv2d(x, name="conv", channels=12, + kernel_size=(3,3), padding=(1,1), layout="NCHW16c", + kernel_layout="OIHW16i16o", out_layout="NCHW8c") + _, ldict = correct_layout(y) + assert(ldict["data"][0] == "NCHW16c") + assert(ldict["conv_weight"][0] == "OIHW16i16o") + assert(ldict["conv_bias"][0] == "C8c") + assert(ldict["conv"][0] == "NCHW8c") + y = sym.conv2d(x, name="conv", channels=12, + kernel_size=(3,3), padding=(1,1), layout="N16cHWC") + _, ldict = correct_layout(y) + assert(ldict["data"][0] == "N16cHWC") + assert(ldict["conv_weight"][0] == "OIHW") + assert(ldict["conv_bias"][0] == "16cC") + assert(ldict["conv"][0] == "N16cHWC") + + +def test_conv2d_transpose(): + x = sym.Variable("data", shape=(1, 32, 512, 512)) + y = sym.conv2d_transpose(x, name="conv", channels=12, + kernel_size=(3,3), padding=(1,1), layout="NCHW") + _, ldict = correct_layout(y) + assert(ldict["data"][0] == "NCHW") + assert(ldict["conv_weight"][0] == "OIHW") + assert(ldict["conv_bias"][0] == "C") + assert(ldict["conv"][0] == "NCHW") + + +def test_max_pool2d(): + x = sym.Variable("data", shape=(1, 32, 512, 512)) + y = sym.max_pool2d(x, name="pool", pool_size=(3,3), + padding=(1,1), layout="NCHW") + g, ldict = correct_layout(y) + assert(ldict["data"][0] == "NCHW") + assert(ldict["pool"][0] == "NCHW") + # if index of H and W remain the same, + # pool2d does not convert the layout. + g, ldict = correct_layout(g, "NCHW16c") + assert(ldict["data"][0] == "NCHW16c") + assert(ldict["pool"][0] == "NCHW16c") + # for other layout it requires a layout transform. + g, ldict = correct_layout(g, "NHWC") + assert(ldict["data"][0] == "NHWC") + assert(ldict["data_NCHW"][0] == "NCHW") + assert(ldict["pool"][0] == "NCHW") + + +def test_global_pool2d(): + x = sym.Variable("data", shape=(1, 32, 512, 512)) + y = sym.global_max_pool2d(x, name="pool", layout="NCHW") + g, ldict = correct_layout(y) + assert(ldict["data"][0] == "NCHW") + assert(ldict["pool"][0] == "NCHW") + # if index of H and W remain the same, + # pool2d does not convert the layout. + g, ldict = correct_layout(g, "NCHW16c") + assert(ldict["data"][0] == "NCHW16c") + assert(ldict["pool"][0] == "NCHW16c") + # for other layout it requires a layout transform. + g, ldict = correct_layout(g, "NHWC") + assert(ldict["data"][0] == "NHWC") + assert(ldict["data_NCHW"][0] == "NCHW") + assert(ldict["pool"][0] == "NCHW") + + +# Level 3 +def test_reshape(): + x = sym.Variable("x", shape=(4,)) + y = sym.reshape(x, shape=(2,2), name="y") + g, ldict = correct_layout(y, "C") + assert(ldict["x"][0] == "C") + assert(ldict["y"][0] == "__undef__") + # second pass will insert layout transform + g, ldict = correct_layout(g, "C16c") + assert(ldict["x"][0] == "C16c") + assert(ldict["x_C"][0] == "C") + assert(ldict["y"][0] == "__undef__") + + +def test_transpose(): + x = sym.Variable("x", shape=(1, 32, 512, 512)) + y = sym.transpose(x, name="y", axes=(0, 2, 3, 1)) + g, ldict = correct_layout(y, "NCHW") + assert(ldict["x"][0] == "NCHW") + assert(ldict["y"][0] == "NHWC") + # second pass will insert layout transform + g, ldict = correct_layout(g, "NCHW16c") + assert(ldict["x"][0] == "NCHW16c") + assert(ldict["x_NCHW"][0] == "NCHW") + assert(ldict["y"][0] == "NHWC") + + +def test_broadcast_to(): + x = sym.Variable("x", shape=(4, 1)) + y = sym.broadcast_to(x, shape=(0, 4), name="y") + g, ldict = correct_layout(y, "HW") + assert(ldict["x"][0] == "HW") + assert(ldict["y"][0] == "__undef__") + # second pass will insert layout transform + g, ldict = correct_layout(g, "HW16h") + assert(ldict["x"][0] == "HW16h") + assert(ldict["x_HW"][0] == "HW") + assert(ldict["y"][0] == "__undef__") + + +def test_broadcast_binary(): + x = sym.Variable("x", shape=(1, 16, 512, 512)) + y = sym.Variable("y", shape=(16, 512, 512)) + z = sym.broadcast_add(x, y, name="z") + g, ldict = correct_layout(z, {"x": "NCHW", "y": "CHW"}) + assert(ldict["x"][0] == "NCHW") + assert(ldict["y"][0] == "CHW") + assert(ldict["z"][0] == "NCHW") + # prior to keep the left layout if they do not match. + g, ldict = correct_layout(g, {"x": "NCHW16c", "y": "CHW"}) + assert(ldict["x"][0] == "NCHW16c") + assert(ldict["y"][0] == "CHW") + assert(ldict["y_CHW16c"][0] == "CHW16c") + assert(ldict["z"][0] == "NCHW16c") + # broadcast_add(HCW16c, N16nCH16cW) + g, ldict = correct_layout(z, {"x": "HCW16c", "y": "N16nCH16cW"}) + assert(ldict["x"][0] == "HCW16c") + assert(ldict["y"][0] == "N16nCH16cW") + assert(ldict["x_CH16cW"][0] == "CH16cW") + assert(ldict["z"][0] == "N16nCH16cW") + + +def test_reduce(): + x = sym.Variable("x", shape=(1, 16, 512, 512)) + y = sym.sum(x, name="y", axis=1) + g, ldict = correct_layout(y, "NCHW") + assert(ldict["x"][0] == "NCHW") + assert(ldict["y"][0] == "__undef__") + # second pass will insert layout transform + g, ldict = correct_layout(g, "NCHW16c") + assert(ldict["x"][0] == "NCHW16c") + assert(ldict["x_NCHW"][0] == "NCHW") + assert(ldict["y"][0] == "__undef__") + + +if __name__ == "__main__": + test_dense() + test_matmul() + test_concatenate() + test_expand_dims() + test_split() + test_batchnorm() + test_flatten() + test_conv2d() + test_conv2d_transpose() + test_max_pool2d() + test_global_pool2d() + test_reshape() + test_transpose() + test_broadcast_to() + test_broadcast_binary() + test_reduce() \ No newline at end of file