Skip to content

Commit

Permalink
[NHWC] InferShape Layout conversion fix. (dmlc#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 authored and tqchen committed May 29, 2018
1 parent 3c3387d commit 833eb95
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
4 changes: 2 additions & 2 deletions nnvm/src/top/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
param.kernel_size[0],
param.kernel_size[1]});

wshape = ConvertLayout(wshape, kNCHW, param.layout);
wshape = ConvertLayout(wshape, kNCHW, param.layout, true);
wshape[0] *= param.groups;

NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
Expand Down Expand Up @@ -189,7 +189,7 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
param.channels / param.groups,
param.kernel_size[0],
param.kernel_size[1]});
wshape = ConvertLayout(wshape, kNCHW, param.layout);
wshape = ConvertLayout(wshape, kNCHW, param.layout, true);
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DTransposeParam::kWeight, wshape);

if (param.use_bias) {
Expand Down
28 changes: 21 additions & 7 deletions nnvm/src/top/nn/nn_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ inline std::vector<std::string> UseBiasListInputNames(const NodeAttrs& attrs) {
* \param dst_layout target layout
* \return shape in target layout
*/
inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout, bool is_weight = false) {
if (src_layout == dst_layout) return src;
TShape dst = src;
if (src.ndim() == 3) {
Expand Down Expand Up @@ -68,9 +68,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
switch (src_layout) {
case kNCHW: break;
case kNHWC: {
dst[2] = src[1];
dst[3] = src[2];
dst[1] = src[3];
if (is_weight) {
dst[2] = src[0];
dst[3] = src[1];
dst[1] = src[2];
dst[0] = src[3];
} else {
dst[2] = src[1];
dst[3] = src[2];
dst[1] = src[3];
}
break;
}
default: {
Expand All @@ -81,9 +88,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
switch (dst_layout) {
case kNCHW: break;
case kNHWC: {
dst[1] = src[2];
dst[2] = src[3];
dst[3] = src[1];
if (is_weight) {
dst[0] = src[2];
dst[1] = src[3];
dst[2] = src[1];
dst[3] = src[0];
} else {
dst[1] = src[2];
dst[2] = src[3];
dst[3] = src[1];
}
break;
}
default: {
Expand Down

0 comments on commit 833eb95

Please sign in to comment.