Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add viterbi decode #35778

Merged
merged 74 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
93e70a9
add viterbi decode cpu kernel
joey12300 Sep 15, 2021
3f33b76
add viterbi decoder api in paddle.text
joey12300 Sep 15, 2021
c176fca
add a data buffer once to avoid create many small pieces of data buff…
joey12300 Sep 16, 2021
f137f6b
fix viterbi max_seq_length bug
joey12300 Sep 16, 2021
dcbc972
fix seq_len=1 bug
joey12300 Sep 16, 2021
d258e07
fix device context
joey12300 Sep 17, 2021
f75a4be
move split out of for loop
joey12300 Sep 17, 2021
3ec4097
remove INVERSE_SUB
joey12300 Sep 17, 2021
188d933
remove 2 GET_CAST_MASK
joey12300 Sep 18, 2021
69e1f85
remove 1 loop
joey12300 Sep 18, 2021
face1f1
remove Functor
joey12300 Sep 19, 2021
08daa51
add to_static deploy code
joey12300 Sep 19, 2021
a0777ff
use MAX_FUNC instead of ELE_MAX
joey12300 Sep 19, 2021
6ddc7d4
add MaxFunctor
joey12300 Sep 19, 2021
36f371b
impl max_func
joey12300 Sep 19, 2021
9525039
remove MaxFunctor
joey12300 Sep 19, 2021
2698874
remove cast op
joey12300 Sep 20, 2021
8d0b3f6
use REGISTER_OP_WITHOUT_GRADIENT
joey12300 Sep 21, 2021
5d2259b
add viterbi cuda kernel
joey12300 Sep 22, 2021
425ceaf
add FIX_BLOCKDIM_CASE macro
joey12300 Sep 22, 2021
e17cff1
add MKL add, mul; add get data mask
joey12300 Sep 22, 2021
ee08aab
add arange mkl impl
joey12300 Sep 22, 2021
f48240a
add CPU Argmax
joey12300 Sep 22, 2021
7082e26
add cpu gather
joey12300 Sep 22, 2021
a893027
use EXECUTE_MKL_ELEMENT_BINARY_OP instead of some ADD, MUL
joey12300 Sep 22, 2021
9cf6fc4
use SameDimsBinaryOP instead of EXECUTE_MKL_ELEMENT_BINARY_OP
joey12300 Sep 22, 2021
a1ee241
use SAME_DIMS_ELEMENT_BINARY_OP
joey12300 Sep 22, 2021
2540dc2
add SimpleBroadcastBinaryOP
joey12300 Sep 22, 2021
6a4f579
use int instead of int64_t to accelerate
joey12300 Sep 23, 2021
8ded54b
optimize SimpleBroadcastBinaryOP
joey12300 Sep 23, 2021
0965c95
optimize SimpleBroadcastBinaryOP
joey12300 Sep 23, 2021
2951a99
optimize performance in both single thread and multithread situation
joey12300 Sep 23, 2021
ea2155d
remove useless line
joey12300 Sep 23, 2021
a5cfe57
remove useless code
joey12300 Sep 24, 2021
b789dbf
add CREATE_TENSOR_BUFFER macro
joey12300 Sep 24, 2021
c722163
add INIT_REQUIRED_TENSOR macro
joey12300 Sep 24, 2021
e70f560
add comment
joey12300 Sep 24, 2021
979db61
fix windows ci
joey12300 Sep 24, 2021
d543b21
add viterbi unittest
joey12300 Sep 24, 2021
fd3417f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
joey12300 Sep 24, 2021
10e1056
remove cuda add functor
joey12300 Sep 24, 2021
5f23d46
remove cuda equal
joey12300 Sep 24, 2021
49493f9
remove a template function
joey12300 Sep 24, 2021
564117b
fix windows ci
joey12300 Sep 24, 2021
6e46a75
fix windows dtype
joey12300 Sep 24, 2021
c5071e1
remove some template instance
joey12300 Sep 24, 2021
2bee6c6
remove useless header file
joey12300 Sep 25, 2021
bc7ccf8
remove some blockdim
joey12300 Sep 25, 2021
60b3eb2
remove transpose impl
joey12300 Sep 26, 2021
34e5c8c
accelerate cpu performance on single thread situation
joey12300 Sep 26, 2021
13de468
viterbi_decode->crf_decode
joey12300 Sep 29, 2021
2713eef
rename crf params name
joey12300 Sep 29, 2021
a1d8709
add viterbi api test
joey12300 Sep 29, 2021
3317ce1
remove useless import
joey12300 Sep 29, 2021
c8a6695
add enable_static
joey12300 Sep 29, 2021
1cfb27c
use viterbi decoder
joey12300 Oct 8, 2021
0986697
fix viterbi len=1
joey12300 Oct 9, 2021
d0f70a6
fix viterbi unittest
joey12300 Oct 9, 2021
5fb7e72
remove useless comments
joey12300 Oct 9, 2021
a027c24
reconstruct viterbi decode
joey12300 Oct 11, 2021
058d2ac
remove ADD,SUB,MUL structure
joey12300 Oct 11, 2021
edf1761
fix coverage
joey12300 Oct 12, 2021
7f0b2ac
remove CREATE_TENSOR
joey12300 Oct 12, 2021
ba34cb8
add name args
joey12300 Oct 13, 2021
332375e
crf.py->ops.py; with_start_stop_tag->include_start_end_tag
joey12300 Oct 15, 2021
f4687bf
update crf_decode en docs
joey12300 Oct 15, 2021
f6fc897
fix viterbi decode en docs
joey12300 Oct 17, 2021
1c68da2
fix some review comments
joey12300 Oct 18, 2021
d116447
add FIXED_BLOCK_DIM_CASE in cuda
joey12300 Oct 18, 2021
c91752d
push_back->emplace_back
joey12300 Oct 18, 2021
c0841f1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
joey12300 Oct 18, 2021
b788042
crf_decode->viterbi_decode; include_start_end_tag->include_bos_eos_tag
joey12300 Oct 19, 2021
2586df5
paddle.text.ops.viterbi_decode->paddle.text.viterbi_decode
joey12300 Oct 20, 2021
76e31b5
fix viterbi_decode en docs
joey12300 Oct 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
x_dims, y_dims, x_dims_array[i], y_dims_array[i], i));
if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) ||
(x_dims_array[i] == 1 && y_dims_array[i] == 1)) {
out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]);
out_dims_array[i] = (std::max)(x_dims_array[i], y_dims_array[i]);
} else {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for windows ci

out_dims_array[i] = -1;
}
Expand Down Expand Up @@ -1779,7 +1779,7 @@ void CommonElementwiseBroadcastForward(
const framework::Tensor *y, framework::Tensor *z,
const framework::DDim &x_dims, const framework::DDim &y_dims, Functor func,
int axis, const bool is_xsize_larger = true) {
int max_dim = std::max(x_dims.size(), y_dims.size());
int max_dim = (std::max)(x_dims.size(), y_dims.size());
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(
axis, 0,
Expand Down
109 changes: 109 additions & 0 deletions paddle/fluid/operators/viterbi_decode_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/viterbi_decode_op.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

class ViterbiDecodeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ViterbiDecode");
OP_INOUT_CHECK(ctx->HasInput("Transition"), "Input", "Transition",
"ViterbiDecode");
OP_INOUT_CHECK(ctx->HasInput("Length"), "Input", "Length", "ViterbiDecode");
OP_INOUT_CHECK(ctx->HasOutput("Scores"), "Output", "Scores",
"ViterbiDecode");
OP_INOUT_CHECK(ctx->HasOutput("Path"), "Output", "Path", "ViterbiDecode");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of Input in ViterbiDecode must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
auto length_dims = ctx->GetInputDim("Length");
PADDLE_ENFORCE_EQ(length_dims.size(), 1,
platform::errors::InvalidArgument(
"The rank of Length in ViterbiDecode must be 1. But "
"received Length's rank is %d.",
length_dims.size()));
auto transition_dims = ctx->GetInputDim("Transition");
PADDLE_ENFORCE_EQ(
transition_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Transition in ViterbiDecode must be 2. But "
"received Transition's rank is %d.",
transition_dims.size()));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
in_dims[0], length_dims[0],
platform::errors::InvalidArgument(
"The batch size of Input and Length should be equal."));
PADDLE_ENFORCE_EQ(in_dims[2], transition_dims[0],
platform::errors::InvalidArgument(
"The number of tags of Input (%d) and Transition "
"(%d) should be equal.",
transition_dims[0], in_dims[2]));
}
ctx->SetOutputDim("Scores", length_dims);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};

class ViterbiDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"Input",
"The unary emission tensor. The shape of Input must be (batch_size,"
"sequence_length, num_tags). ");
AddInput("Transition",
"The transition matrix. The shape of Transition must be ( "
"num_tags, num_tags). ");
AddInput("Length",
"The input length tensor storing real length of each sequence for "
"correctness. The shape of Length MUST be (batch_size).");
AddOutput("Scores",
"The scores tensor containing the score for the Viterbi "
"sequence. The shape of Scores MUST be (batch_size).");
AddOutput("Path",
"The paths tensor containing the highest scoring tag indices. "
"The shape of Scores MUST be (batch_size, sequence_length).");
AddAttr<bool>("include_bos_eos_tag",
"If set to True, the last row and the last column of "
"transitions will be considered as start tag.")
.SetDefault(true);
AddComment(R"DOC(
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace platform = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(viterbi_decode, ops::ViterbiDecodeOp,
ops::ViterbiDecodeOpMaker);
REGISTER_OP_CPU_KERNEL(
viterbi_decode, ops::ViterbiDecodeKernel<platform::CPUDeviceContext, float>,
ops::ViterbiDecodeKernel<platform::CPUDeviceContext, double>);
200 changes: 200 additions & 0 deletions paddle/fluid/operators/viterbi_decode_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/viterbi_decode_op.h"

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

namespace paddle {
namespace operators {

#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \
case (1 << (log2_block_dim)): { \
constexpr auto kBlockDim = (1 << (log2_block_dim)); \
__VA_ARGS__; \
} break

#define FIXED_BLOCK_DIM_CASE(...) \
FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__);

int64_t ComputeBlockSize(int64_t col) {
if (col > 512)
return 1024;
else if (col > 256)
return 512;
else if (col > 128)
return 256;
else if (col > 64)
return 128;
else if (col > 32)
return 64;
else if (col > 16)
return 32;
else if (col > 8)
return 16;
else
return 8;
}

template <template <typename T> typename BinaryFunctor, typename T>
struct BinaryOperation<platform::CUDADeviceContext, BinaryFunctor, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx, const Tensor& lhs,
const Tensor& rhs, Tensor* output) {
std::vector<const Tensor*> ins{&lhs, &rhs};
std::vector<Tensor*> outs{output};
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, -1, BinaryFunctor<T>());
}
};

