首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在安卓手机中集成im2txt模型

在安卓手机中集成im2txt模型
EN

Stack Overflow用户
提问于 2017-05-23 17:22:19
回答 1查看 455关注 0票数 0

我是TensorFlow的新手,无法找到这些问题的解决方案。

  1. 如何为我的新数据集重新培训im2txt模型,使im2txt模型所训练的数据集不会丢失,并将我的新数据集添加到MSCOCO数据集中以标题新图像(即培训dataset= MSCOCO dataset +我的新数据集)。请大家分享我在再培训过程中所能面对的详细程序和问题。
  2. 我已经找到了在实时数据集上运行android初始im2txt模型的TensorFlow教程,这种方法也能应用到TensorFlow模型中,即可以对移动图像进行实时字幕。有人,请分享详细的步骤如何做到这一点。
EN

回答 1

Stack Overflow用户

发布于 2018-04-04 06:14:04

经过几周的挣扎之后,我们可以在Android上运行和执行im2txt模型。由于我从不同的博客和不同的问答中找到了解决方案,我觉得如果所有(最大的)解决方案位于一个place.So上,共享以下步骤可能是有用的。

您需要克隆tensorflow项目https://github.com/tensorflow/tensorflow/releases/tag/v1.5.0,以便冻结图形和更多的实用程序。

下载im2txt模型表单https://github.com/KranthiGV/Pretrained-Show-and-Tell-model,遵循上述链接中描述的步骤,可以在重命名图形中的一些变量(克服NotFoundError )之后,成功地运行推理在Linux桌面上生成标题(用于克服NotFoundError(关于跟踪):键lstm/basic_lstm_cell/偏向在检查点类型的错误中找不到)

现在,我们需要冻结现有的模型以获得冻结的图形,以便在android/ios中使用。

在使用tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)的克隆tensorflow项目中,可以通过提供以下命令来冻结来自任何模型的图形--命令行用法的示例如下:

代码语言:javascript
复制
bazel build tensorflow/python/tools:freeze_graph && \
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=some_graph_def.pb \
--input_checkpoint=model.ckpt-8361242 \
--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
--input_binary=true

我们需要提供运行模型所需的所有output_node_names,在运行冻结图命令时,我们可以从output_node_names中列出输出节点名为' softmax‘、'lstm/initial_state’和‘lstm/output_node_names’,方法是提供输出节点名为'softmax‘、’lstm/output_node_names_state‘和'lstm/state’获得错误"AssertionError: softmax不在图中“。

从Steph和Jeff关于如何冻结im2txt模型(How to freeze an im2txt model?)的答案

当前的模型ckpt.data、ckpt.index和ckpt.meta文件以及graph.pbtxt应该以推理模式加载(参见im2txt中的InferenceWrapper )。它构建一个具有正确名称'softmax‘、'lstm/initial_state’和'lstm/state‘的图形。保存这个图(使用相同的ckpt格式),然后可以应用freeze_graph脚本来获得冻结的模型。

要在Pretrained-Show-and-Tell-model\im2txt\im2txt\inference_utils\inference_wrapper.base.py,中这样做,只需在def _restore_fn(Ses)中的saver.restore(sess, checkpoint_path)之后添加类似于saver.save(sess, "model/ckpt4")的内容:。然后重新构建和run_inference,您将得到一个模型,可以冻结,转换,并可随意映射,由iOS和安卓应用程序加载

现在,我运行下面的命令

代码语言:javascript
复制
python tensorflow/python/tools/freeze_graph.py  \
--input_meta_graph=/tmp/ckpt4.meta \
--input_checkpoint=/tmp/ckpt4 \
--output_graph=/tmp/ckpt4_frozen.pb \
--output_node_names="softmax,lstm/initial_state,lstm/state" \
--input_binary=true

并将获得的ckpt4_frozen.pb文件加载到Android应用程序中,并得到错误"java.lang.IllegalArgumentException:没有注册OpKernel以支持操作系统'DecodeJpeg‘。注册设备: CPU,注册内核:[节点: decode/DecodeJpeg = DecodeJpegacceptable_fraction=1、channels=3、dct_method=“、fancy_upscaling=true、ratio=1、try_recover_truncated=false]。

