SmashKs/OneShoot

View on GitHub
presentation/src/main/java/smash/ks/com/oneshoot/classifiers/TFLiteObjectDetectionAPIModel.kt

Summary

Maintainability
D
1 day
Test Coverage
/*
 * Copyright (C) 2018 The Smash Ks Open Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package smash.ks.com.oneshoot.classifiers

import android.content.res.AssetManager
import android.graphics.Bitmap
import android.graphics.RectF
import android.os.Trace
import com.devrapid.kotlinknifer.logw
import org.tensorflow.lite.Interpreter
import smash.ks.com.data.local.ml.Classifier
import smash.ks.com.data.local.ml.Classifier.Recognition
import smash.ks.com.ext.const.DEFAULT_STR
import java.io.BufferedReader
import java.io.FileInputStream
import java.io.IOException
import java.io.InputStream
import java.io.InputStreamReader
import java.lang.Float.compare
import java.lang.Float.parseFloat
import java.lang.Math.exp
import java.lang.Math.min
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel
import java.util.ArrayList
import java.util.Comparator
import java.util.PriorityQueue
import java.util.StringTokenizer
import java.util.Vector

/**
 * Wrapper for frozen detection models trained using the Tensorflow Object Detection API:
 * github.com/tensorflow/models/tree/master/research/object_detection
 */
class TFLiteObjectDetectionAPIModel private constructor() : Classifier {
    companion object {
        private const val TOP_CLASS_SCORE = -1000f

        private const val LIMITATION = 10

        private const val MINIMUM_THRESHOLD = 0.001f
        private const val MAXIMUM_RECTANGLE = 4
        private const val PRE_BUFFER = 3

        private const val MASK_VALUE = 0xFF
        private const val FIRST_UNIT = 8
        private const val SECOND_UNIT = 16

        private const val TOP_POINT_OF_RECT = 0
        private const val LEFT_POINT_OF_RECT = 1
        private const val BOTTOM_POINT_OF_RECT = 2
        private const val RIGHT_POINT_OF_RECT = 3

        // Only return this many results.
        private const val NUM_RESULTS = 1917
        private const val NUM_CLASSES = 91

        private const val Y_SCALE = 10.0f
        private const val X_SCALE = 10.0f
        private const val H_SCALE = 5.0f
        private const val W_SCALE = 5.0f

        /** Memory-map the model file in Assets.  */
        @Throws(IOException::class)
        private fun loadModelFile(assets: AssetManager, modelFilename: String): MappedByteBuffer {
            val fileDescriptor = assets.openFd(modelFilename)
            val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
            val fileChannel = inputStream.channel
            val startOffset = fileDescriptor.startOffset
            val declaredLength = fileDescriptor.declaredLength

            return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
        }

        /**
         * Initializes a native TensorFlow session for classifying images.
         *
         * @param assetManager  The asset manager to be used to load assets.
         * @param modelFilename The filepath of the model GraphDef protocol buffer.
         * @param labelFilename The filepath of label file for classes.
         */
        @Throws(IOException::class)
        fun create(
            assetManager: AssetManager,
            modelFilename: String,
            labelFilename: String,
            inputSize: Int
        ): Classifier = TFLiteObjectDetectionAPIModel().apply {
            loadCoderOptions(assetManager, "file:///android_asset/box_priors.txt", boxPriors)

            val labelsInput: InputStream?
            val actualFilename = labelFilename
                .split("file:///android_asset/".toRegex())
                .dropLastWhile(String::isEmpty)
                .toTypedArray()[1]

            labelsInput = assetManager.open(actualFilename)
            BufferedReader(InputStreamReader(requireNotNull(labelsInput))).useLines { it.forEach { labels.add(it) } }
            this.inputSize = inputSize
            try {
                tfLite = Interpreter(loadModelFile(assetManager, modelFilename))
            }
            catch (e: IOException) {
                e.printStackTrace()
            }

            // Pre-allocate buffers.
            img = Array(1) { Array(inputSize) { Array(inputSize) { FloatArray(PRE_BUFFER) } } }
            intValues = IntArray(this.inputSize * this.inputSize)
            outputLocations = Array(1) { Array(NUM_RESULTS) { FloatArray(MAXIMUM_RECTANGLE) } }
            outputClasses = Array(1) { Array(NUM_RESULTS) { FloatArray(NUM_CLASSES) } }
        }
    }

    // Config values.
    private var inputSize = 0
    private val boxPriors = Array(MAXIMUM_RECTANGLE) { FloatArray(NUM_RESULTS) }

    // Pre-allocated buffers.
    private val labels = Vector<String>()
    private var tfLite: Interpreter? = null
    private lateinit var intValues: IntArray
    private lateinit var outputLocations: Array<Array<FloatArray>>
    private lateinit var outputClasses: Array<Array<FloatArray>>
    private lateinit var img: Array<Array<Array<FloatArray>>>

    override val statString = DEFAULT_STR

    private fun expit(x: Float) = (1.0 / (1.0 + exp((-x).toDouble()))).toFloat()

    @Throws(IOException::class)
    private fun loadCoderOptions(assetManager: AssetManager, locationFilename: String, boxPriors: Array<FloatArray>) {
        // Try to be intelligent about opening from assets or sdcard depending on prefix.
        val assetPrefix = "file:///android_asset/"
        val `is` = if (locationFilename.startsWith(assetPrefix)) {
            assetManager.open(locationFilename.split(assetPrefix.toRegex()).toTypedArray()[1])
        }
        else {
            FileInputStream(locationFilename)
        }

        BufferedReader(InputStreamReader(`is`)).useLines {
            it.forEachIndexed { lineNum, line ->
                val st = StringTokenizer(line, ", ")
                var priorIndex = 0

                while (st.hasMoreTokens()) {
                    val token = st.nextToken()
                    try {
                        val number = parseFloat(token)

                        boxPriors[lineNum][priorIndex++] = number
                    }
                    catch (e: NumberFormatException) {
                        // Silently ignore.
                    }

                }
                if (priorIndex != NUM_RESULTS) {
                    throw IllegalStateException("BoxPrior length mismatch: $priorIndex vs $NUM_RESULTS")
                }
            }
        }

        logw("Loaded box priors!")
    }

