Skip to content

Commit

Permalink
Support rsqrt_p (#46369)
Browse files Browse the repository at this point in the history
* support rsqrt_p

* refine code and ut

* add_prim_rsqrt

* fix ut
  • Loading branch information
JiabinYang committed Sep 26, 2022
1 parent 9a29168 commit 4c438d3
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 7 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/operators/prim_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ set(PRIM_OP_SRCS
max_p_op.cc
erf_p_op.cc
abs_p_op.cc
cast_p_op.cc)
cast_p_op.cc
rsqrt_p_op.cc)

cc_test(
prim_op_test
Expand Down
82 changes: 82 additions & 0 deletions paddle/fluid/operators/prim_ops/rsqrt_p_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace operators {
class RsqrtPrimOp : public framework::OperatorBase {
public:
RsqrtPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator rsqrt_p should not be excuted directly"));
}
};

class RsqrtPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of rsqrt_p op.");
AddOutput("Y", "(Tensor), The output tensor of rsqrt_p op.");
AddComment(R"DOC(
Autograd primitive rsqrt_p operator.
)DOC");
}
};

class RsqrtPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];

framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);

PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};

class RsqrtPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};

} // namespace operators
} // namespace paddle

REGISTER_OPERATOR(rsqrt_p,
paddle::operators::RsqrtPrimOp,
paddle::operators::RsqrtPrimOpMaker,
paddle::operators::RsqrtPrimOpShapeInference,
paddle::operators::RsqrtPrimOpVarTypeInference);
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,39 @@ def init_data(self):
]


class TestRSqrtPJVPAndTranspose(TestAddPJVPAndTranspose):

def init_data(self):
# Set prim op
self.op_type = 'rsqrt_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}

# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}

self.all_ops = [
# prim op:
'rsqrt_p',
# jvp op:
'div_p',
'div_p',
'mul_p',
'fill_constant_p',
# 'sqrt_p',
# transpose op:
]


class TestTanhPJVPAndTranspose(TestAddPJVPAndTranspose):

def init_data(self):
Expand Down
21 changes: 21 additions & 0 deletions python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,5 +879,26 @@ def init_data(self):
self.out_map = {0: self.output['Out']}


class TestRSqrtOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
self.op_type = 'rsqrt'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}

self.orig2prim_args = (X, )
self.all_ops = ['rsqrt', 'rsqrt_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}


if __name__ == '__main__':
unittest.main()
20 changes: 20 additions & 0 deletions python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,5 +690,25 @@ def init_data(self):
self.out_map = {self.output['Y']: 0}


class TestRsqrtPrim2Orig(TestAddPPrim2Orig):

def init_data(self):
self.op_type = 'rsqrt_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

self.input = {
'X': X,
}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}

self.prim2orig_args = (X, )
self.all_ops = ['rsqrt_p', 'rsqrt']
self.out_map = {self.output['Y']: 0}


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/unittests/autograd/test_primapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def without_program_guard():
('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'),
('abs', paddle.abs, (np.random.uniform(-10, 10,
(10, 10)), ), None, 'float32'),
('rsqrt', paddle.rsqrt, (np.random.rand(100, 200), ), None, 'float32'),
))
# paddle.where, paddle.pow, paddle.maximum has no double grad definition,
# can not compute forward grad use double trick
Expand Down Expand Up @@ -267,6 +268,7 @@ def test_illegal_param(self):
(np.random.rand(3, 3), np.random.rand(3, 3)),
(np.random.rand(3, 3), ), 'float64'),
('sin', paddle.sin, (np.random.rand(100, 200), ), None, 'float32'),
('rsqrt', paddle.rsqrt, (np.random.rand(100, 200), ), None, 'float32'),
('cos', paddle.cos, (np.random.rand(200, 90), ), None, 'float32'),
('exp', paddle.exp, (np.random.rand(299, 320), ), None, 'float32'),
# In where op, grad of condition computed by paddle.static.gradients is None,
Expand Down
22 changes: 17 additions & 5 deletions python/paddle/fluid/tests/unittests/autograd/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,16 @@ def init_data(self):

A = paddle.tanh(X0)
B = paddle.tanh(X1)
Y = paddle.add(A, B)
C = paddle.rsqrt(B)
Y = paddle.add(A, C)

self.orig_xs = [X0, X1]
self.orig_ys = [
Y,
]

