700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 基于SegNet和UNet的遥感图像分割代码解读

基于SegNet和UNet的遥感图像分割代码解读

时间:2020-05-02 05:15:45

相关推荐

基于SegNet和UNet的遥感图像分割代码解读

基于SegNet和UNet的遥感图像分割代码解读

目录

基于SegNet和UNet的遥感图像分割代码解读前言概述代码框架代码细节分析划分数据集gen_dataset.pyUNet模型训练unet_train.py模型融合combind.pyUNet模型预测unet_predict.py分类结果集成ensemble.pySegNet模型训练segnet_train.py

前言

上了一学期的课,趁着寒假有时间,看了往年论文和部分比赛的代码,现在整理出来。整理的这部分内容以实际操作为主,主要讲解代码部分的分析。

概述

首先来分享一个小项目,基于SegNet和UNet的遥感图像比赛。代码来自github,这是对项目的简要介绍。

代码框架

以下是项目的代码结构:总共有4个子目录,分别是deprecated、ensemble、segnet、unet,其中deprecated是作者的一些代码草稿,ensemble是对不同分类结果的集成,segnet和unet分别是两个典型网络的网络架构、训练代码、预测代码、划分训练集和测试集的代码。

代码细节分析

划分数据集gen_dataset.py

import cv2import randomimport osimport numpy as npfrom tqdm import tqdmimg_w = 256 img_h = 256 # 数据集一共5张图片image_sets = ['1.png','2.png','3.png','4.png','5.png']def gamma_transform(img, gamma):gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)]gamma_table = np.round(np.array(gamma_table)).astype(np.uint8)# LUT: Look Up Table查找表,通过LUT变换可以改变图像的曝光和色彩return cv2.LUT(img, gamma_table)def random_gamma_transform(img, gamma_vari):log_gamma_vari = np.log(gamma_vari)alpha = np.random.uniform(-log_gamma_vari, log_gamma_vari)gamma = np.exp(alpha)return gamma_transform(img, gamma)# 旋转imagedef rotate(xb,yb,angle):M_rotate = cv2.getRotationMatrix2D((img_w/2, img_h/2), angle, 1)xb = cv2.warpAffine(xb, M_rotate, (img_w, img_h))yb = cv2.warpAffine(yb, M_rotate, (img_w, img_h))return xb,ybdef blur(img):# cv2.blur(img,(size,size))表示对img使用尺寸为size x size的均值滤波器进行平滑img = cv2.blur(img, (3, 3));return img# 加噪声def add_noise(img):for i in range(200): #添加点噪声temp_x = np.random.randint(0,img.shape[0])temp_y = np.random.randint(0,img.shape[1])img[temp_x][temp_y] = 255return img# 数据增强:图像旋转、gamma变换、模糊变换、加噪声def data_augment(xb,yb):if np.random.random() < 0.25:xb,yb = rotate(xb,yb,90)if np.random.random() < 0.25:xb,yb = rotate(xb,yb,180)if np.random.random() < 0.25:xb,yb = rotate(xb,yb,270)if np.random.random() < 0.25:xb = cv2.flip(xb, 1) # flipcode > 0:沿y轴翻转yb = cv2.flip(yb, 1)if np.random.random() < 0.25:xb = random_gamma_transform(xb,1.0)if np.random.random() < 0.25:xb = blur(xb)if np.random.random() < 0.2:xb = add_noise(xb)return xb,yb# 构建数据集def creat_dataset(image_num = 50000, mode = 'original'):print('creating dataset...')# len(image_sets) = 5image_each = image_num / len(image_sets)g_count = 0for i in tqdm(range(len(image_sets))):count = 0# 读取源图像和标记图像src_img = cv2.imread('./data/src/' + image_sets[i]) # 3 channelslabel_img = cv2.imread('./data/road_label/' + image_sets[i],cv2.IMREAD_GRAYSCALE) # single channelX_height,X_width,_ = src_img.shapewhile count < image_each:# img_w = img_h = 256random_width = random.randint(0, X_width - img_w - 1)random_height = random.randint(0, X_height - img_h - 1)# 随机截取img_h x img_w大小的图像src_roi = src_img[random_height: random_height + img_h, random_width: random_width + img_w,:]label_roi = label_img[random_height: random_height + img_h, random_width: random_width + img_w]# 如果是增强模式,那么对源图像和标记图像使用数据增强if mode == 'augment':src_roi,label_roi = data_augment(src_roi,label_roi)visualize = np.zeros((256,256)).astype(np.uint8)visualize = label_roi *50# 划分数据集cv2.imwrite(('./unet_train/visualize/%d.png' % g_count),visualize)cv2.imwrite(('./unet_train/road/src/%d.png' % g_count),src_roi)cv2.imwrite(('./unet_train/road/label/%d.png' % g_count),label_roi)count += 1 g_count += 1if __name__=='__main__': creat_dataset(mode='augment')

