Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transfer the svd infer into phi infermeta #44528

Merged
merged 4 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 8 additions & 47 deletions paddle/fluid/operators/svd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,64 +17,20 @@
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/infermeta/unary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

namespace paddle {
namespace operators {

using DDim = framework::DDim;
static DDim UDDim(const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 1] = k;
return phi::make_ddim(x_vec);
}
static DDim VHDDim(const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 2] = k;
return phi::make_ddim(x_vec);
}
static DDim SDDim(const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 2] = k;
x_vec.erase(x_vec.end() - 1); // rank - 1
return phi::make_ddim(x_vec);
}

class SvdOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "svd");
OP_INOUT_CHECK(ctx->HasOutput("U"), "Output", "U", "svd");
OP_INOUT_CHECK(ctx->HasOutput("VH"), "Output", "VH", "svd");
OP_INOUT_CHECK(ctx->HasOutput("S"), "Output", "S", "svd");

auto in_dims = ctx->GetInputDim("X");
int x_rank = in_dims.size();
PADDLE_ENFORCE_GE(in_dims.size(),
2,
platform::errors::InvalidArgument(
"the rank of input must greater than 2"));
int m = in_dims[x_rank - 2];
int n = in_dims[x_rank - 1];
int k = std::min(m, n);
const bool full_uv = ctx->Attrs().Get<bool>("full_matrices");
ctx->SetOutputDim("U", !full_uv ? UDDim(in_dims, k) : UDDim(in_dims, m));
ctx->SetOutputDim("VH", !full_uv ? VHDDim(in_dims, k) : VHDDim(in_dims, n));
ctx->SetOutputDim("S", SDDim(in_dims, k));

ctx->ShareLoD("X", /*->*/ "U");
ctx->ShareLoD("X", /*->*/ "VH");
ctx->ShareLoD("X", /*->*/ "S");
}
};

class SvdOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down Expand Up @@ -159,10 +115,15 @@ class SvdGradMaker : public framework::SingleGradOpMaker<T> {

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(svd,
SvdInferShapeFunctor,
PD_INFER_META(phi::SvdInferMeta));

REGISTER_OPERATOR(svd,
ops::SvdOp,
ops::SvdOpMaker,
ops::SvdGradMaker<paddle::framework::OpDesc>,
ops::SvdGradMaker<paddle::imperative::OpBase>);
ops::SvdGradMaker<paddle::imperative::OpBase>,
SvdInferShapeFunctor);

REGISTER_OPERATOR(svd_grad, ops::SvdGradOp);
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2101,6 +2101,15 @@
data_type : x
backward : sum_grad

- api : svd
args : (Tensor x, bool full_metrices)
output : Tensor(u), Tensor(s), Tensor(vh)
infer_meta :
func : SvdInferMeta
kernel :
func : svd
backward : svd_grad

# The python API paddle.nn.functional.swish has no `bete` argument, it may be removed later
- api : swish
args : (Tensor x, float beta=1.0)
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2092,6 +2092,17 @@
output : Tensor(grad_grad_x_grad)
invoke : sum_grad(grad_grad_x, grad_grad_out_grad, dims, keep_dim, reduce_all, grad_grad_x_grad)

- backward_api : svd_grad
forward : svd (Tensor x, bool full) -> Tensor(u), Tensor(s), Tensor(vh)
args : (Tensor x, Tensor u, Tensor vh, Tensor s, Tensor u_grad, Tensor vh_grad, Tensor s_grad, bool full)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : svd_grad
optional: u_grad, vh_grad, s_grad

- backward_api : swish_grad
forward : swish (Tensor x, float beta=1.0) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float bete=1.0)
Expand Down
47 changes: 47 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2674,6 +2674,53 @@ void SumRawInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
MetaTensor* s,
MetaTensor* vh) {
auto UDDim = [](const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 1] = k;
return phi::make_ddim(x_vec);
};

auto VHDDim = [](const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 2] = k;
return phi::make_ddim(x_vec);
};

