From 85e5165fae23ca7eb869cb3b21bd6de42360c092 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Sat, 22 May 2021 14:42:39 +0200 Subject: [PATCH 1/9] added op cast functionality for fp32/bf16 --- paddle/fluid/operators/cast_op.cc | 25 ++++ .../fluid/operators/mkldnn/cast_mkldnn_op.cc | 73 ++++++++++ paddle/fluid/platform/mkldnn_reuse.h | 29 +++- .../unittests/mkldnn/test_cast_mkldnn_op.py | 134 ++++++++++++++++++ .../paddle/fluid/tests/unittests/op_test.py | 18 ++- 5 files changed, 270 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/operators/mkldnn/cast_mkldnn_op.cc create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 7252ed72b2083..4119561a2c8aa 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -27,6 +27,9 @@ class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "The output tensor of cast op"); AddAttr("out_dtype", "output data type"); AddAttr("in_dtype", "input data type"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( Cast Operator. @@ -50,6 +53,7 @@ class CastOpGradMaker : public framework::SingleGradOpMaker { grad->SetOutput("Out", this->InputGrad("X")); grad->SetAttr("out_dtype", this->GetAttr("in_dtype")); grad->SetAttr("in_dtype", this->GetAttr("out_dtype")); + grad->SetAttr("use_mkldnn", this->GetAttr("use_mkldnn")); } }; @@ -77,6 +81,27 @@ class CastOp : public framework::OperatorWithKernel { if (platform::is_cuda_pinned_place(tensor_place)) { return framework::OpKernelType(tensor->type(), ctx.device_context()); } + +#ifdef PADDLE_WITH_MKLDNN + int in_dtype = ctx.Attr("in_dtype"); + int out_dtype = ctx.Attr("out_dtype"); + + auto MKLDNNSupportsCast = [&]() -> bool { + int dtype_fp32 = (int)framework::proto::VarType::FP32; + int dtype_bf16 = (int)framework::proto::VarType::BF16; + + if (in_dtype != dtype_fp32 && in_dtype != dtype_bf16) return false; + if (out_dtype != dtype_fp32 && out_dtype != dtype_bf16) return false; + + return true; + }; + + if (this->CanMKLDNNBeUsed(ctx, tensor->type()) && MKLDNNSupportsCast()) { + return framework::OpKernelType(tensor->type(), ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif return framework::OpKernelType(tensor->type(), tensor_place); } }; diff --git a/paddle/fluid/operators/mkldnn/cast_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/cast_mkldnn_op.cc new file mode 100644 index 0000000000000..7d94a9c700743 --- /dev/null +++ b/paddle/fluid/operators/mkldnn/cast_mkldnn_op.cc @@ -0,0 +1,73 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; + +template +class CastMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx); + } + + void RunKernel(const framework::ExecutionContext& ctx) const { + const auto& dev_ctx = + ctx.template device_context(); + + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + int in_dtype = ctx.Attr("in_dtype"); + int out_dtype = ctx.Attr("out_dtype"); + + auto x_paddle_type = framework::proto::VarType::Type(in_dtype); + auto out_paddle_type = framework::proto::VarType::Type(out_dtype); + + mkldnn::memory::data_type x_type = + framework::ToMKLDNNDataType(x_paddle_type); + mkldnn::memory::data_type out_type = + framework::ToMKLDNNDataType(out_paddle_type); + + auto x_tz = framework::vectorize(x->dims()); + + std::string key = + platform::CreateKey(dev_ctx, x_tz, x->format(), x->format(), x_type); + platform::ReorderMKLDNNHandler reorder_handler( + x_tz, x_paddle_type, x_type, out_paddle_type, out_type, dev_ctx, + dev_ctx.GetEngine(), key); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + x->format(), platform::to_void_cast(x->data())); + auto reorder_dst_memory_p = + reorder_handler.AcquireDstMemory(out, x->format(), dev_ctx.GetPlace()); + auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, + reorder_src_memory_p); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(cast, MKLDNN, paddle::platform::CPUPlace, + ops::CastMKLDNNKernel, + ops::CastMKLDNNKernel); \ No newline at end of file diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 5ff6f893a8953..d6563be48fe48 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -926,7 +926,23 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { : platform::MKLDNNHandler(dev_ctx, engine, base_key), dims_(dims), vtype_(vtype), - dtype_(dtype) {} + vtype_dst_(vtype), + dtype_(dtype), + dtype_dst_(dtype) {} + + ReorderMKLDNNHandler(std::vector& dims, // NOLINT + framework::proto::VarType::Type vtype, + mkldnn::memory::data_type dtype, + framework::proto::VarType::Type vtype_dst, + mkldnn::memory::data_type dtype_dst, + const platform::MKLDNNDeviceContext& dev_ctx, + mkldnn::engine engine, const std::string& base_key) + : platform::MKLDNNHandler(dev_ctx, engine, base_key), + dims_(dims), + vtype_(vtype), + vtype_dst_(vtype_dst), + dtype_(dtype), + dtype_dst_(dtype_dst) {} std::shared_ptr AcquireSrcMemory( const MKLDNNMemoryFormat& fmt, void* ptr) { @@ -940,15 +956,16 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { auto mem_p = std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); if (mem_p == nullptr) { - auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt); - auto dst_data = output->mutable_data(place, vtype_, dst_md.get_size()); + auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_dst_, fmt); + auto dst_data = + output->mutable_data(place, vtype_dst_, dst_md.get_size()); mem_p = std::make_shared(dst_md, engine_, dst_data); dev_ctx_.SetBlob(local_key, mem_p); } else { // Even if memory object exists , we may be using it for diffrent tensor auto dst_data = - output->mutable_data(place, vtype_, mem_p->get_desc().get_size()); + output->mutable_data(place, vtype_dst_, mem_p->get_desc().get_size()); mem_p->set_data_handle(dst_data); } return mem_p; @@ -970,8 +987,8 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { private: std::vector dims_; - framework::proto::VarType::Type vtype_; - mkldnn::memory::data_type dtype_; + framework::proto::VarType::Type vtype_, vtype_dst_; + mkldnn::memory::data_type dtype_, dtype_dst_; }; template diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py new file mode 100644 index 0000000000000..955c9e6e813ab --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py @@ -0,0 +1,134 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float +import unittest +import numpy as np + +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestCastBF16ToFP32MKLDNNOp(OpTest): + def setUp(self): + self.x_fp32 = np.random.random(size=[10, 10]).astype("float32") + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + + self.inputs = {'X': self.x_bf16} + self.outputs = {'Out': self.x_fp32} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.BF16), + 'out_dtype': int(core.VarDesc.VarType.FP32), + 'use_mkldnn': True + } + self.op_type = 'cast' + + def test_check_output(self): + self.check_output(check_dygraph=False) + + def test_check_grad(self): + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + check_dygraph=False, + user_defined_grads=[self.x_bf16], + user_defined_grad_outputs=[self.x_fp32]) + + +class TestCastFP32ToBF16MKLDNNOp(OpTest): + def setUp(self): + self.x_fp32 = np.random.random(size=[2, 6]).astype("float32") + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + + self.inputs = {'X': self.x_fp32} + self.outputs = {'Out': self.x_bf16} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.FP32), + 'out_dtype': int(core.VarDesc.VarType.BF16), + 'use_mkldnn': True + } + self.op_type = 'cast' + + def test_check_output(self): + self.check_output(check_dygraph=False) + + def test_check_grad(self): + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + check_dygraph=False, + user_defined_grads=[self.x_fp32], + user_defined_grad_outputs=[self.x_bf16]) + + +class TestCastBF16ToBF16MKLDNNOp(OpTest): + def setUp(self): + self.x_fp32 = np.random.random(size=[6, 13]).astype("float32") + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + + self.inputs = {'X': self.x_bf16} + self.outputs = {'Out': self.x_bf16} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.BF16), + 'out_dtype': int(core.VarDesc.VarType.BF16), + 'use_mkldnn': True + } + self.op_type = 'cast' + + def test_check_output(self): + self.check_output(check_dygraph=False) + + def test_check_grad(self): + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + check_dygraph=False, + user_defined_grads=[self.x_bf16], + user_defined_grad_outputs=[self.x_bf16]) + + +class TestCastFP32ToFP32MKLDNNOp(OpTest): + def setUp(self): + self.x_fp32 = np.random.random(size=[7, 15]).astype("float32") + + self.inputs = {'X': self.x_fp32} + self.outputs = {'Out': self.x_fp32} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.FP32), + 'out_dtype': int(core.VarDesc.VarType.FP32), + 'use_mkldnn': True + } + self.op_type = 'cast' + + def test_check_output(self): + self.check_output(check_dygraph=False) + + def test_check_grad(self): + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + check_dygraph=False, + user_defined_grads=[self.x_fp32], + user_defined_grad_outputs=[self.x_fp32]) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 3524d1e553d1b..8a1046a73510f 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1193,6 +1193,10 @@ def find_actual(target_name, fetch_list): actual_t = convert_uint16_to_float(actual_t) atol = 0.03 + if expect_t.dtype == np.uint16 and actual_t.dtype == np.uint16: + expect_t = convert_uint16_to_float(expect_t) + actual_t = convert_uint16_to_float(actual_t) + atol = 0.03 # NOTE(zhiqiu): np.allclose([], [1.]) returns True # see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng if expect_t.size == 0: @@ -1501,13 +1505,21 @@ def check_grad_with_place(self, # comparison of bf16 results will happen as fp32 # loop over list of grads and convert bf16 to fp32 - fp32_grads = [] + fp32_analytic_grads = [] for grad in analytic_grads: if grad.dtype == np.uint16: grad = convert_uint16_to_float(grad) max_relative_error = 0.03 - fp32_grads.append(grad) - analytic_grads = fp32_grads + fp32_analytic_grads.append(grad) + analytic_grads = fp32_analytic_grads + + fp32_numeric_grads = [] + for grad in numeric_grads: + if grad.dtype == np.uint16: + grad = convert_uint16_to_float(grad) + max_relative_error = 0.03 + fp32_numeric_grads.append(grad) + numeric_grads = fp32_numeric_grads self._assert_is_close(numeric_grads, analytic_grads, inputs_to_check, max_relative_error, From d9403c0d899baf4587e5f64ce8a461f5ead99ed3 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Sat, 22 May 2021 14:44:58 +0200 Subject: [PATCH 2/9] added newline --- paddle/fluid/operators/mkldnn/cast_mkldnn_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/mkldnn/cast_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/cast_mkldnn_op.cc index 7d94a9c700743..9cfeace6bef99 100644 --- a/paddle/fluid/operators/mkldnn/cast_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/cast_mkldnn_op.cc @@ -70,4 +70,4 @@ class CastMKLDNNKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(cast, MKLDNN, paddle::platform::CPUPlace, ops::CastMKLDNNKernel, - ops::CastMKLDNNKernel); \ No newline at end of file + ops::CastMKLDNNKernel); From 8753181ce2269a7814b345f81e48ba7c74994206 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Sat, 22 May 2021 15:24:06 +0200 Subject: [PATCH 3/9] added entries in static mode white list and unity build --- paddle/fluid/operators/unity_build_rule.cmake | 1 + tools/static_mode_white_list.py | 1 + 2 files changed, 2 insertions(+) diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index cd8b31d72e72a..4cca6f7fa1285 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -30,6 +30,7 @@ register_unity_group(cc bmm_op.cc bpr_loss_op.cc cast_op.cc + mkldnn/cast_mkldnn_op.cc cholesky_op.cc chunk_eval_op.cc clip_by_norm_op.cc diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 15bcae826064d..00fd0a96ca6a6 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -587,6 +587,7 @@ 'test_matmul_op_with_head', 'test_var_conv_2d', 'test_batch_norm_mkldnn_op', + 'test_cast_mkldnn_op', 'test_concat_int8_mkldnn_op', 'test_concat_bf16_mkldnn_op', 'test_concat_mkldnn_op', From ce30ebb0dd574f612ea932937d62de13853c6f9d Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Mon, 24 May 2021 12:57:07 +0200 Subject: [PATCH 4/9] fixed failing tests --- python/paddle/fluid/tests/unittests/op_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 8a1046a73510f..654723d862990 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1191,12 +1191,12 @@ def find_actual(target_name, fetch_list): np.float32, np.float64 ]: actual_t = convert_uint16_to_float(actual_t) - atol = 0.03 + atol = max(atol, 0.03) if expect_t.dtype == np.uint16 and actual_t.dtype == np.uint16: expect_t = convert_uint16_to_float(expect_t) actual_t = convert_uint16_to_float(actual_t) - atol = 0.03 + atol = max(atol, 0.03) # NOTE(zhiqiu): np.allclose([], [1.]) returns True # see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng if expect_t.size == 0: From a37ca869dd62fd6cb812f92dd9cf4b9f9ba3e82b Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 25 May 2021 11:30:44 +0200 Subject: [PATCH 5/9] changes after review --- paddle/fluid/operators/cast_op.cc | 3 +- .../unittests/mkldnn/test_cast_mkldnn_op.py | 45 +++---------------- 2 files changed, 7 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 4119561a2c8aa..4bf3f4e577902 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -90,8 +90,7 @@ class CastOp : public framework::OperatorWithKernel { int dtype_fp32 = (int)framework::proto::VarType::FP32; int dtype_bf16 = (int)framework::proto::VarType::BF16; - if (in_dtype != dtype_fp32 && in_dtype != dtype_bf16) return false; - if (out_dtype != dtype_fp32 && out_dtype != dtype_bf16) return false; + if ((in_dtype != dtype_fp32 && in_dtype != dtype_bf16) or (out_dtype != dtype_fp32 && out_dtype != dtype_bf16)) return false; return true; }; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py index 955c9e6e813ab..27c9852b8d02d 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -48,11 +48,11 @@ def test_check_grad(self): core.CPUPlace(), ["X"], "Out", check_dygraph=False, - user_defined_grads=[self.x_bf16], - user_defined_grad_outputs=[self.x_fp32]) + user_defined_grads=[self.inputs['X']], + user_defined_grad_outputs=[self.outputs['Out']]) -class TestCastFP32ToBF16MKLDNNOp(OpTest): +class TestCastFP32ToBF16MKLDNNOp(TestCastBF16ToFP32MKLDNNOp): def setUp(self): self.x_fp32 = np.random.random(size=[2, 6]).astype("float32") self.x_bf16 = convert_float_to_uint16(self.x_fp32) @@ -66,19 +66,8 @@ def setUp(self): } self.op_type = 'cast' - def test_check_output(self): - self.check_output(check_dygraph=False) - - def test_check_grad(self): - self.check_grad_with_place( - core.CPUPlace(), ["X"], - "Out", - check_dygraph=False, - user_defined_grads=[self.x_fp32], - user_defined_grad_outputs=[self.x_bf16]) - -class TestCastBF16ToBF16MKLDNNOp(OpTest): +class TestCastBF16ToBF16MKLDNNOp(TestCastBF16ToFP32MKLDNNOp): def setUp(self): self.x_fp32 = np.random.random(size=[6, 13]).astype("float32") self.x_bf16 = convert_float_to_uint16(self.x_fp32) @@ -92,19 +81,8 @@ def setUp(self): } self.op_type = 'cast' - def test_check_output(self): - self.check_output(check_dygraph=False) - def test_check_grad(self): - self.check_grad_with_place( - core.CPUPlace(), ["X"], - "Out", - check_dygraph=False, - user_defined_grads=[self.x_bf16], - user_defined_grad_outputs=[self.x_bf16]) - - -class TestCastFP32ToFP32MKLDNNOp(OpTest): +class TestCastFP32ToFP32MKLDNNOp(TestCastBF16ToFP32MKLDNNOp): def setUp(self): self.x_fp32 = np.random.random(size=[7, 15]).astype("float32") @@ -117,17 +95,6 @@ def setUp(self): } self.op_type = 'cast' - def test_check_output(self): - self.check_output(check_dygraph=False) - - def test_check_grad(self): - self.check_grad_with_place( - core.CPUPlace(), ["X"], - "Out", - check_dygraph=False, - user_defined_grads=[self.x_fp32], - user_defined_grad_outputs=[self.x_fp32]) - if __name__ == '__main__': paddle.enable_static() From be9e734ff10604ba4eee2e5cb8f60896ca69c6d9 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 25 May 2021 11:32:46 +0200 Subject: [PATCH 6/9] added formatting --- paddle/fluid/operators/cast_op.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 4bf3f4e577902..f976fd62fd82e 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -90,7 +90,9 @@ class CastOp : public framework::OperatorWithKernel { int dtype_fp32 = (int)framework::proto::VarType::FP32; int dtype_bf16 = (int)framework::proto::VarType::BF16; - if ((in_dtype != dtype_fp32 && in_dtype != dtype_bf16) or (out_dtype != dtype_fp32 && out_dtype != dtype_bf16)) return false; + if ((in_dtype != dtype_fp32 && in_dtype != dtype_bf16) or + (out_dtype != dtype_fp32 && out_dtype != dtype_bf16)) + return false; return true; }; From ada7c313f6bab66abb546d01949c2c14c00fdbfe Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 25 May 2021 12:42:58 +0200 Subject: [PATCH 7/9] upgraded tests file as reviewer suggested --- .../unittests/mkldnn/test_cast_mkldnn_op.py | 61 ++++++------------- 1 file changed, 19 insertions(+), 42 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py index 27c9852b8d02d..8986f1f84e36e 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py @@ -27,15 +27,18 @@ @unittest.skipIf(not core.supports_bfloat16(), "place does not support BF16 evaluation") class TestCastBF16ToFP32MKLDNNOp(OpTest): - def setUp(self): - self.x_fp32 = np.random.random(size=[10, 10]).astype("float32") - self.x_bf16 = convert_float_to_uint16(self.x_fp32) + def init_data(self): + self.out = np.random.random(size=[10, 10]).astype("float32") + self.x = convert_float_to_uint16(self.out) - self.inputs = {'X': self.x_bf16} - self.outputs = {'Out': self.x_fp32} + def setUp(self): + self.init_data() + self.inputs = {'X': self.x} + self.outputs = {'Out': self.out} + prepare_dtype = lambda x: int(core.VarDesc.VarType.BF16 if x.dtype != np.float32 else core.VarDesc.VarType.FP32) self.attrs = { - 'in_dtype': int(core.VarDesc.VarType.BF16), - 'out_dtype': int(core.VarDesc.VarType.FP32), + 'in_dtype': prepare_dtype(self.x), + 'out_dtype': prepare_dtype(self.out), 'use_mkldnn': True } self.op_type = 'cast' @@ -53,47 +56,21 @@ def test_check_grad(self): class TestCastFP32ToBF16MKLDNNOp(TestCastBF16ToFP32MKLDNNOp): - def setUp(self): - self.x_fp32 = np.random.random(size=[2, 6]).astype("float32") - self.x_bf16 = convert_float_to_uint16(self.x_fp32) - - self.inputs = {'X': self.x_fp32} - self.outputs = {'Out': self.x_bf16} - self.attrs = { - 'in_dtype': int(core.VarDesc.VarType.FP32), - 'out_dtype': int(core.VarDesc.VarType.BF16), - 'use_mkldnn': True - } - self.op_type = 'cast' + def init_data(self): + self.x = np.random.random(size=[2, 6]).astype("float32") + self.out = convert_float_to_uint16(self.x) class TestCastBF16ToBF16MKLDNNOp(TestCastBF16ToFP32MKLDNNOp): - def setUp(self): - self.x_fp32 = np.random.random(size=[6, 13]).astype("float32") - self.x_bf16 = convert_float_to_uint16(self.x_fp32) - - self.inputs = {'X': self.x_bf16} - self.outputs = {'Out': self.x_bf16} - self.attrs = { - 'in_dtype': int(core.VarDesc.VarType.BF16), - 'out_dtype': int(core.VarDesc.VarType.BF16), - 'use_mkldnn': True - } - self.op_type = 'cast' + def init_data(self): + self.x = np.random.random(size=[6, 13]).astype("uint16") + self.out = self.x class TestCastFP32ToFP32MKLDNNOp(TestCastBF16ToFP32MKLDNNOp): - def setUp(self): - self.x_fp32 = np.random.random(size=[7, 15]).astype("float32") - - self.inputs = {'X': self.x_fp32} - self.outputs = {'Out': self.x_fp32} - self.attrs = { - 'in_dtype': int(core.VarDesc.VarType.FP32), - 'out_dtype': int(core.VarDesc.VarType.FP32), - 'use_mkldnn': True - } - self.op_type = 'cast' + def init_data(self): + self.x = np.random.random(size=[7, 15]).astype("float32") + self.out = self.x if __name__ == '__main__': From 5af7681f650b6a6dc8627e15165f2877b63aef2f Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 25 May 2021 14:02:38 +0200 Subject: [PATCH 8/9] changes after review --- paddle/fluid/operators/cast_op.cc | 4 ++-- .../fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index f976fd62fd82e..171478e6ae867 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -87,8 +87,8 @@ class CastOp : public framework::OperatorWithKernel { int out_dtype = ctx.Attr("out_dtype"); auto MKLDNNSupportsCast = [&]() -> bool { - int dtype_fp32 = (int)framework::proto::VarType::FP32; - int dtype_bf16 = (int)framework::proto::VarType::BF16; + int dtype_fp32 = static_cast(framework::proto::VarType::FP32); + int dtype_bf16 = static_cast(framework::proto::VarType::BF16); if ((in_dtype != dtype_fp32 && in_dtype != dtype_bf16) or (out_dtype != dtype_fp32 && out_dtype != dtype_bf16)) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py index 8986f1f84e36e..95de37fdc0251 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py @@ -14,7 +14,6 @@ from __future__ import print_function -from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float import unittest import numpy as np @@ -22,6 +21,7 @@ import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 @unittest.skipIf(not core.supports_bfloat16(), From 7bcee8e453c395a32ed30aa07baa6cdef138df37 Mon Sep 17 00:00:00 2001 From: jakpiase <62569058+jakpiase@users.noreply.github.com> Date: Tue, 25 May 2021 15:56:55 +0200 Subject: [PATCH 9/9] minor change --- paddle/fluid/operators/cast_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 171478e6ae867..952e9ca329f10 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -90,7 +90,7 @@ class CastOp : public framework::OperatorWithKernel { int dtype_fp32 = static_cast(framework::proto::VarType::FP32); int dtype_bf16 = static_cast(framework::proto::VarType::BF16); - if ((in_dtype != dtype_fp32 && in_dtype != dtype_bf16) or + if ((in_dtype != dtype_fp32 && in_dtype != dtype_bf16) || (out_dtype != dtype_fp32 && out_dtype != dtype_bf16)) return false;