来自https://github.com/tensorflow/tensorflow/issues/2883

由于DecodeJpeg作为核心的一部分不受支持,所以您需要先将它从图形中去掉

代码语言:javascript
复制
bazel build tensorflow/python/tools:strip_unused && \
bazel-bin/tensorflow/python/tools/strip_unused \
--input_graph=ckpt4_frozen.pb \
--output_graph=ckpt4_frozen_stripped_graph.pb \
--input_node_names=convert_image/Cast,input_feed,lstm/state_feed\
--output_node_names=softmax,lstm/initial_state,lstm/state\
--input_binary=true

当我尝试在android中加载ckpt4_frozen_stripped_graph.pb时,我遇到了错误,所以我遵循了Jeff的答案(inference.py on frozen graph),而不是tools:strip_unused,我使用了图形转换工具

代码语言:javascript
复制
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=/tmp/ckpt4_frozen.pb \
--out_graph=/tmp/ckpt4_frozen_transformed.pb \
--inputs="convert_image/Cast,input_feed,lstm/state_feed" \
--outputs="softmax,lstm/initial_state,lstm/state" \
--transforms='
      strip_unused_nodes(type=float, shape="1,299,299,3")
      fold_constants(ignore_errors=true) 
      fold_batch_norms
      fold_old_batch_norms' 

我可以成功地在安卓上加载所获得的ckpt4_frozen_transformed.pb。当我为输入节点提供输入作为RGB图像像素的浮动数组时,成功地从"lstm/initail_state“节点获取输出。

现在的挑战是理解"Pretrained-Show-and-Tell-model\im2txt\im2txt\inference_utils\caption_generator.py“中的波束搜索,同样也应该在Android端实现。

如果您观察到python脚本caption_generator.py在

代码语言:javascript
复制
softmax, new_states, metadata = self.model.inference_step(sess,input_feed,state_feed)

input_feed是int32位数组,state_feed是多维浮点数数组。

在android方面,我尝试为"input_feed“输入input_feed位数组,因为没有Java来输入多维数组,所以我将浮点数数组输入到lstm/state_feed,因为它以前从"lstm/initail_state”节点中获取。

得到了两个错误,一个是input_fedd期望int 64位和"java.lang.IllegalArgumentException:-input秩(-1) <= split_dim <输入秩(1),但在lstm/state_feed中得到1“。

对于第一个错误,我将input_feed提要数据类型从int32更改为int 64。

