首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TensorFlow2: ResNet50 - ValueError

TensorFlow2: ResNet50 - ValueError
EN

Stack Overflow用户
提问于 2021-05-12 17:29:10
回答 1查看 299关注 0票数 0

我试图使用传输学习使用ResNet-50在TensorFlow2和Keras上的CIFAR-10数据集,其中有(32,32,3)图像。

默认的ResNet-50的第一个conv层使用的过滤器大小为(7,7),步幅= 2,由此产生的CIFAR-10在空间上减少了太多,这是必须避免的。作为“黑客”,图像试图从(32,32)提升到(224,224)。守则是:

代码语言:javascript
复制
import tensorflow.keras as K

# Define KerasTensor as input-
input_t = K.Input(shape = (32, 32, 3))

res_model = K.applications.ResNet50(
    include_top = False,
    weights = "imagenet",
    input_tensor = input_t
)

# Since CIFAR-10 dataset is small as compared to ImageNet, the images are upscaled to (224, 224)-
to_res = (224, 224)

model = K.models.Sequential()
model.add(K.layers.Lambda(lambda image: tf.image.resize(image, to_res))) 
model.add(res_model)
model.add(K.layers.Flatten())
model.add(K.layers.BatchNormalization())
model.add(K.layers.Dense(units = 10, activation = 'softmax'))

# Choose an optimizer and loss function for training-
loss_fn = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.SGD(learning_rate = 0.1, momentum = 0.9)

model.compile(
    # loss = 'categorical_crossentropy',
    loss = loss_fn,
    # optimizer = K.optimizers.RMSprop(lr=2e-5),
    optimizer = optimizer,
    metrics=['accuracy']
)

history = model.fit(
    x = X_train, y = y_train,
    batch_size = batch_size, epochs = 10,
    validation_data = (X_test, y_test),
    # callbacks=[check_point]
    )

我得到错误的原因是:

Epoch 1/10警告:tensorflow:模型是为输入KerasTensor(type_spec=TensorSpec (None,32,32,3),dtype=tf.float32,name=' input _1'),name=' input _1',description=(由图层‘input_1’创建)构造的,但它被调用的输入形状不兼容(None,224,224,3)。

ValueError回溯(最近一次调用)

in () 2x= X_train,y= y_train,3 batch_size = batch_size,epochs = 10,-->4 validation_data = (X_test,y_test),5# callbacks=check_point 6)

9帧

包装器中的/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py (*args,**kwargs) 975,例外情况除外,e:#pylint:else=else除了976 (如果hasattr(e,"ag_error_metadata"):-> 977 e.ag_error_metadata.to_exception(e) 978 if : 979 e.ag_error_metadata.to_exception(E)e.ag_error_metadata.to_exception(E)978)

ValueError:在用户代码中:

ValueError:输入0与层resnet50不兼容:预期的shape=(None,32,32,3),found shape=(None,224,224,3)

EN

回答 1

Stack Overflow用户

发布于 2021-07-20 09:46:19

模型的输入仍然是(32,32,3)

代码语言:javascript
复制
input_t = K.Input(shape = (32, 32, 3))
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67508398

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档