首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >输出MLP预测评分

输出MLP预测评分
EN

Stack Overflow用户
提问于 2022-08-17 23:29:11
回答 2查看 47关注 0票数 1

按照keras教程,MLP分类在这里:分类/分类

我能够成功地训练一个模型,并使用下面的代码打印出3种预测标签。我也想打印预测分数。在文档中,我似乎找不到如何做到这一点。

代码语言:javascript
复制
# Create a model for inference.
model_for_inference = keras.Sequential([text_vectorizer, shallow_mlp_model])

# Create a small dataset just for demoing inference.
inference_dataset = make_dataset(test_df.sample(100), is_train=False)
text_batch, label_batch = next(iter(inference_dataset))
predicted_probabilities = model_for_inference.predict(text_batch)

# Perform inference.
for i, text in enumerate(text_batch[:5]):
    label = label_batch[i].numpy()[None, ...]
    print(f"Abstract: {text}")
    predicted_proba = [proba for proba in predicted_probabilities[i]]
    top_3_labels = [
        x
        for _, x in sorted(
            zip(predicted_probabilities[i], lookup.get_vocabulary()),
            key=lambda pair: pair[0],
            reverse=True,
        )
    ][:3]
    print(f"Predicted Label(s): ({', '.join([label for label in top_3_labels])})")
    print(" ")
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2022-08-18 08:24:14

若要获取概率,请更改此部分:

代码语言:javascript
复制
top_3_labels = [
        x
        for _, x in sorted(
            zip(predicted_probabilities[i], lookup.get_vocabulary()),
            key=lambda pair: pair[0],
            reverse=True,
        )
    ][:3]
    print(f"Predicted Label(s): ({', '.join([label for label in top_3_labels])})")
    print(" ")

对此:

代码语言:javascript
复制
top_3_labels = [
        (p, x)
        for p, x in sorted(
            zip(predicted_probabilities[i], lookup.get_vocabulary()),
            key=lambda pair: pair[0],
            reverse=True,
        )
    ][:3]
    print(f"Predicted Label(s): ({', '.join([l[1] for l in top_3_labels])})")
    print(f"Predicted Probabilities(s): ({', '.join([l[0] for l in top_3_labels])})")
    print(" ")
票数 1
EN

Stack Overflow用户

发布于 2022-08-17 23:45:16

您可以预测最有可能的标签(将[:3]更改为[0]),并使用一些学习函数来获得标准的准确性。

此外,您还可以获取标签的所有概率,并使用滑雪板的顶K精度

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73395917

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档