图像分类学习:X光胸片诊断识别----迁移学习 (2)

图像分类学习:X光胸片诊断识别----迁移学习

好,想在继续往下走

下面呢给出了四个训练模型,实战中我们只需要挑其中一个进行训练就好,其他的模型要注释掉,下面代码上四个模型我都会分析

# inception------------------------------------------------------inception模型,有趣的是它可以翻译为盗梦空间 model = models.inception_v3(pretrained=True) # inception_v3是一个预训练模型, pretrained=True执行后会把模型下载到我们的电脑上 model.aux_logits = False # 是否给模型创建辅助,具体增么个辅助太复杂,请观众老爷们自行谷歌 num_fc_in = model.fc.in_features # 提取fc层固定的参数 # 改变全连接层,2分类问题,out_features = 2 model.fc = nn.Linear(num_fc_in, num_classes) # 修改fc层参数为num_classes = 4(最前面前面定义了) # alexnet--------------------------------------------------------alexnet模型 model = models.alexnet(pretrained=True) # alexnet是一个预训练模型, pretrained=True执行后会把模型下载到我们的电脑上 num_fc_in = model.classifier[6].in_features # 提取fc层固定的参数 model.fc = torch.nn.Linear(num_fc_in, num_classes) # 修改fc层参数为num_classes = 4(最前面前面定义了) model.classifier[6] = model.fc #将图层初始化为model.fc #相当于model.classifier[6] = torch.nn.Linear(num_fc_in, num_classes) # 建立VGG16迁移学习模型------------------------------------------------vgg16模型 model = torchvision.models.vgg16(pretrained=True)# vgg16是一个预训练模型, pretrained=True执行后会把模型下载到我们的电脑上 # 先将模型参数改为不可更新 for param in model.parameters(): param.requires_grad = False # 再更改最后一层的输出,至此网络只能更改该层参数 model.classifier[6] = nn.Linear(4096, num_classes) model.classifier = torch.nn.Sequential( # 修改全连接层 自动梯度会恢复为默认值 torch.nn.Linear(25088, 4096), torch.nn.ReLU(), torch.nn.Dropout(p=0.5), torch.nn.Linear(4096, 4096), torch.nn.Dropout(p=0.5), torch.nn.Linear(4096, num_classes)) # resnet18---------------------------------------------------------------resnet模型(和前几个模型差不多,自己脑部吧) model = models.resnet18(pretrained=True) # 全连接层的输入通道in_channels个数 num_fc_in = model.fc.in_features # 改变全连接层,2分类问题,out_features = 2 model.fc = nn.Linear(num_fc_in, num_classes) 继续,解释都在注释里了 # 定义训练函数 def train_model(model, dataloaders, criterion, optimizer, mundde_epochs=25): since = time.time() # 返回当前时间的时间戳(1970纪元后经过的浮点秒数) # state_dict变量存放训练过程中需要学习的权重和偏执系数,state_dict作为python的字典对象将每一层的参数映射成tensor张量, # 需要注意的是torch.nn.Module模块中的state_dict只包含卷积层和全连接层的参数 best_model_wts = copy.deepcopy(model.state_dict()) # copy是一个复制函数 best_acc = 0.0 # 下面这个迭代就是一个进度条的输出,从0到9显示进度 for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # 下面这个迭代,范围就两个'train', 'val',对应不执行不同的训练模式 for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() running_loss = 0.0 running_corrects = 0.0 for inputs, labels in dataloaders[phase]: # 下面这行代码的意思是将所有最开始读取数据时的tensor变量copy一份到device所指定的GPU或CPU上去, # 之后的运算都在GPU或CPU上进行 inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() # 模型梯度设为0 # 接下来所有的tensor运算产生的新的节点都是不可求导的 with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) # output等于把inputs放到指定设备上去运算 loss = criterion(outputs, labels) # loss为outputs和labels的交叉熵损失 # 举例:output = torch.max(input, dim) # 输入 # input是softmax函数输出的一个tensor # dim是max函数索引的维度0 / 1,0是每列的最大值,1是每行的最大值 # 输出 # 函数会返回两个tensor,第一个tensor是每行的最大值,softmax的输出中最大的是1,所以第一个tensor是全1的tensor; # 第二个tensor是每行最大值的索引。 _, preds = torch.max(outputs, 1) if phase == 'train': loss.backward() # 反向传播计算得到每个参数的梯度值 optimizer.step() # 通过梯度下降执行一步参数更新 running_loss += loss.item() * inputs.size(0) running_corrects += (preds == labels).sum().item() epoch_loss = running_loss / len(dataloaders[phase].dataset) epoch_acc = running_corrects / len(dataloaders[phase].dataset) print('{} loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) if phase == 'val' and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('Best val Acc: {:.4f}'.format(best_acc)) model.load_state_dict(best_model_wts) return model 继续往下看 # 定义优化器和损失函数 model = model.to(device) # 前面解释过了 optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # optimizer = optim.Adam(model.classifier.parameters(), lr=0.0001) # sched = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1) criterion = nn.CrossEntropyLoss() # 交叉熵损失函数

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/wsszxs.html