From d9456334edb842a4b6eb4a7dfe860ea6ba337da3 Mon Sep 17 00:00:00 2001 From: zhiboniu Date: Mon, 5 Jul 2021 10:13:13 +0000 Subject: [PATCH] add paddle.Tensor api fill_(inplace), zero_(inplace) add fill_ backward --- paddle/fluid/operators/fill_any_op.cc | 108 ++++++++++++++++++ paddle/fluid/operators/fill_any_op.cu.cc | 35 ++++++ paddle/fluid/operators/fill_any_op.h | 65 +++++++++++ python/paddle/__init__.py | 6 +- python/paddle/fluid/framework.py | 10 +- .../fluid/tests/unittests/test_fill_any_op.py | 74 ++++++++++++ .../tests/unittests/test_tensor_fill_.py | 84 ++++++++++++++ .../tests/unittests/test_tensor_zero_.py | 47 ++++++++ python/paddle/tensor/manipulation.py | 69 +++++++++++ tools/static_mode_white_list.py | 1 + 10 files changed, 494 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/operators/fill_any_op.cc create mode 100644 paddle/fluid/operators/fill_any_op.cu.cc create mode 100644 paddle/fluid/operators/fill_any_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_fill_any_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_tensor_fill_.py create mode 100644 python/paddle/fluid/tests/unittests/test_tensor_zero_.py diff --git a/paddle/fluid/operators/fill_any_op.cc b/paddle/fluid/operators/fill_any_op.cc new file mode 100644 index 0000000000000..22d0c5fb8cbe0 --- /dev/null +++ b/paddle/fluid/operators/fill_any_op.cc @@ -0,0 +1,108 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fill_any_op.h" + +namespace paddle { +namespace operators { + +class FillAnyOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) The input tensor."); + AddOutput("Out", "Tensor, the tensor filled with input value "); + AddAttr("value_float", "The float var to fill in Tensor") + .SetDefault(0); + AddAttr("value_int", "The int var to fill in Tensor").SetDefault(0); + AddComment(R"DOC(Fill operator with backward; + Fill an tensor with `value`. + )DOC"); + }; +}; + +class FillAnyOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *context) const override { + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "FillAny"); + OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "FillAny"); + auto x_dims = context->GetInputDim("X"); + context->SetOutputDim("Out", x_dims); + } +}; + +class FillAnyGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "mul"); + auto x_dims = ctx->GetInputDim(framework::GradVarName("Out")); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } +}; + +template +class FillAnyGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType(this->ForwardOpType() + "_grad"); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_INPLACE_OP_INFERER(FillAnyOpInplaceInferer, {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(FillAnyGradInplaceInferer, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OPERATOR(fill_any, ops::FillAnyOp, ops::FillAnyOpMaker, + ops::FillAnyGradOpMaker, + ops::FillAnyGradOpMaker, + ops::FillAnyOpInplaceInferer); + +REGISTER_OPERATOR(fill_any_grad, ops::FillAnyGradOp, + ops::FillAnyGradInplaceInferer); + +REGISTER_OP_CPU_KERNEL( + fill_any, ops::FillAnyKernel, + ops::FillAnyKernel, + ops::FillAnyKernel, + ops::FillAnyKernel, + ops::FillAnyKernel, + ops::FillAnyKernel); + +REGISTER_OP_CPU_KERNEL( + fill_any_grad, + ops::FillAnyGradKernel, + ops::FillAnyGradKernel, + ops::FillAnyGradKernel, + ops::FillAnyGradKernel, + ops::FillAnyGradKernel, + ops::FillAnyGradKernel); diff --git a/paddle/fluid/operators/fill_any_op.cu.cc b/paddle/fluid/operators/fill_any_op.cu.cc new file mode 100644 index 0000000000000..9a587219eea60 --- /dev/null +++ b/paddle/fluid/operators/fill_any_op.cu.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fill_any_op.h" +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + fill_any, ops::FillAnyKernel, + ops::FillAnyKernel, + ops::FillAnyKernel, + ops::FillAnyKernel, + ops::FillAnyKernel, + ops::FillAnyKernel); + +REGISTER_OP_CUDA_KERNEL( + fill_any_grad, + ops::FillAnyGradKernel, + ops::FillAnyGradKernel, + ops::FillAnyGradKernel, + ops::FillAnyGradKernel, + ops::FillAnyGradKernel, + ops::FillAnyGradKernel); diff --git a/paddle/fluid/operators/fill_any_op.h b/paddle/fluid/operators/fill_any_op.h new file mode 100644 index 0000000000000..f483e05a08fd6 --- /dev/null +++ b/paddle/fluid/operators/fill_any_op.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class FillAnyKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext &ctx) const override { + auto *out = ctx.Output("Out"); + auto floatvar = ctx.template Attr("value_float"); + auto intvar = ctx.template Attr("value_int"); + auto isfloat = ((typeid(float) == typeid(T)) || + (typeid(double) == typeid(T) || + typeid(paddle::platform::float16) == typeid(T))); + + T fill_var = static_cast(floatvar); + if (!isfloat) { + fill_var = static_cast(intvar); + } + + PADDLE_ENFORCE_EQ( + std::isnan(static_cast(fill_var)), false, + platform::errors::InvalidArgument("fill value should not be NaN," + " but received NaN")); + + out->mutable_data(ctx.GetPlace()); + auto &dev_ctx = ctx.template device_context(); + math::SetConstant functor; + functor(reinterpret_cast(dev_ctx), out, + static_cast(fill_var)); + } +}; + +template +class FillAnyGradKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext &ctx) const override { + auto *dx = ctx.Output(framework::GradVarName("X")); + if (dx) { + dx->mutable_data(ctx.GetPlace()); + auto &dev_ctx = ctx.template device_context(); + math::SetConstant functor; + functor(reinterpret_cast(dev_ctx), dx, T(0)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 9d60a5b381575..d976ee32f58ce 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -276,6 +276,8 @@ from .fluid.framework import is_compiled_with_cuda # noqa: F401 from .fluid.framework import is_compiled_with_rocm # noqa: F401 from .fluid.framework import disable_signal_handler # noqa: F401 +from .fluid.framework import get_flags # noqa: F401 +from .fluid.framework import set_flags # noqa: F401 from .device import is_compiled_with_xpu # noqa: F401 from .device import is_compiled_with_npu # noqa: F401 from .device import XPUPlace # noqa: F401 @@ -521,5 +523,7 @@ 'standard_normal', 'diagonal', 'broadcast_tensors', - 'einsum' + 'einsum', + 'set_flags', + 'get_flags' ] diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 762bfb01fe14e..04666470ea847 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -6273,6 +6273,7 @@ def device_guard(device=None): def set_flags(flags): """ This function sets the GFlags value in Paddle. + For FLAGS please refer to :ref:`en_guides_flags_flags` Args: flags (dict): A dict contains flags and its value. @@ -6280,8 +6281,8 @@ def set_flags(flags): Examples: .. code-block:: python - import paddle.fluid as fluid - fluid.set_flags({'FLAGS_eager_delete_tensor_gb': 1.0}) + import paddle + paddle.set_flags({'FLAGS_eager_delete_tensor_gb': 1.0}) """ if not isinstance(flags, dict): raise TypeError('flags in set_flags should be a dict') @@ -6296,6 +6297,7 @@ def set_flags(flags): def get_flags(flags): """ This function gets the GFlags value in Paddle. + For FLAGS please refer to :ref:`en_guides_flags_flags` Args: flags(list|tuple|str): A list/tuple of string or a string which is the flag's name. @@ -6306,10 +6308,10 @@ def get_flags(flags): Examples: .. code-block:: python - import paddle.fluid as fluid + import paddle flags = ['FLAGS_eager_delete_tensor_gb', 'FLAGS_check_nan_inf'] - res = fluid.get_flags(flags) + res = paddle.get_flags(flags) print(res) # {'FLAGS_eager_delete_tensor_gb': 0.0, 'FLAGS_check_nan_inf': False} """ diff --git a/python/paddle/fluid/tests/unittests/test_fill_any_op.py b/python/paddle/fluid/tests/unittests/test_fill_any_op.py new file mode 100644 index 0000000000000..2066084753631 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fill_any_op.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid.core as core +import unittest +import numpy as np +from op_test import OpTest + + +class TestFillAnyOp(OpTest): + def setUp(self): + self.op_type = "fill_any" + self.dtype = 'float64' + self.value = 0.0 + self.init() + self.inputs = {'X': np.random.random((20, 30)).astype(self.dtype)} + self.attrs = { + 'value_float': float(self.value), + 'value_int': int(self.value) + } + self.outputs = { + 'Out': + self.value * np.ones_like(self.inputs["X"]).astype(self.dtype) + } + + def init(self): + pass + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestFillAnyOpFloat32(TestFillAnyOp): + def init(self): + self.dtype = np.float32 + self.value = 0.0 + + +class TestFillAnyOpFloat16(TestFillAnyOp): + def init(self): + self.dtype = np.float16 + + +class TestFillAnyOpvalue1(TestFillAnyOp): + def init(self): + self.dtype = np.float32 + self.value = 111111555 + + +class TestFillAnyOpvalue2(TestFillAnyOp): + def init(self): + self.dtype = np.float32 + self.value = 11111.1111 + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_fill_.py b/python/paddle/fluid/tests/unittests/test_tensor_fill_.py new file mode 100644 index 0000000000000..bc4456bb9696a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tensor_fill_.py @@ -0,0 +1,84 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import unittest +import numpy as np +import six +import paddle + + +class TensorFill_Test(unittest.TestCase): + def setUp(self): + self.shape = [32, 32] + + def test_tensor_fill_true(self): + typelist = ['float32', 'float64', 'int32', 'int64', 'float16'] + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + places.append(fluid.CUDAPinnedPlace()) + + for idx, p in enumerate(places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + np_arr = np.reshape( + np.array(six.moves.range(np.prod(self.shape))), self.shape) + for dtype in typelist: + var = 1. + tensor = paddle.to_tensor(np_arr, place=p, dtype=dtype) + newtensor = tensor.clone() + newtensor[...] = var + + tensor.fill_(var) #var type is basic type in typelist + self.assertEqual((tensor.numpy() == newtensor.numpy()).all(), + True) + + def test_tensor_fill_backward(self): + typelist = ['float32'] + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + places.append(fluid.CUDAPinnedPlace()) + + for idx, p in enumerate(places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + np_arr = np.reshape( + np.array(six.moves.range(np.prod(self.shape))), self.shape) + for dtype in typelist: + var = int(1) + tensor = paddle.to_tensor(np_arr, place=p, dtype=dtype) + tensor.stop_gradient = False + y = tensor * 2 + y.fill_(var) + loss = y.sum() + loss.backward() + + self.assertEqual((y.grad.numpy() == 0).all().item(), True) + + def test_errors(self): + def test_list(): + x = paddle.to_tensor([2, 3, 4]) + x.fill_([1]) + + self.assertRaises(TypeError, test_list) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_zero_.py b/python/paddle/fluid/tests/unittests/test_tensor_zero_.py new file mode 100644 index 0000000000000..716607710f11a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tensor_zero_.py @@ -0,0 +1,47 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import unittest +import numpy as np +import six +import paddle + + +class TensorFill_Test(unittest.TestCase): + def setUp(self): + self.shape = [32, 32] + + def test_tensor_fill_true(self): + typelist = ['float32', 'float64', 'int32', 'int64', 'float16'] + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + places.append(fluid.CUDAPinnedPlace()) + + for p in places: + np_arr = np.reshape( + np.array(six.moves.range(np.prod(self.shape))), self.shape) + for dtype in typelist: + tensor = paddle.to_tensor(np_arr, place=p, dtype=dtype) + newtensor = tensor.clone() + newtensor[...] = 0 + + tensor.zero_() + self.assertEqual( + (tensor.numpy() == newtensor.numpy()).all().item(), True) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 6d043a7592eda..30477d20e7518 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -29,6 +29,7 @@ from ..fluid.layers import scatter_nd # noqa: F401 from ..fluid.layers import shard_index # noqa: F401 +from ..fluid.layers.nn import _elementwise_op_in_dygraph from ..fluid import layers from ..fluid.dygraph.inplace_utils import inplace_apis_in_dygraph_only import paddle @@ -37,6 +38,74 @@ __all__ = [] +@dygraph_only +def fill_(x, value): + """ + **Notes**: + **This API is ONLY available in Dygraph mode** + + This function fill the Tensor with value inplace. + + Args: + x(Tensor): ``x`` is the Tensor we want to filled data inplace + value(Scale): ``value`` is the value to be filled in x + + Returns: + x(Tensor): Tensor x filled with value inplace + + Examples: + .. code-block:: python + + import paddle + + tensor = paddle.to_tensor([0, 1, 2, 3, 4]) + + tensor.fill_(0) + print(tensor.tolist()) #[0, 0, 0, 0, 0] + + """ + if not isinstance(value, (float, int)): + raise TypeError( + "The type of 'value' must be int or float, but received %s." % + (type(value))) + return core.ops.fill_any_(x, "value_float", + float(value), "value_int", int(value)) + + +setattr(core.VarBase, 'fill_', fill_) + + +@dygraph_only +def zero_(x): + """ + **Notes**: + **This API is ONLY available in Dygraph mode** + + This function fill the Tensor with zero inplace. + + Args: + x(Tensor): ``x`` is the Tensor we want to filled with zero inplace + + Returns: + x(Tensor): Tensor x filled with zero inplace + + Examples: + .. code-block:: python + + import paddle + + tensor = paddle.to_tensor([0, 1, 2, 3, 4]) + + tensor.zero_() + print(tensor.tolist()) #[0, 0, 0, 0, 0] + + """ + return core.ops.fill_any_(x, "value_float", 0., "value_int", int(0)) + + +setattr(core.VarBase, 'zero_', zero_) + + @dygraph_only def fill_diagonal_(x, value, offset=0, wrap=False, name=None): """ diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 3609cfd183bf3..5fa3a25f4caf0 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -722,5 +722,6 @@ 'test_c_embedding_op', 'test_class_center_sample_op', 'test_fill_diagonal_tensor_op', + 'test_fill_any_op', 'test_margin_cross_entropy_op', ]