医学图像分割网络
ZOU

1 UNet

1.1 结构与思路

接下来介绍一种十分适用于医学影像分割的网络,UNet。

首先来说明一下医学影像的特点,为什么UNet比较适合医学影像分割:

  1. 医学影像语义比较简单,结构固定。但是也因此,无论是其低级特征还是高级语义特征都十分重要,所以U型结构的 skip connection 结构(特征拼接)更好派上用场
  2. 医学影像数量比较少,获取难度大,大型网络比较容易过拟合。UNet 这样比较小的网络会比较合适。事实上,有人发现在小数量级中,分割的SOTA模型与轻量级的 UNet 相比并没有什么优势
  3. 医学影像往往是多模态的。因此医学影像任务中,往往需要自己设计网络去提取不同的模态特征,因此轻量结构简单的Unet可以有更大的操作空间。(有很多变种网络)

接下来讲解一下 UNet 网络结构特点。网如其名,它是一种 U 型的网络,可以获取上下文的信息和位置信息。

这个网络大致分为两部分,左边是特征提取网络,右边是特征融合网络

将经过高分辨率—编码—低分辨率—解码—高分辨率的过程。

在特征提取网络中,由两个 3 x 3 的卷积层(ReLU)再加上一个 2 x 2 的 max pooling 层组成一个下采样的模块,一共经过4次这样的操作。而在后面的特征融合网络中,由一层反卷积 + 特征拼接 concat + 两个 3 x 3 的卷积层(ReLU)反复构成,一共经过4次这样的操作,与特征提取网络刚好相对应,最后接一层 1 * 1 卷积,降维处理,即将通道数降低至特定的数量,得到目标图。

UNet的好处:通过反卷积得到的更大的尺寸的特征图的边缘,是缺少信息的,每一次下采样提炼特征的同时,也必然会损失一些边缘特征,而失去的特征并不能从上采样中找回,因此通过特征的拼接,来实现边缘特征的找回。

1.2 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch.nn as nn
import torch

# 构造左边特征提取基础模块
class conv_block(nn.Module):
def __init__(self, in_ch, out_ch):
super(conv_block, self).__init__()

self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3,
stride=1, padding=1, bias=True),
# 卷积神经网络的卷积层之后总会添加批量归一化操作,
# 防止数据在ReLU之前不会因为过大而导致网络性能不稳定
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3,
stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)

def forward(self, x):
x = self.conv(x)
return x

# 构造右边特征融合基础模块
class up_conv(nn.Module):
def __init__(self, in_ch, out_ch):
super(up_conv, self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(in_ch, out_ch, kernel_size=3,
stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)

def forward(self, x):
x = self.up(x)
return x

# 构造UNet主框架
class UNet(nn.Module):
def __init__(self, in_ch=3, out_ch=2):
super(UNet, self).__init__()

# 卷积参数设置
n1 = 64
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

# 最大池化层
self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

# 左边特征提取卷积层
self.Conv1 = conv_block(in_ch, filters[0])
self.Conv2 = conv_block(filters[0], filters[1])
self.Conv3 = conv_block(filters[1], filters[2])
self.Conv4 = conv_block(filters[2], filters[3])
self.Conv5 = conv_block(filters[3], filters[4])

# 右边特征融合反卷积层
self.Up5 = up_conv(filters[4], filters[3])
self.Up_conv5 = conv_block(filters[4], filters[3])

self.Up4 = up_conv(filters[3], filters[2])
self.Up_conv4 = conv_block(filters[3], filters[2])

self.Up3 = up_conv(filters[2], filters[1])
self.Up_conv3 = conv_block(filters[2], filters[1])

self.Up2 = up_conv(filters[1], filters[0])
self.Up_conv2 = conv_block(filters[1], filters[0])

self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1,
stride=1, padding=0)

# 前向计算,输出一张与原图相同尺寸的图片矩阵
def forward(self, x):
e1 = self.Conv1(x)

e2 = self.Maxpool1(e1)
e2 = self.Conv2(e2)

e3 - self.Maxpool2(e2)
e3 = self.Conv3(e3)

