Deeply Supervised Salient Object Detection with Short Connections在自己数据集上训练(tensorflow版本)

代码链接:gbyy422990/salience_object_detection
感谢作者:gbyy42299/CSDN

训练配置过程

下载代码和预训练模型

代码百度云备份, 密码:5qv2)
VGG16.npy百度云备份, 密码:cy99)

将下载的vgg16.npy文件放到代码根目录

样本集的制作

新建文件夹

数据集存储在根目录中dataset文件夹下。首先删除dataset文件夹中的所有文件夹中所有文件及文件夹。重新新建四个文件夹,plane、planelabel、planetest、planetestlabel,这四个文件夹分别用来放置训练用图片、训练用图片的mask、验证用图片、验证用图片的mask。

样本的制作

样本说明

格式如下图所示,左图为plane、planetest文件夹中图片,右图为planelabel和planetestlabel中mask。
Deeply Supervised Salient Object Detection with Short Connections在自己数据集上训练(tensorflow版本)

resize

输入的样本图片及mask的大小都为宽400高300的三通道图片(24位)。因此需要将自己的图片放进四个文件夹。

如果需要将图片转换为400宽,300高,可以使用以下代码

import cv2
import glob
import os

#处理三通道图片
def threeChannel(inDir):

    for jpgfile in glob.glob(inDir):
        img = cv2.imread(jpgfile)
        resized = cv2.resize(img, (400, 300), cv2.INTER_CUBIC)
        cv2.imwrite(jpgfile, resized, [int(cv2.IMWRITE_PNG_COMPRESSION), 9])
        print(jpgfile)
    print('----------------------------------------------------')


# 处理单通道图片
def oneChannel(inDir):
    for jpgfile in glob.glob(inDir):
        img = cv2.imread(jpgfile, cv2.IMREAD_GRAYSCALE)
        resized = cv2.resize(img, (400, 300), cv2.INTER_CUBIC)
        cv2.imwrite(jpgfile, resized, [int(cv2.IMWRITE_PNG_COMPRESSION), 9])
        print(jpgfile)
    print('----------------------------------------------------')


#resize操作
if __name__ == '__main__':
    inDir = "plane/*.png"
    inDir2 = "planelabel/*.png"
    inDir3 = "planetest/*.png"
    inDir4 = "planetestlabel/*.png"

    threeChannel(inDir)
    threeChannel(inDir3)
    oneChannel(inDir2)
    oneChannel(inDir4)

如果需要将自己单通道mask转换为三通道,可使用如下代码(根据自己情况进行修改)

import cv2
import glob
import os


def main(inDir):
    for jpgfile in glob.glob(inDir):
        img = cv2.imread(jpgfile, 0)
        merged = cv2.merge([img, img, img])
        cv2.imwrite(jpgfile, merged, [int(cv2.IMWRITE_PNG_COMPRESSION), 9])
        print(jpgfile)
    print('----------------------------------------------------')


if __name__ == '__main__':
    inDir = "planetestlabel/*.png"
    main(inDir)
    inDir2 = "planelabel/*.png"
    main(inDir2)

生成csv文件

修改根目录下的csc_generator.py文件并运行,在根目录下生成两个csv文件,以下为修改后代码

#coding:utf-8
import os
import csv

def create_csv(dirname):
    path = './dataset/'+ dirname +'/'
    name = os.listdir(path)
    #name.sort(key=lambda x: int(x.split('.')[0]))
    #print(name)
    with open (dirname+'.csv','w', encoding="UTF8", newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['data', 'label'])
        for n in name:
            if n[-4:] == '.png':
                print(n)
                # with open('data_'+dirname+'.csv','rb') as f:
                str1 = './dataset/'+str(dirname) +'/'+ str(n)
                str2 = './dataset/' + str(dirname) + 'label/' + str(n[:-4] + '.png')
                writer.writerow([str1, str2])
            else:
                pass

if __name__ == "__main__":
    create_csv('plane')
    create_csv('planetest')

开始训练

修改train.py文件,然后运行

# coding:utf-8
# Bin GAO

