基于Kaggle的图像分类(CIFAR-10)
Image Classification (CIFAR-10) on Kaggle
一直在使用Gluon’s data package数据包直接获得张量格式的图像数据集。然而,在实际应用中,图像数据集往往以图像文件的形式存在。将从原始图像文件开始,逐步组织、读取并将文件转换为张量格式。对CIFAR-10数据集进行了一个实验。这是计算机视觉领域的一个重要数据集。现在,将应用前面几节中所学的知识来参加Kaggle竞赛,该竞赛解决CIFAR-10图像分类问题。
比赛的网址是https://www.kaggle.com/c/cifar-10
图1显示了比赛网页上的信息。为了提交结果,请先在Kaggle网站注册一个帐户。
Fig. 1 CIFAR-10 image classification competition webpage information. The dataset for the competition can be accessed by clicking the “Data” tab.
首先,导入比赛所需的软件包或模块。
import collections
from d2l import mxnet as d2l
import math
from mxnet import autograd, gluon, init, npx
from mxnet.gluon import nn
import os
import pandas as pd
import shutil
import time
npx.set_np()
1. Obtaining and Organizing the Dataset
比赛数据分为训练集和测试集。训练集包含50000帧图像。测试集包含30万帧图像,其中10000帧图像用于评分,而其29万帧包括非评分图像,以防止手动标记测试集和提交标记结果。两个数据集中的图像格式都是PNG,高度和宽度都是32个像素和三个颜色通道(RGB)。图像覆盖1010类别:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。图中左上角显示了数据集中飞机、汽车和鸟类的一些图像。
1.1. Downloading the Dataset
登录Kaggle后,点击图1所示CIFAR-10图像分类竞赛网页上的“数据”选项卡,点击“全部下载”按钮下载数据集。在../data中解压缩下载的文件,并在其中解压缩train.7z和test.7z之后,将在以下路径中找到整个数据集:
../data/cifar-10/train/[1-50000].png
../data/cifar-10/test/[1-300000].png
../data/cifar-10/trainLabels.csv
../data/cifar-10/sampleSubmission.csv
这里的“训练”和“测试”文件夹分别包含训练和测试图像,trainLabels.csv有训练图像的标签和sample_submission.csv是提交的样本。为了便于入门,提供了一个小规模的数据集示例:包含第一个1000帧训练图像和55随机测试图像。要使用Kaggle竞赛的完整数据集,需要将以下demo变量设置为False。
#@save
d2l.DATA_HUB[\'cifar10_tiny\'] = (d2l.DATA_URL + \'kaggle_cifar10_tiny.zip\',
\'2068874e4b9a9f0fb07ebe0ad2b29754449ccacd\')
# If you use the full dataset downloaded for the Kaggle competition, set the
# demo variable to False
demo = True
if demo:
data_dir = d2l.download_extract(\'cifar10_tiny\')
else:
data_dir = \'../data/cifar-10/\'
1.2. Organizing the Dataset
需要组织数据集来促进模型的训练和测试。让首先从csv文件中读取标签。以下函数返回一个字典,该字典将不带扩展名的文件名映射到其标签。
#@save
def read_csv_labels(fname):
"""Read fname to return a name to label dictionary."""
with open(fname, \'r\') as f:
# Skip the file header line (column name)
lines = f.readlines()[1:]
tokens = [l.rstrip().split(\',\') for l in lines]
return dict(((name, label) for name, label in tokens))
labels = read_csv_labels(os.path.join(data_dir, \'trainLabels.csv\'))
print(\'# training examples:\', len(labels))
print(\'# classes:\', len(set(labels.values())))
# training examples: 1000
# classes: 10
接下来,定义reorg_train_valid函数来从原始训练集中分割验证集。此函数中的参数valid_ratio是验证集中的示例数与原始训练集中的示例数的比率。特别是让n是具有最少示例的类的图像数,以及r是比率,那么将使用最大值(⌊nr⌋,1),每个类的图像作为验证集。让以valid_ratio=0.1为例。从最初的训练开始50000帧图像,会有45000帧。当调整超参数时,用于训练并存储在路径“train_valid_test/train”中的图像,而另一个5000帧图像将作为验证集存储在“train_valid_test/train”路径中。组织好数据后,同一类的图像将被放在同一个文件夹下,以便以后阅读。
#@save
def copyfile(filename, target_dir):
"""Copy a file into a target directory."""
d2l.mkdir_if_not_exist(target_dir)
shutil.copy(filename, target_dir)
#@save
def reorg_train_valid(data_dir, labels, valid_ratio):
# The number of examples of the class with the least examples in the
# training dataset
n = collections.Counter(labels.values()).most_common()[-1][1]
# The number of examples per class for the validation set
n_valid_per_label = max(1, math.floor(n * valid_ratio))
label_count = {}
for train_file in os.listdir(os.path.join(data_dir, \'train\')):
label = labels[train_file.split(\'.\')[0]]
fname = os.path.join(data_dir, \'train\', train_file)
# Copy to train_valid_test/train_valid with a subfolder per class
copyfile(fname, os.path.join(data_dir, \'train_valid_test\',
\'train_valid\', label))
if label not in label_count or label_count[label] < n_valid_per_label:
# Copy to train_valid_test/valid
copyfile(fname, os.path.join(data_dir, \'train_valid_test\',
\'valid\', label))
label_count[label] = label_count.get(label, 0) + 1
else:
# Copy to train_valid_test/train
copyfile(fname, os.path.join(data_dir, \'train_valid_test\',
\'train\', label))
return n_valid_per_label
下面的reorg_test函数用于组织测试集,以便于预测期间的读数。
#@save
def reorg_test(data_dir):
for test_file in os.listdir(os.path.join(data_dir, \'test\')):
copyfile(os.path.join(data_dir, \'test\', test_file),
os.path.join(data_dir, \'train_valid_test\', \'test\',
\'unknown\'))
使用一个函数来调用先前定义的read_csv_labels、reorg_train_valid和reorg_test函数。
def reorg_cifar10_data(data_dir, valid_ratio):
labels = read_csv_labels(os.path.join(data_dir, \'trainLabels.csv\'))
reorg_train_valid(data_dir, labels, valid_ratio)
reorg_test(data_dir)