社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  机器学习算法

【深度学习】PyTorch如何构建图像数据的测试集?(附代码)

机器学习初学者 • 2 年前 • 545 次点击  

如何构建图像数据的测试集?假设图片已经按照类别分好文件夹,从中间取出一定比例的图片做测试,图片类别标签保存在csv文件。

需求

在每个子文件夹中选取一定数量的图片,移动到test文件夹中,并在test文件夹中进行乱序并重新编号;

需要记录答案,即test中图片的原始对应关系,以的格式,保存在csv中。

代码构建

import osimport mathimport randomimport csv
ORI_DIR = './flower_data_new/'TEST_DIR = 'test'TEST_RATE = 0.4SUFFIXES = '.jpg'HEADERS = ['id', 'category']# 计算需要移动到测试集中的图像数量categoryNumList = []categoryNames = os.listdir(ORI_DIR)for categoryName in categoryNames: categoryDir = os.path.join(ORI_DIR, categoryName) categoryFiles = [ file for file in os.listdir(categoryDir) if file.endswith(SUFFIXES) ] categoryNum = len(categoryFiles) categoryNumList.append(categoryNum)selectedNumList = [math.floor(num * TEST_RATE) for num in categoryNumList]testTotalNum = sum(selectedNumList)# 构造乱序数组randomList = [i for i in range(testTotalNum)]for i in range(testTotalNum): rand = random.randint(0, testTotalNum - 1) randomList[i], randomList[rand] = randomList[rand], randomList[i]# 创建测试集文件夹if not os.path.exists(TEST_DIR): os.mkdir(TEST_DIR)# 取每类的最后部分,移动到test文件夹,并记录每张图片的id和类别index = 0rows = []for (i, num) in enumerate(selectedNumList): categoryName = categoryNames[i] destDirPath = os.path.join(ORI_DIR, categoryName) files = [ file for file in os.listdir(destDirPath) if file.endswith(SUFFIXES) ] for name in files[len(files) - num:]: oriPath = os.path.join(destDirPath, name) destName = str(randomList[index]) + SUFFIXES destPath = os.path.join(TEST_DIR, destName) rows.append([randomList[index], i]) os.rename(oriPath, destPath) index += 1# 先将rows排序,再存入csv rows = sorted(rows, key=lambda x: x[0])with open('answer.csv', 'w', newline='') as f: f_csv = csv.writer(f) f_csv.writerow(HEADERS) f_csv.writerows(rows)

此代码假设您的训练文件夹名为 train,测试文件夹名为 test。代码将使用 ImageFolder 获取所有图像路径和对应的标签,并计算需要移动到测试集中的图像数量。然后,它会随机选择需要移动到测试集中的图像,并将它们复制到测试集文件夹中。

在最后一步中,代码将用CSV文件记录每个图像的文件名和标签。可以像使用 answer.csv 文件记录文件名和标签一样将其更改为您想要的格式。

请注意,此代码中使用了 random.sample 方法来随机选择要移动到测试集中的图像。它确保了选中的图像是独立且随机的。



Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/153339
 
545 次点击