当前位置 博文首页 > Pytorch 图像变换函数集合小结

    Pytorch 图像变换函数集合小结

    作者:libo-coder 时间:2021-07-29 18:51

    目录
    • 一、必要的 python 模块
    • 二、PyTorch 图像变换函数
      • 2.1 判断图像数据类型
      • 2.2 to_tensor(pic)
      • 2.3 to_pil_image(pic, mode=None)
      • 2.4 normalize(tensor, mean, std)
      • 2.5 resize(img, size, interpolation=Image.BILINEAR)
      • 2.6 pad(img, padding, fill=0, padding_mode=‘constant')
      • 2.7 crop(img, i, j, h, w)
      • 2.8 center_crop(img, output_size)
      • 2.9 resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR)
      • 2.10 hflip(img)
      • 2.11 vflip(img)
      • 2.12 five_crop(img, size)
      • 2.13 ten_crop(img, size, vertical_flip=False)
      • 2.14 adjust_brightness(img, brightness_factor)
      • 2.15 adjust_contrast(img, contrast_factor)
      • 2.16 adjust_saturation(img, saturation_factor)
      • 2.17 adjust_hue(img, hue_factor)
      • 2.18 adjust_gamma(img, gamma, gain=1)
      • 2.19 rotate(img, angle, resample=False, expand=False, center=None)
      • 2.20 affine(img, angle, translate, scale, shear, resample=0, fillcolor=None)
      • 2.21 to_grayscale(img, num_output_channels=1)
    • 参考链接

      一、必要的 python 模块

      PyTorch 的 Vision 模块提供了图像变换的很多函数.

      torchvision/transforms/functional.py

      from __future__ import division
      import torch
      import sys
      import math
      from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION
      try:
       import accimage
      except ImportError:
       accimage = None
      import numpy as np
      import numbers
      import collections
      import warnings
      import matplotlib as plt
      
      if sys.version_info < (3, 3):
       Sequence = collections.Sequence
       Iterable = collections.Iterable
      else:
       Sequence = collections.abc.Sequence
       Iterable = collections.abc.Iterable
      

      以下图为例:

      img_file = "test.jpe"
      img = Image.open(img_file)
      width, height = img.size #(750, 815)
      img.show()
      

      二、PyTorch 图像变换函数

      2.1 判断图像数据类型

      # 图像格式检查,如,pil, tensor, numpy
      def _is_pil_image(img):
       if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
       else:
        return isinstance(img, Image.Image)
      
      def _is_tensor_image(img):
       return torch.is_tensor(img) and img.ndimension() == 3
      
      def _is_numpy_image(img):
       return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
      
      # example:
      _is_pil_image(img)
      # True
      
      _is_tensor_image(img)
      # False
      
      _is_numpy_image(img)
      # False
      
      _is_numpy_image(np.array(img))
      # True
      

      2.2 to_tensor(pic)

      PIL Imagenupy.ndarray 转换为 tensor

      def to_tensor(pic):
       """
       Args:
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
      
       Returns:
        Tensor: Converted image.
       """
       if not(_is_pil_image(pic) or _is_numpy_image(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
      
       if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        if isinstance(img, torch.ByteTensor):
         return img.float().div(255)
        else:
         return img
      
       if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)
      
       # handle PIL Image
       if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
       elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
       elif pic.mode == 'F':
        img = torch.from_numpy(np.array(pic, np.float32, copy=False))
       elif pic.mode == '1':
        img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
       else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
       # PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK
       if pic.mode == 'YCbCr':
        nchannel = 3
       elif pic.mode == 'I;16':
        nchannel = 1
       else:
        nchannel = len(pic.mode)
       img = img.view(pic.size[1], pic.size[0], nchannel)
       # put it from HWC to CHW format
       # yikes, this transpose takes 80% of the loading time/CPU
       img = img.transpose(0, 1).transpose(0, 2).contiguous()
       if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
       else:
        return img
      

      2.3 to_pil_image(pic, mode=None)

      tensorndarray 转换为 PIL Image

      def to_pil_image(pic, mode=None):
       """
       Args:
        pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
      
       .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
      
       Returns:
        PIL Image: Image converted to PIL Image.
       """
       if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
        raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
      
       elif isinstance(pic, torch.Tensor):
        if pic.ndimension() not in {2, 3}:
         raise ValueError('pic should be 2/3 dimensional. Got {} '\
              'dimensions.'.format(pic.ndimension()))
      
        elif pic.ndimension() == 2:
         # if 2D image, add channel dimension (CHW)
         pic.unsqueeze_(0)
      
       elif isinstance(pic, np.ndarray):
        if pic.ndim not in {2, 3}:
         raise ValueError('pic should be 2/3 dimensional. Got {} '\
              'dimensions.'.format(pic.ndim))
      
        elif pic.ndim == 2:
         # if 2D image, add channel dimension (HWC)
         pic = np.expand_dims(pic, 2)
      
       npimg = pic
       if isinstance(pic, torch.FloatTensor):
        pic = pic.mul(255).byte()
       if isinstance(pic, torch.Tensor):
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
      
       if not isinstance(npimg, np.ndarray):
        raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
            'not {}'.format(type(npimg)))
      
       if npimg.shape[2] == 1:
        expected_mode = None
        npimg = npimg[:, :, 0]
        if npimg.dtype == np.uint8:
         expected_mode = 'L'
        elif npimg.dtype == np.int16:
         expected_mode = 'I;16'
        elif npimg.dtype == np.int32:
         expected_mode = 'I'
        elif npimg.dtype == np.float32:
         expected_mode = 'F'
        if mode is not None and mode != expected_mode:
         raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
              .format(mode, np.dtype, expected_mode))
        mode = expected_mode
      
       elif npimg.shape[2] == 4:
        permitted_4_channel_modes = ['RGBA', 'CMYK']
        if mode is not None and mode not in permitted_4_channel_modes:
         raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
      
        if mode is None and npimg.dtype == np.uint8:
         mode = 'RGBA'
       else:
        permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
        if mode is not None and mode not in permitted_3_channel_modes:
         raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
        if mode is None and npimg.dtype == np.uint8:
         mode = 'RGB'
      
       if mode is None:
        raise TypeError('Input type {} is not supported'.format(npimg.dtype))
      
       return Image.fromarray(npimg, mode=mode)
      

      2.4 normalize(tensor, mean, std)

      归一化 tensor 的图像. in-place 计算.

      def normalize(tensor, mean, std):
       """
       Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channely.
      
       Returns:
        Tensor: Normalized Tensor image.
       """
       if not _is_tensor_image(tensor):
        raise TypeError('tensor is not a torch image.')
      
       # This is faster than using broadcasting, don't change without benchmarking
       for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
       return tensor
      
      # example
      mean = [0.485, 0.456, 0.406]
      std = [0.229, 0.224, 0.225]
      img_normalize = normalize(img_tensor, mean, std)
      
      # vis
      ax1 = plt.subplot(1, 2, 1)
      ax1.imshow(img)
      ax1.axis("off")
      ax1.set_title("orig img")
      ax2 = plt.subplot(1, 2, 2)
      ax2.imshow(to_pil_image(img_normalize))
      ax2.axis("off")
      ax2.set_title("normalize img")
      plt.show()
      

      2.5 resize(img, size, interpolation=Image.BILINEAR)

      对输入的 PIL Image 进行 resize 到给定尺寸.
      参数 size 为调整后的尺寸.
      如果 size 是数组(h, w),则直接调整到该 (h, w) 尺寸.
      如果 size 是一个 int 值,则调整后图像的最短边是该值,且保持固定的长宽比.

      def resize(img, size, interpolation=Image.BILINEAR):
       """
       Args:
        img (PIL Image): Image to be resized.
        size (sequence or int): Desired output size. 
        interpolation (int, optional): Desired interpolation. Default is
         ``PIL.Image.BILINEAR``
       Returns:
        PIL Image: Resized image.
       """
       if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
       if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))
      
       if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
         return img
        if w < h:
         ow = size
         oh = int(size * h / w)
         return img.resize((ow, oh), interpolation)
        else:
         oh = size
         ow = int(size * w / h)
         return img.resize((ow, oh), interpolation)
       else:
        return img.resize(size[::-1], interpolation)
      
      # example:
      img_resize_256x256 = resize(img, (256, 256)) # (256, 256)
      img_resize_256 = resize(img, 256) # (256, 278)
      
      # vis
      ax1 = plt.subplot(1, 3, 1)
      ax1.imshow(img)
      ax1.axis("off")
      ax1.set_title("orig img")
      ax2 = plt.subplot(1, 3, 2)
      ax2.imshow(img_resize_256x256)
      ax2.axis("off")
      ax2.set_title("resize_256x256 img")
      ax3 = plt.subplot(1, 3, 3)
      ax3.imshow(img_resize_256)
      ax3.axis("off")
      ax3.set_title("resize_256 img")
      plt.show()
      

      2.6 pad(img, padding, fill=0, padding_mode=‘constant')

      根据指定的 padding 模式和填充值,对给定的 PIL Image 的所有边进行 pad 处理.
      参数 padding - int 或 tuple 形式.

      padding:

      • 如果是 int 值 ,则对所有的边都 padding 该 int 值.
      • 如果是长度为 2 的tuple,则对 left/right 和 top/bottom 分别进行 padding.
      • 如果是长度为 4 的 tuple,则对 left,top,right, bottom 边分别进行 padding.

      参数 fill - 像素填充值,默认为 0. 如果值是长度为 3 的 tuple,则分别对 R,G,B 通道进行填充. 仅用于当 padding_mode='constant' 的情况.

      参数 padding_mode - 填充的类型,可选:constant,edge,reflect,symmetric. 默认为 constant. 填充常数值.

      constant - padding 填充常数值 fill.

      edge - padding 图像边缘的最后一个值.

      reflect - padding 图像的反射(reflection)值,(不对图像边缘的最后一个像素值进行重复)
      如,[1, 2, 3, 4] 在 reflect 模式下在 两边 padding 2 个元素值,会得到:
      [3, 2, 1, 2, 3, 4, 3, 2]

      symmetric - padding 图像的反射(reflection)值,(对图像边缘的最后一个像素值进行重复).
      如,[1, 2, 3, 4] 在 symmetric 模式下在 两边 padding 2 个元素值,会得到:
      [2, 1, 1, 2, 3, 4, 4, 3]

      def pad(img, padding, fill=0, padding_mode='constant'):
       """
       Args:
        img (PIL Image): Image to be padded.
        padding (int or tuple): Padding on each border. 
        fill: Pixel fill value for constant fill. Default is 0. 
        padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. 
            Default is constant.
       Returns:
        PIL Image: Padded image.
       """
       if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
      
       if not isinstance(padding, (numbers.Number, tuple)):
        raise TypeError('Got inappropriate padding arg')
       if not isinstance(fill, (numbers.Number, str, tuple)):
        raise TypeError('Got inappropriate fill arg')
       if not isinstance(padding_mode, str):
        raise TypeError('Got inappropriate padding_mode arg')
      
       if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
        raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
             "{} element tuple".format(len(padding)))
      
       assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
        'Padding mode should be either constant, edge, reflect or symmetric'
      
       if padding_mode == 'constant':
        if img.mode == 'P':
         palette = img.getpalette()
         image = ImageOps.expand(img, border=padding, fill=fill)
         image.putpalette(palette)
         return image
      
        return ImageOps.expand(img, border=padding, fill=fill)
       else:
        if isinstance(padding, int):
         pad_left = pad_right = pad_top = pad_bottom = padding
        if isinstance(padding, Sequence) and len(padding) == 2:
         pad_left = pad_right = padding[0]
         pad_top = pad_bottom = padding[1]
        if isinstance(padding, Sequence) and len(padding) == 4:
         pad_left = padding[0]
         pad_top = padding[1]
         pad_right = padding[2]
         pad_bottom = padding[3]
      
        if img.mode == 'P':
         palette = img.getpalette()
         img = np.asarray(img)
         img = np.pad(img, 
             ((pad_top, pad_bottom), (pad_left, pad_right)), 
             padding_mode)
         img = Image.fromarray(img)
         img.putpalette(palette)
         return img
      
        img = np.asarray(img)
        # RGB image
        if len(img.shape) == 3:
         img = np.pad(img, 
             ((pad_top, pad_bottom), 
             (pad_left, pad_right), 
             (0, 0)), 
             padding_mode)
        # Grayscale image
        if len(img.shape) == 2:
         img = np.pad(img, 
             ((pad_top, pad_bottom), (pad_left, pad_right)), 
             padding_mode)
      
        return Image.fromarray(img)
      
      # example:
      img_padding = pad(img, (10, 20, 30 ,40), fill=128)	# (750, 815) -> (790, 875)
      
      # vis
      ax1 = plt.subplot(1, 2, 1)
      ax1.imshow(img)
      ax1.axis("off")
      ax1.set_title("orig img")
      ax2 = plt.subplot(1, 2, 2)
      ax2.imshow(img_padding)
      ax2.axis("off")
      ax2.set_title("padding img")
      plt.show()
      

      2.7 crop(img, i, j, h, w)

      裁剪给定的 PIL Image.

      def crop(img, i, j, h, w):
       """
       Args:
        img (PIL Image): Image to be cropped.
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
        h: Height of the cropped image.
        w: Width of the cropped image.
      
       Returns:
        PIL Image: Cropped image.
       """
       if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
      
       return img.crop((j, i, j + w, i + h))
      
      # example
      img_crop = crop(img, 100, 100, 500, 500)	# (750, 815) -> (500, 500)
      
      ax1 = plt.subplot(1, 2, 1)
      ax1.imshow(img)
      ax1.axis("off")
      ax1.set_title("orig img")
      ax2 = plt.subplot(1, 2, 2)
      ax2.imshow(img_crop)
      ax2.axis("off")
      ax2.set_title("crop img")
      plt.show()
      

      2.8 center_crop(img, output_size)

      def center_crop(img, output_size):
       if isinstance(output_size, numbers.Number):
        output_size = (int(output_size), int(output_size))
       w, h = img.size
       th, tw = output_size
       i = int(round((h - th) / 2.))
       j = int(round((w - tw) / 2.))
       return crop(img, i, j, th, tw)
      
      #example
      img_centercrop = center_crop(img, (256, 256))	# (750, 815) -> (256, 256)
      
      ax1 = plt.subplot(1, 2, 1)
      ax1.imshow(img)
      ax1.axis("off")
      ax1.set_title("orig img")
      ax2 = plt.subplot(1, 2, 2)
      ax2.imshow(img_centercrop)
      ax2.axis("off")
      ax2.set_title("centercrop img")
      plt.show()
      

      2.9 resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR)

      对给定 PIL Image 进行裁剪,并 resize 到特定尺寸.

      def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
       """
       Args:
        img (PIL Image): Image to be cropped.
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
        h: Height of the cropped image.
        w: Width of the cropped image.
        size (sequence or int): Desired output size. Same semantics as ``resize``.
        interpolation (int, optional): Desired interpolation. Default is
         ``PIL.Image.BILINEAR``.
       Returns:
        PIL Image: Cropped image.
       """
       assert _is_pil_image(img), 'img should be PIL Image'
       img = crop(img, i, j, h, w)
       img = resize(img, size, interpolation)
       return img
      
      # example
      img_resizedcrop = resized_crop(img, 100, 100, 500, 500, (256, 256))	# (750, 815) -> (500, 500) -> (256, 256)
      
      ax1 = plt.subplot(1, 2, 1)
      ax1.imshow(img)
      ax1.axis("off")
      ax1.set_title("orig img")
      ax2 = plt.subplot(1, 2, 2)
      ax2.imshow(img_resizedcrop)
      ax2.axis("off")
      ax2.set_title("resizedcrop img")
      plt.show()
      

      2.10 hflip(img)

      水平翻转 (Horizontally flip) 给定的 PIL Image.

      def hflip(img):
       """
       Args:
        img (PIL Image): Image to be flipped.
      
       Returns:
        PIL Image: Horizontall flipped image.
       """
       if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
      
       return img.transpose(Image.FLIP_LEFT_RIGHT)
      

      2.11 vflip(img)

      垂直翻转 (Vertically flip) 给定的 PIL Image.

      def vflip(img):
       """
       Args:
        img (PIL Image): Image to be flipped.
      
       Returns:
        PIL Image: Vertically flipped image.
       """
       if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
      
       return img.transpose(Image.FLIP_TOP_BOTTOM)
      
      # example:
      img_hflip = hflip(img)
      img_vflip = vflip(img)
      
      ax1 = plt.subplot(1, 3, 1)
      ax1.imshow(img)
      ax1.axis("off")
      ax1.set_title("orig img")
      ax2 = plt.subplot(1, 3, 2)
      ax2.imshow(img_hflip)
      ax2.axis("off")
      ax2.set_title("hflip img")
      ax3 = plt.subplot(1, 3, 3)
      ax3.imshow(img_vflip)
      ax3.axis("off")
      ax3.set_title("vflip img")
      plt.show()
      

      2.12 five_crop(img, size)

      Crop the given PIL Image into four corners and the central crop.
      从给定 PIL Image 的四个角和中间裁剪出五个子图像.

      def five_crop(img, size):
       """
       Args:
        size (sequence or int): Desired output size of the crop. If size is an
         int instead of sequence like (h, w), a square crop (size, size) is
         made.
      
       Returns:
        tuple: tuple (tl, tr, bl, br, center)
          Corresponding top left, top right, bottom left, 
          bottom right and center crop.
       """
       if isinstance(size, numbers.Number):
        size = (int(size), int(size))
       else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
      
       w, h = img.size
       crop_h, crop_w = size
       if crop_w > w or crop_h > h:
        raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
                            (h, w)))
       tl = img.crop((0, 0, crop_w, crop_h))
       tr = img.crop((w - crop_w, 0, w, crop_h))
       bl = img.crop((0, h - crop_h, crop_w, h))
       br = img.crop((w - crop_w, h - crop_h, w, h))
       center = center_crop(img, (crop_h, crop_w))
       return (tl, tr, bl, br, center)
      
      # example:
      img_tl, img_tr, img_bl, img_br, img_center = five_crop(img, (400, 400))
      
      ax1 = plt.subplot(2, 3, 1)
      ax1.imshow(img)
      ax1.axis("off")
      ax1.set_title("orig img")
      ax2 = plt.subplot(2, 3, 2)
      ax2.imshow(img_tl)
      ax2.axis("off")
      ax2.set_title("tl img")
      ax3 = plt.subplot(2, 3, 3)
      ax3.imshow(img_tr)
      ax3.axis("off")
      ax3.set_title("tr img")
      ax4 = plt.subplot(2, 3, 4)
      ax4.imshow(img_bl)
      ax4.axis("off")
      ax4.set_title("bl img")
      ax5 = plt.subplot(2, 3, 5)
      ax5.imshow(img_br)
      ax5.axis("off")
      ax5.set_title("br img")
      ax6 = plt.subplot(2, 3, 6)
      ax6.imshow(img_center)
      ax6.axis("off")
      ax6.set_title("center img")
      plt.show()
      

      2.13 ten_crop(img, size, vertical_flip=False)

      将给定 PIL Image 裁剪出的四个角和中间部分的五个子图像,每个子图像进行翻转处理. 默认时水平翻转.

      def ten_crop(img, size, vertical_flip=False):
       """
       Args:
        size (sequence or int): Desired output size of the crop. If size is an
         int instead of sequence like (h, w), a square crop (size, size) is
         made.
        vertical_flip (bool): Use vertical flipping instead of horizontal
      
       Returns:
        tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
        Corresponding top left, top right, bottom left, bottom right and center crop
        and same for the flipped image.
       """
       if isinstance(size, numbers.Number):
        size = (int(size), int(size))
       else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
      
       first_five = five_crop(img, size)
      
       if vertical_flip:
        img = vflip(img)
       else:
        img = hflip(img)
      
       second_five = five_crop(img, size)
       return first_five + second_five
      

      2.14 adjust_brightness(img, brightness_factor)

      def adjust_brightness(img, brightness_factor):
       """
       Args:
        img (PIL Image): PIL Image to be adjusted.
        brightness_factor (float): How much to adjust the brightness.
         Can be any non negative number. 
         0 gives a black image, 
         1 gives the original image,
         2 increases the brightness by a factor of 2.
      
       Returns:
        PIL Image: Brightness adjusted image.
       """
       if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
      
       enhancer = ImageEnhance.Brightness(img)
       img = enhancer.enhance(brightness_factor)
       return img
      
      # example:
      img_adjust_brightness = adjust_brightness(img, 2.5)
      
      # vis
      ax1 = plt.subplot(1, 2, 1)
      ax1.imshow(img)
      ax1.axis("off")
      ax1.set_title("orig img")
      ax2 = plt.subplot(1, 2, 2)
      ax2.imshow(img_adjust_brightness)
      ax2.axis("off")
      ax2.set_title("adjust_brightness img")
      plt.show()
      

      2.15 adjust_contrast(img, contrast_factor)

      调整对比度.

      def adjust_contrast(img, contrast_factor):
       """
       Args:
        img (PIL Image): PIL Image to be adjusted.
        contrast_factor (float): How much to adjust the contrast. 
         Can be any non negative number. 
         0 gives a solid gray image, 
         1 gives the original image, 
         2 increases the contrast by a factor of 2.
      
       Returns:
        PIL Image: Contrast adjusted image.
       """
       if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
      
       enhancer = ImageEnhance.Contrast(img)
       img = enhancer.enhance(contrast_factor)
       return img
      
      # example:
      img_adjust_contrast = adjust_contrast(img, 2.5)
      
      # vis
      ax1 = plt.subplot(1, 2, 1)
      ax1.imshow(img)
      ax1.axis("off")
      ax1.set_title("orig img")
      ax2 = plt.subplot(1, 2, 2)
      ax2.imshow(img_adjust_contrast)
      ax2.axis("off")
      ax2.set_title("adjust_contrast img")
      plt.show()
      

      2.16 adjust_saturation(img, saturation_factor)

      调整颜色饱和度.

      def adjust_saturation(img, saturation_factor):
       """
       Args:
        img (PIL Image): PIL Image to be adjusted.
        saturation_factor (float): How much to adjust the saturation. 
         0 will give a black and white image, 
         1 will give the original image while
         2 will enhance the saturation by a factor of 2.
      
       Returns:
        PIL Image: Saturation adjusted image.
       """
       if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
      
       enhancer = ImageEnhance.Color(img)
       img = enhancer.enhance(saturation_factor)
       return img
      
      # example
      img_adjust_saturation = adjust_saturation(img, 2.5)
      
      # vis
      ax1 = plt.subplot(1, 2, 1)
      ax1.imshow(img)
      ax1.axis("off")
      ax1.set_title("orig img")
      ax2 = plt.subplot(1, 2, 2)
      ax2.imshow(img_adjust_saturation)
      ax2.axis("off")
      ax2.set_title("adjust_saturation img")
      plt.show()
      

      2.17 adjust_hue(img, hue_factor)

      调整图像 HUE.

      通过将图像转换为 HSV 空间,并周期地移动在 hue 通道(H) 的强度,以实现图像 hue 的调整.

      最后,再将结果转换回原始的图像模式.参数 hue_factor - H 通道平移的因子,其值必须在区间 [-0.5, 0.5].

      def adjust_hue(img, hue_factor):
       """
       Args:
        img (PIL Image): PIL Image to be adjusted.
        hue_factor (float): How much to shift the hue channel. 
         Should be in [-0.5, 0.5]. 
         0.5 and -0.5 give complete reversal of hue channel in
         HSV space in positive and negative direction respectively.
         0 means no shift. 
         Therefore, both -0.5 and 0.5 will give an image
         with complementary colors while 0 gives the original image.
      
       Returns:
        PIL Image: Hue adjusted image.
       """
       if not(-0.5 <= hue_factor <= 0.5):
        raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
      
       if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
      
       input_mode = img.mode
       if input_mode in {'L', '1', 'I', 'F'}:
        return img
      
       h, s, v = img.convert('HSV').split()
      
       np_h = np.array(h, dtype=np.uint8)
       # uint8 addition take cares of rotation across boundaries
       with np.errstate(over='ignore'):
        np_h += np.uint8(hue_factor * 255)
       h = Image.fromarray(np_h, 'L')
      
       img = Image.merge('HSV', (h, s, v)).convert(input_mode)
       return img
      
      # example:
      img_adjust_hue = adjust_hue(img, 0.5)
      
      # vis
      ax1 = plt.subplot(1, 2, 1)
      ax1.imshow(img)
      ax1.axis("off")
      ax1.set_title("orig img")
      ax2 = plt.subplot(1, 2, 2)
      ax2.imshow(img_adjust_hue)
      ax2.axis("off")
      ax2.set_title("adjust_hue img")
      plt.show()
      

      2.18 adjust_gamma(img, gamma, gain=1)

      对图像进行伽马校正(gamma correction). 也被叫作 Power Law Transform.

      def adjust_gamma(img, gamma, gain=1):
       """
       Args:
        img (PIL Image): PIL Image to be adjusted.
        gamma (float): Non negative real number, 如公式中的 \gamma 值.
         gamma larger than 1 make the shadows darker,
         while gamma smaller than 1 make dark regions lighter.
        gain (float): The constant multiplier.
       """
       if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
      
       if gamma < 0:
        raise ValueError('Gamma should be a non-negative real number')
      
       input_mode = img.mode
       img = img.convert('RGB')
      
       gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
       img = img.point(gamma_map) # use PIL's point-function to accelerate this part
      
       img = img.convert(input_mode)
       return img
      # example:
      img_adjust_gamma = adjust_gamma(img, 0.5)
      
      # vis
      ax1 = plt.subplot(1, 2, 1)
      ax1.imshow(img)
      ax1.axis("off")
      ax1.set_title("orig img")
      ax2 = plt.subplot(1, 2, 2)
      ax2.imshow(img_adjust_gamma)
      ax2.axis("off")
      ax2.set_title("adjust_gamma img")
      plt.show()
      

      2.19 rotate(img, angle, resample=False, expand=False, center=None)

      旋转图像.

      参数 resample
      可选值:PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC.
      如果参数 resample 被忽略,或图像的模式是 1 或 P,则resample=PIL.Image.NEAREST.

      参数 expand
      如果 expand=True,则延展输出图像,以能包含旋转后的全部图像.
      如果 expand=False 或被忽略,则保持输出图像与输入图像的尺寸一致.
      expand 假设旋转是以中心进行旋转,且没有平移.

      def rotate(img, angle, resample=False, expand=False, center=None):
       """
       Args:
        img (PIL Image): PIL Image to be rotated.
        angle (float or int): In degrees degrees counter clockwise order.
        resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or 
           ``PIL.Image.BICUBIC``, optional):
        expand (bool, optional): Optional expansion flag.
        center (2-tuple, optional): Optional center of rotation.
         Origin is the upper left corner.
         Default is the center of the image.
       """
      
       if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
      
       return img.rotate(angle, resample, expand, center)
      
      # example:
      img_rotate = rotate(img, 60)
      
      # vis
      ax1 = plt.subplot(1, 2, 1)
      ax1.imshow(img)
      ax1.axis("off")
      ax1.set_title("orig img")
      ax2 = plt.subplot(1, 2, 2)
      ax2.imshow(img_rotate)
      ax2.axis("off")
      ax2.set_title("rotate img")
      plt.show()
      

      下一篇:没有了