我在LibTorch中有一个一维张量的数字,在C++中,我想用>条件来计算每个数。
这是我的尝试。在if语句中,条件是第一个张量是否大于0.5。
#include <torch/torch.h>
using namespace torch::indexing;
torch::Tensor frogs;
int main() {
frogs = torch::rand({11});
for (int i = 0; i<10; ++i) {
if (frogs.index({i}).item() > 0.5) {
std::cout << frogs.index({i}).item() << " \n";
}
}
return 0;
}它返回错误..。
Consolidate compiler generated dependencies of target mujoco_gym
[ 50%] Building CXX object CMakeFiles/mujoco_gym.dir/tester.cpp.o
/home/iii/tor/m_gym/tester.cpp: In function ‘int main()’:
/home/iii/tor/m_gym/tester.cpp:18:37: error: no match for ‘operator>’ (operand types are ‘c10::Scalar’ and ‘double’)
18 | if (frogs.index({i}).item() > 0.5) {
| ~~~~~~~~~~~~~~~~~~~~~~~ ^ ~~~
| | |
| | double
| c10::Scalar
In file included from /home/iii/tor/m_gym/libtorch/include/c10/util/string_view.h:5,
from /home/iii/tor/m_gym/libtorch/include/c10/util/StringUtil.h:6,
from /home/iii/tor/m_gym/libtorch/include/c10/util/Exception.h:6,
from /home/iii/tor/m_gym/libtorch/include/c10/core/Device.h:5,
from /home/iii/tor/m_gym/libtorch/include/ATen/core/TensorBody.h:11,
from /home/iii/tor/m_gym/libtorch/include/ATen/core/Tensor.h:3,
from /home/iii/tor/m_gym/libtorch/include/ATen/Tensor.h:3,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/function_hook.h:3,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/cpp_hook.h:2,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/variable.h:6,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/autograd.h:3,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/autograd.h:3,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/all.h:7,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/torch.h:3,
from /home/iii/tor/m_gym/tester.cpp:1:
/home/iii/tor/m_gym/libtorch/include/c10/util/reverse_iterator.h:200:23: note: candidate: ‘template<class _Iterator> constexpr bool c10::operator>(const c10::reverse_iterator<_Iterator>&, const c10::reverse_iterator<_Iterator>&)’
200 | inline constexpr bool operator>(
| ^~~~~~~~
/home/iii/tor/m_gym/libtorch/include/c10/util/reverse_iterator.h:200:23: note: template argument deduction/substitution failed:
/home/iii/tor/m_gym/tester.cpp:18:39: note: ‘c10::Scalar’ is not derived from ‘const c10::reverse_iterator<_Iterator>’
18 | if (frogs.index({i}).item() > 0.5) {
| ^~~
In file included from /home/iii/tor/m_gym/libtorch/include/c10/util/string_view.h:5,
from /home/iii/tor/m_gym/libtorch/include/c10/util/StringUtil.h:6,
from /home/iii/tor/m_gym/libtorch/include/c10/util/Exception.h:6,
from /home/iii/tor/m_gym/libtorch/include/c10/core/Device.h:5,
from /home/iii/tor/m_gym/libtorch/include/ATen/core/TensorBody.h:11,
from /home/iii/tor/m_gym/libtorch/include/ATen/core/Tensor.h:3,
from /home/iii/tor/m_gym/libtorch/include/ATen/Tensor.h:3,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/function_hook.h:3,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/cpp_hook.h:2,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/variable.h:6,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/autograd.h:3,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/autograd.h:3,发布于 2022-08-25 04:47:32
您可以使用operator[]访问张量元素,不需要使用data_ptr (实际上可以避免它):
std::cout << frogs[i];你可以直接比较一个张量和一个标量,它将返回一个布尔张量,所有系数都是与标量比较的结果。试试看我的意思是:
auto t = torch::randn({3,4});
std::cout << t << (t > 0);你可以用argwhere得到一个包含所有非零元素指数的张量。例如:
auto t = torch::randn({3,4});
// this will output the list of positive entries in the tensor t
std::cout << torch::argwhere(t > 0);发布于 2022-08-22 05:33:07
我在张量上使用了LibTorch的.data_ptr来创建一个数组,该数组允许索引并返回一个浮点。我可以用在条件语句中。这一行是float* temp_arr = frogs.data_ptr<float>();
int main() {
frogs = torch::rand({11});
float* temp_arr = frogs.data_ptr<float>();
for (int i = 0; i<10; ++i) {
if (temp_arr[i] > 0.5) {
std::cout << temp_arr[i] << " \n";
}
}
return 0;
}https://stackoverflow.com/questions/73439423
复制相似问题