Skip to content

Commit

Permalink
impl max_func
Browse files Browse the repository at this point in the history
  • Loading branch information
joey12300 committed Sep 19, 2021
1 parent 6ddc7d4 commit 36f371b
Showing 1 changed file with 35 additions and 11 deletions.
46 changes: 35 additions & 11 deletions paddle/fluid/operators/viterbi_decode_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */
#include <type_traits>
#include <vector>

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
Expand All @@ -30,7 +31,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/unique_op.h"
#include "paddle/fluid/operators/utils.h"
Expand All @@ -39,6 +39,16 @@ namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

#define CREATE_TENSOR(tensor, dtype, ...) \
LoDTensor tensor; \
Expand All @@ -65,23 +75,36 @@ using LoDTensor = framework::LoDTensor;
dev_ctx); \
cast_functor.template apply<dtype>()

template <typename T>
struct MaxFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->maximum(dim);
}
};

template <typename DeviceContext, typename T>
template <typename DeviceContext, typename T, size_t D, size_t R_D>
inline void MAX_FUNC(const framework::ExecutionContext& ctx,
const Tensor* input, Tensor* output,
const std::vector<int>& dims) {
auto cast_out_dtype =
static_cast<framework::proto::VarType::Type>(output->type());
framework::VisitDataType(cast_out_dtype,
ReduceKernelFunctor<DeviceContext, T, MaxFunctor<T>>(
input, output, dims, false, false, ctx));
auto x = EigenTensor<T, D>::From(*input);
auto x_rank = static_cast<int>(x.dimensions().size());
auto reduce_dim = Eigen::array<int, R_D>();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::vector<int> dims_ref = dims;
for (size_t i = 0; i < dims_ref.size(); ++i) {
if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i];
reduce_dim[i] = dims_ref[i];
}
DDim out_dims = output->dims();
auto& place = *dev_ctx.eigen_device();
MaxFunctor functor;
if (D == 1) {
auto out = EigenScalar<T>::From(*output);
functor(place, &x, &out, reduce_dim);
} else {
auto out = EigenTensor<T, (D - R_D)>::From(*output, out_dims);
functor(place, &x, &out, reduce_dim);
}
}

class TensorBuffer {
Expand Down Expand Up @@ -223,8 +246,8 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> {

ADD(alpha_exp, trans_exp, alpha_trn_sum, T);

MAX_FUNC<DeviceContext, T>(ctx, &alpha_trn_sum, &alpha_max,
std::vector<int>({1}));
MAX_FUNC<DeviceContext, T, 3, 1>(ctx, &alpha_trn_sum, &alpha_max,
std::vector<int>({1}));

auto alpha_argmax_temp = alpha_argmax_unbind[i - 1];
alpha_argmax_temp.Resize({batch_size, n_labels});
Expand Down Expand Up @@ -261,7 +284,8 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> {
}

// scores, last_ids = alpha.max(1), alpha.argmax(1)
MAX_FUNC<DeviceContext, T>(ctx, &alpha, scores, std::vector<int>({1}));
MAX_FUNC<DeviceContext, T, 2, 1>(ctx, &alpha, scores,
std::vector<int>({1}));
ArgMinMaxFunctor<DeviceContext, T, int64_t, 2, ArgMinMaxType::kArgMax>
argmax2;

Expand Down

0 comments on commit 36f371b

Please sign in to comment.