Skip to content

Commit

Permalink
fix seq_len=1 bug
Browse files Browse the repository at this point in the history
  • Loading branch information
joey12300 committed Sep 16, 2021
1 parent f137f6b commit dcbc972
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions paddle/fluid/operators/viterbi_decode_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ class ViterbiDecodeCPUKernel : public framework::OpKernel<T> {
auto n_labels = static_cast<int>(input->dims()[2]);

// Create a large int data buffer
int buffer_size = batch_size * seq_len +
batch_size * n_labels * (seq_len - 1) + 7 * batch_size +
2;
int buffer_size = batch_size * seq_len + batch_size * n_labels * seq_len +
7 * batch_size + 2;
CREATE_TENSOR(int_buffer, int64_t, buffer_size);
TensorBuffer int_tensor_buffer(int_buffer);

Expand Down Expand Up @@ -185,7 +184,7 @@ class ViterbiDecodeCPUKernel : public framework::OpKernel<T> {
Tensor alpha_max =
float_tensor_buffer.GetBufferBlock({batch_size, n_labels});
Tensor alpha_argmax =
int_tensor_buffer.GetBufferBlock({seq_len - 1, batch_size, n_labels});
int_tensor_buffer.GetBufferBlock({seq_len, batch_size, n_labels});
auto alpha_argmax_unbind = Unbind(alpha_argmax);
Tensor alpha_nxt =
float_tensor_buffer.GetBufferBlock({batch_size, n_labels});
Expand Down

0 comments on commit dcbc972

Please sign in to comment.