我保存了一个tflite模型,其输入和输出详细信息如下:
如何使用Java和Tensorflow解释器在android应用程序上将输出显示为图像?
发布于 2020-06-12 07:25:26
import android.content.res.AssetManager
import android.graphics.Bitmap
import android.util.Log
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.Tensor
import java.io.FileInputStream
import java.lang.StringBuilder
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.channels.FileChannel
class ImgPredictor(val assetManager: AssetManager, modelFilename: String) {
private var tflite: Interpreter
private var input: ByteBuffer
private var output: ByteBuffer
init {
val tfliteOptions = Interpreter.Options()
val fd = assetManager.openFd(modelFilename)
val inputStream = FileInputStream(fd.fileDescriptor)
val fileChannel: FileChannel = inputStream.channel
val startOffset: Long = fd.startOffset
val declaredLength: Long = fd.declaredLength
val mbb = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
tflite = Interpreter(mbb, tfliteOptions)
Log.i("ImgPredictor", "interpreter: ${tflite.detail()}")
input = ByteBuffer.allocate(100 * Int.SIZE_BYTES)
input.order(ByteOrder.nativeOrder())
output = ByteBuffer.allocate(1 * 28 * 28 * 1 * Int.SIZE_BYTES)
output.order(ByteOrder.nativeOrder())
}
fun predict(data: IntArray): Bitmap {
val startTs = System.currentTimeMillis();
input.clear()
output.clear()
input.rewind()
for (i in 0 until 100) {
input.putInt(data[i])
}
tflite.run(input, output)
val bitmap = Bitmap.createBitmap(28, 28, Bitmap.Config.ARGB_8888);
// vector is your int[] of ARGB
bitmap.copyPixelsFromBuffer(output)
return bitmap
}
}
fun Tensor.detail(): String {
return "[shape: ${this.shape().toList()} dataType: ${this.dataType()}, bytes: ${this.numBytes()}]"
}
fun Interpreter.detail(): String {
val sb = StringBuilder("interpreter: \n")
sb.append("input: { \n")
for (i in 0 until this.inputTensorCount) {
sb.append(" ").append(this.getInputTensor(i).detail()).append("\n")
}
sb.append("}, \n")
sb.append("output: { \n")
for (i in 0 until this.outputTensorCount) {
sb.append(" ").append(this.getOutputTensor(i).detail()).append("\n")
}
sb.append("}")
return sb.toString()
}您可以在这里查看官方教程以获得更多详细信息:对象检测解释器示例。
但是,你应该注意到以下几点:
implementation 'org.tensorflow:tensorflow-lite:x.x.x'与您的PC保持相同的版本,因为某些操作系统可能无法在较低版本中工作。https://stackoverflow.com/questions/62338930
复制相似问题