SmashKs/OneShoot

View on GitHub
presentation/src/main/java/smash/ks/com/oneshoot/classifiers/TFLiteImageClassifier.kt

Summary

Maintainability
A
1 hr
Test Coverage
/*
 * Copyright (C) 2018 The Smash Ks Open Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package smash.ks.com.oneshoot.classifiers

import android.annotation.SuppressLint
import android.content.res.AssetManager
import android.graphics.Bitmap
import org.tensorflow.lite.Interpreter
import smash.ks.com.data.local.ml.Classifier
import smash.ks.com.data.local.ml.Classifier.Recognition
import smash.ks.com.ext.const.DEFAULT_INT
import smash.ks.com.ext.const.DEFAULT_STR
import java.io.BufferedReader
import java.io.FileInputStream
import java.io.IOException
import java.io.InputStreamReader
import java.lang.Float.compare
import java.lang.Math.min
import java.nio.ByteBuffer
import java.nio.ByteBuffer.allocateDirect
import java.nio.ByteOrder.nativeOrder
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel.MapMode.READ_ONLY
import java.util.ArrayList
import java.util.PriorityQueue
import kotlin.experimental.and

/** A classifier specialized to label images using TensorFlow.  */
class TFLiteImageClassifier private constructor() : Classifier {
    companion object {
        private const val MASK_VALUE = 0xFF
        private const val FIRST_UNIT = 8
        private const val SECOND_UNIT = 16

        private const val MAX_RESULTS = 3
        private const val BATCH_SIZE = 1
        private const val PIXEL_SIZE = 3
        private const val THRESHOLD = 0.1f

        @Throws(IOException::class)
        fun create(assetManager: AssetManager, modelPath: String, labelPath: String, inputSize: Int) =
            TFLiteImageClassifier().apply {
                interpreter = Interpreter(loadModelFile(assetManager, modelPath))
                labelList = loadLabelList(assetManager, labelPath)
                this.inputSize = inputSize
            }
    }

    private var interpreter: Interpreter? = null
    private var inputSize = 0
    private var labelList: List<String>? = null

    override val statString = DEFAULT_STR

    override fun recognizeImage(bitmap: Bitmap): List<Recognition> {
        val byteBuffer = convertBitmapToByteBuffer(bitmap)
        val result = Array(1) { ByteArray(labelList?.size ?: DEFAULT_INT) }

        requireNotNull(interpreter?.run(byteBuffer, result))

        return getSortedResult(result)
    }

    override fun close() {
        interpreter?.close()
        interpreter = null
    }

    @Throws(IOException::class)
    private fun loadModelFile(assetManager: AssetManager, modelPath: String): MappedByteBuffer {
        val fileDescriptor = assetManager.openFd(modelPath)
        val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
        val fileChannel = inputStream.channel
        val startOffset = fileDescriptor.startOffset
        val declaredLength = fileDescriptor.declaredLength

        return fileChannel.map(READ_ONLY, startOffset, declaredLength)
    }

    @Throws(IOException::class)
    private fun loadLabelList(assetManager: AssetManager, labelPath: String): List<String> {
        val labelList = ArrayList<String>()

        BufferedReader(InputStreamReader(assetManager.open(labelPath))).useLines { it.forEach { labelList.add(it) } }

        return labelList
    }

    private fun convertBitmapToByteBuffer(bitmap: Bitmap): ByteBuffer {
        val byteBuffer = allocateDirect(BATCH_SIZE * inputSize * inputSize * PIXEL_SIZE)
        byteBuffer.order(nativeOrder())

        val intValues = IntArray(inputSize * inputSize)
        bitmap.getPixels(intValues, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)

        var pixel = 0

        for (i in 0 until inputSize) {
            for (j in 0 until inputSize) {
                val value = intValues[pixel++]
                byteBuffer.put((value shr SECOND_UNIT and MASK_VALUE).toByte())
                byteBuffer.put((value shr FIRST_UNIT and MASK_VALUE).toByte())
                byteBuffer.put((value and MASK_VALUE).toByte())
            }
        }

        return byteBuffer
    }

    @SuppressLint("DefaultLocale")
    private fun getSortedResult(labelProbArray: Array<ByteArray>): List<Recognition> {
        val pq = PriorityQueue<Recognition>(MAX_RESULTS) { lhs, rhs ->
            compare(requireNotNull(rhs.confidence), requireNotNull(lhs.confidence))
        }

        labelList?.let {
            for (i in it.indices) {
                val confidence = (labelProbArray[0][i] and MASK_VALUE.toByte()) / MASK_VALUE.toFloat()

                if (THRESHOLD < confidence) {
                    pq.add(Recognition(i.toString(), if (i < it.size) it[i] else "unknown", confidence))
                }
            }
        }

        val recognitions = ArrayList<Recognition>()
        val recognitionsSize = min(pq.size, MAX_RESULTS)

        (0 until recognitionsSize).forEach { recognitions.add(pq.poll()) }

        return recognitions
    }

    override fun enableStatLogging(debug: Boolean) {}
}