为什么python float乘以torch.long会得到torch.float,而用torch.long给float赋权则会得到torch.long?
>>> a = 0.9
>>> b = torch.tensor(2, dtype=torch.long)
>>> foo = a * b
>>> print(foo, foo.dtype)
tensor(1.8000) torch.float32
>>> bar = a ** b
>>> print(bar, bar.dtype)
tensor(0) torch.int64发布于 2020-01-21 23:50:52
这看起来像是一个bug,可能是pytorch将**绑定到__rpow__或__pow__的方式。
例如,如果你尝试了0.9 - torch.tensor(2),因为0.9不是一个张量,所以它会被解释为torch.tensor(2).__rsub__(0.9),它可以正常工作。**的行为与此相同,但torch.tensor(2).__rpow__(0.9)错误地返回了数据类型为int64的tensor(0)。
同时,您可以使用torch.tensor(0.9) ** torch.tensor(2)。
https://stackoverflow.com/questions/59827509
复制相似问题