首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras LSTM在线学习

Keras LSTM在线学习
EN

Stack Overflow用户
提问于 2020-02-19 04:20:18
回答 1查看 1.2K关注 0票数 4

我有一个非常可预测的序列。在下面,您可以看到其中的一部分:

代码语言:javascript
复制
deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4])

基本上,这是一个变化的,但仍然可以预测的数目4s,其次是8,和28每三个8s。

我想要为在线预测建立一个非常简单的LSTM模型:每次一个新的数字到达,它就被附加在deque的右边。因此,LSTM是在由deque的0:seq_length元素组成的旧序列上训练的,训练目标是seq_length元素。然后,对1:seq_length+1元素执行窗口移位和预测。最后,deque的最左边元素被丢弃。我的直觉告诉我,这应该使网络记忆序列。

然而,我的网络只回答了4个,但令人惊讶的是,它只回答了8个,几乎所有的时间都没有。然后,一个(长),稍后,它回到回答只有4。

我的模型结构如图所示。当然,我已经为seq_length和lstm_cells尝试过不同的价值观,但没有一个给了我成功。这些数据来自于最近一轮的调查:

代码语言:javascript
复制
seq_length = 64  #Length of the sequence to be inserted into the LSTM
vocab_size = 4  #Size of the final dense layer of the model
lstm_cells = 16  #Size of the LSTM layer

model = Sequential()
model.add(LSTM(lstm_cells, input_shape=(seq_length, 1)))
model.add(Dense(vocab_size))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])

下面是如何在模型上准备、训练和预测数据的方法。变量序列是在这篇文章的开头显示的deque。我维护一个列表词汇表= 4,8,28,它是在看到新的数字时建立在执行时间上的,所以vocabi将类I转换成相应的序列数。然后,我创建一个字典图例来做相反的事情。这或多或少就是正在进行的在线循环:

代码语言:javascript
复制
while True:

    # Receives new number and puts it into the deque:
    sequence.append(generateNextNumber())

    # At this point, please note that the length of the deque is seq_length + 1.

    # Dictionary to convert numbers to classes:
    legend = dict([(v, k) for k, v in enumerate(vocab)])
    # Converts the deque into a list:
    seq_list = list(sequence)
    # Each iteration is comprised of 1 training and 1 prediction. These are the training sequence and target:
    train_seq = [ [legend[i]] for i in seq_list[:seq_length] ]
    train_target = legend[ seq_list[seq_length] ]
    # And the prediction sequence just shifts the window by 1:
    pred_seq = [ [legend[i]] for i in seq_list[1:] ]

    # Batches data into a batch of size 1:
    x = np.zeros((1, seq_length, 1))
    y = np.zeros((1, vocab_size))
    x[0,:] = train_seq
    y[0,:] = to_categorical( train_target, num_classes=vocab_size )
    # Online training:
    model.fit(x=x, y=y, batch_size=1, epochs=1, verbose=0)

    # Now that one training step is done, make a prediction:
    x_pred = np.zeros((1, seq_length, 1))
    x[0,:] = pred_seq
    predicted_onehot = model.predict(x_pred)
    # Avoids "index out of range" erros when the LSTM vocab is still being built:
    predicted_index = min(np.argmax(predicted_onehot), len(vocab)-1)
    predicted_number = vocab[ predicted_index ]
    # Reverts deque length to seq_length:
    sequence.popleft()

最后,这是一个示例输出:

代码语言:javascript
复制
HIT! Current hit rate: 34.753665869071725 (predicted: 4, sequence was: deque([4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4]))

HIT! Current hit rate: 34.75566735175926 (predicted: 4, sequence was: deque([4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4]))

Predicted 4 but it was 8

HIT! Current hit rate: 34.75660255820374 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4]))

HIT! Current hit rate: 34.758603766640086 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4]))

HIT! Current hit rate: 34.7606048523142 (predicted: 4, sequence was: deque([4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4]))

HIT! Current hit rate: 34.76260581523739 (predicted: 4, sequence was: deque([4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4]))

HIT! Current hit rate: 34.76460665542095 (predicted: 4, sequence was: deque([4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4]))

HIT! Current hit rate: 34.76660737287616 (predicted: 4, sequence was: deque([4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4]))

Predicted 4 but it was 28

HIT! Current hit rate: 34.767541707556425 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4]))

出什么事了?

先谢谢你。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-02-19 21:19:40

我有个问题要说:

代码语言:javascript
复制
x[0,:] = pred_seq

应该是:

代码语言:javascript
复制
x_pred[0,:] = pred_seq

现在,一切都或多或少地正常工作了。我仍然将这个问题留在这里,因为它提供了一些关于LSTM在线学习的很好的见解。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60292957

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档