首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >用于创建小批量的切片

用于创建小批量的切片
EN

Stack Overflow用户
提问于 2020-07-02 22:25:44
回答 1查看 282关注 0票数 1

我打算为我的深度学习神经网络程序创建迷你批次,从一个由'm‘个示例组成的训练集。我试过了:

代码语言:javascript
复制
# First Shuffle (X, Y)
permutation = list(np.random.permutation(m))
shuffled_X = X[:, permutation]
shuffled_Y = Y[:, permutation].reshape((1,m))

# Partition (shuffled_X, shuffled_Y). Minus the end case where mini-batch will contain lesser number of training samples.
num_complete_minibatches = math.floor(m/mini_batch_size) # number of mini batches of size mini_batch_size in your partitionning
for k in range(0, num_complete_minibatches):
    ### START CODE HERE ### (approx. 2 lines)
    mini_batch_X = shuffled_X[mini_batch_size*k:mini_batch_size*(k+2)]
    mini_batch_Y = shuffled_Y[mini_batch_size*k:mini_batch_size*(k+2)]

但这给了我以下结果:

代码语言:javascript
复制
shape of the 1st mini_batch_X: (128, 148)
shape of the 2nd mini_batch_X: (128, 148)
shape of the 3rd mini_batch_X: (12288, 148)
shape of the 1st mini_batch_Y: (1, 148)
shape of the 2nd mini_batch_Y: (0, 148)
shape of the 3rd mini_batch_Y: (1, 148)
mini batch sanity check: [ 0.90085595 -0.7612069   0.2344157 ]

预期输出为:

代码语言:javascript
复制
shape of the 1st mini_batch_X   (12288, 64)
shape of the 2nd mini_batch_X   (12288, 64)
shape of the 3rd mini_batch_X   (12288, 20)
shape of the 1st mini_batch_Y   (1, 64)
shape of the 2nd mini_batch_Y   (1, 64)
shape of the 3rd mini_batch_Y   (1, 20)
mini batch sanity check [ 0.90085595 -0.7612069 0.2344157 ] 

我确信我已经实现的切片有问题,但无法解决它。任何帮助都是非常感谢的。谢谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-07-03 02:14:00

我认为你没有正确地切分numpy数组。最初,当您对数组进行混洗时,这种方式是正确的。您不希望对第一个维度进行切片,因此请保持使用:的方式,并使用<Start Index>:<End Index>对第二个维度进行切片。这就是我在下面的代码中所做的事情。

代码语言:javascript
复制
for k in range(num_complete_minibatches+1):
### START CODE HERE ### (approx. 2 lines)
    mini_batch_X = shuffled_X[:,mini_batch_size*(k):mini_batch_size*(k+1)]
    mini_batch_Y = shuffled_Y[:,mini_batch_size*(k):mini_batch_size*(k+1)]
    print(mini_batch_X.shape,mini_batch_Y.shape)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62698610

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档