Py学习  »  Python

Python Keras-ImageDataGenerator中的自定义标签

Cubilla • 5 年前 • 1494 次点击  

我目前正在创建一个CNN模型来分类字体是否 Arial Verdana , Times New Roman Georgia . 总而言之 16 类,因为我还考虑了检测字体是否 regular , bold , italics bold italics . 所以 4 fonts * 4 styles = 16 classes .

Training data set : 800 image patches of 256 * 256 dimension (50 for each class)
Validation data set : 320 image patches of 256 * 256 dimension (20 for each class)
Testing data set : 160 image patches of 256 * 256 dimension (10 for each class)

以下是我的初始代码:

import numpy as np
import keras
from keras import backend as K
from keras.models import Sequential
from keras.layers import Activation
from keras.layers.core import Dense, Flatten
from keras.optimizers import Adam
from keras.metrics import categorical_crossentropy
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import *
from matplotlib import pyplot as plt
import itertools
import matplotlib.pyplot as plt
import pickle


 image_width = 256
 image_height = 256

 train_path = 'font_model_data/train'
 valid_path =  'font_model_data/valid'
 test_path = 'font_model_data/test'


  train_batches = ImageDataGenerator().flow_from_directory(train_path, target_size=(image_width, image_height), classes=['1','2','3','4', '5', '6', '7', '8', '9', '10', '11', '12','13', '14', '15', '16'], batch_size = 16)
 valid_batches = ImageDataGenerator().flow_from_directory(valid_path, target_size=(image_width, image_height), classes=['1','2','3','4', '5', '6', '7', '8', '9', '10', '11', '12','13', '14', '15', '16'], batch_size = 16)
 test_batches = ImageDataGenerator().flow_from_directory(test_path, target_size=(image_width, 
 image_height), classes=['1','2','3','4', '5', '6', '7', '8', '9', '10', '11', '12','13', '14', '15', '16'], batch_size = 160)


 imgs, labels = next(train_batches)
 print(labels)

#CNN model
model = Sequential([
    Conv2D(32, (3,3), activation='relu', input_shape=(image_width, image_height, 3)),
    Flatten(),
    Dense(**16**, activation='softmax'), # I want to make it 4
])

我计划在网络中有4个输出节点:

4 Output Nodes (4 bits):
Class 01 - 0000
Class 02 - 0001
Class 03 - 0010
Class 04 - 0011
Class 05 - 0100
Class 06 - 0101
Class 07 - 0110
Class 08 - 0111
Class 09 - 1000
Class 10 - 1001
Class 11 - 1010
Class 12 - 1011
Class 13 - 1100
Class 14 - 1101
Class 15 - 1110
Class 16 - 1111

ImageDataGenerator 是一个 16 bits 标签

[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]

 labels = [[0,0,0,0],
 [0,0,0,1],
 [0,0,1,0],
 [0,0,1,1],
 [0,1,0,0],
 [0,1,0,1],
 [0,1,1,0],
 [0,1,1,1],
 [1,0,0,0],
 [1,0,0,1],
 [1,0,1,0],
 [1,0,1,1],
 [1,1,0,0],
 [1,1,0,1],
 [1,1,1,0],
 [1,1,1,1]]

其目的是使我的网络/最后一个密集层的输出节点 16个 4

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/55349
 
1494 次点击  
文章 [ 2 ]  |  最新文章 5 年前