Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI]Seperate xshape kernel from normal kernel #44315

Merged
merged 5 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/operators/einsum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(einsum,
EinsumInferShapeFunctor,
PD_INFER_META(phi::EinsumInferMeta));
PD_INFER_META(phi::EinsumRawInferMeta));

REGISTER_OPERATOR(einsum,
ops::EinsumOp,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/squeeze_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(squeeze2,
SqueezeInferShapeFunctor,
PD_INFER_META(phi::SqueezeInferMeta));
PD_INFER_META(phi::SqueezeWithXShapeInferMeta));

REGISTER_OPERATOR(squeeze,
ops::SqueezeOp,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/unsqueeze_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X");

DECLARE_INFER_SHAPE_FUNCTOR(unsqueeze2,
Unsqueeze2InferShapeFunctor,
PD_INFER_META(phi::UnsqueezeInferMeta));
PD_INFER_META(phi::UnsqueezeWithXShapeInferMeta));

namespace ops = paddle::operators;
REGISTER_OPERATOR(unsqueeze,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ add_custom_command(
${dygraph_api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_source_file_tmp}
${dygraph_api_source_file}
DEPENDS ${api_yaml_file} ${sparse_api_yaml_file} ${im_api_gen_file}
${api_gen_base} ${api_gen_file}
DEPENDS ${api_yaml_file} ${legacy_api_yaml_file} ${sparse_api_yaml_file}
${im_api_gen_file} ${api_gen_base} ${api_gen_file}
VERBATIM)

# generate wrapped infermeta
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,10 @@
args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
infer_meta :
func : EinsumInferMeta
func : EinsumRawInferMeta
param : [x, equation]
kernel :
func : einsum
func : einsum_raw
backward : einsum_grad

- api : elementwise_pow
Expand Down Expand Up @@ -2003,9 +2003,9 @@
args : (Tensor x, int[] axes)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : SqueezeInferMeta
func : SqueezeWithXShapeInferMeta
kernel :
func : squeeze
func : squeeze_with_xshape
view: (x -> out)
intermediate : xshape
backward : squeeze_grad
Expand Down Expand Up @@ -2236,9 +2236,9 @@
args : (Tensor x, IntArray axis)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : UnsqueezeInferMeta
func : UnsqueezeWithXShapeInferMeta
kernel :
func : unsqueeze
func : unsqueeze_with_xshape
view: (x -> out)
intermediate : xshape
backward : unsqueeze_grad
Expand Down
56 changes: 39 additions & 17 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,7 @@ void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) {

void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape) {
MetaTensor* out) {
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
Expand Down Expand Up @@ -481,6 +479,14 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
out->set_dims(make_ddim(output_dims));
out->set_dtype(inputs[0]->dtype());
}

void EinsumRawInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape) {
EinsumInferMeta(inputs, equation, out);
for (size_t i = 0; i < xshape.size(); ++i) {
if (xshape[i] != nullptr) {
xshape[i]->set_dims(inputs[i]->dims());
Expand Down Expand Up @@ -2320,8 +2326,7 @@ void SplitInferMeta(const MetaTensor& x,

void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out,
MetaTensor* xshape) {
MetaTensor* out) {
const auto& x_dims = x.dims();
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(),
Expand All @@ -2341,15 +2346,25 @@ void SqueezeInferMeta(const MetaTensor& x,
out->share_lod(x);
}

out->set_dtype(x.dtype());
}

void SqueezeWithXShapeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out,
MetaTensor* xshape) {
SqueezeInferMeta(x, axes, out);
const auto& x_dims = x.dims();
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
xshape->set_dtype(x.dtype());
out->set_dtype(x.dtype());
if (xshape) {
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
xshape->set_dtype(x.dtype());
}
}

void StridedSliceRawInferMeta(const MetaTensor& x,
Expand Down Expand Up @@ -3182,7 +3197,6 @@ void UniqueRawInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config) {
const auto& x_dims = x.dims();
// Validity Check: input tensor dims (<6).
Expand Down Expand Up @@ -3211,14 +3225,22 @@ void UnsqueezeInferMeta(const MetaTensor& x,
}
out->set_dtype(x.dtype());
}
if (xshape) {
// set xshape dims.
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
}

void UnsqueezeWithXShapeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config) {
const auto& x_dims = x.dims();
UnsqueezeInferMeta(x, axes, out, config);
// set xshape dims.
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
if (xshape) {
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
xshape->set_dtype(x.dtype());
Expand Down
25 changes: 19 additions & 6 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,13 @@ void EigvalsInferMeta(const MetaTensor& x,

void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape);
MetaTensor* out);

void EinsumRawInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape);

void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
Expand Down Expand Up @@ -332,8 +336,12 @@ void SplitInferMeta(const MetaTensor& x_meta,

void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out,
MetaTensor* xshape);
MetaTensor* out);

void SqueezeWithXShapeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out,
MetaTensor* xshape);

void StridedSliceRawInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
Expand Down Expand Up @@ -461,9 +469,14 @@ void UniqueRawInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config = MetaConfig());

void UnsqueezeWithXShapeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config = MetaConfig());

void UnStackInferMeta(const MetaTensor& x,
int axis,
int num,
Expand Down
11 changes: 10 additions & 1 deletion paddle/phi/kernels/cpu/einsum_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,20 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"

PD_REGISTER_KERNEL(einsum,
PD_REGISTER_KERNEL(einsum_raw,
CPU,
ALL_LAYOUT,
phi::EinsumKernelRaw,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(einsum,
CPU,
ALL_LAYOUT,
phi::EinsumKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
15 changes: 15 additions & 0 deletions paddle/phi/kernels/cpu/squeeze_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,18 @@ PD_REGISTER_KERNEL(squeeze,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(squeeze_with_xshape,
CPU,
ALL_LAYOUT,
phi::SqueezeWithXShapeKernel,
float,
double,
phi::dtype::bfloat16,
bool,
int,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
16 changes: 16 additions & 0 deletions paddle/phi/kernels/cpu/unsqueeze_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,19 @@ PD_REGISTER_KERNEL(unsqueeze,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(unsqueeze_with_xshape,
CPU,
ALL_LAYOUT,
phi::UnsqueezeWithXShapeKernel,
float,
double,
phi::dtype::bfloat16,
bool,
int,
int16_t,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
13 changes: 12 additions & 1 deletion paddle/phi/kernels/gpu/einsum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"

PD_REGISTER_KERNEL(einsum,
PD_REGISTER_KERNEL(einsum_raw,
GPU,
ALL_LAYOUT,
phi::EinsumKernelRaw,
Expand All @@ -28,3 +28,14 @@ PD_REGISTER_KERNEL(einsum,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(einsum,
GPU,
ALL_LAYOUT,
phi::EinsumKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
16 changes: 16 additions & 0 deletions paddle/phi/kernels/gpu/squeeze_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,19 @@ PD_REGISTER_KERNEL(squeeze,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(squeeze_with_xshape,
GPU,
ALL_LAYOUT,
phi::SqueezeWithXShapeKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16,
bool,
int,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
17 changes: 17 additions & 0 deletions paddle/phi/kernels/gpu/unsqueeze_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,20 @@ PD_REGISTER_KERNEL(unsqueeze,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(unsqueeze_with_xshape,
GPU,
ALL_LAYOUT,
phi::UnsqueezeWithXShapeKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
bool,
int,
int16_t,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
6 changes: 4 additions & 2 deletions paddle/phi/kernels/impl/solve_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/expand_as_kernel.h"
#include "paddle/phi/kernels/funcs/matrix_solve.h"
Expand Down Expand Up @@ -77,7 +79,7 @@ static std::vector<int64_t> get_broadcast_batch_portion(
static inline std::vector<int> convert_to_int_vec(std::vector<int64_t> a) {
std::vector<int> ret;
for (size_t i = 0; i < a.size(); i++) {
ret.emplace_back(int(a[i]));
ret.emplace_back(static_cast<int>(a[i]));
}

return ret;
Expand Down Expand Up @@ -167,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx,
out_tmp.Resize(out->dims());
out_tmp = *out;

phi::SqueezeKernel<T, Context>(dev_ctx, out_tmp, {-1}, out, nullptr);
phi::SqueezeKernel<T, Context>(dev_ctx, out_tmp, {-1}, out);
} else {
PADDLE_ENFORCE_EQ(
x_dim[x_dim_size - 1],
Expand Down
Loading