给定两个0和1的数组,a和b,我希望它们的点积应该等于a & b中的1的个数(其中&是按位and)。所以我希望np.dot(a, b) == np.sum(a & b)。下面的代码表明这通常不是真的。对于较短的向量似乎是正确的,但随着向量变长,更容易发现失败的情况。这里有没有bug,或者我明显漏掉了什么?
# andsum_dot_test.py
import sys
import numpy as np
if len(sys.argv) != 4:
print(f"usage: {sys.argv[0]} seed size reps")
exit(1)
seed = int(sys.argv[1])
size = int(sys.argv[2])
reps = int(sys.argv[3])
rng = np.random.default_rng(seed)
for _ in range(reps):
a = rng.integers(2, size=size, dtype=np.uint8)
b = rng.integers(2, size=size, dtype=np.uint8)
assert np.sum(a & b) == np.dot(a, b)$ python -c "import sys; print(sys.version)"
3.8.5 (default, Jul 27 2020, 08:42:51)
[GCC 10.1.0]
$ python -c "import numpy; print(numpy.__version__)"
1.19.1$ python andsum_dot_test.py 1234 500 10000
$ python andsum_dot_test.py 1234 5000 1
Traceback (most recent call last):
File "andsum_dot_test.py", line 17, in <module>
assert np.sum(a & b) == np.dot(a, b)
AssertionError发布于 2020-08-09 17:56:53
np.dot()返回一个与输入向量具有相同数据类型的值,因此断言应该是等价的模256:
assert np.sum(a & b) % 256 == np.dot(a, b)要明确的是,在长np.uint8数组的情况下,原始断言失败,因为8位值太小,无法容纳乘积的总和。似乎没有一种明显的方法可以在不强制其输入的数据类型(例如np.dot(a.astype(np.uint64), b))的情况下从np.dot返回替代数据类型。
import sys
import numpy as np
if len(sys.argv) != 4:
print(f"usage: {sys.argv[0]} seed size reps")
exit(1)
seed = int(sys.argv[1])
size = int(sys.argv[2])
reps = int(sys.argv[3])
rng = np.random.default_rng(seed)
dtype = np.uint8
k = np.iinfo(dtype).max + 1
for _ in range(reps):
a = rng.integers(2, size=size, dtype=dtype)
b = rng.integers(2, size=size, dtype=dtype)
assert np.sum(a & b) % k == np.dot(a, b)https://stackoverflow.com/questions/63324103
复制相似问题