Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify compare logical inplace #56888

Merged
70 changes: 42 additions & 28 deletions paddle/phi/kernels/cpu/compare_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,34 @@ inline void CompareKernelImpl(const Context& ctx,
const DenseTensor& y,
int axis,
DenseTensor* out) {
if (!out->IsSharedWith(x)) {
ctx.template Alloc<bool>(out);
if (x.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x, y, Functor(), out, axis);
} else {
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
ctx, x, y, InverseFunctor(), out, axis);
}
ctx.template Alloc<bool>(out);
if (x.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x, y, Functor(), out, axis);
} else {
if (x.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, T>(ctx, x, y, Functor(), out, axis);
} else {
funcs::ElementwiseCompute<InverseFunctor, T, T>(
ctx, x, y, InverseFunctor(), out, axis);
}
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
ctx, x, y, InverseFunctor(), out, axis);
}
}

template <typename T,
typename Context,
typename Functor,
typename InverseFunctor>
inline void InplaceCompareKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
auto x_origin = x;
out->set_type(phi::DataType::BOOL);
ctx.template Alloc<bool>(out);
if (x_origin.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x_origin, y, Functor(), out, axis);
} else {
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
ctx, x_origin, y, InverseFunctor(), out, axis);
}
}

Expand Down Expand Up @@ -92,19 +104,21 @@ PD_REGISTER_KERNEL(equal_all,
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}

#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) {}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
Expand Down
67 changes: 46 additions & 21 deletions paddle/phi/kernels/cpu/logical_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,40 @@

namespace phi {

#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
funcs::Logical##type##Functor<T> binary_func; \
if (out->IsSharedWith(x)) { \
funcs::ElementwiseCompute<funcs::Logical##type##Functor<T>, T, T>( \
dev_ctx, x, y, binary_func, out); \
} else { \
funcs::ElementwiseCompute<funcs::Logical##type##Functor<T>, T, bool>( \
dev_ctx, x, y, binary_func, out); \
} \
template <typename T, typename Context, typename Functor>
void LogicalKernelImpl(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
Functor binary_func;
funcs::ElementwiseCompute<Functor, T, bool>(dev_ctx, x, y, binary_func, out);
}

template <typename T, typename Context, typename Functor>
void InplaceLogicalKernelImpl(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
Functor binary_func;
auto x_origin = x;
out->set_type(phi::DataType::BOOL);
funcs::ElementwiseCompute<Functor, T, bool>(
dev_ctx, x_origin, y, binary_func, out);
}

#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
if (out->IsSharedWith(x)) { \
InplaceLogicalKernelImpl<T, Context, funcs::Logical##type##Functor<T>>( \
dev_ctx, x, y, out); \
} else { \
LogicalKernelImpl<T, Context, funcs::Logical##type##Functor<T>>( \
dev_ctx, x, y, out); \
} \
}

DEFINE_LOGICAL_BINARY_KERNEL(And)
Expand All @@ -52,15 +72,18 @@ void LogicalNotKernel(const Context& dev_ctx,
funcs::LogicalNotFunctor<T> unary_func;

phi::Transform<Context> trans;
if (!out->IsSharedWith(x)) {
if (out->IsSharedWith(x)) {
auto x_origin = x;
out->set_type(phi::DataType::BOOL);
auto* out_ptr = dev_ctx.template Alloc<bool>(out);
trans(dev_ctx, x.data<T>(), x.data<T>() + x.numel(), out_ptr, unary_func);
} else {
trans(dev_ctx,
x.data<T>(),
x.data<T>() + x.numel(),
reinterpret_cast<T*>(out->data()),
x_origin.data<T>(),
x_origin.data<T>() + x_origin.numel(),
out_ptr,
unary_func);
} else {
auto* out_ptr = dev_ctx.template Alloc<bool>(out);
trans(dev_ctx, x.data<T>(), x.data<T>() + x.numel(), out_ptr, unary_func);
}
}

Expand All @@ -79,7 +102,9 @@ void LogicalNotKernel(const Context& dev_ctx,
int8_t, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>, \
int16_t) {}
int16_t) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}

REGISTER_LOGICAL_CPU_KERNEL(logical_and, And)
REGISTER_LOGICAL_CPU_KERNEL(logical_or, Or)
Expand Down
31 changes: 23 additions & 8 deletions paddle/phi/kernels/impl/compare_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,35 @@ inline void CompareKernelImpl(const Context& ctx,
int axis,
DenseTensor* out);

template <typename T,
typename Context,
typename Functor,
typename InverseFunctor>
inline void InplaceCompareKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);

template <typename T, typename Context, typename Functor>
inline void CompareAllKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);

