Skip to content

Commit

Permalink
[PHI]Seperate xshape kernel from normal kernel (PaddlePaddle#44315)
Browse files Browse the repository at this point in the history
* seperate xshape kernel from normal kernel

* fix bugs in infermeta

* fix compile bugs

* fix compile bugs
  • Loading branch information
YuanRisheng authored and Aurelius84 committed Jul 29, 2022
1 parent 6f919fe commit d977bf7
Show file tree
Hide file tree
Showing 21 changed files with 239 additions and 61 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/einsum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(einsum,
EinsumInferShapeFunctor,
PD_INFER_META(phi::EinsumInferMeta));
PD_INFER_META(phi::EinsumRawInferMeta));

REGISTER_OPERATOR(einsum,
ops::EinsumOp,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/squeeze_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(squeeze2,
SqueezeInferShapeFunctor,
PD_INFER_META(phi::SqueezeInferMeta));
PD_INFER_META(phi::SqueezeWithXShapeInferMeta));

REGISTER_OPERATOR(squeeze,
ops::SqueezeOp,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/unsqueeze_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X");

DECLARE_INFER_SHAPE_FUNCTOR(unsqueeze2,
Unsqueeze2InferShapeFunctor,
PD_INFER_META(phi::UnsqueezeInferMeta));
PD_INFER_META(phi::UnsqueezeWithXShapeInferMeta));

namespace ops = paddle::operators;
REGISTER_OPERATOR(unsqueeze,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ add_custom_command(
${dygraph_api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_source_file_tmp}
${dygraph_api_source_file}
DEPENDS ${api_yaml_file} ${sparse_api_yaml_file} ${im_api_gen_file}
${api_gen_base} ${api_gen_file}
DEPENDS ${api_yaml_file} ${legacy_api_yaml_file} ${sparse_api_yaml_file}
${im_api_gen_file} ${api_gen_base} ${api_gen_file}
VERBATIM)

# generate wrapped infermeta
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -582,10 +582,10 @@
args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
infer_meta :
func : EinsumInferMeta
func : EinsumRawInferMeta
param : [x, equation]
kernel :
func : einsum
func : einsum_raw
backward : einsum_grad

- api : elementwise_pow
Expand Down Expand Up @@ -2047,9 +2047,9 @@
args : (Tensor x, int[] axes)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : SqueezeInferMeta
func : SqueezeWithXShapeInferMeta
kernel :
func : squeeze
func : squeeze_with_xshape
view: (x -> out)
intermediate : xshape
backward : squeeze_grad
Expand Down Expand Up @@ -2290,9 +2290,9 @@
args : (Tensor x, IntArray axis)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : UnsqueezeInferMeta
func : UnsqueezeWithXShapeInferMeta
kernel :
func : unsqueeze
func : unsqueeze_with_xshape
view: (x -> out)
intermediate : xshape
backward : unsqueeze_grad
Expand Down
56 changes: 39 additions & 17 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -570,9 +570,7 @@ void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) {

void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape) {
MetaTensor* out) {
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
Expand Down Expand Up @@ -609,6 +607,14 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
out->set_dims(make_ddim(output_dims));
out->set_dtype(inputs[0]->dtype());
}

void EinsumRawInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape) {
EinsumInferMeta(inputs, equation, out);
for (size_t i = 0; i < xshape.size(); ++i) {
if (xshape[i] != nullptr) {
xshape[i]->set_dims(inputs[i]->dims());
Expand Down Expand Up @@ -2448,8 +2454,7 @@ void SplitInferMeta(const MetaTensor& x,

void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out,
MetaTensor* xshape) {
MetaTensor* out) {
const auto& x_dims = x.dims();
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(),
Expand All @@ -2469,15 +2474,25 @@ void SqueezeInferMeta(const MetaTensor& x,
out->share_lod(x);
}

out->set_dtype(x.dtype());
}

void SqueezeWithXShapeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out,
MetaTensor* xshape) {
SqueezeInferMeta(x, axes, out);
const auto& x_dims = x.dims();
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
xshape->set_dtype(x.dtype());
out->set_dtype(x.dtype());
if (xshape) {
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
xshape->set_dtype(x.dtype());
}
}

void StridedSliceRawInferMeta(const MetaTensor& x,
Expand Down Expand Up @@ -3310,7 +3325,6 @@ void UniqueRawInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config) {
const auto& x_dims = x.dims();
// Validity Check: input tensor dims (<6).
Expand Down Expand Up @@ -3339,14 +3353,22 @@ void UnsqueezeInferMeta(const MetaTensor& x,
}
out->set_dtype(x.dtype());
}
if (xshape) {
// set xshape dims.
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
}

void UnsqueezeWithXShapeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config) {
const auto& x_dims = x.dims();
UnsqueezeInferMeta(x, axes, out, config);
// set xshape dims.
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
if (xshape) {
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
xshape->set_dtype(x.dtype());
Expand Down
25 changes: 19 additions & 6 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,13 @@ void EigvalsInferMeta(const MetaTensor& x,

void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape);
MetaTensor* out);

void EinsumRawInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape);

void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
Expand Down Expand Up @@ -341,8 +345,12 @@ void SplitInferMeta(const MetaTensor& x_meta,

void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out,
MetaTensor* xshape);
MetaTensor* out);

void SqueezeWithXShapeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out,
MetaTensor* xshape);

void StridedSliceRawInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
Expand Down Expand Up @@ -470,9 +478,14 @@ void UniqueRawInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config = MetaConfig());

void UnsqueezeWithXShapeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config = MetaConfig());

void UnStackInferMeta(const MetaTensor& x,
int axis,
int num,
Expand Down
11 changes: 10 additions & 1 deletion paddle/phi/kernels/cpu/einsum_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,20 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"

PD_REGISTER_KERNEL(einsum,
PD_REGISTER_KERNEL(einsum_raw,
CPU,
ALL_LAYOUT,
phi::EinsumKernelRaw,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(einsum,
CPU,
ALL_LAYOUT,
phi::EinsumKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
15 changes: 15 additions & 0 deletions paddle/phi/kernels/cpu/squeeze_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,18 @@ PD_REGISTER_KERNEL(squeeze,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(squeeze_with_xshape,
CPU,
ALL_LAYOUT,
phi::SqueezeWithXShapeKernel,
float,
double,
phi::dtype::bfloat16,
bool,
int,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
16 changes: 16 additions & 0 deletions paddle/phi/kernels/cpu/unsqueeze_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,19 @@ PD_REGISTER_KERNEL(unsqueeze,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(unsqueeze_with_xshape,
CPU,
ALL_LAYOUT,
phi::UnsqueezeWithXShapeKernel,
float,
double,
phi::dtype::bfloat16,
bool,
int,
int16_t,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
13 changes: 12 additions & 1 deletion paddle/phi/kernels/gpu/einsum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"

PD_REGISTER_KERNEL(einsum,
PD_REGISTER_KERNEL(einsum_raw,
GPU,
ALL_LAYOUT,
phi::EinsumKernelRaw,
Expand All @@ -28,3 +28,14 @@ PD_REGISTER_KERNEL(einsum,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(einsum,
GPU,
ALL_LAYOUT,
phi::EinsumKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
16 changes: 16 additions & 0 deletions paddle/phi/kernels/gpu/squeeze_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,19 @@ PD_REGISTER_KERNEL(squeeze,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(squeeze_with_xshape,
GPU,
ALL_LAYOUT,
phi::SqueezeWithXShapeKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16,
bool,
int,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
17 changes: 17 additions & 0 deletions paddle/phi/kernels/gpu/unsqueeze_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,20 @@ PD_REGISTER_KERNEL(unsqueeze,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(unsqueeze_with_xshape,
GPU,
ALL_LAYOUT,
phi::UnsqueezeWithXShapeKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
bool,
int,
int16_t,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
6 changes: 4 additions & 2 deletions paddle/phi/kernels/impl/solve_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/expand_as_kernel.h"
#include "paddle/phi/kernels/funcs/matrix_solve.h"
Expand Down Expand Up @@ -77,7 +79,7 @@ static std::vector<int64_t> get_broadcast_batch_portion(
static inline std::vector<int> convert_to_int_vec(std::vector<int64_t> a) {
std::vector<int> ret;
for (size_t i = 0; i < a.size(); i++) {
ret.emplace_back(int(a[i]));
ret.emplace_back(static_cast<int>(a[i]));
}

return ret;
Expand Down Expand Up @@ -167,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx,
out_tmp.Resize(out->dims());
out_tmp = *out;

phi::SqueezeKernel<T, Context>(dev_ctx, out_tmp, {-1}, out, nullptr);
phi::SqueezeKernel<T, Context>(dev_ctx, out_tmp, {-1}, out);
} else {
PADDLE_ENFORCE_EQ(
x_dim[x_dim_size - 1],
Expand Down
Loading

0 comments on commit d977bf7

Please sign in to comment.