Skip to content

Commit

Permalink
[Eager] Add warpctc yaml (#44617)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHUI committed Jul 26, 2022
1 parent b6e8480 commit 33cc0f7
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 26 deletions.
12 changes: 12 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2403,6 +2403,18 @@
func : viterbi_decode
data_type : input

- api : warpctc
args : (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank, bool norm_by_times)
output : Tensor(loss), Tensor(warpctcgrad)
infer_meta :
func : WarpctcInferMeta
kernel :
func : warpctc
data_type: logits
optional: logits_length, labels_length
intermediate: warpctcgrad
backward : warpctc_grad

- api : where
args : (Tensor condition, Tensor x, Tensor y)
output : Tensor
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2376,6 +2376,18 @@
inplace : (out_grad -> x_grad)
backward : unsqueeze_double_grad

- backward_api : warpctc_grad
forward : warpctc (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank, bool norm_by_times) -> Tensor(loss), Tensor(warpctcgrad)
args : (Tensor logits, Tensor logits_length, Tensor warpctcgrad, Tensor loss_grad, int blank, bool norm_by_times)
output : Tensor(logits_grad)
infer_meta :
func : UnchangedInferMeta
param : [logits]
kernel :
func : warpctc_grad
optional : logits_length
no_need_buffer : logits

- backward_api : where_grad
forward : where (Tensor condition, Tensor x, Tensor y) -> Tensor(out)
args : (Tensor condition, Tensor x, Tensor y, Tensor out_grad)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2049,7 +2049,7 @@ void WarpctcInferMeta(const MetaTensor& logits,
const MetaTensor& labels_length,
int blank,
bool norm_by_times,
MetaTensor* warpctc_grad,
MetaTensor* warpctcgrad,
MetaTensor* loss) {
auto logits_dims = logits.dims();
int sequence_width = 0;
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ void WarpctcInferMeta(const MetaTensor& logits,
const MetaTensor& labels_length,
int blank,
bool norm_by_times,
MetaTensor* warpctc_grad,
MetaTensor* warpctcgrad,
MetaTensor* loss);

void WhereInferMeta(const MetaTensor& condition,
Expand Down
18 changes: 9 additions & 9 deletions paddle/phi/kernels/impl/warpctc_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,33 @@ namespace phi {

template <typename T, typename Context>
void WarpctcGradKernel(const Context& dev_ctx,
const DenseTensor& warpctc_grad,
const DenseTensor& logits,
const DenseTensor& loss_grad,
const paddle::optional<DenseTensor>& logits_length,
const DenseTensor& warpctcgrad,
const DenseTensor& loss_grad,
int blank,
bool norm_by_times,
DenseTensor* logits_grad) {
dev_ctx.template Alloc<T>(logits_grad);

if (logits_length.is_initialized()) {
int max_seq_length = warpctc_grad.dims()[0]; // Tmax
int num_sequences = warpctc_grad.dims()[1]; // B
int seq_width = warpctc_grad.dims()[2]; // D
int max_seq_length = warpctcgrad.dims()[0]; // Tmax
int num_sequences = warpctcgrad.dims()[1]; // B
int seq_width = warpctcgrad.dims()[2]; // D

// B
auto logits_len_e = EigenTensor<int64_t, 1>::From(*logits_length);
// (B, 1)
auto loss_grad_e = EigenTensor<T, 2>::From(loss_grad);
// (T, B, D)
auto warpctc_grad_e = EigenTensor<T, 3>::From(warpctc_grad);
auto warpctcgrad_e = EigenTensor<T, 3>::From(warpctcgrad);

auto logits_grad_e = EigenTensor<T, 3>::From(*logits_grad);

Eigen::DSizes<int, 3> grad_shape(1, num_sequences, 1);
Eigen::DSizes<int, 3> bcast(max_seq_length, 1, seq_width);
auto logits_g = warpctc_grad_e *
loss_grad_e.reshape(grad_shape).broadcast(bcast).eval();
auto logits_g =
warpctcgrad_e * loss_grad_e.reshape(grad_shape).broadcast(bcast).eval();

auto* place = dev_ctx.eigen_device();
if (norm_by_times) {
Expand All @@ -71,7 +71,7 @@ void WarpctcGradKernel(const Context& dev_ctx,
} else {
paddle::operators::math::UnpaddingLoDTensorFunctor<Context, T>()(
dev_ctx,
warpctc_grad,
warpctcgrad,
logits_grad,
-1,
0,
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/kernels/impl/warpctc_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ void WarpctcKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& labels_length,
int blank,
bool norm_by_times,
DenseTensor* warpctc_grad,
DenseTensor* loss) {
DenseTensor* loss,
DenseTensor* warpctcgrad) {
size_t num_sequences, sequence_width, max_sequence_length;
paddle::framework::Vector<size_t> logits_lod;
paddle::framework::Vector<size_t> label_lod;
Expand Down Expand Up @@ -383,11 +383,11 @@ void WarpctcKernel(const Context& dev_ctx,

// warpctc computes loss and gradient in one call, gradient data also stored
// in batch format
warpctc_grad->Resize(warpctc_logits.dims());
T* warpctc_grad_data = dev_ctx.template Alloc<T>(warpctc_grad);
warpctcgrad->Resize(warpctc_logits.dims());
T* warpctcgrad_data = dev_ctx.template Alloc<T>(warpctcgrad);

phi::funcs::SetConstant<Context, T>()(
dev_ctx, warpctc_grad, static_cast<T>(0));
dev_ctx, warpctcgrad, static_cast<T>(0));

// warpctc accesses labels in CPU memory
DenseTensor warpctc_label;
Expand Down Expand Up @@ -439,7 +439,7 @@ void WarpctcKernel(const Context& dev_ctx,
T* warpctc_loss_data = dev_ctx.template HostAlloc<T>(&warpctc_loss);
WarpCTCFunctor<Context, T>()(dev_ctx,
warpctc_logits_data,
warpctc_grad_data,
warpctcgrad_data,
warpctc_label_data,
warpctc_label_lengths.data(),
warpctc_logits_lengths.data(),
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/warpctc_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ namespace phi {

template <typename T, typename Context>
void WarpctcGradKernel(const Context& dev_ctx,
const DenseTensor& warpctc_grad,
const DenseTensor& logits,
const DenseTensor& loss_grad,
const paddle::optional<DenseTensor>& logits_length,
const DenseTensor& warpctcgrad,
const DenseTensor& loss_grad,
int blank,
bool norm_by_times,
DenseTensor* logits_grad);
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/warpctc_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void WarpctcKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& labels_length,
int blank,
bool norm_by_times,
DenseTensor* warpctc_grad,
DenseTensor* loss);
DenseTensor* loss,
DenseTensor* warpctcgrad);

} // namespace phi
4 changes: 2 additions & 2 deletions paddle/phi/ops/compat/warpctc_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ KernelSignature WarpctcOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("warpctc",
{"Logits", "Label", "LogitsLength", "LabelLength"},
{"blank", "norm_by_times"},
{"WarpCTCGrad", "Loss"});
{"Loss", "WarpCTCGrad"});
}

KernelSignature WarpctcGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("warpctc_grad",
{"WarpCTCGrad", "Logits", "Loss@GRAD", "LogitsLength"},
{"Logits", "LogitsLength", "WarpCTCGrad", "Loss@GRAD"},
{"blank", "norm_by_times"},
{"Logits@GRAD"});
}
Expand Down
9 changes: 9 additions & 0 deletions python/paddle/fluid/layers/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,15 @@ def warpctc(input,
fetch_list=[cost.name])
print(output)
"""
if in_dygraph_mode():
if input_length is None or label_length is None:
raise ValueError(
"input_length and label_length must not be None in dygraph mode!"
)
loss_out = _C_ops.final_state_warpctc(input, label, input_length,
label_length, blank,
norm_by_times)
return loss_out
if _non_static_mode():
if input_length is None or label_length is None:
raise ValueError(
Expand Down
20 changes: 17 additions & 3 deletions python/paddle/fluid/tests/unittests/test_warpctc_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,16 @@ def forward(self):
return self.loss


def python_api(logits,
label,
logits_length=None,
labels_length=None,
blank=0,
norm_by_times=False):
return paddle.fluid.layers.warpctc(logits, label, blank, norm_by_times,
logits_length, labels_length)


class TestWarpCTCOp(OpTest):

def config(self):
Expand Down Expand Up @@ -280,6 +290,8 @@ def config(self):

def setUp(self):
self.op_type = "warpctc"
self.python_api = python_api
self.python_out_sig = ["Loss"]
self.config()

logits = np.random.uniform(
Expand Down Expand Up @@ -344,7 +356,7 @@ def setUp(self):
}

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)

def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
Expand Down Expand Up @@ -387,6 +399,8 @@ def config(self):

def setUp(self):
self.op_type = "warpctc"
self.python_api = python_api
self.python_out_sig = ["Loss"]
self.config()

logits = np.random.uniform(
Expand Down Expand Up @@ -451,11 +465,11 @@ def setUp(self):
}

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)

def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
self.check_grad(["Logits"], "Loss")
self.check_grad(["Logits"], "Loss", check_eager=True)


class TestWarpCTCOpError(unittest.TestCase):
Expand Down

0 comments on commit 33cc0f7

Please sign in to comment.