首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >时间分布式层keras

时间分布式层keras
EN

Stack Overflow用户
提问于 2020-03-30 05:01:32
回答 1查看 778关注 0票数 0

我正在尝试理解keras/tensorflow中的时间分布层。据我所知,它是一种包装器,可以例如处理一系列图像。

现在我想知道如何在不使用时间分布层的情况下设计一个时间分布式网络。

例如,如果我有一个由3个图像组成的序列,每个图像都有1个通道,像素尺寸为256x256px,那么首先应该由CNN处理,然后由LSTM单元处理。我对时间分布层的输入将是(N,3,256,256,1),其中N是批处理大小。

然后CNN将有3个输出,这些输出被馈送到LSTM单元。

现在,在不使用时间分布层的情况下,是否可以通过设置一个具有3个不同输入和3个相似CNN的网络来完成相同的任务?然后,3个CNN的输出可以被平坦和连接。

这与时间分布式方法有什么不同吗?

提前谢谢你,

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-03-31 02:05:19

我为你创造了一个原型。我使用了最少的层和任意的单元/内核/过滤器,你可以随意改变它们。它首先创建一个cnn模型,该模型接受大小为(256,256,1)的输入。它使用相同的cnn模型3次(对于序列中的三张图像)来提取特征。它使用Lambda层来堆叠所有的功能,以将其放回序列中。然后,该序列通过LSTM层。我选择让LSTM为每个示例返回一个特征向量,但如果您希望输出也是一个序列,则可以将其更改为return_sequences=True。您还可以添加最终的附加层以使其适应您的需求。

代码语言:javascript
复制
from tensorflow.keras.layers import Input, LSTM, Conv2D, Flatten, Lambda
from tensorflow.keras import Model
import tensorflow.keras.backend as K

def create_cnn_model():
  inp = Input(shape=(256,256,1))
  x = Conv2D(filters=16, kernel_size=5, strides=2)(inp)
  x = Flatten()(x)
  model = Model(inputs=inp, outputs=x, name='cnn_Model')
  return model


def combined_model():
  cnn_model = create_cnn_model()
  inp_1 = Input(shape=(256,256,1))
  inp_2 = Input(shape=(256,256,1))
  inp_3 = Input(shape=(256,256,1))

  out_1 = cnn_model(inp_1)
  out_2 = cnn_model(inp_2)
  out_3 = cnn_model(inp_3)

  lstm_inp = [out_1, out_2, out_3]
  lstm_inp = Lambda(lambda x: K.stack(x, axis=-2))(lstm_inp)
  x = LSTM(units=32, return_sequences=False)(lstm_inp)

  model = Model(inputs=[inp_1, inp_2, inp_3], outputs=x)
  return model

现在按如下方式创建模型:

代码语言:javascript
复制
model = combined_model()

检查摘要:

代码语言:javascript
复制
model.summary()

它将打印:

代码语言:javascript
复制
Model: "model_14"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_53 (InputLayer)           [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
input_54 (InputLayer)           [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
input_55 (InputLayer)           [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
cnn_Model (Model)               (None, 254016)       416         input_53[0][0]                   
                                                                 input_54[0][0]                   
                                                                 input_55[0][0]                   
__________________________________________________________________________________________________
lambda_3 (Lambda)               (None, 3, 254016)    0           cnn_Model[1][0]                  
                                                                 cnn_Model[2][0]                  
                                                                 cnn_Model[3][0]                  
__________________________________________________________________________________________________
lstm_13 (LSTM)                  (None, 32)           32518272    lambda_3[0][0]                   
==================================================================================================
Total params: 32,518,688
Trainable params: 32,518,688
Non-trainable params: 0

内部cnn模型摘要可以打印出来:

代码语言:javascript
复制
model.get_layer('cnn_Model').summary()

它当前打印:

代码语言:javascript
复制
Model: "cnn_Model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_52 (InputLayer)        [(None, 256, 256, 1)]     0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 126, 126, 16)      416       
_________________________________________________________________
flatten_6 (Flatten)          (None, 254016)            0         
=================================================================
Total params: 416
Trainable params: 416
Non-trainable params: 0
_________________________

您的模型需要一个列表作为输入。该列表的长度应为3 (因为一个序列中有3个图像)。列表的每个元素都应该是形状的numpy数组(batch_size,256,256,1)。我在下面使用了一个批量大小为1的虚拟示例:

代码语言:javascript
复制
import numpy as np

a = np.zeros((256,256,1)) # first image filled with zeros
b = np.zeros((256,256,1)) # second image filled with zeros
c = np.zeros((256,256,1)) # third image filled with zeros

a = np.expand_dims(a, 0) # adding batch dimension to make it (1, 256, 256, 1)
b = np.expand_dims(b, 0) # same here
c = np.expand_dims(c, 0) # same here


model.compile(loss='mse', optimizer='adam')
# train your model with model.fit(....)

e = model.predict([a,b,c]) # a,b and c have shape of (1, 256, 256, 1) where the first 1 is the batch size
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60920672

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档