template <template <typename T> typename CompareFunctor, typename T>
struct GetMask<platform::CUDADeviceContext, CompareFunctor, T> {
void operator()(const framework::ExecutionContext& ctx, const Tensor& lhs,
const Tensor& rhs, Tensor* mask) {
std::vector<const Tensor*> ins = {&lhs, &rhs};
std::vector<Tensor*> outs = {mask};
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, int64_t, T>(
dev_ctx, ins, &outs, CompareFunctor<int64_t>());
}
};

template <typename T, typename IndType, size_t BlockDim>
__global__ void ArgmaxCUDAKernel(const int64_t height, // n * h
const int64_t width, // c
const int64_t post_size, // h
const T* in, IndType* out_idx, T* out) {
typedef cub::BlockReduce<cub::KeyValuePair<int, T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
cub::ArgMax reducer;
T init = (std::numeric_limits<T>::lowest)(); // for windows compile
for (int idx = blockIdx.x; idx < height; idx += gridDim.x) {
cub::KeyValuePair<int, T> kv_pair = {-1, init};
int h = idx / post_size;
int w = idx % post_size;
for (int k = threadIdx.x; k < width; k += blockDim.x) {
kv_pair =
reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair);
}
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer);
if (threadIdx.x == 0) {
// return max, argmax
if (out_idx != nullptr) out_idx[idx] = static_cast<IndType>(kv_pair.key);
if (out != nullptr) out[idx] = kv_pair.value;
}
__syncthreads();
}
}

