-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add viterbi decode #35778
Add viterbi decode #35778
Changes from 73 commits
93e70a9
3f33b76
c176fca
f137f6b
dcbc972
d258e07
f75a4be
3ec4097
188d933
69e1f85
face1f1
08daa51
a0777ff
6ddc7d4
36f371b
9525039
2698874
8d0b3f6
5d2259b
425ceaf
e17cff1
ee08aab
f48240a
7082e26
a893027
9cf6fc4
a1ee241
2540dc2
6a4f579
8ded54b
0965c95
2951a99
ea2155d
a5cfe57
b789dbf
c722163
e70f560
979db61
d543b21
fd3417f
10e1056
5f23d46
49493f9
564117b
6e46a75
c5071e1
2bee6c6
bc7ccf8
60b3eb2
34e5c8c
13de468
2713eef
a1d8709
3317ce1
c8a6695
1cfb27c
0986697
d0f70a6
5fb7e72
a027c24
058d2ac
edf1761
7f0b2ac
ba34cb8
332375e
f4687bf
f6fc897
1c68da2
d116447
c91752d
c0841f1
b788042
2586df5
76e31b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/viterbi_decode_op.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class ViterbiDecodeOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ViterbiDecode"); | ||
OP_INOUT_CHECK(ctx->HasInput("Transition"), "Input", "Transition", | ||
"ViterbiDecode"); | ||
OP_INOUT_CHECK(ctx->HasInput("Length"), "Input", "Length", "ViterbiDecode"); | ||
OP_INOUT_CHECK(ctx->HasOutput("Scores"), "Output", "Scores", | ||
"ViterbiDecode"); | ||
OP_INOUT_CHECK(ctx->HasOutput("Path"), "Output", "Path", "ViterbiDecode"); | ||
auto in_dims = ctx->GetInputDim("Input"); | ||
PADDLE_ENFORCE_EQ(in_dims.size(), 3, | ||
platform::errors::InvalidArgument( | ||
"The rank of Input in ViterbiDecode must be 3. But " | ||
"received Input's rank is %d.", | ||
in_dims.size())); | ||
auto length_dims = ctx->GetInputDim("Length"); | ||
PADDLE_ENFORCE_EQ(length_dims.size(), 1, | ||
platform::errors::InvalidArgument( | ||
"The rank of Length in ViterbiDecode must be 1. But " | ||
"received Length's rank is %d.", | ||
length_dims.size())); | ||
auto transition_dims = ctx->GetInputDim("Transition"); | ||
PADDLE_ENFORCE_EQ( | ||
transition_dims.size(), 2, | ||
platform::errors::InvalidArgument( | ||
"The rank of Transition in ViterbiDecode must be 2. But " | ||
"received Transition's rank is %d.", | ||
transition_dims.size())); | ||
if (ctx->IsRuntime()) { | ||
PADDLE_ENFORCE_EQ( | ||
in_dims[0], length_dims[0], | ||
platform::errors::InvalidArgument( | ||
"The batch size of Input and Length should be equal.")); | ||
PADDLE_ENFORCE_EQ(in_dims[2], transition_dims[0], | ||
platform::errors::InvalidArgument( | ||
"The number of tags of Input (%d) and Transition " | ||
"(%d) should be equal.", | ||
transition_dims[0], in_dims[2])); | ||
} | ||
ctx->SetOutputDim("Scores", length_dims); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), | ||
ctx.device_context()); | ||
} | ||
}; | ||
|
||
class ViterbiDecodeOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput( | ||
"Input", | ||
"The unary emission tensor. The shape of Input must be (batch_size," | ||
"sequence_length, num_tags). "); | ||
AddInput("Transition", | ||
"The transition matrix. The shape of Transition must be ( " | ||
"num_tags, num_tags). "); | ||
AddInput("Length", | ||
"The input length tensor storing real length of each sequence for " | ||
"correctness. The shape of Length MUST be (batch_size)."); | ||
AddOutput("Scores", | ||
"The scores tensor containing the score for the Viterbi " | ||
"sequence. The shape of Scores MUST be (batch_size)."); | ||
AddOutput("Path", | ||
"The paths tensor containing the highest scoring tag indices. " | ||
"The shape of Scores MUST be (batch_size, sequence_length)."); | ||
AddAttr<bool>("include_bos_eos_tag", | ||
"If set to True, the last row and the last column of " | ||
"transitions will be considered as start tag.") | ||
.SetDefault(true); | ||
AddComment(R"DOC( | ||
)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace platform = paddle::platform; | ||
REGISTER_OP_WITHOUT_GRADIENT(viterbi_decode, ops::ViterbiDecodeOp, | ||
ops::ViterbiDecodeOpMaker); | ||
REGISTER_OP_CPU_KERNEL( | ||
viterbi_decode, ops::ViterbiDecodeKernel<platform::CPUDeviceContext, float>, | ||
ops::ViterbiDecodeKernel<platform::CPUDeviceContext, double>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/elementwise/elementwise_functor.h" | ||
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" | ||
#include "paddle/fluid/operators/gather.cu.h" | ||
#include "paddle/fluid/operators/viterbi_decode_op.h" | ||
|
||
#ifdef __NVCC__ | ||
#include "cub/cub.cuh" | ||
#endif | ||
#ifdef __HIPCC__ | ||
#include <hipcub/hipcub.hpp> | ||
namespace cub = hipcub; | ||
#endif | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \ | ||
case (1 << (log2_block_dim)): { \ | ||
constexpr auto kBlockDim = (1 << (log2_block_dim)); \ | ||
__VA_ARGS__; \ | ||
} break | ||
|
||
#define FIXED_BLOCK_DIM_CASE(...) \ | ||
FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \ | ||
FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \ | ||
FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \ | ||
FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \ | ||
FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \ | ||
FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \ | ||
FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \ | ||
FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); | ||
|
||
int64_t ComputeBlockSize(int64_t col) { | ||
if (col > 512) | ||
return 1024; | ||
else if (col > 256) | ||
return 512; | ||
else if (col > 128) | ||
return 256; | ||
else if (col > 64) | ||
return 128; | ||
else if (col > 32) | ||
return 64; | ||
else if (col > 16) | ||
return 32; | ||
else if (col > 8) | ||
return 16; | ||
else | ||
return 8; | ||
} | ||
|
||
template <template <typename T> typename BinaryFunctor, typename T> | ||
struct BinaryOperation<platform::CUDADeviceContext, BinaryFunctor, T> { | ||
void operator()(const platform::CUDADeviceContext& dev_ctx, const Tensor& lhs, | ||
const Tensor& rhs, Tensor* output) { | ||
std::vector<const Tensor*> ins{&lhs, &rhs}; | ||
std::vector<Tensor*> outs{output}; | ||
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( | ||
dev_ctx, ins, &outs, -1, BinaryFunctor<T>()); | ||
} | ||
}; | ||
|
||
template <template <typename T> typename CompareFunctor, typename T> | ||
struct GetMask<platform::CUDADeviceContext, CompareFunctor, T> { | ||
void operator()(const framework::ExecutionContext& ctx, const Tensor& lhs, | ||
const Tensor& rhs, Tensor* mask) { | ||
std::vector<const Tensor*> ins = {&lhs, &rhs}; | ||
std::vector<Tensor*> outs = {mask}; | ||
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); | ||
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, int64_t, T>( | ||
dev_ctx, ins, &outs, CompareFunctor<int64_t>()); | ||
} | ||
}; | ||
|
||
template <typename T, typename IndType, size_t BlockDim> | ||
__global__ void ArgmaxCUDAKernel(const int64_t height, // n * h | ||
const int64_t width, // c | ||
const int64_t post_size, // h | ||
const T* in, IndType* out_idx, T* out) { | ||
typedef cub::BlockReduce<cub::KeyValuePair<int, T>, BlockDim> BlockReduce; | ||
__shared__ typename BlockReduce::TempStorage temp_storage; | ||
cub::ArgMax reducer; | ||
T init = (std::numeric_limits<T>::lowest)(); // for windows compile | ||
for (int idx = blockIdx.x; idx < height; idx += gridDim.x) { | ||
cub::KeyValuePair<int, T> kv_pair = {-1, init}; | ||
int h = idx / post_size; | ||
int w = idx % post_size; | ||
for (int k = threadIdx.x; k < width; k += blockDim.x) { | ||
kv_pair = | ||
reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair); | ||
} | ||
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer); | ||
if (threadIdx.x == 0) { | ||
// return max, argmax | ||
if (out_idx != nullptr) out_idx[idx] = static_cast<IndType>(kv_pair.key); | ||
if (out != nullptr) out[idx] = kv_pair.value; | ||
} | ||
__syncthreads(); | ||
} | ||
} | ||
|
||
__global__ void ARangeKernel(int64_t* data, int num, int64_t scale) { | ||
int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
for (int start = idx; idx < num; idx += gridDim.x) { | ||
data[idx] = idx * scale; | ||
} | ||
} | ||
|
||
template <> | ||
struct ARange<platform::CUDADeviceContext> { | ||
void operator()(const platform::CUDADeviceContext& dev_ctx, int64_t* data, | ||
int num, int64_t scale) { | ||
int64_t kBlockDim = ComputeBlockSize(num); | ||
// kBlockDim > num at most of time, so we can set grid = 1 | ||
ARangeKernel<<<1, kBlockDim, 0, dev_ctx.stream()>>>(data, num, scale); | ||
} | ||
}; | ||
|
||
template <typename T, typename IndType> | ||
struct Argmax<platform::CUDADeviceContext, T, IndType> { | ||
void operator()(const framework::ExecutionContext& ctx, const Tensor& input, | ||
Tensor* out_idx, Tensor* out, int axis) { | ||
framework::DDim input_dims = input.dims(); | ||
int64_t numel = input.numel(); | ||
int64_t groups = numel / input_dims[axis]; | ||
int64_t pre = 1; | ||
int64_t post = 1; | ||
int64_t n = input_dims[axis]; | ||
for (int i = 0; i < axis; i++) { | ||
pre *= input_dims[i]; | ||
} | ||
for (int i = axis + 1; i < input_dims.size(); i++) { | ||
post *= input_dims[i]; | ||
} | ||
const auto& dev_ctx = ctx.cuda_device_context(); | ||
auto cu_stream = dev_ctx.stream(); | ||
int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize().x; | ||
int64_t height = pre * post; | ||
int64_t width = n; | ||
int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx; | ||
const T* in_data = input.data<T>(); | ||
IndType* out_idx_data = out_idx->data<IndType>(); | ||
T* out_data = out->data<T>(); | ||
switch (ComputeBlockSize(width)) { | ||
FIXED_BLOCK_DIM_CASE( | ||
ArgmaxCUDAKernel<T, IndType, | ||
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>( | ||
height, width, post, in_data, out_idx_data, out_data)); | ||
} | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct GetMaxValue<platform::CUDADeviceContext, T> { | ||
void operator()(const platform::CUDADeviceContext& dev_ctx, | ||
const Tensor& input, T* max_value) { | ||
Tensor out_data; | ||
out_data.Resize(framework::make_ddim({1})); | ||
out_data.mutable_data<T>(platform::CUDAPlace()); | ||
switch (ComputeBlockSize(input.numel())) { | ||
FIXED_BLOCK_DIM_CASE( | ||
ArgmaxCUDAKernel<T, T, | ||
kBlockDim><<<1, kBlockDim, 0, dev_ctx.stream()>>>( | ||
1, input.numel(), 1, input.data<int64_t>(), nullptr, | ||
out_data.data<int64_t>())); | ||
} | ||
Tensor max_value_tensor; | ||
framework::TensorCopy(out_data, platform::CPUPlace(), &max_value_tensor); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 想确认一下 TensorCopy之前不需要对max_value_tensor分配内存吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不需要了,TensorCopy会调用mutable_data分配显存 |
||
*max_value = max_value_tensor.data<T>()[0]; | ||
} | ||
}; | ||
|
||
template <typename T, typename IndexT> | ||
struct Gather<platform::CUDADeviceContext, T, IndexT> { | ||
void operator()(const platform::CUDADeviceContext& ctx, const Tensor& src, | ||
const Tensor& index, Tensor* output) { | ||
GPUGather<T, IndexT>(ctx, src, index, output); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace platform = paddle::platform; | ||
REGISTER_OP_CUDA_KERNEL( | ||
viterbi_decode, | ||
ops::ViterbiDecodeKernel<platform::CUDADeviceContext, float>, | ||
ops::ViterbiDecodeKernel<platform::CUDADeviceContext, double>); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这块要调研是否能支持fp16,为后面的优化做点准备,如果组合API不支持,可以先不支持 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for windows ci