博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
语义分割模型的优化
阅读量:4142 次
发布时间:2019-05-25

本文共 9851 字,大约阅读时间需要 32 分钟。

语义分割模型的优化

当发现验证集指标和训练集指标相差较大时,主要可以检查这些原因:

  1. 数据集类别芜杂、数据量不够,先检查数据集和数据迭代器的质量。
  2. 如果训练集较快拟合,则模型过于庞大,降低了鲁棒性,可以降低batchsize或减少层数、卷积核数量。
  3. 若训练集指标也不正常,将学习率调整到[1e-2, 1e-6],各测试一遍。若loss曲线仍有问题,调整损失函数和激活函数。再不行,换个网络和显卡试试吧。

目录

常用优化技巧

  1. 数据增强

    使用 albumentations,随机旋转、镜像、模糊、色彩映射、噪音、曝光、翻转、色差平移。

  2. 模型的投票

    融合多个模型进行集成,采用多个模型对各个像素点进行投票,对于每张图片在不同scale(不同放大比例、旋转参数)下的结果融合。

  3. TTA (Test-Time Augmentation) ,即测试时的数据增

  4. 强。

    mask_xor = (mask^bg)&mask

  5. 联合损失函数BCE+Dice+Focal+lovasz_softmax

    将不同损失函数以一定权重混合,获得较好的鲁棒性。

  6. CRF后处理

    SegNet做语义分割时通常在末端加入CRF模块做后处理,旨在进一步精修边缘的分割结果。

关于学习率的优化

  1. Warm up 预热
    采用小学习率预训练。当训练时出现训练指标上升、验证集指标不动的奇怪现象时,此方法极其有效。预训练后再加大LR。
  2. 按val_miou减小
    使用keras.callbacks.ReduceLROnPlateau,根据val_miou的变化来动态调整lr。
  3. 余弦退火

损失函数和激活函数、通用评估指标

损失函数

损失函数选用keras.losses.binary_crossentropy

激活函数

单通道输出(二分类,[batch_size, 512, 512 ,1]),选用Sigmoid。

多通道输出(多分类,[batch_size, 512, 512 , N]),选用Softmax,其中第一层是背景层。

通用评估指标 IOU 代码

from tensorflow.keras import backend as Kdef Iou_score(y_true, y_pred):    '''总体的IOU'''    smooth = 1e-5    threhold = 0.5    # score calculation    y_pred = K.greater(y_pred, threhold)    y_pred = K.cast(y_pred, K.floatx())    intersection = K.sum(y_true * y_pred, axis=[0,1,2])    '''    这里y_pred为四维,[16,512,512,2],    axis=[0,1,2]时输出的intersection是每个类别的得分(准确的数量,[20, 30])。    [16,512,512,1]时候也通用。    '''    union = K.sum(y_true + y_pred, axis=[0,1,2]) - intersection    return (intersection + smooth) / (union + smooth)

实现数据增强的语义分割数据迭代器代码

