在pytorch中自定义网络,集成nn.Module类并重载__init__(self)和forward,分别定义网络组成和前向传播,这里有一个简单的例子。
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))下面先看一下PSPNet的论文介绍,网络结构非常简单,在ResNet之后接一个PPM模块。
此外PSPNet还采用了辅助损失分支。
import torch.nn as nn from torch.nn import functional as F import math import torch.utils.model_zoo as model_zoo import torch import numpy as np from torch.autograd import Variable affine_par = True import functools import sys, os from libs import InPlaceABN, InPlaceABNSync BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') def conv3x3(in_planes, out_planes, stride=1): "3x3 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) #ResNet的Bottleneck class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, fist_dilation=1, multi_grid=1): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation*multi_grid, dilation=dilation*multi_grid, bias=False) self.bn2 = BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=False) self.relu_inplace = nn.ReLU(inplace=True) self.downsample = downsample self.dilation = dilation self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out = out + residual out = self.relu_inplace(out) return out #PPM模块 class PSPModule(nn.Module): """ Reference: Zhao, Hengshuang, et al. *"Pyramid scene parsing network."* """ def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)): super(PSPModule, self).__init__() self.stages = [] self.stages = nn.ModuleList([self._make_stage(features, out_features, size) for size in sizes]) self.bottleneck = nn.Sequential( nn.Conv2d(features+len(sizes)*out_features, out_features, kernel_size=3, padding=1, dilation=1, bias=False), InPlaceABNSync(out_features), nn.Dropout2d(0.1) ) def _make_stage(self, features, out_features, size): prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False) bn = InPlaceABNSync(out_features) return nn.Sequential(prior, conv, bn) def forward(self, feats): h, w = feats.size(2), feats.size(3) priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] + [feats] bottle = self.bottleneck(torch.cat(priors, 1)) return bottle #PSPNet网络整体 class ResNet(nn.Module): def __init__(self, block, layers, num_classes): self.inplanes = 128 super(ResNet, self).__init__() self.conv1 = conv3x3(3, 64, stride=2) self.bn1 = BatchNorm2d(64) self.relu1 = nn.ReLU(inplace=False) self.conv2 = conv3x3(64, 64) self.bn2 = BatchNorm2d(64) self.relu2 = nn.ReLU(inplace=False) self.conv3 = conv3x3(64, 128) self.bn3 = BatchNorm2d(128) self.relu3 = nn.ReLU(inplace=False) # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.relu = nn.ReLU(inplace=False) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, multi_grid=(1,1,1)) self.head = nn.Sequential(PSPModule(2048, 512), nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)) #辅助损失 self.dsn = nn.Sequential( nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), InPlaceABNSync(512), nn.Dropout2d(0.1), nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True) ) def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), BatchNorm2d(planes * block.expansion,affine = affine_par)) layers = [] generate_multi_grid = lambda index, grids: grids[index%len(grids)] if isinstance(grids, tuple) else 1 layers.append(block(self.inplanes, planes, stride,dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid))) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid))) return nn.Sequential(*layers) def forward(self, x): #(1,3,769,769) x = self.relu1(self.bn1(self.conv1(x))) #(1,64,385,385) x = self.relu2(self.bn2(self.conv2(x))) #(1,64,385,385) x = self.relu3(self.bn3(self.conv3(x))) #(1,128,385,385) x = self.maxpool(x) #(1,128,193,193) x = self.layer1(x) #(1,256,97,97) x = self.layer2(x) #(1,512,97,97) x = self.layer3(x) #(1,1024,97,97) x_dsn = self.dsn(x) #(1,19,97,97) x = self.layer4(x) #(1,2048,97,97) x = self.head(x) #(1,19,769,769) return [x, x_dsn] def Res_Deeplab(num_classes=21): model = ResNet(Bottleneck,[3, 4, 23, 3], num_classes) return model