简单说明:
augment: 对图像进行TTA操作;
batch_augment:对一批图像进行TTA操作;
deaugment_boxes: 对预测框进行TTA操作,该操作与输入图像相反。
class BaseWheatTTA:
""" author: @shonenkov """
image_size = 512
def augment(self, image):
raise NotImplementedError
def batch_augment(self, images):
raise NotImplementedError
def deaugment_boxes(self, boxes):
raise NotImplementedError
class TTAHorizontalFlip(BaseWheatTTA):
""" author: @shonenkov """
def augment(self, image):
return image.flip(1)
def batch_augment(self, images):
return images.flip(2)
def deaugment_boxes(self, boxes):
boxes[:, [1,3]] = self.image_size - boxes[:, [3,1]]
return boxes
class TTAVerticalFlip(BaseWheatTTA):
""" author: @shonenkov """
def augment(self, image):
return image.flip(2)
def batch_augment(self, images):
return images.flip(3)
def deaugment_boxes(self, boxes):
boxes[:, [0,2]] = self.image_size - boxes[:, [2,0]]
return boxes
class TTARotate90(BaseWheatTTA):
""" author: @shonenkov """
def augment(self, image):
return torch.rot90(image, 1, (1, 2))
def batch_augment(self, images):
return torch.rot90(images, 1, (2, 3))
def deaugment_boxes(self, boxes):
res_boxes = boxes.copy()
res_boxes[:, [0,2]] = self.image_size - boxes[:, [1,3]]
res_boxes[:, [1,3]] = boxes[:, [2,0]]
return res_boxes
class TTACompose(BaseWheatTTA):
""" author: @shonenkov """
def __init__(self, transforms):
self.transforms = transforms
def augment(self, image):
for transform in self.transforms:
image = transform.augment(image)
return image
def batch_augment(self, images):
for transform in self.transforms:
images = transform.batch_augment(images)
return images
def prepare_boxes(self, boxes):
result_boxes = boxes.copy()
result_boxes[:,0] = np.min(boxes[:, [0,2]], axis=1)
result_boxes[:,2] = np.max(boxes[:, [0,2]], axis=1)
result_boxes[:,1] = np.min(boxes[:, [1,3]], axis=1)
result_boxes[:,3] = np.max(boxes[:, [1,3]], axis=1)
return result_boxes
def deaugment_boxes(self, boxes):
for transform in self.transforms[::-1]:
boxes = transform.deaugment_boxes(boxes)
return self.prepare_boxes(boxes)
# you can try own combinations:
transform = TTACompose([
TTARotate90(),
TTAVerticalFlip(),
])
fig, ax = plt.subplots(1, 3, figsize=(16, 6))
image, image_id = dataset[5]
numpy_image = image.permute(1,2,0).cpu().numpy().copy()
ax[0].imshow(numpy_image);
ax[0].set_title('original')
tta_image = transform.augment(image)
tta_image_numpy = tta_image.permute(1,2,0).cpu().numpy().copy()
det = net(tta_image.unsqueeze(0).float().cuda(), torch.tensor([1]).float().cuda())
boxes, scores = process_det(0, det)
for box in boxes:
cv2.rectangle(tta_image_numpy, (box[0], box[1]), (box[2], box[3]), (0, 1, 0), 2)
ax[1].imshow(tta_image_numpy);
ax[1].set_title('tta')
boxes = transform.deaugment_boxes(boxes)
for box in boxes:
cv2.rectangle(numpy_image, (box[0], box[1]), (box[2], box[3]), (0, 1, 0), 2)
ax[2].imshow(numpy_image);
ax[2].set_title('deaugment predictions');