Skip to content

Commit

Permalink
transfer the svd infer into phi infermeta (#44528)
Browse files Browse the repository at this point in the history
* transfer the svd infer into phi infermeta

* remove the svd.h

* modify svd api

* fix svd error by insert optional
  • Loading branch information
2742195759 committed Jul 26, 2022
1 parent 8d3672f commit 25d3dce
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 81 deletions.
55 changes: 8 additions & 47 deletions paddle/fluid/operators/svd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,64 +17,20 @@
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/infermeta/unary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

namespace paddle {
namespace operators {

using DDim = framework::DDim;
static DDim UDDim(const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 1] = k;
return phi::make_ddim(x_vec);
}
static DDim VHDDim(const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 2] = k;
return phi::make_ddim(x_vec);
}
static DDim SDDim(const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 2] = k;
x_vec.erase(x_vec.end() - 1); // rank - 1
return phi::make_ddim(x_vec);
}

class SvdOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "svd");
OP_INOUT_CHECK(ctx->HasOutput("U"), "Output", "U", "svd");
OP_INOUT_CHECK(ctx->HasOutput("VH"), "Output", "VH", "svd");
OP_INOUT_CHECK(ctx->HasOutput("S"), "Output", "S", "svd");

auto in_dims = ctx->GetInputDim("X");
int x_rank = in_dims.size();
PADDLE_ENFORCE_GE(in_dims.size(),
2,
platform::errors::InvalidArgument(
"the rank of input must greater than 2"));
int m = in_dims[x_rank - 2];
int n = in_dims[x_rank - 1];
int k = std::min(m, n);
const bool full_uv = ctx->Attrs().Get<bool>("full_matrices");
ctx->SetOutputDim("U", !full_uv ? UDDim(in_dims, k) : UDDim(in_dims, m));
ctx->SetOutputDim("VH", !full_uv ? VHDDim(in_dims, k) : VHDDim(in_dims, n));
ctx->SetOutputDim("S", SDDim(in_dims, k));

ctx->ShareLoD("X", /*->*/ "U");
ctx->ShareLoD("X", /*->*/ "VH");
ctx->ShareLoD("X", /*->*/ "S");
}
};

class SvdOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down Expand Up @@ -159,10 +115,15 @@ class SvdGradMaker : public framework::SingleGradOpMaker<T> {

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(svd,
SvdInferShapeFunctor,
PD_INFER_META(phi::SvdInferMeta));

REGISTER_OPERATOR(svd,
ops::SvdOp,
ops::SvdOpMaker,
ops::SvdGradMaker<paddle::framework::OpDesc>,
ops::SvdGradMaker<paddle::imperative::OpBase>);
ops::SvdGradMaker<paddle::imperative::OpBase>,
SvdInferShapeFunctor);

REGISTER_OPERATOR(svd_grad, ops::SvdGradOp);
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2140,6 +2140,15 @@
data_type : x
backward : sum_grad

- api : svd
args : (Tensor x, bool full_metrices)
output : Tensor(u), Tensor(s), Tensor(vh)
infer_meta :
func : SvdInferMeta
kernel :
func : svd
backward : svd_grad

# The python API paddle.nn.functional.swish has no `bete` argument, it may be removed later
- api : swish
args : (Tensor x, float beta=1.0)
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2133,6 +2133,17 @@
output : Tensor(grad_grad_x_grad)
invoke : sum_grad(grad_grad_x, grad_grad_out_grad, dims, keep_dim, reduce_all, grad_grad_x_grad)

- backward_api : svd_grad
forward : svd (Tensor x, bool full) -> Tensor(u), Tensor(s), Tensor(vh)
args : (Tensor x, Tensor u, Tensor vh, Tensor s, Tensor u_grad, Tensor vh_grad, Tensor s_grad, bool full)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : svd_grad
optional: u_grad, vh_grad, s_grad

- backward_api : swish_grad
forward : swish (Tensor x, float beta=1.0) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float bete=1.0)
Expand Down
47 changes: 47 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2715,6 +2715,53 @@ void SumRawInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
MetaTensor* s,
MetaTensor* vh) {
auto UDDim = [](const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 1] = k;
return phi::make_ddim(x_vec);
};

