700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > unet图片数据增强_numpy实现深度学习遥感图像语义分割数据增强(支持多波段)

unet图片数据增强_numpy实现深度学习遥感图像语义分割数据增强(支持多波段)

时间:2023-11-11 02:22:46

相关推荐

unet图片数据增强_numpy实现深度学习遥感图像语义分割数据增强(支持多波段)

前言

数据增强是指对训练样本数据进行某种变换操作,从而生成新数据的过程。数据增强的根本目的是得到充足的样本数据量,避免模型训练过程中产生过拟合现象。

正文

对于遥感影像来说,由于成像过程传感器对同一地物在不同角度拍摄会在影像上展现出不同的位置和形态,所以经过变换的样本可以使模型更好地学习地物的旋转不变的特征,从而更好地适应不同形态的图像。因此我们对训练数据进行几何变换(包括水平翻转、竖直翻转以及对角翻转)的数据增强操作。原图像水平翻转垂直翻转对角翻转

代码实现:

因为如果多波段的话,利用OpenCV对图像进行翻转就会报错,所以我们使用numpy进行翻转的实现。

import gdal

import numpy as np

import os

import cv2

# 读取tif数据集

def readTif(fileName, xoff = 0, yoff = 0, data_width = 0, data_height = 0):

dataset = gdal.Open(fileName)

if dataset == None:

print(fileName + "文件无法打开")

# 栅格矩阵的列数

width = dataset.RasterXSize

# 栅格矩阵的行数

height = dataset.RasterYSize

# 波段数

bands = dataset.RasterCount

# 获取数据

if(data_width == 0 and data_height == 0):

data_width = width

data_height = height

data = dataset.ReadAsArray(xoff, yoff, data_width, data_height)

# 获取仿射矩阵信息

geotrans = dataset.GetGeoTransform()

# 获取投影信息

proj = dataset.GetProjection()

return width, height, bands, data, geotrans, proj

# 保存tif文件函数

def writeTiff(im_data, im_geotrans, im_proj, path):

if 'int8' in im_data.dtype.name:

datatype = gdal.GDT_Byte

elif 'int16' in im_data.dtype.name:

datatype = gdal.GDT_UInt16

else:

datatype = gdal.GDT_Float32

if len(im_data.shape) == 3:

im_bands, im_height, im_width = im_data.shape

elif len(im_data.shape) == 2:

im_data = np.array([im_data])

im_bands, im_height, im_width = im_data.shape

#创建文件

driver = gdal.GetDriverByName("GTiff")

dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)

if(dataset!= None):

dataset.SetGeoTransform(im_geotrans) #写入仿射变换参数

dataset.SetProjection(im_proj) #写入投影

for i in range(im_bands):

dataset.GetRasterBand(i + 1).WriteArray(im_data[i])

del dataset

train_image_path = r"Data\train\image"

train_label_path = r"Data\train\label"

# 进行几何变换数据增强

imageList = os.listdir(train_image_path)

labelList = os.listdir(train_label_path)

tran_num = len(imageList) + 1

for i in range(len(imageList)):

# 图像

img_file = train_image_path + "\\" + imageList[i]

im_width, im_height, im_bands, im_data, im_geotrans, im_proj = readTif(img_file)

# 标签

label_file = train_label_path + "\\" + labelList[i]

label = cv2.imread(label_file)

# 图像水平翻转

im_data_hor = np.flip(im_data, axis = 2)

hor_path = train_image_path + "\\" + str(tran_num) + imageList[i][-4:]

writeTiff(im_data_hor, im_geotrans, im_proj, hor_path)

# 标签水平翻转

Hor = cv2.flip(label, 1)

hor_path = train_label_path + "\\" + str(tran_num) + labelList[i][-4:]

cv2.imwrite(hor_path, Hor)

tran_num += 1

# 图像垂直翻转

im_data_vec = np.flip(im_data, axis = 1)

vec_path = train_image_path + "\\" + str(tran_num) + imageList[i][-4:]

writeTiff(im_data_vec, im_geotrans, im_proj, vec_path)

# 标签垂直翻转

Vec = cv2.flip(label, 0)

vec_path = train_label_path + "\\" + str(tran_num) + labelList[i][-4:]

cv2.imwrite(vec_path, Vec)

tran_num += 1

# 图像对角镜像

im_data_dia = np.flip(im_data_vec, axis = 2)

dia_path = train_image_path + "\\" + str(tran_num) + imageList[i][-4:]

writeTiff(im_data_dia, im_geotrans, im_proj, dia_path)

# 标签对角镜像

Dia = cv2.flip(label, -1)

dia_path = train_label_path + "\\" + str(tran_num) + labelList[i][-4:]

cv2.imwrite(dia_path, Dia)

tran_num += 1

后记

有问题欢迎留言评论,觉得不错可以动动手指点个赞同&喜欢

我的其他文字:馨意:keras遥感图像Unet语义分割(支持多波段&多类)​馨意:遥感大图像深度学习忽略边缘(划窗)预测​

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。