diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index c0d813ccc215e..4e435660ff6dc 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -655,6 +655,24 @@ class MatMulOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument("reshape_out supported rank is 3, " "received %d", reshape_out_size)); + + auto it = std::find(reshape_out.begin(), reshape_out.end(), -1); + + // if "-1" is present then one of reshape dims must be infered + if (it != reshape_out.end()) { + int index = std::distance(reshape_out.begin(), it); + + auto ddim_out_vec = framework::vectorize(ddim_out); + + int ddim_out_product = + std::accumulate(ddim_out_vec.begin(), ddim_out_vec.end(), 1, + std::multiplies()); + int reshape_out_product = std::accumulate( + reshape_out.begin(), reshape_out.end(), -1, std::multiplies()); + + reshape_out[index] = ddim_out_product / reshape_out_product; + } + framework::DDim shape_out = ddim_out.transpose(transpose_out).reshape(reshape_out); context->SetOutputDim("Out", shape_out); diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 01c2d95a0782b..f14f92cb51fdb 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -358,6 +358,8 @@ inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size, } else if (data_format == MKLDNNMemoryFormat::nhwc) { return MKLDNNMemoryFormat::ndhwc; } + } else if (dims_size == 6) { + return MKLDNNMemoryFormat::abcdef; } return data_format; }