diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 306f7a77a05bb..a4a4f22142a4e 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -453,6 +453,26 @@ func : cross_grad data_type : out_grad +- backward_op : cummax_grad + forward : cummax(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices) + args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : cummax_grad + +- backward_op : cummin_grad + forward : cummin(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices) + args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : cummin_grad + - backward_op : cumprod_grad forward : cumprod (Tensor x, int dim) -> Tensor(out) args : (Tensor x, Tensor out, Tensor out_grad, int dim) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 966444785dd8c..8f262c5c42e10 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -522,6 +522,24 @@ data_type : input backward : cross_entropy_with_softmax_grad +- op : cummax + args : (Tensor x, int axis=-1, int dtype=3) + output : Tensor(out), Tensor(indices) + infer_meta : + func : CumWithIndicesInferMeta + kernel : + func : cummax + backward : cummax_grad + +- op : cummin + args : (Tensor x, int axis=-1, int dtype=3) + output : Tensor(out), Tensor(indices) + infer_meta : + func : CumWithIndicesInferMeta + kernel : + func : cummin + backward : cummin_grad + - op : cumprod args : (Tensor x, int dim) output : Tensor(out) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 92cf654aee8c0..457c207b3c180 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -506,6 +506,69 @@ void CumScalarAxisInferMeta(const MetaTensor& x, CumInferMeta(x, axis.to(), flatten, exclusive, reverse, out); } +void CumWithIndicesInferMeta(const MetaTensor& x, + int axis, + int dtype, + MetaTensor* out, + MetaTensor* indices) { + auto x_dims = x.dims(); + auto indices_type = phi::TransToPhiDataType(dtype); + PADDLE_ENFORCE_EQ( + (indices_type == DataType::INT32 || indices_type == DataType::INT64), + true, + phi::errors::InvalidArgument("dtype of indices must be int32 or int64")); + + if (indices_type == DataType::INT32) { + int _axis; + if (axis < 0) { + _axis = axis + x_dims.size(); + } else { + _axis = axis; + } + PADDLE_ENFORCE_LT( + phi::vectorize(x_dims)[_axis], + INT32_MAX, + phi::errors::OutOfRange( + "cummax with axis %ld may be overflow, set dtype int64 to continue", + axis)); + } + + if (x_dims.size() > 0) { + PADDLE_ENFORCE_GE( + axis, + -x_dims.size(), + phi::errors::OutOfRange( + "axis is out of range (expected to be in range of [%ld, " + "%ld), but got %ld).", + -(x_dims.size()), + x_dims.size(), + axis)); + PADDLE_ENFORCE_LT( + axis, + x_dims.size(), + phi::errors::OutOfRange( + "axis is out of range (expected to be in range of [%ld, " + "%ld), but got %ld).", + -(x_dims.size()), + x_dims.size(), + axis)); + } else { + PADDLE_ENFORCE_EQ( + (axis == 0 || axis == -1), + true, + errors::InvalidArgument("The axis must be -1 or 0 in 0D Tensor, " + "but the value given is %d.", + axis)); + } + + out->set_dims(x_dims); + out->set_dtype(x.dtype()); + out->share_lod(x); + indices->set_dims(x_dims); + indices->set_dtype(indices_type); + indices->share_lod(x); +} + void CropInferMeta(const MetaTensor& x, const IntArray& shape, const IntArray& offsets, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 7dc922ac9a487..5e3e53804fe69 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -111,6 +111,12 @@ void CumScalarAxisInferMeta(const MetaTensor& x, bool reverse, MetaTensor* out); +void CumWithIndicesInferMeta(const MetaTensor& x, + int axis, + int dtype, + MetaTensor* out, + MetaTensor* indices); + void DecodeJpegInferMeta(const MetaTensor& x, const std::string& mode, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc b/paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc new file mode 100644 index 0000000000000..88fb4f4feb91f --- /dev/null +++ b/paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cum_maxmin_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/gather_scatter_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void CummaxGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& indices, + const DenseTensor& out_grad, + int axis, + int dtype, + DenseTensor* x_grad) { + dev_ctx.template Alloc(x_grad); + phi::funcs::SetConstant functor; + functor(dev_ctx, x_grad, static_cast(0)); + if (axis < 0) { + axis = axis + x.dims().size(); + } + auto indices_type = phi::TransToPhiDataType(dtype); + if (indices_type == DataType::INT32) { + phi::funcs::cpu_scatter_add_kernel( + *x_grad, axis, indices, out_grad, dev_ctx); + } else if (indices_type == DataType::INT64) { + phi::funcs::cpu_scatter_add_kernel( + *x_grad, axis, indices, out_grad, dev_ctx); + } +} + +template +void CumminGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& indices, + const DenseTensor& out_grad, + int axis, + int dtype, + DenseTensor* x_grad) { + dev_ctx.template Alloc(x_grad); + phi::funcs::SetConstant functor; + functor(dev_ctx, x_grad, static_cast(0)); + if (axis < 0) { + axis = axis + x.dims().size(); + } + auto indices_type = phi::TransToPhiDataType(dtype); + if (indices_type == DataType::INT32) { + phi::funcs::cpu_scatter_add_kernel( + *x_grad, axis, indices, out_grad, dev_ctx); + } else if (indices_type == DataType::INT64) { + phi::funcs::cpu_scatter_add_kernel( + *x_grad, axis, indices, out_grad, dev_ctx); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(cummax_grad, + CPU, + ALL_LAYOUT, + phi::CummaxGradKernel, + float, + double, + int32_t, + int64_t) {} + +PD_REGISTER_KERNEL(cummin_grad, + CPU, + ALL_LAYOUT, + phi::CumminGradKernel, + float, + double, + int32_t, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/cum_maxmin_kernel.cc b/paddle/phi/kernels/cpu/cum_maxmin_kernel.cc new file mode 100644 index 0000000000000..be1cfe3d86b1f --- /dev/null +++ b/paddle/phi/kernels/cpu/cum_maxmin_kernel.cc @@ -0,0 +1,200 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cum_maxmin_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +#ifdef _MSC_VER +template +typename std::enable_if::value, bool>::type isnan_(T x) { + return false; +} +template +typename std::enable_if::value, bool>::type isnan_(T x) { + return std::isnan(x); +} +#else +template +bool isnan_(T x) { + return std::isnan(x); +} +#endif + +template +T compute_stride(T axis, phi::DDim dims) { + T size = 1; + for (T i = axis + 1; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + +template +void ComputeImp(const DenseTensor& x, + DenseTensor* out, + DenseTensor* indices, + int64_t axis) { + int ndims = x.dims().size(); + int finished = 0; + std::vector counter(ndims, 0); + const T1* x_data = x.data(); + T1* values_data = out->data(); + T2* indices_data = indices->data(); + int64_t x_stride = compute_stride(axis, x.dims()); + int64_t values_stride = compute_stride(axis, out->dims()); + int64_t indices_stride = compute_stride(axis, indices->dims()); + auto x_dim_vec = phi::vectorize(x.dims()); + int x_dim_size = x_dim_vec[axis]; + BinaryFunction op; + + while (!finished) { + T1 max = *reinterpret_cast(x_data); + int idx = 0; + for (int i = 0; i < x_dim_size; i++) { + T1 curr_elem = *reinterpret_cast(&x_data[i * x_stride]); + if (isnan_(curr_elem) || (!isnan_(max) && op(curr_elem, max))) { + max = curr_elem; + idx = i; + } + values_data[i * values_stride] = max; + indices_data[i * indices_stride] = idx; + } + if (ndims == 1) break; + for (int dim_i = 0; dim_i < ndims; dim_i++) { + if (dim_i == axis) { + if (dim_i == (ndims - 1)) { + finished = 1; + break; + } + continue; + } + int64_t x_stride_ = compute_stride(dim_i, x.dims()); + int64_t values_stride_ = compute_stride(dim_i, out->dims()); + int64_t indices_stride_ = compute_stride(dim_i, indices->dims()); + counter[dim_i]++; + x_data += x_stride_; + values_data += values_stride_; + indices_data += indices_stride_; + if (counter[dim_i] == x_dim_vec[dim_i]) { + if (dim_i == ndims - 1) { + finished = 1; + break; + } else { + x_data -= counter[dim_i] * x_stride_; + values_data -= counter[dim_i] * values_stride_; + indices_data -= counter[dim_i] * indices_stride_; + counter[dim_i] = 0; + } + } else { + break; + } + } + } +} + +template +void ScanWithIndicesKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + DenseTensor* out, + DenseTensor* indices) { + dev_ctx.template Alloc(out); + dev_ctx.template Alloc(indices); + + // For 0D Tensor + if (x.numel() == 1) { + auto raw_dims = out->dims(); + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, indices, static_cast(0.0)); + out->Resize(raw_dims); + indices->Resize(raw_dims); + return; + } + auto out_dims = out->dims(); + + PADDLE_ENFORCE_EQ( + axis < out_dims.size() && axis >= (0 - out_dims.size()), + true, + phi::errors::OutOfRange( + "Attr(axis) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(axis) = %d.", + out_dims.size(), + out_dims.size() - 1, + axis)); + + if (axis < 0) { + axis = axis + out_dims.size(); + } + ComputeImp(x, out, indices, axis); +} + +template +void CummaxKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + int dtype, + DenseTensor* out, + DenseTensor* indices) { + auto indices_type = phi::TransToPhiDataType(dtype); + if (indices_type == DataType::INT32) { + ScanWithIndicesKernel, Context>( + dev_ctx, x, axis, out, indices); + } else if (indices_type == DataType::INT64) { + ScanWithIndicesKernel, Context>( + dev_ctx, x, axis, out, indices); + } +} + +template +void CumminKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + int dtype, + DenseTensor* out, + DenseTensor* indices) { + auto indices_type = phi::TransToPhiDataType(dtype); + if (indices_type == DataType::INT32) { + ScanWithIndicesKernel, Context>( + dev_ctx, x, axis, out, indices); + } else if (indices_type == DataType::INT64) { + ScanWithIndicesKernel, Context>( + dev_ctx, x, axis, out, indices); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(cummax, + CPU, + ALL_LAYOUT, + phi::CummaxKernel, + float, + double, + int32_t, + int64_t) {} + +PD_REGISTER_KERNEL(cummin, + CPU, + ALL_LAYOUT, + phi::CumminKernel, + float, + double, + int32_t, + int64_t) {} diff --git a/paddle/phi/kernels/cum_maxmin_grad_kernel.h b/paddle/phi/kernels/cum_maxmin_grad_kernel.h new file mode 100644 index 0000000000000..13a6b7ee6ec1e --- /dev/null +++ b/paddle/phi/kernels/cum_maxmin_grad_kernel.h @@ -0,0 +1,39 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CummaxGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& indices, + const DenseTensor& out_grad, + int axis, + int dtype, + DenseTensor* x_grad); + +template +void CumminGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& indices, + const DenseTensor& out_grad, + int axis, + int dtype, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/cum_maxmin_kernel.h b/paddle/phi/kernels/cum_maxmin_kernel.h new file mode 100644 index 0000000000000..37755deb5d91e --- /dev/null +++ b/paddle/phi/kernels/cum_maxmin_kernel.h @@ -0,0 +1,37 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CummaxKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + int dtype, + DenseTensor* out, + DenseTensor* indices); + +template +void CumminKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + int dtype, + DenseTensor* out, + DenseTensor* indices); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/cum_maxmin_grad_kernel.cu b/paddle/phi/kernels/gpu/cum_maxmin_grad_kernel.cu new file mode 100644 index 0000000000000..a89373c607f7d --- /dev/null +++ b/paddle/phi/kernels/gpu/cum_maxmin_grad_kernel.cu @@ -0,0 +1,91 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cum_maxmin_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/gather_scatter_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void CummaxGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& indices, + const DenseTensor& out_grad, + int axis, + int dtype, + DenseTensor* x_grad) { + dev_ctx.template Alloc(x_grad); + phi::funcs::SetConstant functor; + functor(dev_ctx, x_grad, static_cast(0)); + if (axis < 0) { + axis = axis + x.dims().size(); + } + auto indices_type = phi::TransToPhiDataType(dtype); + if (indices_type == DataType::INT32) { + phi::funcs::gpu_scatter_add_kernel( + *x_grad, axis, indices, out_grad, dev_ctx); + } else if (indices_type == DataType::INT64) { + phi::funcs::gpu_scatter_add_kernel( + *x_grad, axis, indices, out_grad, dev_ctx); + } +} + +template +void CumminGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& indices, + const DenseTensor& out_grad, + int axis, + int dtype, + DenseTensor* x_grad) { + dev_ctx.template Alloc(x_grad); + phi::funcs::SetConstant functor; + functor(dev_ctx, x_grad, static_cast(0)); + if (axis < 0) { + axis = axis + x.dims().size(); + } + auto indices_type = phi::TransToPhiDataType(dtype); + if (indices_type == DataType::INT32) { + phi::funcs::gpu_scatter_add_kernel( + *x_grad, axis, indices, out_grad, dev_ctx); + } else if (indices_type == DataType::INT64) { + phi::funcs::gpu_scatter_add_kernel( + *x_grad, axis, indices, out_grad, dev_ctx); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(cummax_grad, + GPU, + ALL_LAYOUT, + phi::CummaxGradKernel, + float, + double, + int32_t, + int64_t) {} + +PD_REGISTER_KERNEL(cummin_grad, + GPU, + ALL_LAYOUT, + phi::CumminGradKernel, + float, + double, + int32_t, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu b/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu new file mode 100644 index 0000000000000..bf836af72c58f --- /dev/null +++ b/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu @@ -0,0 +1,368 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cum_maxmin_kernel.h" + +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template < + typename T1, + typename T2, + typename BinaryOperation, + typename std::enable_if::value, int>::type = 0> +__device__ void binary_op_update(const T1 lhs, + T1* rhs, + const T2 lhs_idx, + T2* rhs_idx, + BinaryOperation binary_op) { + if (!isnan(*rhs) && (isnan(lhs) || !binary_op(*rhs, lhs))) { + *rhs = lhs; + *rhs_idx = lhs_idx; + } +} + +template ::value, int>::type = 0> +__device__ void binary_op_update(const T1 lhs, + T1* rhs, + const T2 lhs_idx, + T2* rhs_idx, + BinaryOperation binary_op) { + if (!binary_op(*rhs, lhs)) { + *rhs = lhs; + *rhs_idx = lhs_idx; + } +} + +template < + typename T1, + typename T2, + typename BinaryOperation, + typename std::enable_if::value, int>::type = 0> +__device__ void binary_op_update_v(const T1 lhs, + T1* rhs, + const T2 lhs_idx, + T2* rhs_idx, + BinaryOperation binary_op) { + if (isnan(lhs) || (!isnan(*rhs) && binary_op(lhs, *rhs))) { + *rhs = lhs; + *rhs_idx = lhs_idx; + } +} + +template ::value, int>::type = 0> +__device__ void binary_op_update_v(const T1 lhs, + T1* rhs, + const T2 lhs_idx, + T2* rhs_idx, + BinaryOperation binary_op) { + if (binary_op(lhs, *rhs)) { + *rhs = lhs; + *rhs_idx = lhs_idx; + } +} + +template +__global__ void KernelScanInnerWithIndices(const T1* x_data, + T1* values_data, + T2* indices_data, + int num_rows, + int row_size, + T1 init, + BinaryFunction binary_op) { + __shared__ T1 vbuf[num_threads_y][2 * num_threads_x]; + __shared__ T2 ibuf[num_threads_y][2 * num_threads_x]; + T1* row_buf = vbuf[threadIdx.y]; + T2* row_idx_buf = ibuf[threadIdx.y]; + + for (int block_row = blockIdx.x * blockDim.y; block_row < num_rows; + block_row += blockDim.y * gridDim.x) { + int row = block_row + threadIdx.y; + const T1* row_self = x_data + row * row_size; + T1* row_values = values_data + row * row_size; + T2* row_indices = indices_data + row * row_size; + T1 block_total = init; + T2 block_idx_final = 0; + // Perform scan on one block at a time, keeping track of the total value of + // all blocks processed so far. + for (int block_col = 0; block_col < row_size; + block_col += 2 * num_threads_x) { + // Load data into shared memory (two values per thread). + int col1 = block_col + threadIdx.x; + int col2 = block_col + num_threads_x + threadIdx.x; + if (row < num_rows) { + if (col1 < row_size) { + row_buf[threadIdx.x] = *reinterpret_cast(&row_self[col1]); + row_idx_buf[threadIdx.x] = col1; + } else { + row_buf[threadIdx.x] = init; + } + + if (col2 < row_size) { + row_buf[num_threads_x + threadIdx.x] = + *reinterpret_cast(&row_self[col2]); + row_idx_buf[num_threads_x + threadIdx.x] = col2; + } else { + row_buf[num_threads_x + threadIdx.x] = init; + } + + if (threadIdx.x == 0) { + binary_op_update(block_total, + &row_buf[0], + block_idx_final, + &row_idx_buf[0], + binary_op); + } + } + __syncthreads(); + + // Parallel reduction (up-sweep). + for (int s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { + if (row < num_rows && threadIdx.x < s) { + int offset = (2 * threadIdx.x + 1) * d - 1; + binary_op_update(row_buf[offset], + &row_buf[offset + d], + row_idx_buf[offset], + &row_idx_buf[offset + d], + binary_op); + } + __syncthreads(); + } + + // Down-sweep. + for (int s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) { + if (row < num_rows && threadIdx.x < s - 1) { + int offset = 2 * (threadIdx.x + 1) * d - 1; + binary_op_update(row_buf[offset], + &row_buf[offset + d], + row_idx_buf[offset], + &row_idx_buf[offset + d], + binary_op); + } + __syncthreads(); + } + + // Write back to output. + if (row < num_rows) { + if (col1 < row_size) { + row_values[col1] = row_buf[threadIdx.x]; + row_indices[col1] = row_idx_buf[threadIdx.x]; + } + if (col2 < row_size) { + row_values[col2] = row_buf[num_threads_x + threadIdx.x]; + row_indices[col2] = row_idx_buf[num_threads_x + threadIdx.x]; + } + } + block_total = row_buf[2 * num_threads_x - 1]; + block_idx_final = row_idx_buf[2 * num_threads_x - 1]; + __syncthreads(); + } + } +} + +template +__global__ void KernelScanOuterWithIndices(const T1* x_data, + T1* values_data, + T2* indices_data, + const uint32_t num_orows, + const uint32_t num_irows, + const uint32_t row_size, + T1 init, + BinaryFunction binary_op) { + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; + irow < num_irows; + irow += gridDim.y * blockDim.x) { + const T1* x = x_data + orow * row_size * num_irows + irow; + T1* values = values_data + orow * row_size * num_irows + irow; + T2* indices = indices_data + orow * row_size * num_irows + irow; + T1 out = init; + T2 out_idx = 0; + + for (T2 col = 0; col < row_size; ++col) { + const auto val = *reinterpret_cast(x); + binary_op_update_v(val, &out, col, &out_idx, binary_op); + *values = out; + *indices = out_idx; + x += num_irows; + values += num_irows; + indices += num_irows; + } + } + } +} + +template +void ScanWithIndicesKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + T1 init, + DenseTensor* out, + DenseTensor* indices) { + dev_ctx.template Alloc(out); + dev_ctx.template Alloc(indices); + // For 0D Tensor + if (out->numel() == 1) { + auto raw_dims = out->dims(); + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, indices, static_cast(0.0)); + out->Resize(raw_dims); + indices->Resize(raw_dims); + return; + } + + BinaryFunction op; + auto out_dims = out->dims(); + auto size = x.numel(); + + PADDLE_ENFORCE_EQ( + axis < out_dims.size() && axis >= (0 - out_dims.size()), + true, + phi::errors::OutOfRange( + "Attr(axis) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(axis) = %d.", + out_dims.size(), + out_dims.size() - 1, + axis)); + if (axis < 0) { + axis += out_dims.size(); + } + + const T1* x_data = x.data(); + T1* values_data = out->data(); + T2* indices_data = indices->data(); + if (axis == out_dims.size() - 1) { + int ndim = x.dims().size(); + int row_size = x.dims()[ndim - 1]; + int num_rows = x.numel() / row_size; + + dim3 threads(16, 32); + dim3 grid( + std::min(dev_ctx.GetCUDAMaxGridDimSize()[0], + static_cast(std::ceil(static_cast(num_rows) / + static_cast(threads.y))))); + + KernelScanInnerWithIndices + <<>>( + x_data, values_data, indices_data, num_rows, row_size, init, op); + } else { + int64_t row_size = x.dims()[axis]; + auto sizes = phi::vectorize(x.dims()); + + const int64_t num_orows = + std::accumulate(sizes.begin(), + sizes.begin() + axis, + int64_t(1), + [](int64_t& a, int64_t& b) { return a * b; }); + const int64_t num_irows = + std::accumulate(sizes.begin() + axis + 1, + sizes.end(), + int64_t(1), + [](int64_t& a, int64_t& b) { return a * b; }); + + dim3 threads(std::min(512, static_cast(num_irows))); + int64_t maxGridDim = dev_ctx.GetCUDAMaxGridDimSize()[1]; + dim3 grid(std::min(maxGridDim, num_orows), + std::min(maxGridDim, + static_cast( + std::ceil(static_cast(num_irows) / + static_cast(threads.x))))); + + KernelScanOuterWithIndices + <<>>(x_data, + values_data, + indices_data, + num_orows, + num_irows, + row_size, + init, + op); + } +} + +template +void CummaxKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + int dtype, + DenseTensor* out, + DenseTensor* indices) { + auto indices_type = phi::TransToPhiDataType(dtype); + T init = std::is_floating_point::value + ? (-1 * std::numeric_limits::infinity()) + : std::numeric_limits::lowest(); + if (indices_type == DataType::INT32) { + ScanWithIndicesKernel, Context>( + dev_ctx, x, axis, init, out, indices); + } else if (indices_type == DataType::INT64) { + ScanWithIndicesKernel, Context>( + dev_ctx, x, axis, init, out, indices); + } +} + +template +void CumminKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + int dtype, + DenseTensor* out, + DenseTensor* indices) { + auto indices_type = phi::TransToPhiDataType(dtype); + T init = std::is_floating_point::value ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + if (indices_type == DataType::INT32) { + ScanWithIndicesKernel, Context>( + dev_ctx, x, axis, init, out, indices); + } else if (indices_type == DataType::INT64) { + ScanWithIndicesKernel, Context>( + dev_ctx, x, axis, init, out, indices); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(cummax, + GPU, + ALL_LAYOUT, + phi::CummaxKernel, + float, + double, + int32_t, + int64_t) {} + +PD_REGISTER_KERNEL(cummin, + GPU, + ALL_LAYOUT, + phi::CumminKernel, + float, + double, + int32_t, + int64_t) {} diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ff5f4c8650795..7d1e100c882db 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -211,6 +211,8 @@ from .tensor.math import tan # noqa: F401 from .tensor.math import cosh # noqa: F401 from .tensor.math import cumsum # noqa: F401 +from .tensor.math import cummax # noqa: F401 +from .tensor.math import cummin # noqa: F401 from .tensor.math import cumprod # noqa: F401 from .tensor.math import logcumsumexp # noqa: F401 from .tensor.math import logit # noqa: F401 @@ -444,6 +446,8 @@ 'empty_like', 'eye', 'cumsum', + 'cummax', + 'cummin', 'cumprod', 'logaddexp', 'logcumsumexp', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ba13034fc56ed..8e3fb9ce1d3c4 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -148,6 +148,8 @@ from .math import tan # noqa: F401 from .math import cosh # noqa: F401 from .math import cumsum # noqa: F401 +from .math import cummax # noqa: F401 +from .math import cummin # noqa: F401 from .math import cumprod # noqa: F401 from .math import logcumsumexp # noqa: F401 from .math import logit # noqa: F401 @@ -341,6 +343,8 @@ 'cos', 'cosh', 'cumsum', + 'cummax', + 'cummin', 'cumprod', 'logcumsumexp', 'logit', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 32338dbcf2d6e..e30960d037831 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3371,6 +3371,155 @@ def cumsum(x, axis=None, dtype=None, name=None): return _cum_sum_(**kwargs) +def cummax(x, axis=None, dtype='int64', name=None): + """ + The cumulative max of the elements along a given axis. + + Note: + The first element of the result is the same as the first element of the input. + + Args: + x (Tensor): The input tensor needed to be cummaxed. + axis (int, optional): The dimension to accumulate along. -1 means the last dimension. The default (None) is to compute the cummax over the flattened array. + dtype (str, optional): The data type of the indices tensor, can be int32, int64. The default value is int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor), The result of cummax operation. The dtype of cummax result is same with input x. + + indices (Tensor), The corresponding index results of cummax operation. + + Examples: + .. code-block:: python + + import paddle + + data = paddle.to_tensor([-1, 5, 0, -2, -3, 2]) + data = paddle.reshape(data, (2, 3)) + + y = paddle.cummax(data) + # value: [-1, 5, 5, 5, 5, 5] + # indcies: [0, 1, 1, 1, 1, 1] + + y = paddle.cummax(data, axis=0) + # value: [[-1, 5, 0] + # [-1, 5, 2]] + # indcies: [[0, 0, 0] + # [0, 0, 1]] + + y = paddle.cummax(data, axis=-1) + # value: [[-1, 5, 5] + # [-2, -2, 2]] + # indcies: [[0, 1, 1] + # [0, 0, 2]] + + y = paddle.cummax(data, dtype='int64') + print(y[1].dtype) + # indcies type: paddle.int64 + """ + if axis is None: + axis = -1 + x = x.flatten(0, len(x.shape) - 1) + + check_dtype(dtype, 'dtype', ['int32', 'int64'], 'cummax') + dtype = convert_np_dtype_to_dtype_(dtype) + + if in_dynamic_mode(): + return _C_ops.cummax(x, axis, dtype) + else: + check_variable_and_dtype( + x, + 'x', + ['float32', 'float64', 'int32', 'int64'], + 'cummax', + ) + check_type(x, 'x', (Variable), 'cummax') + helper = LayerHelper('cummax', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + indices = helper.create_variable_for_type_inference(dtype='int64') + helper.append_op( + type='cummax', + inputs={'x': x}, + outputs={'out': out, 'indices': indices}, + attrs={'axis': axis, 'dtype': dtype}, + ) + return out, indices + + +def cummin(x, axis=None, dtype='int64', name=None): + """ + The cumulative min of the elements along a given axis. + + Note: + The first element of the result is the same as the first element of the input. + + Args: + x (Tensor): The input tensor needed to be cummined. + axis (int, optional): The dimension to accumulate along. -1 means the last dimension. The default (None) is to compute the cummin over the flattened array. + dtype (str, optional): The data type of the indices tensor, can be int32, int64. The default value is int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor), The result of cummin operation. The dtype of cummin result is same with input x. + + indices (Tensor), The corresponding index results of cummin operation. + + Examples: + .. code-block:: python + + import paddle + data = paddle.to_tensor([-1, 5, 0, -2, -3, 2]) + data = paddle.reshape(data, (2, 3)) + + y = paddle.cummin(data) + # value: [-1, -1, -1, -2, -3, -3] + # indcies: [0, 0, 0, 3, 4, 4] + + y = paddle.cummin(data, axis=0) + # value: [[-1, 5, 0] + # [-2, -3, 0]] + # indcies: [[0, 0, 0] + # [1, 1, 0]] + + y = paddle.cummin(data, axis=-1) + # value: [[-1, -1, -1] + # [-2, -3, -3]] + # indcies: [[0, 0, 0] + # [0, 1, 1]] + + y = paddle.cummin(data, dtype='int64') + print(y[1].dtype) + # indcies type: paddle.int64 + """ + if axis is None: + axis = -1 + x = x.flatten(0, len(x.shape) - 1) + + check_dtype(dtype, 'dtype', ['int32', 'int64'], 'cummin') + dtype = convert_np_dtype_to_dtype_(dtype) + + if in_dynamic_mode(): + return _C_ops.cummin(x, axis, dtype) + else: + check_variable_and_dtype( + x, + 'x', + ['float32', 'float64', 'int32', 'int64'], + 'cummin', + ) + check_type(x, 'x', (Variable), 'cummin') + helper = LayerHelper('cummin', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + indices = helper.create_variable_for_type_inference(dtype='int64') + helper.append_op( + type='cummin', + inputs={'x': x}, + outputs={'out': out, 'indices': indices}, + attrs={'axis': axis, 'dtype': dtype}, + ) + return out, indices + + def logcumsumexp(x, axis=None, dtype=None, name=None): r""" The logarithm of the cumulative summation of the exponentiation of the elements along a given axis. diff --git a/test/legacy_test/test_cummax_op.py b/test/legacy_test/test_cummax_op.py new file mode 100644 index 0000000000000..1ff5cb2442a63 --- /dev/null +++ b/test/legacy_test/test_cummax_op.py @@ -0,0 +1,244 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np +from eager_op_test import OpTest + +import paddle +from paddle import fluid +from paddle.fluid import core + + +def cummax_dim2(arr, axis=None): + if axis is None: + arr = arr.flatten() + cummax = np.maximum.accumulate(arr) + shape = arr.shape + indices = np.zeros(shape).astype('int32') + max_val = -sys.maxsize + max_ind = 0 + for i in range(shape[0]): + if arr[i] >= max_val: + max_val = max(arr[i], max_val) + max_ind = i + indices[i] = i + else: + indices[i] = max_ind + else: + cummax = np.maximum.accumulate(arr, axis) + shape = arr.shape + indices = np.zeros(shape).astype('int32') + if axis < 0: + axis = axis + len(shape) + if axis == 0: + for j in range(shape[1]): + max_ind = 0 + max_val = -sys.maxsize + for i in range(shape[0]): + if arr[i][j] >= max_val: + max_val = arr[i][j] + max_ind = i + indices[i][j] = i + else: + indices[i][j] = max_ind + elif axis == 1: + for i in range(shape[0]): + max_ind = 0 + max_val = -sys.maxsize + for j in range(shape[1]): + if arr[i][j] >= max_val: + max_val = arr[i][j] + max_ind = j + indices[i][j] = j + else: + indices[i][j] = max_ind + else: + raise Exception("unfeasible axis") + return cummax, indices + + +class TestCummaxOp(OpTest): + def setUp(self): + self.op_type = "cummax" + self.python_api = paddle.cummax + self.dtype = np.float64 + self.axis = -1 + self.indices_type = 3 + self.input_data = np.random.random((10, 10)).astype(self.dtype) + self.set_attrs() + + self.inputs = {'x': self.input_data} + self.attrs = {'axis': self.axis, 'dtype': self.indices_type} + self.np_res, self.np_ind = cummax_dim2(self.input_data, axis=self.axis) + self.outputs = {'out': self.np_res, 'indices': self.np_ind} + + def set_attrs(self): + pass + + def test_check_output(self): + paddle.enable_static() + self.check_output() + + def test_check_grad(self): + paddle.enable_static() + self.check_grad(['x'], 'out') + + +class TestCummaxOpAxis1(TestCummaxOp): + def set_attrs(self): + self.axis = 0 + + +class TestCummaxOpAxis2(TestCummaxOp): + def set_attrs(self): + self.axis = -2 + + +class TestCummaxOpIndexType(TestCummaxOp): + def set_attrs(self): + self.indices_type = 2 + + +class TestCummaxAPI(unittest.TestCase): + def run_cases(self): + data_np = np.random.random((100, 100)).astype(np.float32) + data = paddle.to_tensor(data_np) + + y, indices = paddle.cummax(data) + z, ind = cummax_dim2(data_np) + np.testing.assert_array_equal(z, y.numpy()) + np.testing.assert_array_equal(ind, indices.numpy()) + + y, indices = paddle.cummax(data, axis=0) + z, ind = cummax_dim2(data_np, axis=0) + np.testing.assert_array_equal(z, y.numpy()) + np.testing.assert_array_equal(ind, indices.numpy()) + + y, indices = paddle.cummax(data, axis=-1) + z, ind = cummax_dim2(data_np, axis=-1) + np.testing.assert_array_equal(z, y.numpy()) + np.testing.assert_array_equal(ind, indices.numpy()) + + y, indices = paddle.cummax(data, axis=-2) + z, ind = cummax_dim2(data_np, axis=-2) + np.testing.assert_array_equal(z, y.numpy()) + np.testing.assert_array_equal(ind, indices.numpy()) + + y, indices = paddle.cummax(data, axis=-2, dtype='int32') + z, ind = cummax_dim2(data_np, axis=-2) + np.testing.assert_array_equal(z, y.numpy()) + np.testing.assert_array_equal(ind, indices.numpy()) + self.assertTrue(indices.dtype == core.VarDesc.VarType.INT32) + + data_np = np.random.randint(0, 10, size=(100, 100)).astype(np.int32) + data = paddle.to_tensor(data_np) + y, indices = paddle.cummax(data, axis=0) + z, ind = cummax_dim2(data_np, axis=0) + np.testing.assert_array_equal(z, y.numpy()) + np.testing.assert_array_equal(ind, indices.numpy()) + + def run_static(self, use_gpu=False): + with fluid.program_guard(fluid.Program()): + data_np = np.random.random((100, 100)).astype(np.float32) + x = paddle.static.data('x', [100, 100]) + y1, indices1 = paddle.cummax(x) + y2, indices2 = paddle.cummax(x, axis=0) + y3, indices3 = paddle.cummax(x, axis=-1) + y4, indices4 = paddle.cummax(x, axis=-2) + y5, indices5 = paddle.cummax(x, axis=-2, dtype=np.int32) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + out = exe.run( + feed={'x': data_np}, + fetch_list=[ + y1.name, + indices1.name, + y2.name, + indices2.name, + y3.name, + indices3.name, + y4.name, + indices4.name, + y5.name, + indices5.name, + ], + ) + + z, ind = cummax_dim2(data_np) + np.testing.assert_allclose(z, out[0], rtol=1e-05) + np.testing.assert_allclose(ind, out[1], rtol=1e-05) + + z, ind = cummax_dim2(data_np, axis=0) + np.testing.assert_allclose(z, out[2], rtol=1e-05) + np.testing.assert_allclose(ind, out[3], rtol=1e-05) + + z, ind = cummax_dim2(data_np, axis=-1) + np.testing.assert_allclose(z, out[4], rtol=1e-05) + np.testing.assert_allclose(ind, out[5], rtol=1e-05) + + z, ind = cummax_dim2(data_np, axis=-2) + np.testing.assert_allclose(z, out[6], rtol=1e-05) + np.testing.assert_allclose(ind, out[7], rtol=1e-05) + + z, ind = cummax_dim2(data_np, axis=-2) + np.testing.assert_allclose(z, out[8], rtol=1e-05) + np.testing.assert_allclose(ind, out[9], rtol=1e-05) + + def test_cpu(self): + paddle.disable_static(paddle.fluid.CPUPlace()) + self.run_cases() + paddle.enable_static() + self.run_static() + + def test_gpu(self): + if not fluid.core.is_compiled_with_cuda(): + return + paddle.disable_static(paddle.fluid.CUDAPlace(0)) + self.run_cases() + paddle.enable_static() + self.run_static(use_gpu=True) + + def test_errors(self): + paddle.enable_static() + with fluid.program_guard(fluid.Program()): + + def test_x_type(): + data = [1, 2, 3] + y, indices = paddle.cummax(data, axis=0) + + self.assertRaises(TypeError, test_x_type) + paddle.disable_static() + + def test_indices_type(): + data_np = np.random.random((10, 10)).astype(np.float32) + data = paddle.to_tensor(data_np) + y, indices = paddle.cummax(data, dtype='float32') + + self.assertRaises(ValueError, test_indices_type) + + def test_axis_outrange(): + data_np = np.random.random(100).astype(np.float32) + data = paddle.to_tensor(data_np) + y, indices = paddle.cummax(data, axis=-2) + + self.assertRaises(IndexError, test_axis_outrange) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_cummin_op.py b/test/legacy_test/test_cummin_op.py new file mode 100644 index 0000000000000..dc542ebe90077 --- /dev/null +++ b/test/legacy_test/test_cummin_op.py @@ -0,0 +1,245 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np +from eager_op_test import OpTest + +import paddle +from paddle import fluid +from paddle.fluid import core + + +def cummin_dim2(arr, axis=None): + if axis is None: + arr = arr.flatten() + cummin = np.minimum.accumulate(arr) + shape = arr.shape + indices = np.zeros(shape).astype('int32') + min_val = sys.maxsize + min_ind = 0 + for i in range(shape[0]): + if arr[i] <= min_val: + min_val = min(arr[i], min_val) + min_ind = i + indices[i] = i + else: + indices[i] = min_ind + else: + cummin = np.minimum.accumulate(arr, axis) + shape = arr.shape + indices = np.zeros(shape).astype('int32') + if axis < 0: + axis = axis + len(shape) + if axis == 0: + for j in range(shape[1]): + min_ind = 0 + min_val = sys.maxsize + for i in range(shape[0]): + if arr[i][j] <= min_val: + min_val = arr[i][j] + min_ind = i + indices[i][j] = i + else: + indices[i][j] = min_ind + elif axis == 1: + for i in range(shape[0]): + min_ind = 0 + min_val = sys.maxsize + for j in range(shape[1]): + if arr[i][j] <= min_val: + min_val = arr[i][j] + min_ind = j + indices[i][j] = j + else: + indices[i][j] = min_ind + else: + raise Exception("unfeasible axis") + return cummin, indices + + +class TestCumminOp(OpTest): + def setUp(self): + self.op_type = "cummin" + self.python_api = paddle.cummin + self.dtype = np.float64 + self.axis = -1 + self.indices_type = 3 + self.input_data = np.random.random((10, 10)).astype(self.dtype) + self.set_attrs() + + self.inputs = {'x': self.input_data} + self.attrs = {'axis': self.axis, 'dtype': self.indices_type} + self.np_res, self.np_ind = cummin_dim2(self.input_data, axis=self.axis) + self.outputs = {'out': self.np_res, 'indices': self.np_ind} + + def set_attrs(self): + pass + + def test_check_output(self): + paddle.enable_static() + self.check_output() + + def test_check_grad(self): + paddle.enable_static() + self.check_grad(['x'], 'out') + + +class TestCuinOpAxis1(TestCumminOp): + def set_attrs(self): + self.axis = 0 + + +class TestCumminOpAxis2(TestCumminOp): + def set_attrs(self): + self.axis = -2 + + +class TestCumminOpIndexType(TestCumminOp): + def set_attrs(self): + self.indices_type = 2 + + +class TestCumminAPI(unittest.TestCase): + def run_cases(self): + data_np = np.random.random((100, 100)).astype(np.float32) + data = paddle.to_tensor(data_np) + + y, indices = paddle.cummin(data) + z, ind = cummin_dim2(data_np) + np.testing.assert_array_equal(z, y.numpy()) + np.testing.assert_array_equal(ind, indices.numpy()) + + y, indices = paddle.cummin(data, axis=0) + z, ind = cummin_dim2(data_np, axis=0) + np.testing.assert_array_equal(z, y.numpy()) + np.testing.assert_array_equal(ind, indices.numpy()) + + y, indices = paddle.cummin(data, axis=-1) + z, ind = cummin_dim2(data_np, axis=-1) + np.testing.assert_array_equal(z, y.numpy()) + np.testing.assert_array_equal(ind, indices.numpy()) + + y, indices = paddle.cummin(data, axis=-2) + z, ind = cummin_dim2(data_np, axis=-2) + np.testing.assert_array_equal(z, y.numpy()) + np.testing.assert_array_equal(ind, indices.numpy()) + + y, indices = paddle.cummin(data, axis=-2, dtype='int32') + z, ind = cummin_dim2(data_np, axis=-2) + np.testing.assert_array_equal(z, y.numpy()) + np.testing.assert_array_equal(ind, indices.numpy()) + self.assertTrue(indices.dtype == core.VarDesc.VarType.INT32) + + data_np = np.random.randint(0, 10, size=(100, 100)).astype(np.int32) + data = paddle.to_tensor(data_np) + y, indices = paddle.cummin(data, axis=0) + z, ind = cummin_dim2(data_np, axis=0) + np.testing.assert_array_equal(z, y.numpy()) + np.testing.assert_array_equal(ind, indices.numpy()) + + def run_static(self, use_gpu=False): + with fluid.program_guard(fluid.Program()): + data_np = np.random.random((100, 100)).astype(np.float32) + x = paddle.static.data('x', [100, 100]) + y1, indices1 = paddle.cummin(x) + y2, indices2 = paddle.cummin(x, axis=0) + y3, indices3 = paddle.cummin(x, axis=-1) + y4, indices4 = paddle.cummin(x, axis=-2) + y5, indices5 = paddle.cummin(x, axis=-2, dtype=np.int32) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + out = exe.run( + feed={'x': data_np}, + fetch_list=[ + y1.name, + indices1.name, + y2.name, + indices2.name, + y3.name, + indices3.name, + y4.name, + indices4.name, + y5.name, + indices5.name, + ], + ) + + z, ind = cummin_dim2(data_np) + np.testing.assert_allclose(z, out[0], rtol=1e-05) + np.testing.assert_allclose(ind, out[1], rtol=1e-05) + + z, ind = cummin_dim2(data_np, axis=0) + np.testing.assert_allclose(z, out[2], rtol=1e-05) + np.testing.assert_allclose(ind, out[3], rtol=1e-05) + + z, ind = cummin_dim2(data_np, axis=-1) + np.testing.assert_allclose(z, out[4], rtol=1e-05) + np.testing.assert_allclose(ind, out[5], rtol=1e-05) + + z, ind = cummin_dim2(data_np, axis=-2) + np.testing.assert_allclose(z, out[6], rtol=1e-05) + np.testing.assert_allclose(ind, out[7], rtol=1e-05) + + z, ind = cummin_dim2(data_np, axis=-2) + np.testing.assert_allclose(z, out[8], rtol=1e-05) + np.testing.assert_allclose(ind, out[9], rtol=1e-05) + + def test_cpu(self): + paddle.disable_static(paddle.fluid.CPUPlace()) + self.run_cases() + paddle.enable_static() + self.run_static() + + def test_gpu(self): + if not fluid.core.is_compiled_with_cuda(): + return + paddle.disable_static(paddle.fluid.CUDAPlace(0)) + self.run_cases() + paddle.enable_static() + self.run_static(use_gpu=True) + + def test_errors(self): + paddle.enable_static() + with fluid.program_guard(fluid.Program()): + + def test_x_type(): + data = [1, 2, 3] + y, indices = paddle.cummin(data, axis=0) + + self.assertRaises(TypeError, test_x_type) + + paddle.disable_static() + + def test_indices_type(): + data_np = np.random.random((10, 10)).astype(np.float32) + data = paddle.to_tensor(data_np) + y, indices = paddle.cummin(data, dtype='float32') + + self.assertRaises(ValueError, test_indices_type) + + def test_axis_outrange(): + data_np = np.random.random(100).astype(np.float32) + data = paddle.to_tensor(data_np) + y, indices = paddle.cummin(data, axis=-2) + + self.assertRaises(IndexError, test_axis_outrange) + + +if __name__ == '__main__': + unittest.main()