self.orig_ops = ['tanh', 'tanh', 'elementwise_add']
self.orig2prim_ops = ['tanh_p', 'tanh_p', 'add_p']
self.orig_ops = ['tanh', 'tanh', 'elementwise_add', 'rsqrt']
self.orig2prim_ops = ['tanh_p', 'tanh_p', 'add_p', 'rsqrt_p']
self.linearize_ops = self.orig2prim_ops + [
# call fill_const() in linearize() function
'fill_constant_p',
Expand All @@ -71,6 +72,10 @@ def init_data(self):
'fill_constant_p',
'mul_p',
'add_p',
'fill_constant_p',
'div_p',
'div_p',
'mul_p',
]
self.transpose_ops = self.orig2prim_ops + [
# call fill_const() in transpose() function
Expand All @@ -84,6 +89,10 @@ def init_data(self):
'mul_p',
'sub_p',
'fill_constant_p',
'mul_p',
'div_p',
'div_p',
'fill_constant_p',
# transposed op
'mul_p',
'mul_p'
Expand All @@ -92,13 +101,16 @@ def init_data(self):
'tanh', 'tanh', 'add_p', 'fill_constant', 'fill_constant',
'fill_constant', 'elementwise_mul', 'sub_p', 'fill_constant',
'elementwise_mul', 'sub_p', 'fill_constant', 'elementwise_mul',
'elementwise_mul'
'elementwise_mul', 'rsqrt', 'fill_constant', 'elementwise_div',
'elementwise_div', 'elementwise_mul'
]
self.prim2orig_ops = [
'tanh', 'tanh', 'elementwise_add', 'fill_constant', 'fill_constant',
'fill_constant', 'elementwise_mul', 'elementwise_sub',
'fill_constant', 'elementwise_mul', 'elementwise_sub',
'fill_constant', 'elementwise_mul', 'elementwise_mul'
'fill_constant', 'elementwise_mul', 'elementwise_mul', 'rsqrt',
'fill_constant', 'elementwise_div', 'elementwise_div',
'elementwise_mul'
]

def test_run(self):
Expand Down
5 changes: 5 additions & 0 deletions python/paddle/incubate/autograd/primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,8 @@ def cast(x, dtype, out=None):
outputs={'Y': out},
attrs={'dtype': dtype})
return out


@REGISTER_FN('rsqrt_p', 'X', 'Y')
def rsqrt(x, out=None):
return _simple_unop(LayerHelper('rsqrt_p', **locals()))
23 changes: 22 additions & 1 deletion python/paddle/incubate/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
fill_const, gather, ge, gt, log, matmul, max, mul, ne,
neg, reduce_sum, reshape, scatter_add, select, set_value,
sin, slice_assign, slice_select, split, sqrt, sub, tanh,
transpose)
transpose, rsqrt)
from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG,
REGISTER_TRANSPOSE, lookup_fn, lookup_jvp,
lookup_orig2prim, lookup_prim2orig, lookup_transpose,
Expand Down Expand Up @@ -252,6 +252,11 @@ def sqrt_orig2prim(op, x):
return sqrt(x)


@REGISTER_ORIG2PRIM('rsqrt')
def rsqrt_orig2prim(op, x):
return rsqrt(x)


@REGISTER_ORIG2PRIM('matmul_v2')
def matmul_v2_orig2prim(op, x, y):

Expand Down Expand Up @@ -456,6 +461,11 @@ def sub_prim2orig(op, x, y):
return paddle.subtract(x, y)


@REGISTER_PRIM2ORIG('rsqrt_p')
def rsqrt_prim2orig(op, x):
return paddle.rsqrt(x)


@REGISTER_PRIM2ORIG('mul_p')
def mul_prim2orig(op, x, y):
return paddle.multiply(x, y)
Expand Down Expand Up @@ -969,6 +979,17 @@ def cast_jvp(op, x_dot):
return primops.cast(x_dot, y.dtype)


@REGISTER_JVP('rsqrt_p')
def rsqrt_jvp(op, x_dot):
if x_dot is None:
return None
y = op_position_output(op)
x = op_position_inputs(op)
c2 = fill_const(value=-2.0, shape=y.shape, dtype=y.dtype)
y_dot = mul(x_dot, div(div(y, x), c2))
return y_dot


## Register transpose rules


Expand Down

0 comments on commit 4c438d3

Please sign in to comment.