首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何从tflite模型中输出形状数组[ 1,28,28,1]作为android中的图像

如何从tflite模型中输出形状数组[ 1,28,28,1]作为android中的图像
EN

Stack Overflow用户
提问于 2020-06-12 06:46:07
回答 1查看 745关注 0票数 0

我保存了一个tflite模型,其输入和输出详细信息如下:

  1. 输入 :[{'name':‘稠密_4_input’,'index':0,'shape':数组( 1,100,dtype=int32),'shape_signature':数组( 1,100,dtype=int32),'dtype':,‘量化’:(0.0,0),‘量化_参数’:{‘缩放’:数组([],dtype=float32),‘零点’:数组([],dtype=int32),‘量化_维’:0},“稀疏参数”:{}}]
  2. 输出:[{“名称”:“标识”、“索引”:22、“形状”:数组( 1、28、28、1、1、dtype=int32)、“shape_signature”:数组( 1、28、28、1、1、dtype=int32)、“dtype”:、“量化”:(0.0、0)、“量化参数”:{‘缩放’:数组([]、dtype=float32)、‘零点’:数组([]、dtype=int32)、‘量化_维度’:0},“稀疏参数”:{}}]

如何使用Java和Tensorflow解释器在android应用程序上将输出显示为图像?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-06-12 07:25:26

代码语言:javascript
复制
import android.content.res.AssetManager
import android.graphics.Bitmap
import android.util.Log
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.Tensor
import java.io.FileInputStream
import java.lang.StringBuilder
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.channels.FileChannel

class ImgPredictor(val assetManager: AssetManager, modelFilename: String) {
    private var tflite: Interpreter

    private var input: ByteBuffer
    private var output: ByteBuffer

    init {
        val tfliteOptions = Interpreter.Options()

        val fd = assetManager.openFd(modelFilename)
        val inputStream = FileInputStream(fd.fileDescriptor)
        val fileChannel: FileChannel = inputStream.channel
        val startOffset: Long = fd.startOffset
        val declaredLength: Long = fd.declaredLength
        val mbb = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
        tflite = Interpreter(mbb, tfliteOptions)
        Log.i("ImgPredictor", "interpreter: ${tflite.detail()}")
        input = ByteBuffer.allocate(100 * Int.SIZE_BYTES)
        input.order(ByteOrder.nativeOrder())

        output = ByteBuffer.allocate(1 * 28 * 28 * 1 * Int.SIZE_BYTES)
        output.order(ByteOrder.nativeOrder())
    }

    fun predict(data: IntArray): Bitmap {
        val startTs = System.currentTimeMillis();

        input.clear()
        output.clear()

        input.rewind()
        for (i in 0 until 100) {
            input.putInt(data[i])
        }
        tflite.run(input, output)
        val bitmap = Bitmap.createBitmap(28, 28, Bitmap.Config.ARGB_8888);
        // vector is your int[] of ARGB
        bitmap.copyPixelsFromBuffer(output)
        return bitmap
    }
}

fun Tensor.detail(): String {
    return "[shape: ${this.shape().toList()} dataType: ${this.dataType()}, bytes: ${this.numBytes()}]"
}

fun Interpreter.detail(): String {
    val sb = StringBuilder("interpreter: \n")
    sb.append("input: { \n")
    for (i in 0 until this.inputTensorCount) {
        sb.append("    ").append(this.getInputTensor(i).detail()).append("\n")
    }
    sb.append("}, \n")

    sb.append("output: { \n")
    for (i in 0 until this.outputTensorCount) {
        sb.append("    ").append(this.getOutputTensor(i).detail()).append("\n")
    }
    sb.append("}")
    return sb.toString()
}

您可以在这里查看官方教程以获得更多详细信息:对象检测解释器示例

但是,你应该注意到以下几点:

  1. implementation 'org.tensorflow:tensorflow-lite:x.x.x'与您的PC保持相同的版本,因为某些操作系统可能无法在较低版本中工作。
  2. 使用一些细节函数打印解释器输入/输出。
  3. 检查输入输出数据缓冲区顺序endian。
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62338930

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档