我有这样一个类,用于查找给定句子的句子嵌入
class Embeddings:
def __init__(self):
self.embedding_model_url = config_obj.tf_model_url
self.embedding_model = hub.Module(self.embedding_model_url)
self.messages = tf.placeholder(dtype=tf.string, shape=[None])
self.output = self.embedding_model(self.messages)
# self.initialize_graph()
@staticmethod
def initialize_graph():
with tf.Session() as session:
session.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
def get_sentence_embeddings(self, sentence):
with tf.Session() as session:
session.run([tf.global_variables_initializer(), tf.tables_initializer()])
result = session.run(self.output, feed_dict={self.messages: [sentence]})
return result
if __name__ == '__main__':
sentence = "GoAir is waiving cancellation and change fees for Bhubaneswar, Kolkata and Ranchi flights for travel between May 2 and May 5, the airline said in a statement"
tf_object = Embeddings()
embeddings = tf_object.get_sentence_embeddings(sentence)
print(embeddings)这是一个独立的应用程序,但当我尝试将其与Flask集成时,如下所示
from sentence_embeddings import Embeddings
embedding_obj = Embeddings()
@app.route('/get-similar-claims', methods=['POST'])
def get_similar_claims():
params = request.get_json()
claim = params.get("claim", "")
num_results = params.get("num_results", 10)
t0 = time.time()
# claim_embeddings = ""
claim_embeddings = embedding_obj.get_sentence_embeddings(claim)
logger.info("Time taken to calculate sentence embeddings - {}".format(round(time.time() - t0, 4)))
return Response(json.dumps(claim_embeddings), mimetype='application/json')
if __name__ == '__main__':
app.run('0.0.0.0', 5001)它抛出一个错误
File "/Users/anuragsharma/claim_similarity/api/app.py", line 32, in get_similar_claims
claim_embeddings = embedding_obj.get_sentence_embeddings(claim)
File "/Users/anuragsharma/claim_similarity/api/sentence_embeddings.py", line 31, in get_sentence_embeddings
result = session.run(self.output, feed_dict={self.messages: tf.convert_to_tensor(sentence)})
File "/Users/anuragsharma/anaconda3/envs/similarity-search/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 929, in run
run_metadata_ptr)
File "/Users/anuragsharma/anaconda3/envs/similarity-search/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1095, in _run
'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(?,), dtype=string) is not an element of this graph.
I0730 16:44:33.384036 123145584836608 _internal.py:122] 127.0.0.1 - - [30/Jul/2019 16:44:33] "POST /get-similar-claims HTTP/1.1" 500 -发布于 2019-07-30 22:36:04
问题是,由于flask中的线程,您正在失去tf图的上下文。当您加载模型时,您需要保存对tf图的引用,以便以后使用它。
def __init__():
self.model = load_model()
self.graph = tf.get_default_graph()
def predict():
with self.graph.as_default():
self.model.predict()https://stackoverflow.com/questions/57270982
复制相似问题