首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >有状态LSTM:何时重置状态?

有状态LSTM:何时重置状态?
EN

Stack Overflow用户
提问于 2017-08-10 21:08:13
回答 2查看 4.7K关注 0票数 9

给定具有维数的X (m样本、n序列和k特征)和具有维数的y标签(m样本,0/1)

假设我想要训练一个有状态的LSTM (按照keras定义,其中“state= True”意味着每个样本之间的序列之间不重置单元状态-如果我错了,请纠正我!),是每个时代还是每个样本基础上重新设置状态?

示例:

代码语言:javascript
复制
for e in epoch:
    for m in X.shape[0]:          #for each sample
        for n in X.shape[1]:      #for each sequence
            #train_on_batch for model...
            #model.reset_states()  (1) I believe this is 'stateful = False'?
        #model.reset_states()      (2) wouldn't this make more sense?
    #model.reset_states()          (3) This is what I usually see...

总之,我不确定是在每个序列或每个时代之后重置状态(在所有m个样本被训练成X之后)。

我们非常感谢你的建议。

EN

回答 2

Stack Overflow用户

发布于 2017-08-10 21:31:23

如果使用stateful=True,通常会在每个时代结束时或每两个样本重新设置状态。如果您想在每个示例之后重置状态,则这相当于使用stateful=False

关于您提供的循环:

代码语言:javascript
复制
for e in epoch:
    for m in X.shape[0]:          #for each sample
        for n in X.shape[1]:      #for each sequence

请注意,X的维数不完全是

代码语言:javascript
复制
 (m samples, n sequences, k features)

维数实际上是

代码语言:javascript
复制
(batch size, number of timesteps, number of features)

因此,您不应该有内环:

代码语言:javascript
复制
for n in X.shape[1]

现在,关于循环

代码语言:javascript
复制
for m in X.shape[0]

由于批处理的枚举是在keras中自动完成的,所以您也不必实现这个循环(除非您希望每两个样本重置状态)。因此,如果您只想在每个时代结束时重置,您只需要外部循环。

以下是这种体系结构的一个示例(取自这篇博客文章):

代码语言:javascript
复制
batch_size = 1
model = Sequential()
model.add(LSTM(16, batch_input_shape=(batch_size, X.shape[1], X.shape[2]), stateful=True))
model.add(Dense(y.shape[1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
for i in range(300):
    model.fit(X, y, epochs=1, batch_size=batch_size, verbose=2, shuffle=False)
    model.reset_states()
票数 8
EN

Stack Overflow用户

发布于 2022-02-21 20:57:54

或者,似乎可以进行自定义回调。这避免了在循环中调用fit,这是一项代价很高的工作。类似于Tensorflow LSTM/GRU每一时期重置一次,而不是每一批新批的东西

代码语言:javascript
复制
gru_layer = model.layers[1]

class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, gru_layer):
        self.gru_layer = gru_layer
   def on_epoch_end(self, epoch, logs=None):
        self.gru_layer.reset_states()
        
model.fit(train_dataset, validation_data=validation_dataset, \
    epochs=EPOCHS, callbacks = [EarlyS, CustomCallback(gru_layer)], verbose=1)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45623480

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档