__global__ void ARangeKernel(int64_t* data, int num, int64_t scale) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int start = idx; idx < num; idx += gridDim.x) {
data[idx] = idx * scale;
}
}

template <>
struct ARange<platform::CUDADeviceContext> {
void operator()(const platform::CUDADeviceContext& dev_ctx, int64_t* data,
int num, int64_t scale) {
int64_t kBlockDim = ComputeBlockSize(num);
// kBlockDim > num at most of time, so we can set grid = 1
ARangeKernel<<<1, kBlockDim, 0, dev_ctx.stream()>>>(data, num, scale);
}
};

template <typename T, typename IndType>
struct Argmax<platform::CUDADeviceContext, T, IndType> {
void operator()(const framework::ExecutionContext& ctx, const Tensor& input,
Tensor* out_idx, Tensor* out, int axis) {
framework::DDim input_dims = input.dims();
int64_t numel = input.numel();
int64_t groups = numel / input_dims[axis];
int64_t pre = 1;
int64_t post = 1;
int64_t n = input_dims[axis];
for (int i = 0; i < axis; i++) {
pre *= input_dims[i];
}
for (int i = axis + 1; i < input_dims.size(); i++) {
post *= input_dims[i];
}
const auto& dev_ctx = ctx.cuda_device_context();
auto cu_stream = dev_ctx.stream();
int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize().x;
int64_t height = pre * post;
int64_t width = n;
int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx;
const T* in_data = input.data<T>();
IndType* out_idx_data = out_idx->data<IndType>();
T* out_data = out->data<T>();
switch (ComputeBlockSize(width)) {
FIXED_BLOCK_DIM_CASE(
ArgmaxCUDAKernel<T, IndType,
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
height, width, post, in_data, out_idx_data, out_data));
}
}
};

template <typename T>
struct GetMaxValue<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const Tensor& input, T* max_value) {
Tensor out_data;
out_data.Resize(framework::make_ddim({1}));
out_data.mutable_data<T>(platform::CUDAPlace());
switch (ComputeBlockSize(input.numel())) {
FIXED_BLOCK_DIM_CASE(
ArgmaxCUDAKernel<T, T,
kBlockDim><<<1, kBlockDim, 0, dev_ctx.stream()>>>(
1, input.numel(), 1, input.data<int64_t>(), nullptr,
out_data.data<int64_t>()));
}
Tensor max_value_tensor;
framework::TensorCopy(out_data, platform::CPUPlace(), &max_value_tensor);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

想确认一下 TensorCopy之前不需要对max_value_tensor分配内存吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要了,TensorCopy会调用mutable_data分配显存

*max_value = max_value_tensor.data<T>()[0];
}
};

template <typename T, typename IndexT>
struct Gather<platform::CUDADeviceContext, T, IndexT> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
GPUGather<T, IndexT>(ctx, src, index, output);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace platform = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
viterbi_decode,
ops::ViterbiDecodeKernel<platform::CUDADeviceContext, float>,
ops::ViterbiDecodeKernel<platform::CUDADeviceContext, double>);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块要调研是否能支持fp16,为后面的优化做点准备,如果组合API不支持,可以先不支持

Loading