UNet模型训练unet_train.py

#coding=utf-8import matplotlib# matplotlib.use('Agg')必须放在import matplotlib.pyplot as plt前面,这个语句的意思是不使用交互式页面,仅仅保存图像而是不把图像shhow出来matplotlib.use("Agg")import matplotlib.pyplot as pltimport argparseimport numpy as np from keras.models import Sequential from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation,Input from keras.utils.np_utils import to_categorical from keras.preprocessing.image import img_to_array from keras.callbacks import ModelCheckpoint from sklearn.preprocessing import LabelEncoder from keras.models import Modelfrom keras.layers.merge import concatenatefrom PIL import Image import matplotlib.pyplot as plt import cv2import randomimport osfrom tqdm import tqdm os.environ["CUDA_VISIBLE_DEVICES"] = "4"# 设置随机数种子,以便每次产生的随机数一样,方便比较在同一批数据上比较实验结果seed = 7 np.random.seed(seed) #data_shape = 360*480 img_w = 256 img_h = 256 #有一个为背景 #n_label = 4+1 n_label = 1# 总共5个类别classes = [0. , 1., 2., 3. , 4.] labelencoder = LabelEncoder() labelencoder.fit(classes) image_sets = ['1.png','2.png','3.png']def load_img(path, grayscale=False):if grayscale:# cv2.IMREAD_GRAYSCALE将灰度图读取成灰度图,否则cv2.imread默认将图像读取为RGBimg = cv2.imread(path,cv2.IMREAD_GRAYSCALE)else:img = cv2.imread(path)# 归一化img = np.array(img,dtype="float") / 255.0return img# 训练数据路径filepath ='./unet_train/' # 划分训练集和验证集,其中用25%的数据来做验证集def get_train_val(val_rate = 0.25):train_url = [] train_set = []val_set = []for pic in os.listdir(filepath + 'src'):train_url.append(pic)random.shuffle(train_url)total_num = len(train_url)val_num = int(val_rate * total_num)# 打乱顺序之后的前25%作为验证集,剩余75%作为训练集for i in range(len(train_url)):if i < val_num:val_set.append(train_url[i]) else:train_set.append(train_url[i])return train_set,val_set# 产生训练数据# data for training def generateData(batch_size,data=[]): #print 'generateData...'while True: train_data = [] train_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img) train_data.append(img)label = load_img(filepath + 'label/' + url, grayscale=True) label = img_to_array(label)train_label.append(label) if batch % batch_size==0: #print 'get enough batch!\n'train_data = np.array(train_data) train_label = np.array(train_label) yield (train_data,train_label) train_data = [] train_label = [] batch = 0 # 产生验证数据# data for validation def generateValidData(batch_size,data=[]): #print 'generateValidData...'while True: valid_data = [] valid_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img) valid_data.append(img) label = load_img(filepath + 'label/' + url, grayscale=True)label = img_to_array(label)valid_label.append(label) if batch % batch_size==0: valid_data = np.array(valid_data) valid_label = np.array(valid_label) yield (valid_data,valid_label) valid_data = [] valid_label = [] batch = 0 # 定义unet,整体上来看是一个对称的U型结构def unet():inputs = Input((3, img_w, img_h))conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)conv5 = MaxPooling2D(pool_size=(2,2))(conv5)# 引入上采样将特征图方法,就是简单的插值。其中,UpSampling2D(size = size)(x),执行的操作是分别将x的行和列重复size[0]和size[1]次# 例如令size = [2,2], 从[[1,2],[3,4]]变成[[1,1,2,2],[1,1,2,2],[3,3,4,4],[3,3,4,4]]up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=1)conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(up6)conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=1)conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(up7)conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=1)conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(up8)conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=1)conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(up9)conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)conv10 = Conv2D(n_label, (1, 1), activation="sigmoid")(conv9)#conv10 = Conv2D(n_label, (1, 1), activation="softmax")(conv9)model = Model(inputs=inputs, outputs=conv10)# 使用二元分类的cross_entropy,直接用cross_entropy也可以,多分类问题也适用于二分类问题pile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])return modeldef train(args): EPOCHS = 10# batch_sizeBS = 16#model = SegNet() model = unet()modelcheck = ModelCheckpoint(args['model'],monitor='val_accuracy',save_best_only=True,mode='max') callable = [modelcheck] train_set,val_set = get_train_val()train_numb = len(train_set) valid_numb = len(val_set) print ("the number of train data is",train_numb) print ("the number of val data is",valid_numb)# max_q_size定义了内部训练队列(queue)的最大大小H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1, validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1) # plot the training loss and accuracy# plt.style.use('ggplot')用ggplot样式美化画图效果# 可选的plt.style(plt.style.available)如下:# ['bmh', 'classic', 'dark_background', 'fast', 'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn-bright', 'seaborn-colorblind', # 'seaborn-dark-palette', 'seaborn-dark', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 'seaborn-notebook', 'seaborn-paper',# 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'seaborn', # 'Solarize_Light2', 'tableau-colorblind10', '_classic_test']plt.style.use("ggplot")plt.figure()N = EPOCHSplt.plot(np.arange(0, N), H.history["loss"], label="train_loss")plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")plt.title("Training Loss and Accuracy on U-Net Satellite Seg")plt.xlabel("Epoch #")plt.ylabel("Loss/Accuracy")# 在右下角画图plt.legend(loc="lower left")plt.savefig(args["plot"])# 命令行输入参数的提示以及默认参数def args_parse():# construct the argument parse and parse the argumentsap = argparse.ArgumentParser()ap.add_argument("-d", "--data", help="training data's path",default=True)ap.add_argument("-m", "--model", required=True,help="path to output model")ap.add_argument("-p", "--plot", type=str, default="plot.png",help="path to output accuracy/loss plot")args = vars(ap.parse_args()) return argsif __name__=='__main__': args = args_parse()filepath = args['data']train(args) #predict()

