Skip to content

Commit

Permalink
Remove memory copy operation
Browse files Browse the repository at this point in the history
  • Loading branch information
Zjq9409 committed Sep 23, 2021
1 parent 9e4cf0d commit 8d53c2f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 23 deletions.
40 changes: 19 additions & 21 deletions paddle/fluid/operators/math/eigen_values_vectors.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,17 @@ struct MatrixEighFunctor<platform::CPUDeviceContext, ValueType, T> {
Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
bool has_vectors) {
auto *out_value = eigen_values->mutable_data<ValueType>(ctx.GetPlace());

auto dito =
math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext, T>(
ctx);
*eigen_vectors = dito.Transpose(input);
auto *out_vector = eigen_vectors->mutable_data<T>(ctx.GetPlace());

auto dims = input.dims();
int dim_size = dims.size();
int64_t batch_size = GetBatchSize(dims);

auto dito =
math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext, T>(
ctx);
Tensor output_v_var_trans = dito.Transpose(input);
TensorCopy(output_v_var_trans, ctx.GetPlace(), eigen_vectors);

int vector_stride = dims[dim_size - 1] * dims[dim_size - 2];
int values_stride = dims[dim_size - 1];
char uplo = is_lower ? 'L' : 'U';
Expand All @@ -74,27 +73,27 @@ struct MatrixEighFunctor<platform::CPUDeviceContext, ValueType, T> {
int lwork = -1;
int lrwork = -1;
int liwork = -1;
int iwork_buffer = -1;
T lwork_buffer = static_cast<T>(-1);
ValueType rwork_buffer = static_cast<ValueType>(-1);
int iwork_opt = -1;
T lwork_opt = static_cast<T>(-1);
ValueType rwork_opt = static_cast<ValueType>(-1);

Tensor info_tensor;
auto *infos_data = info_tensor.mutable_data<int>(
framework::make_ddim({batch_size}), ctx.GetPlace());

math::lapackEvd<T, ValueType>(jobz, uplo, n, out_vector, lda, out_value,
&lwork_buffer, lwork, &rwork_buffer, lrwork,
&iwork_buffer, liwork, infos_data);
&lwork_opt, lwork, &rwork_opt, lrwork,
&iwork_opt, liwork, infos_data);

lwork = std::max<int>(1, static_cast<int>(lwork_buffer));
liwork = std::max<int>(1, iwork_buffer);
lwork = std::max<int>(1, static_cast<int>(lwork_opt));
liwork = std::max<int>(1, iwork_opt);

Tensor rwork_tensor;
ValueType *rwork_data = nullptr;

// complex type
if (framework::IsComplexType(eigen_vectors->type())) {
lrwork = std::max<int>(1, static_cast<int>(rwork_buffer));
lrwork = std::max<int>(1, static_cast<int>(rwork_opt));
rwork_data = rwork_tensor.mutable_data<ValueType>(
framework::make_ddim({lrwork}), ctx.GetPlace());
}
Expand Down Expand Up @@ -136,6 +135,12 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T> {
Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
bool has_vectors) {
auto *out_value = eigen_values->mutable_data<ValueType>(ctx.GetPlace());

auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto dito =
math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
T>(ctx);
*eigen_vectors = dito.Transpose(input);
auto *out_vector = eigen_vectors->mutable_data<T>(ctx.GetPlace());

auto &dims = input.dims();
Expand All @@ -152,13 +157,6 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T> {
auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2];
auto values_stride = dims[dim_size - 1];

auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto dito =
math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
T>(ctx);
Tensor output_v_var_trans = dito.Transpose(input);
TensorCopy(output_v_var_trans, ctx.GetPlace(), eigen_vectors);

int lwork = 0;
auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_size);
auto *info_ptr = reinterpret_cast<int *>(info->ptr());
Expand Down
1 change: 0 additions & 1 deletion python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@
from .tensor.linalg import multi_dot # noqa: F401
from .tensor.linalg import matrix_power # noqa: F401
from .tensor.linalg import svd # noqa: F401
from .tensor.linalg import eigh # noqa: F401
from .tensor.linalg import pinv # noqa: F401
from .tensor.logic import equal # noqa: F401
from .tensor.logic import greater_equal # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,7 @@ def eigh(x, UPLO='L', name=None):
x_data = np.array([[1, -2j], [2j, 5]])
x = paddle.to_tensor(x_data)
out_value, out_vector = paddle.eigh(x, UPLO='L')
out_value, out_vector = paddle.linalg.eigh(x, UPLO='L')
print(out_value)
#[0.17157288, 5.82842712]
print(out_vector)
Expand Down

0 comments on commit 8d53c2f

Please sign in to comment.