首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >大图像数据集中计算均方误差最快的方法是哪一种?

大图像数据集中计算均方误差最快的方法是哪一种?
EN

Stack Overflow用户
提问于 2020-09-30 06:06:27
回答 1查看 107关注 0票数 0

我试图计算图像数据集(CIFAR-10)中的均方误差。我有一个维度的numpy array of 5*10000*32*32*3,换句话说,有5批10000张图像,每个图像都有32*32*3的维数。这些图像属于10类图像。我已经计算了每个类的平均值,现在我试图计算50000幅图像的均方误差,wrt,10幅平均图像。以下是代码:

代码语言:javascript
复制
for i in range(0, 5):
  for j in range(0, 10000):
      min_diff, min_class = float('inf'), 0
      for avg in class_avg:  # avg class comprises of 10 average images
          temp = mse(avg[1], images[i][j])
          if temp < min_diff:
              min_diff = temp
              min_class = avg[0]
      train_pred[i][j] = min_class

Problem:有什么办法让它更快吗?有矮胖的魔法吗?谢谢。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-09-30 06:48:53

您可以使用expand_dimstile

有很多方法来扩展数组的维度,我将使用其中之一,类似于[:,None,:],这在中间增加了一个新的轴。

下面是一个如何将这两种方法结合起来以完成任务的示例:

代码语言:javascript
复制
test = np.ones((5,100,32,32,3)) # batches of images 
average = np.ones((10,32,32,3)) # the 10 images 
average = average[None,None,...] # reshape to (1,1,10,32,32,3)

test = test[:,:,None,...] # insert an axis 
test = np.tile(test,(1,1,10,1,1,1)) # reshape to (5,100,10,32,32,3)
print(test.shape,average.shape)

mse = ((test-average)**2).mean(axis=(3,4,5))
class_idx = np.argmin(mse,axis=-1)

更新

使用expand_dimstile的目的是避免使用for-loop。但是,np.tile操作将创建10个原始数组的副本,如果数组很大,这肯定会损害性能。为了避免使用np.tile,您可以尝试下面的代码:

代码语言:javascript
复制
labels = np.empty((5,100,10))
average = np.ones((10,32,32,3))
average = average[None,...]

test = np.ones((5,100,32,32,3))

for ind in range(10):
    labels[...,ind] = ((test-average[:,ind,...])**2).mean(axis=(2,3,4))
labels = np.argmin(labels,axis=-1)  
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64131772

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档