我是TensorFlow的新手,无法找到这些问题的解决方案。
im2txt模型,使im2txt模型所训练的数据集不会丢失,并将我的新数据集添加到MSCOCO数据集中以标题新图像(即培训dataset= MSCOCO dataset +我的新数据集)。请大家分享我在再培训过程中所能面对的详细程序和问题。im2txt模型的TensorFlow教程,这种方法也能应用到TensorFlow模型中,即可以对移动图像进行实时字幕。有人,请分享详细的步骤如何做到这一点。发布于 2018-04-04 06:14:04
经过几周的挣扎之后,我们可以在Android上运行和执行im2txt模型。由于我从不同的博客和不同的问答中找到了解决方案,我觉得如果所有(最大的)解决方案位于一个place.So上,共享以下步骤可能是有用的。
您需要克隆tensorflow项目https://github.com/tensorflow/tensorflow/releases/tag/v1.5.0,以便冻结图形和更多的实用程序。
下载im2txt模型表单https://github.com/KranthiGV/Pretrained-Show-and-Tell-model,遵循上述链接中描述的步骤,可以在重命名图形中的一些变量(克服NotFoundError )之后,成功地运行推理在Linux桌面上生成标题(用于克服NotFoundError(关于跟踪):键lstm/basic_lstm_cell/偏向在检查点类型的错误中找不到)
现在,我们需要冻结现有的模型以获得冻结的图形,以便在android/ios中使用。
在使用tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)的克隆tensorflow项目中,可以通过提供以下命令来冻结来自任何模型的图形--命令行用法的示例如下:
bazel build tensorflow/python/tools:freeze_graph && \
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=some_graph_def.pb \
--input_checkpoint=model.ckpt-8361242 \
--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
--input_binary=true我们需要提供运行模型所需的所有output_node_names,在运行冻结图命令时,我们可以从output_node_names中列出输出节点名为' softmax‘、'lstm/initial_state’和‘lstm/output_node_names’,方法是提供输出节点名为'softmax‘、’lstm/output_node_names_state‘和'lstm/state’获得错误"AssertionError: softmax不在图中“。
从Steph和Jeff关于如何冻结im2txt模型(How to freeze an im2txt model?)的答案
当前的模型ckpt.data、ckpt.index和ckpt.meta文件以及graph.pbtxt应该以推理模式加载(参见im2txt中的InferenceWrapper )。它构建一个具有正确名称'softmax‘、'lstm/initial_state’和'lstm/state‘的图形。保存这个图(使用相同的ckpt格式),然后可以应用freeze_graph脚本来获得冻结的模型。
要在Pretrained-Show-and-Tell-model\im2txt\im2txt\inference_utils\inference_wrapper.base.py,中这样做,只需在def _restore_fn(Ses)中的saver.restore(sess, checkpoint_path)之后添加类似于saver.save(sess, "model/ckpt4")的内容:。然后重新构建和run_inference,您将得到一个模型,可以冻结,转换,并可随意映射,由iOS和安卓应用程序加载
现在,我运行下面的命令
python tensorflow/python/tools/freeze_graph.py \
--input_meta_graph=/tmp/ckpt4.meta \
--input_checkpoint=/tmp/ckpt4 \
--output_graph=/tmp/ckpt4_frozen.pb \
--output_node_names="softmax,lstm/initial_state,lstm/state" \
--input_binary=true并将获得的ckpt4_frozen.pb文件加载到Android应用程序中,并得到错误"java.lang.IllegalArgumentException:没有注册OpKernel以支持操作系统'DecodeJpeg‘。注册设备: CPU,注册内核:[节点: decode/DecodeJpeg = DecodeJpegacceptable_fraction=1、channels=3、dct_method=“、fancy_upscaling=true、ratio=1、try_recover_truncated=false]。
来自https://github.com/tensorflow/tensorflow/issues/2883
由于DecodeJpeg作为核心的一部分不受支持,所以您需要先将它从图形中去掉
bazel build tensorflow/python/tools:strip_unused && \
bazel-bin/tensorflow/python/tools/strip_unused \
--input_graph=ckpt4_frozen.pb \
--output_graph=ckpt4_frozen_stripped_graph.pb \
--input_node_names=convert_image/Cast,input_feed,lstm/state_feed\
--output_node_names=softmax,lstm/initial_state,lstm/state\
--input_binary=true当我尝试在android中加载ckpt4_frozen_stripped_graph.pb时,我遇到了错误,所以我遵循了Jeff的答案(inference.py on frozen graph),而不是tools:strip_unused,我使用了图形转换工具
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=/tmp/ckpt4_frozen.pb \
--out_graph=/tmp/ckpt4_frozen_transformed.pb \
--inputs="convert_image/Cast,input_feed,lstm/state_feed" \
--outputs="softmax,lstm/initial_state,lstm/state" \
--transforms='
strip_unused_nodes(type=float, shape="1,299,299,3")
fold_constants(ignore_errors=true)
fold_batch_norms
fold_old_batch_norms' 我可以成功地在安卓上加载所获得的ckpt4_frozen_transformed.pb。当我为输入节点提供输入作为RGB图像像素的浮动数组时,成功地从"lstm/initail_state“节点获取输出。
现在的挑战是理解"Pretrained-Show-and-Tell-model\im2txt\im2txt\inference_utils\caption_generator.py“中的波束搜索,同样也应该在Android端实现。
如果您观察到python脚本caption_generator.py在
softmax, new_states, metadata = self.model.inference_step(sess,input_feed,state_feed)input_feed是int32位数组,state_feed是多维浮点数数组。
在android方面,我尝试为"input_feed“输入input_feed位数组,因为没有Java来输入多维数组,所以我将浮点数数组输入到lstm/state_feed,因为它以前从"lstm/initail_state”节点中获取。
得到了两个错误,一个是input_fedd期望int 64位和"java.lang.IllegalArgumentException:-input秩(-1) <= split_dim <输入秩(1),但在lstm/state_feed中得到1“。
对于第一个错误,我将input_feed提要数据类型从int32更改为int 64。
关于第二个错误,它是期望等级二张量。如果您看到tensorflow java源--我们正在输入的数据类型浮点数数组被转换成一级张量,我们应该以这样的方式提供数据类型,即应该创建二级张量,但目前,当我浏览tensorflow java源时,还没有找到任何API来为多维浮点数组提供服务--我发现这个API没有公开为Android API,我们可以创建一个二级张量。因此,我通过启用等级2张量创建调用来重建libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar。(关于构建过程,请参考https://blog.mindorks.com/android-tensorflow-machine-learning-example-ff0e9b2654cc)
现在我可以在Android上运行推理,并获得image.But的一个标题,准确率很低。限制一个标题的原因是,我没有找到作为多维数组获取输出的方法,而多维数组是为单个图像生成更多阳离子所必需的。
String actualFilename = labelFilename.split("file:///android_asset/")[1];
vocab = new Vocabulary(assetManager.open(actualFilename));
inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
final Graph g = c.inferenceInterface.graph();
final Operation inputOperation = g.operation(inputName);
if (inputOperation == null) {
throw new RuntimeException("Failed to find input Node '" + inputName + "'");
}
final Operation outPutOperation = g.operation(outputName);
if (outPutOperation == null) {
throw new RuntimeException("Failed to find output Node '" + outputName + "'");
}
// The shape of the output is [N, NUM_CLASSES], where N is the batch size.
int numClasses = (int) inferenceInterface.graph().operation(outputName)
.output(0).shape().size(1);
Log.i(TAG, "Read " + vocab.totalWords() + " labels, output layer size is " + numClasses);
// Ideally, inputSize could have been retrieved from the shape of the input operation. Alas,
// the placeholder node for input in the graphdef typically used does not specify a shape, so it
// must be passed in as a parameter.
inputSize = inputSize;
// Pre-allocate buffers.
outputNames = new String[]{outputName + ":0"};
outputs = new float[numClasses];
inferenceInterface.feed(inputName + ":0", pixels, inputSize, inputSize, 3);
inferenceInterface.run(outputNames, runStats);
inferenceInterface.fetch(outputName + ":0", outputs);
startIm2txtBeamSearch(outputs);//在JAVA中实现束搜索
private void startIm2txtBeamSearch(float[] outputs) {
int beam_size = 1;
//TODO:Prepare vocab ids from file
ArrayList<Integer> vocab_ids = new ArrayList<>();
vocab_ids.add(1);
int vocab_end_id = 2;
float lenth_normalization_factor = 0;
int maxCaptionLength = 20;
Graph g = inferenceInterface.graph();
//node input feed
String input_feed_node_name = "input_feed";
Operation inputOperation = g.operation(input_feed_node_name);
if (inputOperation == null) {
throw new RuntimeException("Failed to find input Node '" + input_feed_node_name + "'");
}
String output_feed_node_name = "softmax";
Operation outPutOperation = g.operation(output_feed_node_name);
if (outPutOperation == null) {
throw new RuntimeException("Failed to find output Node '" + output_feed_node_name + "'");
}
int output_feed_node_numClasses = (int) outPutOperation.output(0).shape().size(1);
Log.i(TAG, "Output layer " + output_feed_node_name + ", output layer size is " + output_feed_node_numClasses);
FloatBuffer output_feed_output = FloatBuffer.allocate(output_feed_node_numClasses);
//float [][] output_feed_output = new float[numClasses][];
//node state feed
String input_state_feed_node_name = "lstm/state_feed";
inputOperation = g.operation(input_state_feed_node_name);
if (inputOperation == null) {
throw new RuntimeException("Failed to find input Node '" + input_state_feed_node_name + "'");
}
String output_state_feed_node_name = "lstm/state";
outPutOperation = g.operation(output_state_feed_node_name);
if (outPutOperation == null) {
throw new RuntimeException("Failed to find output Node '" + output_state_feed_node_name + "'");
}
int output_state_feed_node_numClasses = (int) outPutOperation.output(0).shape().size(1);
Log.i(TAG, "Output layer " + output_state_feed_node_name + ", output layer size is " + output_state_feed_node_numClasses);
FloatBuffer output_state_output = FloatBuffer.allocate(output_state_feed_node_numClasses);
//float[][] output_state_output= new float[numClasses][];
String[] output_nodes = new String[]{output_feed_node_name, output_state_feed_node_name};
Caption initialBean = new Caption(vocab_ids, outputs, (float) 0.0, (float) 0.0);
TopN partialCaptions = new TopN(beam_size);
partialCaptions.push(initialBean);
TopN completeCaption = new TopN(beam_size);
captionLengthLoop:
for (int i = maxCaptionLength; i >= 0; i--) {
List<Caption> partialCaptionsList = new LinkedList<>(partialCaptions.extract(false));
partialCaptions.reset();
long[] input_feed = new long[partialCaptionsList.size()];
float[][] state_feed = new float[partialCaptionsList.size()][];
for (int j = 0; j < partialCaptionsList.size(); j++) {
Caption curCaption = partialCaptionsList.get(j);
ArrayList<Integer> senArray = curCaption.getSentence();
input_feed[j] = senArray.get(senArray.size() - 1);
state_feed[j] = curCaption.getState();
}
//feeding
inferenceInterface.feed(input_feed_node_name, input_feed, new long[]{input_feed.length});
inferenceInterface.feed(input_state_feed_node_name, state_feed, new long[]{state_feed.length});
//run
inferenceInterface.run(output_nodes, runStats);
//fetching
inferenceInterface.fetch(output_feed_node_name, output_feed_output);
inferenceInterface.fetch(output_state_feed_node_name, output_state_output);
float[] word_probabilities = new float[partialCaptionsList.size()];
float[] new_state = new float[partialCaptionsList.size()];
for (int k = 0; k < partialCaptionsList.size(); k++) {
word_probabilities = output_feed_output.array();
//output_feed_output.get(word_probabilities[k]);
new_state = output_state_output.array();
//output_feed_output.get(state[k]);
// For this partial caption, get the beam_size most probable next words.
Map<Integer, Float> word_and_probs = new LinkedHashMap<>();
//key is index of probability; value is index = word
for (int l = 0; l < word_probabilities.length; l++) {
word_and_probs.put(l, word_probabilities[l]);
}
//sorting
// word_and_probs = word_and_probs.entrySet().stream()
// .sorted(Map.Entry.comparingByValue())
// .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue,(e1, e2) -> e1, LinkedHashMap::new));
word_and_probs = MapUtil.sortByValue(word_and_probs);
//considering first (beam size probabilities)
LinkedHashMap<Integer, Float> final_word_and_probs = new LinkedHashMap<>();
for (int key : word_and_probs.keySet()) {
final_word_and_probs.put(key, word_and_probs.get(key));
if (final_word_and_probs.size() == beam_size)
break;
}
for (int w : final_word_and_probs.keySet()) {
float p = final_word_and_probs.get(w);
if (p < 1e-12) {//# Avoid log(0).
Log.d(TAG, "p is < 1e-12");
continue;
}
Caption partialCaption = partialCaptionsList.get(k);
ArrayList<Integer> sentence = new ArrayList<>(partialCaption.getSentence());
sentence.add(w);
float logprob = (float) (partialCaption.getPorb() + Math.log(p));
float scroe = logprob;
Caption beam = new Caption(sentence, new_state, logprob, scroe);
if (w == vocab_end_id) {
completeCaption.push(beam);
} else {
partialCaptions.push(beam);
}
}
if (partialCaptions.getSize() == 0)//run out of partial candidates; happens when beam_size = 1.
break captionLengthLoop;
}
//clear buffer retrieve sub sequent output
output_feed_output.clear();
output_state_output.clear();
output_feed_output = null;
output_state_output = null;
output_feed_output = FloatBuffer.allocate(output_feed_node_numClasses);
output_state_output = FloatBuffer.allocate(output_state_feed_node_numClasses);
Log.d(TAG, "----" + i + " Iteration completed----");
}
Log.d(TAG, "----Total Iterations completed----");
LinkedList<Caption> completeCaptions = completeCaption.extract(true);
for (Caption cap : completeCaptions) {
ArrayList<Integer> wordids = cap.getSentence();
StringBuffer caption = new StringBuffer();
boolean isFirst = true;
for (int word : wordids) {
if (!isFirst)
caption.append(" ");
caption.append(vocab.getWord(word));
isFirst = false;
}
Log.d(TAG, "Cap score = " + Math.exp(cap.getScore()) + " and Caption is " + caption);
}
}//词汇b
public class Vocabulary {
String TAG = Vocabulary.class.getSimpleName();
String start_word = "<S>", end_word = "</S>", unk_word = "<UNK>";
ArrayList<String> words;
public Vocabulary(File vocab_file) {
loadVocabsFromFile(vocab_file);
}
public Vocabulary(InputStream vocab_file_stream) {
words = readLinesFromFileAndLoadWords(new InputStreamReader(vocab_file_stream));
}
public Vocabulary(String vocab_file_path) {
File vocabFile = new File(vocab_file_path);
loadVocabsFromFile(vocabFile);
}
private void loadVocabsFromFile(File vocabFile) {
try {
this.words = readLinesFromFileAndLoadWords(new FileReader(vocabFile));
//Log.d(TAG, "Words read from file = " + words.size());
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
private ArrayList<String> readLinesFromFileAndLoadWords(InputStreamReader file_reader) {
ArrayList<String> words = new ArrayList<>();
try (BufferedReader br = new BufferedReader(file_reader)) {
String line;
while ((line = br.readLine()) != null) {
// process the line.
words.add(line.split(" ")[0].trim());
}
br.close();
if (!words.contains(unk_word))
words.add(unk_word);
} catch (IOException e) {
e.printStackTrace();
}
return words;
}
public String getWord(int word_id) {
if (words != null)
if (word_id >= 0 && word_id < words.size())
return words.get(word_id);
return "No word found, Maybe Vocab File not loaded";
}
public int totalWords() {
if (words != null)
return words.size();
return 0;
}
}//MapUtil
public class MapUtil {
public static <K, V extends Comparable<? super V>> Map<K, V> sortByValue(Map<K, V> map) {
List<Map.Entry<K, V>> list = new ArrayList<>(map.entrySet());
list.sort(new Comparator<Map.Entry<K, V>>() {
@Override
public int compare(Map.Entry<K, V> o1, Map.Entry<K, V> o2) {
if (o1.getValue() instanceof Float && o2.getValue() instanceof Float) {
Float o1Float = (Float) o1.getValue();
Float o2Float = (Float) o2.getValue();
return o1Float >= o2Float ? -1 : 1;
}
return 0;
}
});
Map<K, V> result = new LinkedHashMap<>();
for (Map.Entry<K, V> entry : list) {
result.put(entry.getKey(), entry.getValue());
}
return result;
}
}//标题
public class Caption implements Comparable<Caption> {
private ArrayList<Integer> sentence;
private float[] state;
private float porb;
private float score;
public Caption(ArrayList<Integer> sentence, float[] state, float porb, float score) {
this.sentence = sentence;
this.state = state;
this.porb = porb;
this.score = score;
}
public ArrayList<Integer> getSentence() {
return sentence;
}
public void setSentence(ArrayList<Integer> sentence) {
this.sentence = sentence;
}
public float[] getState() {
return state;
}
public void setState(float[] state) {
this.state = state;
}
public float getPorb() {
return porb;
}
public void setPorb(float porb) {
this.porb = porb;
}
public float getScore() {
return score;
}
public void setScore(float score) {
this.score = score;
}
@Override
public int compareTo(@NonNull Caption oc) {
if (score == oc.score)
return 0;
if (score < oc.score)
return -1;
else
return 1;
}
}//TopN
public class TopN {
//Maintains the top n elements of an incrementally provided set.
int n;
LinkedList<Caption> data;
public TopN(int n) {
this.n = n;
this.data = new LinkedList<>();
}
public int getSize() {
if (data != null)
return data.size();
return 0;
}
//Pushes a new element
public void push(Caption x) {
if (data != null) {
if (getSize() < n) {
data.add(x);
} else {
data.removeLast();
data.add(x);
}
}
}
//Extracts all elements from the TopN. This is a destructive operation.
//The only method that can be called immediately after extract() is reset().
//Args:
//sort: Whether to return the elements in descending sorted order.
//Returns: A list of data; the top n elements provided to the set.
public LinkedList<Caption> extract(boolean sort) {
if (sort) {
Collections.sort(data);
}
return data;
}
//Returns the TopN to an empty state.
public void reset() {
if (data != null) data.clear();
}
}尽管准确性很低,但我还是会分享这一点,因为这可能对一些人来说是有用的,比如在android中加载显示和告诉模型。
https://stackoverflow.com/questions/44141279
复制相似问题