我正在尝试学习一些Keras语法并使用初始v3示例。
我有一个4类多类分类玩具问题,所以我从示例中更改了以下行:
NB_CLASS = 4 # number of classes
DIM_ORDERING = 'tf' # 'th' (channels, width, height) or 'tf' (width, height, channels)我的玩具数据集有以下几个维度:
然后,我尝试使用以下代码来训练模型:
# fit the model on the batches generated by datagen.flow()
# https://github.com/fchollet/keras/issues/1627
# http://keras.io/models/sequential/#sequential-model-methods
checkpointer = ModelCheckpoint(filepath="/tmp/weights.hdf5", verbose=1, save_best_only=True)
model.fit_generator(datagen.flow(X_train, Y_train,
batch_size=32),
nb_epoch=10,
samples_per_epoch=32,
class_weight=None, #classWeights,
verbose=2,
validation_data=(X_test, Y_test),
callbacks=[checkpointer])然后,我得到以下错误:
Exception: The model expects 2 input arrays, but only received one array. Found: array with shape (179, 4)`这可能与此有关,因为盗梦空间希望拥有辅助分类器(Szegedy等人,2014年)
model = Model(input=img_input, output=[preds, aux_preds])如何将这两个标签赋予Keras中的模型,因为它也不是高级的程序员?
发布于 2017-11-06 18:09:24
在它的第一部分中,您将看到如何使用以下方法从目录中加载数据:
.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='binary')为了输入不同的类,您必须将图像放在每个类的一个文件夹中(请注意,通过传递标签,可能有另一种方法)。还请注意,在您的示例中,class_mode不能使用“二进制”(我认为您应该使用“分类”):
`"binary"`: binary targets (if there are only two classes),
`"categorical"`: categorical targets,然后您可以使用已经在Keras中的inceptionv3模型:
from keras.applications import InceptionV3
cnn = InceptionV3(...)还要注意的是,您训练InceptionV3的例子太少了,因为这个模型非常大(请检查这里的大小)。在这种情况下,您可以做的是传输学习,在InceptionV3上使用预先训练过的权重。请参阅使用预先培训过的网络的瓶颈特性的部分:在本教程中立即达到90%的准确率。
发布于 2018-04-11 21:47:31
错误消息与validation_data参数有关:在使用model.fit_generator时,验证数据也应该通过ImageDataGenerator对象传递(就像您已经对培训数据所做的那样)。这与缺少辅助分类器无关-- Keras 不实现辅助分类器。中的初始不实现辅助分类器。模型来自原始论文(这是尝试迁移学习而不是完全培训的另一个原因)。
更新代码以使用生成器提供验证数据:
datagen = ImageDataGenerator()
model.fit_generator(datagen.flow(X_train, Y_train, batch_size=32),
nb_epoch=10,
steps_per_epoch=len(X_train) / 32,
class_weight=None,
verbose=2,
validation_data=datagen.flow(X_test, Y_test, batch_size=32),
validation_steps=len(X_test) / 32,
callbacks=[checkpointer])请注意,我已经将参数samples_per_epoch更新为较新的steps_per_epoch。
https://stackoverflow.com/questions/37641854
复制相似问题