为了看清楚unet的每一层的输入输出的tensor是怎么样的形状,我们将其打印出来如下:

__________________________________________________________________________________________________Layer (type)Output Shape Param #Connected to==================================================================================================input_7 (InputLayer) (None, 3, 256, 256) 0__________________________________________________________________________________________________conv2d_79 (Conv2D) (None, 32, 256, 256) 896 input_7[0][0]__________________________________________________________________________________________________conv2d_80 (Conv2D) (None, 32, 256, 256) 9248 conv2d_79[0][0]__________________________________________________________________________________________________max_pooling2d_29 (MaxPooling2D) (None, 32, 128, 128) 0 conv2d_80[0][0]__________________________________________________________________________________________________conv2d_81 (Conv2D) (None, 64, 128, 128) 18496 max_pooling2d_29[0][0]__________________________________________________________________________________________________conv2d_82 (Conv2D) (None, 64, 128, 128) 36928 conv2d_81[0][0]__________________________________________________________________________________________________max_pooling2d_30 (MaxPooling2D) (None, 64, 64, 64) 0 conv2d_82[0][0]__________________________________________________________________________________________________conv2d_83 (Conv2D) (None, 128, 64, 64) 73856 max_pooling2d_30[0][0]__________________________________________________________________________________________________conv2d_84 (Conv2D) (None, 128, 64, 64) 147584conv2d_83[0][0]__________________________________________________________________________________________________max_pooling2d_31 (MaxPooling2D) (None, 128, 32, 32) 0 conv2d_84[0][0]__________________________________________________________________________________________________conv2d_85 (Conv2D) (None, 256, 32, 32) 295168max_pooling2d_31[0][0]__________________________________________________________________________________________________conv2d_86 (Conv2D) (None, 256, 32, 32) 590080conv2d_85[0][0]__________________________________________________________________________________________________max_pooling2d_32 (MaxPooling2D) (None, 256, 16, 16) 0 conv2d_86[0][0]__________________________________________________________________________________________________conv2d_87 (Conv2D) (None, 512, 16, 16) 1180160max_pooling2d_32[0][0]__________________________________________________________________________________________________conv2d_88 (Conv2D) (None, 512, 16, 16) 2359808conv2d_87[0][0]__________________________________________________________________________________________________up_sampling2d_13 (UpSampling2D) (None, 512, 32, 32) 0 conv2d_88[0][0]__________________________________________________________________________________________________concatenate_13 (Concatenate) (None, 768, 32, 32) 0 up_sampling2d_13[0][0]conv2d_86[0][0]__________________________________________________________________________________________________conv2d_89 (Conv2D) (None, 256, 32, 32) 1769728concatenate_13[0][0]__________________________________________________________________________________________________conv2d_90 (Conv2D) (None, 256, 32, 32) 590080conv2d_89[0][0]__________________________________________________________________________________________________up_sampling2d_14 (UpSampling2D) (None, 256, 64, 64) 0 conv2d_90[0][0]__________________________________________________________________________________________________concatenate_14 (Concatenate) (None, 384, 64, 64) 0 up_sampling2d_14[0][0]conv2d_84[0][0]__________________________________________________________________________________________________conv2d_91 (Conv2D) (None, 128, 64, 64) 442496concatenate_14[0][0]__________________________________________________________________________________________________conv2d_92 (Conv2D) (None, 128, 64, 64) 147584conv2d_91[0][0]__________________________________________________________________________________________________up_sampling2d_15 (UpSampling2D) (None, 128, 128, 128 0 conv2d_92[0][0]__________________________________________________________________________________________________concatenate_15 (Concatenate) (None, 192, 128, 128 0 up_sampling2d_15[0][0]conv2d_82[0][0]__________________________________________________________________________________________________conv2d_93 (Conv2D) (None, 64, 128, 128) 110656concatenate_15[0][0]__________________________________________________________________________________________________conv2d_94 (Conv2D) (None, 64, 128, 128) 36928 conv2d_93[0][0]__________________________________________________________________________________________________up_sampling2d_16 (UpSampling2D) (None, 64, 256, 256) 0 conv2d_94[0][0]__________________________________________________________________________________________________concatenate_16 (Concatenate) (None, 96, 256, 256) 0 up_sampling2d_16[0][0]conv2d_80[0][0]__________________________________________________________________________________________________conv2d_95 (Conv2D) (None, 32, 256, 256) 27680 concatenate_16[0][0]__________________________________________________________________________________________________conv2d_96 (Conv2D) (None, 32, 256, 256) 9248 conv2d_95[0][0]__________________________________________________________________________________________________conv2d_97 (Conv2D) (None, 1, 256, 256) 33conv2d_96[0][0]==================================================================================================Total params: 7,846,657Trainable params: 7,846,657Non-trainable params: 0__________________________________________________________________________________________________

