Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix magnitude_range in RandAug #249

Merged
merged 4 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions mmcls/datasets/pipelines/auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,18 +510,23 @@ class Solarize(object):

Args:
thr (int | float): The threshold above which the pixels value will be
inverted.
inverted when incresing is set to False.
prob (float): The probability for solarizing therefore should be in
range [0, 1]. Defaults to 0.5.
increasing (bool): When setting to True, the meaning of thr is
8 - actual thr.
"""

def __init__(self, thr, prob=0.5):
def __init__(self, thr, prob=0.5, increasing=False):
assert isinstance(thr, (int, float)), 'The thr type must '\
f'be int or float, but got {type(thr)} instead.'
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.'

self.thr = thr
if increasing:
self.thr = 256 - thr
else:
self.thr = thr
LXXXXR marked this conversation as resolved.
Show resolved Hide resolved
self.prob = prob

def __call__(self, results):
Expand Down Expand Up @@ -588,18 +593,23 @@ class Posterize(object):
"""Posterize images (reduce the number of bits for each color channel).

Args:
bits (int | float): Number of bits for each pixel in the output img,
which should be less or equal to 8.
bits (int | float): Number of bits for each pixel in the output img
when increasing is False, which should be less or equal to 8.
prob (float): The probability for posterizing therefore should be in
range [0, 1]. Defaults to 0.5.
increasing (bool): When setting to True, the meaning of bits is
8 - actual number of bits.
"""

def __init__(self, bits, prob=0.5):
def __init__(self, bits, prob=0.5, increasing=False):
assert bits <= 8, f'The bits must be less than 8, got {bits} instead.'
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.'

self.bits = int(bits)
if increasing:
self.bits = 8 - int(bits)
else:
self.bits = int(bits)
self.prob = prob

def __call__(self, results):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_pipelines/test_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,18 @@ def test_solarize():
assert (results['img'] == img_solarized).all()
assert (results['img'] == results['img2']).all()

# test case when thr=156
results = construct_toy_data_photometric()
transform = dict(type='Solarize', thr=156, prob=1., increasing=True)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
img_solarized = np.array([[0, 127, 0], [1, 128, 1], [2, 126, 2]],
dtype=np.uint8)
img_solarized = np.stack([img_solarized, img_solarized, img_solarized],
axis=-1)
assert (results['img'] == img_solarized).all()
assert (results['img'] == results['img2']).all()


def test_solarize_add():
# test assertion for invalid type of magnitude
Expand Down Expand Up @@ -822,6 +834,18 @@ def test_posterize():
assert (results['img'] == img_posterized).all()
assert (results['img'] == results['img2']).all()

# test case when bits=5, incresing= True
results = construct_toy_data_photometric()
transform = dict(type='Posterize', bits=5, prob=1., increasing=True)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
img_posterized = np.array([[0, 128, 224], [0, 96, 224], [0, 128, 224]],
dtype=np.uint8)
img_posterized = np.stack([img_posterized, img_posterized, img_posterized],
axis=-1)
assert (results['img'] == img_posterized).all()
assert (results['img'] == results['img2']).all()


def test_contrast(nb_rand_test=100):

Expand Down