首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在tensorflow中使用两种不同的模型

在tensorflow中使用两种不同的模型
EN

Stack Overflow用户
提问于 2017-08-18 02:51:36
回答 2查看 1.5K关注 0票数 2

我试着使用两种不同的移动网络模型。下面是我如何初始化模型的代码。

代码语言:javascript
复制
def initialSetup():
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    start_time = timeit.default_timer()

    # This takes 2-5 seconds to run
    # Unpersists graph from file
    with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f:
        age_graph_def = tf.GraphDef()
        age_graph_def.ParseFromString(f.read())
        tf.import_graph_def(age_graph_def, name='')

    with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
        gender_graph_def = tf.GraphDef()
        gender_graph_def.ParseFromString(f.read())
        tf.import_graph_def(gender_graph_def, name='')

    print ('Took {} seconds to unpersist the graph'.format(timeit.default_timer() - start_time))

由于两者都是不同的模型,我如何使用它来预测呢?

更新

代码语言:javascript
复制
initialSetup()

age_session = tf.Session(graph=age_graph_def)
gender_session = tf.Session(graph=gender_graph_def)

with tf.Session() as sess:
    start_time = timeit.default_timer()

    # Feed the image_data as input to the graph and get first prediction
    softmax_tensor = age_session.graph.get_tensor_by_name('final_result:0')

    print ('Took {} seconds to feed data to graph'.format(timeit.default_timer() - start_time))

    while True:
        # Capture frame-by-frame
        ret, frame = video_capture.read()

误差

追溯(最近一次调用):文件"C:/Users/Desktop/untitled/testimg/testimg/combo.py",第48行,在age_session = tf.Session(graph=age_graph_def)文件"C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py",第1292行“中,在init超级(会话,self).init(target,图,文件“C:\ Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py",Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py",第529行,在init raise (‘图形必须是tf.Graph,但获得%s’%类型(图形)) TypeError:图必须是tf.Graph,但是在:> Traceback (最近调用的最后一次调用):文件”C:Program第587行中忽略了异常。在del中,如果self._session不是None: AttributeError:'Session‘对象没有属性'_session’

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2017-08-18 15:00:01

当您在同一个图中处理多个模型时,请使用名称作用域为单个张量提供可预测的名称。例如,您可以重写initial_setup()如下:

代码语言:javascript
复制
def initialSetup():
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    start_time = timeit.default_timer()

    # This takes 2-5 seconds to run
    # Unpersists graph from file
    with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f:
        age_graph_def = tf.GraphDef()
        age_graph_def.ParseFromString(f.read())
        tf.import_graph_def(age_graph_def, name='age_model')

    with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
        gender_graph_def = tf.GraphDef()
        gender_graph_def.ParseFromString(f.read())
        tf.import_graph_def(gender_graph_def, name='gender_model')

    print ('Took {} seconds to unpersist the graph'.format(timeit.default_timer() - start_time))

现在,来自age_graph_def的所有节点的名称将以"age_model/"作为前缀,而来自gender_graph_def的所有节点的名称将以"gender_model/"作为前缀。它们都是同一个默认图的一部分,因此您可以使用一个没有tf.Session参数的graph来访问这两个模型。

代码语言:javascript
复制
initialSetup()

with tf.Session() as sess:
    start_time = timeit.default_timer()

    # Feed the image_data as input to the graph and get first prediction
    softmax_tensor = sess.graph.get_tensor_by_name('age_model/final_result:0')

    # Alternatively, to get a tensor from the gender model:
    # tensor = sess.graph.get_tensor_by_name('gender_model/...')

    print ('Took {} seconds to feed data to graph'.format(timeit.default_timer() - start_time))

    while True:
        # Capture frame-by-frame
        ret, frame = video_capture.read()
票数 3
EN

Stack Overflow用户

发布于 2017-08-18 07:36:36

tf.Session需要一个tf.Graph实例而不是tf.GraphDef,下面的步骤修复了这个问题。

代码语言:javascript
复制
def initialSetup():
    with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f:
        age_graph_def = tf.GraphDef()
        age_graph_def.ParseFromString(f.read())
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(age_graph_def, name='')
            age_graph = graph

   ...
   return age_graph, gender_graph

age_graph, gender_graph = initial_setup() 
age_session = tf.Session(graph=age_graph)
...
# also delete the following line, as it creates another new context 
with tf.Session() as sess:
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45747769

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档