关于第二个错误,它是期望等级二张量。如果您看到tensorflow java源--我们正在输入的数据类型浮点数数组被转换成一级张量,我们应该以这样的方式提供数据类型,即应该创建二级张量,但目前,当我浏览tensorflow java源时,还没有找到任何API来为多维浮点数组提供服务--我发现这个API没有公开为Android API,我们可以创建一个二级张量。因此,我通过启用等级2张量创建调用来重建libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar。(关于构建过程,请参考https://blog.mindorks.com/android-tensorflow-machine-learning-example-ff0e9b2654cc)

现在我可以在Android上运行推理,并获得image.But的一个标题,准确率很低。限制一个标题的原因是,我没有找到作为多维数组获取输出的方法,而多维数组是为单个图像生成更多阳离子所必需的。

代码语言:javascript
复制
String actualFilename = labelFilename.split("file:///android_asset/")[1];

vocab = new Vocabulary(assetManager.open(actualFilename));


inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
final Graph g = c.inferenceInterface.graph();

final Operation inputOperation = g.operation(inputName);
if (inputOperation == null) {
    throw new RuntimeException("Failed to find input Node '" + inputName + "'");
}
final Operation outPutOperation = g.operation(outputName);

if (outPutOperation == null) {
    throw new RuntimeException("Failed to find output Node '" + outputName + "'");
}

// The shape of the output is [N, NUM_CLASSES], where N is the batch size.
int numClasses = (int) inferenceInterface.graph().operation(outputName)
        .output(0).shape().size(1);


Log.i(TAG, "Read " + vocab.totalWords() + " labels, output layer size is " + numClasses);

// Ideally, inputSize could have been retrieved from the shape of the input operation.  Alas,
// the placeholder node for input in the graphdef typically used does not specify a shape, so it
// must be passed in as a parameter.
inputSize = inputSize;

// Pre-allocate buffers.
outputNames = new String[]{outputName + ":0"};
outputs = new float[numClasses];
inferenceInterface.feed(inputName + ":0", pixels, inputSize, inputSize, 3);


inferenceInterface.run(outputNames, runStats);


inferenceInterface.fetch(outputName + ":0", outputs);



startIm2txtBeamSearch(outputs);

//在JAVA中实现束搜索

代码语言:javascript
复制
private void startIm2txtBeamSearch(float[] outputs) {

        int beam_size = 1;
        //TODO:Prepare vocab ids from file
        ArrayList<Integer> vocab_ids = new ArrayList<>();
        vocab_ids.add(1);
        int vocab_end_id = 2;
        float lenth_normalization_factor = 0;
        int maxCaptionLength = 20;
        Graph g = inferenceInterface.graph();


        //node input feed
        String input_feed_node_name = "input_feed";
        Operation inputOperation = g.operation(input_feed_node_name);
        if (inputOperation == null) {
            throw new RuntimeException("Failed to find input Node '" + input_feed_node_name + "'");
        }

        String output_feed_node_name = "softmax";
        Operation outPutOperation = g.operation(output_feed_node_name);
        if (outPutOperation == null) {
            throw new RuntimeException("Failed to find output Node '" + output_feed_node_name + "'");
        }
        int output_feed_node_numClasses = (int) outPutOperation.output(0).shape().size(1);
        Log.i(TAG, "Output layer " + output_feed_node_name + ", output layer size is " + output_feed_node_numClasses);
        FloatBuffer output_feed_output = FloatBuffer.allocate(output_feed_node_numClasses);
        //float [][] output_feed_output = new float[numClasses][];

        //node state feed
        String input_state_feed_node_name = "lstm/state_feed";
        inputOperation = g.operation(input_state_feed_node_name);
        if (inputOperation == null) {
            throw new RuntimeException("Failed to find input Node '" + input_state_feed_node_name + "'");
        }
        String output_state_feed_node_name = "lstm/state";
        outPutOperation = g.operation(output_state_feed_node_name);
        if (outPutOperation == null) {
            throw new RuntimeException("Failed to find output Node '" + output_state_feed_node_name + "'");
        }
        int output_state_feed_node_numClasses = (int) outPutOperation.output(0).shape().size(1);
        Log.i(TAG, "Output layer " + output_state_feed_node_name + ", output layer size is " + output_state_feed_node_numClasses);
        FloatBuffer output_state_output = FloatBuffer.allocate(output_state_feed_node_numClasses);
        //float[][] output_state_output= new float[numClasses][];
        String[] output_nodes = new String[]{output_feed_node_name, output_state_feed_node_name};


        Caption initialBean = new Caption(vocab_ids, outputs, (float) 0.0, (float) 0.0);
        TopN partialCaptions = new TopN(beam_size);
        partialCaptions.push(initialBean);
        TopN completeCaption = new TopN(beam_size);


        captionLengthLoop:
        for (int i = maxCaptionLength; i >= 0; i--) {
            List<Caption> partialCaptionsList = new LinkedList<>(partialCaptions.extract(false));
            partialCaptions.reset();

            long[] input_feed = new long[partialCaptionsList.size()];
            float[][] state_feed = new float[partialCaptionsList.size()][];

            for (int j = 0; j < partialCaptionsList.size(); j++) {
                Caption curCaption = partialCaptionsList.get(j);
                ArrayList<Integer> senArray = curCaption.getSentence();
                input_feed[j] = senArray.get(senArray.size() - 1);
                state_feed[j] = curCaption.getState();
            }
            //feeding
            inferenceInterface.feed(input_feed_node_name, input_feed, new long[]{input_feed.length});

            inferenceInterface.feed(input_state_feed_node_name, state_feed, new long[]{state_feed.length});


            //run
            inferenceInterface.run(output_nodes, runStats);

            //fetching
            inferenceInterface.fetch(output_feed_node_name, output_feed_output);
            inferenceInterface.fetch(output_state_feed_node_name, output_state_output);

            float[] word_probabilities = new float[partialCaptionsList.size()];
            float[] new_state = new float[partialCaptionsList.size()];
            for (int k = 0; k < partialCaptionsList.size(); k++) {
                word_probabilities = output_feed_output.array();
                //output_feed_output.get(word_probabilities[k]);
                new_state = output_state_output.array();
                //output_feed_output.get(state[k]);

                // For this partial caption, get the beam_size most probable next words.
                Map<Integer, Float> word_and_probs = new LinkedHashMap<>();
                //key is index of probability; value is index = word
                for (int l = 0; l < word_probabilities.length; l++) {
                    word_and_probs.put(l, word_probabilities[l]);
                }
                //sorting
//                word_and_probs = word_and_probs.entrySet().stream()
//                        .sorted(Map.Entry.comparingByValue())
//                        .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue,(e1, e2) -> e1, LinkedHashMap::new));
                word_and_probs = MapUtil.sortByValue(word_and_probs);
                //considering first (beam size probabilities)
                LinkedHashMap<Integer, Float> final_word_and_probs = new LinkedHashMap<>();

                for (int key : word_and_probs.keySet()) {
                    final_word_and_probs.put(key, word_and_probs.get(key));
                    if (final_word_and_probs.size() == beam_size)
                        break;
                }

                for (int w : final_word_and_probs.keySet()) {
                    float p = final_word_and_probs.get(w);
                    if (p < 1e-12) {//# Avoid log(0).
                        Log.d(TAG, "p is < 1e-12");
                        continue;
                    }
                    Caption partialCaption = partialCaptionsList.get(k);
                    ArrayList<Integer> sentence = new ArrayList<>(partialCaption.getSentence());
                    sentence.add(w);
                    float logprob = (float) (partialCaption.getPorb() + Math.log(p));
                    float scroe = logprob;
                    Caption beam = new Caption(sentence, new_state, logprob, scroe);
                    if (w == vocab_end_id) {
                        completeCaption.push(beam);
                    } else {
                        partialCaptions.push(beam);
                    }
                }
                if (partialCaptions.getSize() == 0)//run out of partial candidates; happens when beam_size = 1.
                    break captionLengthLoop;
            }


            //clear buffer retrieve sub sequent output
            output_feed_output.clear();
            output_state_output.clear();
            output_feed_output = null;
            output_state_output = null;
            output_feed_output = FloatBuffer.allocate(output_feed_node_numClasses);
            output_state_output = FloatBuffer.allocate(output_state_feed_node_numClasses);
            Log.d(TAG, "----" + i + " Iteration completed----");
        }
        Log.d(TAG, "----Total Iterations completed----");
        LinkedList<Caption> completeCaptions = completeCaption.extract(true);

        for (Caption cap : completeCaptions) {

            ArrayList<Integer> wordids = cap.getSentence();
            StringBuffer caption = new StringBuffer();
            boolean isFirst = true;
            for (int word : wordids) {
                if (!isFirst)
                    caption.append(" ");
                caption.append(vocab.getWord(word));
                isFirst = false;
            }
            Log.d(TAG, "Cap score = " + Math.exp(cap.getScore()) + " and Caption is " + caption);
        }

    }

//词汇b

代码语言:javascript
复制
    public class Vocabulary {
        String TAG = Vocabulary.class.getSimpleName();
        String start_word = "<S>", end_word = "</S>", unk_word = "<UNK>";
        ArrayList<String> words;

        public Vocabulary(File vocab_file) {
            loadVocabsFromFile(vocab_file);
        }

        public Vocabulary(InputStream vocab_file_stream) {
            words = readLinesFromFileAndLoadWords(new InputStreamReader(vocab_file_stream));
        }

        public Vocabulary(String vocab_file_path) {
            File vocabFile = new File(vocab_file_path);
            loadVocabsFromFile(vocabFile);
        }

        private void loadVocabsFromFile(File vocabFile) {
            try {
                this.words = readLinesFromFileAndLoadWords(new FileReader(vocabFile));
                //Log.d(TAG, "Words read from file = " + words.size());
            } catch (FileNotFoundException e) {
                e.printStackTrace();
            }
        }


        private ArrayList<String> readLinesFromFileAndLoadWords(InputStreamReader file_reader) {
            ArrayList<String> words = new ArrayList<>();
            try (BufferedReader br = new BufferedReader(file_reader)) {
                String line;
                while ((line = br.readLine()) != null) {
                    // process the line.
                    words.add(line.split(" ")[0].trim());
                }
                br.close();
                if (!words.contains(unk_word))
                    words.add(unk_word);
            } catch (IOException e) {
                e.printStackTrace();
            }

            return words;
        }

        public String getWord(int word_id) {
            if (words != null)
                if (word_id >= 0 && word_id < words.size())
                    return words.get(word_id);
            return "No word found, Maybe Vocab File not loaded";
        }

        public int totalWords() {
            if (words != null)
                return words.size();
            return 0;
        }
    }

//MapUtil

代码语言:javascript
复制
public class MapUtil {


    public static <K, V extends Comparable<? super V>> Map<K, V> sortByValue(Map<K, V> map) {
        List<Map.Entry<K, V>> list = new ArrayList<>(map.entrySet());
        list.sort(new Comparator<Map.Entry<K, V>>() {
            @Override
            public int compare(Map.Entry<K, V> o1, Map.Entry<K, V> o2) {
                if (o1.getValue() instanceof Float && o2.getValue() instanceof Float) {
                    Float o1Float = (Float) o1.getValue();
                    Float o2Float = (Float) o2.getValue();

                    return o1Float >= o2Float ? -1 : 1;
                }
                return 0;
            }
        });

        Map<K, V> result = new LinkedHashMap<>();
        for (Map.Entry<K, V> entry : list) {
            result.put(entry.getKey(), entry.getValue());
        }

        return result;
    }

}

//标题

代码语言:javascript
复制
    public class Caption implements Comparable<Caption> {

        private ArrayList<Integer> sentence;
        private float[] state;
        private float porb;
        private float score;

        public Caption(ArrayList<Integer> sentence, float[] state, float porb, float score) {
            this.sentence = sentence;
            this.state = state;
            this.porb = porb;
            this.score = score;
        }

        public ArrayList<Integer> getSentence() {
            return sentence;
        }

        public void setSentence(ArrayList<Integer> sentence) {
            this.sentence = sentence;
        }

        public float[] getState() {
            return state;
        }

        public void setState(float[] state) {
            this.state = state;
        }

        public float getPorb() {
            return porb;
        }

        public void setPorb(float porb) {
            this.porb = porb;
        }

        public float getScore() {
            return score;
        }

        public void setScore(float score) {
            this.score = score;
        }

        @Override
        public int compareTo(@NonNull Caption oc) {
            if (score == oc.score)
                return 0;
            if (score < oc.score)
                return -1;
            else
                return 1;
        }
    }

//TopN

代码语言:javascript
复制
 public class TopN {

    //Maintains the top n elements of an incrementally provided set.
    int n;
    LinkedList<Caption> data;


    public TopN(int n) {
        this.n = n;
        this.data = new LinkedList<>();
    }

    public int getSize() {
        if (data != null)
            return data.size();
        return 0;
    }

    //Pushes a new element
    public void push(Caption x) {
        if (data != null) {
            if (getSize() < n) {
                data.add(x);
            } else {
                data.removeLast();
                data.add(x);
            }
        }
    }

    //Extracts all elements from the TopN. This is a destructive operation.
    //The only   method that  can be  called immediately after extract() is reset().
    //Args:
    //sort: Whether to return the elements in descending  sorted order.
    //Returns: A list of data; the top   n elements provided to  the set.

    public LinkedList<Caption> extract(boolean sort) {
        if (sort) {
            Collections.sort(data);
        }
        return data;
    }

    //Returns the TopN to an empty state.
    public void reset() {
        if (data != null) data.clear();
    }

}

尽管准确性很低,但我还是会分享这一点,因为这可能对一些人来说是有用的,比如在android中加载显示和告诉模型。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/44141279

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档