e4 = self.Maxpool3(e3)
e4 = self.Conv4(e4)

e5 = self.Maxpool4(e4)
e5 = self.Conv5(e5)

# 过第一个上采样时降低了通道数
d5 = self.Up5(e5)

# 将e4特征图和d5特征图横向拼接
d5 = torch.cat((e4, d5), dim=1)

d5 = self.Up_conv5(d5)

d4 = self.Up4(d5)
d4 = torch.cat((e3, d4), dim=1)
d4 = self.Up_conv(d4)

d3 = self.Up3(d4)
d3 = torch.cat((e2, d3), dim=1) # 将e2特征图与d3特征图横向拼接
d3 = self.Up_conv3(d3)

d2 = self.Up2(d3)
d2 = torch.cat((e1, d2), dim=1) # 将e1特征图与d1特征图横向拼接
d2 = self.Up_conv2(d2)

out = self.Conv(d2)

return out

参考博客:

  1. 图像分割必备知识点 | Unet详解 理论+ 代码 - 忽逢桃林 - 博客园 (cnblogs.com)
  2. unet模型及代码解析_静待缘起的博客-CSDN博客_unet模型代码

2 VNet

2.1 结构与思路

VNet 是 UNet 的一种改进网络,其构建与 UNet 高度一致。最大的特点就是可以高效地处理三维影像。

下面是 VNet 的网络结构图,它保留了 UNet 进行特征图的拼接增大感受野。将卷积层代替上采样和下采样。除了将主要处理对象修改成为了三维影像之外,其最大的改进就是在每一个下采样之后,VNet 采用了 ResNet 的短路连接方式(灰色路线)。相当于在 UNet 中引入残差块。这是 VNet 最大的改进之处。源论文指出这种改进有助于 VNet 训练过程的收敛。

2.2 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import torch
from torch import nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
def __init__(self, n_stages, n_filters_in,
n_filters_out, normalization='none'):
super(ConvBlock, self).__init__()

ops = []
for i in range(n_stages):
if i==0:
input_channel = n_filters_in
else:
input_channel = n_filters_out

ops.append(nn.Conv3d(input_channel, n_filters_out,
3, padding=1))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16,
num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
elif normalization != 'none':
assert False
ops.append(nn.ReLU(inplace=True))

self.conv = nn.Sequential(*ops)

def forward(self, x):
x = self.conv(x)
return x


class ResidualConvBlock(nn.Module):
def __init__(self, n_stages, n_filters_in,
n_filters_out, normalization='none'):
super(ResidualConvBlock, self).__init__()

ops = []
for i in range(n_stages):
if i == 0:
input_channel = n_filters_in
else:
input_channel = n_filters_out

ops.append(nn.Conv3d(input_channel, n_filters_out,
3, padding=1))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16,
num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
elif normalization != 'none':
assert False

if i != n_stages-1:
ops.append(nn.ReLU(inplace=True))

self.conv = nn.Sequential(*ops)
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
x = (self.conv(x) + x)
x = self.relu(x)
return x


class DownsamplingConvBlock(nn.Module):
def __init__(self, n_filters_in, n_filters_out,
stride=2, normalization='none'):
super(DownsamplingConvBlock, self).__init__()

ops = []
if normalization != 'none':
ops.append(nn.Conv3d(n_filters_in, n_filters_out,
stride, padding=0, stride=stride))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16,
num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
else:
assert False
else:
ops.append(nn.Conv3d(n_filters_in, n_filters_out,
stride, padding=0, stride=stride))

ops.append(nn.ReLU(inplace=True))

self.conv = nn.Sequential(*ops)

def forward(self, x):
x = self.conv(x)
return x


class UpsamplingDeconvBlock(nn.Module):
def __init__(self, n_filters_in, n_filters_out,
stride=2, normalization='none'):
super(UpsamplingDeconvBlock, self).__init__()

ops = []
if normalization != 'none':
ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out,
stride, padding=0, stride=stride))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16,
num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
else:
assert False
else:
ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out,
stride, padding=0, stride=stride))

ops.append(nn.ReLU(inplace=True))

self.conv = nn.Sequential(*ops)