auto VHDDim = [](const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 2] = k;
return phi::make_ddim(x_vec);
};

auto SDDim = [](const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 2] = k;
x_vec.erase(x_vec.end() - 1); // rank - 1
return phi::make_ddim(x_vec);
};

auto in_dims = x.dims();
int x_rank = in_dims.size();
PADDLE_ENFORCE_GE(
in_dims.size(),
2,
phi::errors::InvalidArgument("the rank of input must greater than 2"));
int m = in_dims[x_rank - 2];
int n = in_dims[x_rank - 1];
int k = std::min(m, n);
u->set_dims(!full_matrices ? UDDim(in_dims, k) : UDDim(in_dims, m));
vh->set_dims(!full_matrices ? VHDDim(in_dims, k) : VHDDim(in_dims, n));
s->set_dims(SDDim(in_dims, k));
u->share_lod(x);
vh->share_lod(x);
s->share_lod(x);
u->set_dtype(x.dtype());
vh->set_dtype(x.dtype());
s->set_dtype(x.dtype());
}

void TemporalShiftInferMeta(const MetaTensor& x,
int seg_num,
float shift_ratio,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ void SumRawInferMeta(const MetaTensor& x,
DataType dtype,
MetaTensor* out);

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
MetaTensor* s,
MetaTensor* vh);

