当前位置 主页 > 网站技术 > 代码类 >

    关于Pytorch的MNIST数据集的预处理详解

    栏目:代码类 时间:2020-01-10 12:09

    关于Pytorch的MNIST数据集的预处理详解

    MNIST的准确率达到99.7%

    用于MNIST的卷积神经网络(CNN)的实现,具有各种技术,例如数据增强,丢失,伪随机化等。

    操作系统:ubuntu18.04

    显卡:GTX1080ti

    python版本:2.7(3.7)

    网络架构

    具有4层的CNN具有以下架构。

    输入层:784个节点(MNIST图像大小)

    第一卷积层:5x5x32

    第一个最大池层

    第二卷积层:5x5x64

    第二个最大池层

    第三个完全连接层:1024个节点

    输出层:10个节点(MNIST的类数)

    用于改善CNN性能的工具

    采用以下技术来改善CNN的性能。

    1. Data augmentation

    通过以下方式将列车数据的数量增加到5倍

    随机旋转:每个图像在[-15°,+ 15°]范围内随机旋转。

    随机移位:每个图像在两个轴上随机移动一个范围为[-2pix,+ 2pix]的值。

    零中心归一化:将像素值减去(PIXEL_DEPTH / 2)并除以PIXEL_DEPTH。

    2. Parameter initializers

    重量初始化器:xaiver初始化器

    偏差初始值设定项:常量(零)初始值设定项

    3. Batch normalization

    所有卷积/完全连接的层都使用批量标准化。

    4. Dropout

    The third fully-connected layer employes dropout technique.

    5. Exponentially decayed learning rate

    A learning rate is decayed every after one-epoch.

    代码部分

    第一步:了解MNIST数据集

    MNIST数据集是一个手写体数据集,一共60000张图片,所有的图片都是28×28的,下载数据集的地址:数据集官网。这个数据集由四部分组成,分别是:

    train-images-idx3-ubyte.gz: training set images (9912422 bytes) 
    train-labels-idx1-ubyte.gz: training set labels (28881 bytes) 
    t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) 
    t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
    

    也就是一个训练图片集,一个训练标签集,一个测试图片集,一个测试标签集;我们可以看出这个其实并不是普通的文本文件

    或是图片文件,而是一个压缩文件,下载并解压出来,我们看到的是二进制文件。

    第二步:加载MNIST数据集

    先引入一些库文件

    import torchvision,torch
    import torchvision.transforms as transforms
    from torch.utils.data import Dataset, DataLoader
    import matplotlib.pyplot as plt
    

    加载MNIST数据集有很多方法:

    方法一:在pytorch下可以直接调用torchvision.datasets里面的MNIST数据集(这是官方写好的数据集类)

    train = torchvision.datasets.MNIST(root='./mnist/',train=True, transform= transforms.ToTensor())
    

    返回值为一个元组(train_data,train_target)(这个类使用的时候也有坑,必须用train[i]索引才能使用 transform功能)

    一般是与torch.utils.data.DataLoader配合使用

    dataloader = DataLoader(train, batch_size=50,shuffle=True, num_workers=4)
    for step, (x, y) in enumerate(dataloader):
     b_x = x.shape
     b_y = y.shape
     print 'Step: ', step, '| train_data的维度' ,b_x,'| train_target的维度',b_y

    如图将60000张图片的数据分为1200份,每份包含50张图像,这样并行处理数据能有效加快计算速度