from tensorflow.keras.utils import Sequenceimport random, os, gc, cv2import numpy as npseed = 295random.seed(seed)import osimport cv2 as cvimport albumentations as A # pip install albumentations -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.comclass SequenceData(Sequence):    def __init__(self, images_dir_X, images_dir_Y, img_size=(256,256),imgOnly=False,                 batch_size=1, classes=None, imgListTxt=None, isOneHot=True, dataEnhancement=True):        with open(imgListTxt) as f:            # 为了对应样本和标签 读入的文件名不带后缀            self.datas = list(f.readlines())        self.images_dir_X = images_dir_X        self.images_dir_Y = images_dir_Y        self.batch_size = batch_size        self.L = len(self.datas)        self.img_size = img_size        self.index = random.sample(range(self.L), self.L)        self.classes=classes        self.isOneHot=isOneHot        self.imgOnly=imgOnly                self.dataEnhancement=dataEnhancement        prob = 0.4        self.transform = A.Compose([            # A.RandomCrop(width=256, height=256),            A.HorizontalFlip(p=0.5),            A.RandomBrightnessContrast(brightness_limit=(0, 0.2),                                       contrast_limit=(0, 0.2), p=prob),            A.Rotate(limit=30, interpolation=cv2.INTER_CUBIC, border_mode=4, p=prob),            A.RandomGamma(gamma_limit=(80, 120), eps=1e-07, p=prob),            A.MotionBlur(blur_limit=5, p=prob),            A.IAASharpen(p=prob),            A.IAAPerspective(p=prob),            A.GaussNoise(var_limit=(10.0, 50.0), p=prob),                        A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=110,                                          val_shift_limit=10, p=prob),            A.RGBShift(r_shift_limit=5, g_shift_limit=5, b_shift_limit=5, p=prob),            A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, interpolation=2, p=prob),            A.GridDistortion(num_steps=5, distort_limit=0.3, interpolation=2, p=prob),        ])                   # 返回长度,通过len(
<你的实例>
)调用 def __len__(self): return int(np.ceil(len(self.datas) / self.batch_size)) # 通过索引获取a[0],a[1]这种 def __getitem__(self, idx): batch_indexs = self.index[idx:(idx+self.batch_size)] batch_datas = [self.datas[k] for k in batch_indexs] images = self.load_image_from_directory( images_dir=self.images_dir_X, img_size=self.img_size, suffix='.jpg', imgList=batch_datas) if self.imgOnly: return images labels = self.load_image_from_directory( images_dir=self.images_dir_Y, img_size=self.img_size, isLabel=True, isOneHot=self.isOneHot, classes=self.classes, suffix='.png', imgList=batch_datas) if self.dataEnhancement: images,labels = self.dataEnhancing(images,labels) return images,labels def load_image_from_directory(self, images_dir, img_size, isLabel=False, classes=None,isOneHot=True, suffix='.png', dtype=np.float32, imgList=None): """ 从数据集目录中加载图像数组 :param images_dir: 数据集目录,原图或标签,该文件夹下直接是图像,三通道,24位深度 :param img_size: 网络要求的图像大小 :param isLabel: 是否为标签, 测试集的标签由于只是显示作用,不参与训练,因此保持默认False即可 :param classes: 类,与isLabel相匹配,isLabel=True时,必须赋值 :param suffix: 图像后缀 :param dtype: 图像数据类型 :param imgList: 图像文件名列表 :return: 返回图像数组,[图像个数,高,宽,通道数] """ images_path = [] for fname in imgList: # if fname.endswith(suffix) and not fname.startswith('.'): # 数组填充,将图像绝对路径添加至数组images_path末尾 images_path.append(os.path.join(images_dir, fname[:-1]+suffix)) images_path = sorted(images_path) # 按顺序整理 # print(len(images_path)) images = [] # 创建空数组 for i, path in enumerate(images_path): img = cv.imdecode(np.fromfile(path, dtype=np.uint8), cv.IMREAD_COLOR) # 可为中文路径读取图像 # img = cv.imread(path) # 读取图像, 无中文路径 if img.shape[:2] is not img_size: # img = cv.resize(img, dsize=img_size, # interpolation=cv.INTER_CUBIC) # 不满足条件重置图像尺寸 img = self.resize_img_keep_ratio(img=img, target_size=img_size) img = img[:, :, ::-1] # 交换图像通道 if isLabel: if isOneHot: # 创建空数组用于储存图像数组 newImg = np.zeros(img.shape[:2] + (len(classes),), dtype=dtype) for j, value in enumerate(classes.values()): newImg[np.bitwise_and(np.bitwise_and( img[:, :, 0] == value[0], img[:, :, 1] == value[1]), img[:, :, 2] == value[2]), j] = 1 # 把标签转为one-hot形式 img = newImg else: img = img[:,:,0] / 255.0 img[img>=0.5] = 1.0 img[img<0.5] = 0.0 img = np.expand_dims(img, axis=-1) # 扩充维度 images.append(img) # images.append(self.resize_img_keep_ratio(img=img, target_size=img_size)) images = np.array(images, dtype=dtype) # 改变数组类型 if not isLabel: images = images/255.0 # print(images.shape) return images def resize_img_keep_ratio(self, img=None,img_name=None,target_size=(256,256)): ''' 1.resize图片,先计算最长边的resize的比例,然后按照该比例resize。 2.计算四个边需要padding的像素宽度,然后padding ''' if img is None: img = cv2.imread(img_name) old_size= img.shape[0:2] ratio = min(float(target_size[i])/(old_size[i]) for i in range(len(old_size))) new_size = tuple([int(i*ratio) for i in old_size]) img = cv2.resize(img,(new_size[1], new_size[0]),interpolation=cv2.INTER_CUBIC) #注意插值算法 pad_w = target_size[1] - new_size[1] pad_h = target_size[0] - new_size[0] top,bottom = pad_h//2, pad_h-(pad_h//2) left,right = pad_w//2, pad_w -(pad_w//2) img_new = cv2.copyMakeBorder(img,top,bottom,left,right,cv2.BORDER_CONSTANT,None,(0,0,0)) # return cv2.cvtColor(img_new, cv2.COLOR_BGR2RGB) return img_new def dataEnhancing(self,images,labels,dtype=np.float32): '''数据增强''' transformed_images,transformed_masks=[],[] # for image,mask in zip(images,labels): # print(images.shape[0]) for i in range(images.shape[0]): image=np.asarray(images[i,...]*255.0).astype(np.uint8) mask=labels[i,...] if(mask.shape[-1]==1): mask=mask.reshape(mask.shape[:-1]) transformed = self.transform(image=image, mask=mask) outImg = np.asarray(transformed["image"]).astype(dtype)/255.0 transformed_images.append(transformed["image"]) transformed_masks.append(transformed["mask"]) transformed_images=np.asarray(transformed_images,dtype=dtype) transformed_masks=np.asarray(transformed_masks,dtype=dtype) # print(transformed_images.shape,transformed_masks.shape) return transformed_images,transformed_masks # from tensorflow.keras.preprocessing.image import array_to_img# def showImg(frame):# array_to_img(frame).show()# classes = dict(# [('background', [0, 0, 0]), ('object', [255, 255, 255])]) # trainSDG = SequenceData(r".\ourData\combine\images\1",# r"ourData\combine\labels\1",# img_size=(512,512), classes=classes, batch_size=1,# imgListTxt=r".\ourData\combine\sets\train.txt", dataEnhancement=False)# xx=trainSDG.__getitem__(9)# x0=np.array(xx[0])[0,...]*255.0# x1=np.array(xx[1])[0,...]# showImg(np.expand_dims(x1[:,:,0],axis=-1))# showImg(np.expand_dims(x1[:,:,1],axis=-1))# showImg(x0)# from tensorflow.keras.preprocessing.image import array_to_img# def showImg(frame):# array_to_img(frame).show()# classes = dict(# [('background', [0, 0, 0]), ('object', [255, 255, 255])]) # trainSDG = SequenceData(r".\ourData\combine\images\1",# r"ourData\combine\labels\1", isOneHot=False,# img_size=(512,512), classes=classes, batch_size=2,# imgListTxt=r".\ourData\combine\sets\train.txt", dataEnhancement=False)# xx=trainSDG.__getitem__(5)# x0=np.array(xx[0])[0,...]*255.0# x1=np.array(xx[1])[0,...]# showImg(x1[:,:])# showImg(x0)

