Py学习  »  Python

【教程】Python实现随机森林遥感图像分类

新机器视觉 • 2 年前 • 597 次点击  

点击下方卡片,关注“新机器视觉”公众号

视觉/图像重磅干货,第一时间送达

本期以landsat提取植被为例,更新如何用python实现随机森林遥感图像分类,原作者王振庆@知乎,仅用于学术分享


1
随机森林(RandomForest)

随机森林,顾名思义是用随机的方式建立一个森林,森林里面有很多的决策树组成,随机森林的每一棵决策树之间是没有关联的。

随机森林的构造过程:
假如有N个样本,则有放回的随机选择N个样本(每次随机选择一个样本,然后返回继续选择)。这选择好了的N个样本用来训练一个决策树,作为决策树根节点处的样本。
当每个样本有M个属性时,在决策树的每个节点需要分裂时,随机从这M个属性中选取出m个属性,满足条件m << M。然后从这m个属性中采用某种策略(比如说信息增益)来选择1个属性作为该节点的分裂属性。
决策树形成过程中每个节点都要按照步骤2来分裂(很容易理解,如果下一次该节点选出来的那一个属性是刚刚其父节点分裂时用过的属性,则该节点已经达到了叶子节点,无须继续分裂了)。一直到不能够再分裂为止。注意整个决策树形成过程中没有进行剪枝。
2
随机森林遥感图像分类


以landsat提取植被为例,其实不论什么影像分什么类,操作都是一样的。

(1) 制作样本

a. 数字矢量化样本标签图

随机森林属于监督分类,监督分类是一定需要样本的。我们在Arcgis(ENVI也可)中目视解译矢量化一些植被与非植被的典型样本,然后【要素转栅格】将矢量数据转为栅格标签图。其中要注意:植被与非植被的值要设置为不同;转栅格的范围要与遥感图像一致。这样做的目的是为了方便抓取与标签图对应位置的遥感图像各波段值。

图1 Lanset真彩色图像

图2 栅格标签图

图1是landset的真彩色图像,图2是数字化样本并转成栅格的标签图像,标签图为单波段灰度图,为了更好地展示,我进行了RGB渲染。其中绿色的为植被样本,紫色的为非植被样本。

b. 样本数据集制作

样本数据集为txt,格式如图3所示。每行的前7个数为landset的7个波段值,Vegetation和Non-Vegetation表示该数据为植被还是非植被。具体制作过程直接上代码,注释很详细。

图3 样本数据集示意图

import gdalimport osimport random#读取tif数据集def readTif(fileName):    dataset = gdal.Open(fileName)    if dataset == None:        print(fileName+"文件无法打开")    return datasetLandset_Path = r"D:\ROI.tif"LabelPath = r"D:\label.tif"txt_Path = r"D:\data.txt"# 读取图像数据dataset = readTif(Landset_Path)Tif_width = dataset.RasterXSize #栅格矩阵的列数


    
Tif_height = dataset.RasterYSize #栅格矩阵的行数Tif_bands = dataset.RasterCount #波段数Tif_geotrans = dataset.GetGeoTransform()#获取仿射矩阵信息Landset_data = dataset.ReadAsArray(0,0,Tif_width,Tif_height)dataset = readTif(LabelPath)Label_data = dataset.ReadAsArray(0,0,Tif_width,Tif_height)# 写之前,先检验文件是否存在,存在就删掉if os.path.exists(txt_Path):    os.remove(txt_Path)# 以写的方式打开文件,如果文件不存在,就会自动创建file_write_obj = open(txt_Path, 'w')#首先收集植被类别样本,#遍历所有像素值,#为植被的像元全部收集。count = 0for i in range(Label_data.shape[0]):    for j in range(Label_data.shape[1]):        #  我设置的植被类别在标签图中像元值为1        if(Label_data[i][j] == 1):            var = ""            for k in range(Landset_data.shape[0]):                var = var + str(Landset_data[k][i][j])+","            var = var + "Vegetation"            file_write_obj.writelines(var)            file_write_obj.write('\n')            count = count + 1


    
#其次收集非植被类别样本,#因为非植被样本比植被样本多很多,#所以采用在所有非植被类别中随机选择非植被样本,#数量与植被样本数量保持一致。Threshold = countcount = 0for i in range(10000000000):    X_random = random.randint(0,Label_data.shape[0]-1)    Y_random = random.randint(0,Label_data.shape[1]-1)    #  我设置的非植被类别在标签图中像元值为0    if(Label_data[X_random][Y_random] == 0):        var = ""        for k in range(Landset_data.shape[0]):            var = var + str(Landset_data[k][X_random][Y_random])+","        var = var + "Non-Vegetation"        file_write_obj.writelines(var)        file_write_obj.write('\n')        count = count + 1    if(count == Threshold):        breakfile_write_obj.close()

