首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >当运行交叉验证时,Tensorflow保存的模型变得更大。

当运行交叉验证时,Tensorflow保存的模型变得更大。
EN

Stack Overflow用户
提问于 2016-03-29 08:29:01
回答 1查看 494关注 0票数 0

在Tensorflow上运行交叉验证的正确方法是什么?下面是我的代码片段:

代码语言:javascript
复制
class TextCNN:
  ...
  def train(self):
    saver = tf.train.Saver(tf.all_variables())
    with tf.Session() as sess:
      ...
      # training loop
      ...
      # training finished
      path = saver.save(sess, "{:s}/model.{:d}".format(self.checkpoint_dir, self.test_fold))

if __name__ == "__main__":
  for i in range(CV_SIZE):
    cnn = TextCNN(i)
    cnn.train()

折叠0保存的型号大小约为2M。但对于4米左右的1倍,6米左右的2倍,等等。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2016-03-29 15:33:46

我的猜测是,TextCNN构造函数和train()方法正在向默认图(tf.get_default_graph())中添加节点,而保存的模型包含所有以前的图,因此它是“意外二次”的,并且随着__main__循环的每一次迭代而增长。

幸运的是,解决办法很简单。只需按以下方式重写主循环:

代码语言:javascript
复制
if __name__ == "__main__":
  for i in range(CV_SIZE):
    with tf.Graph().as_default():  # Performs training in a new, empty graph.
      cnn = TextCNN(i)
      cnn.train()

这将为循环的每一次迭代创建一个新的空图。因此,保存的模型将不包含上一次迭代中的节点(和变量),模型大小应该保持不变。

注意,如果可能的话,您应该尝试在所有迭代中重用相同的图。但是,我意识到,如果图的结构从一次迭代到下一次迭代,这可能是不可能的。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/36279144

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档