def forward(self, x):
x = self.conv(x)
return x


class Upsampling(nn.Module):
def __init__(self, n_filters_in, n_filters_out,
stride=2, normalization='none'):
super(Upsampling, self).__init__()

ops = []
ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',
align_corners=False))
ops.append(nn.Conv3d(n_filters_in, n_filters_out,
kernel_size=3, padding=1))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16,
num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
elif normalization != 'none':
assert False
ops.append(nn.ReLU(inplace=True))

self.conv = nn.Sequential(*ops)

def forward(self, x):
x = self.conv(x)
return x


class VNet(nn.Module):
def __init__(self, n_channels=3, n_classes=2, n_filters=16,
normalization='none', has_dropout=False):
super(VNet, self).__init__()
self.has_dropout = has_dropout

self.block_one = ConvBlock(1, n_channels, n_filters,
normalization=normalization)
self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters,
normalization=normalization)

self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2,
normalization=normalization)
self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4,
normalization=normalization)

self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4,
normalization=normalization)
self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8,
normalization=normalization)

self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8,
normalization=normalization)
self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16,
normalization=normalization)

self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16,
normalization=normalization)
self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8,
normalization=normalization)

self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8,
normalization=normalization)
self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4,
normalization=normalization)

self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4,
normalization=normalization)
self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2,
normalization=normalization)

self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2,
normalization=normalization)
self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters,
normalization=normalization)

self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)

# droppout rate = 0.5 用了两个dropout
self.dropout = nn.Dropout3d(p=0.5, inplace=False)
# self.__init_weight()

def encoder(self, input):
x1 = self.block_one(input)
x1_dw = self.block_one_dw(x1)

x2 = self.block_two(x1_dw)
x2_dw = self.block_two_dw(x2)

x3 = self.block_three(x2_dw)
x3_dw = self.block_three_dw(x3)

x4 = self.block_four(x3_dw)
x4_dw = self.block_four_dw(x4)

x5 = self.block_five(x4_dw)
# x5 = F.dropout3d(x5, p=0.5, training=True)
if self.has_dropout:
x5 = self.dropout(x5)

res = [x1, x2, x3, x4, x5]

return res

def decoder(self, features):
x1 = features[0]
x2 = features[1]
x3 = features[2]
x4 = features[3]
x5 = features[4]

x5_up = self.block_five_up(x5)
x5_up = x5_up + x4

x6 = self.block_six(x5_up)
x6_up = self.block_six_up(x6)
x6_up = x6_up + x3

x7 = self.block_seven(x6_up)
x7_up = self.block_seven_up(x7)
x7_up = x7_up + x2

x8 = self.block_eight(x7_up)
x8_up = self.block_eight_up(x8)
x8_up = x8_up + x1
x9 = self.block_nine(x8_up)
# x9 = F.dropout3d(x9, p=0.5, training=True)
if self.has_dropout:
x9 = self.dropout(x9)

out = self.out_conv(x9)
return out


def forward(self, input, turnoff_drop=False):
if turnoff_drop:
has_dropout = self.has_dropout
self.has_dropout = False
features = self.encoder(input)
out = self.decoder(features)
if turnoff_drop:
self.has_dropout = has_dropout
return out

参考博客:

  1. UNet 、3D-UNet 、VNet 区别_阿里云小仙女的博客-CSDN博客_3d unet

3 UNet++

3.1 结构与思路

UNet++ 在 UNet 的基础上的改进在于每一次下采样之后都会进行上采样进行特征拼接,作者认为既然医学影像处理中浅层特征和深层特征一样重要,那为什么要像 UNet 那样进行 4 次下采样才返回呢?于是他建立了一个将 1-4 层特征全部连接在一起的网络。这个网络按照作者的话说是将原来空心的 UNet 填满了

开始时只考虑了各到终点的输出,这种网络因为无法反向传播计算梯度而无法训练,而后想到在和终点之间也添加 Skip connection,网络这才可以训练。

该网络比起 UNet 的一个缺陷在于增加了网络参数量,作者认为参数并不是越多越好,而应该将参数用在刀刃上。他运用深监督的方法对 UNet++ 进行剪枝:训练时正常训练,测试时剪枝。在只降低了极小精度的情况下大幅降低了参数。