(2) 模型训练

随机森林模型我们采用sklearn库中自带的随机森林模型RandomForestClassifier。具体训练过程直接上代码,注释很详细。

from sklearn.ensemble import RandomForestClassifierimport numpy as npfrom sklearn import model_selection


    
import pickle #  定义字典,便于来解析样本数据集txtdef Iris_label(s):    it={b'Vegetation':0, b'Non-Vegetation':1}    return it[s]path=r"D:\data.txt"SavePath = r"D:\model.pickle"#  1.读取数据集data=np.loadtxt(path, dtype=float, delimiter=',', converters={7:Iris_label} )#  converters={7:Iris_label}中“7”指的是第8列:将第8列的str转化为label(number)#  2.划分数据与标签x,y=np.split(data,indices_or_sections=(7,),axis=1) #x为数据,y为标签x=x[:,0:7] #选取前7个波段作为特征train_data,test_data,train_label,test_label = model_selection.train_test_split(x,y, random_state=1, train_size=0.9,test_size=0.1)#  3.用100个树来创建随机森林模型,训练随机森林classifier = RandomForestClassifier(n_estimators=100,                                bootstrap = True,                               max_features = 'sqrt')classifier.fit(train_data, train_label.ravel())#ravel函数拉伸到一维#  4.计算随机森林的准确率print("训练集:",classifier.score(train_data,train_label))print("测试集:",classifier.score(test_data,test_label))#  5.保存模型#以二进制的方式打开文件:


    
file = open(SavePath, "wb")#将模型写入文件:pickle.dump(classifier, file)#最后关闭文件:file.close()

(3) 模型预测

训练好了模型,就该进行我们遥感图像的预测了。具体预测过程依旧直接上代码,注释很详细。
import numpy as npimport gdalimport pickle #读取tif数据集def readTif(fileName):    dataset = gdal.Open(fileName)    if dataset == None:        print(fileName+"文件无法打开")    return dataset#保存tif文件函数def writeTiff(im_data,im_geotrans,im_proj,path):    if 'int8' in im_data.dtype.name:        datatype = gdal.GDT_Byte    elif 'int16' in im_data.dtype.name:        datatype = gdal.GDT_UInt16    else:        datatype = gdal.GDT_Float32    if len(im_data.shape) == 3:        im_bands, im_height, im_width = im_data.shape


    
    elif len(im_data.shape) == 2:        im_data = np.array([im_data])        im_bands, im_height, im_width = im_data.shape    #创建文件    driver = gdal.GetDriverByName("GTiff")    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)    if(dataset!= None):        dataset.SetGeoTransform(im_geotrans) #写入仿射变换参数        dataset.SetProjection(im_proj) #写入投影    for i in range(im_bands):        dataset.GetRasterBand(i+1).WriteArray(im_data[i])    del dataset   RFpath = r"D:\model.pickle"Landset_Path = r"D:\20130514_ROI.tif"SavePath = r"D:\save.tif"dataset = readTif(Landset_Path)Tif_width = dataset.RasterXSize #栅格矩阵的列数Tif_height = dataset.RasterYSize #栅格矩阵的行数Tif_geotrans = dataset.GetGeoTransform()#获取仿射矩阵信息Tif_proj = dataset.GetProjection()#获取投影信息Landset_data = dataset.ReadAsArray(0,0,Tif_width,Tif_height)#调用保存好的模型#以读二进制的方式打开文件file = open(RFpath, "rb")#把模型从文件中读取出来rf_model = pickle.load(file)#关闭文件file.close()


    
#用读入的模型进行预测#  在与测试前要调整一下数据的格式data = np.zeros((Landset_data.shape[0],Landset_data.shape[1]*Landset_data.shape[2]))for i in range(Landset_data.shape[0]):    data[i] = Landset_data[i].flatten() data = data.swapaxes(0,1)#  对调整好格式的数据进行预测pred = rf_model.predict(data)#  同样地,我们对预测好的数据调整为我们图像的格式pred = pred.reshape(Landset_data.shape[1],Landset_data.shape[2])*255pred = pred.astype(np.uint8)#  将结果写到tif图像里writeTiff(pred,Tif_geotrans,Tif_proj,SavePath)
预测结果如下:
图4 RF预测结果结果图像


—版权声明—

仅用于学术分享,版权属于原作者。

若有侵权,请联系微信号:yiyang-sy 删除或修改!


—THE END—
Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/121921
 
597 次点击