Pytorch源码阅读(二):transforms模块

注:笔者阅读的pytorch版本为1.7.0,torchvision版本为0.6

前言

在这篇博客文章中,我主要来写关于pytorchtransforms模块,该模块提供了对图像各种预处理方法,位于torchvision/transforms/transforms.py,这些方法会应用在模型训练推理前,对图像进行预处理,再将处理后的图像送进深度网络中训练与推理。我想写这篇文章对这些预处理方法,进行学习理解,尽可能从源码角度,形象直观地展示这些图像预处理是如何对图像进行转换的。

transforms : 读[trænsˈfɔːm],变换、转换、移动。 eg. Fourier transform 傅里叶变换; It was an event that would transform my life. 那是能够彻底改变我一生的一件事。

transforms 模块相关源码分析

transforms.Compose类

先来看下transforms.Compose在实际代码中是如何运用的,下面代码是使用pytorch中的Dataset方式定义ImageNet数据集,也就说ImageNet继承自data.Dataset

1
2
3
4
5
6
7
train_dataset = torchvision.datasets.ImageNet(train_path,
        transform=transforms.Compose([
            transforms.Resize((32, 32)),  # 将图片缩放到指定大小(h,w)或者保持长宽比并缩放最短的边到int大小

            transforms.CenterCrop(32),
            transforms.ToTensor()])
        )

从上面代码可以看出来transforms模块定义的对象,作为参数传入给ImageNet,在《pytorch源码(一)》中,了解到,通过for循环可以遍历Dataset对象获取图像数据,这篇文章介绍的transforms模块定义的类,一般在遍历Dataset获取图像前对图像进行预处理,那么通过for循环得到的图像就是进行处理后的图像。

下面来分析transforms.Compose源码。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Compose(object):
    """Composes several transforms together.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

Compose是一个容器,它是对多个transforms模块定义转换对象transform组合,本质上是对列表的包装(装饰模式?)。

这里我将这些转换对象类定义为transform,包括transforms.CenterCroptransforms.ToTensor等等。

该类的构造器(constructor)函数入参为列表类型,列表里是多个transform转换对象;在__call__中,通过for循环遍历transforms列表,对图像依次进行调用t(img),可以看到每个transform都是一个可调用对象,通过依次调用这些transform,对图像进行预处理,这些处理是按照列表顺序处理。注:list是一种有序的集合。

transforms.Resize类

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Resize(object):
    """Resize the input PIL Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be scaled.

        Returns:
            PIL Image: Rescaled image.
        """
        return F.resize(img, self.size, self.interpolation)

    def __repr__(self):
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)

先分析Resize构造器方法,参数size,为转换后的图像像素大小,如果size参数是这样的序列(h,w),输出大小将与此匹配。如果size是int,图像的较小边等于此数字。如果height > width,则图像将被重新缩放到$\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)$。另一个参数为interpolation,翻译为插值,默认值为双线性插值,这里我猜测下,当需要将图像分辨率变大,比如300*400提高至600*500分辨率,因为分辨率乘积代表像素点的个数,分辨率增大,那么原有的图像像素点不够,需要通过插值的方式新生成一部分像素点,这个参数就是控制如何产生这部分的像素点,在后面分析源码我会详细地说明插值方法。

分析__call__方法,它的内部的实现直接调用了torchvision/transforms/functional.py模块中的方法,functional.py模块包含了很多对图像转换的具体方法实现,比如resize(img, size, interpolation=Image.BILINEAR)方法的具体实现等等。

下面我使用Resize通过代码将475 * 300大小图片转换为237 * 150大小的图片。

1
2
3
4
5
6
7
8
from torchvision.transforms import transforms
from PIL import Image

im = Image.open("resources/dog-4671215_1280.jpg")
resize = transforms.Resize((150, 237))
im = resize(im)
im.show()
im.save('resources/dog-150_237.jpg')

475 * 300 缩小为 237 * 150

下面我来具体分析functional.py模块中的resize方法,源码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def resize(img, size, interpolation=Image.BILINEAR):
    r"""Resize the input PIL Image to the given size.

    Args:
        img (PIL Image): Image to be resized.
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), the output size will be matched to this. If size is an int,
            the smaller edge of the image will be matched to this number maintaing
            the aspect ratio. i.e, if height > width, then image will be rescaled to
            :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``

    Returns:
        PIL Image: Resized image.
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
        return img.resize(size[::-1], interpolation)

resize函数是定义在functional.py模块中的一个函数,前两个if语句,用来做参数类型校验,主要用到python中内置的isinstance(x, A_tuple)函数,返回对象是类的实例还是子类的实例,值得注意的是,第二个参数,既可以是一个类对象,也可以是一个包含多个类对象的元组。关于Iterable的理解可以查看这篇博客理解Python的Iterable和Iterator

其他部分的代码也很容易理解,isinstance(size, int)True时,size为图像较小的边的输出大小,输出图像另一边的大小根据原图的比例计算得出。

最后的else分支也很容易理解,在这里可以学习到的一点是,如何取得一个列表倒序的结果?这里用了切片的方式size[::-1],来将入参是传入的(h, w)逆序变成(w, h),再传入PIL.Image对象的resize()方法,也就是说pytorch本身没有自己实现对图像的resize方法,而是在底层调用的PIL图像库的方法。

下面先插入一些对python中切片的学习,不知道为什么我每次遇到切片表达方式都要重新查一遍各种切片具体是代表的什么含义(😖)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
Last login: Wed Mar 10 14:19:12 on ttys001
>>> L = list(range(10))
>>> L
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> L[::-1]  # 倒序排列并且每1个取一个
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
>>> L[::-2]  # 倒序排列并且每2个取一个 
[9, 7, 5, 3, 1]
>>> L[::2]   # 正序排列并且每2个取一个
[0, 2, 4, 6, 8]
>>> L[:]     # 原样复制
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> L[1:5]   # 取第一个到第五个
[1, 2, 3, 4]

参考:python切片

至于PIL.Image对象的resize()方法具体分析,内部实现是比较复杂的,先暂时省略,后面再补(🤦)‍️。

transforms.CenterCrop类

CenterCrop类的功能是依据给定的size从中心裁剪,在pytorch中的实现非常简单,这里就不贴出来源码了,简单看下它的构造器方法,参数size为(h, w),代表裁剪后的图像分辨率大小,而如果size为int型,那么裁剪为(size, size)大小的正方形中心图像。例如,下面在475 * 300大小图像的中心部分裁剪出120 * 120大小的图像: 475 * 300 裁剪出 120 * 120

transforms.ToTensor

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class ToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.
    """

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)

    def __repr__(self):
        return self.__class__.__name__ + '()'

ToTensor从类名字就可以看出来这个类主要将PIL.Image对象和numpy.ndarray对象转换为torch.FloatTensor对象,值得注意的是,最终的转换得到的Tensor对象的size(C x H x W),通道数在前面,也就是说pytorch在处理图像时,一般情况下,将三维的图像排列是(通道数,高度,宽度)。另外,该类还会将原本的0~255的RGB三原色强度值进行归一化处理,统一除以255,使得值在0.0~1.0之间。

comments powered by Disqus
Built with Hugo
主题 StackJimmy 设计