Неточный прогноз в пользовательской модели MLKit

Попытка использовать переобученную модель MobileNet для прогнозирования пород собак, но при использовании модели через Firebase MLKit она не может правильно предсказать породу собак. Настольная модель и модель tflite способны правильно предсказать породу, но с использованием одного и того же изображения мопса, настольная модель и модель tflite (на рабочем столе) на 87,8% уверены, что это мопс; тогда как на MLKit достоверность составляет 1,47x10-2%.

Я подозреваю, что проблема в моей предварительной обработке изображения в коде приложения. В документах показано, как масштабировать пиксели в диапазон -1,0, 1,0; что, согласно коду для функции предварительной обработки изображений keras, является тем, что требуется.

Вот моя функция infer(iStream), в которой, я думаю, может лежать ошибка. Любая помощь приветствуется, это сводит меня с ума.

private fun infer(iStream: InputStream?) {
    Log.d("ML_TAG", "infer")
    val bmp = Bitmap.createScaledBitmap(BitmapFactory.decodeStream(iStream), 224, 224, true)
    i.setImageBitmap(bmp)
    val bNum = 0
    val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
    for (x in 0..223) {
        for (y in 0..223) {
            val px = bmp.getPixel(x, y)
            input[bNum][x][y][0] = (Color.red(px) - 127) / 255.0f
            input[bNum][x][y][1] = (Color.green(px) - 127) / 255.0f
            input[bNum][x][y][2] = (Color.blue(px) - 127) / 255.0f
        }
    }

    val inputs = FirebaseModelInputs.Builder()
        .add(input)
        .build()

    interpreter.run(inputs, ioOpts).addOnSuccessListener { res ->
        val o = res.getOutput<kotlin.Array<FloatArray>>(0)
        val prob = o[0]

        val r = BufferedReader(InputStreamReader(assets.open("retrained_labels.txt")))
        val arrToSort = arrayListOf<Pair<String, Float>>()
        val rArr = r.readLines()
        for (i in prob.indices) {
            val p = Pair(rArr[i], prob[i])
            arrToSort.add(p)
        }
        val sortedList = arrToSort.sortedWith(compareByDescending {it.second})
        val topFive = sortedList.slice(0..4)
        arrToSort.forEach {
            if (it.first == "pug") {
                Log.i("ML_TAG", "Pug: ${it.second}")
            }
        }
        sortedList.forEach {
            if(it.first == "pug") {
                Log.i("ML_TAG", "Pug: ${it.second}")
            }
        }
        topFive.forEach {
            Log.i("ML_TAG", "${it.first}: ${it.second}")
        }
    }
        .addOnFailureListener { res ->
            Log.e("ML_TAG", res.message)
        }
}

person Joshua Feltimo    schedule 27.02.2019    source источник
comment
В каком формате ваша модель принимает входные данные? Шаги предварительной обработки зависят от того, является ли это int8/uint8 или float.   -  person Sachin Joglekar    schedule 01.03.2019
comment
@SachinJoglekar float32   -  person Joshua Feltimo    schedule 03.03.2019


Ответы (1)


Я думаю, что (Color.red(px) - 127) / 255.0f масштабируется до [-0,5, 0,5]. (Color.red(px) - 127) / 128.0f дает лучшие результаты?

person Kevin Cheung    schedule 20.03.2019