void TemporalShiftInferMeta(const MetaTensor& x,
int seg_num,
float shift_ratio,
Expand Down
55 changes: 34 additions & 21 deletions paddle/phi/kernels/impl/svd_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ void SvdGradKernel(const Context& dev_ctx,
const DenseTensor& u,
const DenseTensor& vh,
const DenseTensor& s,
const DenseTensor& u_grad,
const DenseTensor& vh_grad,
const DenseTensor& s_grad,
const paddle::optional<DenseTensor>& u_grad,
const paddle::optional<DenseTensor>& vh_grad,
const paddle::optional<DenseTensor>& s_grad,
bool full_matrices,
DenseTensor* x_grad) {
const auto& dX = *x_grad;
Expand All @@ -87,15 +87,33 @@ void SvdGradKernel(const Context& dev_ctx,
dev_ctx, u, {u.dims().size() - 1}, {0}, {k}, {1}, {});
VH = SliceKernel<T, Context>(
dev_ctx, vh, {vh.dims().size() - 2}, {0}, {k}, {1}, {});
dU = SliceKernel<T, Context>(
dev_ctx, u_grad, {u_grad.dims().size() - 1}, {0}, {k}, {1}, {});
dVH = SliceKernel<T, Context>(
dev_ctx, vh_grad, {vh.dims().size() - 2}, {0}, {k}, {1}, {});
if (u_grad.get_ptr() != nullptr) {
dU = SliceKernel<T, Context>(dev_ctx,
*(u_grad.get_ptr()),
{u.dims().size() - 1},
{0},
{k},
{1},
{});
}
if (vh_grad.get_ptr() != nullptr) {
dVH = SliceKernel<T, Context>(dev_ctx,
*(vh_grad.get_ptr()),
{vh.dims().size() - 2},
{0},
{k},
{1},
{});
}
} else {
U = u;
VH = vh;
dU = u_grad;
dVH = vh_grad;
if (u_grad.get_ptr() != nullptr) {
dU = *(u_grad.get_ptr());
}
if (vh_grad.get_ptr() != nullptr) {
dVH = *(vh_grad.get_ptr());
}
}
auto s_inverse = Pow<T, Context>(dev_ctx, s, -1);
auto s_square = Pow<T, Context>(dev_ctx, s, 2);
Expand All @@ -106,19 +124,17 @@ void SvdGradKernel(const Context& dev_ctx,
F,
Diag<T, Context>(dev_ctx, Infinits<T, Context>(dev_ctx, {k}), 0, 0));
F = Pow<T, Context>(dev_ctx, F, -1);
DenseTensor sigma_term;
DenseTensor u_term;
DenseTensor v_term;
DenseTensor sigma_term = Fill<T, Context>(dev_ctx, {1}, 0.0);
DenseTensor u_term = Fill<T, Context>(dev_ctx, {1}, 0.0);
DenseTensor v_term = Fill<T, Context>(dev_ctx, {1}, 0.0);

// if (ctx.HasInput(framework::GradVarName("S")))
{
const DenseTensor& gS = s_grad;
if (s_grad.get_ptr() != nullptr) {
const DenseTensor& gS = *(s_grad.get_ptr());
sigma_term = Multiply<T, Context>(dev_ctx, Unsqueeze(gS, -2), U);
sigma_term = Matmul<T, Context>(dev_ctx, sigma_term, VH);
}

// if (ctx.HasInput(framework::GradVarName("U"))) {
{
if (u_grad.get_ptr() != nullptr) {
auto UTG = Matmul<T, Context>(dev_ctx, U, dU, true, false);
auto GTU = Matmul<T, Context>(dev_ctx, dU, U, true, false);
u_term = Multiply<T, Context>(
Expand All @@ -141,10 +157,7 @@ void SvdGradKernel(const Context& dev_ctx,
}
u_term = Matmul<T, Context>(dev_ctx, u_term, VH);
}
// }

// if (ctx.HasInput(framework::GradVarName("VH"))) {
{
if (vh_grad.get_ptr() != nullptr) {
auto UTG = Matmul<T, Context>(dev_ctx, VH, dVH, false, true);
auto GTU = Matmul<T, Context>(dev_ctx, dVH, VH, false, true);
v_term = Multiply<T, Context>(
Expand Down
14 changes: 7 additions & 7 deletions paddle/phi/kernels/svd_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ namespace phi {

template <typename T, typename Context>
void SvdGradKernel(const Context& dev_ctx,
const DenseTensor& X,
const DenseTensor& U,
const DenseTensor& VH,
const DenseTensor& S,
const DenseTensor& U_grad,
const DenseTensor& VH_grad,
const DenseTensor& S_grad,
const DenseTensor& x,
const DenseTensor& u,
const DenseTensor& vh,
const DenseTensor& s,
const paddle::optional<DenseTensor>& u_grad,
const paddle::optional<DenseTensor>& vh_grad,
const paddle::optional<DenseTensor>& s_grad,
bool full_matrices,
DenseTensor* X_grad);
} // namespace phi
15 changes: 11 additions & 4 deletions python/paddle/fluid/tests/unittests/test_svd_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TestSvdOp(OpTest):

def setUp(self):
paddle.enable_static()
self.python_api = paddle.linalg.svd
self.generate_input()
self.generate_output()
self.op_type = "svd"
Expand All @@ -55,7 +56,7 @@ def generate_output(self):
self._output_data = np.linalg.svd(self._input_data)

def test_check_output(self):
self.check_output(no_check_set=['U', 'VH'])
self.check_output(no_check_set=['U', 'VH'], check_eager=True)

def test_svd_forward(self):
""" u matmul diag(s) matmul vt must become X
Expand All @@ -75,13 +76,19 @@ def test_svd_forward(self):
paddle.enable_static()

def check_S_grad(self):
self.check_grad(['X'], ['S'], numeric_grad_delta=0.001)
self.check_grad(['X'], ['S'],
numeric_grad_delta=0.001,
check_eager=True)

def check_U_grad(self):
self.check_grad(['X'], ['U'], numeric_grad_delta=0.001)
self.check_grad(['X'], ['U'],
numeric_grad_delta=0.001,
check_eager=True)

def check_V_grad(self):
self.check_grad(['X'], ['VH'], numeric_grad_delta=0.001)
self.check_grad(['X'], ['VH'],
numeric_grad_delta=0.001,
check_eager=True)

def test_check_grad(self):
"""
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,8 +1857,9 @@ def svd(x, full_matrices=False, name=None):
# U * UH == I
# V * VH == I
"""

if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_svd(x, full_matrices)
if _in_legacy_dygraph():
return _C_ops.svd(x, 'full_matrices', full_matrices)
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'svd')
check_type(full_matrices, 'full_matrices', bool, 'svd')
Expand Down

0 comments on commit 25d3dce

Please sign in to comment.