Pytorch自定义dataloader以及在迭代过程中返回image的(4)


# def my data loader, return the data and corresponding label
def default_loader(path):
    return Image.open(path).convert('RGB')  # operation object is the PIL image object


class myImageFloder(data.Dataset):  # Class inheritance,继承Dataset类
    def __init__(self, root, label, transform=None, target_transform=None, loader=default_loader):
        # fh = open(label)
        c = 0
        imgs = []
        class_names = ['regression']
        for line in label:  # label is a list
            cls = line.split()  # cls is a list
            fn = cls.pop(0)
            if os.path.isfile(os.path.join(root, fn)):
                imgs.append((fn, tuple([float(v) for v in cls[:len(cls)-1]])))
                # access the last label
                # images is the list,and the content is the tuple, every image corresponds to a label
                # despite the label's dimension
                # we can use the append way to append the element for list
            c = c + 1
        print('the total image is',c)
        print(class_names)
        self.root = root
        self.imgs = imgs
        self.classes = class_names
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
    def __getitem__(self, index):
        fn, label = self.imgs[index]  # even though the imgs is just a list, it can return the elements of it
        # in a proper way
        img = self.loader(os.path.join(self.root, fn))
        if self.transform is not None:
            img = self.transform(img)
        return img, torch.Tensor(label), fn  # 在这里返回图像数据以及对应的label以及对应的名称

def __len__(self):
        return len(self.imgs)

def getName(self):
        return self.classes

实际上是继承Dataset这个类中的两个函数__getitem____len__,并且返回的变量类型是torch.Tensor即可

看dataloader定义方式以及如何在dataloader中加载数据

mytransform = transforms.Compose([transforms.ToTensor()])  # transform [0,255] to [0,1]
test_data_root = "/home/ying/data/google_streetview_train_test1"
data_label = "/home/ying/data/google_streetview_train_test1/label.txt"
# test_label="/home/ying/data/google_streetview_train_test1/label.txt"
train_label, test_label = random_choose_data(data_label)
test_loader = torch.utils.data.DataLoader(
        myImageFloder(root=test_data_root, label=test_label, transform=mytransform),batch_size=batch_size, shuffle=True, **kwargs)
...
for i, (test_images, test_labels, fn) in enumerate(test_loader):  # the first i in index, and the () is the content
    test_images = Variable(test_images.cuda())
    test_labels = Variable(test_labels.cuda())
    outputs = cnn(test_images)
    print(outputs.data[0])
    print(fn)
    loss = criterion(outputs, test_labels)
    print("Iter [%d/%d] Test_Loss: %.4f" % (i + 1, 781, loss.data[0]))

实际上刚刚在myImageFloder中定义的__getitem__实际上就是i, (test_images, test_labels, fn) in enumerate(test_loader): 中返回的对象, 其中第一个i是与enumberate相关的index

这样就能够在模型test的时候观察哪些数据误差比较大并且进行输出

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/1f2de4ccea99914fdd992d479ea496a6.html