pytorch官方给的加载数据的方式是已经定义好的dataset以及loader,如何加载自己本地的图片以及label?
形如数据格式为
image1 label1
image2 label2
...
imagen labeln
实验中我采用的数据的格式如下,一个图片的名字对应一个label,每一个label是一个9维的向量
1_-2_pitch_100_yaw_0_lat_29.7553171_lng_-95.3675684.jpg 0.304295635957 0.952577642997 0.0614006041909 0.0938333659301 -0.995587916479 0.126405046864 -0.999368204665 0.0355414055005 0.382030624629 0.0
1_0_pitch_100_yaw_0_lat_29.7553171_lng_-95.3675684.jpg 0.271224474168 0.962516121742 0.061399602839 0.128727689658 -0.991679979588 0.126495313272 -0.999999890616 0.000467726796359 0.381981952872 0.0
1_2_pitch_100_yaw_0_lat_29.7553171_lng_-95.3675684.jpg 0.237868729379 0.971297311632 0.0614713240576 0.163626102983 -0.986522426721 0.1265439964 -0.999400990041 -0.0346072406472 0.382020891324 0.0
1.1_-2_pitch_100_yaw_0_lat_29.7553171_lng_-95.3675684.jpg 0.303575822293 0.95280728383 0.0675229548933 0.0939225945957 -0.995579502714 0.138745857429 -0.999376861795 0.0352971402251 0.410670255038 0.1
1.1_0_pitch_100_yaw_0_lat_29.7553171_lng_-95.3675684.jpg 0.270745576918 0.962650940154 0.0674654115238 0.128659340525 -0.991688849436 0.138685653232 -0.999999909615 0.000425170029598 0.410739827476 0.1
1.1_2_pitch_100_yaw_0_lat_29.7553171_lng_-95.3675684.jpg 0.23757921143 0.971368168253 0.0674866175928 0.16322766122 -0.986588430204 0.138789623782 -0.999406504329 -0.0344476284471 0.410661183171 0.1
1.2_-2_pitch_100_yaw_0_lat_29.7553171_lng_-95.3675684.jpg 0.305474635089 0.952200213882 0.0736939767933 0.0939968709874 -0.995572492712 0.150981626608 -0.999370773952 0.0354690875311 0.437620875774 0.2
1.2_0_pitch_100_yaw_0_lat_29.7553171_lng_-95.3675684.jpg 0.270346113421 0.962763199836 0.073518963401 0.128433455959 -0.991718129002 0.150964425444 -0.999999924062 0.000389711583812 0.437667827367 0.2
1.2_2_pitch_100_yaw_0_lat_29.7553171_lng_-95.3675684.jpg 0.237337349604 0.971427291403 0.0734898449879 0.162895476227 -0.986643331617 0.150931800731 -0.999411541516 -0.0343011761519 0.437608139736 0.2
1.3_-2_pitch_100_yaw_0_lat_29.7553171_lng_-95.3675684.jpg 0.305514664536 0.952187371137 0.0795990377393 0.0941741911595 -0.995555735115 0.162914965783 -0.999378340534 0.0352552474342 0.462816755558 0.3
1.3_0_pitch_100_yaw_0_lat_29.7553171_lng_-95.3675684.jpg 0.272366931798 0.962193459998 0.0796135882128 0.128398130503 -0.991722703221 0.162940731132 -0.999999935257 0.000359841646368 0.462733965419 0.3
...
源程序如下
import torch
import torch.nn as nn
import math
import os
from PIL import Image
import random
from torchvision import datasets, transforms
import torch.utils.data as data
from torch.autograd import Variable
torch.cuda.set_device(0)
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
kwargs = {'num_workers': 1, 'pin_memory': True}
batch_size = 8
# load the data
def random_choose_data(label_path):
random.seed(1)
file = open(label_path)
lines = file.readlines()
slice_initial = random.sample(lines, 200000) # if don't change this ,it will be all the same
slice = list(set(lines)-set(slice_initial))
random.shuffle(slice)
train_label = slice[:150000]
test_label = slice[150000:200000]
return train_label, test_label # output the list and delvery it into ImageFolder
# def my data loader, return the data and corresponding label
def default_loader(path):
return Image.open(path).convert('RGB') # operation object is the PIL image object
class myImageFloder(data.Dataset): # Class inheritance
def __init__(self, root, label, transform=None, target_transform=None, loader=default_loader):
# fh = open(label)
c = 0
imgs = []
class_names = ['regression']
for line in label: # label is a list
cls = line.split() # cls is a list
fn = cls.pop(0)
if os.path.isfile(os.path.join(root, fn)):
imgs.append((fn, tuple([float(v) for v in cls[:len(cls)-1]])))
# access the last label
# images is the list,and the content is the tuple, every image corresponds to a label
# despite the label's dimension
# we can use the append way to append the element for list
c = c + 1
print('the total image is',c)
print(class_names)
self.root = root
self.imgs = imgs
self.classes = class_names
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index] # even though the imgs is just a list, it can return the elements of it
# in a proper way
img = self.loader(os.path.join(self.root, fn))
if self.transform is not None:
img = self.transform(img)
return img, torch.Tensor(label), fn
def __len__(self):
return len(self.imgs)
def getName(self):
return self.classes