模型融合combind.py

#coding=utf-8import numpy as npimport cv2import csvfrom tqdm import tqdm# 定义三个maskmask1_pool = ['testing1_vegetation_predict.png','testing1_building_predict.png','testing1_water_predict.png','testing1_road_predict.png']mask2_pool = ['testing2_vegetation_predict.png','testing2_building_predict.png','testing2_water_predict.png','testing2_road_predict.png']mask3_pool = ['testing3_vegetation_predict.png','testing3_building_predict.png','testing3_water_predict.png','testing3_road_predict.png'] ## 0:none 1:vegetation 2:building 3:water 4:road#after mask combindimg_sets = ['pre1.png','pre2.png','pre3.png']def combind_all_mask():for mask_num in tqdm(range(3)):if mask_num == 0:final_mask = np.zeros((5142,5664),np.uint8)#生成一个全黑全0图像,图片尺寸与原图相同elif mask_num == 1:final_mask = np.zeros((2470,4011),np.uint8)elif mask_num == 2:final_mask = np.zeros((6116,3356),np.uint8)#final_mask = cv2.imread('final_1_8bits_predict.png',0)if mask_num == 0:mask_pool = mask1_poolelif mask_num == 1:mask_pool = mask2_poolelif mask_num == 2:mask_pool = mask3_poolfinal_name = img_sets[mask_num]for idx,name in enumerate(mask_pool):img = cv2.imread('./predict_mask/'+name,0)height,width = img.shapelabel_value = idx+1 #coressponding labels valuefor i in tqdm(range(height)): #priority:building>water>road>vegetationfor j in range(width):# 模型融合if img[i,j] == 255:# 如果当前像素为全部为全白,那么到底这个区域属于哪个类别呢?按照优先级的顺序来定:building>water>road>vegetationif label_value == 2:final_mask[i,j] = label_valueelif label_value == 3 and final_mask[i,j] != 2:final_mask[i,j] = label_valueelif label_value == 4 and final_mask[i,j] != 2 and final_mask[i,j] != 3:final_mask[i,j] = label_valueelif label_value == 1 and final_mask[i,j] == 0:final_mask[i,j] = label_value cv2.imwrite('./final_result/'+final_name,final_mask) print 'combinding mask...'combind_all_mask()

