Tensorflow Liteを使ってみた(画像分類)

はじめに

Tensorflowで学習させた結果をTensorflow Liteで使えるようにするまでの道のりをメモします。
Tensorflow Liteは公式ドキュメントの情報がイマイチなのですが、サンプルソースが公開されているのでここから必要な部分を切り出しました。
今回は単純な画像分類をやってみました。

前提

Tensorflowでモデルを作成済である事。

モデルのコンバート

TensorflowのモデルをTensorflow Lite用にコンバートします。
公式ドキュメントには「コマンドライン」と「Python」の2つの方法が記載されています。
今回はコマンドラインで実施しました。

tflite_convert --output_file=hoge.tflite --keras_model_file=hoge.h5

Tensorflow Light AARの読み込み

app\build.gradleを修正します。

dependencies {
    implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
}

tfliteの圧縮を無効にする

build.gradleを修正します。

android {
    aaptOptions {
        noCompress "tflite"
    }
}

assetsを作成

assetsフォルダを作成して、その中にコンバートしたモデルを格納します。
labels.txtを公式を参考に作成して格納します。

MainActivity

今回はMainActivityのonCreate内に全部書いてしまいます。
こんな感じになりました。

@Override
protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);

    try {
        classifier = Classifier.create(this);
    } catch (IOException e) {
        System.out.println(e);
    }

    AssetManager as = getResources().getAssets();
    try {
        image = getAssets().open("images/test.gif");
    } catch (IOException e) {
        System.out.println(e);
    }
    Bitmap bitmap= Bitmap.createScaledBitmap(BitmapFactory.decodeStream(image), 100, 100, false);

    List results = classifier.recognizeImage(bitmap);
}

処理の概要

基本的には上記のコードを動かすのに必要な部分を公式サンプルから持ってきました。
処理の概要については公式ドキュメントに記載されています。
ここでは主な変更点をメモします。

公式サンプルのClassifierのcreateを呼んでいます。
サンプルの引数は削除して使いました。
ClassifierのloadModelFile内のgetModelPath()のtffileのファイル名だけは、修正しておく必要があります。

classifier = Classifier.create(this);

今回分類したい画像ファイルを読み込みます。

image = getAssets().open("images/test.gif");

今回分類したい画像ファイルとtffile(学習した画像ファイル)のサイズを合わせています。

Bitmap bitmap= Bitmap.createScaledBitmap(BitmapFactory.decodeStream(image), 100, 100, false);

画像分類を実行します。
成功すれば、resultsに結果が格納されます。

List results = classifier.recognizeImage(bitmap);

備考

分類したい画像ファイルと学習した画像ファイルのサイズが違うと、以下の例外が発生します。
この例外が発生した場合は画像ファイルのサイズを合わせてあげましょう。

java.lang.IllegalArgumentException: Cannot convert between a TensorFlowLite buffer with xxxxx bytes and a ByteBuffer with xxxxx bytes.

感想

思っていたより公式サンプルが複雑でした。
モデルの読み込みなんか、Kerasみたいに1行で書けるようにして欲しいなと思いました。

model = keras.models.load_model('hoge.h5')