pytorch基础教程-模型编写(二)
对于一些重复的layer可以采用nn.Sequential先进行封装layer,然后调用_make_layer_进行实现
以ResNet18/34为例
模型图:
ResNet中将一个跨层直连的单元叫做Residual block,其结构如下,当输入和输出的通道不一致的时候,需要一个专门的单元将二者转成一致,使其可以相加。
实现:
class BasicBlock(nn.Module):
expansion = 1
def __init__(self,in_channels,out_channels,stride=1,downsample=None):
super(BasicBlock,self).__init__()
self.conv1 = conv3x3(
in_channels,out_channels,stride=stride
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(
in_channels,out_channels
)
self.bn2 = nn.BatchNorm2d(in_channels)
self.downsample = downsample
self.stride = stride
def forward(self,x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
_make_layer实现
def _make_layer(self,block,out_channels,blocks,stride=1):
## 虚线部分
downsample = None
if stride!=1 or self.in_channels != out_channels*block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.in_channels,out_channels*block.expansion,kernel_size=1,
stride=stride,bias=False
),
nn.BatchNorm2d(out_channels*block.expansion)
)
layers = []
layers.append(block(self.in_channels,out_channels,stride,downsample))
self.in_channels = out_channels*block.expansion
for i in range(1,blocks):
layers.append(block(self.in_channels,out_channels))
return nn.Sequential(*layers)
实现:
class ResNet(nn.Module):
def __init__(self,block,layers,num_classes=1000):
self.in_channels = 64
super(ResNet,self).__init__()
self.conv1 = nn.Conv2d(
3,64,kernel_size=7,stride=2,padding=3,bias=False
)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.layer1 = self._make_layer(block,64,layers[0])
self.layer2 = self._make_layer(block,128,layers[1],stride=2)
self.layer3 = self._make_layer(block,256,layers[2],stride=2)
self.layer4 = self._make_layer(block,512,layers[3],stride=2)
self.avgpool = nn.AvgPool2d(7,stride=1)
self.fc = nn.Linear(512*block.expansion,num_classes)
def _make_layer(self,block,out_channels,blocks,stride=1):
downsample = None
if stride!=1 or self.in_channels != out_channels*block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.in_channels,out_channels*block.expansion,kernel_size=1,
stride=stride,bias=False
),
nn.BatchNorm2d(out_channels*block.expansion)
)
layers = []
## 虚线部分
layers.append(block(self.in_channels,out_channels,stride,downsample))
## 更新 in_channels
self.in_channels = out_channels*block.expansion
for i in range(1,blocks):
layers.append(block(self.in_channels,out_channels))
return nn.Sequential(*layers)
def forward(self,x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x2 = self.layer1(x)
x3 = self.layer2(x2)
x4 = self.layer3(x3)
x5 = self.layer4(x4)
x6 = self.avgpool(x5)
x7 = x6.view(x6.size(0),-1)
return self.fc(x7)
resnet18实现
def resnet18():
model = ResNet(BasicBlock,[2,2,2,2])
return model
resnet34实现 [3, 4, 6, 3],把数组替换下