UNet模型预测unet_predict.py

import cv2import randomimport numpy as npimport osimport argparsefrom keras.preprocessing.image import img_to_arrayfrom keras.models import load_modelfrom sklearn.preprocessing import LabelEncoder # 设置用编号为1的GPU来训练os.environ["CUDA_VISIBLE_DEVICES"] = "1"TEST_SET = ['1.png','2.png','3.png']image_size = 256classes = [0. , 1., 2., 3. , 4.] labelencoder = LabelEncoder() labelencoder.fit(classes) def args_parse():# construct the argument parse and parse the argumentsap = argparse.ArgumentParser()ap.add_argument("-m", "--model", required=True,help="path to trained model model")ap.add_argument("-s", "--stride", required=False,help="crop slide stride", type=int, default=image_size)args = vars(ap.parse_args()) return argsdef predict(args):# load the trained convolutional neural networkprint("[INFO] loading network...")# 加载训练好的模型model = load_model(args["model"])stride = args['stride']for n in range(len(TEST_SET)):path = TEST_SET[n]#load the image读取测试图片image = cv2.imread('./test/' + path)h,w,_ = image.shape# 要怎么样进行预测呢?由于在训练的时候输入的图像大小是256x256,在测试的时候喂给model的size也是256,# 可以先对原图补零,确保padding之后的size刚好可以被256整除padding_h = (h//stride + 1) * stride padding_w = (w//stride + 1) * stridepadding_img = np.zeros((padding_h,padding_w,3),dtype=np.uint8)# 不足的部分补零padding_img[0:h,0:w,:] = image[:,:,:]#padding_img = padding_img.astype("float") / 255.0padding_img = img_to_array(padding_img)print ('src:',padding_img.shape)mask_whole = np.zeros((padding_h,padding_w),dtype=np.uint8)for i in range(padding_h//stride):for j in range(padding_w//stride):# 放到padding之后的图像对应的位置crop = padding_img[:3,i*stride:i*stride+image_size,j*stride:j*stride+image_size]_,ch,cw = crop.shapeif ch != 256 or cw != 256:print ('invalid size!')continuecrop = np.expand_dims(crop, axis=0) # fit当中的verbose = 0 为不在标准输出流输出日志信息# verbose = 1 为输出进度条记录# verbose = 2 为每个epoch输出一行记录# evaluate当中的verbose = 0 为不在标准输出流输出日志信息# verbose = 1 为输出进度条记录pred = model.predict(crop,verbose=2)#print (np.unique(pred)) pred = pred.reshape((256,256)).astype(np.uint8)#print ('pred:',pred.shape)mask_whole[i*stride:i*stride+image_size,j*stride:j*stride+image_size] = pred[:,:]# 再把图像切割成跟原来一样大小的图像cv2.imwrite('./predict/pre'+str(n+1)+'.png',mask_whole[0:h,0:w])if __name__ == '__main__':args = args_parse()predict(args)

分类结果集成ensemble.py

