首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在Keras视频应答机中的扁平化示例

在Keras视频应答机中的扁平化示例
EN

Stack Overflow用户
提问于 2018-08-08 22:31:58
回答 1查看 24关注 0票数 1

在凯拉斯的视频问答示例(https://keras.io/getting-started/functional-api-guide/)中,卷积神经网络末尾的vision_model.add(Flatten())做了什么?为什么需要它?

完整的源代码:

代码语言:javascript
复制
from keras.layers import Conv2D, MaxPooling2D, Flatten
from keras.layers import Input, LSTM, Embedding, Dense
from keras.models import Model, Sequential

# First, let's define a vision model using a Sequential model.
# This model will encode an image into a vector.
vision_model = Sequential()
vision_model.add(Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(224, 224, 3)))
vision_model.add(Conv2D(64, (3, 3), activation='relu'))
vision_model.add(MaxPooling2D((2, 2)))
vision_model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
vision_model.add(Conv2D(128, (3, 3), activation='relu'))
vision_model.add(MaxPooling2D((2, 2)))
vision_model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
vision_model.add(Conv2D(256, (3, 3), activation='relu'))
vision_model.add(Conv2D(256, (3, 3), activation='relu'))
vision_model.add(MaxPooling2D((2, 2)))
vision_model.add(Flatten())

然后,稍后

代码语言:javascript
复制
from keras.layers import TimeDistributed

video_input = Input(shape=(100, 224, 224, 3))
# This is our video encoded via the previously trained vision_model (weights are reused)
encoded_frame_sequence = TimeDistributed(vision_model)(video_input)  # the output will be a sequence of vectors
encoded_video = LSTM(256)(encoded_frame_sequence)  # the output will be a vector
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-08-08 22:36:42

运行:

代码语言:javascript
复制
vision_model.summary()

我们得到:

代码语言:javascript
复制
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 224, 224, 64)      1792      
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 222, 222, 64)      36928     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 111, 111, 64)      0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 111, 111, 128)     73856     
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 109, 109, 128)     147584    
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 54, 54, 128)       0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 54, 54, 256)       295168    
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 52, 52, 256)       590080    
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 50, 50, 256)       590080    
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 25, 25, 256)       0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 160000)            0         
=================================================================
Total params: 1,735,488
Trainable params: 1,735,488
Non-trainable params: 0

vision_model.add(Flatten())vision_model.add(MaxPooling2D((2, 2)))从(None,25,25,256)展平为(None,160000)

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51749280

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档