首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在Tensorflow 2中实现小批量梯度下降?

如何在Tensorflow 2中实现小批量梯度下降?
EN

Stack Overflow用户
提问于 2020-09-01 12:04:28
回答 1查看 2.2K关注 0票数 1

我对机器学习和Tensorflow相对较新,我想尝试在MNIST数据集上实现小型批处理梯度下降。然而,我不知道该如何实施。

(附带说明:训练图像(28 in乘28 in)和标签存储在Numpy数组中)

目前,我可以看到两种不同的实现方法:

  1. My训练图像位于一个60万,28,28的Numpy数组中。将其重新构造为25 (num批)、2400 (批处理中的num映像)、28、28,然后使用for循环调用每个批,并将其传递给model.compile()方法。对于这种方法,我唯一担心的是for循环本身是缓慢的,而矢量化实现则要快得多。

  1. 将图像和标签组合到tensorflow dataset对象中,然后调用Dataset.batch()方法和Dataset.prefetch()方法,然后将数据传递给model.compile()方法。唯一的问题是,我的数据不是一个Numpy数组,我觉得它比tensorflow dataset对象具有更大的灵活性。

这两种方法中哪一种最好实现,还是有第三种方法是我不知道的最好的方法?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-09-01 12:25:30

Keras的model.fit方法有一个内置的model.fit参数(因为您用keras标记了这个问题,我假设您正在使用它)。我相信,这可能是最好的优化方法,以实现您正在寻找的。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63687295

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档