import os
import tensorflow as tf
import numpy as np
import argparse
import pandas as pd
import model
import time

from model import train_op
from model import loss_CE, loss_IOU

h = 300  # 4032
w = 400  # 3024
c_image = 3
c_label = 1
g_mean = [142.53, 129.53, 120.20]

parser = argparse.ArgumentParser()
parser.add_argument('--pretrained',
                    type=int,
                    default=1) #是否使用预训练模型

parser.add_argument('--data_dir',
                    default='./plane.csv') #训练集

parser.add_argument('--test_dir',
                    default='./planetest.csv') #验证集

parser.add_argument('--model_dir',
                    default='./model1') #生成的模型存在该文件夹

parser.add_argument('--epochs',
                    type=int,
                    default=100) #训练epoch

parser.add_argument('--peochs_per_eval',
                    type=int,
                    default=1) #每训练多少次进行验证

parser.add_argument('--logdir',
                    default='./logs1') #生成的日志存放在该文件夹

parser.add_argument('--batch_size',
                    type=int,
                    default=1) #batch_size

parser.add_argument('--is_cross_entropy',
                    action='store_true',
                    default=True)

parser.add_argument('--learning_rate',
                    type=float,
                    default=1e-3) #初始学习率

# 衰减系数
parser.add_argument('--decay_rate',
                    type=float,
                    default=0.9)

# 衰减速度model
parser.add_argument('--decay_step',
                    type=int,
                    default=100)

parser.add_argument('--weight',
                    nargs='+',
                    type=float,
                    default=[1.0, 1.0])

parser.add_argument('--random_seed',
                    type=int,
                    default=1234)

parser.add_argument('--gpu',
                    type=str,
                    default=1)

flags = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def set_config():
    ''''#允许增长
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
    '''

    # 控制使用率
    os.environ['CUDA_VISIBLE_DEVICES'] = str(flags.gpu)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1)
    config = tf.ConfigProto(gpu_options=gpu_options)
    return tf.Session(config=config)


def data_augmentation(image, label, training=True):
    if training:
        image_label = tf.concat([image, label], axis=-1)
        print('image label shape concat', image_label.get_shape())

        maybe_flipped = tf.image.random_flip_left_right(image_label)
        maybe_flipped = tf.image.random_flip_up_down(maybe_flipped)
        # maybe_flipped = tf.random_crop(maybe_flipped,size=[h/2,w/2,image_label.get_shape()[-1]])

        image = maybe_flipped[:, :, :-1]
        mask = maybe_flipped[:, :, -1:]

        # image = tf.image.random_brightness(image, 0.7)
        # image = tf.image.random_hue(image, 0.3)
        # 设置随机的对比度
        # tf.image.random_contrast(image,lower=0.3,upper=1.0)
        return image, mask


def read_csv(queue, augmentation=True):
    csv_reader = tf.TextLineReader(skip_header_lines=1)

    _, csv_content = csv_reader.read(queue)

    image_path, label_path = tf.decode_csv(csv_content, record_defaults=[[""], [""]])

    image_file = tf.read_file(image_path)
    label_file = tf.read_file(label_path)

    image = tf.image.decode_jpeg(image_file, channels=3)
    image.set_shape([h, w, c_image])
    image = tf.cast(image, tf.float32)

    label = tf.image.decode_jpeg(label_file, channels=1)
    label.set_shape([h, w, c_label])

    label = tf.cast(label, tf.float32)
    # label = label / (tf.reduce_max(label) + 1e-7)
    label = label / 255

    # 数据增强
    if augmentation:
        image, label = data_augmentation(image, label)
    else:
        pass
    return image, label


def main(flags):
    current_time = time.strftime("%m/%d/%H/%M/%S")