import numpy as npimport cv2import argparseRESULT_PREFIXX = ['./result1/','./result2/','./result3/']# each mask has 5 classes: 0~4def vote_per_image(image_id):result_list = []for j in range(len(RESULT_PREFIXX)):im = cv2.imread(RESULT_PREFIXX[j]+str(image_id)+'.png',0)result_list.append(im)# each pixelheight,width = result_list[0].shapevote_mask = np.zeros((height,width))for h in range(height):for w in range(width):# 像素级别# 每个像素的所属的类别,总共5类,因此类别list是一个1x5的recordrecord = np.zeros((1,5))# 下面这个for循环是每个像素的类别级别for n in range(len(result_list)):#对于每一类结果中的每一张图片的每一个像素,统计这个位置的类别票数mask = result_list[n]pixel = mask[h,w]#print('pix:',pixel)record[0,pixel]+=1# 集成学习,取票数最多的为最终类别label = record.argmax()#print(label)vote_mask[h,w] = labelcv2.imwrite('vote_mask'+str(image_id)+'.png',vote_mask)# 总共3类结果vote_per_image(3)

SegNet模型训练segnet_train.py

#coding=utf-8import matplotlibmatplotlib.use("Agg")import matplotlib.pyplot as pltimport argparseimport numpy as np from keras.models import Sequential from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation from keras.utils.np_utils import to_categorical from keras.preprocessing.image import img_to_array from keras.callbacks import ModelCheckpoint from sklearn.preprocessing import LabelEncoder from PIL import Image import matplotlib.pyplot as plt import cv2import randomimport osfrom tqdm import tqdm os.environ["CUDA_VISIBLE_DEVICES"] = "1"seed = 7 np.random.seed(seed) #data_shape = 360*480 img_w = 256 img_h = 256 #有一个为背景 n_label = 4+1 classes = [0. , 1., 2., 3. , 4.] labelencoder = LabelEncoder() labelencoder.fit(classes) image_sets = ['1.png','2.png','3.png']def load_img(path, grayscale=False):if grayscale:img = cv2.imread(path,cv2.IMREAD_GRAYSCALE)else:img = cv2.imread(path)img = np.array(img,dtype="float") / 255.0return imgfilepath ='./train/' def get_train_val(val_rate = 0.25):train_url = [] train_set = []val_set = []for pic in os.listdir(filepath + 'src'):train_url.append(pic)random.shuffle(train_url)total_num = len(train_url)val_num = int(val_rate * total_num)for i in range(len(train_url)):if i < val_num:val_set.append(train_url[i]) else:train_set.append(train_url[i])return train_set,val_set# data for training def generateData(batch_size,data=[]): #print 'generateData...'while True: train_data = [] train_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img) train_data.append(img) label = load_img(filepath + 'label/' + url, grayscale=True)label = img_to_array(label).reshape((img_w * img_h,)) # print label.shape train_label.append(label) if batch % batch_size==0: #print 'get enough bacth!\n'train_data = np.array(train_data) train_label = np.array(train_label).flatten() train_label = labelencoder.transform(train_label) train_label = to_categorical(train_label, num_classes=n_label) train_label = train_label.reshape((batch_size,img_w * img_h,n_label)) yield (train_data,train_label) train_data = [] train_label = [] batch = 0 # data for validation def generateValidData(batch_size,data=[]): #print 'generateValidData...'while True: valid_data = [] valid_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img) valid_data.append(img) label = load_img(filepath + 'label/' + url, grayscale=True)label = img_to_array(label).reshape((img_w * img_h,)) # print label.shape valid_label.append(label) if batch % batch_size==0: valid_data = np.array(valid_data) valid_label = np.array(valid_label).flatten() valid_label = labelencoder.transform(valid_label) valid_label = to_categorical(valid_label, num_classes=n_label) valid_label = valid_label.reshape((batch_size,img_w * img_h,n_label)) yield (valid_data,valid_label) valid_data = [] valid_label = [] batch = 0 def SegNet(): model = Sequential() #encoder model.add(Conv2D(64,(3,3),strides=(1,1),input_shape=(3,img_w,img_h),padding='same',activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(64,(3,3),strides=(1,1),padding='same',activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2,2),dim_ordering = 'th')) #(128,128) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th')) #(64,64) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th')) #(32,32) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th')) #(16,16) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th')) #(8,8) #decoder model.add(UpSampling2D(size=(2,2))) #(16,16) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(32,32) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(64,64) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(128,128) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(256,256) model.add(Conv2D(64, (3, 3), strides=(1, 1), input_shape=(3,img_w, img_h), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(n_label, (1, 1), strides=(1, 1), padding='same')) model.add(Reshape((n_label,img_w*img_h))) #axis=1和axis=2互换位置,等同于np.swapaxes(layer,1,2) model.add(Permute((2,1))) model.add(Activation('softmax')) pile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy']) return model def train(args): EPOCHS = 30BS = 16model = SegNet() modelcheck = ModelCheckpoint(args['model'],monitor='val_acc',save_best_only=True,mode='max') callable = [modelcheck] train_set,val_set = get_train_val()train_numb = len(train_set) valid_numb = len(val_set) print ("the number of train data is",train_numb) print ("the number of val data is",valid_numb)H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1, validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1) # plot the training loss and accuracyplt.style.use("ggplot")plt.figure()N = EPOCHSplt.plot(np.arange(0, N), H.history["loss"], label="train_loss")plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")plt.title("Training Loss and Accuracy on SegNet Satellite Seg")plt.xlabel("Epoch #")plt.ylabel("Loss/Accuracy")plt.legend(loc="lower left")plt.savefig(args["plot"])def args_parse():# construct the argument parse and parse the argumentsap = argparse.ArgumentParser()ap.add_argument("-a", "--augment", help="using data augment or not",action="store_true", default=False)ap.add_argument("-m", "--model", required=True,help="path to output model")ap.add_argument("-p", "--plot", type=str, default="plot.png",help="path to output accuracy/loss plot")args = vars(ap.parse_args()) return argsif __name__=='__main__': args = args_parse()if args['augment'] == True:filepath ='./aug/train/'train(args) #predict()

