Skip to content

Commit

Permalink
fix api sigmoid_focal_loss to final state (#45207)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghuancoder committed Aug 17, 2022
1 parent a79d4a7 commit 2105d14
Showing 1 changed file with 44 additions and 15 deletions.
59 changes: 44 additions & 15 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2616,23 +2616,54 @@ def sigmoid_focal_loss(logit,
"Expected one dimension of normalizer in sigmoid_focal_loss but got {}."
.format(normalizer_dims))

if _non_static_mode():
if in_dygraph_mode():
place = _current_expected_place()
one = _C_ops.final_state_full(logit.shape, float(1.0), logit.dtype,
place)
if in_dygraph_mode():
place = _current_expected_place()
one = _C_ops.final_state_full(logit.shape, float(1.0), logit.dtype,
place)

loss = _C_ops.final_state_sigmoid_cross_entropy_with_logits(
logit, label, False, -100)
loss = _C_ops.final_state_sigmoid_cross_entropy_with_logits(
logit, label, False, -100)

elif _in_legacy_dygraph():
one = _varbase_creator(dtype=logit.dtype)
_C_ops.fill_constant(one, 'value', float(1.0), 'force_cpu', False,
'dtype', one.dtype, 'str_value', '1.0',
'shape', logit.shape)
loss = _C_ops.sigmoid_cross_entropy_with_logits(logit, label)
pred = _C_ops.final_state_sigmoid(logit)

p_t = _C_ops.final_state_add(
_C_ops.final_state_multiply(pred, label),
_C_ops.final_state_multiply(_C_ops.final_state_subtract(one, pred),
_C_ops.final_state_subtract(one,
label)))

alpha = fluid.dygraph.base.to_variable([alpha], dtype=loss.dtype)
alpha_t = _C_ops.final_state_add(
_C_ops.final_state_multiply(alpha, label),
_C_ops.final_state_multiply(_C_ops.final_state_subtract(one, alpha),
_C_ops.final_state_subtract(one,
label)))
loss = _C_ops.final_state_multiply(alpha_t, loss)

gamma = fluid.dygraph.base.to_variable([gamma], dtype=loss.dtype)
gamma_t = _C_ops.final_state_pow(_C_ops.elementwise_sub(one, p_t),
gamma)
loss = _C_ops.final_state_multiply(gamma_t, loss)

if normalizer is not None:
loss = _C_ops.final_state_divide(loss, normalizer)

if reduction == "sum":
return _C_ops.final_state_sum(loss, [], None, False)
elif reduction == "mean":
return _C_ops.final_state_mean_all(loss)

return loss

elif _in_legacy_dygraph():
one = _varbase_creator(dtype=logit.dtype)
_C_ops.fill_constant(one, 'value', float(1.0), 'force_cpu', False,
'dtype', one.dtype, 'str_value', '1.0', 'shape',
logit.shape)
loss = _C_ops.sigmoid_cross_entropy_with_logits(logit, label)

pred = _C_ops.sigmoid(logit)

p_t = _C_ops.elementwise_add(
_C_ops.elementwise_mul(pred, label),
_C_ops.elementwise_mul(_C_ops.elementwise_sub(one, pred),
Expand All @@ -2656,8 +2687,6 @@ def sigmoid_focal_loss(logit,
if reduction == "sum":
return _C_ops.reduce_sum(loss, 'reduce_all', True)
elif reduction == "mean":
if in_dygraph_mode():
return _C_ops.final_state_mean_all(loss)
return _C_ops.mean(loss)

return loss
Expand Down

0 comments on commit 2105d14

Please sign in to comment.