Skip to content

Commit

Permalink
[Fix] Fix magnitude_range in RandAug (#249)
Browse files Browse the repository at this point in the history
* add increasing in solarize and posterize

* fix linting

* Revert "add increasing in solarize and posterize"

This reverts commit 128af36.

* revise according to comments
  • Loading branch information
LXXXXR committed May 12, 2021
1 parent f415c49 commit 8c90a87
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
17 changes: 12 additions & 5 deletions mmcls/datasets/pipelines/auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ class RandAugment(object):
augmentation. For those which have magnitude, (given to the fact
they are named differently in different augmentation, )
`magnitude_key` and `magnitude_range` shall be the magnitude
argument (str) and the range of magnitude (tuple in the format or
(minval, maxval)), respectively.
argument (str) and the range of magnitude (tuple in the format of
(val1, val2)), respectively. Note that val1 is not necessarily
less than val2.
num_policies (int): Number of policies to select from policies each
time.
magnitude_level (int | float): Magnitude level for all the augmentation
Expand All @@ -85,6 +86,10 @@ class RandAugment(object):
Note:
`magnitude_std` will introduce some randomness to policy, modified by
https://github.com/rwightman/pytorch-image-models
When magnitude_std=0, we calculate the magnitude as follows:
.. math::
magnitude = magnitude_level / total_level * (val2 - val1) + val1
"""

def __init__(self,
Expand Down Expand Up @@ -130,18 +135,20 @@ def _process_policies(self, policies):
processed_policy = copy.deepcopy(policy)
magnitude_key = processed_policy.pop('magnitude_key', None)
if magnitude_key is not None:
minval, maxval = processed_policy.pop('magnitude_range')
val1, val2 = processed_policy.pop('magnitude_range')
magnitude_value = (self.magnitude_level / self.total_level
) * float(maxval - minval) + minval
) * float(val2 - val1) + val1

# if magnitude_std is positive number or 'inf', move
# magnitude_value randomly.
maxval = max(val1, val2)
minval = min(val1, val2)
if self.magnitude_std == 'inf':
magnitude_value = random.uniform(minval, magnitude_value)
elif self.magnitude_std > 0:
magnitude_value = random.gauss(magnitude_value,
self.magnitude_std)
magnitude_value = min(maxval, max(0, magnitude_value))
magnitude_value = min(maxval, max(minval, magnitude_value))
processed_policy.update({magnitude_key: magnitude_value})
processed_policies.append(processed_policy)
return processed_policies
Expand Down
28 changes: 28 additions & 0 deletions tests/test_pipelines/test_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,34 @@ def test_rand_augment():
# apply rotation with prob=0.
assert (results['img'] == results['ori_img']).all()

# test case where magnitude_range is reversed
random.seed(1)
np.random.seed(0)
results = construct_toy_data()
reversed_policies = [
dict(
type='Translate',
magnitude_key='magnitude',
magnitude_range=(1, 0),
pad_val=128,
prob=1.,
direction='horizontal'),
dict(type='Invert', prob=1.),
dict(
type='Rotate',
magnitude_key='angle',
magnitude_range=(30, 0),
prob=0.)
]
transform = dict(
type='RandAugment',
policies=reversed_policies,
num_policies=1,
magnitude_level=30)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
assert (results['img'] == results['ori_img']).all()

# test case where num_policies = 2
random.seed(0)
np.random.seed(0)
Expand Down

0 comments on commit 8c90a87

Please sign in to comment.