Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added cast op oneDNN kernel for bf16/fp32 datatypes casting(FWD/BWD) #33056

Merged
merged 9 commits into from
May 26, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions paddle/fluid/operators/cast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "The output tensor of cast op");
AddAttr<int>("out_dtype", "output data type");
AddAttr<int>("in_dtype", "input data type");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Cast Operator.

Expand All @@ -50,6 +53,7 @@ class CastOpGradMaker : public framework::SingleGradOpMaker<T> {
grad->SetOutput("Out", this->InputGrad("X"));
grad->SetAttr("out_dtype", this->GetAttr("in_dtype"));
grad->SetAttr("in_dtype", this->GetAttr("out_dtype"));
grad->SetAttr("use_mkldnn", this->GetAttr("use_mkldnn"));
}
};

Expand Down Expand Up @@ -77,6 +81,28 @@ class CastOp : public framework::OperatorWithKernel {
if (platform::is_cuda_pinned_place(tensor_place)) {
return framework::OpKernelType(tensor->type(), ctx.device_context());
}

#ifdef PADDLE_WITH_MKLDNN
int in_dtype = ctx.Attr<int>("in_dtype");
int out_dtype = ctx.Attr<int>("out_dtype");

auto MKLDNNSupportsCast = [&]() -> bool {
int dtype_fp32 = (int)framework::proto::VarType::FP32;
int dtype_bf16 = (int)framework::proto::VarType::BF16;
jakpiase marked this conversation as resolved.
Show resolved Hide resolved

if ((in_dtype != dtype_fp32 && in_dtype != dtype_bf16) or
(out_dtype != dtype_fp32 && out_dtype != dtype_bf16))
jakpiase marked this conversation as resolved.
Show resolved Hide resolved
return false;

return true;
};

if (this->CanMKLDNNBeUsed(ctx, tensor->type()) && MKLDNNSupportsCast()) {
return framework::OpKernelType(tensor->type(), ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(tensor->type(), tensor_place);
}
};
Expand Down
73 changes: 73 additions & 0 deletions paddle/fluid/operators/mkldnn/cast_mkldnn_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/platform/mkldnn_reuse.h"

namespace paddle {
namespace operators {

using paddle::framework::Tensor;

template <typename T>
class CastMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx);
}

void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();

auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");

int in_dtype = ctx.Attr<int>("in_dtype");
int out_dtype = ctx.Attr<int>("out_dtype");

auto x_paddle_type = framework::proto::VarType::Type(in_dtype);
auto out_paddle_type = framework::proto::VarType::Type(out_dtype);

mkldnn::memory::data_type x_type =
framework::ToMKLDNNDataType(x_paddle_type);
mkldnn::memory::data_type out_type =
framework::ToMKLDNNDataType(out_paddle_type);

auto x_tz = framework::vectorize(x->dims());

std::string key =
platform::CreateKey(dev_ctx, x_tz, x->format(), x->format(), x_type);
platform::ReorderMKLDNNHandler reorder_handler(
x_tz, x_paddle_type, x_type, out_paddle_type, out_type, dev_ctx,
dev_ctx.GetEngine(), key);

auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->format(), platform::to_void_cast(x->data<T>()));
auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(out, x->format(), dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);

auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();

out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_KERNEL(cast, MKLDNN, paddle::platform::CPUPlace,
ops::CastMKLDNNKernel<float>,
ops::CastMKLDNNKernel<paddle::platform::bfloat16>);
jakpiase marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions paddle/fluid/operators/unity_build_rule.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ register_unity_group(cc
bmm_op.cc
bpr_loss_op.cc
cast_op.cc
mkldnn/cast_mkldnn_op.cc
cholesky_op.cc
chunk_eval_op.cc
clip_by_norm_op.cc
Expand Down
29 changes: 23 additions & 6 deletions paddle/fluid/platform/mkldnn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,23 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
vtype_(vtype),
dtype_(dtype) {}
vtype_dst_(vtype),
dtype_(dtype),
dtype_dst_(dtype) {}

ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
framework::proto::VarType::Type vtype,
mkldnn::memory::data_type dtype,
framework::proto::VarType::Type vtype_dst,
mkldnn::memory::data_type dtype_dst,
const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
vtype_(vtype),
vtype_dst_(vtype_dst),
dtype_(dtype),
dtype_dst_(dtype_dst) {}

std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const MKLDNNMemoryFormat& fmt, void* ptr) {
Expand All @@ -940,15 +956,16 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
auto dst_data = output->mutable_data(place, vtype_, dst_md.get_size());
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_dst_, fmt);
auto dst_data =
output->mutable_data(place, vtype_dst_, dst_md.get_size());

