详解 MNIST 数据集

MNIST数据集已经是一个被”嚼烂”了的数据集,很多教程都会对它”下手”,几乎成为一个“典范”.不过有些人可能对它还不是很了解,下面来介绍一下.MNIST数据集可在http://yann.lecun.com/exdb/mnist/获取,它包含了四个部分:Trainingsetimages:train-images-idx3-ubyte.gz(9.9MB,解压后47

大家好,又见面了,我是你们的朋友全栈君。

MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 不过有些人可能对它还不是很了解, 下面来介绍一下.

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:

  • Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
  • Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
  • Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
  • Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.

不妨新建一个文件夹 – mnist, 将数据集下载到 mnist 以后, 解压即可:

dataset

图片是以字节的形式进行存储, 我们需要把它们读取到 NumPy array 中, 以便训练和测试算法.

import os
import struct
import numpy as np

def load_mnist(path, kind='train'):
    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte'
                               % kind)
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II',
                                 lbpath.read(8))
        labels = np.fromfile(lbpath,
                             dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII',
                                               imgpath.read(16))
        images = np.fromfile(imgpath,
                             dtype=np.uint8).reshape(len(labels), 784)

    return images, labels

load_mnist 函数返回两个数组, 第一个是一个 n x m 维的 NumPy array(images), 这里的 n 是样本数(行数), m 是特征数(列数). 训练数据集包含 60,000 个样本, 测试数据集包含 10,000 样本. 在 MNIST 数据集中的每张图片由 28 x 28 个像素点构成, 每个像素点用一个灰度值表示. 在这里, 我们将 28 x 28 的像素展开为一个一维的行向量, 这些行向量就是图片数组里的行(每行 784 个值, 或者说每行就是代表了一张图片). load_mnist 函数返回的第二个数组(labels) 包含了相应的目标变量, 也就是手写数字的类标签(整数 0-9).

第一次见的话, 可能会觉得我们读取图片的方式有点奇怪:

magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)

为了理解这两行代码, 我们先来看一下 MNIST 网站上对数据集的介绍:

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000801(2049) magic number (MSB first) 
0004     32 bit integer  60000            number of items 
0008     unsigned byte   ??               label 
0009     unsigned byte   ??               label 
........ 
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.

通过使用上面两行代码, 我们首先读入 magic number, 它是一个文件协议的描述, 也是在我们调用 fromfile 方法将字节读入 NumPy array 之前在文件缓冲中的 item 数(n). 作为参数值传入 struct.unpack>II 有两个部分:

  • >: 这是指大端(用来定义字节是如何存储的); 如果你还不知道什么是大端和小端, Endianness 是一个非常好的解释. (关于大小端, 更多内容可见<<深入理解计算机系统 – 2.1 节信息存储>>)
  • I: 这是指一个无符号整数.

通过执行下面的代码, 我们将会从刚刚解压 MNIST 数据集后的 mnist 目录下加载 60,000 个训练样本和 10,000 个测试样本.

为了了解 MNIST 中的图片看起来到底是个啥, 让我们来对它们进行可视化处理. 从 feature matrix 中将 784-像素值 的向量 reshape 为之前的 28*28 的形状, 然后通过 matplotlib 的 imshow 函数进行绘制:

import matplotlib.pyplot as plt

fig, ax = plt.subplots(
    nrows=2,
    ncols=5,
    sharex=True,
    sharey=True, )

ax = ax.flatten()
for i in range(10):
    img = X_train[y_train == i][0].reshape(28, 28)
    ax[i].imshow(img, cmap='Greys', interpolation='nearest')

ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

我们现在应该可以看到一个 2*5 的图片, 里面分别是 0-9 单个数字的图片.

0-9

此外, 我们还可以绘制某一数字的多个样本图片, 来看一下这些手写样本到底有多不同:

fig, ax = plt.subplots(
    nrows=5,
    ncols=5,
    sharex=True,
    sharey=True, )

ax = ax.flatten()
for i in range(25):
    img = X_train[y_train == 7][i].reshape(28, 28)
    ax[i].imshow(img, cmap='Greys', interpolation='nearest')

ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

执行上面的代码后, 我们应该看到数字 7 的 25 个不同形态:

7

另外, 我们也可以选择将 MNIST 图片数据和标签保存为 CSV 文件, 这样就可以在不支持特殊的字节格式的程序中打开数据集. 但是, 有一点要说明, CSV 的文件格式将会占用更多的磁盘空间, 如下所示:

  • train_img.csv: 109.5 MB
  • train_labels.csv: 120 KB
  • test_img.csv: 18.3 MB
  • test_labels: 20 KB

