首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何允许文本输入到TensorFlow模型?

如何允许文本输入到TensorFlow模型?
EN

Stack Overflow用户
提问于 2020-01-04 01:05:02
回答 1查看 317关注 0票数 1

我正在用TensorFlow开发一个自定义的文本分类模型,现在我想用TensorFlow serving来设置它,以便进行生产部署。该模型基于通过单独的模型计算的文本嵌入进行预测,该模型要求将原始文本编码为向量。

我现在以一种有点脱节的方式工作,一个服务完成所有的文本预处理,然后计算嵌入,然后将嵌入作为嵌入的文本向量发送到文本分类器。如果我们能将所有这些都捆绑到一个TensorFlow服务模型中,尤其是最初的文本预处理步骤,那就太好了。

这就是我被困住的地方。如何构造一个作为原始文本输入的张量(或其他TensorFlow原语)?您是否需要做一些特殊的事情来标记标记向量组件映射的查找表,以便将其保存为模型包的一部分?

作为参考,这里是我现在所拥有的大致情况:

代码语言:javascript
复制
input = tf.placeholder(tf.float32, [None, 510], name='input')

# lots of steps omitted for brevity/clarity

outputs = tf.linalg.matmul(outputs, terminal_layer, transpose_b=True, name='output')

sess = tf.Session()
tf.saved_model.simple_save(sess,
                           'model.pb',
                           inputs={'input': input}, outputs={'output': outputs})
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-01-04 04:45:20

这被证明是相对简单的,这要归功于作为TensorFlow标准库的一部分的tf.lookup.StaticVocabularyTable

我的模型是使用词袋方法,而不是保持顺序,尽管这将是对代码的一个非常简单的更改。

假设您有一个对词汇表进行编码的list对象(我称之为vocab)和您想要使用的相应术语/标记嵌入矩阵(我称之为raw_term_embeddings,因为我将其强制到张量中),代码将如下所示:

代码语言:javascript
复制
initalizer = tf.lookup.KeyValueTensorInitializer(vocab, np.arange(len(vocab)))
lut = tf.lookup.StaticVocabularyTable(initalizer, 1) # the one here is the out of vocab size
lut.initializer.run(session=sess) # pushes the LUT onto the session

input = tf.placeholder(tf.string, [None, None], name='input')

ones_at = lut.lookup(input)
encoded_text = tf.math.reduce_sum(tf.one_hot(ones_at, tf.dtypes.cast(lut.size(), np.int32)), axis=0, keepdims=True)

# I didn't build an embedding for the out of vocabulary token
term_embeddings = tf.convert_to_tensor(np.vstack([raw_term_embeddings]), dtype=tf.float32)
embedded_text = tf.linalg.matmul(encoded_text, term_embeddings)

# then use embedded_text for the remainder of the model

其中一个小技巧是确保在加载模型时将legacy_init_op=tf.tables_initializer()传递给save函数,以提示TensorFlow初始化文本编码的查找表。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59582516

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档