首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在tensorflow-hub预训练模型之后添加LSTM层

在tensorflow-hub预训练模型之后添加LSTM层
EN

Stack Overflow用户
提问于 2021-08-01 16:41:04
回答 1查看 126关注 0票数 0

我正在使用Tensorflow-hub预训练的Word2vec模型进行文本分类。我正在寻求将LSTM层添加到keras模型中。为此,我使用了以下代码:

代码语言:javascript
复制
model = tf.keras.models.Sequential()
model.add(hub.KerasLayer(hub.load('https://tfhub.dev/google/Wiki-words-250/2'), 
                        input_shape=[], 
                        dtype=tf.string, 
                        trainable=True))

添加LSTM层后:

代码语言:javascript
复制
model.add(tf.keras.layers.LSTM(32))

它向我展示了以下错误:

代码语言:javascript
复制
~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\input_spec.py in assert_input_compatibility(input_spec, inputs, layer_name)
    174       ndim = x.shape.ndims
    175       if ndim != spec.ndim:
--> 176         raise ValueError('Input ' + str(input_index) + ' of layer ' +
    177                          layer_name + ' is incompatible with the layer: '
    178                          'expected ndim=' + str(spec.ndim) + ', found ndim=' +

ValueError: Input 0 of layer lstm_0 is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: [None, 250]

任何帮助都是值得欣赏的。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-08-01 20:42:27

您可以调整hub.KerasLayer的输出

代码语言:javascript
复制
model.add(hub.KerasLayer(hub.load('https://tfhub.dev/google/Wiki-words-250/2'), 
                        input_shape=[], 
                        dtype=tf.string, 
                        trainable=True))

model.add(tf.keras.layers.Reshape((250, 1)))
model.add(tf.keras.layers.LSTM(32))

model.summary()

Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer_4 (KerasLayer)   (None, 250)               252343750 
_________________________________________________________________
reshape_2 (Reshape)          (None, 250, 1)            0         
_________________________________________________________________
lstm_2 (LSTM)                (None, 32)                4352      
=================================================================
Total params: 252,348,102
Trainable params: 252,348,102
Non-trainable params: 0
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68612469

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档