首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow模型使用Flask抛出此图的一个元素

Tensorflow模型使用Flask抛出此图的一个元素
EN

Stack Overflow用户
提问于 2019-07-30 19:50:19
回答 1查看 90关注 0票数 1

我有这样一个类,用于查找给定句子的句子嵌入

代码语言:javascript
复制
class Embeddings:
    def __init__(self):
        self.embedding_model_url = config_obj.tf_model_url
        self.embedding_model = hub.Module(self.embedding_model_url)
        self.messages = tf.placeholder(dtype=tf.string, shape=[None])
        self.output = self.embedding_model(self.messages)
        # self.initialize_graph()

    @staticmethod
    def initialize_graph():
        with tf.Session() as session:
            session.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

    def get_sentence_embeddings(self, sentence):

        with tf.Session() as session:
            session.run([tf.global_variables_initializer(), tf.tables_initializer()])
            result = session.run(self.output, feed_dict={self.messages: [sentence]})

        return result


if __name__ == '__main__':
    sentence = "GoAir is waiving cancellation and change fees for Bhubaneswar, Kolkata and Ranchi flights for travel between May 2 and May 5, the airline said in a statement"
    tf_object = Embeddings()
    embeddings = tf_object.get_sentence_embeddings(sentence)
    print(embeddings)

这是一个独立的应用程序,但当我尝试将其与Flask集成时,如下所示

代码语言:javascript
复制
from sentence_embeddings import Embeddings

embedding_obj = Embeddings()
@app.route('/get-similar-claims', methods=['POST'])
def get_similar_claims():
    params = request.get_json()
    claim = params.get("claim", "")
    num_results = params.get("num_results", 10)
    t0 = time.time()
    # claim_embeddings = ""
    claim_embeddings = embedding_obj.get_sentence_embeddings(claim)
    logger.info("Time taken to calculate sentence embeddings - {}".format(round(time.time() - t0, 4)))
    return Response(json.dumps(claim_embeddings), mimetype='application/json')

if __name__ == '__main__':
    app.run('0.0.0.0', 5001)

它抛出一个错误

代码语言:javascript
复制
File "/Users/anuragsharma/claim_similarity/api/app.py", line 32, in get_similar_claims
    claim_embeddings = embedding_obj.get_sentence_embeddings(claim)
  File "/Users/anuragsharma/claim_similarity/api/sentence_embeddings.py", line 31, in get_sentence_embeddings
    result = session.run(self.output, feed_dict={self.messages: tf.convert_to_tensor(sentence)})
  File "/Users/anuragsharma/anaconda3/envs/similarity-search/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 929, in run
    run_metadata_ptr)
  File "/Users/anuragsharma/anaconda3/envs/similarity-search/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1095, in _run
    'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(?,), dtype=string) is not an element of this graph.
I0730 16:44:33.384036 123145584836608 _internal.py:122] 127.0.0.1 - - [30/Jul/2019 16:44:33] "POST /get-similar-claims HTTP/1.1" 500 -
EN

回答 1

Stack Overflow用户

发布于 2019-07-30 22:36:04

问题是,由于flask中的线程,您正在失去tf图的上下文。当您加载模型时,您需要保存对tf图的引用,以便以后使用它。

代码语言:javascript
复制
def __init__():
    self.model = load_model()
    self.graph = tf.get_default_graph()

def predict():
    with self.graph.as_default():
        self.model.predict()
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57270982

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档