Skip to content

Commit

Permalink
Add viterbi decode (#35778) (#36615)
Browse files Browse the repository at this point in the history
* add viterbi decode cpu kernel

* add viterbi decoder api in paddle.text

* add a data buffer once to avoid create many small pieces of data buffer frequently

* fix viterbi max_seq_length bug

* fix seq_len=1 bug

* fix device context

* move split out of for loop

* remove INVERSE_SUB

* remove 2 GET_CAST_MASK

* remove 1 loop

* remove Functor

* add to_static deploy code

* use MAX_FUNC instead of ELE_MAX

* add MaxFunctor

* impl max_func

* remove MaxFunctor

* remove cast op

* use REGISTER_OP_WITHOUT_GRADIENT

* add viterbi cuda kernel

* add FIX_BLOCKDIM_CASE macro

* add MKL add, mul; add get data mask

* add arange mkl impl

* add CPU Argmax

* add cpu gather

* use EXECUTE_MKL_ELEMENT_BINARY_OP instead of some ADD, MUL

* use SameDimsBinaryOP instead of EXECUTE_MKL_ELEMENT_BINARY_OP

* use SAME_DIMS_ELEMENT_BINARY_OP

* add SimpleBroadcastBinaryOP

* use int instead of int64_t to accelerate

* optimize SimpleBroadcastBinaryOP

* optimize SimpleBroadcastBinaryOP

* optimize performance in both single thread and multithread situation

* remove useless line

* remove useless code

* add CREATE_TENSOR_BUFFER macro

* add INIT_REQUIRED_TENSOR macro

* add comment

* fix windows ci

* add viterbi unittest

* remove cuda add functor

* remove cuda equal

* remove a template function

* fix windows ci

* fix windows dtype

* remove some template instance

* remove useless header file

* remove some blockdim

* remove transpose impl

* accelerate cpu performance on single thread situation

* viterbi_decode->crf_decode

* rename crf params name

* add viterbi api test

* remove useless import

* add enable_static

* use viterbi decoder

* fix viterbi len=1

* fix  viterbi unittest

* remove useless comments

* reconstruct viterbi decode

* remove ADD,SUB,MUL structure

* fix coverage

* remove CREATE_TENSOR

* add name args

* crf.py->ops.py; with_start_stop_tag->include_start_end_tag

* update crf_decode en docs

* fix viterbi decode en docs

* fix some review comments

* add FIXED_BLOCK_DIM_CASE in cuda

* push_back->emplace_back

* crf_decode->viterbi_decode; include_start_end_tag->include_bos_eos_tag

* paddle.text.ops.viterbi_decode->paddle.text.viterbi_decode

* fix viterbi_decode en docs
  • Loading branch information
joey12300 committed Oct 23, 2021
1 parent 6840cf5 commit 1906c74
Show file tree
Hide file tree
Showing 7 changed files with 996 additions and 4 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
x_dims, y_dims, x_dims_array[i], y_dims_array[i], i));
if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) ||
(x_dims_array[i] == 1 && y_dims_array[i] == 1)) {
out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]);
out_dims_array[i] = (std::max)(x_dims_array[i], y_dims_array[i]);
} else {
out_dims_array[i] = -1;
}
Expand Down Expand Up @@ -1779,7 +1779,7 @@ void CommonElementwiseBroadcastForward(
const framework::Tensor *y, framework::Tensor *z,
const framework::DDim &x_dims, const framework::DDim &y_dims, Functor func,
int axis, const bool is_xsize_larger = true) {
int max_dim = std::max(x_dims.size(), y_dims.size());
int max_dim = (std::max)(x_dims.size(), y_dims.size());
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(
axis, 0,
Expand Down
109 changes: 109 additions & 0 deletions paddle/fluid/operators/viterbi_decode_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/viterbi_decode_op.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

class ViterbiDecodeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ViterbiDecode");
OP_INOUT_CHECK(ctx->HasInput("Transition"), "Input", "Transition",
"ViterbiDecode");
OP_INOUT_CHECK(ctx->HasInput("Length"), "Input", "Length", "ViterbiDecode");
OP_INOUT_CHECK(ctx->HasOutput("Scores"), "Output", "Scores",
"ViterbiDecode");
OP_INOUT_CHECK(ctx->HasOutput("Path"), "Output", "Path", "ViterbiDecode");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of Input in ViterbiDecode must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
auto length_dims = ctx->GetInputDim("Length");
PADDLE_ENFORCE_EQ(length_dims.size(), 1,
platform::errors::InvalidArgument(
"The rank of Length in ViterbiDecode must be 1. But "
"received Length's rank is %d.",
length_dims.size()));
auto transition_dims = ctx->GetInputDim("Transition");
PADDLE_ENFORCE_EQ(
transition_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Transition in ViterbiDecode must be 2. But "
"received Transition's rank is %d.",
transition_dims.size()));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
in_dims[0], length_dims[0],
platform::errors::InvalidArgument(
"The batch size of Input and Length should be equal."));
PADDLE_ENFORCE_EQ(in_dims[2], transition_dims[0],
platform::errors::InvalidArgument(
"The number of tags of Input (%d) and Transition "
"(%d) should be equal.",
transition_dims[0], in_dims[2]));
}
ctx->SetOutputDim("Scores", length_dims);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};

class ViterbiDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"Input",
"The unary emission tensor. The shape of Input must be (batch_size,"
"sequence_length, num_tags). ");
AddInput("Transition",
"The transition matrix. The shape of Transition must be ( "
"num_tags, num_tags). ");
AddInput("Length",
"The input length tensor storing real length of each sequence for "
"correctness. The shape of Length MUST be (batch_size).");
AddOutput("Scores",
"The scores tensor containing the score for the Viterbi "
"sequence. The shape of Scores MUST be (batch_size).");
AddOutput("Path",
"The paths tensor containing the highest scoring tag indices. "
"The shape of Scores MUST be (batch_size, sequence_length).");
AddAttr<bool>("include_bos_eos_tag",
"If set to True, the last row and the last column of "
"transitions will be considered as start tag.")
.SetDefault(true);
AddComment(R"DOC(
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace platform = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(viterbi_decode, ops::ViterbiDecodeOp,
ops::ViterbiDecodeOpMaker);
REGISTER_OP_CPU_KERNEL(
viterbi_decode, ops::ViterbiDecodeKernel<platform::CPUDeviceContext, float>,
ops::ViterbiDecodeKernel<platform::CPUDeviceContext, double>);
200 changes: 200 additions & 0 deletions paddle/fluid/operators/viterbi_decode_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/viterbi_decode_op.h"

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

namespace paddle {
namespace operators {

#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \
case (1 << (log2_block_dim)): { \
constexpr auto kBlockDim = (1 << (log2_block_dim)); \
__VA_ARGS__; \
} break

#define FIXED_BLOCK_DIM_CASE(...) \
FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__);

int64_t ComputeBlockSize(int64_t col) {
if (col > 512)
return 1024;
else if (col > 256)
return 512;
else if (col > 128)
return 256;
else if (col > 64)
return 128;
else if (col > 32)
return 64;
else if (col > 16)
return 32;
else if (col > 8)
return 16;
else
return 8;
}

template <template <typename T> typename BinaryFunctor, typename T>
struct BinaryOperation<platform::CUDADeviceContext, BinaryFunctor, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx, const Tensor& lhs,
const Tensor& rhs, Tensor* output) {
std::vector<const Tensor*> ins{&lhs, &rhs};
std::vector<Tensor*> outs{output};
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, -1, BinaryFunctor<T>());
}
};

template <template <typename T> typename CompareFunctor, typename T>
struct GetMask<platform::CUDADeviceContext, CompareFunctor, T> {
void operator()(const framework::ExecutionContext& ctx, const Tensor& lhs,
const Tensor& rhs, Tensor* mask) {
std::vector<const Tensor*> ins = {&lhs, &rhs};
std::vector<Tensor*> outs = {mask};
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, int64_t, T>(
dev_ctx, ins, &outs, CompareFunctor<int64_t>());
}
};

