点击下方卡片,关注“新机器视觉”公众号
视觉/图像重磅干货,第一时间送达
本期以landsat提取植被为例,更新如何用python实现随机森林遥感图像分类,原作者王振庆@知乎,仅用于学术分享
随机森林,顾名思义是用随机的方式建立一个森林,森林里面有很多的决策树组成,随机森林的每一棵决策树之间是没有关联的。
假如有N个样本,则有放回的随机选择N个样本(每次随机选择一个样本,然后返回继续选择)。这选择好了的N个样本用来训练一个决策树,作为决策树根节点处的样本。当每个样本有M个属性时,在决策树的每个节点需要分裂时,随机从这M个属性中选取出m个属性,满足条件m << M。然后从这m个属性中采用某种策略(比如说信息增益)来选择1个属性作为该节点的分裂属性。决策树形成过程中每个节点都要按照步骤2来分裂(很容易理解,如果下一次该节点选出来的那一个属性是刚刚其父节点分裂时用过的属性,则该节点已经达到了叶子节点,无须继续分裂了)。一直到不能够再分裂为止。注意整个决策树形成过程中没有进行剪枝。
以landsat提取植被为例,其实不论什么影像分什么类,操作都是一样的。
(1) 制作样本
a. 数字矢量化样本标签图
随机森林属于监督分类,监督分类是一定需要样本的。我们在Arcgis(ENVI也可)中目视解译矢量化一些植被与非植被的典型样本,然后【要素转栅格】将矢量数据转为栅格标签图。其中要注意:植被与非植被的值要设置为不同;转栅格的范围要与遥感图像一致。这样做的目的是为了方便抓取与标签图对应位置的遥感图像各波段值。
图1是landset的真彩色图像,图2是数字化样本并转成栅格的标签图像,标签图为单波段灰度图,为了更好地展示,我进行了RGB渲染。其中绿色的为植被样本,紫色的为非植被样本。
b. 样本数据集制作
样本数据集为txt,格式如图3所示。每行的前7个数为landset的7个波段值,Vegetation和Non-Vegetation表示该数据为植被还是非植被。具体制作过程直接上代码,注释很详细。
图3 样本数据集示意图
import gdal
import os
import random
def readTif(fileName):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName+"文件无法打开")
return dataset
Landset_Path = r"D:\ROI.tif"
LabelPath = r"D:\label.tif"
txt_Path = r"D:\data.txt"
dataset = readTif(Landset_Path)
Tif_width = dataset.RasterXSize
Tif_height = dataset.RasterYSize
Tif_bands = dataset.RasterCount
Tif_geotrans = dataset.GetGeoTransform()
Landset_data = dataset.ReadAsArray(0,0,Tif_width,Tif_height)
dataset = readTif(LabelPath)
Label_data = dataset.ReadAsArray(0,0,Tif_width,Tif_height)
if os.path.exists(txt_Path):
os.remove(txt_Path)
file_write_obj = open(txt_Path, 'w')
count = 0
for i in range(Label_data.shape[0]):
for j in range(Label_data.shape[1]):
if(Label_data[i][j] == 1):
var = ""
for k in range(Landset_data.shape[0]):
var = var + str(Landset_data[k][i][j])+","
var = var + "Vegetation"
file_write_obj.writelines(var)
file_write_obj.write('\n')
count = count + 1
Threshold = count
count = 0
for i in range(10000000000):
X_random = random.randint(0,Label_data.shape[0]-1)
Y_random = random.randint(0,Label_data.shape[1]-1)
if(Label_data[X_random][Y_random] == 0):
var = ""
for k in range(Landset_data.shape[0]):
var = var + str(Landset_data[k][X_random][Y_random])+","
var = var + "Non-Vegetation"
file_write_obj.writelines(var)
file_write_obj.write('\n')
count = count + 1
if(count == Threshold):
break
file_write_obj.close()
(2) 模型训练
随机森林模型我们采用sklearn库中自带的随机森林模型RandomForestClassifier。具体训练过程直接上代码,注释很详细。
from sklearn.ensemble import RandomForestClassifier
import numpy as np
from sklearn import model_selection
import pickle
def Iris_label(s):
it={b'Vegetation':0, b'Non-Vegetation':1}
return it[s]
path=r"D:\data.txt"
SavePath = r"D:\model.pickle"
data=np.loadtxt(path, dtype=float, delimiter=',', converters={7:Iris_label} )
x,y=np.split(data,indices_or_sections=(7,),axis=1)
x=x[:,0:7]
train_data,test_data,train_label,test_label = model_selection.train_test_split(x,y, random_state=1, train_size=0.9,test_size=0.1)
classifier = RandomForestClassifier(n_estimators=100,
bootstrap = True,
max_features = 'sqrt')
classifier.fit(train_data, train_label.ravel())
print("训练集:",classifier.score(train_data,train_label))
print("测试集:",classifier.score(test_data,test_label))
file = open(SavePath, "wb")
pickle.dump(classifier, file)
file.close()
(3) 模型预测
训练好了模型,就该进行我们遥感图像的预测了。具体预测过程依旧直接上代码,注释很详细。import numpy as np
import gdal
import pickle
def readTif(fileName):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName+"文件无法打开")
return dataset
def writeTiff(im_data,im_geotrans,im_proj,path):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
im_bands, im_height, im_width = im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
if(dataset!= None):
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
RFpath = r"D:\model.pickle"
Landset_Path = r"D:\20130514_ROI.tif"
SavePath = r"D:\save.tif"
dataset = readTif(Landset_Path)
Tif_width = dataset.RasterXSize
Tif_height = dataset.RasterYSize
Tif_geotrans = dataset.GetGeoTransform()
Tif_proj = dataset.GetProjection()
Landset_data = dataset.ReadAsArray(0,0,Tif_width,Tif_height)
file = open(RFpath, "rb")
rf_model = pickle.load(file)
file.close()
data = np.zeros((Landset_data.shape[0],Landset_data.shape[1]*Landset_data.shape[2]))
for i in range(Landset_data.shape[0]):
data[i] = Landset_data[i].flatten()
data = data.swapaxes(0,1)
pred = rf_model.predict(data)
pred = pred.reshape(Landset_data.shape[1],Landset_data.shape[2])*255
pred = pred.astype(np.uint8)
writeTiff(pred,Tif_geotrans,Tif_proj,SavePath)
—版权声明—
仅用于学术分享,版权属于原作者。
若有侵权,请联系微信号:yiyang-sy 删除或修改!