    private fun decodeCenterSizeBoxes(predictions: Array<Array<FloatArray>>) {
        for (i in 0 until NUM_RESULTS) {
            val yCenter = predictions[0][i][TOP_POINT_OF_RECT] /
                          Y_SCALE *
                          boxPriors[BOTTOM_POINT_OF_RECT][i] +
                          boxPriors[0][i]
            val xCenter = predictions[0][i][LEFT_POINT_OF_RECT] /
                          X_SCALE *
                          boxPriors[RIGHT_POINT_OF_RECT][i] +
                          boxPriors[1][i]
            val h = exp((predictions[0][i][BOTTOM_POINT_OF_RECT] / H_SCALE).toDouble()).toFloat() *
                    boxPriors[BOTTOM_POINT_OF_RECT][i]
            val w = exp((predictions[0][i][RIGHT_POINT_OF_RECT] / W_SCALE).toDouble()).toFloat() *
                    boxPriors[RIGHT_POINT_OF_RECT][i]

            val yMin = yCenter - h / 2f
            val xMin = xCenter - w / 2f
            val yMax = yCenter + h / 2f
            val xMax = xCenter + w / 2f

            predictions[0][i][TOP_POINT_OF_RECT] = yMin
            predictions[0][i][LEFT_POINT_OF_RECT] = xMin
            predictions[0][i][BOTTOM_POINT_OF_RECT] = yMax
            predictions[0][i][RIGHT_POINT_OF_RECT] = xMax
        }
    }

    override fun recognizeImage(bitmap: Bitmap): List<Recognition> {
        // Log this method so that it can be analyzed with systrace.
        Trace.beginSection("recognizeImage")
        Trace.beginSection("preprocessBitmap")
        // Preprocess the image data from 0-255 int to normalized float based
        // on the provided parameters.
        bitmap.getPixels(intValues, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)

        for (i in 0 until inputSize) {
            for (j in 0 until inputSize) {
                val pixel = intValues[j * inputSize + i]

                img[0][j][i][BOTTOM_POINT_OF_RECT] = (pixel and MASK_VALUE).toFloat() / (MASK_VALUE + 1f) - 1f
                img[0][j][i][LEFT_POINT_OF_RECT] =
                    (pixel shr FIRST_UNIT and MASK_VALUE).toFloat() / (MASK_VALUE + 1f) - 1f
                img[0][j][i][TOP_POINT_OF_RECT] =
                    (pixel shr SECOND_UNIT and MASK_VALUE).toFloat() / (MASK_VALUE + 1f) - 1f
            }
        }
        Trace.endSection()  // preprocessBitmap

        // Copy the input data into TensorFlow.
        Trace.beginSection("feed")
        outputLocations = Array(1) { Array(NUM_RESULTS) { FloatArray(MAXIMUM_RECTANGLE) } }
        outputClasses = Array(1) { Array(NUM_RESULTS) { FloatArray(NUM_CLASSES) } }
        Trace.endSection()  // feed

        // Run the inference call.
        Trace.beginSection("run")
        tfLite?.runForMultipleInputsOutputs(arrayOf<Any>(img),
                                            hashMapOf<Int, Any>(0 to outputLocations,
                                                                1 to outputClasses))  // inputArray, outputMap
        Trace.endSection()  // run

        decodeCenterSizeBoxes(outputLocations)
        // Find the best detections.
        val pq = PriorityQueue(1, Comparator<Recognition> { lhs, rhs ->
            // Intentionally reversed to put high confidence at the head of the queue.
            compare(requireNotNull(rhs.confidence), requireNotNull(lhs.confidence))
        })

        // Scale them back to the input size.
        for (i in 0 until NUM_RESULTS) {
            var topClassScore = TOP_CLASS_SCORE
            var topClassScoreIndex = -1

            // Skip the first catch-all class.
            for (j in 1 until NUM_CLASSES) {
                val score = expit(outputClasses[0][i][j])

                if (score > topClassScore) {
                    topClassScoreIndex = j
                    topClassScore = score
                }
            }

            if (topClassScore > MINIMUM_THRESHOLD) {
                val detection = RectF(outputLocations[0][i][LEFT_POINT_OF_RECT] * inputSize,
                                      outputLocations[0][i][TOP_POINT_OF_RECT] * inputSize,
                                      outputLocations[0][i][RIGHT_POINT_OF_RECT] * inputSize,
                                      outputLocations[0][i][BOTTOM_POINT_OF_RECT] * inputSize)

                pq.add(Recognition(i.toString(),
                                   labels[topClassScoreIndex],
                                   outputClasses[0][i][topClassScoreIndex],
                                   detection))
            }
        }

        val recognitions = ArrayList<Recognition>()

        for (i in 0 until min(pq.size, LIMITATION)) {
            val recog = pq.poll()

            recognitions.add(recog)
        }
        Trace.endSection()  // recognizeImage

        return recognitions
    }

    override fun enableStatLogging(logStats: Boolean) {}

    override fun close() {
        tfLite?.close()
        tfLite = null
    }
}