template <typename T, typename IndType, size_t BlockDim>
__global__ void ArgmaxCUDAKernel(const int64_t height, // n * h
const int64_t width, // c
const int64_t post_size, // h
const T* in, IndType* out_idx, T* out) {
typedef cub::BlockReduce<cub::KeyValuePair<int, T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
cub::ArgMax reducer;
T init = (std::numeric_limits<T>::lowest)(); // for windows compile
for (int idx = blockIdx.x; idx < height; idx += gridDim.x) {
cub::KeyValuePair<int, T> kv_pair = {-1, init};
int h = idx / post_size;
int w = idx % post_size;
for (int k = threadIdx.x; k < width; k += blockDim.x) {
kv_pair =
reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair);
}
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer);
if (threadIdx.x == 0) {
// return max, argmax
if (out_idx != nullptr) out_idx[idx] = static_cast<IndType>(kv_pair.key);
if (out != nullptr) out[idx] = kv_pair.value;
}
__syncthreads();
}
}

__global__ void ARangeKernel(int64_t* data, int num, int64_t scale) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int start = idx; idx < num; idx += gridDim.x) {
data[idx] = idx * scale;
}
}

template <>
struct ARange<platform::CUDADeviceContext> {
void operator()(const platform::CUDADeviceContext& dev_ctx, int64_t* data,
int num, int64_t scale) {
int64_t kBlockDim = ComputeBlockSize(num);
// kBlockDim > num at most of time, so we can set grid = 1
ARangeKernel<<<1, kBlockDim, 0, dev_ctx.stream()>>>(data, num, scale);
}
};

template <typename T, typename IndType>
struct Argmax<platform::CUDADeviceContext, T, IndType> {
void operator()(const framework::ExecutionContext& ctx, const Tensor& input,
Tensor* out_idx, Tensor* out, int axis) {
framework::DDim input_dims = input.dims();
int64_t numel = input.numel();
int64_t groups = numel / input_dims[axis];
int64_t pre = 1;
int64_t post = 1;
int64_t n = input_dims[axis];
for (int i = 0; i < axis; i++) {
pre *= input_dims[i];
}
for (int i = axis + 1; i < input_dims.size(); i++) {
post *= input_dims[i];
}
const auto& dev_ctx = ctx.cuda_device_context();
auto cu_stream = dev_ctx.stream();
int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize().x;
int64_t height = pre * post;
int64_t width = n;
int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx;
const T* in_data = input.data<T>();
IndType* out_idx_data = out_idx->data<IndType>();
T* out_data = out->data<T>();
switch (ComputeBlockSize(width)) {
FIXED_BLOCK_DIM_CASE(
ArgmaxCUDAKernel<T, IndType,
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
height, width, post, in_data, out_idx_data, out_data));
}
}
};

template <typename T>
struct GetMaxValue<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const Tensor& input, T* max_value) {
Tensor out_data;
out_data.Resize(framework::make_ddim({1}));
out_data.mutable_data<T>(platform::CUDAPlace());
switch (ComputeBlockSize(input.numel())) {
FIXED_BLOCK_DIM_CASE(
ArgmaxCUDAKernel<T, T,
kBlockDim><<<1, kBlockDim, 0, dev_ctx.stream()>>>(
1, input.numel(), 1, input.data<int64_t>(), nullptr,
out_data.data<int64_t>()));
}
Tensor max_value_tensor;
framework::TensorCopy(out_data, platform::CPUPlace(), &max_value_tensor);
*max_value = max_value_tensor.data<T>()[0];
}
};

template <typename T, typename IndexT>
struct Gather<platform::CUDADeviceContext, T, IndexT> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
GPUGather<T, IndexT>(ctx, src, index, output);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace platform = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
viterbi_decode,
ops::ViterbiDecodeKernel<platform::CUDADeviceContext, float>,
ops::ViterbiDecodeKernel<platform::CUDADeviceContext, double>);
Loading

0 comments on commit 1906c74

Please sign in to comment.