首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >numpy确实覆盖了==操作符,因为我不能理解下面的python代码

numpy确实覆盖了==操作符,因为我不能理解下面的python代码
EN

Stack Overflow用户
提问于 2016-10-15 15:15:29
回答 3查看 168关注 0票数 1
代码语言:javascript
复制
image_size = 28
num_labels = 10

def reformat(dataset, labels):
  dataset = dataset.reshape((-1, image_size * image_size)).astype(np.float32)
  # Map 0 to [1.0, 0.0, 0.0 ...], 1 to [0.0, 1.0, 0.0 ...]
  labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)
  return dataset, labels
train_dataset, train_labels = reformat(train_dataset, train_labels)
valid_dataset, valid_labels = reformat(valid_dataset, valid_labels)
test_dataset, test_labels = reformat(test_dataset, test_labels)
print('Training set', train_dataset.shape, train_labels.shape)
print('Validation set', valid_dataset.shape, valid_labels.shape)
print('Test set', test_dataset.shape, test_labels.shape)

这行是什么意思?

代码语言:javascript
复制
labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)

代码来自https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/udacity/2_fullyconnected.ipynb

EN

回答 3

Stack Overflow用户

发布于 2016-10-15 15:21:48

在numpy中,当比较两个numpy数组时,==运算符的含义有所不同(就像在这一行中所做的那样),所以是的,它在这种意义上是重载的。它按元素比较两个numpy数组,并返回一个大小与两个输入相同的布尔型numpy数组。对于>=<等其他比较也是如此。

例如。

代码语言:javascript
复制
import numpy as np
print(np.array([5,8,2]) == np.array([5,3,2]))
# [True False True]
print((np.array([5,8,2]) == np.array([5,3,2])).astype(np.float32))
# [1. 0. 1.]
票数 3
EN

Stack Overflow用户

发布于 2016-10-15 15:22:55

对于Numpy数组,==运算符是一个元素级操作,它返回一个布尔数组。astype函数将布尔值True转换为1.0,将False转换为0.0,如注释中所述。

票数 1
EN

Stack Overflow用户

发布于 2016-10-16 01:08:31

https://docs.python.org/3/reference/expressions.html#value-comparisons描述了与==类似的值比较。虽然默认比较是identity x is y,但它首先检查参数是否实现了__eq__方法。数字、列表和字典实现了它们自己的版本。numpy也是如此。

numpy __eq__的独特之处在于,如果可能,它会逐个元素进行比较,并返回相同大小的布尔数组。

代码语言:javascript
复制
In [426]: [1,2,3]==[1,2,3]
Out[426]: True
In [427]: z1=np.array([1,2,3]); z2=np.array([1,2,3])
In [428]: z1==z2
Out[428]: array([ True,  True,  True], dtype=bool)
In [432]: z1=np.array([1,2,3]); z2=np.array([1,2,4])
In [433]: z1==z2
Out[433]: array([ True,  True, False], dtype=bool)
In [434]: (z1==z2).astype(float)     # change bool to float
Out[434]: array([ 1.,  1.,  0.])

一个常见的问题是“为什么我会得到这个ValueError?”

代码语言:javascript
复制
In [435]: if z1==z2: print('yes')
...
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

这是因为比较会产生这个数组,它有多个True/False值。

浮点数的比较也是一个常见的问题。查看iscloseallclose它,问题就出现了。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/40056209

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档