转载地址:http://vrzti.baihongyu.com/

你可能感兴趣的文章
两个linux内核rootkit--之二:adore-ng
查看>>
两个linux内核rootkit--之一:enyelkm
查看>>
关于linux栈的一个深层次的问题
查看>>
rootkit related
查看>>
配置文件的重要性------轻化操作
查看>>
又是缓存惹的祸!!!
查看>>
为什么要实现程序指令和程序数据的分离?
查看>>
我对C++ string和length方法的一个长期误解------从protobuf序列化说起(没处理好会引起数据丢失、反序列化失败哦!)
查看>>
一起来看看protobuf中容易引起bug的一个细节
查看>>
无protobuf协议情况下的反序列化------貌似无解, 其实有解!
查看>>
make -n(仅列出命令, 但不会执行)用于调试makefile
查看>>
makefile中“-“符号的使用
查看>>
go语言如何从终端逐行读取数据?------用bufio包
查看>>
go的值类型和引用类型------重要的概念
查看>>
求二叉树中结点的最大值(所有结点的值都是正整数)
查看>>
用go的flag包来解析命令行参数
查看>>
来玩下go的http get
查看>>
感受一下go协程goroutine------协程在手,说go就go
查看>>
队列和栈的本质区别
查看>>
matlab中inline的用法
查看>>