首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >model.evaluation : model.prediction与Keras损失不匹配

model.evaluation : model.prediction与Keras损失不匹配
EN

Stack Overflow用户
提问于 2020-03-26 08:04:51
回答 2查看 531关注 0票数 1

我应用了本教程https://github.com/Hvass-Labs/TensorFlow-Tutorials/blob/master/23_Time-Series-Prediction.ipynb (在不同的数据集上),该模型不计算单个输出的均方误差,因此我在比较函数中添加了以下行:

代码语言:javascript
复制
    mean_squared_error(signal_true,signal_pred)

但在测试数据上,预测的损失和均方误差与model.evaluation的损失和均方误差不同。来自model.evaluation (损失、mae、mse)的错误(测试集):

代码语言:javascript
复制
    [0.013499056920409203, 0.07980187237262726, 0.013792216777801514]

来自单个目标(输出)的错误:

代码语言:javascript
复制
    Target0 0.167851388666284
    Target1 0.6068108648555771
    Target2 0.1710370357827747
    Target3 2.747463225418181
    Target4 1.7965991690103074
    Target5 0.9065426398192563 

我认为在训练模型时可能会有问题,但我找不到它的确切位置。我真的很感谢你的帮助。

谢谢

EN

回答 2

Stack Overflow用户

发布于 2020-03-26 13:04:09

培训损失和评估损失之间可能存在差异的原因有很多。

  • 某些操作,如批处理规范化,在预测时被禁用-这可以在某些体系结构中产生很大的差异,尽管如果您正确使用批处理规范,通常不会这样做。用于训练的
  • 均方误差是在整个时期内平均的,而评估只发生在模型的最新“最佳”版本上。
  • 如果分割不是随机的,则可能是由于数据集中的差异。
  • 您可能在没有意识到的情况下使用了不同的度量。

我不确定你到底遇到了什么问题,但它可能是由许多不同的东西引起的,而且通常很难调试。

票数 1
EN

Stack Overflow用户

发布于 2020-08-17 23:37:17

我也遇到了同样的问题,并找到了解决方案。希望这和你遇到的问题是一样的。

事实证明,model.predict不会像generator.labels那样以相同的顺序返回预测,这就是为什么当我尝试手动计算(使用scikit learn度量函数)时,MSE要大得多。

代码语言:javascript
复制
>>> model.evaluate(valid_generator, return_dict=True)['mean_squared_error']
13.17293930053711
>>> mean_squared_error(valid_generator.labels, model.predict(valid_generator)[:,0])
91.1225401637833

我的快速和肮脏的解决方案:

代码语言:javascript
复制
valid_generator.reset()  # Necessary for starting from first batch
all_labels = []
all_pred = []
for i in range(len(valid_generator)):  # Necessary for avoiding infinite loop
    x = next(valid_generator)
    pred_i = model.predict(x[0])[:,0]
    labels_i = x[1]
    all_labels.append(labels_i)
    all_pred.append(pred_i)
    print(np.shape(pred_i), np.shape(labels_i))

cat_labels = np.concatenate(all_labels)
cat_pred = np.concatenate(all_pred)

结果是:

代码语言:javascript
复制
>>> mean_squared_error(cat_labels, cat_pred)
13.172956865002352

这可以做得更优雅,但足以让我确认我对问题的假设,并恢复一些理智。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60859191

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档