首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >多通道图像数据集上的卷积网训练

多通道图像数据集上的卷积网训练
EN

Stack Overflow用户
提问于 2017-08-22 14:43:14
回答 1查看 695关注 0票数 1

我正试图从零开始实现一个卷积神经网络,我无法弄清楚如何对rgb这样的三维多通道图像执行(矢量化)操作。在遵循文章和教程(如本CS231n教程 )之后,实现单个输入的网络是非常清楚的,因为输入层将是一个3d矩阵,但是数据集中总是有多个数据点。因此,我想不出如何在整个数据集中实现这些网络的向量化操作。

我已经实现了一个以三维矩阵为输入的网络,但现在我意识到它不能在整个数据集上工作,但我必须一次传播一个输入,我真的不知道conv网是否在整个数据集上向量化,如果是的话,我如何向化我的卷积网络来处理多通道图像?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-10-04 16:57:32

如果我正确地回答了你的问题,你基本上是在问如何做一个小批量的卷积层,这将是一个四维张量。

简单地说,您想要将每个输入单独地处理成一个批处理,并对每个输入应用卷积。不使用循环进行矢量化的代码是相当简单的。

矢量化实现通常基于im2col技术,它将四维输入张量转化为一个巨大的矩阵,并进行矩阵乘法。下面是在python中使用numpy.lib.stride_tricks实现前向传递的方法:

代码语言:javascript
复制
import numpy as np

def conv_forward(x, w, b, stride, pad):
  N, C, H, W = x.shape
  F, _, HH, WW = w.shape

  # Check dimensions
  assert (W + 2 * pad - WW) % stride == 0, 'width does not work'
  assert (H + 2 * pad - HH) % stride == 0, 'height does not work'

  # Pad the input
  p = pad
  x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')

  # Figure out output dimensions
  H += 2 * pad
  W += 2 * pad
  out_h = (H - HH) / stride + 1
  out_w = (W - WW) / stride + 1

  # Perform an im2col operation by picking clever strides
  shape = (C, HH, WW, N, out_h, out_w)
  strides = (H * W, W, 1, C * H * W, stride * W, stride)
  strides = x.itemsize * np.array(strides)
  x_stride = np.lib.stride_tricks.as_strided(x_padded,
                                             shape=shape, strides=strides)
  x_cols = np.ascontiguousarray(x_stride)
  x_cols.shape = (C * HH * WW, N * out_h * out_w)

  # Now all our convolutions are a big matrix multiply
  res = w.reshape(F, -1).dot(x_cols) + b.reshape(-1, 1)

  # Reshape the output
  res.shape = (F, N, out_h, out_w)
  out = res.transpose(1, 0, 2, 3)
  out = np.ascontiguousarray(out)
  return out

请注意,它使用了线性代数库的一些重要特性,这些特性在numpy中实现,但可能不在库中。

顺便说一句,您通常不希望将整个数据集作为一个批处理来推送--将其分成几个批。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45820735

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档