From 36f371b16ca22fcbb89e1d9d390efc8f7a72e49e Mon Sep 17 00:00:00 2001 From: zhoushunjie Date: Mon, 20 Sep 2021 00:52:39 +0800 Subject: [PATCH] impl max_func --- paddle/fluid/operators/viterbi_decode_op.h | 46 ++++++++++++++++------ 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/viterbi_decode_op.h b/paddle/fluid/operators/viterbi_decode_op.h index ac5c86134d5d7..38436c23f4184 100644 --- a/paddle/fluid/operators/viterbi_decode_op.h +++ b/paddle/fluid/operators/viterbi_decode_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/arg_min_max_op_base.h" @@ -30,7 +31,6 @@ limitations under the License. */ #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/functors.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/operators/unique_op.h" #include "paddle/fluid/operators/utils.h" @@ -39,6 +39,16 @@ namespace paddle { namespace operators { using LoDTensor = framework::LoDTensor; +using DDim = framework::DDim; +template +using EigenTensor = framework::EigenTensor; +template +using EigenScalar = framework::EigenScalar; +template +using EigenVector = framework::EigenVector; #define CREATE_TENSOR(tensor, dtype, ...) \ LoDTensor tensor; \ @@ -65,7 +75,6 @@ using LoDTensor = framework::LoDTensor; dev_ctx); \ cast_functor.template apply() -template struct MaxFunctor { template void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { @@ -73,15 +82,29 @@ struct MaxFunctor { } }; -template +template inline void MAX_FUNC(const framework::ExecutionContext& ctx, const Tensor* input, Tensor* output, const std::vector& dims) { - auto cast_out_dtype = - static_cast(output->type()); - framework::VisitDataType(cast_out_dtype, - ReduceKernelFunctor>( - input, output, dims, false, false, ctx)); + auto x = EigenTensor::From(*input); + auto x_rank = static_cast(x.dimensions().size()); + auto reduce_dim = Eigen::array(); + auto& dev_ctx = ctx.template device_context(); + std::vector dims_ref = dims; + for (size_t i = 0; i < dims_ref.size(); ++i) { + if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i]; + reduce_dim[i] = dims_ref[i]; + } + DDim out_dims = output->dims(); + auto& place = *dev_ctx.eigen_device(); + MaxFunctor functor; + if (D == 1) { + auto out = EigenScalar::From(*output); + functor(place, &x, &out, reduce_dim); + } else { + auto out = EigenTensor::From(*output, out_dims); + functor(place, &x, &out, reduce_dim); + } } class TensorBuffer { @@ -223,8 +246,8 @@ class ViterbiDecodeKernel : public framework::OpKernel { ADD(alpha_exp, trans_exp, alpha_trn_sum, T); - MAX_FUNC(ctx, &alpha_trn_sum, &alpha_max, - std::vector({1})); + MAX_FUNC(ctx, &alpha_trn_sum, &alpha_max, + std::vector({1})); auto alpha_argmax_temp = alpha_argmax_unbind[i - 1]; alpha_argmax_temp.Resize({batch_size, n_labels}); @@ -261,7 +284,8 @@ class ViterbiDecodeKernel : public framework::OpKernel { } // scores, last_ids = alpha.max(1), alpha.argmax(1) - MAX_FUNC(ctx, &alpha, scores, std::vector({1})); + MAX_FUNC(ctx, &alpha, scores, + std::vector({1})); ArgMinMaxFunctor argmax2;