#define DEFINE_COMPARE_KERNEL(name, functor, inverse_functor) \
template <typename T, typename Context> \
void name##Kernel(const Context& ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
CompareKernelImpl<T, Context, functor<T>, inverse_functor<T>>( \
ctx, x, y, -1, out); \
#define DEFINE_COMPARE_KERNEL(name, functor, inverse_functor) \
template <typename T, typename Context> \
void name##Kernel(const Context& ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
if (out->IsSharedWith(x)) { \
InplaceCompareKernelImpl<T, Context, functor<T>, inverse_functor<T>>( \
ctx, x, y, -1, out); \
} else { \
CompareKernelImpl<T, Context, functor<T>, inverse_functor<T>>( \
ctx, x, y, -1, out); \
} \
}

DEFINE_COMPARE_KERNEL(LessThan,
Expand Down
54 changes: 34 additions & 20 deletions paddle/phi/kernels/kps/compare_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,27 @@ inline void CompareKernelImpl(const Context& ctx,
const DenseTensor& y,
int axis,
DenseTensor* out) {
if (!out->IsSharedWith(x)) {
ctx.template Alloc<bool>(out);
}
ctx.template Alloc<bool>(out);
std::vector<const DenseTensor*> ins{&x, &y};
std::vector<DenseTensor*> outs{out};
if (!out->IsSharedWith(x)) {
funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
} else {
funcs::BroadcastKernel<T>(ctx, ins, &outs, Functor(), axis);
}
funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
}

template <typename T,
typename Context,
typename Functor,
typename InverseFunctor>
inline void InplaceCompareKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
auto x_origin = x;
ctx.template Alloc<bool>(out);
out->set_type(phi::DataType::BOOL);
std::vector<const DenseTensor*> ins{&x_origin, &y};
std::vector<DenseTensor*> outs{out};
funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
}

#ifndef PADDLE_WITH_XPU_KP
Expand Down Expand Up @@ -134,18 +145,21 @@ PD_REGISTER_KERNEL(equal_all,
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}

#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
KPS, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16) {}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
KPS, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}

PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
Expand Down
77 changes: 52 additions & 25 deletions paddle/phi/kernels/kps/logical_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,45 @@

namespace phi {

#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
if (!out->IsSharedWith(x)) { \
dev_ctx.template Alloc<bool>(out); \
} \
\
funcs::Logical##type##Functor<T> binary_func; \
std::vector<const DenseTensor*> ins = {&x, &y}; \
std::vector<DenseTensor*> outs = {out}; \
if (!out->IsSharedWith(x)) { \
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, binary_func); \
} else { \
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, binary_func); \
} \
template <typename T, typename Context, typename Functor>
void LogicalKernelImpl(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
dev_ctx.template Alloc<bool>(out);
Functor binary_func;
std::vector<const DenseTensor*> ins = {&x, &y};
std::vector<DenseTensor*> outs = {out};
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, binary_func);
}

template <typename T, typename Context, typename Functor>
void InplaceLogicalKernelImpl(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto x_origin = x;
dev_ctx.template Alloc<bool>(out);
out->set_type(phi::DataType::BOOL);
Functor binary_func;
std::vector<const DenseTensor*> ins = {&x_origin, &y};
std::vector<DenseTensor*> outs = {out};
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, binary_func);
}

#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
if (out->IsSharedWith(x)) { \
InplaceLogicalKernelImpl<T, Context, funcs::Logical##type##Functor<T>>( \
dev_ctx, x, y, out); \
} else { \
LogicalKernelImpl<T, Context, funcs::Logical##type##Functor<T>>( \
dev_ctx, x, y, out); \
} \
}

DEFINE_LOGICAL_BINARY_KERNEL(And)
Expand All @@ -56,14 +77,18 @@ void LogicalNotKernel(const Context& dev_ctx,
DenseTensor* out) {
if (!out->IsSharedWith(x)) {
dev_ctx.template Alloc<bool>(out);
}
funcs::LogicalNotFunctor<T> unary_func;
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
if (!out->IsSharedWith(x)) {
funcs::LogicalNotFunctor<T> unary_func;
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, unary_func);
} else {
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, unary_func);
auto x_origin = x;
out->set_type(phi::DataType::BOOL);
dev_ctx.template Alloc<bool>(out);
funcs::LogicalNotFunctor<T> unary_func;
std::vector<const DenseTensor*> ins = {&x_origin};
std::vector<DenseTensor*> outs = {out};
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, unary_func);
}
}

Expand Down Expand Up @@ -99,7 +124,9 @@ PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {
int8_t, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>, \
int16_t) {}
int16_t) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}

REGISTER_LOGICAL_CUDA_KERNEL(logical_and, And)
REGISTER_LOGICAL_CUDA_KERNEL(logical_or, Or)
Expand Down
Loading