Skip to content

Commit

Permalink
add paddle.Tensor api fill_(inplace), zero_(inplace)
Browse files Browse the repository at this point in the history
add fill_ backward
  • Loading branch information
zhiboniu committed Sep 13, 2021
1 parent a4b67f7 commit d945633
Show file tree
Hide file tree
Showing 10 changed files with 494 additions and 5 deletions.
108 changes: 108 additions & 0 deletions paddle/fluid/operators/fill_any_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fill_any_op.h"

namespace paddle {
namespace operators {

class FillAnyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input tensor.");
AddOutput("Out", "Tensor, the tensor filled with input value ");
AddAttr<float>("value_float", "The float var to fill in Tensor")
.SetDefault(0);
AddAttr<int>("value_int", "The int var to fill in Tensor").SetDefault(0);
AddComment(R"DOC(Fill operator with backward;
Fill an tensor with `value`.
)DOC");
};
};

class FillAnyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "FillAny");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "FillAny");
auto x_dims = context->GetInputDim("X");
context->SetOutputDim("Out", x_dims);
}
};

class FillAnyGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "mul");
auto x_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};

template <typename T>
class FillAnyGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType(this->ForwardOpType() + "_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};

DECLARE_INPLACE_OP_INFERER(FillAnyOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FillAnyGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;

REGISTER_OPERATOR(fill_any, ops::FillAnyOp, ops::FillAnyOpMaker,
ops::FillAnyGradOpMaker<paddle::framework::OpDesc>,
ops::FillAnyGradOpMaker<paddle::imperative::OpBase>,
ops::FillAnyOpInplaceInferer);

REGISTER_OPERATOR(fill_any_grad, ops::FillAnyGradOp,
ops::FillAnyGradInplaceInferer);

REGISTER_OP_CPU_KERNEL(
fill_any, ops::FillAnyKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillAnyKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillAnyKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillAnyKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillAnyKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::FillAnyKernel<paddle::platform::CPUDeviceContext, bool>);

REGISTER_OP_CPU_KERNEL(
fill_any_grad,
ops::FillAnyGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillAnyGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillAnyGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillAnyGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillAnyGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::FillAnyGradKernel<paddle::platform::CPUDeviceContext, bool>);
35 changes: 35 additions & 0 deletions paddle/fluid/operators/fill_any_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fill_any_op.h"
namespace ops = paddle::operators;

REGISTER_OP_CUDA_KERNEL(
fill_any, ops::FillAnyKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillAnyKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillAnyKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillAnyKernel<paddle::platform::CUDADeviceContext, int>,
ops::FillAnyKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillAnyKernel<paddle::platform::CUDADeviceContext, bool>);

REGISTER_OP_CUDA_KERNEL(
fill_any_grad,
ops::FillAnyGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillAnyGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillAnyGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillAnyGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::FillAnyGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillAnyGradKernel<paddle::platform::CUDADeviceContext, bool>);
65 changes: 65 additions & 0 deletions paddle/fluid/operators/fill_any_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class FillAnyKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto *out = ctx.Output<framework::Tensor>("Out");
auto floatvar = ctx.template Attr<float>("value_float");
auto intvar = ctx.template Attr<int>("value_int");
auto isfloat = ((typeid(float) == typeid(T)) ||
(typeid(double) == typeid(T) ||
typeid(paddle::platform::float16) == typeid(T)));

T fill_var = static_cast<T>(floatvar);
if (!isfloat) {
fill_var = static_cast<T>(intvar);
}

PADDLE_ENFORCE_EQ(
std::isnan(static_cast<double>(fill_var)), false,
platform::errors::InvalidArgument("fill value should not be NaN,"
" but received NaN"));

out->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> functor;
functor(reinterpret_cast<const DeviceContext &>(dev_ctx), out,
static_cast<T>(fill_var));
}
};

template <typename DeviceContext, typename T>
class FillAnyGradKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> functor;
functor(reinterpret_cast<const DeviceContext &>(dev_ctx), dx, T(0));
}
}
};

} // namespace operators
} // namespace paddle
6 changes: 5 additions & 1 deletion python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@
from .fluid.framework import is_compiled_with_cuda # noqa: F401
from .fluid.framework import is_compiled_with_rocm # noqa: F401
from .fluid.framework import disable_signal_handler # noqa: F401
from .fluid.framework import get_flags # noqa: F401
from .fluid.framework import set_flags # noqa: F401
from .device import is_compiled_with_xpu # noqa: F401
from .device import is_compiled_with_npu # noqa: F401
from .device import XPUPlace # noqa: F401
Expand Down Expand Up @@ -521,5 +523,7 @@
'standard_normal',
'diagonal',
'broadcast_tensors',
'einsum'
'einsum',
'set_flags',
'get_flags'
]
10 changes: 6 additions & 4 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -6273,15 +6273,16 @@ def device_guard(device=None):
def set_flags(flags):
"""
This function sets the GFlags value in Paddle.
For FLAGS please refer to :ref:`en_guides_flags_flags`
Args:
flags (dict): A dict contains flags and its value.
Examples:
.. code-block:: python
import paddle.fluid as fluid
fluid.set_flags({'FLAGS_eager_delete_tensor_gb': 1.0})
import paddle
paddle.set_flags({'FLAGS_eager_delete_tensor_gb': 1.0})
"""
if not isinstance(flags, dict):
raise TypeError('flags in set_flags should be a dict')
Expand All @@ -6296,6 +6297,7 @@ def set_flags(flags):
def get_flags(flags):
"""
This function gets the GFlags value in Paddle.
For FLAGS please refer to :ref:`en_guides_flags_flags`
Args:
flags(list|tuple|str): A list/tuple of string or a string which is the flag's name.
Expand All @@ -6306,10 +6308,10 @@ def get_flags(flags):
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
flags = ['FLAGS_eager_delete_tensor_gb', 'FLAGS_check_nan_inf']
res = fluid.get_flags(flags)
res = paddle.get_flags(flags)
print(res)
# {'FLAGS_eager_delete_tensor_gb': 0.0, 'FLAGS_check_nan_inf': False}
"""
Expand Down
74 changes: 74 additions & 0 deletions python/paddle/fluid/tests/unittests/test_fill_any_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import paddle
import paddle.fluid.core as core
import unittest
import numpy as np
from op_test import OpTest


class TestFillAnyOp(OpTest):
def setUp(self):
self.op_type = "fill_any"
self.dtype = 'float64'
self.value = 0.0
self.init()
self.inputs = {'X': np.random.random((20, 30)).astype(self.dtype)}
self.attrs = {
'value_float': float(self.value),
'value_int': int(self.value)
}
self.outputs = {
'Out':
self.value * np.ones_like(self.inputs["X"]).astype(self.dtype)
}

def init(self):
pass

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(["X"], "Out")


class TestFillAnyOpFloat32(TestFillAnyOp):
def init(self):
self.dtype = np.float32
self.value = 0.0


class TestFillAnyOpFloat16(TestFillAnyOp):
def init(self):
self.dtype = np.float16


class TestFillAnyOpvalue1(TestFillAnyOp):
def init(self):
self.dtype = np.float32
self.value = 111111555


class TestFillAnyOpvalue2(TestFillAnyOp):
def init(self):
self.dtype = np.float32
self.value = 11111.1111


if __name__ == "__main__":
unittest.main()
Loading

1 comment on commit d945633

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.