首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >我尝试使用functional API在tensorflow 2.x中创建模型,但得到LSTM层不兼容错误

我尝试使用functional API在tensorflow 2.x中创建模型,但得到LSTM层不兼容错误
EN

Stack Overflow用户
提问于 2020-07-25 10:12:43
回答 1查看 32关注 0票数 0

错误读取:

代码语言:javascript
复制
Input 0 of layer lstm_28 is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: [None, None, 15, 12]

在LSTM层,输入tf.nn.embedding_lookup(embedding, neighbor)的形状=( 15,12 ),一个None是批量大小,它是如何得到None,None,15,12的大小的?如何处理这个错误?下面是我创建的虚拟模型。

代码语言:javascript
复制
    def create_model(embedding, embedding_dim, samp_size):
        
        
        node = Input(shape=(None,), dtype=tf.int64)
        neighbor = Input(shape=(None, samp_size), dtype=tf.int64)
        label = Input(shape=(None,), dtype=tf.int64)
        
        cell = LSTMCell(embedding_dim,)
        _,h,c = LSTM(embedding_size, return_sequences=True, return_state=True)(tf.nn.embedding_lookup(embedding, neighbor))
        predict_info = tf.squeeze(Dense(1, activation='relu'))(h)
        
        return h
    
    
    
    node_size = 1000
    embedding_dim = 12
    sampling_size = 15
    embedding = tf.random.uniform([node_size, embedding_dim])
    
    model = create_model (embedding, embedding_dim, sampling_size)
EN

回答 1

Stack Overflow用户

发布于 2020-07-25 11:11:14

使用Keras functional API时,请勿将批处理维包括为None。例如,如果您的输入是维度(batch_size、image_w、image_h、image_channels),则如下所示:

代码语言:javascript
复制
inp = tf.keras.Input(shape=(IMG_W, IMG_H, IMG_CH))
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63083607

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档