首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >GRadi-Py手电筒MNIST数字识别器

GRadi-Py手电筒MNIST数字识别器
EN

Stack Overflow用户
提问于 2022-03-26 16:07:28
回答 1查看 153关注 0票数 1

我在YouTube https://www.youtube.com/watch?v=jx9iyQZhSwI上观看了以下视频,其中显示可以在Tensorflow中使用G收音机和MNIST数据集的学习模型。我已经读过和写过,它是有可能使用在Gradio,但我有问题,它的实施。有人知道怎么做吗?我的cnn火炬代码

代码语言:javascript
复制
import torch.nn as nn
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(16, 32, 5, 1, 2),     
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output, x    # return x for visualization

通过观看,我发现我需要改变Gradio使用的功能

代码语言:javascript
复制
def predict_image(img):
  img_3d=img.reshape(-1,28,28)
  im_resize=img_3d/255.0
  prediction=CNN(im_resize)
  pred=np.argmax(prediction)
  return pred
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-04-05 11:52:03

如果我把你的问题搞错了,我很抱歉,但据我所知,当你试图用你的函数预测图像时,你会出现一个错误。

这里有两个可能的提示。也许您已经实现了它们,但我不知道,因为代码片段非常小。

首先。是否将模型设置为评估模式?

代码语言:javascript
复制
CNN.eval()

在您完成模型的培训之后,要评估输入,而不需要对模型进行培训。

其次,也许您需要在输入张量"im_resize“中添加第四个维度。通常,您的模型期望输入的通道数、批大小、高度和宽度都有一个维度。此外,我无法判断您的输入是否属于数据类型torch.tensor。如果没有,首先将数组转换为张量。

可以将批处理维度添加到输入张量中,方法是

代码语言:javascript
复制
im_resize = im_resize.unsqueeze(0)

我希望我能正确理解你的问题,并能帮助你。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71629683

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档