首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >我是否在我的tensorflow代码中完全重置了计算图?

我是否在我的tensorflow代码中完全重置了计算图?
EN

Stack Overflow用户
提问于 2016-04-28 14:50:42
回答 1查看 130关注 0票数 2

我正在试着做一些实验。在每一次小批量之后,我都会尝试重新构建计算图。但我有一种感觉,那就是有些问题。当我为第一个小批处理植入W1、W2、W3的初始值时,我希望得到更新。然而,从第二个小批量开始,我没有得到我期望的更新。是否有可能在每次迭代时检查计算图是什么样子的?

代码语言:javascript
复制
import tensorflow as tf
import numpy as np
bsize = 5
Xset = np.random.uniform(0,1,(60000,6*20)) * 50
Yset = Xset[:,0]

Wone = np.random.normal(0, .35, (6,6))
Wtwo = np.random.normal(0, .35, (6,6))
Wthree = np.random.normal(0, .35, (6,6))

Results = []

for q in range(1):
    for k in range(40):
        from tensorflow.python.framework import ops
        ops.reset_default_graph()
        tf.reset_default_graph()
        tf.InteractiveSession()
        x1 = tf.placeholder(tf.float32, shape=(bsize,6*20))
        y = tf.placeholder(tf.float32, shape=(bsize,1))
        x = tf.reshape(x1,[bsize,6,20])
        InitialState = tf.zeros((6,bsize))
        h = InitialState
        W1 = tf.Variable(tf.convert_to_tensor(Wone,dtype = tf.float32),name = "W1")
        W2 = tf.Variable(tf.convert_to_tensor(Wtwo,dtype = tf.float32),name = "W2")
        W3 = tf.Variable(tf.convert_to_tensor(Wthree,dtype = tf.float32),name = "W3")


#create list
        lis = []
        for q in range(10):
            pit = np.random.uniform(-1,1)
            #print pit

            if(pit<0) or q == 0 or pit==0 or pit > 0:
                lis.append(q)

        for p in lis:
            h = tf.matmul(W1,h) + tf.matmul(W2,tf.transpose(x[:,:,p]))
            h = tf.nn.relu(h)

        hstar = h
        output = tf.matmul(W3,hstar)
        output1 = output[0:1,:]

        loss = tf.reduce_sum(tf.sub(tf.transpose(output1) ,y)*tf.sub(tf.transpose(output1) ,y))

        opt = tf.train.AdamOptimizer()
        opt_operation = opt.minimize(loss)

        for h in range(1):
            with tf.Session() as sess:
                sess.run(tf.initialize_all_variables())
                a,b,RLoss,_ = sess.run([hstar,output,loss,opt_operation], feed_dict = {x1:Xset[(bsize*k):(bsize*k+bsize),:],y:Yset[bsize*k:k*bsize+bsize,None]})

                print RLoss, k
EN

回答 1

Stack Overflow用户

发布于 2016-04-29 02:34:19

仅仅是tf.reset_default_graph()就足够了。您可以通过检查tf.get_default_graph().as_graph_def()来查看图形的外观,它是tensorflow.GraphDef模式实现的here的一个协议

特别是,要获取图中节点的所有名称,您可以这样做

代码语言:javascript
复制
[n.name for n in tf.get_default_graph().as_graph_def().node]
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/36907288

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档