You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					77 lines
				
				2.4 KiB
			
		
		
			
		
	
	
					77 lines
				
				2.4 KiB
			| 
											4 days ago
										 | # Copied from https://github.com/mlcommons/training/blob/637c82f9e699cd6caf108f92efb2c1d446b630e0/single_stage_detector/ssd/transforms.py
 | ||
|  | 
 | ||
|  | import torch
 | ||
|  | import torchvision
 | ||
|  | 
 | ||
|  | from torch import nn, Tensor
 | ||
|  | from torchvision.transforms import functional as F
 | ||
|  | from torchvision.transforms import transforms as T
 | ||
|  | from typing import List, Tuple, Dict, Optional
 | ||
|  | 
 | ||
|  | from PIL import Image
 | ||
|  | Image.MAX_IMAGE_PIXELS = None
 | ||
|  | from typing import Any
 | ||
|  | 
 | ||
|  | try:
 | ||
|  |     import accimage
 | ||
|  | except ImportError:
 | ||
|  |     accimage = None
 | ||
|  | 
 | ||
|  | @torch.jit.unused
 | ||
|  | def _is_pil_image(img: Any) -> bool:
 | ||
|  |     if accimage is not None:
 | ||
|  |         return isinstance(img, (Image.Image, accimage.Image))
 | ||
|  |     else:
 | ||
|  |         return isinstance(img, Image.Image)
 | ||
|  | 
 | ||
|  | def get_image_size_tensor(img: Tensor) -> List[int]:
 | ||
|  |     # Returns (w, h) of tensor image
 | ||
|  |     torchvision.transforms._functional_tensor._assert_image_tensor(img)
 | ||
|  |     return [img.shape[-1], img.shape[-2]]
 | ||
|  | 
 | ||
|  | @torch.jit.unused
 | ||
|  | def get_image_size_pil(img: Any) -> List[int]:
 | ||
|  |     if _is_pil_image(img):
 | ||
|  |         return list(img.size)
 | ||
|  |     raise TypeError("Unexpected type {}".format(type(img)))
 | ||
|  | 
 | ||
|  | def get_image_size(img: Tensor) -> List[int]:
 | ||
|  |     """Returns the size of an image as [width, height].
 | ||
|  |     Args:
 | ||
|  |         img (PIL Image or Tensor): The image to be checked.
 | ||
|  |     Returns:
 | ||
|  |         List[int]: The image size.
 | ||
|  |     """
 | ||
|  |     if isinstance(img, torch.Tensor):
 | ||
|  |         return get_image_size_tensor(img)
 | ||
|  | 
 | ||
|  |     return get_image_size_pil(img)
 | ||
|  | 
 | ||
|  | class Compose(object):
 | ||
|  |     def __init__(self, transforms):
 | ||
|  |         self.transforms = transforms
 | ||
|  | 
 | ||
|  |     def __call__(self, image, target):
 | ||
|  |         for t in self.transforms:
 | ||
|  |             image, target = t(image, target)
 | ||
|  |         return image, target
 | ||
|  | 
 | ||
|  | 
 | ||
|  | class RandomHorizontalFlip(T.RandomHorizontalFlip):
 | ||
|  |     def forward(self, image: Tensor,
 | ||
|  |                 target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
 | ||
|  |         if torch.rand(1) < self.p:
 | ||
|  |             image = F.hflip(image)
 | ||
|  |             if target is not None:
 | ||
|  |                 width, _ = get_image_size(image)
 | ||
|  |                 target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
 | ||
|  |                 if "masks" in target:
 | ||
|  |                     target["masks"] = target["masks"].flip(-1)
 | ||
|  |         return image, target
 | ||
|  | 
 | ||
|  | 
 | ||
|  | class ToTensor(nn.Module):
 | ||
|  |     def forward(self, image: Tensor,
 | ||
|  |                 target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
 | ||
|  |         image = F.to_tensor(image)
 | ||
|  |         return image, target
 |