首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow 2.0:如何在功能级别实现具有融合的网络?

Tensorflow 2.0:如何在功能级别实现具有融合的网络?
EN

Stack Overflow用户
提问于 2020-03-30 08:46:48
回答 1查看 362关注 0票数 1

我正在尝试在tensorflow中为预测任务实现一个小模型,其中有两个信号作为输入,这两个信号分别通过几个层,然后在后面的层中组合以生成输出预测。从本质上讲,模型的工作原理如下:

代码语言:javascript
复制
(Signal A) -> [L 1] -> [L 2] -> ... -> [L k] 
                                            \
                                             \
                                               -> [L k+1] ->...-> [Final Layer] -> Output
                                             /
                                            /
(Signal B) -> [L 1] -> [L 2] -> ... -> [L k]

其中L i是网络的不同层。在融合之前,网络的第一部分对于两个信号是相同的。在TensorFlow2.0中实现此模型的正确方法是什么?我认为Sequential在这个场景中不是一个选项,但是我可以通过Functional API来实现吗?或者我应该通过Model Subclassing来实现?根据我所读到的,这两种方法似乎没有太大的区别。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-03-30 08:59:03

这是functional API中模型的模板,您可以根据需要更改层。

您的基本模型(两者通用)-

代码语言:javascript
复制
from tensorflow.keras.layers import Input, Conv1D, Concatenate, MaxPooling1D, Flatten, Dense, GlobalMaxPooling1D, subtract, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l1, l2
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.backend as K

# baseline model

input_shape = (256, 1) # assuming your signals have length 256, 1 channel

# conv base model
sig_input = Input(input_shape)
cnn1 = Conv1D(64,3,activation='relu',input_shape=input_shape, kernel_regularizer=l2(2e-4))(sig_input)
mp1 = MaxPooling1D()(cnn1)
mp1 = BatchNormalization()(mp1)
cnn2 = Conv1D(128,3,activation='relu', kernel_regularizer=l2(2e-4))(mp1)
mp2 = MaxPooling1D()(cnn2)
mp2 = BatchNormalization()(mp2)
cnn3 = Conv1D(128,3,activation='relu', kernel_regularizer=l2(2e-4))(mp2)
mp3 = MaxPooling1D()(cnn3)
mp3 = BatchNormalization()(mp3)
cnn4 = Conv1D(256,3,activation='relu', kernel_regularizer=l2(2e-4))(mp3)
mp4 = MaxPooling1D()(cnn4)
mp4 = BatchNormalization()(mp4)
flat = Flatten()(mp4)
embed = Dense(64, activation="sigmoid")(flat)

conv_base = Model(sig_input, embed)

conv_base.summary()

网络摘要:

代码语言:javascript
复制
Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_6 (InputLayer)         [(None, 256, 1)]          0         
_________________________________________________________________
conv1d_12 (Conv1D)           (None, 254, 64)           256       
_________________________________________________________________
max_pooling1d_12 (MaxPooling (None, 127, 64)           0         
_________________________________________________________________
batch_normalization_12 (Batc (None, 127, 64)           256       
_________________________________________________________________
conv1d_13 (Conv1D)           (None, 125, 128)          24704     
_________________________________________________________________
max_pooling1d_13 (MaxPooling (None, 62, 128)           0         
_________________________________________________________________
batch_normalization_13 (Batc (None, 62, 128)           512       
_________________________________________________________________
conv1d_14 (Conv1D)           (None, 60, 128)           49280     
_________________________________________________________________
max_pooling1d_14 (MaxPooling (None, 30, 128)           0         
_________________________________________________________________
batch_normalization_14 (Batc (None, 30, 128)           512       
_________________________________________________________________
conv1d_15 (Conv1D)           (None, 28, 256)           98560     
_________________________________________________________________
max_pooling1d_15 (MaxPooling (None, 14, 256)           0         
_________________________________________________________________
batch_normalization_15 (Batc (None, 14, 256)           1024      
_________________________________________________________________
flatten_3 (Flatten)          (None, 3584)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 64)                229440    
=================================================================
Total params: 404,544
Trainable params: 403,392
Non-trainable params: 1,152

第二个聚变网络-

代码语言:javascript
复制
left_input = Input(input_shape)
right_input = Input(input_shape)

# encode each of the two inputs into a vector with the base conv model
encoded_l = conv_base(left_input)
encoded_r = conv_base(right_input)



fusion = Concatenate()([encoded_l,encoded_r]) # this can be any other fusion method too

prediction = Dense(1, activation='sigmoid')(fusion)

twin_net = Model([left_input,right_input],prediction)

optimizer = Adam(0.001)

twin_net.compile(loss="binary_crossentropy",optimizer=optimizer)

twin_net.summary()
代码语言:javascript
复制
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_7 (InputLayer)            [(None, 256, 1)]     0                                            
__________________________________________________________________________________________________
input_8 (InputLayer)            [(None, 256, 1)]     0                                            
__________________________________________________________________________________________________
model_2 (Model)                 (None, 64)           404544      input_7[0][0]                    
                                                                 input_8[0][0]                    
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 128)          0           model_2[1][0]                    
                                                                 model_2[2][0]                    
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 1)            129         concatenate[0][0]                
==================================================================================================
Total params: 404,673
Trainable params: 403,521
Non-trainable params: 1,152
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60922607

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档