mem_p = std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
// Even if memory object exists , we may be using it for diffrent tensor
auto dst_data =
output->mutable_data(place, vtype_, mem_p->get_desc().get_size());
output->mutable_data(place, vtype_dst_, mem_p->get_desc().get_size());
mem_p->set_data_handle(dst_data);
jakpiase marked this conversation as resolved.
Show resolved Hide resolved
}
return mem_p;
Expand All @@ -970,8 +987,8 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {

private:
std::vector<int64_t> dims_;
framework::proto::VarType::Type vtype_;
mkldnn::memory::data_type dtype_;
framework::proto::VarType::Type vtype_, vtype_dst_;
mkldnn::memory::data_type dtype_, dtype_dst_;
};

template <typename T>
Expand Down
78 changes: 78 additions & 0 deletions python/paddle/fluid/tests/unittests/mkldnn/test_cast_mkldnn_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float
jakpiase marked this conversation as resolved.
Show resolved Hide resolved
import unittest
import numpy as np

import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard


@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestCastBF16ToFP32MKLDNNOp(OpTest):
jakpiase marked this conversation as resolved.
Show resolved Hide resolved
def init_data(self):
self.out = np.random.random(size=[10, 10]).astype("float32")
self.x = convert_float_to_uint16(self.out)

def setUp(self):
self.init_data()
self.inputs = {'X': self.x}
self.outputs = {'Out': self.out}
prepare_dtype = lambda x: int(core.VarDesc.VarType.BF16 if x.dtype != np.float32 else core.VarDesc.VarType.FP32)
self.attrs = {
'in_dtype': prepare_dtype(self.x),
'out_dtype': prepare_dtype(self.out),
'use_mkldnn': True
}
self.op_type = 'cast'

def test_check_output(self):
self.check_output(check_dygraph=False)

def test_check_grad(self):
self.check_grad_with_place(
core.CPUPlace(), ["X"],
"Out",
check_dygraph=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def check_output_with_place(self,
place,
atol=0,
no_check_set=None,
equal_nan=False,
check_dygraph=True,
inplace_atol=None):
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
if self.dtype == np.float64 and \
self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST:
atol = 0
if self.is_bfloat16_op():
check_dygraph = False
if hasattr(self, 'force_fp32_output') and getattr(

since you add check_dygraph = False when bf16_op in check_output_with_place, could you add same code in check_grad_with_place? Thus, you can decrease one approve
图片

@jakpiase @lidanqing-intel @jczaja

@wzzju Please note this.

user_defined_grads=[self.inputs['X']],
user_defined_grad_outputs=[self.outputs['Out']])


class TestCastFP32ToBF16MKLDNNOp(TestCastBF16ToFP32MKLDNNOp):
def init_data(self):
self.x = np.random.random(size=[2, 6]).astype("float32")
self.out = convert_float_to_uint16(self.x)


class TestCastBF16ToBF16MKLDNNOp(TestCastBF16ToFP32MKLDNNOp):
def init_data(self):
self.x = np.random.random(size=[6, 13]).astype("uint16")
self.out = self.x


class TestCastFP32ToFP32MKLDNNOp(TestCastBF16ToFP32MKLDNNOp):
def init_data(self):
self.x = np.random.random(size=[7, 15]).astype("float32")
self.out = self.x


if __name__ == '__main__':
paddle.enable_static()
unittest.main()
20 changes: 16 additions & 4 deletions python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,8 +1191,12 @@ def find_actual(target_name, fetch_list):
np.float32, np.float64
]:
actual_t = convert_uint16_to_float(actual_t)
atol = 0.03
atol = max(atol, 0.03)

if expect_t.dtype == np.uint16 and actual_t.dtype == np.uint16:
expect_t = convert_uint16_to_float(expect_t)
actual_t = convert_uint16_to_float(actual_t)
atol = max(atol, 0.03)
# NOTE(zhiqiu): np.allclose([], [1.]) returns True
# see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
if expect_t.size == 0:
Expand Down Expand Up @@ -1501,13 +1505,21 @@ def check_grad_with_place(self,

# comparison of bf16 results will happen as fp32
# loop over list of grads and convert bf16 to fp32
fp32_grads = []
fp32_analytic_grads = []
for grad in analytic_grads:
if grad.dtype == np.uint16:
grad = convert_uint16_to_float(grad)
max_relative_error = 0.03
fp32_grads.append(grad)
analytic_grads = fp32_grads
fp32_analytic_grads.append(grad)
analytic_grads = fp32_analytic_grads

fp32_numeric_grads = []
for grad in numeric_grads:
if grad.dtype == np.uint16:
grad = convert_uint16_to_float(grad)
max_relative_error = 0.03
fp32_numeric_grads.append(grad)
numeric_grads = fp32_numeric_grads

self._assert_is_close(numeric_grads, analytic_grads, inputs_to_check,
max_relative_error,
Expand Down
1 change: 1 addition & 0 deletions tools/static_mode_white_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@
'test_matmul_op_with_head',
'test_var_conv_2d',
'test_batch_norm_mkldnn_op',
'test_cast_mkldnn_op',
'test_concat_int8_mkldnn_op',
'test_concat_bf16_mkldnn_op',
'test_concat_mkldnn_op',
Expand Down