关于如何将数据集封装为 Bunch 可参考 关于 『AI 专属数据库的定制』的改进

PyTablesPython 与 HDF5 数据库/文件标准的结合。它专门为优化 I/O 操作的性能、最大限度地利用可用硬件而设计,并且它还支持压缩功能。

下面的代码均是在 Jupyter NoteBook 下完成的:

import sys 
sys.path.append(\'E:/xinlib\')
from base.filez import DataBunch
import tables as tb
import numpy as np


def bunch2hdf5(root):
    \'\'\'
    这里我仅仅封装了 Cifar10、Cifar100、MNIST、Fashion MNIST 数据集,
    使用者还可以自己追加数据集。
    \'\'\'
    db = DataBunch(root)
    filters = tb.Filters(complevel=7, shuffle=False)
    # 这里我采用了压缩表,因而保存为 `.h5c` 但也可以保存为 `.h5`
    with tb.open_file(f\'{root}X.h5c\', \'w\', filters=filters, title=\'Xinet\\'s dataset\') as h5:
        for name in db.keys():
            h5.create_group(\'/\', name, title=f\'{db[name].url}\')
            if name != \'cifar100\':
                h5.create_array(h5.root[name], \'trainX\', db[name].trainX, title=\'训练数据\')
                h5.create_array(h5.root[name], \'trainY\', db[name].trainY, title=\'训练标签\')
                h5.create_array(h5.root[name], \'testX\', db[name].testX, title=\'测试数据\')
                h5.create_array(h5.root[name], \'testY\', db[name].testY, title=\'测试标签\')
            else:
                h5.create_array(h5.root[name], \'trainX\', db[name].trainX, title=\'训练数据\')
                h5.create_array(h5.root[name], \'testX\', db[name].testX, title=\'测试数据\')
                h5.create_array(h5.root[name], \'train_coarse_labels\', db[name].train_coarse_labels, title=\'超类训练标签\')
                h5.create_array(h5.root[name], \'test_coarse_labels\', db[name].test_coarse_labels, title=\'超类测试标签\')
                h5.create_array(h5.root[name], \'train_fine_labels\', db[name].train_fine_labels, title=\'子类训练标签\')
                h5.create_array(h5.root[name], \'test_fine_labels\', db[name].test_fine_labels, title=\'子类测试标签\')

        for k in [\'cifar10\', \'cifar100\']:
            for name in db[k].meta.keys():
                name = name.decode()
                if name.endswith(\'names\'):
                    label_names = np.asanyarray([label_name.decode() for label_name in db[k].meta[name.encode()]])
                    h5.create_array(h5.root[k], name, label_names, title=\'标签名称\')

完成 BunchHDF5 的转换

root = \'E:/Data/Zip/\'
bunch2hdf5(root)
h5c = tb.open_file(\'E:/Data/Zip/X.h5c\')
h5c
File(filename=E:/Data/Zip/X.h5c, title="Xinet\'s dataset", mode=\'r\', root_uep=\'/\', filters=Filters(complevel=7, complib=\'zlib\', shuffle=False, bitshuffle=False, fletcher32=False, least_significant_digit=None))
/ (RootGroup) "Xinet\'s dataset"
/cifar10 (Group) \'https://www.cs.toronto.edu/~kriz/cifar.html\'
/cifar10/label_names (Array(10,)) \'标签名称\'
  atom := StringAtom(itemsize=10, shape=(), dflt=b\'\')
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'irrelevant\'
  chunkshape := None
/cifar10/testX (Array(10000, 32, 32, 3)) \'测试数据\'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'irrelevant\'
  chunkshape := None
/cifar10/testY (Array(10000,)) \'测试标签\'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'little\'
  chunkshape := None
/cifar10/trainX (Array(50000, 32, 32, 3)) \'训练数据\'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'irrelevant\'
  chunkshape := None
/cifar10/trainY (Array(50000,)) \'训练标签\'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'little\'
  chunkshape := None
/cifar100 (Group) \'https://www.cs.toronto.edu/~kriz/cifar.html\'
/cifar100/coarse_label_names (Array(20,)) \'标签名称\'
  atom := StringAtom(itemsize=30, shape=(), dflt=b\'\')
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'irrelevant\'
  chunkshape := None
/cifar100/fine_label_names (Array(100,)) \'标签名称\'
  atom := StringAtom(itemsize=13, shape=(), dflt=b\'\')
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'irrelevant\'
  chunkshape := None
/cifar100/testX (Array(10000, 32, 32, 3)) \'测试数据\'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'irrelevant\'
  chunkshape := None
/cifar100/test_coarse_labels (Array(10000,)) \'超类测试标签\'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'little\'
  chunkshape := None
/cifar100/test_fine_labels (Array(10000,)) \'子类测试标签\'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'little\'
  chunkshape := None
/cifar100/trainX (Array(50000, 32, 32, 3)) \'训练数据\'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'irrelevant\'
  chunkshape := None
/cifar100/train_coarse_labels (Array(50000,)) \'超类训练标签\'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'little\'
  chunkshape := None
/cifar100/train_fine_labels (Array(50000,)) \'子类训练标签\'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'little\'
  chunkshape := None
/fashion_mnist (Group) \'https://github.com/zalandoresearch/fashion-mnist\'
/fashion_mnist/testX (Array(10000, 28, 28, 1)) \'测试数据\'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'irrelevant\'
  chunkshape := None
/fashion_mnist/testY (Array(10000,)) \'测试标签\'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'little\'
  chunkshape := None
/fashion_mnist/trainX (Array(60000, 28, 28, 1)) \'训练数据\'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'irrelevant\'
  chunkshape := None
/fashion_mnist/trainY (Array(60000,)) \'训练标签\'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'little\'
  chunkshape := None
/mnist (Group) \'http://yann.lecun.com/exdb/mnist\'
/mnist/testX (Array(10000, 28, 28, 1)) \'测试数据\'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'irrelevant\'
  chunkshape := None
/mnist/testY (Array(10000,)) \'测试标签\'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'little\'
  chunkshape := None
/mnist/trainX (Array(60000, 28, 28, 1)) \'训练数据\'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'irrelevant\'
  chunkshape := None
/mnist/trainY (Array(60000,)) \'训练标签\'
  atom := Int32Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'little\'
  chunkshape := None

从上面的结构可看出我将 Cifar10Cifar100MNISTFashion MNIST 进行了封装,并且还附带了它们各种的数据集信息。比如标签名,数字特征(以数组的形式进行封装)等。

%%time
arr = h5c.root.cifar100.trainX.read() # 读取数据十分快速
Wall time: 125 ms
arr.shape
(50000, 32, 32, 3)
h5c.root
/ (RootGroup) "Xinet\'s dataset"
  children := [\'cifar10\' (Group), \'cifar100\' (Group), \'fashion_mnist\' (Group), \'mnist\' (Group)]

X.h5c 使用说明

下面我们以 Cifar100 为例来展示我们自创的数据集 X.h5c(我将其上传到了百度云盘「链接:https://pan.baidu.com/s/1hsbMhv3MDlOES3UDDmOQiw 密码:qlb7」可以下载直接使用;亦可你自己生成,不过我推荐自己生成,可以对数据集加深理解)

cifar100 = h5c.root.cifar100
cifar100
/cifar100 (Group) \'https://www.cs.toronto.edu/~kriz/cifar.html\'
  children := [\'coarse_label_names\' (Array), \'fine_label_names\' (Array), \'testX\' (Array), \'test_coarse_labels\' (Array), \'test_fine_labels\' (Array), \'trainX\' (Array), \'train_coarse_labels\' (Array), \'train_fine_labels\' (Array)]

\'coarse_label_names\' 指的是粗粒度或超类标签名,\'fine_label_names\' 则是细粒度标签名。

可以使用 read() 方法直接获取信息,也可以使用索引的方式获取。

coarse_label_names = cifar100.coarse_label_names[:]
# 或者
coarse_label_names = cifar100.coarse_label_names.read()
coarse_label_names.astype(\'str\')
array([\'aquatic_mammals\', \'fish\', \'flowers\', \'food_containers\',
       \'fruit_and_vegetables\', \'household_electrical_devices\',
       \'household_furniture\', \'insects\', \'large_carnivores\',
       \'large_man-made_outdoor_things\', \'large_natural_outdoor_scenes\',
       \'large_omnivores_and_herbivores\', \'medium_mammals\',
       \'non-insect_invertebrates\', \'people\', \'reptiles\', \'small_mammals\',
       \'trees\', \'vehicles_1\', \'vehicles_2\'], dtype=\'<U30\')
fine_label_names = cifar100.fine_label_names[:].astype(\'str\')
fine_label_names
array([\'apple\', \'aquarium_fish\', \'baby\', \'bear\', \'beaver\', \'bed\', \'bee\',
       \'beetle\', \'bicycle\', \'bottle\', \'bowl\', \'boy\', \'bridge\', \'bus\',
       \'butterfly\', \'camel\', \'can\', \'castle\', \'caterpillar\', \'cattle\',
       \'chair\', \'chimpanzee\', \'clock\', \'cloud\', \'cockroach\', \'couch\',
       \'crab\', \'crocodile\', \'cup\', \'dinosaur\', \'dolphin\', \'elephant\',
       \'flatfish\', \'forest\', \'fox\', \'girl\', \'hamster\', \'house\',
       \'kangaroo\', \'keyboard\', \'lamp\', \'lawn_mower\', \'leopard\', \'lion\',
       \'lizard\', \'lobster\', \'man\', \'maple_tree\', \'motorcycle\', \'mountain\',
       \'mouse\', \'mushroom\', \'oak_tree\', \'orange\', \'orchid\', \'otter\',
       \'palm_tree\', \'pear\', \'pickup_truck\', \'pine_tree\', \'plain\', \'plate\',
       \'poppy\', \'porcupine\', \'possum\', \'rabbit\', \'raccoon\', \'ray\', \'road\',
       \'rocket\', \'rose\', \'sea\', \'seal\', \'shark\', \'shrew\', \'skunk\',
       \'skyscraper\', \'snail\', \'snake\', \'spider\', \'squirrel\', \'streetcar\',
       \'sunflower\', \'sweet_pepper\', \'table\', \'tank\', \'telephone\',
       \'television\', \'tiger\', \'tractor\', \'train\', \'trout\', \'tulip\',
       \'turtle\', \'wardrobe\', \'whale\', \'willow_tree\', \'wolf\', \'woman\',
       \'worm\'], dtype=\'<U13\')

\'testX\'\'trainX\' 分别代表数据的测试数据和训练数据,而其他的节点所代表的含义也是类似的。

例如,我们可以看看训练集的数据和标签:

trainX = cifar100.trainX
train_coarse_labels = cifar100.train_coarse_labels

array([11, 15,  4, ...,  8,  7,  1])

shape(50000, 32, 32, 3),数据的获取,我们一样可以采用索引的形式或者使用 read()

train_data = trainX[:]
print(train_data[0].shape)
print(train_data.dtype)
(32, 32, 3)
uint8

当然,我们也可以直接使用 trainX 做运算。

for x in cifar100.trainX:
    y = x * 2
    break

print(y.shape)
(32, 32, 3)
h5c.get_node(h5c.root.cifar100, \'trainX\')
/cifar100/trainX (Array(50000, 32, 32, 3)) \'训练数据\'
  atom := UInt8Atom(shape=(), dflt=0)
  maindim := 0
  flavor := \'numpy\'
  byteorder := \'irrelevant\'
  chunkshape := None

更甚者,我们可以直接定义迭代器来获取数据:

trainX = cifar100.trainX
train_coarse_labels = cifar100.train_coarse_labels
def data_iter(X, Y, batch_size):
    n = X.nrows
    idx = np.arange(n)
    if X.name.startswith(\'train\'):
        np.random.shuffle(idx)
    for i in range(0, n ,batch_size):
        k = idx[i: min(n, i + batch_size)].tolist()
        yield np.take(X, k, 0), np.take(Y, k, 0)
for x, y in data_iter(trainX, train_coarse_labels, 8):
    print(x.shape, y)
    break
(8, 32, 32, 3) [ 7  7  0 15  4  8  8  3]

更多使用详情见:使用 迭代器 获取 Cifar 等常用数据集

版权声明:本文为q735613050原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://www.cnblogs.com/q735613050/p/9244223.html