我试图使用传输学习使用ResNet-50在TensorFlow2和Keras上的CIFAR-10数据集,其中有(32,32,3)图像。
默认的ResNet-50的第一个conv层使用的过滤器大小为(7,7),步幅= 2,由此产生的CIFAR-10在空间上减少了太多,这是必须避免的。作为“黑客”,图像试图从(32,32)提升到(224,224)。守则是:
import tensorflow.keras as K
# Define KerasTensor as input-
input_t = K.Input(shape = (32, 32, 3))
res_model = K.applications.ResNet50(
include_top = False,
weights = "imagenet",
input_tensor = input_t
)
# Since CIFAR-10 dataset is small as compared to ImageNet, the images are upscaled to (224, 224)-
to_res = (224, 224)
model = K.models.Sequential()
model.add(K.layers.Lambda(lambda image: tf.image.resize(image, to_res)))
model.add(res_model)
model.add(K.layers.Flatten())
model.add(K.layers.BatchNormalization())
model.add(K.layers.Dense(units = 10, activation = 'softmax'))
# Choose an optimizer and loss function for training-
loss_fn = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.SGD(learning_rate = 0.1, momentum = 0.9)
model.compile(
# loss = 'categorical_crossentropy',
loss = loss_fn,
# optimizer = K.optimizers.RMSprop(lr=2e-5),
optimizer = optimizer,
metrics=['accuracy']
)
history = model.fit(
x = X_train, y = y_train,
batch_size = batch_size, epochs = 10,
validation_data = (X_test, y_test),
# callbacks=[check_point]
)我得到错误的原因是:
Epoch 1/10警告:tensorflow:模型是为输入KerasTensor(type_spec=TensorSpec (None,32,32,3),dtype=tf.float32,name=' input _1'),name=' input _1',description=(由图层‘input_1’创建)构造的,但它被调用的输入形状不兼容(None,224,224,3)。
ValueError回溯(最近一次调用)
in () 2x= X_train,y= y_train,3 batch_size = batch_size,epochs = 10,-->4 validation_data = (X_test,y_test),5# callbacks=check_point 6)
9帧
包装器中的/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py (*args,**kwargs) 975,例外情况除外,e:#pylint:else=else除了976 (如果hasattr(e,"ag_error_metadata"):-> 977 e.ag_error_metadata.to_exception(e) 978 if : 979 e.ag_error_metadata.to_exception(E)e.ag_error_metadata.to_exception(E)978)
ValueError:在用户代码中:
ValueError:输入0与层resnet50不兼容:预期的shape=(None,32,32,3),found shape=(None,224,224,3)
发布于 2021-07-20 09:46:19
模型的输入仍然是(32,32,3)
input_t = K.Input(shape = (32, 32, 3))https://stackoverflow.com/questions/67508398
复制相似问题