首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >我在tensorflow中对Densenet的配置是否错误?

我在tensorflow中对Densenet的配置是否错误?
EN

Stack Overflow用户
提问于 2021-12-01 11:36:13
回答 1查看 44关注 0票数 0

当我运行下面粘贴的代码时,模型只是训练“乘数”=1或=4。在google colab→中运行相同的代码只是训练multiplier=1

我在这里使用DenseNet的方式有什么错误吗?

提前感谢,感谢您的帮助!

代码语言:javascript
复制
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.densenet import DenseNet201
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy


random_array = np.random.rand(128,128,3)
image = tf.convert_to_tensor(
    random_array
)
label = tf.constant(0)



model = DenseNet201(
    include_top=False, weights='imagenet', input_tensor=None,
    input_shape=(128, 128, 3), pooling=None, classes=2
)
model.compile(
optimizer=Adam(),
loss=BinaryCrossentropy(),
metrics=['accuracy'],
)


for multiplier in range(1,20):

    print(f"Using multiplier {multiplier}")
    x_train = np.array([image]*multiplier)
    y_train = np.array([label]*multiplier)



    try: 
        model.fit(x=x_train,y=y_train, epochs=2)
    except:
        print("Not training...")
        pass

如果训练没有开始,则输出为:

代码语言:javascript
复制
2021-12-01 11:48:40.372387: W tensorflow/core/framework/op_kernel.cc:1680] Invalid argument: required broadcastable shapes
2021-12-01 11:48:40.372660: W tensorflow/core/framework/op_kernel.cc:1680] Invalid argument: required broadcastable shapes
2021-12-01 11:48:40.372734: W tensorflow/core/framework/op_kernel.cc:1680] Invalid argument: required broadcastable shapes
EN

回答 1

Stack Overflow用户

发布于 2021-12-02 11:35:52

显然,如果使用自定义GlobalAveragePooling (而不是ImageNet的标准224x224x3 )和include_top = False,则有必要添加自定义input_shape和致密层:

代码语言:javascript
复制
base_model = DenseNet201(
    include_top=False, weights='imagenet', input_tensor=None,
    input_shape=(128, 128, 3),
    pooling=None, classes=2
)

x= base_model.output
x = GlobalAveragePooling2D(name = "avg_pool")(x)
outputs = Dense(2, activation=tf.nn.softmax, name="predictions")(x)

model = Model(base_model.input, outputs)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70183494

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档