首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >sess.run()不运行?

sess.run()不运行?
EN

Stack Overflow用户
提问于 2018-05-03 02:54:05
回答 1查看 1.3K关注 0票数 0

我是新来的,学习tensorflow,遇到了问题。

代码语言:javascript
复制
import model_method
fittt(model_method.build(self,...),...parameters...)

上面的内容在main.py导入model_method.py中。函数fittt在main.py中:

代码语言:javascript
复制
def fittt(model,...):
    model.fit(...)

build() in model_method.py:

代码语言:javascript
复制
def build(self,...):
    self.op_C,self.op_A = self.function_A(...)
    self.op_B = self.function_B(self.op_C,...)

fit() in model_method.py:

代码语言:javascript
复制
def fit(self,...):
    sess = tf.Session(graph=self.graph,config=config)
    BB,AA = sess.run([self.op_B,self.op_A],feed_dict)

为了检查运行过程,我在function_A()function_B()的开头model_method.py中添加了pdb.set_trace(),如下所示:

代码语言:javascript
复制
def function_A(self,...):
    pdb.set_trace()
    ......

def function_B(self,...):
    pdb.set_trace()
    ......

两个pdb.set_trace()只在构建()调用时停止,而在sess.run(self.op_B、self.op_A、feed_dict)调用时不工作。这意味着sess.run() 实际上没有运行 function_A()和function_B()。我想知道为什么,也想知道如何使这两个功能发挥作用?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-05-03 05:07:46

通过调用model_method.build()函数,您可以创建一个计算图。在这个调用中,每一行代码都被执行(因此pdb停止了)。

但是,tf.Session.run(...)只执行计算图中计算获取值所必需的部分(在您的示例中是self.op_Aself.op_B )。该函数不会再次执行整个build()函数。

因此,当您运行pdb.set_trace()时,sess.run(...)没有执行是因为它们不是有效的Tensor对象,因此不是计算图的一部分。

更新

请考虑以下几点:

代码语言:javascript
复制
class My_Model:

  def __init__(self):
      self.np_input = np.random.normal(size=(10,2)) # 10x2

  def build(self):
      self._in = tf.placeholder(dtype=tf.float32, shape=[10, None]) # matrix 10xN
      W_exception = tf.random_normal(dtype=tf.float32, shape=[3,3]) # matrix 3x3
      W_success = tf.random_normal(dtype=tf.float32, shape=[2,3]) # matrix 2x3
      self.op_exception = tf.matmul(self._in, W_exception) # [10x2] x [3x3] = ERROR
      self.op_success = tf.matmul(self._in, W_success) # [10x2] x [2x3] = [10x3]
      print('Computational Graph Built')

  def fit_success(self):
      with tf.Session() as sess:
          res = sess.run(self.op_success, feed_dict={self._in : self.np_input})
          print('Result shape: {}'.format(res.shape))

  def fit_exception(self):
      with tf.Session() as sess:
          res = sess.run(self.op_exception, feed_dict={self._in : self.np_input})
          print('Result shape: {}'.format(res.shape))

然后打电话:

代码语言:javascript
复制
m = My_Model()
m.build()
#> Computational Graph Built

m.fit_success()
#> Result shape: (10, 3)

m.fit_exception()
#> InvalidArgumentError: Matrix size-incompatible: In[0]: [10,2], In[1]: [3,3]

所以解释一下你在那里看到的。我们首先在build()函数中定义计算图。_in是我们的输入张量;None表示维数1是动态确定的-也就是说,一旦我们提供了一个具有指定值的张量。

然后我们定义了两个矩阵W_exceptionW_success,它们具有指定的所有维数,它们的值都是随机生成的。

然后我们定义了两个运算,矩阵乘法,每个运算返回一个张量。

我们调用了build()函数并创建了计算图,print()函数也被执行,但没有添加到图中。这里没有计算。事实上,它甚至不可能,因为没有指定_in的值。

现在,为了证明只需要计算所需的部分,我们调用fit_success()函数,它简单地将输入张量_in乘以W_success张量(具有正确的维数)。我们得到一个形状正确的张量: 10x3。请注意,我们没有收到由于尺寸不匹配而无法计算op_exception的错误。这是因为我们不需要它来评估op_success

最后,当我们试图用相同的输入张量计算op_exception时,我只是说明了异常确实会被抛出。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50146239

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档