From 9e8fb9fd0b4e6ad840991823f7342ca6227ddb62 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 4 Jul 2021 18:14:04 +0200 Subject: [PATCH] Create `utils/augmentations.py` (#3877) * Create `utils/augmentations.py` * cleanup --- utils/augmentations.py | 244 +++++++++++++++++++++++++++++++++++++++++ utils/datasets.py | 241 +--------------------------------------- 2 files changed, 250 insertions(+), 235 deletions(-) create mode 100644 utils/augmentations.py diff --git a/utils/augmentations.py b/utils/augmentations.py new file mode 100644 index 000000000000..f7b13165daf0 --- /dev/null +++ b/utils/augmentations.py @@ -0,0 +1,244 @@ +# YOLOv5 image augmentation functions + +import random + +import cv2 +import math +import numpy as np + +from utils.general import segment2box, resample_segments +from utils.metrics import bbox_ioa + + +def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5): + # HSV color-space augmentation + if hgain or sgain or vgain: + r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains + hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV)) + dtype = im.dtype # uint8 + + x = np.arange(0, 256, dtype=r.dtype) + lut_hue = ((x * r[0]) % 180).astype(dtype) + lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) + lut_val = np.clip(x * r[2], 0, 255).astype(dtype) + + img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) + cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed + + +def hist_equalize(im, clahe=True, bgr=False): + # Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255 + yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV) + if clahe: + c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + yuv[:, :, 0] = c.apply(yuv[:, :, 0]) + else: + yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram + return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB + + +def replicate(im, labels): + # Replicate labels + h, w = im.shape[:2] + boxes = labels[:, 1:].astype(int) + x1, y1, x2, y2 = boxes.T + s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels) + for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices + x1b, y1b, x2b, y2b = boxes[i] + bh, bw = y2b - y1b, x2b - x1b + yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y + x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh] + im[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax] + labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0) + + return im, labels + + +def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): + # Resize and pad image while meeting stride-multiple constraints + shape = im.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better test mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if auto: # minimum rectangle + dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding + elif scaleFill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return im, ratio, (dw, dh) + + +def random_perspective(im, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, + border=(0, 0)): + # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10)) + # targets = [cls, xyxy] + + height = im.shape[0] + border[0] * 2 # shape(h,w,c) + width = im.shape[1] + border[1] * 2 + + # Center + C = np.eye(3) + C[0, 2] = -im.shape[1] / 2 # x translation (pixels) + C[1, 2] = -im.shape[0] / 2 # y translation (pixels) + + # Perspective + P = np.eye(3) + P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y) + P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x) + + # Rotation and Scale + R = np.eye(3) + a = random.uniform(-degrees, degrees) + # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations + s = random.uniform(1 - scale, 1 + scale) + # s = 2 ** random.uniform(-scale, scale) + R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) + + # Shear + S = np.eye(3) + S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg) + S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg) + + # Translation + T = np.eye(3) + T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels) + T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels) + + # Combined rotation matrix + M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT + if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed + if perspective: + im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114)) + else: # affine + im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114)) + + # Visualize + # import matplotlib.pyplot as plt + # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel() + # ax[0].imshow(img[:, :, ::-1]) # base + # ax[1].imshow(img2[:, :, ::-1]) # warped + + # Transform label coordinates + n = len(targets) + if n: + use_segments = any(x.any() for x in segments) + new = np.zeros((n, 4)) + if use_segments: # warp segments + segments = resample_segments(segments) # upsample + for i, segment in enumerate(segments): + xy = np.ones((len(segment), 3)) + xy[:, :2] = segment + xy = xy @ M.T # transform + xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine + + # clip + new[i] = segment2box(xy, width, height) + + else: # warp boxes + xy = np.ones((n * 4, 3)) + xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 + xy = xy @ M.T # transform + xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine + + # create new boxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + + # clip + new[:, [0, 2]] = new[:, [0, 2]].clip(0, width) + new[:, [1, 3]] = new[:, [1, 3]].clip(0, height) + + # filter candidates + i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10) + targets = targets[i] + targets[:, 1:5] = new[i] + + return im, targets + + +def copy_paste(im, labels, segments, probability=0.5): + # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy) + n = len(segments) + if probability and n: + h, w, c = im.shape # height, width, channels + im_new = np.zeros(im.shape, np.uint8) + for j in random.sample(range(n), k=round(probability * n)): + l, s = labels[j], segments[j] + box = w - l[3], l[2], w - l[1], l[4] + ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area + if (ioa < 0.30).all(): # allow 30% obscuration of existing labels + labels = np.concatenate((labels, [[l[0], *box]]), 0) + segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)) + cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED) + + result = cv2.bitwise_and(src1=im, src2=im_new) + result = cv2.flip(result, 1) # augment segments (flip left-right) + i = result > 0 # pixels to replace + # i[:, :] = result.max(2).reshape(h, w, 1) # act over ch + im[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug + + return im, labels, segments + + +def cutout(im, labels): + # Applies image cutout augmentation https://arxiv.org/abs/1708.04552 + h, w = im.shape[:2] + + # create random masks + scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction + for s in scales: + mask_h = random.randint(1, int(h * s)) + mask_w = random.randint(1, int(w * s)) + + # box + xmin = max(0, random.randint(0, w) - mask_w // 2) + ymin = max(0, random.randint(0, h) - mask_h // 2) + xmax = min(w, xmin + mask_w) + ymax = min(h, ymin + mask_h) + + # apply random color mask + im[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)] + + # return unobscured labels + if len(labels) and s > 0.03: + box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32) + ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area + labels = labels[ioa < 0.60] # remove >60% obscured labels + + return labels + + +def mixup(im, labels, im2, labels2): + # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf + r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0 + im = (im * r + im2 * (1 - r)).astype(np.uint8) + labels = np.concatenate((labels, labels2), 0) + return im, labels + + +def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) + # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio + w1, h1 = box1[2] - box1[0], box1[3] - box1[1] + w2, h2 = box2[2] - box2[0], box2[3] - box2[1] + ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio + return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates diff --git a/utils/datasets.py b/utils/datasets.py index 8560f7cfeb88..5c76a908c559 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -1,4 +1,4 @@ -# Dataset utils and dataloaders +# YOLOv5 dataset utils and dataloaders import glob import hashlib @@ -14,7 +14,6 @@ from threading import Thread import cv2 -import math import numpy as np import torch import torch.nn.functional as F @@ -23,9 +22,9 @@ from torch.utils.data import Dataset from tqdm import tqdm +from utils.augmentations import augment_hsv, copy_paste, letterbox, mixup, random_perspective from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \ - xyn2xy, segment2box, segments2boxes, resample_segments, clean_str -from utils.metrics import bbox_ioa + xyn2xy, segments2boxes, clean_str from utils.torch_utils import torch_distributed_zero_first # Parameters @@ -523,12 +522,10 @@ def __getitem__(self, index): img, labels = load_mosaic(self, index) shapes = None - # MixUp https://arxiv.org/pdf/1710.09412.pdf + # MixUp augmentation if random.random() < hyp['mixup']: - img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1)) - r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0 - img = (img * r + img2 * (1 - r)).astype(np.uint8) - labels = np.concatenate((labels, labels2), 0) + img, labels = mixup(img, labels, *load_mosaic(self, random.randint(0, self.n - 1))) + else: # Load image @@ -639,32 +636,6 @@ def load_image(self, index): return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized -def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5): - if hgain or sgain or vgain: - r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains - hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) - dtype = img.dtype # uint8 - - x = np.arange(0, 256, dtype=r.dtype) - lut_hue = ((x * r[0]) % 180).astype(dtype) - lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) - lut_val = np.clip(x * r[2], 0, 255).astype(dtype) - - img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) - cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed - - -def hist_equalize(img, clahe=True, bgr=False): - # Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255 - yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV) - if clahe: - c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) - yuv[:, :, 0] = c.apply(yuv[:, :, 0]) - else: - yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram - return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB - - def load_mosaic(self, index): # loads images in a 4-mosaic @@ -796,205 +767,6 @@ def load_mosaic9(self, index): return img9, labels9 -def replicate(img, labels): - # Replicate labels - h, w = img.shape[:2] - boxes = labels[:, 1:].astype(int) - x1, y1, x2, y2 = boxes.T - s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels) - for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices - x1b, y1b, x2b, y2b = boxes[i] - bh, bw = y2b - y1b, x2b - x1b - yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y - x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh] - img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax] - labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0) - - return img, labels - - -def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): - # Resize and pad image while meeting stride-multiple constraints - shape = img.shape[:2] # current shape [height, width] - if isinstance(new_shape, int): - new_shape = (new_shape, new_shape) - - # Scale ratio (new / old) - r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) - if not scaleup: # only scale down, do not scale up (for better test mAP) - r = min(r, 1.0) - - # Compute padding - ratio = r, r # width, height ratios - new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) - dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding - if auto: # minimum rectangle - dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding - elif scaleFill: # stretch - dw, dh = 0.0, 0.0 - new_unpad = (new_shape[1], new_shape[0]) - ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios - - dw /= 2 # divide padding into 2 sides - dh /= 2 - - if shape[::-1] != new_unpad: # resize - img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) - top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) - left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) - img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border - return img, ratio, (dw, dh) - - -def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, - border=(0, 0)): - # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10)) - # targets = [cls, xyxy] - - height = img.shape[0] + border[0] * 2 # shape(h,w,c) - width = img.shape[1] + border[1] * 2 - - # Center - C = np.eye(3) - C[0, 2] = -img.shape[1] / 2 # x translation (pixels) - C[1, 2] = -img.shape[0] / 2 # y translation (pixels) - - # Perspective - P = np.eye(3) - P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y) - P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x) - - # Rotation and Scale - R = np.eye(3) - a = random.uniform(-degrees, degrees) - # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations - s = random.uniform(1 - scale, 1 + scale) - # s = 2 ** random.uniform(-scale, scale) - R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) - - # Shear - S = np.eye(3) - S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg) - S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg) - - # Translation - T = np.eye(3) - T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels) - T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels) - - # Combined rotation matrix - M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT - if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed - if perspective: - img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114)) - else: # affine - img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114)) - - # Visualize - # import matplotlib.pyplot as plt - # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel() - # ax[0].imshow(img[:, :, ::-1]) # base - # ax[1].imshow(img2[:, :, ::-1]) # warped - - # Transform label coordinates - n = len(targets) - if n: - use_segments = any(x.any() for x in segments) - new = np.zeros((n, 4)) - if use_segments: # warp segments - segments = resample_segments(segments) # upsample - for i, segment in enumerate(segments): - xy = np.ones((len(segment), 3)) - xy[:, :2] = segment - xy = xy @ M.T # transform - xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine - - # clip - new[i] = segment2box(xy, width, height) - - else: # warp boxes - xy = np.ones((n * 4, 3)) - xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 - xy = xy @ M.T # transform - xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine - - # create new boxes - x = xy[:, [0, 2, 4, 6]] - y = xy[:, [1, 3, 5, 7]] - new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T - - # clip - new[:, [0, 2]] = new[:, [0, 2]].clip(0, width) - new[:, [1, 3]] = new[:, [1, 3]].clip(0, height) - - # filter candidates - i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10) - targets = targets[i] - targets[:, 1:5] = new[i] - - return img, targets - - -def copy_paste(img, labels, segments, probability=0.5): - # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy) - n = len(segments) - if probability and n: - h, w, c = img.shape # height, width, channels - im_new = np.zeros(img.shape, np.uint8) - for j in random.sample(range(n), k=round(probability * n)): - l, s = labels[j], segments[j] - box = w - l[3], l[2], w - l[1], l[4] - ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area - if (ioa < 0.30).all(): # allow 30% obscuration of existing labels - labels = np.concatenate((labels, [[l[0], *box]]), 0) - segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)) - cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED) - - result = cv2.bitwise_and(src1=img, src2=im_new) - result = cv2.flip(result, 1) # augment segments (flip left-right) - i = result > 0 # pixels to replace - # i[:, :] = result.max(2).reshape(h, w, 1) # act over ch - img[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug - - return img, labels, segments - - -def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) - # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio - w1, h1 = box1[2] - box1[0], box1[3] - box1[1] - w2, h2 = box2[2] - box2[0], box2[3] - box2[1] - ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio - return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates - - -def cutout(image, labels): - # Applies image cutout augmentation https://arxiv.org/abs/1708.04552 - h, w = image.shape[:2] - - # create random masks - scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction - for s in scales: - mask_h = random.randint(1, int(h * s)) - mask_w = random.randint(1, int(w * s)) - - # box - xmin = max(0, random.randint(0, w) - mask_w // 2) - ymin = max(0, random.randint(0, h) - mask_h // 2) - xmax = min(w, xmin + mask_w) - ymax = min(h, ymin + mask_h) - - # apply random color mask - image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)] - - # return unobscured labels - if len(labels) and s > 0.03: - box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32) - ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area - labels = labels[ioa < 0.60] # remove >60% obscured labels - - return labels - - def create_folder(path='./new'): # Create folder if os.path.exists(path): @@ -1012,7 +784,6 @@ def flatten_recursive(path='../datasets/coco128'): def extract_boxes(path='../datasets/coco128'): # from utils.datasets import *; extract_boxes() # Convert detection dataset into classification dataset, with one directory per class - path = Path(path) # images dir shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing files = list(path.rglob('*.*'))