#训练和验证使用的文件夹
    train_logdir = os.path.join(flags.logdir, "plane", current_time)
    test_logdir = os.path.join(flags.logdir, "planetest", current_time)

    if not os.path.exists(train_logdir):
        os.mkdir(train_logdir)
    if not os.path.exists(test_logdir):
        os.mkdir(test_logdir)

    train = pd.read_csv(flags.data_dir)

    num_train = train.shape[0]

    test = pd.read_csv(flags.test_dir)
    num_test = test.shape[0]

    tf.reset_default_graph()
    X = tf.placeholder(tf.float32, shape=[None, h, w, c_image], name='X')
    y = tf.placeholder(tf.float32, shape=[None, h, w, c_label], name='y')
    mode = tf.placeholder(tf.bool, name='mode')

    score_dsn6_up, score_dsn5_up, score_dsn4_up, score_dsn3_up, score_dsn2_up, score_dsn1_up, upscore_fuse = model.unet(
        X, mode)

    # print(score_dsn6_up.get_shape().as_list())

    loss6 = loss_CE(score_dsn6_up, y)
    loss5 = loss_CE(score_dsn5_up, y)
    loss4 = loss_CE(score_dsn4_up, y)
    loss3 = loss_CE(score_dsn3_up, y)
    loss2 = loss_CE(score_dsn2_up, y)
    loss1 = loss_CE(score_dsn1_up, y)
    loss_fuse = loss_CE(upscore_fuse, y)
    tf.summary.scalar("CE6", loss6)
    tf.summary.scalar("CE5", loss5)
    tf.summary.scalar("CE4", loss4)
    tf.summary.scalar("CE3", loss3)
    tf.summary.scalar("CE2", loss2)
    tf.summary.scalar("CE1", loss1)
    tf.summary.scalar("CE_fuse", loss_fuse)

    Loss = loss6 + loss5 + loss4 + loss3 + loss2 + 2 * loss1 + loss_fuse
    tf.summary.scalar("CE_total", Loss)

    global_step = tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    learning_rate = tf.train.exponential_decay(flags.learning_rate, global_step,
                                               decay_steps=flags.decay_step,
                                               decay_rate=flags.decay_rate, staircase=True)

    with tf.control_dependencies(update_ops):
        training_op = train_op(Loss, learning_rate)