同理,为了搞清楚segnet每一层的输入输出的tensor分别是什么样的,我们将shape打印出来如下:

_________________________________________________________________Layer (type) Output Shape Param #=================================================================conv2d_98 (Conv2D) (None, 64, 256, 256)1792_________________________________________________________________batch_normalization_1 (Batch (None, 64, 256, 256)1024_________________________________________________________________conv2d_99 (Conv2D) (None, 64, 256, 256)36928_________________________________________________________________batch_normalization_2 (Batch (None, 64, 256, 256)1024_________________________________________________________________max_pooling2d_33 (MaxPooling (None, 64, 128, 128)0_________________________________________________________________conv2d_100 (Conv2D)(None, 128, 128, 128)73856_________________________________________________________________batch_normalization_3 (Batch (None, 128, 128, 128)512_________________________________________________________________conv2d_101 (Conv2D)(None, 128, 128, 128)147584_________________________________________________________________batch_normalization_4 (Batch (None, 128, 128, 128)512_________________________________________________________________max_pooling2d_34 (MaxPooling (None, 128, 64, 64) 0_________________________________________________________________conv2d_102 (Conv2D)(None, 256, 64, 64) 295168_________________________________________________________________batch_normalization_5 (Batch (None, 256, 64, 64) 256_________________________________________________________________conv2d_103 (Conv2D)(None, 256, 64, 64) 590080_________________________________________________________________batch_normalization_6 (Batch (None, 256, 64, 64) 256_________________________________________________________________conv2d_104 (Conv2D)(None, 256, 64, 64) 590080_________________________________________________________________batch_normalization_7 (Batch (None, 256, 64, 64) 256_________________________________________________________________max_pooling2d_35 (MaxPooling (None, 256, 32, 32) 0_________________________________________________________________conv2d_105 (Conv2D)(None, 512, 32, 32) 1180160_________________________________________________________________batch_normalization_8 (Batch (None, 512, 32, 32) 128_________________________________________________________________conv2d_106 (Conv2D)(None, 512, 32, 32) 2359808_________________________________________________________________batch_normalization_9 (Batch (None, 512, 32, 32) 128_________________________________________________________________conv2d_107 (Conv2D)(None, 512, 32, 32) 2359808_________________________________________________________________batch_normalization_10 (Batc (None, 512, 32, 32) 128_________________________________________________________________max_pooling2d_36 (MaxPooling (None, 512, 16, 16) 0_________________________________________________________________conv2d_108 (Conv2D)(None, 512, 16, 16) 2359808_________________________________________________________________batch_normalization_11 (Batc (None, 512, 16, 16) 64_________________________________________________________________conv2d_109 (Conv2D)(None, 512, 16, 16) 2359808_________________________________________________________________batch_normalization_12 (Batc (None, 512, 16, 16) 64_________________________________________________________________conv2d_110 (Conv2D)(None, 512, 16, 16) 2359808_________________________________________________________________batch_normalization_13 (Batc (None, 512, 16, 16) 64_________________________________________________________________max_pooling2d_37 (MaxPooling (None, 512, 8, 8) 0_________________________________________________________________up_sampling2d_17 (UpSampling (None, 512, 16, 16) 0_________________________________________________________________conv2d_111 (Conv2D)(None, 512, 16, 16) 2359808_________________________________________________________________batch_normalization_14 (Batc (None, 512, 16, 16) 64_________________________________________________________________conv2d_112 (Conv2D)(None, 512, 16, 16) 2359808_________________________________________________________________batch_normalization_15 (Batc (None, 512, 16, 16) 64_________________________________________________________________conv2d_113 (Conv2D)(None, 512, 16, 16) 2359808_________________________________________________________________batch_normalization_16 (Batc (None, 512, 16, 16) 64_________________________________________________________________up_sampling2d_18 (UpSampling (None, 512, 32, 32) 0_________________________________________________________________conv2d_114 (Conv2D)(None, 512, 32, 32) 2359808_________________________________________________________________batch_normalization_17 (Batc (None, 512, 32, 32) 128_________________________________________________________________conv2d_115 (Conv2D)(None, 512, 32, 32) 2359808_________________________________________________________________batch_normalization_18 (Batc (None, 512, 32, 32) 128_________________________________________________________________conv2d_116 (Conv2D)(None, 512, 32, 32) 2359808_________________________________________________________________batch_normalization_19 (Batc (None, 512, 32, 32) 128_________________________________________________________________up_sampling2d_19 (UpSampling (None, 512, 64, 64) 0_________________________________________________________________conv2d_117 (Conv2D)(None, 256, 64, 64) 1179904_________________________________________________________________batch_normalization_20 (Batc (None, 256, 64, 64) 256_________________________________________________________________conv2d_118 (Conv2D)(None, 256, 64, 64) 590080_________________________________________________________________batch_normalization_21 (Batc (None, 256, 64, 64) 256_________________________________________________________________conv2d_119 (Conv2D)(None, 256, 64, 64) 590080_________________________________________________________________batch_normalization_22 (Batc (None, 256, 64, 64) 256_________________________________________________________________up_sampling2d_20 (UpSampling (None, 256, 128, 128)0_________________________________________________________________conv2d_120 (Conv2D)(None, 128, 128, 128)295040_________________________________________________________________batch_normalization_23 (Batc (None, 128, 128, 128)512_________________________________________________________________conv2d_121 (Conv2D)(None, 128, 128, 128)147584_________________________________________________________________batch_normalization_24 (Batc (None, 128, 128, 128)512_________________________________________________________________up_sampling2d_21 (UpSampling (None, 128, 256, 256)0_________________________________________________________________conv2d_122 (Conv2D)(None, 64, 256, 256)73792_________________________________________________________________batch_normalization_25 (Batc (None, 64, 256, 256)1024_________________________________________________________________conv2d_123 (Conv2D)(None, 64, 256, 256)36928_________________________________________________________________batch_normalization_26 (Batc (None, 64, 256, 256)1024_________________________________________________________________conv2d_124 (Conv2D)(None, 1, 256, 256) 65_________________________________________________________________reshape_1 (Reshape)(None, 1, 65536)0_________________________________________________________________permute_1 (Permute)(None, 65536, 1)0_________________________________________________________________activation_1 (Activation) (None, 65536, 1)0=================================================================Total params: 31,795,841Trainable params: 31,791,425Non-trainable params: 4,416_________________________________________________________________

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。