Skip to content

Commit

Permalink
added fix for matmul and support for 6 rank tensor (#35740)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpiase committed Sep 15, 2021
1 parent bd79ae0 commit e80acff
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
18 changes: 18 additions & 0 deletions paddle/fluid/operators/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,24 @@ class MatMulOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument("reshape_out supported rank is 3, "
"received %d",
reshape_out_size));

auto it = std::find(reshape_out.begin(), reshape_out.end(), -1);

// if "-1" is present then one of reshape dims must be infered
if (it != reshape_out.end()) {
int index = std::distance(reshape_out.begin(), it);

auto ddim_out_vec = framework::vectorize(ddim_out);

int ddim_out_product =
std::accumulate(ddim_out_vec.begin(), ddim_out_vec.end(), 1,
std::multiplies<int>());
int reshape_out_product = std::accumulate(
reshape_out.begin(), reshape_out.end(), -1, std::multiplies<int>());

reshape_out[index] = ddim_out_product / reshape_out_product;
}

framework::DDim shape_out =
ddim_out.transpose(transpose_out).reshape(reshape_out);
context->SetOutputDim("Out", shape_out);
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/platform/mkldnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
} else if (data_format == MKLDNNMemoryFormat::nhwc) {
return MKLDNNMemoryFormat::ndhwc;
}
} else if (dims_size == 6) {
return MKLDNNMemoryFormat::abcdef;
}
return data_format;
}
Expand Down

0 comments on commit e80acff

Please sign in to comment.