Skip to content

Commit

Permalink
fix cross_entropy when run static graph mode of mlu and npu (PaddlePa…
Browse files Browse the repository at this point in the history
  • Loading branch information
qipengh committed Mar 30, 2022
1 parent cb8afc2 commit 489a64e
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,12 +1818,16 @@ def cross_entropy(input,
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=input.dtype)
out = helper.create_variable_for_type_inference(dtype=input.dtype)

outputs = {'Softmax': softmax, 'Loss': out}
if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
backprop = helper.create_variable_for_type_inference(dtype=input.dtype)
outputs['Backprop'] = backprop
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': input,
'Label': label},
outputs={'Softmax': softmax,
'Loss': out},
outputs=outputs,
attrs=attrs)

if weight is not None:
Expand Down

0 comments on commit 489a64e

Please sign in to comment.