首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在推理过程中从自定义Tensorflow/Keras层提取中间变量(TF 2.0)

在推理过程中从自定义Tensorflow/Keras层提取中间变量(TF 2.0)
EN

Stack Overflow用户
提问于 2019-11-21 22:28:46
回答 1查看 868关注 0票数 5

下面是一些背景知识:

我主要使用TensorFlow2.0的Keras函数模型实现了一个NLP分类模型。模型架构是一个非常简单的LSTM网络,在LSTM和密集输出层之间增加了一个关注层。注意力层来自this Kaggle kernel (从51行开始)。

我将经过训练的模型封装在一个简单的Flask应用程序中,并获得了相当准确的预测。除了预测特定输入的类别之外,我还从前面提到的关注层输出注意力权重向量"a“的值,这样我就可以可视化应用于输入序列的权重。

我目前提取注意力权重变量的方法是有效的,但似乎效率非常低,因为我预测输出类,然后使用中间Keras模型手动计算注意力向量。在Flask应用程序中,推理如下所示:

代码语言:javascript
复制
# Load the trained model
model = tf.keras.models.load_model('saved_model.h5')

# Extract the trained weights and biases of the trained attention layer
attention_weights = model.get_layer('attention').get_weights()

# Create an intermediate model that outputs the activations of the LSTM layer
intermediate_model = tf.keras.Model(inputs=model.input, outputs=model.get_layer('bi-lstm').output)

# Predict the output class using the trained model
model_score = model.predict(input)

# Obtain LSTM activations by predicting the output again using the intermediate model
lstm_activations = intermediate_model.predict(input)

# Use the intermediate LSTM activations and the trained model attention layer weights and biases to calculate the attention vector.  
# Maths from the custom Attention Layer (heavily modified for the sake of brevity)
eij = tf.keras.backend.dot(lstm_activations, attention_weights)
a = tf.keras.backend.exp(eij)
attention_vector = a

我认为我应该能够将注意力向量作为模型输出的一部分,但我正在努力弄清楚如何实现这一点。理想情况下,我应该在一次正向传递中从自定义注意力层提取注意力向量,而不是提取各种中间模型值并进行第二次计算。

例如:

代码语言:javascript
复制
model_score = model.predict(input)

model_score[0] # The predicted class label or probability
model_score[1] # The attention vector, a

我想我缺少一些关于Tensorflow/Keras如何抛出变量以及何时/如何访问这些值以包含为模型输出的基本知识。任何建议都将不胜感激。

EN

回答 1

Stack Overflow用户

发布于 2019-12-11 07:13:57

经过更多的研究,我终于拼凑出了一个可行的解决方案。我将在这里为任何未来疲惫的互联网旅行者总结这篇文章。

第一个线索来自this github thread.,在那里定义的注意力层似乎建立在前面提到的Kaggle内核中的注意力层上。github用户将return_attention标志添加到层初始化,当启用该标志时,除了层输出中的加权RNN输出向量外,还包括注意力向量。

我还在同一个GitHub线程中添加了this user建议的get_config函数,它使我们能够保存和重新加载经过训练的模型。我必须将return_attention标志添加到get_config,否则在尝试使用return_attention=True加载已保存的模型时,TF将抛出列表迭代错误。

进行了这些更改后,需要更新模型定义以捕获附加层输出。

代码语言:javascript
复制
inputs = Input(shape=(max_sequence_length,))
lstm = Bidirectional(LSTM(lstm1_units, return_sequences=True))(inputs)
# Added 'attention_vector' to capture the second layer output
attention, attention_vector = Attention(max_sequence_length, return_attention=True)(lstm)
x = Dense(dense_units, activation="softmax")(attention)

最后,也是最重要的一块来自this Stackoverflow answer.,这里描述的方法允许我们输出多个结果,同时只优化其中一个结果。代码的变化很微妙,但却非常重要。我已经在我为实现此功能所做的更改中添加了下面的注释。

代码语言:javascript
复制
model = Model(
    inputs=inputs,
    outputs=[x, attention_vector] # Original value:  outputs=x
    )

model.compile(
    loss=['categorical_crossentropy', None], # Original value: loss='categorical_crossentropy'
    optimizer=optimizer,
    metrics=[BinaryAccuracy(name='accuracy')])

有了这些更改,我重新训练了模型,瞧!model.predict()的输出现在是一个列表,其中包含分数及其关联的注意力向量。

变化的结果是相当戏剧性的。使用这种新方法在10k个示例上运行推理大约需要20分钟。使用中间模型的旧方法在同一数据集上执行推理需要大约33分钟。

对于任何感兴趣的人,这里是我修改后的关注层:

代码语言:javascript
复制
from tensorflow.python.keras.layers import Layer
from tensorflow.keras import initializers, regularizers, constraints
from tensorflow.keras import backend as K


class Attention(Layer):
    def __init__(self, step_dim,
                W_regularizer=None, b_regularizer=None,
                W_constraint=None, b_constraint=None,
                bias=True, return_attention=True, **kwargs):
        self.supports_masking = True
        self.init = initializers.get('glorot_uniform')

        self.W_regularizer = regularizers.get(W_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)

        self.W_constraint = constraints.get(W_constraint)
        self.b_constraint = constraints.get(b_constraint)

        self.bias = bias

        self.step_dim = step_dim
        self.features_dim = 0
        self.return_attention = return_attention
        super(Attention, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 3

        self.W = self.add_weight(shape=(input_shape[-1],),
                                 initializer=self.init,
                                 name='{}_W'.format(self.name),
                                 regularizer=self.W_regularizer,
                                 constraint=self.W_constraint)
        self.features_dim = input_shape[-1]

        if self.bias:
            self.b = self.add_weight(shape=(input_shape[1],),
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.b_regularizer,
                                     constraint=self.b_constraint)
        else:
            self.b = None

        self.built = True

    def compute_mask(self, input, input_mask=None):
        return None

    def call(self, x, mask=None):
        features_dim = self.features_dim
        step_dim = self.step_dim

        eij = K.reshape(K.dot(K.reshape(x, (-1, features_dim)),
                              K.reshape(self.W, (features_dim, 1))), (-1, step_dim))

        if self.bias:
            eij += self.b

        eij = K.tanh(eij)

        a = K.exp(eij)

        if mask is not None:
            a *= K.cast(mask, K.floatx())

        a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())

        a = K.expand_dims(a)
        weighted_input = x * a
        result = K.sum(weighted_input, axis=1)

        if self.return_attention:
            return [result, a]
        return result

    def compute_output_shape(self, input_shape):
        if self.return_attention:
            return [(input_shape[0], self.features_dim),
                    (input_shape[0], input_shape[1])]
        else:
            return input_shape[0], self.features_dim

    def get_config(self):
        config = {
            'step_dim': self.step_dim,
            'W_regularizer': regularizers.serialize(self.W_regularizer),
            'b_regularizer': regularizers.serialize(self.b_regularizer),
            'W_constraint': constraints.serialize(self.W_constraint),
            'b_constraint': constraints.serialize(self.b_constraint),
            'bias': self.bias,
            'return_attention': self.return_attention
        }

        base_config = super(Attention, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58977262

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档