为了得到我的预训练模型的预测/输出;model为卷积图像的每一帧(列)预测一个符号,并且有必要对logits (RNN的输出)进行后处理,以发出实际的预测符号序列。构建模型的代码可以在here中找到。
logits = graph.get_tensor_by_name("fully_connected/BiasAdd:0")
decoded, _ = tf.nn.ctc_greedy_decoder(logits, seq_len)
prediction = sess.run(decoded,
feed_dict={
input: image,
seq_len: seq_lengths,
rnn_keep_prob: 1.0,
})Prediction是一个包含所有预测符号的SparseTensorValue。Decoded是非空张量的稀疏张量。最后,我解析生成的SparseTensorValue以获得所需的字符串。
我想通过tensorflow serving或tflite使用这个经过训练的模型进行推理,但是为了继续,我需要指出模型的输出节点。考虑到稀疏张量的性质,我无法通过名称来表示它。有没有一种方法可以让我使用这个模型进行正确的推理?
我见过许多以类似的方式使用this等ctc解码器进行预测的示例,但是,在不紧密依赖tensorflow api的情况下,没有使用这些模型进行推理的示例,我不确定如何继续。
发布于 2021-02-16 09:17:57
可以将模型保存为tf saved_model格式。之后,您可以使用tensorflow-serving-api包的CLI工具saved_model_cli通过:saved_model_cli show --dir . --all检查所有模型签名。有了它,您将看到输入和输出形状的所有信息。默认签名称为default_serving。
https://stackoverflow.com/questions/65776162
复制相似问题