/
losses.py
1197 lines (943 loc) · 40.2 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Built-in loss functions.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.ops.losses import util as tf_losses_util
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls
@keras_export('keras.losses.Loss')
class Loss(object):
"""Loss base class.
To be implemented by subclasses:
* `call()`: Contains the logic for loss calculation using `y_true`, `y_pred`.
Example subclass implementation:
```
class MeanSquaredError(Loss):
def call(self, y_true, y_pred):
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
return K.mean(math_ops.square(y_pred - y_true), axis=-1)
```
When used with `tf.distribute.Strategy`, outside of built-in training loops
such as `tf.keras` `compile` and `fit`, please use 'SUM' or 'NONE' reduction
types, and reduce losses explicitly in your training loop. Using 'AUTO' or
'SUM_OVER_BATCH_SIZE' will raise an error.
Please see
https://www.tensorflow.org/tutorials/distribute/custom_training for more
details on this.
You can implement 'SUM_OVER_BATCH_SIZE' using global batch size like:
```
with strategy.scope():
loss_obj = tf.keras.losses.CategoricalCrossentropy(
reduction=tf.keras.losses.Reduction.NONE)
....
loss = (tf.reduce_sum(loss_obj(labels, predictions)) *
(1. / global_batch_size))
```
Args:
reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss.
Default value is `AUTO`. `AUTO` indicates that the reduction option will
be determined by the usage context. For almost all cases this defaults to
`SUM_OVER_BATCH_SIZE`.
When used with `tf.distribute.Strategy`, outside of built-in training
loops such as `tf.keras` `compile` and `fit`, using `AUTO` or
`SUM_OVER_BATCH_SIZE` will raise an error. Please see
https://www.tensorflow.org/tutorials/distribute/custom_training
for more details on this.
name: Optional name for the op.
"""
def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name=None):
losses_utils.ReductionV2.validate(reduction)
self.reduction = reduction
self.name = name
def __call__(self, y_true, y_pred, sample_weight=None):
"""Invokes the `Loss` instance.
Args:
y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`
y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`
sample_weight: Optional `sample_weight` acts as a
coefficient for the loss. If a scalar is provided, then the loss is
simply scaled by the given value. If `sample_weight` is a tensor of size
`[batch_size]`, then the total loss for each sample of the batch is
rescaled by the corresponding element in the `sample_weight` vector. If
the shape of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be
broadcasted to this shape), then each loss element of `y_pred` is scaled
by the corresponding value of `sample_weight`. (Note on`dN-1`: all loss
functions reduce by 1 dimension, usually axis=-1.)
Returns:
Weighted loss float `Tensor`. If `reduction` is `NONE`, this has
shape `[batch_size, d0, .. dN-1]`; otherwise, it is scalar. (Note `dN-1`
because all loss functions reduce by 1 dimension, usually axis=-1.)
Raises:
ValueError: If the shape of `sample_weight` is invalid.
"""
# If we are wrapping a lambda function strip '<>' from the name as it is not
# accepted in scope name.
scope_name = 'lambda' if self.name == '<lambda>' else self.name
graph_ctx = tf_utils.graph_context_for_symbolic_tensors(
y_true, y_pred, sample_weight)
with K.name_scope(scope_name or self.__class__.__name__), graph_ctx:
losses = self.call(y_true, y_pred)
return losses_utils.compute_weighted_loss(
losses, sample_weight, reduction=self._get_reduction())
@classmethod
def from_config(cls, config):
"""Instantiates a `Loss` from its config (output of `get_config()`).
Args:
config: Output of `get_config()`.
Returns:
A `Loss` instance.
"""
return cls(**config)
def get_config(self):
return {'reduction': self.reduction, 'name': self.name}
@abc.abstractmethod
@doc_controls.for_subclass_implementers
def call(self, y_true, y_pred):
"""Invokes the `Loss` instance.
Args:
y_true: Ground truth values, with the same shape as 'y_pred'.
y_pred: The predicted values.
"""
NotImplementedError('Must be implemented in subclasses.')
def _get_reduction(self):
"""Handles `AUTO` reduction cases and returns the reduction value."""
if distribution_strategy_context.has_strategy() and (
self.reduction == losses_utils.ReductionV2.AUTO or
self.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE):
raise ValueError(
'Please use `tf.keras.losses.Reduction.SUM` or '
'`tf.keras.losses.Reduction.NONE` for loss reduction when losses are '
'used with `tf.distribute.Strategy` outside of the built-in training '
'loops. You can implement '
'`tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` using global batch '
'size like:\n```\nwith strategy.scope():\n'
' loss_obj = tf.keras.losses.CategoricalCrossentropy('
'reduction=tf.keras.losses.Reduction.NONE)\n....\n'
' loss = tf.reduce_sum(loss_obj(labels, predictions)) * '
'(1. / global_batch_size)\n```\nPlease see '
'https://www.tensorflow.org/tutorials/distribute/custom_training'
' for more details.')
if self.reduction == losses_utils.ReductionV2.AUTO:
return losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
return self.reduction
class LossFunctionWrapper(Loss):
"""Wraps a loss function in the `Loss` class.
Args:
fn: The loss function to wrap, with signature `fn(y_true, y_pred,
**kwargs)`.
reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss.
Default value is `AUTO`. `AUTO` indicates that the reduction option will
be determined by the usage context. For almost all cases this defaults to
`SUM_OVER_BATCH_SIZE`.
When used with `tf.distribute.Strategy`, outside of built-in training
loops such as `tf.keras` `compile` and `fit`, using `AUTO` or
`SUM_OVER_BATCH_SIZE` will raise an error. Please see
https://www.tensorflow.org/tutorials/distribute/custom_training
for more details on this.
name: (Optional) name for the loss.
**kwargs: The keyword arguments that are passed on to `fn`.
"""
def __init__(self,
fn,
reduction=losses_utils.ReductionV2.AUTO,
name=None,
**kwargs):
super(LossFunctionWrapper, self).__init__(reduction=reduction, name=name)
self.fn = fn
self._fn_kwargs = kwargs
def call(self, y_true, y_pred):
"""Invokes the `LossFunctionWrapper` instance.
Args:
y_true: Ground truth values.
y_pred: The predicted values.
Returns:
Loss values per sample.
"""
if tensor_util.is_tensor(y_pred) and tensor_util.is_tensor(y_true):
y_pred, y_true = tf_losses_util.squeeze_or_expand_dimensions(
y_pred, y_true)
return self.fn(y_true, y_pred, **self._fn_kwargs)
def get_config(self):
config = {}
for k, v in six.iteritems(self._fn_kwargs):
config[k] = K.eval(v) if tf_utils.is_tensor_or_variable(v) else v
base_config = super(LossFunctionWrapper, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@keras_export('keras.losses.MeanSquaredError')
class MeanSquaredError(LossFunctionWrapper):
"""Computes the mean of squares of errors between labels and predictions.
`loss = square(y_true - y_pred)`
Usage:
```python
mse = tf.keras.losses.MeanSquaredError()
loss = mse([0., 0., 1., 1.], [1., 1., 1., 0.])
print('Loss: ', loss.numpy()) # Loss: 0.75
```
Usage with the `compile` API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.MeanSquaredError())
```
"""
def __init__(self,
reduction=losses_utils.ReductionV2.AUTO,
name='mean_squared_error'):
super(MeanSquaredError, self).__init__(
mean_squared_error, name=name, reduction=reduction)
@keras_export('keras.losses.MeanAbsoluteError')
class MeanAbsoluteError(LossFunctionWrapper):
"""Computes the mean of absolute difference between labels and predictions.
`loss = abs(y_true - y_pred)`
Usage:
```python
mae = tf.keras.losses.MeanAbsoluteError()
loss = mae([0., 0., 1., 1.], [1., 1., 1., 0.])
print('Loss: ', loss.numpy()) # Loss: 0.75
```
Usage with the `compile` API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.MeanAbsoluteError())
```
"""
def __init__(self,
reduction=losses_utils.ReductionV2.AUTO,
name='mean_absolute_error'):
super(MeanAbsoluteError, self).__init__(
mean_absolute_error, name=name, reduction=reduction)
@keras_export('keras.losses.MeanAbsolutePercentageError')
class MeanAbsolutePercentageError(LossFunctionWrapper):
"""Computes the mean absolute percentage error between `y_true` and `y_pred`.
`loss = 100 * abs(y_true - y_pred) / y_true`
Usage:
```python
mape = tf.keras.losses.MeanAbsolutePercentageError()
loss = mape([0., 0., 1., 1.], [1., 1., 1., 0.])
print('Loss: ', loss.numpy()) # Loss: 5e+08
```
Usage with the `compile` API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.MeanAbsolutePercentageError())
```
"""
def __init__(self,
reduction=losses_utils.ReductionV2.AUTO,
name='mean_absolute_percentage_error'):
super(MeanAbsolutePercentageError, self).__init__(
mean_absolute_percentage_error, name=name, reduction=reduction)
@keras_export('keras.losses.MeanSquaredLogarithmicError')
class MeanSquaredLogarithmicError(LossFunctionWrapper):
"""Computes the mean squared logarithmic error between `y_true` and `y_pred`.
`loss = square(log(y_true) - log(y_pred))`
Usage:
```python
msle = tf.keras.losses.MeanSquaredLogarithmicError()
loss = msle([0., 0., 1., 1.], [1., 1., 1., 0.])
print('Loss: ', loss.numpy()) # Loss: 0.36034
```
Usage with the `compile` API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.MeanSquaredLogarithmicError())
```
"""
def __init__(self,
reduction=losses_utils.ReductionV2.AUTO,
name='mean_squared_logarithmic_error'):
super(MeanSquaredLogarithmicError, self).__init__(
mean_squared_logarithmic_error, name=name, reduction=reduction)
@keras_export('keras.losses.BinaryCrossentropy')
class BinaryCrossentropy(LossFunctionWrapper):
"""Computes the cross-entropy loss between true labels and predicted labels.
Use this cross-entropy loss when there are only two label classes (assumed to
be 0 and 1). For each example, there should be a single floating-point value
per prediction.
In the snippet below, each of the four examples has only a single
floating-pointing value, and both `y_pred` and `y_true` have the shape
`[batch_size]`.
Usage:
```python
bce = tf.keras.losses.BinaryCrossentropy()
loss = bce([0., 0., 1., 1.], [1., 1., 1., 0.])
print('Loss: ', loss.numpy()) # Loss: 11.522857
```
Usage with the `tf.keras` API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.BinaryCrossentropy())
```
Args:
from_logits: Whether to interpret `y_pred` as a tensor of
[logit](https://en.wikipedia.org/wiki/Logit) values. By default, we assume
that `y_pred` contains probabilities (i.e., values in [0, 1]).
Note: Using from_logits=True may be more numerically stable.
label_smoothing: Float in [0, 1]. When 0, no smoothing occurs. When > 0, we
compute the loss between the predicted labels and a smoothed version of
the true labels, where the smoothing squeezes the labels towards 0.5.
Larger values of `label_smoothing` correspond to heavier smoothing.
reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss.
Default value is `AUTO`. `AUTO` indicates that the reduction option will
be determined by the usage context. For almost all cases this defaults to
`SUM_OVER_BATCH_SIZE`.
When used with `tf.distribute.Strategy`, outside of built-in training
loops such as `tf.keras` `compile` and `fit`, using `AUTO` or
`SUM_OVER_BATCH_SIZE` will raise an error. Please see
https://www.tensorflow.org/tutorials/distribute/custom_training
for more details on this.
name: (Optional) Name for the op.
"""
def __init__(self,
from_logits=False,
label_smoothing=0,
reduction=losses_utils.ReductionV2.AUTO,
name='binary_crossentropy'):
super(BinaryCrossentropy, self).__init__(
binary_crossentropy,
name=name,
reduction=reduction,
from_logits=from_logits,
label_smoothing=label_smoothing)
self.from_logits = from_logits
@keras_export('keras.losses.CategoricalCrossentropy')
class CategoricalCrossentropy(LossFunctionWrapper):
"""Computes the crossentropy loss between the labels and predictions.
Use this crossentropy loss function when there are two or more label classes.
We expect labels to be provided in a `one_hot` representation. If you want to
provide labels as integers, please use `SparseCategoricalCrossentropy` loss.
There should be `# classes` floating point values per feature.
In the snippet below, there is `# classes` floating pointing values per
example. The shape of both `y_pred` and `y_true` are
`[batch_size, num_classes]`.
Usage:
```python
cce = tf.keras.losses.CategoricalCrossentropy()
loss = cce(
[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]],
[[.9, .05, .05], [.05, .89, .06], [.05, .01, .94]])
print('Loss: ', loss.numpy()) # Loss: 0.0945
```
Usage with the `compile` API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.CategoricalCrossentropy())
```
Args:
from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
we assume that `y_pred` encodes a probability distribution.
Note: Using from_logits=True may be more numerically stable.
label_smoothing: Float in [0, 1]. When > 0, label values are smoothed,
meaning the confidence on label values are relaxed. e.g.
`label_smoothing=0.2` means that we will use a value of `0.1` for label
`0` and `0.9` for label `1`"
reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss.
Default value is `AUTO`. `AUTO` indicates that the reduction option will
be determined by the usage context. For almost all cases this defaults to
`SUM_OVER_BATCH_SIZE`.
When used with `tf.distribute.Strategy`, outside of built-in training
loops such as `tf.keras` `compile` and `fit`, using `AUTO` or
`SUM_OVER_BATCH_SIZE` will raise an error. Please see
https://www.tensorflow.org/tutorials/distribute/custom_training
for more details on this.
name: Optional name for the op.
"""
def __init__(self,
from_logits=False,
label_smoothing=0,
reduction=losses_utils.ReductionV2.AUTO,
name='categorical_crossentropy'):
super(CategoricalCrossentropy, self).__init__(
categorical_crossentropy,
name=name,
reduction=reduction,
from_logits=from_logits,
label_smoothing=label_smoothing)
@keras_export('keras.losses.SparseCategoricalCrossentropy')
class SparseCategoricalCrossentropy(LossFunctionWrapper):
"""Computes the crossentropy loss between the labels and predictions.
Use this crossentropy loss function when there are two or more label classes.
We expect labels to be provided as integers. If you want to provide labels
using `one-hot` representation, please use `CategoricalCrossentropy` loss.
There should be `# classes` floating point values per feature for `y_pred`
and a single floating point value per feature for `y_true`.
In the snippet below, there is a single floating point value per example for
`y_true` and `# classes` floating pointing values per example for `y_pred`.
The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is
`[batch_size, num_classes]`.
Usage:
```python
cce = tf.keras.losses.SparseCategoricalCrossentropy()
loss = cce(
tf.convert_to_tensor([0, 1, 2]),
tf.convert_to_tensor([[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]]))
print('Loss: ', loss.numpy()) # Loss: 0.3239
```
Usage with the `compile` API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.SparseCategoricalCrossentropy())
```
Args:
from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
we assume that `y_pred` encodes a probability distribution.
Note: Using from_logits=True may be more numerically stable.
reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss.
Default value is `AUTO`. `AUTO` indicates that the reduction option will
be determined by the usage context. For almost all cases this defaults to
`SUM_OVER_BATCH_SIZE`.
When used with `tf.distribute.Strategy`, outside of built-in training
loops such as `tf.keras` `compile` and `fit`, using `AUTO` or
`SUM_OVER_BATCH_SIZE` will raise an error. Please see
https://www.tensorflow.org/tutorials/distribute/custom_training
for more details on this.
name: Optional name for the op.
"""
def __init__(self,
from_logits=False,
reduction=losses_utils.ReductionV2.AUTO,
name='sparse_categorical_crossentropy'):
super(SparseCategoricalCrossentropy, self).__init__(
sparse_categorical_crossentropy,
name=name,
reduction=reduction,
from_logits=from_logits)
@keras_export('keras.losses.Hinge')
class Hinge(LossFunctionWrapper):
"""Computes the hinge loss between `y_true` and `y_pred`.
`loss = maximum(1 - y_true * y_pred, 0)`
`y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
provided we will convert them to -1 or 1.
Usage:
```python
h = tf.keras.losses.Hinge()
loss = h([-1., 1., 1.], [0.6, -0.7, -0.5])
# loss = max(0, 1 - y_true * y_pred) = [1.6 + 1.7 + 1.5] / 3
print('Loss: ', loss.numpy()) # Loss: 1.6
```
Usage with the `compile` API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.Hinge())
```
"""
def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='hinge'):
super(Hinge, self).__init__(hinge, name=name, reduction=reduction)
@keras_export('keras.losses.SquaredHinge')
class SquaredHinge(LossFunctionWrapper):
"""Computes the squared hinge loss between `y_true` and `y_pred`.
`loss = square(maximum(1 - y_true * y_pred, 0))`
`y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
provided we will convert them to -1 or 1.
Usage:
```python
sh = tf.keras.losses.SquaredHinge()
loss = sh([-1., 1., 1.], [0.6, -0.7, -0.5])
# loss = (max(0, 1 - y_true * y_pred))^2 = [1.6^2 + 1.7^2 + 1.5^2] / 3
print('Loss: ', loss.numpy()) # Loss: 2.566666
```
Usage with the `compile` API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.SquaredHinge())
```
"""
def __init__(self,
reduction=losses_utils.ReductionV2.AUTO,
name='squared_hinge'):
super(SquaredHinge, self).__init__(
squared_hinge, name=name, reduction=reduction)
@keras_export('keras.losses.CategoricalHinge')
class CategoricalHinge(LossFunctionWrapper):
"""Computes the categorical hinge loss between `y_true` and `y_pred`.
`loss = maximum(neg - pos + 1, 0)`
where `neg = sum(y_true * y_pred)` and `pos = maximum(1 - y_true)`
Usage:
```python
ch = tf.keras.losses.CategoricalHinge()
loss = ch([0., 1., 1.], [1., 0., 1.])
print('Loss: ', loss.numpy()) # Loss: 1.0
```
Usage with the `compile` API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.CategoricalHinge())
```
"""
def __init__(self,
reduction=losses_utils.ReductionV2.AUTO,
name='categorical_hinge'):
super(CategoricalHinge, self).__init__(
categorical_hinge, name=name, reduction=reduction)
@keras_export('keras.losses.Poisson')
class Poisson(LossFunctionWrapper):
"""Computes the Poisson loss between `y_true` and `y_pred`.
`loss = y_pred - y_true * log(y_pred)`
Usage:
```python
p = tf.keras.losses.Poisson()
loss = p([1., 9., 2.], [4., 8., 12.])
print('Loss: ', loss.numpy()) # Loss: -0.35702705
```
Usage with the `compile` API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.Poisson())
```
"""
def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='poisson'):
super(Poisson, self).__init__(poisson, name=name, reduction=reduction)
@keras_export('keras.losses.LogCosh')
class LogCosh(LossFunctionWrapper):
"""Computes the logarithm of the hyperbolic cosine of the prediction error.
`logcosh = log((exp(x) + exp(-x))/2)`,
where x is the error `y_pred - y_true`.
Usage:
```python
l = tf.keras.losses.LogCosh()
loss = l([0., 1., 1.], [1., 0., 1.])
print('Loss: ', loss.numpy()) # Loss: 0.289
```
Usage with the `compile` API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.LogCosh())
```
"""
def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='logcosh'):
super(LogCosh, self).__init__(logcosh, name=name, reduction=reduction)
@keras_export('keras.losses.KLDivergence')
class KLDivergence(LossFunctionWrapper):
"""Computes Kullback-Leibler divergence loss between `y_true` and `y_pred`.
`loss = y_true * log(y_true / y_pred)`
See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
Usage:
```python
k = tf.keras.losses.KLDivergence()
loss = k([.4, .9, .2], [.5, .8, .12])
print('Loss: ', loss.numpy()) # Loss: 0.11891246
```
Usage with the `compile` API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.KLDivergence())
```
"""
def __init__(self,
reduction=losses_utils.ReductionV2.AUTO,
name='kullback_leibler_divergence'):
super(KLDivergence, self).__init__(
kullback_leibler_divergence, name=name, reduction=reduction)
@keras_export('keras.losses.Huber')
class Huber(LossFunctionWrapper):
"""Computes the Huber loss between `y_true` and `y_pred`.
For each value x in `error = y_true - y_pred`:
```
loss = 0.5 * x^2 if |x| <= d
loss = 0.5 * d^2 + d * (|x| - d) if |x| > d
```
where d is `delta`. See: https://en.wikipedia.org/wiki/Huber_loss
Usage:
```python
l = tf.keras.losses.Huber()
loss = l([0., 1., 1.], [1., 0., 1.])
print('Loss: ', loss.numpy()) # Loss: 0.333
```
Usage with the `compile` API:
```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.Huber())
```
Args:
delta: A float, the point where the Huber loss function changes from a
quadratic to linear.
reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss.
Default value is `AUTO`. `AUTO` indicates that the reduction option will
be determined by the usage context. For almost all cases this defaults to
`SUM_OVER_BATCH_SIZE`.
When used with `tf.distribute.Strategy`, outside of built-in training
loops such as `tf.keras` `compile` and `fit`, using `AUTO` or
`SUM_OVER_BATCH_SIZE` will raise an error. Please see
https://www.tensorflow.org/tutorials/distribute/custom_training
for more details on this.
name: Optional name for the op.
"""
def __init__(self,
delta=1.0,
reduction=losses_utils.ReductionV2.AUTO,
name='huber_loss'):
super(Huber, self).__init__(
huber_loss, name=name, reduction=reduction, delta=delta)
@keras_export('keras.metrics.mean_squared_error',
'keras.metrics.mse',
'keras.metrics.MSE',
'keras.losses.mean_squared_error',
'keras.losses.mse',
'keras.losses.MSE')
def mean_squared_error(y_true, y_pred):
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
return K.mean(math_ops.squared_difference(y_pred, y_true), axis=-1)
@keras_export('keras.metrics.mean_absolute_error',
'keras.metrics.mae',
'keras.metrics.MAE',
'keras.losses.mean_absolute_error',
'keras.losses.mae',
'keras.losses.MAE')
def mean_absolute_error(y_true, y_pred):
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
return K.mean(math_ops.abs(y_pred - y_true), axis=-1)
@keras_export('keras.metrics.mean_absolute_percentage_error',
'keras.metrics.mape',
'keras.metrics.MAPE',
'keras.losses.mean_absolute_percentage_error',
'keras.losses.mape',
'keras.losses.MAPE')
def mean_absolute_percentage_error(y_true, y_pred): # pylint: disable=missing-docstring
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
diff = math_ops.abs(
(y_true - y_pred) / K.maximum(math_ops.abs(y_true), K.epsilon()))
return 100. * K.mean(diff, axis=-1)
@keras_export('keras.metrics.mean_squared_logarithmic_error',
'keras.metrics.msle',
'keras.metrics.MSLE',
'keras.losses.mean_squared_logarithmic_error',
'keras.losses.msle',
'keras.losses.MSLE')
def mean_squared_logarithmic_error(y_true, y_pred): # pylint: disable=missing-docstring
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
first_log = math_ops.log(K.maximum(y_pred, K.epsilon()) + 1.)
second_log = math_ops.log(K.maximum(y_true, K.epsilon()) + 1.)
return K.mean(math_ops.squared_difference(first_log, second_log), axis=-1)
def _maybe_convert_labels(y_true):
"""Converts binary labels into -1/1."""
are_zeros = math_ops.equal(y_true, 0)
are_ones = math_ops.equal(y_true, 1)
is_binary = math_ops.reduce_all(math_ops.logical_or(are_zeros, are_ones))
def _convert_binary_labels():
# Convert the binary labels to -1 or 1.
return 2. * y_true - 1.
updated_y_true = smart_cond.smart_cond(is_binary,
_convert_binary_labels, lambda: y_true)
return updated_y_true
@keras_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge')
def squared_hinge(y_true, y_pred):
"""Computes the squared hinge loss between `y_true` and `y_pred`.
Args:
y_true: The ground truth values. `y_true` values are expected to be -1 or 1.
If binary (0 or 1) labels are provided we will convert them to -1 or 1.
y_pred: The predicted values.
Returns:
Tensor with one scalar loss entry per sample.
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
y_true = _maybe_convert_labels(y_true)
return K.mean(
math_ops.square(math_ops.maximum(1. - y_true * y_pred, 0.)), axis=-1)
@keras_export('keras.metrics.hinge', 'keras.losses.hinge')
def hinge(y_true, y_pred):
"""Computes the hinge loss between `y_true` and `y_pred`.
Args:
y_true: The ground truth values. `y_true` values are expected to be -1 or 1.
If binary (0 or 1) labels are provided they will be converted to -1 or 1.
y_pred: The predicted values.
Returns:
Tensor with one scalar loss entry per sample.
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
y_true = _maybe_convert_labels(y_true)
return K.mean(math_ops.maximum(1. - y_true * y_pred, 0.), axis=-1)
@keras_export('keras.losses.categorical_hinge')
def categorical_hinge(y_true, y_pred):
"""Computes the categorical hinge loss between `y_true` and `y_pred`.
Args:
y_true: The ground truth values. `y_true` values are expected to be -1 or 1.
If binary (0 or 1) labels are provided they will be converted to -1 or 1.
y_pred: The predicted values.
Returns:
A tensor.
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
pos = math_ops.reduce_sum(y_true * y_pred, axis=-1)
neg = math_ops.reduce_max((1. - y_true) * y_pred, axis=-1)
return math_ops.maximum(0., neg - pos + 1.)
def huber_loss(y_true, y_pred, delta=1.0):
"""Computes Huber loss value.
For each value x in `error = y_true - y_pred`:
```
loss = 0.5 * x^2 if |x| <= d
loss = 0.5 * d^2 + d * (|x| - d) if |x| > d
```
where d is `delta`. See: https://en.wikipedia.org/wiki/Huber_loss
Args:
y_true: tensor of true targets.
y_pred: tensor of predicted targets.
delta: A float, the point where the Huber loss function changes from a
quadratic to linear.
Returns:
Tensor with one scalar loss entry per sample.
"""
y_pred = math_ops.cast(y_pred, dtype=K.floatx())
y_true = math_ops.cast(y_true, dtype=K.floatx())
error = math_ops.subtract(y_pred, y_true)
abs_error = math_ops.abs(error)
quadratic = math_ops.minimum(abs_error, delta)
linear = math_ops.subtract(abs_error, quadratic)
return math_ops.add(
math_ops.multiply(
ops.convert_to_tensor(0.5, dtype=quadratic.dtype),
math_ops.multiply(quadratic, quadratic)),
math_ops.multiply(delta, linear))
@keras_export('keras.losses.logcosh')
def logcosh(y_true, y_pred):
"""Logarithm of the hyperbolic cosine of the prediction error.
`log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and
to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly
like the mean squared error, but will not be so strongly affected by the
occasional wildly incorrect prediction.
Arguments:
y_true: tensor of true targets.
y_pred: tensor of predicted targets.
Returns:
Tensor with one scalar loss entry per sample.
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
def _logcosh(x):
return x + nn.softplus(-2. * x) - math_ops.log(2.)
return K.mean(_logcosh(y_pred - y_true), axis=-1)
@keras_export('keras.metrics.categorical_crossentropy',
'keras.losses.categorical_crossentropy')
def categorical_crossentropy(y_true,
y_pred,
from_logits=False,
label_smoothing=0):
"""Computes the categorical crossentropy loss.
Args:
y_true: tensor of true targets.
y_pred: tensor of predicted targets.
from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
we assume that `y_pred` encodes a probability distribution.
label_smoothing: Float in [0, 1]. If > `0` then smooth the labels.
Returns:
Categorical crossentropy loss value.
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
label_smoothing = ops.convert_to_tensor(label_smoothing, dtype=K.floatx())
def _smooth_labels():
num_classes = math_ops.cast(array_ops.shape(y_true)[1], y_pred.dtype)
return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes)
y_true = smart_cond.smart_cond(label_smoothing,
_smooth_labels, lambda: y_true)
return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
@keras_export('keras.metrics.sparse_categorical_crossentropy',
'keras.losses.sparse_categorical_crossentropy')
def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
return K.sparse_categorical_crossentropy(
y_true, y_pred, from_logits=from_logits, axis=axis)
@keras_export('keras.metrics.binary_crossentropy',
'keras.losses.binary_crossentropy')
def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0): # pylint: disable=missing-docstring
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
label_smoothing = ops.convert_to_tensor(label_smoothing, dtype=K.floatx())
def _smooth_labels():
return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
y_true = smart_cond.smart_cond(label_smoothing,
_smooth_labels, lambda: y_true)
return K.mean(
K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
@keras_export('keras.metrics.kullback_leibler_divergence',
'keras.metrics.kld',
'keras.metrics.KLD',
'keras.losses.kullback_leibler_divergence',