{"id":7215,"date":"2024-04-11T11:01:01","date_gmt":"2024-04-11T03:01:01","guid":{"rendered":""},"modified":"2024-04-11T11:01:01","modified_gmt":"2024-04-11T03:01:01","slug":"\u76ee\u6807\u68c0\u6d4b\u4e4b\u6d4b\u8bd5\u65f6\u95f4\u589e\u5f3a\uff08TTA\uff09","status":"publish","type":"post","link":"https:\/\/mushiming.com\/7215.html","title":{"rendered":"\u76ee\u6807\u68c0\u6d4b\u4e4b\u6d4b\u8bd5\u65f6\u95f4\u589e\u5f3a\uff08TTA\uff09"},"content":{"rendered":"

\n <\/path> \n<\/svg> <\/p>\n

TTA<\/h2>\n

\u7b80\u5355\u8bf4\u660e\uff1a
augment\uff1a \u5bf9\u56fe\u50cf\u8fdb\u884cTTA\u64cd\u4f5c\uff1b
batch_augment\uff1a\u5bf9\u4e00\u6279\u56fe\u50cf\u8fdb\u884cTTA\u64cd\u4f5c\uff1b
deaugment_boxes: \u5bf9\u9884\u6d4b\u6846\u8fdb\u884cTTA\u64cd\u4f5c\uff0c\u8be5\u64cd\u4f5c\u4e0e\u8f93\u5165\u56fe\u50cf\u76f8\u53cd\u3002<\/p>\n

class<\/span> BaseWheatTTA<\/span>:<\/span>\n    \"\"\" author: @shonenkov \"\"\"<\/span>\n    image_size =<\/span> 512<\/span>\n\n    def<\/span> augment<\/span>(<\/span>self,<\/span> image)<\/span>:<\/span>\n        raise<\/span> NotImplementedError\n    \n    def<\/span> batch_augment<\/span>(<\/span>self,<\/span> images)<\/span>:<\/span>\n        raise<\/span> NotImplementedError\n    \n    def<\/span> deaugment_boxes<\/span>(<\/span>self,<\/span> boxes)<\/span>:<\/span>\n        raise<\/span> NotImplementedError\n\nclass<\/span> TTAHorizontalFlip<\/span>(<\/span>BaseWheatTTA)<\/span>:<\/span>\n    \"\"\" author: @shonenkov \"\"\"<\/span>\n\n    def<\/span> augment<\/span>(<\/span>self,<\/span> image)<\/span>:<\/span>\n        return<\/span> image.<\/span>flip(<\/span>1<\/span>)<\/span>\n    \n    def<\/span> batch_augment<\/span>(<\/span>self,<\/span> images)<\/span>:<\/span>\n        return<\/span> images.<\/span>flip(<\/span>2<\/span>)<\/span>\n    \n    def<\/span> deaugment_boxes<\/span>(<\/span>self,<\/span> boxes)<\/span>:<\/span>\n        boxes[<\/span>:<\/span>,<\/span> [<\/span>1<\/span>,<\/span>3<\/span>]<\/span>]<\/span> =<\/span> self.<\/span>image_size -<\/span> boxes[<\/span>:<\/span>,<\/span> [<\/span>3<\/span>,<\/span>1<\/span>]<\/span>]<\/span>\n        return<\/span> boxes\n\nclass<\/span> TTAVerticalFlip<\/span>(<\/span>BaseWheatTTA)<\/span>:<\/span>\n    \"\"\" author: @shonenkov \"\"\"<\/span>\n    \n    def<\/span> augment<\/span>(<\/span>self,<\/span> image)<\/span>:<\/span>\n        return<\/span> image.<\/span>flip(<\/span>2<\/span>)<\/span>\n    \n    def<\/span> batch_augment<\/span>(<\/span>self,<\/span> images)<\/span>:<\/span>\n        return<\/span> images.<\/span>flip(<\/span>3<\/span>)<\/span>\n    \n    def<\/span> deaugment_boxes<\/span>(<\/span>self,<\/span> boxes)<\/span>:<\/span>\n        boxes[<\/span>:<\/span>,<\/span> [<\/span>0<\/span>,<\/span>2<\/span>]<\/span>]<\/span> =<\/span> self.<\/span>image_size -<\/span> boxes[<\/span>:<\/span>,<\/span> [<\/span>2<\/span>,<\/span>0<\/span>]<\/span>]<\/span>\n        return<\/span> boxes\n    \nclass<\/span> TTARotate90<\/span>(<\/span>BaseWheatTTA)<\/span>:<\/span>\n    \"\"\" author: @shonenkov \"\"\"<\/span>\n    \n    def<\/span> augment<\/span>(<\/span>self,<\/span> image)<\/span>:<\/span>\n        return<\/span> torch.<\/span>rot90(<\/span>image,<\/span> 1<\/span>,<\/span> (<\/span>1<\/span>,<\/span> 2<\/span>)<\/span>)<\/span>\n\n    def<\/span> batch_augment<\/span>(<\/span>self,<\/span> images)<\/span>:<\/span>\n        return<\/span> torch.<\/span>rot90(<\/span>images,<\/span> 1<\/span>,<\/span> (<\/span>2<\/span>,<\/span> 3<\/span>)<\/span>)<\/span>\n    \n    def<\/span> deaugment_boxes<\/span>(<\/span>self,<\/span> boxes)<\/span>:<\/span>\n        res_boxes =<\/span> boxes.<\/span>copy(<\/span>)<\/span>\n        res_boxes[<\/span>:<\/span>,<\/span> [<\/span>0<\/span>,<\/span>2<\/span>]<\/span>]<\/span> =<\/span> self.<\/span>image_size -<\/span> boxes[<\/span>:<\/span>,<\/span> [<\/span>1<\/span>,<\/span>3<\/span>]<\/span>]<\/span>\n        res_boxes[<\/span>:<\/span>,<\/span> [<\/span>1<\/span>,<\/span>3<\/span>]<\/span>]<\/span> =<\/span> boxes[<\/span>:<\/span>,<\/span> [<\/span>2<\/span>,<\/span>0<\/span>]<\/span>]<\/span>\n        return<\/span> res_boxes\n\nclass<\/span> TTACompose<\/span>(<\/span>BaseWheatTTA)<\/span>:<\/span>\n    \"\"\" author: @shonenkov \"\"\"<\/span>\n    def<\/span> __init__<\/span>(<\/span>self,<\/span> transforms)<\/span>:<\/span>\n        self.<\/span>transforms =<\/span> transforms\n        \n    def<\/span> augment<\/span>(<\/span>self,<\/span> image)<\/span>:<\/span>\n        for<\/span> transform in<\/span> self.<\/span>transforms:<\/span>\n            image =<\/span> transform.<\/span>augment(<\/span>image)<\/span>\n        return<\/span> image\n    \n    def<\/span> batch_augment<\/span>(<\/span>self,<\/span> images)<\/span>:<\/span>\n        for<\/span> transform in<\/span> self.<\/span>transforms:<\/span>\n            images =<\/span> transform.<\/span>batch_augment(<\/span>images)<\/span>\n        return<\/span> images\n    \n    def<\/span> prepare_boxes<\/span>(<\/span>self,<\/span> boxes)<\/span>:<\/span>\n        result_boxes =<\/span> boxes.<\/span>copy(<\/span>)<\/span>\n        result_boxes[<\/span>:<\/span>,<\/span>0<\/span>]<\/span> =<\/span> np.<\/span>min<\/span>(<\/span>boxes[<\/span>:<\/span>,<\/span> [<\/span>0<\/span>,<\/span>2<\/span>]<\/span>]<\/span>,<\/span> axis=<\/span>1<\/span>)<\/span>\n        result_boxes[<\/span>:<\/span>,<\/span>2<\/span>]<\/span> =<\/span> np.<\/span>max<\/span>(<\/span>boxes[<\/span>:<\/span>,<\/span> [<\/span>0<\/span>,<\/span>2<\/span>]<\/span>]<\/span>,<\/span> axis=<\/span>1<\/span>)<\/span>\n        result_boxes[<\/span>:<\/span>,<\/span>1<\/span>]<\/span> =<\/span> np.<\/span>min<\/span>(<\/span>boxes[<\/span>:<\/span>,<\/span> [<\/span>1<\/span>,<\/span>3<\/span>]<\/span>]<\/span>,<\/span> axis=<\/span>1<\/span>)<\/span>\n        result_boxes[<\/span>:<\/span>,<\/span>3<\/span>]<\/span> =<\/span> np.<\/span>max<\/span>(<\/span>boxes[<\/span>:<\/span>,<\/span> [<\/span>1<\/span>,<\/span>3<\/span>]<\/span>]<\/span>,<\/span> axis=<\/span>1<\/span>)<\/span>\n        return<\/span> result_boxes\n    \n    def<\/span> deaugment_boxes<\/span>(<\/span>self,<\/span> boxes)<\/span>:<\/span>\n        for<\/span> transform in<\/span> self.<\/span>transforms[<\/span>:<\/span>:<\/span>-<\/span>1<\/span>]<\/span>:<\/span>\n            boxes =<\/span> transform.<\/span>deaugment_boxes(<\/span>boxes)<\/span>\n        return<\/span> self.<\/span>prepare_boxes(<\/span>boxes)<\/span>\n<\/code><\/pre>\n

TTA\u7684\u7528\u6cd5<\/h2>\n
# you can try own combinations:<\/span>\ntransform =<\/span> TTACompose(<\/span>[<\/span>\n    TTARotate90(<\/span>)<\/span>,<\/span>\n    TTAVerticalFlip(<\/span>)<\/span>,<\/span>\n]<\/span>)<\/span>\n\nfig,<\/span> ax =<\/span> plt.<\/span>subplots(<\/span>1<\/span>,<\/span> 3<\/span>,<\/span> figsize=<\/span>(<\/span>16<\/span>,<\/span> 6<\/span>)<\/span>)<\/span>\n\nimage,<\/span> image_id =<\/span> dataset[<\/span>5<\/span>]<\/span>\n\nnumpy_image =<\/span> image.<\/span>permute(<\/span>1<\/span>,<\/span>2<\/span>,<\/span>0<\/span>)<\/span>.<\/span>cpu(<\/span>)<\/span>.<\/span>numpy(<\/span>)<\/span>.<\/span>copy(<\/span>)<\/span>\n\nax[<\/span>0<\/span>]<\/span>.<\/span>imshow(<\/span>numpy_image)<\/span>;<\/span>\nax[<\/span>0<\/span>]<\/span>.<\/span>set_title(<\/span>'original'<\/span>)<\/span>\n\ntta_image =<\/span> transform.<\/span>augment(<\/span>image)<\/span>\ntta_image_numpy =<\/span> tta_image.<\/span>permute(<\/span>1<\/span>,<\/span>2<\/span>,<\/span>0<\/span>)<\/span>.<\/span>cpu(<\/span>)<\/span>.<\/span>numpy(<\/span>)<\/span>.<\/span>copy(<\/span>)<\/span>\n\ndet =<\/span> net(<\/span>tta_image.<\/span>unsqueeze(<\/span>0<\/span>)<\/span>.<\/span>float<\/span>(<\/span>)<\/span>.<\/span>cuda(<\/span>)<\/span>,<\/span> torch.<\/span>tensor(<\/span>[<\/span>1<\/span>]<\/span>)<\/span>.<\/span>float<\/span>(<\/span>)<\/span>.<\/span>cuda(<\/span>)<\/span>)<\/span>\nboxes,<\/span> scores =<\/span> process_det(<\/span>0<\/span>,<\/span> det)<\/span>\n\nfor<\/span> box in<\/span> boxes:<\/span>\n    cv2.<\/span>rectangle(<\/span>tta_image_numpy,<\/span> (<\/span>box[<\/span>0<\/span>]<\/span>,<\/span> box[<\/span>1<\/span>]<\/span>)<\/span>,<\/span> (<\/span>box[<\/span>2<\/span>]<\/span>,<\/span>  box[<\/span>3<\/span>]<\/span>)<\/span>,<\/span> (<\/span>0<\/span>,<\/span> 1<\/span>,<\/span> 0<\/span>)<\/span>,<\/span> 2<\/span>)<\/span>\n\nax[<\/span>1<\/span>]<\/span>.<\/span>imshow(<\/span>tta_image_numpy)<\/span>;<\/span>\nax[<\/span>1<\/span>]<\/span>.<\/span>set_title(<\/span>'tta'<\/span>)<\/span>\n    \nboxes =<\/span> transform.<\/span>deaugment_boxes(<\/span>boxes)<\/span>\n\nfor<\/span> box in<\/span> boxes:<\/span>\n    cv2.<\/span>rectangle(<\/span>numpy_image,<\/span> (<\/span>box[<\/span>0<\/span>]<\/span>,<\/span> box[<\/span>1<\/span>]<\/span>)<\/span>,<\/span> (<\/span>box[<\/span>2<\/span>]<\/span>,<\/span>  box[<\/span>3<\/span>]<\/span>)<\/span>,<\/span> (<\/span>0<\/span>,<\/span> 1<\/span>,<\/span> 0<\/span>)<\/span>,<\/span> 2<\/span>)<\/span>\n    \nax[<\/span>2<\/span>]<\/span>.<\/span>imshow(<\/span>numpy_image)<\/span>;<\/span>\nax[<\/span>2<\/span>]<\/span>.<\/span>set_title(<\/span>'deaugment predictions'<\/span>)<\/span>;<\/span>\n<\/code><\/pre>\n

\"\u76ee\u6807\u68c0\u6d4b\u4e4b\u6d4b\u8bd5\u65f6\u95f4\u589e\u5f3a\uff08TTA\uff09<\/p>\n

\u53c2\u8003<\/h2>\n