首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将内存块分配给py-火炬输出张量(C++ API)

将内存块分配给py-火炬输出张量(C++ API)
EN

Stack Overflow用户
提问于 2022-04-08 00:28:31
回答 1查看 406关注 0票数 1

我正在使用py-torch训练一个线性模型,并使用“保存”函数调用将其保存到一个文件中。我有另一段代码,它在C++中加载模型并执行推理。我想指示火炬CPP库在最终输出张量中使用一个特定的内存块。这有可能吗?如果是,怎么做?下面你可以看到一个小小的例子,说明我正在努力实现的目标。

代码语言:javascript
复制
#include <iostream>
#include <memory>

#include <torch/script.h>

int main(int argc, const char* argv[]) {
  if (argc != 3) {
    std::cerr << "usage: example-app <path-to-exported-script-module>\n";
    return -1;
  }
  long numElements = (1024*1024)/sizeof(float) * atoi(argv[2]);

  float *a = new float[numElements]; 
  float *b = new float[numElements];
  float *c = new float[numElements*4];

  for (int i = 0; i < numElements; i++){
    a[i] = i;
    b[i] = -i;
  }

  //auto options = torch::TensorOptions().dtype(torch::kFloat64);
  at::Tensor a_t = torch::from_blob((float*) a, {numElements,1});
  at::Tensor b_t = torch::from_blob((float*) b, {numElements,1});
  at::Tensor out = torch::from_blob((float*) c, {numElements,4});

  at::Tensor c_t = at::cat({a_t,b_t}, 1);
  at::Tensor d_t = at::reshape(c_t, {numElements,2});

  torch::jit::script::Module module;
  try {
    module = torch::jit::load(argv[1]);
  }
  catch (const c10::Error& e) {
    return -1;
  }


  out =  module.forward({d_t}).toTensor();
  std::cout<< out.sizes() << "\n";

  delete [] a;
  delete [] b;
  delete [] c;

  return 0;
}

因此,我将内存分配到"c“中,然后从这个内存中抛出一个张量。我把这个记忆存储在一个名为"out“的张量中。当我调用前向方法时,我加载模型。我观察到,得到的数据被复制/移动到"out“张量中。然而,我想指示火炬直接存储到“外面”的记忆。这个是可能的吗?

EN

回答 1

Stack Overflow用户

发布于 2022-04-28 02:26:10

在lib手电源代码中的某个地方(我不记得在哪里,我会尝试找到文件),有一个操作符,如下所示(注意最后&)

代码语言:javascript
复制
torch::tensor& operator=(torch::Tensor rhs) &&;

如果我没记错的话。基本上,torch假设如果将张量rhs分配给rvalue引用张量,那么实际上意味着将rhs复制到底层存储中。

所以在你的情况下,那将是

代码语言:javascript
复制
std::move(out) = module.forward({d_t}).toTensor();

代码语言:javascript
复制
torch::from_blob((float*) c, {numElements,4}) = module.forward({d_t}).toTensor();
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71790378

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档