#修改数据读取的csv
    train_csv = tf.train.string_input_producer(['plane.csv'])
    test_csv = tf.train.string_input_producer(['planetest.csv'])

    train_image, train_label = read_csv(train_csv, augmentation=True)
    test_image, test_label = read_csv(test_csv, augmentation=False)

    # batch_size是返回的一个batch样本集的样本个数。capacity是队列中的容量
    X_train_batch_op, y_train_batch_op = tf.train.shuffle_batch([train_image, train_label], batch_size=flags.batch_size,
                                                                capacity=flags.batch_size * 5,
                                                                min_after_dequeue=flags.batch_size * 2,
                                                                allow_smaller_final_batch=True)

    X_test_batch_op, y_test_batch_op = tf.train.batch([test_image, test_label], batch_size=flags.batch_size,
                                                      capacity=flags.batch_size * 2, allow_smaller_final_batch=True)

    print('Shuffle batch done')
    # tf.summary.scalar('loss/Cross_entropy', CE_op)
    score_dsn6_up = tf.nn.sigmoid(score_dsn6_up)
    score_dsn5_up = tf.nn.sigmoid(score_dsn5_up)
    score_dsn4_up = tf.nn.sigmoid(score_dsn4_up)
    score_dsn3_up = tf.nn.sigmoid(score_dsn3_up)
    score_dsn2_up = tf.nn.sigmoid(score_dsn2_up)
    score_dsn1_up = tf.nn.sigmoid(score_dsn1_up)
    upscore_fuse = tf.nn.sigmoid(upscore_fuse)
    print(upscore_fuse.get_shape().as_list())

    tf.add_to_collection('inputs', X)
    tf.add_to_collection('inputs', mode)
    tf.add_to_collection('score_dsn6_up', score_dsn6_up)
    tf.add_to_collection('score_dsn5_up', score_dsn5_up)
    tf.add_to_collection('score_dsn4_up', score_dsn4_up)
    tf.add_to_collection('score_dsn3_up', score_dsn3_up)
    tf.add_to_collection('score_dsn2_up', score_dsn2_up)
    tf.add_to_collection('score_dsn1_up', score_dsn1_up)
    tf.add_to_collection('upscore_fuse', upscore_fuse)

    tf.summary.image('Input Image:', X)
    tf.summary.image('Label:', y)
    tf.summary.image('score_dsn6_up:', score_dsn6_up)
    tf.summary.image('score_dsn5_up:', score_dsn5_up)
    tf.summary.image('score_dsn4_up:', score_dsn4_up)
    tf.summary.image('score_dsn3_up:', score_dsn3_up)
    tf.summary.image('score_dsn2_up:', score_dsn2_up)
    tf.summary.image('score_dsn1_up:', score_dsn1_up)
    tf.summary.image('upscore_fuse:', upscore_fuse)

    tf.summary.scalar("learning_rate", learning_rate)

    # 添加任意shape的Tensor,统计这个Tensor的取值分布
    tf.summary.histogram('score_dsn1_up:', score_dsn1_up)
    tf.summary.histogram('score_dsn2_up:', score_dsn2_up)
    tf.summary.histogram('score_dsn3_up:', score_dsn3_up)
    tf.summary.histogram('score_dsn4_up:', score_dsn4_up)
    tf.summary.histogram('score_dsn5_up:', score_dsn5_up)
    tf.summary.histogram('score_dsn6_up:', score_dsn6_up)
    tf.summary.histogram('upscore_fuse:', upscore_fuse)

    # 添加一个操作,代表执行所有summary操作,这样可以避免人工执行每一个summary op
    summary_op = tf.summary.merge_all()

    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter(train_logdir, sess.graph)
        test_writer = tf.summary.FileWriter(test_logdir)

        init = tf.global_variables_initializer()
        sess.run(init)

        saver = tf.train.Saver()

        if flags.pretrained == 1:
            if os.path.exists(flags.model_dir) and tf.train.checkpoint_exists(flags.model_dir):
                latest_check_point = tf.train.latest_checkpoint(flags.model_dir)
                saver.restore(sess, latest_check_point)
                print('start with pre-trained model')
            else:
                print('no model')
        else:
            print('start without pre-trained model')
            if not os.path.exists(flags.model_dir):
                os.mkdir(flags.model_dir)

        try:
            # global_step = tf.train.get_global_step(sess.graph)

            # 使用tf.train.string_input_producer(epoch_size, shuffle=False),会默认将QueueRunner添加到全局图中,
            # 我们必须使用tf.train.start_queue_runners(sess=sess),去启动该线程。要在session当中将该线程开启,不然就会挂起。然后使用coord= tf.train.Coordinator()去做一些线程的同步工作,
            # 否则会出现运行到sess.run一直卡住不动的情况。
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            for epoch in range(flags.epochs):
                for step in range(0, num_train, flags.batch_size):
                    X_train, y_train = sess.run([X_train_batch_op, y_train_batch_op])
                    _, step_ce, step_summary, global_step_value = sess.run([training_op, Loss, summary_op, global_step],
                                                                           feed_dict={X: X_train, y: y_train,
                                                                                      mode: True})
                    train_writer.add_summary(step_summary, global_step_value)
                    print('epoch:{} step:{} loss_CE:{}'.format(epoch + 1, global_step_value, step_ce))

                for step in range(0, num_test, flags.batch_size):
                    X_test, y_test = sess.run([X_test_batch_op, y_test_batch_op])
                    step_ce, step_summary = sess.run([Loss, summary_op], feed_dict={X: X_test, y: y_test, mode: False})

                    test_writer.add_summary(step_summary, epoch *
                                            (num_train // flags.batch_size)
                                            + step // flags.batch_size * num_train // num_test)
                    print('Test loss_CE:{}'.format(step_ce))
                saver.save(sess, '{}/model.ckpt'.format(flags.model_dir))

        finally:
            coord.request_stop()
            coord.join(threads)
            saver.save(sess, "{}/model.ckpt".format(flags.model_dir))


if __name__ == '__main__':
    # set_config()
    main(flags)

用训练好的模型进行显著性检测

施工中……