首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >获取内部Keras层的预测值。

获取内部Keras层的预测值。
EN

Stack Overflow用户
提问于 2022-06-10 13:00:23
回答 1查看 94关注 0票数 1

我有一个像这样的TensorFlow模型-

我希望知道特定输入的红色标记层(5个浮点值)的值,以检查模型在此层(注意层)如何响应。我需要这个值,这样我就可以知道我的注意力层是否正确地提取了值。

由于模型是一个端到端模型,我不知道如何提取内部层的值以供特定的输入。有人能帮忙吗?

EN

回答 1

Stack Overflow用户

发布于 2022-06-10 13:19:44

您可以编写class Callback,然后传递输入并检查所需每一层的输出:

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self):
        self.data = np.random.rand(1,10)
   def on_epoch_end(self, epoch, logs=None):
        dns_layer = self.model.layers[6]
        outputs = dns_layer(self.data)
        tf.print(f'\n input: {self.data}')
        tf.print(f'\n output: {outputs}')


x_train = tf.random.normal((10, 32, 32))
y_train = tf.random.uniform((10, 1), maxval=10)

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.LSTM(256, input_shape=(x_train.shape[1], x_train.shape[2]), return_sequences=True))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.LSTM(256))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(10, activation='softmax')) 
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(5, activation='softmax'))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(10, activation='softmax')) 
model.compile(optimizer='adam', loss = tf.keras.losses.SparseCategoricalCrossentropy(False))
model.summary()    

for layer in model.layers:
    print(layer)

model.fit(x_train, y_train , epochs=3, callbacks=[CustomCallback()], batch_size=32)

输出:

代码语言:javascript
复制
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 lstm (LSTM)                 (None, 32, 256)           295936    
                                                                 
 dropout (Dropout)           (None, 32, 256)           0         
                                                                 
 lstm_1 (LSTM)               (None, 256)               525312    
                                                                 
 dropout_1 (Dropout)         (None, 256)               0         
                                                                 
 dense (Dense)               (None, 10)                2570      
                                                                 
 dropout_2 (Dropout)         (None, 10)                0         
                                                                 
 dense_1 (Dense)             (None, 5)                 55        
                                                                 
 dropout_3 (Dropout)         (None, 5)                 0         
                                                                 
 dense_2 (Dense)             (None, 10)                60        
                                                                 
=================================================================
Total params: 823,933
Trainable params: 823,933
Non-trainable params: 0
_________________________________________________________________
<keras.layers.recurrent_v2.LSTM object at 0x7f6e2163dbd0>
<keras.layers.core.dropout.Dropout object at 0x7f6da1d2efd0>
<keras.layers.recurrent_v2.LSTM object at 0x7f6d9dfe0a50>
<keras.layers.core.dropout.Dropout object at 0x7f6d9de1ec90>
<keras.layers.core.dense.Dense object at 0x7f6d9de04dd0>
<keras.layers.core.dropout.Dropout object at 0x7f6d9dd549d0>
<keras.layers.core.dense.Dense object at 0x7f6d9dd8ec90>
<keras.layers.core.dropout.Dropout object at 0x7f6d9dedd650>
<keras.layers.core.dense.Dense object at 0x7f6d9ddc2ed0>
Epoch 1/3
1/1 [==============================] - ETA: 0s - loss: 2.4188
 input: [[0.91498145 0.98430978 0.22720893 0.76032816 0.78405846 0.72664182
  0.7772921  0.9851892  0.41715033 0.21014543]]

 output: [[0.5767021  0.04140956 0.1909151  0.06737834 0.12359484]]
1/1 [==============================] - 12s 12s/step - loss: 2.4188
Epoch 2/3
1/1 [==============================] - ETA: 0s - loss: 2.4111
 input: [[0.91498145 0.98430978 0.22720893 0.76032816 0.78405846 0.72664182
  0.7772921  0.9851892  0.41715033 0.21014543]]

 output: [[0.5780218  0.04101932 0.18909878 0.06769065 0.12416941]]
1/1 [==============================] - 0s 376ms/step - loss: 2.4111
Epoch 3/3
1/1 [==============================] - ETA: 0s - loss: 2.3978
 input: [[0.91498145 0.98430978 0.22720893 0.76032816 0.78405846 0.72664182
  0.7772921  0.9851892  0.41715033 0.21014543]]

 output: [[0.579072   0.04067017 0.1874026  0.0679936  0.12486164]]
1/1 [==============================] - 0s 458ms/step - loss: 2.3978
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72574698

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档