首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在collate_fn中使用LibTorch

如何在collate_fn中使用LibTorch
EN

Stack Overflow用户
提问于 2021-09-28 08:56:35
回答 1查看 149关注 0票数 0

我试图实现一个基于图像的回归,使用CNN在lib手电。问题是,我的图像有不同的大小,这将导致异常批处理图像。

首先,我创建了我的dataset

代码语言:javascript
复制
auto set = MyDataSet(pathToData).map(torch::data::transforms::Stack<>());

然后我创建了dataLoader

代码语言:javascript
复制
auto dataLoader = torch::data::make_data_loader(
    std::move(set),
    torch::data::DataLoaderOptions().batch_size(batchSize).workers(numWorkersDataLoader)
);

异常将在列车循环中抛出批处理数据:

代码语言:javascript
复制
for (torch::data::Example<> &batch: *dataLoader) {
        processBatch(model, optimizer, counter, batch);
}

对于大于1的批处理大小(批处理大小为1,一切都很好,因为没有任何堆叠)。例如,我将使用2的批处理大小获得以下错误:

代码语言:javascript
复制
...
what():  stack expects each tensor to be equal size, but got [3, 1264, 532] at entry 0 and [3, 299, 294] at entry 1

例如,我读到可以使用collate_fn来实现一些填充(例如这里),我只是不知道在哪里实现它。例如,torch::data::DataLoaderOptions不提供这样的东西。

有人知道怎么做吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-09-29 07:42:22

我现在有解决办法了。总之,我在Conv和Denselayers中分割我的CNN,并在批处理结构中使用torch::nn::AdaptiveMaxPool2d的输出。

为了做到这一点,我必须修改我的数据集、Net和train/val/test-方法。在我的网络中,我增加了两个额外的转发函数。第一个通过所有Conv层传递数据,并返回AdaptiveMaxPool2d-Layer的输出。第二种方法是将数据通过所有密集的层。在实践中,这看起来像是:

代码语言:javascript
复制
torch::Tensor forwardConLayer(torch::Tensor x) {
    x = torch::relu(conv1(x));
    x = torch::relu(conv2(x));
    x = torch::relu(conv3(x));
    x = torch::relu(ada1(x));
    x = torch::flatten(x);
    return x;
}

torch::Tensor forwardDenseLayer(torch::Tensor x) {
    x = torch::relu(lin1(x));
    x = lin2(x);
    return x;
}

然后重写get_batch方法,并使用forwardConLayer计算每个批处理条目。为了训练(正确),我在构造批处理之前调用zero_grad()。所有这些看起来都是这样的:

代码语言:javascript
复制
std::vector<ExampleType> get_batch(at::ArrayRef<size_t> indices) override {
    // impl from bash.h
    this->net.zero_grad();
    std::vector<ExampleType> batch;
    batch.reserve(indices.size());
    for (const auto i : indices) {
        ExampleType batchEntry = get(i);
        auto batchEntryData = (batchEntry.data).unsqueeze(0);
        auto newBatchEntryData = this->net.forwardConLayer(batchEntryData);             
        batchEntry.data = newBatchEntryData;
        batch.push_back(batchEntry);
    }
    return batch;
}

最后,我在所有我通常会调用forwardDenseLayer的地方调用forward,例如:

代码语言:javascript
复制
    for (torch::data::Example<> &batch: *dataLoader) {
        auto data = batch.data;
        auto target = batch.target.squeeze();

        auto output = model.forwardDenseLayer(data);
        auto loss = torch::mse_loss(output, target);
        LOG(INFO) << "Batch loss: " << loss.item<double>();

        loss.backward();
        optimizer.step();
    }

更新

如果数据服务器的工作人员数不是0,则此解决方案似乎会导致错误。错误是:

代码语言:javascript
复制
terminate called after thro9wing an instance of 'std::runtime_error'
  what(): one of the variables needed for gradient computation has been modified by an inplace operation: [CPUFloatType [3, 12, 3, 3]] is at version 2; expected version 1 instead. ...

这个错误确实是有意义的,因为在批处理过程中,数据正在传递CNN的头部。解决这个“问题”的方法是将工人人数设置为0。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69358675

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档