3.2 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.nn as nn

class conv_block_nested(nn.Module):
def __init__(self, in_ch, mid_ch, out_ch):
super(conv_block_nested, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3,
padding=1, bias=True)
self.bn1 = nn.BatchNorm2d(mid_ch)
self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3,
padding=1, bias=True)
self.bn2 = nn.BatchNorm2d(out_ch)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.activation(x)

x = self.conv2(x)
x = self.bn2(x)
output = self.activation(x)

return output

class Nested_UNet(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super(Nested_UNet, self).__init__()

n1 = 64
filters = [n1, n1 * 2, n1 * 4, n1 * 8, ]

self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

# 上采样
self.Up = nn.Upsample(scale_factor=2, mode="bilinear",
align_corners=True)

# 代表两次卷积操作
self.conv0_0 = conv_block_nested(in_ch,
filters[0],
filters[0])
self.conv1_0 = conv_block_nested(filters[0],
filters[1],
filters[1])
self.conv2_0 = conv_block_nested(filters[1],
filters[2],
filters[2])
self.conv3_0 = conv_block_nested(filters[2],
filters[3],
filters[3])
self.conv4_0 = conv_block_nested(filters[3],
filters[4],
filters[4])

self.conv0_1 = conv_block_nested(filters[0] + filters[1],
filters[0],
filters[1])
self.conv1_1 = conv_block_nested(filters[1] + filters[2],
filters[1],
filters[1])
self.conv2_1 = conv_block_nested(filters[2] + filters[3],
filters[2],
filters[2])
self.conv3_1 = conv_block_nested(filters[3] + filters[4],
filters[3],
filters[3])

# conv0_0 + conv0_1
self.conv0_2 = conv_block_nested(filters[0]*2 + filters[1],
filters[0],
filters[0])
self.conv1_2 = conv_block_nested(filters[1]*2 + filters[2],
filters[1],
filters[1])
self.conv2_2 = conv_block_nested(filters[2]*2 + filters[3],
filters[2],
filters[2])

self.conv0_3 = conv_block_nested(filters[0]*3 + filters[1],
filters[0],
filters[0])
self.conv1_3 = conv_block_nested(filters[1]*3 + filters[2],
filters[1],
filters[1])

self.conv0_4 = conv_block_nested(filters[0]*4 + filters[1],
filters[0],
filters[0])

self.final = nn.Conv2d(filters[0], out_ch, kernel_size=1)

def forward(self, x):
x0_0 = self.conv0_0(x)
x1_0 = self.conv1_0(self.pool(x0_0))
x0_1 = self.conv0_1(torch.cat([x0_0,
self.Up(x1_0)], 1))

x2_0 = self.conv2_0(self.pool(x1_0))
x1_1 = self.conv1_1(torch.cat([x1_0,
self.Up(x2_0)], 1))
x0_2 = self.conv0_2(torch.cat([x0_0, x0_1,
self.Up(x1_1)], 1))

x3_0 = self.conv3_0(self.pool(x2_0))
x2_1 = self.conv2_1(torch.cat([x2_0,
self.Up(x3_0)], 1))
x1_2 = self.conv1_2(torch.cat([x1_0, x1_1,
self.Up(x2_1)], 1))
x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2,
self.Up(x1_2)], 1))

x4_0 = self.conv4_0(self.pool(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0,
self.Up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, x2_1,
self.Up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2,
self.Up(x2_2)], 1))

x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3,
self.Up(x1_3)], 1))

output = self.fincal(x0_4)

return output

参考博客:

  1. 研习U-Net - 知乎 (zhihu.com)
  2. Biomedical Image Segmentation: UNet++ | by Jingles (Hong Jing) | Towards Data Science
  • 本文标题:医学图像分割网络
  • 本文作者:ZOU
  • 创建时间:2022-03-18 17:17:24
  • 本文链接:https://yipeng.xyz/2022/03/18/医学图像分割网络/
  • 版权声明:可随意使用,但是转载请联系我!
 评论