如果我们打算保存这些 CSV 文件, 在将 MNIST 数据集加载入 NumPy array 以后, 我们应该执行下列代码:

np.savetxt('train_img.csv', X_train,
           fmt='%i', delimiter=',')
np.savetxt('train_labels.csv', y_train,
           fmt='%i', delimiter=',')
np.savetxt('test_img.csv', X_test,
           fmt='%i', delimiter=',')
np.savetxt('test_labels.csv', y_test,
           fmt='%i', delimiter=',')

一旦将数据集保存为 CSV 文件, 我们也可以用 NumPy 的 genfromtxt 函数重新将它们加载入程序中:

X_train = np.genfromtxt('train_img.csv',
                        dtype=int, delimiter=',')
y_train = np.genfromtxt('train_labels.csv',
                        dtype=int, delimiter=',')
X_test = np.genfromtxt('test_img.csv',
                       dtype=int, delimiter=',')
y_test = np.genfromtxt('test_labels.csv',
                       dtype=int, delimiter=',')

不过, 从 CSV 文件中加载 MNIST 数据将会显著发给更长的时间, 因此如果可能的话, 还是建议你维持数据集原有的字节格式.

参考:
– Book , Python Machine Learning.

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/125840.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 虚拟机安装gcc「建议收藏」

    虚拟机安装gcc「建议收藏」(确保虚拟机网络通畅)1在虚拟机桌面上右击【打开终端】2输入sudoapt-getupdate回车输入开机密码(输入时密码并不显示,输完后回车)等待出现下一步指示3输入sudoapt-getinstallgcc回车输入密码等待(我等了约30min)出现下一步指示4输入gcc–version回车完…

  • 基础版的音频功放电路(A类功放电路、B类功放电路、C类功放电路、D类功放电路、G类功放电路、H类功放电路、K类功放电路、T类功放电路)[通俗易懂]

    基础版的音频功放电路(A类功放电路、B类功放电路、C类功放电路、D类功放电路、G类功放电路、H类功放电路、K类功放电路、T类功放电路)[通俗易懂]  作为硬件工程师,特别是做纯粹模拟电路、应用于音频功放的工程师,对于A类,B类,AB类,D类,G类,H类,T类功放应该特别熟悉。大多数工程师或许只知道其中的一部分、或者知道大概,为了让更多的工程师掌握更加详尽的音频功放知识,下文对以上说的音频功放做详细的说明。  引用地址:http://www.eepw.com.cn/article/201611/340477.htm  功放,顾名思义,就是功率放大的缩写。与电压或者电流放大来说,功放要求获得一定的、不失真的功率,一般在大信号状态下工作,因此,功放电路

  • zencart模板制作步骤详解

    zencart模板制作步骤详解
    1,在includes/template下面新建个文件夹叫你新模板的名字就可以了,这里我就叫yourname

    2,把includes/template/defalut_template 这个文件夹下面的所有的文件夹和文件复制到你刚刚新建的文件夹里面去yourname

    3,把template_info.php这个文件用dw打开,出现在你眼前的是php代码这个你可以不用管,你只用把[$template_name=’DefaultTemplate’;

  • Pycharm安装Pytorch过程

    Pycharm安装Pytorch过程关键字:torch1.8.1+cu111,torchvision0.9.1+cu111,torchaudio===0.8.1之前一个项目用到了pytorch,当时试了好多方案安装pytorch终于后终于成功把它装上了。很久没用后把anaconda卸载了…重新打开这个项目emmmm…心情有点复杂。写个文档记录下我怎么重装的(留给下一次重装看bushi)。直接上pytorch官网(https://pytorch.org/get-started/locally/)选好对应版本号,复制然后试图直接安装时报

  • 移动通信网络结构「建议收藏」

    移动通信网络结构「建议收藏」移动通信网络架构蜂窝网络架构蜂窝系统:(小区制系统)将所要覆盖的地区划分为若干个小区,每个小区的半径可视用户的发布密度在1-10km左右,在每个小区设立一个基站为本小区范围内用户服务;与之相对应的网络称为蜂窝式网络。特点:用户容量大、服务性能较好、频谱利用率较高、用户终端小巧且电池使用时间长,辐射小等。问题:频率复用牵扯到系统的复杂性、重选、切换、漫游、位置登记、更新和管理以及系统鉴权等。…

  • ubuntu root默认密码(初始密码)

    ubuntu root默认密码(初始密码)ubuntu安装好后,root初始密码(默认密码)不知道,需要设置。1、先用安装时候的用户登录进入系统2、输入:sudopasswd按回车3、输入新密码,重复输入密码,最后提示passwd

    2022年10月24日

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号