在Tensorflow上运行交叉验证的正确方法是什么?下面是我的代码片段:
class TextCNN:
...
def train(self):
saver = tf.train.Saver(tf.all_variables())
with tf.Session() as sess:
...
# training loop
...
# training finished
path = saver.save(sess, "{:s}/model.{:d}".format(self.checkpoint_dir, self.test_fold))
if __name__ == "__main__":
for i in range(CV_SIZE):
cnn = TextCNN(i)
cnn.train()折叠0保存的型号大小约为2M。但对于4米左右的1倍,6米左右的2倍,等等。
发布于 2016-03-29 15:33:46
我的猜测是,TextCNN构造函数和train()方法正在向默认图(tf.get_default_graph())中添加节点,而保存的模型包含所有以前的图,因此它是“意外二次”的,并且随着__main__循环的每一次迭代而增长。
幸运的是,解决办法很简单。只需按以下方式重写主循环:
if __name__ == "__main__":
for i in range(CV_SIZE):
with tf.Graph().as_default(): # Performs training in a new, empty graph.
cnn = TextCNN(i)
cnn.train()这将为循环的每一次迭代创建一个新的空图。因此,保存的模型将不包含上一次迭代中的节点(和变量),模型大小应该保持不变。
注意,如果可能的话,您应该尝试在所有迭代中重用相同的图。但是,我意识到,如果图的结构从一次迭代到下一次迭代,这可能是不可能的。
https://stackoverflow.com/questions/36279144
复制相似问题