From d04092221e91f1be2137b9e84ea8e60d4ae31927 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 15 Jun 2021 19:37:49 +0200 Subject: [PATCH 01/10] base changes for split op --- .../fluid/operators/mkldnn/split_mkldnn_op.cc | 145 +++++++++++++++ paddle/fluid/operators/split_op.cc | 24 +-- paddle/fluid/platform/mkldnn_reuse.h | 36 ++++ .../unittests/mkldnn/test_split_mkldnn_op.py | 171 ++++++++++++++++++ 4 files changed, 365 insertions(+), 11 deletions(-) create mode 100644 paddle/fluid/operators/mkldnn/split_mkldnn_op.cc create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py diff --git a/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc new file mode 100644 index 0000000000000..4cf920c835f5f --- /dev/null +++ b/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc @@ -0,0 +1,145 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/operators/utils.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; + +static inline std::vector> CalculateOutsDims( + const framework::DDim& in_dims, const size_t num, + const std::vector& sections, const size_t axis, const int outs_number) { + PADDLE_ENFORCE_NE(num, 0, platform::errors::InvalidArgument( + "Only num option is implemented for now, num " + "must be different than 0")); + + std::vector> outs_dims(outs_number, framework::vectorize(in_dims)); + + if (num > 0) { + PADDLE_ENFORCE_EQ( + in_dims[axis] % num, 0, + platform::errors::InvalidArgument( + "The input's size along the split dimension " + "must be evenly divisible by Attr(num_or_sections). " + "But received Attr(num_or_sections) " + "= %d, input(X)'s shape = [%s], Attr(dim) = %d.", + num, in_dims, axis)); + + const size_t out_axis_dim = in_dims[axis] / num; + + for (auto& out_dim : outs_dims) + out_dim[axis] = out_axis_dim; + } else { + for (size_t i=0;i +class SplitMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx); + } + + void RunKernel(const framework::ExecutionContext& ctx) const { + const auto& dev_ctx = + ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + const auto* x = ctx.Input("X"); + auto outs = ctx.MultiOutput("Out"); + + int num = ctx.Attr("num"); + auto sections = ctx.Attr>("sections"); + int axis = ctx.Attr("axis"); + auto outs_number = outs.size(); + const auto x_dims = x->dims(); + + bool need_resize = false; + if (ctx.HasInput("AxisTensor")) { + auto* axis_tensor = ctx.Input("AxisTensor"); + axis = GetDataFromTensor(axis_tensor)[0]; + need_resize = true; + } + + auto sections_tensor_list = + ctx.MultiInput("SectionsTensorList"); + if (sections_tensor_list.size() > 0) { + sections = GetDataFromTensorList(sections_tensor_list); + need_resize = true; + } + + if(need_resize){ + const auto outs_dims = CalculateOutsDims(x->dims(), num, sections, axis, outs_number); + for(size_t i=0;iResize(framework::make_ddim(outs_dims[i])); + } + } + + auto x_vec_dims = framework::vectorize(x_dims); + + mkldnn::memory::data_type x_type = + framework::ToMKLDNNDataType(x->type()); + std::string key = platform::CreateKey( + dev_ctx, x_vec_dims, x->format(), x->format(), x_type); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + + std::vector offset(x_vec_dims.size(), 0); + + platform::ReorderMKLDNNHandler reorder_handler( + x_vec_dims, x->type(), x_type, dev_ctx, onednn_engine, key); + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + x->format(), platform::to_void_cast(x->data())); + + for(size_t i=0;idims()); + const auto slice_md = reorder_src_memory_p->get_desc().submemory_desc(out_vec_dims, {offset}); + auto slice_mem_p = std::make_shared(slice_md, onednn_engine, reorder_src_memory_p->get_data_handle()); + + // change in mkldnn_reuse AcquireDstMemory and add keys in split case!!! new function is needed + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + outs[i], out_vec_dims, i, x->format(), ctx.GetPlace()); + auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, + slice_mem_p, i); + + reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p); + + offset[axis] += num > 0 ? x->dims()[axis] / num : sections[i]; + + outs[i]->set_layout(framework::DataLayout::kMKLDNN); + outs[i]->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); + } + astream.wait(); + // add also sections case + + //for(int i=0;iset_layout(framework::DataLayout::kMKLDNN); + // outs[i]->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); + //} + }// +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(split, MKLDNN, paddle::platform::CPUPlace, + ops::SplitMKLDNNKernel, + ops::SplitMKLDNNKernel); + diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index 0151778075de0..cc0fce1942948 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -73,18 +73,17 @@ class SplitOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); - } - - framework::OpKernelType GetKernelTypeForVar( - const std::string &var_name, const Tensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (var_name == "AxisTensor" || var_name == "SectionsTensorList") { - return expected_kernel_type; + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), tensor.layout()); +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -136,6 +135,9 @@ This operator splits the input tensor into multiple sub-tensors. "(int, default 0) " "The axis which the input will be split on.") .SetDefault(0); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); } }; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 2981e5502ce6a..4bf2e885c551c 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -1032,6 +1032,42 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { return mem_p; } + std::shared_ptr AcquireDstMemory( + framework::Tensor* output, const std::vector& dims, const int memory_number, + const MKLDNNMemoryFormat& fmt, platform::Place place) { + auto local_key = key_ + "@user_dst_mem" + std::to_string(memory_number) + "_p"; + auto mem_p = + std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); + if (mem_p == nullptr) { + auto dst_md = platform::MKLDNNMemDesc(dims, dtype_dst_, fmt); + auto dst_data = + output->mutable_data(place, vtype_dst_, dst_md.get_size()); + + mem_p = std::make_shared(dst_md, engine_, dst_data); + dev_ctx_.SetBlob(local_key, mem_p); + } else { + // Even if memory object exists , we may be using it for diffrent tensor + auto dst_data = + output->mutable_data(place, vtype_dst_, mem_p->get_desc().get_size()); + mem_p->set_data_handle(dst_data); + } + return mem_p; + } + + std::shared_ptr AcquireReorder( + std::shared_ptr dst_memory_p, + std::shared_ptr src_memory_p, int reorder_number) { + auto prim_key = key_ + "@reorder" + std::to_string(reorder_number) + "_p"; + auto reorder_p = + std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); + if (reorder_p == nullptr) { + reorder_p = + std::make_shared(*(src_memory_p), *(dst_memory_p)); + dev_ctx_.SetBlob(prim_key, reorder_p); + } + return reorder_p; + } + std::shared_ptr AcquireReorder( std::shared_ptr dst_memory_p, std::shared_ptr src_memory_p) { diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py new file mode 100644 index 0000000000000..dc4bfefffbe2d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py @@ -0,0 +1,171 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard, core +from paddle.fluid.tests.unittests.op_test import OpTest + + +class TestSplitSectionsOneDNNOp(OpTest): + def init_data(self): + self.x = np.random.random((4, 5, 6)).astype("float32") + self.axis = 1 + self.sections = [2, 1, 2] + self.num = None + indices_or_sections = [2, 3] # sections + np_sections = [2, 3] + self.out = np.split(self.x, np_sections, self.axis) + + def setUp(self): + self.op_type = "split" + self.axis_tensor = None + self.init_data() + self.inputs = {'X': self.x} + self.attrs = {'use_mkldnn' : True} + + if self.axis is not None: + self.attrs['axis'] = self.axis + if self.num is not None: + self.attrs['num'] = self.num + if self.sections is not None: + self.attrs['sections'] = self.sections + if self.axis_tensor is not None: + self.inputs['AxisTensor'] = self.axis_tensor + + self.outputs = {'Out': [('out%d' % i, self.out[i]) \ + for i in range(len(self.out))]} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], ['out0', 'out1', 'out2']) + + +# test with attr(num) +class TestSplitNumOneDNNop(TestSplitSectionsOneDNNOp): + def init_data(self): + self.x = np.random.random((4, 8, 5)).astype("float32") + self.axis = 1 + self.sections = [] + self.num = 4 + indices_or_sections = 4 #indices + self.out = np.split(self.x, indices_or_sections, self.axis) + + def test_check_grad(self): + self.check_grad(['X'], ['out0', 'out1', 'out2', 'out3']) + + +class TestSplitNumAxisTensorOneDNN(TestSplitSectionsOneDNNOp): + def init_data(self): + self.x = np.random.random((4, 5, 6)).astype("float32") + self.axis = 1 + self.sections = [] + self.num = 3 + indices_or_sections = 3 #indices + self.axis_tensor = np.array([2]).astype("int32") + self.out = np.split(self.x, indices_or_sections, 2) + + def test_check_output(self): + self.check_output() + +# +# attr(sections) is list containing Tensor +#class TestSplitOp_SectionsTensor(OpTest): +# def setUp(self): +# self._set_op_type() +# self.dtype = self.get_dtype() +# self.init_data() +# self.inputs = {'X': self.x} +# +# sections_tensor = [] +# for index, ele in enumerate(self.sections): +# sections_tensor.append(("x" + str(index), np.ones( +# (1)).astype('int32') * ele)) +# +# self.inputs['SectionsTensorList'] = sections_tensor +# +# self.attrs = { +# 'axis': self.axis, +# 'sections': self.sections_infer, +# 'num': self.num +# } +# +# out = np.split(self.x, self.indices_or_sections, self.axis) +# self.outputs = {'Out': [('out%d' % i, out[i]) \ +# for i in range(len(out))]} +# +# def init_data(self): +# self.x = np.random.random((4, 5, 6)).astype(self.dtype) +# self.axis = 1 +# self.sections = [2, 1, 2] +# self.sections_infer = [-1, -1, -1] +# self.num = 0 +# self.indices_or_sections = [2, 3] +# +# def get_dtype(self): +# return "float64" +# +# def _set_op_type(self): +# self.op_type = "split" +# +# def test_check_output(self): +# self.check_output() +# +# def test_check_grad(self): +# self.check_grad(['X'], ['out0', 'out1', 'out2']) + + +#class TestSplitOp_unk_section(OpTest): +# def setUp(self): +# self._set_op_type() +# self.dtype = self.get_dtype() +# self.init_data() +# self.inputs = {'X': self.x} +# self.attrs = { +# 'axis': self.axis, +# 'sections': self.sections, +# 'num': self.num +# } +# +# out = np.split(self.x, self.indices_or_sections, self.axis) +# self.outputs = {'Out': [('out%d' % i, out[i]) \ +# for i in range(len(out))]} +# +# def init_data(self): +# self.x = np.random.random((4, 5, 6)).astype(self.dtype) +# self.axis = 2 +# self.sections = [2, 1, -1] +# self.num = 0 +# self.indices_or_sections = [2, 3] +# +# def get_dtype(self): +# return "float64" +# +# def _set_op_type(self): +# self.op_type = "split" +# +# def test_check_output(self): +# self.check_output() +# +# def test_check_grad(self): +# self.check_grad(['X'], ['out0', 'out1', 'out2']) + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() From e1e00f30b13563c130a768de487c993653cb24b7 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Wed, 16 Jun 2021 11:30:20 +0200 Subject: [PATCH 02/10] 90% of split functionality added --- .../operators/mkldnn/concat_mkldnn_op.cc | 12 ++++ .../fluid/operators/mkldnn/split_mkldnn_op.cc | 4 -- .../unittests/mkldnn/test_split_mkldnn_op.py | 70 ++++++------------- 3 files changed, 33 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index df1b5af121da9..7df951998c845 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/operators/utils.h" namespace paddle { namespace operators { @@ -156,6 +157,17 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { "The axis is expected to be in range of [%d, %d), but got %d", -rank, rank, concat_axis)); platform::MKLDNNDeviceContext::tls().log_lib_version(); + + if(ctx.HasInput("AxisTensor")){ + auto* axis_tensor = ctx.Input("AxisTensor"); + concat_axis = GetDataFromTensor(axis_tensor)[0]; + auto out_dims = multi_input[0]->dims(); + for(size_t i=1;idims()[concat_axis]; + } + output->Resize(out_dims); + } + if (concat_axis < 0) { concat_axis = concat_axis + rank; } diff --git a/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc index 4cf920c835f5f..7557c32e0b57b 100644 --- a/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc @@ -23,10 +23,6 @@ using paddle::framework::Tensor; static inline std::vector> CalculateOutsDims( const framework::DDim& in_dims, const size_t num, const std::vector& sections, const size_t axis, const int outs_number) { - PADDLE_ENFORCE_NE(num, 0, platform::errors::InvalidArgument( - "Only num option is implemented for now, num " - "must be different than 0")); - std::vector> outs_dims(outs_number, framework::vectorize(in_dims)); if (num > 0) { diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py index dc4bfefffbe2d..cc9cfa14a219c 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py @@ -34,6 +34,7 @@ def init_data(self): def setUp(self): self.op_type = "split" self.axis_tensor = None + self.sections_tensor_list = None self.init_data() self.inputs = {'X': self.x} self.attrs = {'use_mkldnn' : True} @@ -46,6 +47,10 @@ def setUp(self): self.attrs['sections'] = self.sections if self.axis_tensor is not None: self.inputs['AxisTensor'] = self.axis_tensor + if self.sections_tensor_list is not None: + print("haha") + print(self.sections_tensor_list) + self.inputs['SectionsTensorList'] = self.sections_tensor_list self.outputs = {'Out': [('out%d' % i, self.out[i]) \ for i in range(len(self.out))]} @@ -58,7 +63,7 @@ def test_check_grad(self): # test with attr(num) -class TestSplitNumOneDNNop(TestSplitSectionsOneDNNOp): +class TestSplitNumOneDNNOp(TestSplitSectionsOneDNNOp): def init_data(self): self.x = np.random.random((4, 8, 5)).astype("float32") self.axis = 1 @@ -71,7 +76,7 @@ def test_check_grad(self): self.check_grad(['X'], ['out0', 'out1', 'out2', 'out3']) -class TestSplitNumAxisTensorOneDNN(TestSplitSectionsOneDNNOp): +class TestSplitNumAxisTensorOneDNNOp(TestSplitSectionsOneDNNOp): def init_data(self): self.x = np.random.random((4, 5, 6)).astype("float32") self.axis = 1 @@ -81,55 +86,22 @@ def init_data(self): self.axis_tensor = np.array([2]).astype("int32") self.out = np.split(self.x, indices_or_sections, 2) - def test_check_output(self): - self.check_output() -# # attr(sections) is list containing Tensor -#class TestSplitOp_SectionsTensor(OpTest): -# def setUp(self): -# self._set_op_type() -# self.dtype = self.get_dtype() -# self.init_data() -# self.inputs = {'X': self.x} -# -# sections_tensor = [] -# for index, ele in enumerate(self.sections): -# sections_tensor.append(("x" + str(index), np.ones( -# (1)).astype('int32') * ele)) -# -# self.inputs['SectionsTensorList'] = sections_tensor -# -# self.attrs = { -# 'axis': self.axis, -# 'sections': self.sections_infer, -# 'num': self.num -# } -# -# out = np.split(self.x, self.indices_or_sections, self.axis) -# self.outputs = {'Out': [('out%d' % i, out[i]) \ -# for i in range(len(out))]} -# -# def init_data(self): -# self.x = np.random.random((4, 5, 6)).astype(self.dtype) -# self.axis = 1 -# self.sections = [2, 1, 2] -# self.sections_infer = [-1, -1, -1] -# self.num = 0 -# self.indices_or_sections = [2, 3] -# -# def get_dtype(self): -# return "float64" -# -# def _set_op_type(self): -# self.op_type = "split" -# -# def test_check_output(self): -# self.check_output() -# -# def test_check_grad(self): -# self.check_grad(['X'], ['out0', 'out1', 'out2']) - +class TestSplitSectionsTensorOneDNNOp(TestSplitSectionsOneDNNOp): + def init_data(self): + self.x = np.random.random((4, 5, 6)).astype("float32") + self.axis = 1 + self.sections = [2, 1, 2] + self.sections_tensor_list = [] + for index, ele in enumerate(self.sections): + self.sections_tensor_list.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + self.sections = [-1, -1, -1] + + self.num = 0 + indices_or_sections = [2, 3] #sections + self.out = np.split(self.x, indices_or_sections, self.axis) #class TestSplitOp_unk_section(OpTest): # def setUp(self): From 50fd22cdfba665c2c50937034879f79a7e726d76 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Wed, 16 Jun 2021 13:30:56 +0200 Subject: [PATCH 03/10] full fp32 functionality --- .../operators/mkldnn/concat_mkldnn_op.cc | 6 +- .../fluid/operators/mkldnn/split_mkldnn_op.cc | 85 +++++++++---------- paddle/fluid/platform/mkldnn_reuse.h | 8 +- .../unittests/mkldnn/test_split_mkldnn_op.py | 63 ++++---------- tools/static_mode_white_list.py | 1 + 5 files changed, 64 insertions(+), 99 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index 7df951998c845..df4750321e3fc 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -14,9 +14,9 @@ limitations under the License. */ #include #include "paddle/fluid/operators/concat_op.h" +#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_reuse.h" -#include "paddle/fluid/operators/utils.h" namespace paddle { namespace operators { @@ -158,11 +158,11 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { -rank, rank, concat_axis)); platform::MKLDNNDeviceContext::tls().log_lib_version(); - if(ctx.HasInput("AxisTensor")){ + if (ctx.HasInput("AxisTensor")) { auto* axis_tensor = ctx.Input("AxisTensor"); concat_axis = GetDataFromTensor(axis_tensor)[0]; auto out_dims = multi_input[0]->dims(); - for(size_t i=1;idims()[concat_axis]; } output->Resize(out_dims); diff --git a/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc index 7557c32e0b57b..b4185c37b502a 100644 --- a/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { namespace operators { @@ -22,26 +22,26 @@ using paddle::framework::Tensor; static inline std::vector> CalculateOutsDims( const framework::DDim& in_dims, const size_t num, - const std::vector& sections, const size_t axis, const int outs_number) { - std::vector> outs_dims(outs_number, framework::vectorize(in_dims)); + const std::vector& sections, const size_t axis, + const int outs_number) { + std::vector> outs_dims(outs_number, + framework::vectorize(in_dims)); if (num > 0) { - PADDLE_ENFORCE_EQ( - in_dims[axis] % num, 0, - platform::errors::InvalidArgument( - "The input's size along the split dimension " - "must be evenly divisible by Attr(num_or_sections). " - "But received Attr(num_or_sections) " - "= %d, input(X)'s shape = [%s], Attr(dim) = %d.", - num, in_dims, axis)); + PADDLE_ENFORCE_EQ(in_dims[axis] % num, 0, + platform::errors::InvalidArgument( + "The input's size along the split dimension " + "must be evenly divisible by Attr(num_or_sections). " + "But received Attr(num_or_sections) " + "= %d, input(X)'s shape = [%s], Attr(dim) = %d.", + num, in_dims, axis)); const size_t out_axis_dim = in_dims[axis] / num; - - for (auto& out_dim : outs_dims) - out_dim[axis] = out_axis_dim; + + for (auto& out_dim : outs_dims) out_dim[axis] = out_axis_dim; } else { - for (size_t i=0;i { need_resize = true; } - auto sections_tensor_list = - ctx.MultiInput("SectionsTensorList"); + auto sections_tensor_list = ctx.MultiInput("SectionsTensorList"); if (sections_tensor_list.size() > 0) { sections = GetDataFromTensorList(sections_tensor_list); need_resize = true; } - if(need_resize){ - const auto outs_dims = CalculateOutsDims(x->dims(), num, sections, axis, outs_number); - for(size_t i=0;idims(), num, sections, axis, outs_number); + for (size_t i = 0; i < outs.size(); ++i) { outs[i]->Resize(framework::make_ddim(outs_dims[i])); } } auto x_vec_dims = framework::vectorize(x_dims); - mkldnn::memory::data_type x_type = - framework::ToMKLDNNDataType(x->type()); - std::string key = platform::CreateKey( - dev_ctx, x_vec_dims, x->format(), x->format(), x_type); + mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type()); + std::string key = platform::CreateKey(dev_ctx, x_vec_dims, sections, + x->format(), x->format(), x_type); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); @@ -104,32 +103,27 @@ class SplitMKLDNNKernel : public framework::OpKernel { auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( x->format(), platform::to_void_cast(x->data())); - for(size_t i=0;idims()); - const auto slice_md = reorder_src_memory_p->get_desc().submemory_desc(out_vec_dims, {offset}); - auto slice_mem_p = std::make_shared(slice_md, onednn_engine, reorder_src_memory_p->get_data_handle()); + for (size_t i = 0; i < outs_number; ++i) { + auto out_vec_dims = framework::vectorize(outs[i]->dims()); + const auto slice_md = reorder_src_memory_p->get_desc().submemory_desc( + out_vec_dims, {offset}); + auto slice_mem_p = std::make_shared( + slice_md, onednn_engine, reorder_src_memory_p->get_data_handle()); - // change in mkldnn_reuse AcquireDstMemory and add keys in split case!!! new function is needed - auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( - outs[i], out_vec_dims, i, x->format(), ctx.GetPlace()); - auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, - slice_mem_p, i); + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + outs[i], out_vec_dims, i, x->format(), ctx.GetPlace()); + auto reorder_p = + reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p, i); - reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p); + reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p); - offset[axis] += num > 0 ? x->dims()[axis] / num : sections[i]; + offset[axis] += num > 0 ? x->dims()[axis] / num : sections[i]; - outs[i]->set_layout(framework::DataLayout::kMKLDNN); - outs[i]->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); + outs[i]->set_layout(framework::DataLayout::kMKLDNN); + outs[i]->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); } astream.wait(); - // add also sections case - - //for(int i=0;iset_layout(framework::DataLayout::kMKLDNN); - // outs[i]->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); - //} - }// + } }; } // namespace operators } // namespace paddle @@ -138,4 +132,3 @@ namespace ops = paddle::operators; REGISTER_OP_KERNEL(split, MKLDNN, paddle::platform::CPUPlace, ops::SplitMKLDNNKernel, ops::SplitMKLDNNKernel); - diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 4bf2e885c551c..23872d0503184 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -1033,9 +1033,11 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { } std::shared_ptr AcquireDstMemory( - framework::Tensor* output, const std::vector& dims, const int memory_number, - const MKLDNNMemoryFormat& fmt, platform::Place place) { - auto local_key = key_ + "@user_dst_mem" + std::to_string(memory_number) + "_p"; + framework::Tensor* output, const std::vector& dims, + const int memory_number, const MKLDNNMemoryFormat& fmt, + platform::Place place) { + auto local_key = + key_ + "@user_dst_mem" + std::to_string(memory_number) + "_p"; auto mem_p = std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); if (mem_p == nullptr) { diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py index cc9cfa14a219c..1b6cc63fbc719 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py @@ -26,8 +26,7 @@ def init_data(self): self.x = np.random.random((4, 5, 6)).astype("float32") self.axis = 1 self.sections = [2, 1, 2] - self.num = None - indices_or_sections = [2, 3] # sections + indices_or_sections = [2, 3] # sections np_sections = [2, 3] self.out = np.split(self.x, np_sections, self.axis) @@ -35,21 +34,18 @@ def setUp(self): self.op_type = "split" self.axis_tensor = None self.sections_tensor_list = None + self.num = 0 self.init_data() self.inputs = {'X': self.x} - self.attrs = {'use_mkldnn' : True} + self.attrs = {'use_mkldnn': True, 'num': self.num} if self.axis is not None: self.attrs['axis'] = self.axis - if self.num is not None: - self.attrs['num'] = self.num if self.sections is not None: self.attrs['sections'] = self.sections if self.axis_tensor is not None: self.inputs['AxisTensor'] = self.axis_tensor if self.sections_tensor_list is not None: - print("haha") - print(self.sections_tensor_list) self.inputs['SectionsTensorList'] = self.sections_tensor_list self.outputs = {'Out': [('out%d' % i, self.out[i]) \ @@ -65,11 +61,11 @@ def test_check_grad(self): # test with attr(num) class TestSplitNumOneDNNOp(TestSplitSectionsOneDNNOp): def init_data(self): - self.x = np.random.random((4, 8, 5)).astype("float32") + self.x = np.random.random((4, 8, 5, 3)).astype("float32") self.axis = 1 self.sections = [] self.num = 4 - indices_or_sections = 4 #indices + indices_or_sections = 4 #indices self.out = np.split(self.x, indices_or_sections, self.axis) def test_check_grad(self): @@ -79,10 +75,10 @@ def test_check_grad(self): class TestSplitNumAxisTensorOneDNNOp(TestSplitSectionsOneDNNOp): def init_data(self): self.x = np.random.random((4, 5, 6)).astype("float32") - self.axis = 1 + self.axis = None self.sections = [] self.num = 3 - indices_or_sections = 3 #indices + indices_or_sections = 3 #indices self.axis_tensor = np.array([2]).astype("int32") self.out = np.split(self.x, indices_or_sections, 2) @@ -98,45 +94,18 @@ def init_data(self): self.sections_tensor_list.append(("x" + str(index), np.ones( (1)).astype('int32') * ele)) self.sections = [-1, -1, -1] + indices_or_sections = [2, 3] #sections + self.out = np.split(self.x, indices_or_sections, self.axis) - self.num = 0 - indices_or_sections = [2, 3] #sections + +class TestSplitOpUnknownSectionOneDNNOp(TestSplitSectionsOneDNNOp): + def init_data(self): + self.x = np.random.random((4, 5, 6)).astype("float32") + self.axis = 2 + self.sections = [2, 2, -1] + indices_or_sections = [2, 4] #sections self.out = np.split(self.x, indices_or_sections, self.axis) -#class TestSplitOp_unk_section(OpTest): -# def setUp(self): -# self._set_op_type() -# self.dtype = self.get_dtype() -# self.init_data() -# self.inputs = {'X': self.x} -# self.attrs = { -# 'axis': self.axis, -# 'sections': self.sections, -# 'num': self.num -# } -# -# out = np.split(self.x, self.indices_or_sections, self.axis) -# self.outputs = {'Out': [('out%d' % i, out[i]) \ -# for i in range(len(out))]} -# -# def init_data(self): -# self.x = np.random.random((4, 5, 6)).astype(self.dtype) -# self.axis = 2 -# self.sections = [2, 1, -1] -# self.num = 0 -# self.indices_or_sections = [2, 3] -# -# def get_dtype(self): -# return "float64" -# -# def _set_op_type(self): -# self.op_type = "split" -# -# def test_check_output(self): -# self.check_output() -# -# def test_check_grad(self): -# self.check_grad(['X'], ['out0', 'out1', 'out2']) if __name__ == '__main__': paddle.enable_static() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index d1e4680e63f95..76fd474697546 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -474,6 +474,7 @@ 'test_split_and_merge_lod_tensor_op', 'test_split_ids_op', 'test_split_op', + 'test_split_mkldnn_op' 'test_spp_op', 'test_square_error_cost', 'test_squared_l2_norm_op', From a09a69145bbb55c9831b704abd07ab802730721f Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Wed, 16 Jun 2021 19:35:26 +0200 Subject: [PATCH 04/10] added bf16 test --- .../fluid/operators/mkldnn/split_mkldnn_op.cc | 12 +- paddle/fluid/platform/mkldnn_reuse.h | 19 +++ .../mkldnn/test_split_bf16_mkldnn_op.py | 114 ++++++++++++++++++ .../unittests/mkldnn/test_split_mkldnn_op.py | 2 +- 4 files changed, 141 insertions(+), 6 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py diff --git a/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc index b4185c37b502a..36ba6f4c3cc05 100644 --- a/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc @@ -91,7 +91,7 @@ class SplitMKLDNNKernel : public framework::OpKernel { auto x_vec_dims = framework::vectorize(x_dims); mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type()); - std::string key = platform::CreateKey(dev_ctx, x_vec_dims, sections, + std::string key = platform::CreateKey(dev_ctx, x_vec_dims, axis, num, sections, x->format(), x->format(), x_type); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); @@ -105,10 +105,12 @@ class SplitMKLDNNKernel : public framework::OpKernel { for (size_t i = 0; i < outs_number; ++i) { auto out_vec_dims = framework::vectorize(outs[i]->dims()); - const auto slice_md = reorder_src_memory_p->get_desc().submemory_desc( - out_vec_dims, {offset}); - auto slice_mem_p = std::make_shared( - slice_md, onednn_engine, reorder_src_memory_p->get_data_handle()); + auto slice_mem_p = reorder_handler.AcquireSrcSubmemory(out_vec_dims, offset, reorder_src_memory_p); + //const auto slice_md = reorder_src_memory_p->get_desc().submemory_desc( + // out_vec_dims, {offset}); + //auto slice_mem_p = std::make_shared( + // slice_md, onednn_engine, reorder_src_memory_p->get_data_handle()); + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( outs[i], out_vec_dims, i, x->format(), ctx.GetPlace()); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 23872d0503184..efe51c2228efb 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -1010,6 +1010,25 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { return this->AcquireMemory(dims_, dtype_, fmt, ptr, "@user_src_mem_p"); } + std::shared_ptr AcquireSrcSubmemory( + const std::vector& dims, const std::vector& offset, + const std::shared_ptr mem_p){ + std::string local_key = key_; + AppendKey(&local_key, dims); + AppendKey(&local_key, offset); + + auto sub_mem_p = std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); + if (mem_p == nullptr){ + auto sub_md = mem_p->get_desc().submemory_desc(dims, {offset}); + auto sub_mem_p = std::make_shared( + sub_md, engine_, mem_p->get_data_handle()); + dev_ctx_.SetBlob(local_key, mem_p); + } else { + sub_mem_p->set_data_handle(mem_p->get_data_handle()); + } + return sub_mem_p; + } + std::shared_ptr AcquireDstMemory( framework::Tensor* output, const MKLDNNMemoryFormat& fmt, platform::Place place) { diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py new file mode 100644 index 0000000000000..2ae440a471fd1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py @@ -0,0 +1,114 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard, core +from paddle.fluid.tests.unittests.op_test import OpTest + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestSplitSectionsBF16OneDNNOp(OpTest): + def init_data(self): + self.x = np.random.random((4, 5, 6)).astype("uint16") + self.axis = 1 + self.sections = [2, 1, 2] + indices_or_sections = [2, 3] # sections + np_sections = [2, 3] + self.out = np.split(self.x, np_sections, self.axis) + + def setUp(self): + self.op_type = "split" + self.axis_tensor = None + self.sections_tensor_list = None + self.num = 0 + self.init_data() + self.inputs = {'X': self.x} + self.attrs = {'use_mkldnn': True, 'num': self.num} + + if self.axis is not None: + self.attrs['axis'] = self.axis + if self.sections is not None: + self.attrs['sections'] = self.sections + if self.axis_tensor is not None: + self.inputs['AxisTensor'] = self.axis_tensor + if self.sections_tensor_list is not None: + self.inputs['SectionsTensorList'] = self.sections_tensor_list + + self.outputs = {'Out': [('out%d' % i, self.out[i]) \ + for i in range(len(self.out))]} + + def test_check_output(self): + self.check_output(check_dygraph=False) + +# TODO jakpiase enable grad check(concat op) +# def test_check_grad(self): +# self.check_grad_with_place( +# core.CPUPlace(), ["X"], +# "Out", +# check_dygraph=False, +# user_defined_grads=[self.inputs['X']], +# user_defined_grad_outputs=self.out[0]) + + +class TestSplitNumBF16OneDNNOp(TestSplitSectionsBF16OneDNNOp): + def init_data(self): + self.x = np.random.random((4, 8, 5, 3)).astype("uint16") + self.axis = 1 + self.sections = [] + self.num = 4 + indices_or_sections = 4 #indices + self.out = np.split(self.x, indices_or_sections, self.axis) + + +class TestSplitNumAxisTensorBF16OneDNNOp(TestSplitSectionsBF16OneDNNOp): + def init_data(self): + self.x = np.random.random((4, 5, 6)).astype("uint16") + self.axis = None + self.sections = [] + self.num = 3 + indices_or_sections = 3 #indices + self.axis_tensor = np.array([2]).astype("int32") + self.out = np.split(self.x, indices_or_sections, 2) + + +class TestSplitSectionsTensorBF16OneDNNOp(TestSplitSectionsBF16OneDNNOp): + def init_data(self): + self.x = np.random.random((4, 5, 6)).astype("uint16") + self.axis = 1 + self.sections = [2, 1, 2] + self.sections_tensor_list = [] + for index, ele in enumerate(self.sections): + self.sections_tensor_list.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + self.sections = [-1, -1, -1] + indices_or_sections = [2, 3] #sections + self.out = np.split(self.x, indices_or_sections, self.axis) + + +class TestSplitOpUnknownSectionBF16OneDNNOp(TestSplitSectionsBF16OneDNNOp): + def init_data(self): + self.x = np.random.random((4, 5, 6)).astype("uint16") + self.axis = 2 + self.sections = [2, 2, -1] + indices_or_sections = [2, 4] #sections + self.out = np.split(self.x, indices_or_sections, self.axis) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py index 1b6cc63fbc719..55b56434f3eb1 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_split_mkldnn_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 1270057fbee5941ce7ceb50ee97afe8e606a8a36 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Thu, 17 Jun 2021 15:55:33 +0200 Subject: [PATCH 05/10] added submemory caching --- .../fluid/operators/mkldnn/split_mkldnn_op.cc | 12 +++---- paddle/fluid/operators/split_op.cc | 10 ++++++ paddle/fluid/platform/mkldnn_reuse.h | 32 ++++++++++--------- .../mkldnn/test_split_bf16_mkldnn_op.py | 2 ++ 4 files changed, 33 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc index 36ba6f4c3cc05..afbe330305b7e 100644 --- a/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc @@ -91,8 +91,8 @@ class SplitMKLDNNKernel : public framework::OpKernel { auto x_vec_dims = framework::vectorize(x_dims); mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type()); - std::string key = platform::CreateKey(dev_ctx, x_vec_dims, axis, num, sections, - x->format(), x->format(), x_type); + auto key = platform::CreateKey(dev_ctx, x_vec_dims, axis, num, sections, + x->format(), x_type); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); @@ -105,12 +105,8 @@ class SplitMKLDNNKernel : public framework::OpKernel { for (size_t i = 0; i < outs_number; ++i) { auto out_vec_dims = framework::vectorize(outs[i]->dims()); - auto slice_mem_p = reorder_handler.AcquireSrcSubmemory(out_vec_dims, offset, reorder_src_memory_p); - //const auto slice_md = reorder_src_memory_p->get_desc().submemory_desc( - // out_vec_dims, {offset}); - //auto slice_mem_p = std::make_shared( - // slice_md, onednn_engine, reorder_src_memory_p->get_data_handle()); - + auto slice_mem_p = reorder_handler.AcquireSrcSubmemory( + out_vec_dims, offset, reorder_src_memory_p, i); auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( outs[i], out_vec_dims, i, x->format(), ctx.GetPlace()); diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index cc0fce1942948..37a7575c12c2c 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -85,6 +85,16 @@ class SplitOp : public framework::OperatorWithKernel { #endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "AxisTensor" || var_name == "SectionsTensorList") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } }; class SplitOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index efe51c2228efb..f8fb3228827a1 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -1012,21 +1012,23 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { std::shared_ptr AcquireSrcSubmemory( const std::vector& dims, const std::vector& offset, - const std::shared_ptr mem_p){ - std::string local_key = key_; - AppendKey(&local_key, dims); - AppendKey(&local_key, offset); - - auto sub_mem_p = std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); - if (mem_p == nullptr){ - auto sub_md = mem_p->get_desc().submemory_desc(dims, {offset}); - auto sub_mem_p = std::make_shared( - sub_md, engine_, mem_p->get_data_handle()); - dev_ctx_.SetBlob(local_key, mem_p); - } else { - sub_mem_p->set_data_handle(mem_p->get_data_handle()); - } - return sub_mem_p; + const std::shared_ptr& mem_p, int submemory_number) { + std::string local_key = key_; + local_key.append("@submem") + .append(std::to_string(submemory_number)) + .append("_p"); + + auto sub_mem_p = + std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); + if (sub_mem_p == nullptr) { + auto sub_md = mem_p->get_desc().submemory_desc(dims, {offset}); + sub_mem_p = std::make_shared(sub_md, engine_, + mem_p->get_data_handle()); + dev_ctx_.SetBlob(local_key, sub_mem_p); + } else { + sub_mem_p->set_data_handle(mem_p->get_data_handle()); + } + return sub_mem_p; } std::shared_ptr AcquireDstMemory( diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py index 2ae440a471fd1..ba20b998b35fd 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py @@ -20,6 +20,7 @@ from paddle.fluid import compiler, Program, program_guard, core from paddle.fluid.tests.unittests.op_test import OpTest + @unittest.skipIf(not core.supports_bfloat16(), "place does not support BF16 evaluation") class TestSplitSectionsBF16OneDNNOp(OpTest): @@ -55,6 +56,7 @@ def setUp(self): def test_check_output(self): self.check_output(check_dygraph=False) + # TODO jakpiase enable grad check(concat op) # def test_check_grad(self): # self.check_grad_with_place( From 5afd5e32db062ec0db4c02c4c07a4f7f823df657 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Thu, 17 Jun 2021 16:00:23 +0200 Subject: [PATCH 06/10] added bf test to static mode whitelist --- tools/static_mode_white_list.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 76fd474697546..2bf887ddf3907 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -474,7 +474,8 @@ 'test_split_and_merge_lod_tensor_op', 'test_split_ids_op', 'test_split_op', - 'test_split_mkldnn_op' + 'test_split_mkldnn_op', + 'test_split_bf16_mkldnn_op', 'test_spp_op', 'test_square_error_cost', 'test_squared_l2_norm_op', From e72f151472a0354be1ade5aebe0f2e28b82c94bd Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Thu, 17 Jun 2021 20:57:11 +0200 Subject: [PATCH 07/10] minor change --- .../fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py index ba20b998b35fd..19407d8944a25 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py @@ -23,6 +23,8 @@ @unittest.skipIf(not core.supports_bfloat16(), "place does not support BF16 evaluation") +@unittest.skipIf(core.is_compiled_with_cuda(), + "core is compiled with CUDA which has no BF implementation") class TestSplitSectionsBF16OneDNNOp(OpTest): def init_data(self): self.x = np.random.random((4, 5, 6)).astype("uint16") From 2bfbf627087cbb6221d5c535616666b025756c24 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 22 Jun 2021 12:13:53 +0200 Subject: [PATCH 08/10] enabled split op for inference --- paddle/fluid/framework/ir/graph_pattern_detector.cc | 10 +++++----- paddle/fluid/operators/split_op.cc | 5 +++++ .../unittests/mkldnn/test_split_bf16_mkldnn_op.py | 8 ++++++-- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 064da3d941602..573cb7dcd09b0 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2262,11 +2262,11 @@ PDNode *patterns::QuantizePlacement::operator()( PDNode *patterns::Bfloat16Placement::operator()( const std::unordered_set &bfloat16_enabled_op_types) { std::unordered_set supported_op_types = - std::unordered_set({"concat", "conv2d", "conv2d_transpose", - "elementwise_add", "elementwise_mul", - "fc", "fusion_gru", "gelu", "layer_norm", - "matmul", "pool2d", "relu", "reshape2", - "softmax", "sum", "transpose2"}); + std::unordered_set( + {"concat", "conv2d", "conv2d_transpose", "elementwise_add", + "elementwise_mul", "fc", "fusion_gru", "gelu", "layer_norm", + "matmul", "pool2d", "relu", "reshape2", "softmax", "split", "sum", + "transpose2"}); if (!bfloat16_enabled_op_types.empty()) { supported_op_types = bfloat16_enabled_op_types; } diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index 37a7575c12c2c..661e4ca727bee 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -148,6 +148,11 @@ This operator splits the input tensor into multiple sub-tensors. AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "bfloat16"}); } }; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py index 19407d8944a25..200360859b076 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py @@ -41,7 +41,11 @@ def setUp(self): self.num = 0 self.init_data() self.inputs = {'X': self.x} - self.attrs = {'use_mkldnn': True, 'num': self.num} + self.attrs = { + 'use_mkldnn': True, + 'num': self.num, + 'mkldnn_data_type': "bfloat16" + } if self.axis is not None: self.attrs['axis'] = self.axis @@ -56,7 +60,7 @@ def setUp(self): for i in range(len(self.out))]} def test_check_output(self): - self.check_output(check_dygraph=False) + self.check_output_with_place(core.CPUPlace()) # TODO jakpiase enable grad check(concat op) From d849d16cdb81d3bd5de2b0b1f68dc7504ef7fab6 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 22 Jun 2021 12:28:00 +0200 Subject: [PATCH 09/10] minor fix --- .../fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py index 200360859b076..af270dfe535cf 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py @@ -68,7 +68,7 @@ def test_check_output(self): # self.check_grad_with_place( # core.CPUPlace(), ["X"], # "Out", -# check_dygraph=False, +# check_dygraph=, # user_defined_grads=[self.inputs['X']], # user_defined_grad_outputs=self.out[0]) From bdf98a18a9743b63deeea1b1368662500baaf965 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 22 Jun 2021 12:29:41 +0200 Subject: [PATCH 10/10] minor fix --- .../fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py index af270dfe535cf..4cb559fc15407 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py @@ -68,7 +68,7 @@ def test_check_output(self): # self.check_grad_with_place( # core.CPUPlace(), ["X"], # "Out", -# check_dygraph=, +# chck_dgrph= # user_defined_grads=[self.inputs['X']], # user_defined_grad_outputs=self.out[0])