From 1906c74622664394a0967d0a257eee501611343c Mon Sep 17 00:00:00 2001 From: Jack Zhou Date: Sun, 24 Oct 2021 07:59:35 +0800 Subject: [PATCH] Add viterbi decode (#35778) (#36615) * add viterbi decode cpu kernel * add viterbi decoder api in paddle.text * add a data buffer once to avoid create many small pieces of data buffer frequently * fix viterbi max_seq_length bug * fix seq_len=1 bug * fix device context * move split out of for loop * remove INVERSE_SUB * remove 2 GET_CAST_MASK * remove 1 loop * remove Functor * add to_static deploy code * use MAX_FUNC instead of ELE_MAX * add MaxFunctor * impl max_func * remove MaxFunctor * remove cast op * use REGISTER_OP_WITHOUT_GRADIENT * add viterbi cuda kernel * add FIX_BLOCKDIM_CASE macro * add MKL add, mul; add get data mask * add arange mkl impl * add CPU Argmax * add cpu gather * use EXECUTE_MKL_ELEMENT_BINARY_OP instead of some ADD, MUL * use SameDimsBinaryOP instead of EXECUTE_MKL_ELEMENT_BINARY_OP * use SAME_DIMS_ELEMENT_BINARY_OP * add SimpleBroadcastBinaryOP * use int instead of int64_t to accelerate * optimize SimpleBroadcastBinaryOP * optimize SimpleBroadcastBinaryOP * optimize performance in both single thread and multithread situation * remove useless line * remove useless code * add CREATE_TENSOR_BUFFER macro * add INIT_REQUIRED_TENSOR macro * add comment * fix windows ci * add viterbi unittest * remove cuda add functor * remove cuda equal * remove a template function * fix windows ci * fix windows dtype * remove some template instance * remove useless header file * remove some blockdim * remove transpose impl * accelerate cpu performance on single thread situation * viterbi_decode->crf_decode * rename crf params name * add viterbi api test * remove useless import * add enable_static * use viterbi decoder * fix viterbi len=1 * fix viterbi unittest * remove useless comments * reconstruct viterbi decode * remove ADD,SUB,MUL structure * fix coverage * remove CREATE_TENSOR * add name args * crf.py->ops.py; with_start_stop_tag->include_start_end_tag * update crf_decode en docs * fix viterbi decode en docs * fix some review comments * add FIXED_BLOCK_DIM_CASE in cuda * push_back->emplace_back * crf_decode->viterbi_decode; include_start_end_tag->include_bos_eos_tag * paddle.text.ops.viterbi_decode->paddle.text.viterbi_decode * fix viterbi_decode en docs --- .../elementwise/elementwise_op_function.h | 4 +- paddle/fluid/operators/viterbi_decode_op.cc | 109 +++++ paddle/fluid/operators/viterbi_decode_op.cu | 200 +++++++++ paddle/fluid/operators/viterbi_decode_op.h | 415 ++++++++++++++++++ .../tests/unittests/test_viterbi_decode_op.py | 134 ++++++ python/paddle/text/__init__.py | 6 +- python/paddle/text/viterbi_decode.py | 132 ++++++ 7 files changed, 996 insertions(+), 4 deletions(-) create mode 100644 paddle/fluid/operators/viterbi_decode_op.cc create mode 100644 paddle/fluid/operators/viterbi_decode_op.cu create mode 100644 paddle/fluid/operators/viterbi_decode_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_viterbi_decode_op.py create mode 100644 python/paddle/text/viterbi_decode.py diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 312978a010b30..2df7dd06f2cc8 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -240,7 +240,7 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, x_dims, y_dims, x_dims_array[i], y_dims_array[i], i)); if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) || (x_dims_array[i] == 1 && y_dims_array[i] == 1)) { - out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]); + out_dims_array[i] = (std::max)(x_dims_array[i], y_dims_array[i]); } else { out_dims_array[i] = -1; } @@ -1779,7 +1779,7 @@ void CommonElementwiseBroadcastForward( const framework::Tensor *y, framework::Tensor *z, const framework::DDim &x_dims, const framework::DDim &y_dims, Functor func, int axis, const bool is_xsize_larger = true) { - int max_dim = std::max(x_dims.size(), y_dims.size()); + int max_dim = (std::max)(x_dims.size(), y_dims.size()); axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); PADDLE_ENFORCE_GE( axis, 0, diff --git a/paddle/fluid/operators/viterbi_decode_op.cc b/paddle/fluid/operators/viterbi_decode_op.cc new file mode 100644 index 0000000000000..bf1cdeed65a84 --- /dev/null +++ b/paddle/fluid/operators/viterbi_decode_op.cc @@ -0,0 +1,109 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/viterbi_decode_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class ViterbiDecodeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ViterbiDecode"); + OP_INOUT_CHECK(ctx->HasInput("Transition"), "Input", "Transition", + "ViterbiDecode"); + OP_INOUT_CHECK(ctx->HasInput("Length"), "Input", "Length", "ViterbiDecode"); + OP_INOUT_CHECK(ctx->HasOutput("Scores"), "Output", "Scores", + "ViterbiDecode"); + OP_INOUT_CHECK(ctx->HasOutput("Path"), "Output", "Path", "ViterbiDecode"); + auto in_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_EQ(in_dims.size(), 3, + platform::errors::InvalidArgument( + "The rank of Input in ViterbiDecode must be 3. But " + "received Input's rank is %d.", + in_dims.size())); + auto length_dims = ctx->GetInputDim("Length"); + PADDLE_ENFORCE_EQ(length_dims.size(), 1, + platform::errors::InvalidArgument( + "The rank of Length in ViterbiDecode must be 1. But " + "received Length's rank is %d.", + length_dims.size())); + auto transition_dims = ctx->GetInputDim("Transition"); + PADDLE_ENFORCE_EQ( + transition_dims.size(), 2, + platform::errors::InvalidArgument( + "The rank of Transition in ViterbiDecode must be 2. But " + "received Transition's rank is %d.", + transition_dims.size())); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ( + in_dims[0], length_dims[0], + platform::errors::InvalidArgument( + "The batch size of Input and Length should be equal.")); + PADDLE_ENFORCE_EQ(in_dims[2], transition_dims[0], + platform::errors::InvalidArgument( + "The number of tags of Input (%d) and Transition " + "(%d) should be equal.", + transition_dims[0], in_dims[2])); + } + ctx->SetOutputDim("Scores", length_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); + } +}; + +class ViterbiDecodeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "Input", + "The unary emission tensor. The shape of Input must be (batch_size," + "sequence_length, num_tags). "); + AddInput("Transition", + "The transition matrix. The shape of Transition must be ( " + "num_tags, num_tags). "); + AddInput("Length", + "The input length tensor storing real length of each sequence for " + "correctness. The shape of Length MUST be (batch_size)."); + AddOutput("Scores", + "The scores tensor containing the score for the Viterbi " + "sequence. The shape of Scores MUST be (batch_size)."); + AddOutput("Path", + "The paths tensor containing the highest scoring tag indices. " + "The shape of Scores MUST be (batch_size, sequence_length)."); + AddAttr("include_bos_eos_tag", + "If set to True, the last row and the last column of " + "transitions will be considered as start tag.") + .SetDefault(true); + AddComment(R"DOC( + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace platform = paddle::platform; +REGISTER_OP_WITHOUT_GRADIENT(viterbi_decode, ops::ViterbiDecodeOp, + ops::ViterbiDecodeOpMaker); +REGISTER_OP_CPU_KERNEL( + viterbi_decode, ops::ViterbiDecodeKernel, + ops::ViterbiDecodeKernel); diff --git a/paddle/fluid/operators/viterbi_decode_op.cu b/paddle/fluid/operators/viterbi_decode_op.cu new file mode 100644 index 0000000000000..086ff05b08461 --- /dev/null +++ b/paddle/fluid/operators/viterbi_decode_op.cu @@ -0,0 +1,200 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/elementwise/elementwise_functor.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/gather.cu.h" +#include "paddle/fluid/operators/viterbi_decode_op.h" + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +namespace paddle { +namespace operators { + +#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \ + case (1 << (log2_block_dim)): { \ + constexpr auto kBlockDim = (1 << (log2_block_dim)); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM_CASE(...) \ + FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); + +int64_t ComputeBlockSize(int64_t col) { + if (col > 512) + return 1024; + else if (col > 256) + return 512; + else if (col > 128) + return 256; + else if (col > 64) + return 128; + else if (col > 32) + return 64; + else if (col > 16) + return 32; + else if (col > 8) + return 16; + else + return 8; +} + +template