Skip to content

Commit

Permalink
changes to GetKernelTypeForVar
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpiase committed Jul 5, 2022
1 parent 24eff5e commit 4691061
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
14 changes: 6 additions & 8 deletions paddle/fluid/operators/pad2d_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,14 +722,12 @@ class Pad2dOp : public framework::OperatorWithKernel {
const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
return framework::OpKernelType(
expected_kernel_type.data_type_,
tensor.place(),
framework::StringToDataLayout(data_format));
(tensor.layout() != framework::DataLayout::kMKLDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(
Expand Down
14 changes: 6 additions & 8 deletions paddle/fluid/operators/pad3d_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,12 @@ class Pad3dOp : public framework::OperatorWithKernel {
const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
return framework::OpKernelType(
expected_kernel_type.data_type_,
tensor.place(),
framework::StringToDataLayout(data_format));
(tensor.layout() != framework::DataLayout::kMKLDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(
Expand Down

0 comments on commit 4691061

Please sign in to comment.