auto SDDim = [](const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 2] = k;
x_vec.erase(x_vec.end() - 1); // rank - 1
return phi::make_ddim(x_vec);
};

auto in_dims = x.dims();
int x_rank = in_dims.size();
PADDLE_ENFORCE_GE(
in_dims.size(),
2,
phi::errors::InvalidArgument("the rank of input must greater than 2"));
int m = in_dims[x_rank - 2];
int n = in_dims[x_rank - 1];
int k = std::min(m, n);
u->set_dims(!full_matrices ? UDDim(in_dims, k) : UDDim(in_dims, m));
vh->set_dims(!full_matrices ? VHDDim(in_dims, k) : VHDDim(in_dims, n));
s->set_dims(SDDim(in_dims, k));
u->share_lod(x);
vh->share_lod(x);
s->share_lod(x);
u->set_dtype(x.dtype());
vh->set_dtype(x.dtype());
s->set_dtype(x.dtype());
}

void TemporalShiftInferMeta(const MetaTensor& x,
int seg_num,
float shift_ratio,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,12 @@ void SumRawInferMeta(const MetaTensor& x,
DataType dtype,
MetaTensor* out);

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
MetaTensor* s,
MetaTensor* vh);

void TemporalShiftInferMeta(const MetaTensor& x,
int seg_num,
float shift_ratio,
Expand Down
55 changes: 34 additions & 21 deletions paddle/phi/kernels/impl/svd_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ void SvdGradKernel(const Context& dev_ctx,
const DenseTensor& u,
const DenseTensor& vh,
const DenseTensor& s,
const DenseTensor& u_grad,
const DenseTensor& vh_grad,
const DenseTensor& s_grad,
const paddle::optional<DenseTensor>& u_grad,
const paddle::optional<DenseTensor>& vh_grad,
const paddle::optional<DenseTensor>& s_grad,
bool full_matrices,
DenseTensor* x_grad) {
const auto& dX = *x_grad;
Expand All @@ -87,15 +87,33 @@ void SvdGradKernel(const Context& dev_ctx,
dev_ctx, u, {u.dims().size() - 1}, {0}, {k}, {1}, {});
VH = SliceKernel<T, Context>(
dev_ctx, vh, {vh.dims().size() - 2}, {0}, {k}, {1}, {});
dU = SliceKernel<T, Context>(
dev_ctx, u_grad, {u_grad.dims().size() - 1}, {0}, {k}, {1}, {});
dVH = SliceKernel<T, Context>(
dev_ctx, vh_grad, {vh.dims().size() - 2}, {0}, {k}, {1}, {});
if (u_grad.get_ptr() != nullptr) {
dU = SliceKernel<T, Context>(dev_ctx,
*(u_grad.get_ptr()),
{u.dims().size() - 1},
{0},
{k},
{1},
{});
}
if (vh_grad.get_ptr() != nullptr) {
dVH = SliceKernel<T, Context>(dev_ctx,
*(vh_grad.get_ptr()),
{vh.dims().size() - 2},
{0},
{k},
{1},
{});
}
} else {
U = u;
VH = vh;
dU = u_grad;
dVH = vh_grad;
if (u_grad.get_ptr() != nullptr) {
dU = *(u_grad.get_ptr());
}
if (vh_grad.get_ptr() != nullptr) {
dVH = *(vh_grad.get_ptr());
}
}
auto s_inverse = Pow<T, Context>(dev_ctx, s, -1);
auto s_square = Pow<T, Context>(dev_ctx, s, 2);
Expand All @@ -106,19 +124,17 @@ void SvdGradKernel(const Context& dev_ctx,
F,
Diag<T, Context>(dev_ctx, Infinits<T, Context>(dev_ctx, {k}), 0, 0));
F = Pow<T, Context>(dev_ctx, F, -1);
DenseTensor sigma_term;
DenseTensor u_term;
DenseTensor v_term;
DenseTensor sigma_term = Fill<T, Context>(dev_ctx, {1}, 0.0);
DenseTensor u_term = Fill<T, Context>(dev_ctx, {1}, 0.0);
DenseTensor v_term = Fill<T, Context>(dev_ctx, {1}, 0.0);

// if (ctx.HasInput(framework::GradVarName("S")))
{
const DenseTensor& gS = s_grad;
if (s_grad.get_ptr() != nullptr) {
const DenseTensor& gS = *(s_grad.get_ptr());
sigma_term = Multiply<T, Context>(dev_ctx, Unsqueeze(gS, -2), U);
sigma_term = Matmul<T, Context>(dev_ctx, sigma_term, VH);
}

// if (ctx.HasInput(framework::GradVarName("U"))) {
{
if (u_grad.get_ptr() != nullptr) {
auto UTG = Matmul<T, Context>(dev_ctx, U, dU, true, false);
auto GTU = Matmul<T, Context>(dev_ctx, dU, U, true, false);
u_term = Multiply<T, Context>(
Expand All @@ -141,10 +157,7 @@ void SvdGradKernel(const Context& dev_ctx,
}
u_term = Matmul<T, Context>(dev_ctx, u_term, VH);
}
// }

// if (ctx.HasInput(framework::GradVarName("VH"))) {
{
if (vh_grad.get_ptr() != nullptr) {
auto UTG = Matmul<T, Context>(dev_ctx, VH, dVH, false, true);
auto GTU = Matmul<T, Context>(dev_ctx, dVH, VH, false, true);
v_term = Multiply<T, Context>(
Expand Down
14 changes: 7 additions & 7 deletions paddle/phi/kernels/svd_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ namespace phi {

template <typename T, typename Context>
void SvdGradKernel(const Context& dev_ctx,
const DenseTensor& X,
const DenseTensor& U,
const DenseTensor& VH,
const DenseTensor& S,
const DenseTensor& U_grad,
const DenseTensor& VH_grad,
const DenseTensor& S_grad,
const DenseTensor& x,
const DenseTensor& u,
const DenseTensor& vh,
const DenseTensor& s,
const paddle::optional<DenseTensor>& u_grad,
const paddle::optional<DenseTensor>& vh_grad,
const paddle::optional<DenseTensor>& s_grad,
bool full_matrices,
DenseTensor* X_grad);
} // namespace phi
15 changes: 11 additions & 4 deletions python/paddle/fluid/tests/unittests/test_svd_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TestSvdOp(OpTest):

def setUp(self):
paddle.enable_static()
self.python_api = paddle.linalg.svd
self.generate_input()
self.generate_output()
self.op_type = "svd"
Expand All @@ -55,7 +56,7 @@ def generate_output(self):
self._output_data = np.linalg.svd(self._input_data)

def test_check_output(self):
self.check_output(no_check_set=['U', 'VH'])
self.check_output(no_check_set=['U', 'VH'], check_eager=True)

def test_svd_forward(self):
""" u matmul diag(s) matmul vt must become X
Expand All @@ -75,13 +76,19 @@ def test_svd_forward(self):
paddle.enable_static()

def check_S_grad(self):
self.check_grad(['X'], ['S'], numeric_grad_delta=0.001)
self.check_grad(['X'], ['S'],
numeric_grad_delta=0.001,
check_eager=True)

def check_U_grad(self):
self.check_grad(['X'], ['U'], numeric_grad_delta=0.001)
self.check_grad(['X'], ['U'],
numeric_grad_delta=0.001,
check_eager=True)

def check_V_grad(self):
self.check_grad(['X'], ['VH'], numeric_grad_delta=0.001)
self.check_grad(['X'], ['VH'],
numeric_grad_delta=0.001,
check_eager=True)

def test_check_grad(self):
"""
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1854,8 +1854,9 @@ def svd(x, full_matrices=False, name=None):
# U * UH == I
# V * VH == I
"""

if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_svd(x, full_matrices)
if _in_legacy_dygraph():
return _C_ops.svd(x, 'full_matrices', full_matrices)
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'svd')
check_type(full_matrices, 'full_matrices', bool, 'svd')
Expand Down