You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
77 lines
2.4 KiB
77 lines
2.4 KiB
# Copied from https://github.com/mlcommons/training/blob/637c82f9e699cd6caf108f92efb2c1d446b630e0/single_stage_detector/ssd/transforms.py
|
|
|
|
import torch
|
|
import torchvision
|
|
|
|
from torch import nn, Tensor
|
|
from torchvision.transforms import functional as F
|
|
from torchvision.transforms import transforms as T
|
|
from typing import List, Tuple, Dict, Optional
|
|
|
|
from PIL import Image
|
|
Image.MAX_IMAGE_PIXELS = None
|
|
from typing import Any
|
|
|
|
try:
|
|
import accimage
|
|
except ImportError:
|
|
accimage = None
|
|
|
|
@torch.jit.unused
|
|
def _is_pil_image(img: Any) -> bool:
|
|
if accimage is not None:
|
|
return isinstance(img, (Image.Image, accimage.Image))
|
|
else:
|
|
return isinstance(img, Image.Image)
|
|
|
|
def get_image_size_tensor(img: Tensor) -> List[int]:
|
|
# Returns (w, h) of tensor image
|
|
torchvision.transforms._functional_tensor._assert_image_tensor(img)
|
|
return [img.shape[-1], img.shape[-2]]
|
|
|
|
@torch.jit.unused
|
|
def get_image_size_pil(img: Any) -> List[int]:
|
|
if _is_pil_image(img):
|
|
return list(img.size)
|
|
raise TypeError("Unexpected type {}".format(type(img)))
|
|
|
|
def get_image_size(img: Tensor) -> List[int]:
|
|
"""Returns the size of an image as [width, height].
|
|
Args:
|
|
img (PIL Image or Tensor): The image to be checked.
|
|
Returns:
|
|
List[int]: The image size.
|
|
"""
|
|
if isinstance(img, torch.Tensor):
|
|
return get_image_size_tensor(img)
|
|
|
|
return get_image_size_pil(img)
|
|
|
|
class Compose(object):
|
|
def __init__(self, transforms):
|
|
self.transforms = transforms
|
|
|
|
def __call__(self, image, target):
|
|
for t in self.transforms:
|
|
image, target = t(image, target)
|
|
return image, target
|
|
|
|
|
|
class RandomHorizontalFlip(T.RandomHorizontalFlip):
|
|
def forward(self, image: Tensor,
|
|
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
|
if torch.rand(1) < self.p:
|
|
image = F.hflip(image)
|
|
if target is not None:
|
|
width, _ = get_image_size(image)
|
|
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
|
|
if "masks" in target:
|
|
target["masks"] = target["masks"].flip(-1)
|
|
return image, target
|
|
|
|
|
|
class ToTensor(nn.Module):
|
|
def forward(self, image: Tensor,
|
|
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
|
image = F.to_tensor(image)
|
|
return image, target |