首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Skorch:帮助为多个输出构造分类器

Skorch:帮助为多个输出构造分类器
EN

Stack Overflow用户
提问于 2019-07-26 15:05:29
回答 1查看 676关注 0票数 3

我正试图通过翻译一个简单的py手电模型来学习skorch,该模型预测了一组MNIST多位数字图片中包含的2位数。这些图片包含两个重叠的数字,这是输出标签(y)。我得到了以下错误:

ValueError: Stratified CV requires explicitely passing a suitable y

我遵循了“使用SciKit学习和skorch的MNIST”笔记本,并通过创建自定义的get_loss函数应用了“从前返回多个值”中概述的多个输出修复。

数据范围包括:

  • X:(40000, 1, 4, 28)
  • y:(40000, 2)

代码:

代码语言:javascript
复制
class Flatten(nn.Module):
    """A custom layer that views an input as 1D."""

    def forward(self, input):
        return input.view(input.size(0), -1)


class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, 3)
        self.pool1 = nn.MaxPool2d((2, 2))
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.pool2 = nn.MaxPool2d((2, 2))
        self.flatten = Flatten()
        self.fc1 = nn.Linear(2880, 64)
        self.drop1 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(64, 10)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.drop1(x)
        out_first_digit = self.fc2(x)
        out_second_digit = self.fc3(x)

        return out_first_digit, out_second_digit


torch.manual_seed(0)

class CNN_net(NeuralNetClassifier):
    def get_loss(self, y_pred, y_true, *args, **kwargs):

        loss1 = F.cross_entropy(y_pred[0], y_true[:,0])
        loss2 = F.cross_entropy(y_pred[1], y_true[:,1])

        return 0.5 * (loss1 + loss2)

net = CNN_net(
    CNN,
    max_epochs=5,
    lr=0.1,
    device=device,
)

net.fit(X_train, y_train);
  1. 我需要修改y的格式吗?
  2. 我是否需要构造额外的自定义函数(预测)?
  3. 还有其他建议吗?
EN

回答 1

Stack Overflow用户

发布于 2019-07-30 14:24:06

skorch的NeuralNetClassifier在默认情况下应用分层交叉验证分割,为您提供培训期间验证准确性等指标。当然,这使得您的数据可以以这种方式拆分是必要的。因为每个图像都有两个标签,所以没有什么简单的方法可以进行分层分割(尽管有是方法)。

有两种解决办法:

  1. 完全禁用列车分割(通过train_split=None),并在培训期间丢失验证
  2. 通过train_split=skorch.dataset.CVSplit(5, stratified=False)将列车拆分改为非分层。

由于我猜您希望在培训期间使用验证指标,所以最终代码应该如下所示:

代码语言:javascript
复制
net = CNN_net(
    CNN,
    max_epochs=5,
    lr=0.1,
    device=device,
    train_split=skorch.dataset.CVSplit(5, stratified=False),
)

net.fit(X_train, y_train);
票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57222733

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档