UNet/UNet++多类别分割

本文是一个UNet/UNet++多类别分割的实操,不介绍原理。
本文使用的代码: https://github.com/zonasw/unet-nested-multiple-classification

运行demo

  1. 下载代码:git clone https://github.com/zonasw/unet-nested-multiple-classification.git
  2. 下载demo数据集(或者从百度网盘下载,提取密码: dq7j)并解压到data文件夹中,该数据集中包含checkpoints, images, masks, test四个文件夹,其中images是图像数据集,masks是该数据集对应的标签,test是测试数据,checkpoints是在该数据集上预训练的模型。
  3. 训练 python train.py
  4. 推理 python inference.py -m ./data/checkpoints/epoch_10.pth -i ./data/test/input -o ./data/test/output

该数据集包含1500张128x128的图像,图像是程序生成的,包含三种类别: 背景、圆形、矩形,如下:
UNet/UNet++多类别分割

该模型识别背景,圆形,矩形三种类别,使用如下图像进行推理:
UNet/UNet++多类别分割

得到的推理结果为三个图像,这三个图像分别是背景、圆、矩形(白色像素为预测结果):
UNet/UNet++多类别分割

该数据集是由程序生成的,图像对应的标签是一个8位的单通道图像,值为相应的类别索引。

关于标签

假设有如下图像,该图像是一个10x10大小的图像,图像周围是空白背景,中心位置是一个圆形:
UNet/UNet++多类别分割
该图像包含两个类别,背景和圆,则背景位置对应的标签的像素值应该为0,圆对应的标签像素值应该为1,像下面这样:
0    0    0    0    0    0    0    0    0    0
0    0    0    0    0    0    0    0    0    0
0    0    0    1    1    1    1    0    0    0
0    0    1    1    1    1    1    1    0    0
0    0    1    1    1    1    1    1    0    0
0    0    1    1    1    1    1    1    0    0
0    0    1    1    1    1    1    1    0    0
0    0    0    1    1    1    1    0    0    0
0    0    0    0    0    0    0    0    0    0
0    0    0    0    0    0    0    0    0    0
由于该标签图像的值只包含0和1,所以它看起来整个图都是黑色的。

由于标签图像是8位的单通道图像,所以该方法支持最多256种类别。

损失函数

在计算多类别任务损失时,最开始是使用了交叉熵损失函数,交叉熵损失函数容易受到类别不平衡影响,后来改用了一种基于IOU的损失函数lovaszSoftmax,效果显著提升。