首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tensorflow 1 Session.run使用通用语句编码器嵌入语句花费了太多时间

tensorflow 1 Session.run使用通用语句编码器嵌入语句花费了太多时间
EN

Stack Overflow用户
提问于 2020-07-14 07:25:28
回答 2查看 272关注 0票数 1

与烧瓶REST 一起使用张量流

如何减少session.run 时间?

我在REST中使用tf 1/2,而不是在服务器上使用它,而不是服务它。

我试过tensorflow 1和2。

tensorflow 1花了太多时间。

tensorflow 2甚至没有返回文本的向量。

in tensorflow 1初始化需要2-4秒,session.run需要5-8秒.随着我不断地满足要求,时间也在增加。

tensorflow 1

代码语言:javascript
复制
import tensorflow.compat.v1 as tfo
import tensorflow_hub as hub
tfo.disable_eager_execution()

module_url = "https://tfhub.dev/google/universal-sentence-encoder-qa/3"
# Import the Universal Sentence Encoder's TF Hub module
embed = hub.Module(module_url)

def convert_text_to_vector(text):
    # Compute a representation for each message, showing various lengths supported.
    try:
        #text = "qwerty" or ["qwerty"]
        if isinstance(text, str):
            text = [text]
        with tfo.Session() as session:
            t_time = time.time()
            session.run([tfo.global_variables_initializer(), tfo.tables_initializer()])
            m_time = time.time()
            message_embeddings = session.run(embed(text))
            vector_array = message_embeddings.tolist()[0]
        return vector_array
    except Exception as err:
        raise Exception(str(err))

tensorflow 2

它被困在vector_array = embedding_fn(text)

代码语言:javascript
复制
import tensorflow as tf
import tensorflow_hub as hub
module_url = "https://tfhub.dev/google/universal-sentence-encoder-qa/3"
embedding_fn = hub.load(module_url)

@tf.function
def convert_text_to_vector(text):
    try:
        #text = ["qwerty"]
        vector_array = embedding_fn(text)
        return vector_array
    except Exception as err:
        raise Exception(str(err))
EN

回答 2

Stack Overflow用户

发布于 2020-07-14 09:37:14

对于tensorflow 2版本,我做了很少的修正。基本上,我遵循了您提供的通用句子编码器中的示例。

代码语言:javascript
复制
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
module_url = "https://tfhub.dev/google/universal-sentence-encoder-qa/3"
embedding_fn = hub.load(module_url)

@tf.function
def convert_text_to_vector(text):
  try:
      vector_array = embedding_fn.signatures['question_encoder'](
          tf.constant(text))
      return vector_array['outputs']
  except Exception as err:
      raise Exception(str(err))

### run the function
vector = convert_text_to_vector(['is this helpful ?'])
print(vector.shape())
票数 0
EN

Stack Overflow用户

发布于 2020-07-14 10:16:34

代码语言:javascript
复制
from flask import Flask
from flask_restplus import Api, Resource
from werkzeug.utils import cached_property

import tensorflow as tf
import tensorflow_hub as hub
module_url = "https://tfhub.dev/google/universal-sentence-encoder-qa/3"
embedding_fn = hub.load(module_url)


app = Flask(__name__)

@app.route('/embedding', methods=['POST'])
def entry_point(args):
    if args.get("text"):
        text_term = args.get("text")
        if isinstance(text_term, str):
            text_term = [text_term]
        vectors = convert_text_to_vector(text_term)
    return vectors



@tf.function
def convert_text_to_vector(text):
    try:
        vector_array = embedding_fn.signatures['question_encoder'](tf.constant(text))
        return vector_array['outputs']
    except Exception as err:
        raise Exception(str(err))


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, debug=True)

"""
 ----- Requirements.txt ----
flask-restplus==0.13.0
Flask==1.1.1
Werkzeug==0.15.5
tensorboard==2.2.2
tensorboard-plugin-wit==1.6.0.post3
tensorflow==2.2.0
tensorflow-estimator==2.2.0
tensorflow-hub==0.8.0
tensorflow-text==2.2.1
"""
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62890024

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档