我在YouTube https://www.youtube.com/watch?v=jx9iyQZhSwI上观看了以下视频,其中显示可以在Tensorflow中使用G收音机和MNIST数据集的学习模型。我已经读过和写过,它是有可能使用在Gradio,但我有问题,它的实施。有人知道怎么做吗?我的cnn火炬代码
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=5,
stride=1,
padding=2,
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2),
)
# fully connected layer, output 10 classes
self.out = nn.Linear(32 * 7 * 7, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
# flatten the output of conv2 to (batch_size, 32 * 7 * 7)
x = x.view(x.size(0), -1)
output = self.out(x)
return output, x # return x for visualization通过观看,我发现我需要改变Gradio使用的功能
def predict_image(img):
img_3d=img.reshape(-1,28,28)
im_resize=img_3d/255.0
prediction=CNN(im_resize)
pred=np.argmax(prediction)
return pred发布于 2022-04-05 11:52:03
如果我把你的问题搞错了,我很抱歉,但据我所知,当你试图用你的函数预测图像时,你会出现一个错误。
这里有两个可能的提示。也许您已经实现了它们,但我不知道,因为代码片段非常小。
首先。是否将模型设置为评估模式?
CNN.eval()在您完成模型的培训之后,要评估输入,而不需要对模型进行培训。
其次,也许您需要在输入张量"im_resize“中添加第四个维度。通常,您的模型期望输入的通道数、批大小、高度和宽度都有一个维度。此外,我无法判断您的输入是否属于数据类型torch.tensor。如果没有,首先将数组转换为张量。
可以将批处理维度添加到输入张量中,方法是
im_resize = im_resize.unsqueeze(0)我希望我能正确理解你的问题,并能帮助你。
https://stackoverflow.com/questions/71629683
复制相似问题