首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >计算精度(在lstm模型中)

计算精度(在lstm模型中)
EN

Stack Overflow用户
提问于 2021-05-01 00:00:39
回答 1查看 69关注 0票数 0

我正在尝试获得我的准确性,我有这样的代码:

代码语言:javascript
复制
            num_correct = 0.0
            for inputs, labels in dataloader(
            valid_features, valid_labels, batch_size=batch_size, sequence_length=20):

                top_val, top_class = torch.exp(output).topk(1)
                num_correct += torch.sum(top_class.squeeze() == labels)
#...
            print(#...,
                  "Accuracy: {:.3f}".format(num_correct*1.0 / len(valid_labels) *1.0)

它总是打印0.000,所以我决定将原始值打印到num_correctprint(top_class.squeeze(), labels)

代码语言:javascript
复制
tensor([ 1,  3,  3,  ...,  3,  4,  3], device='cuda:0') tensor([ 1,  1,  3,  ...,  3,  3,  3], device='cuda:0')
tensor([ 4,  3,  1,  ...,  4,  4,  3], device='cuda:0') tensor([ 4,  3,  1,  ...,  4,  4,  3], device='cuda:0')
tensor([ 2,  4,  2,  ...,  4,  4,  4], device='cuda:0') tensor([ 3,  4,  1,  ...,  4,  4,  4], device='cuda:0')
tensor([ 0,  1,  3,  ...,  2,  3,  0], device='cuda:0') tensor([ 0,  1,  3,  ...,  2,  2,  0], device='cuda:0')

这些数据看起来非常准确。所以..。我可以将其提取到numpy并完成,但有一种pytorch方法。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-05-01 02:10:58

这是可行的:

代码语言:javascript
复制
num_correct += torch.sum(torch.eq(top_class.squeeze(), labels)).item()
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67337015

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档