最近のAI技術の発展には目を見張るものがあるが、そんな中でも言語・ライブラリ・フレームワーク間の互換性は今ひとつ進んでいないのが現状で、TensorFlowとPyTorchとscikit-learnとでは別々の形式をモデルとして利用しているし、もちろんPythonで動かすしかないという現状がここ最近まではあった(最近MLをやるようになったので認識が雑だが、最近始めた人からはこう見えているという話)。
そんな中登場したONNX(Open Neural Network Exchange)は、群雄割拠しているDeep Learningモデルの共通規格として最近脚光を浴びつつある。
イケてる。カッコいいね。
また、ONNXは規格だけではなく軽量なランタイムも用意しており、ONNX Runtimeと呼ばれている(VOICEVOXもこれを呼び出している)。そしてこのONNX RuntimeはC APIを経由して提供されるので、FFIが利用可能な言語であれば既存のモデルを走らせて画像分類や音声合成などが行えるというモジュラーな仕組みになっている。
で、ありがたいことにScalaからもONNX Runtimeが呼び出せる仕組みが存在する。それがONNX-Scalaである。
このバインディングとONNX Runtime、そしてSqueezeNetという既存の画像分類モデルを使って、今年の干支であるウサギを正しく判定できるか試してみる*1。
ONNX Runtime
まずは、公式ドキュメントの手順にそってONNX Runtimeをインストールしておく。
ちなみにインストールにはpipが必要である。それでいいのか。
ちなみに音声合成ソフトウェアであるVOICEVOXも音声合成モデルとしてONNXを使っている。
モデル
今回はSqueezeNetという画像分類モデルを使う。ONNX形式にしたときのサイズは5MB程度で、計算量が少なくて成績良好、というのがウリのモデルである。
ONNX版のモデルの詳細は以下の通り。
ONNXのサイトから.onnx
ファイルをダウンロードできるので、手元に落としておく。
curl -O "https://media.githubusercontent.com/media/onnx/models/main/vision/classification/squeezenet/model/squeezenet1.0-12.onnx"
また、分類結果はIMAGENET 1000という分類で出力される。
IMAGENETの何番がどの物体に相当するかは以下のページを見るとよい。
deeplearning.cms.waikato.ac.nz
画像
Wikimediaからかわいいウサちゃんの画像をいくつか用意する。
まず、入力画像のサイズは少なくとも224x224である必要があるため、GIMPで適宜切り出して縮小しておく。
今回は特別なライブラリはあまり使わずにそのまま画像をバイト配列として読み込みたい。今回は画像がRGBの順で1バイトずつ左から右に、右端まで行ったら次の行の左端から・・・というフォーマットで格納するPPM形式で画像を保存する。
NHWCとNCHW
突然だが、画像を読み込むにあたって、気にしなければならないことがある。ピクセルの格納順序である。PPM形式では、画像処理用語で言うところのNHWCという順序でピクセルがシリアライズされている。NHWCでは、左上のピクセルのR,G,B,その右のピクセルのR,G,B, ... 右端のR,G,B, 次の行の左端の...
という順序でバイトが格納される。
一方で、Deep LearningではNCHWという順序でピクセルをメモリに格納するのが一般的なようだ。今回もこの順序での格納が必要だった。すなわち、左上のピクセルのR,その右のピクセルのR,...右端のR, ... 右下のピクセルのR, 左上のピクセルのG, ...
という順序でバイトが格納される。
口で言ってもわかりにくいのでPyTorch Channels Last memory format perf optimization and oneDNN integration plan. · GitHubから画像を引用する。
このような順序になっている。
この順序入れ替えのためのコードがスクリプトの大勢を占めることになった。
Scala Scriptを書く
さていきなりだがScalaのコードを書いていく。ライブラリのREADMEを参考に、ちょっと色々工夫して画像を読み込んだ。
//> using scala "3.2" //> using dep "org.emergent-order::onnx-scala-backends:0.17.0" import java.nio.file.{Files, Paths} import org.emergentorder.onnx.Tensors._ import org.emergentorder.onnx.Tensors.Tensor._ import org.emergentorder.onnx.backends._ import org.emergentorder.compiletime._ import org.emergentorder.io.kjaer.compiletime._ // 非同期実行ランタイムとして必要 import cats.effect.unsafe.implicits.global // ONNXモデルを読み込む val squeezenetBytes = Files.readAllBytes(Paths.get("squeezenet1.0-12.onnx")) val squeezenet = new ORTModelBackend(squeezenetBytes) // 画像ファイルを読み込む val rabbitPpmBytes = Files.readAllBytes(Paths.get("Domestic-rabbit-Lilly-washing-0a-224.ppm")) // PPM形式のヘッダは不要でデータだけ欲しいので落とす(バイナリエディタを見て位置を決定した) val rabbitPpmBytesBody = rabbitPpmBytes.drop(16 * 3 + 13) // PPMはByte配列であり、NHWC順になっているので、floatに変換してからNCHW順に並べ替える。 // Javaはsigned byteでデータを読み取るので0xffでandして符号を外させる(0〜255に変換する) val data = rabbitPpmBytesBody.map(b => (b & 0xff).toFloat ) // 次元入れ替えを行うだけのボイラープレート def transform(a: Array[Float]): Array[Float] = { // サイズと次元の情報 val height: Int = 224 val width: Int = 224 val channel: Int = 3 // NHWC形式からNCHW形式に変換する val nchwTensor: Array[Array[Array[Float]]] = { val c = channel val h = height val w = width val nchwArray: Array[Array[Array[Float]]] = Array.ofDim[Float](c, h, w) val nhwcArray: Array[Array[Array[Float]]] = Array.ofDim[Float](h, w, c) // バイト配列からNHWC形式のテンソルに復元 var idx = 0 for { hh <- 0 until h ww <- 0 until w cc <- 0 until c } { nhwcArray(hh)(ww)(cc) = a(idx) idx += 1 } // NHWC形式からNCHW形式に変換 for { hh <- 0 until h ww <- 0 until w cc <- 0 until c } { nchwArray(cc)(hh)(ww) = nhwcArray(hh)(ww)(cc) } nchwArray } nchwTensor.flatten.flatten } // NHWC形式をNCHW形式に変換する val transformedData = transform(data) // サイズが合致するか確認している println(s"data size: ${transformedData.size}") println(s"required size: ${3 * 224 * 224}") // 入力テンソルのShapeを設定する val shape = 1 #: 3 #: 224 #: 224 #: SNil val tensorShapeDenotation = "Batch" ##: "Channel" ##: "Height" ##: "Width" ##: TSNil val tensorDenotation: String & Singleton = "Image" // float配列とメタデータをもとに入力テンソルを作る val imageTens = Tensor(transformedData, tensorDenotation, tensorShapeDenotation, shape) // モデルに画像を投入する val out = squeezenet.fullModel[ Float, "ImageNetClassification", "Batch" ##: "Class" ##: TSNil, 1 #: 1000 #: 1 #: 1 #: SNil ](Tuple(imageTens)) // 実行結果のShapeを確認 out.shape.unsafeRunSync() // val res0: Array[Int] = Array(1, 1000, 1, 1) // モデルを実行する val calcdata = out.data.unsafeRunSync() // 最も出力が高いラベルを5つ取出す val classified = calcdata.indices.sortBy(calcdata).reverse.take(5) println(classified) println(calcdata.max)
実行結果は以下の通りになった:
data size: 150528 required size: 150528 Vector(330, 331, 332, 434, 174) 0.9106789
最も可能性が高いのは330番、つまりwood rabbit, cottontail, cottontail rabbit
であり、91%の確率でウサギであると無事判定している。
もう一方のウサちゃんでも実行してみる:
data size: 150528 required size: 150528 Vector(331, 330, 24, 348, 349) 0.9466005
最も可能性が高いのは331番、つまりhare
であり、94%の確率でノウサギであると無事判定している。
感想
型の具体的な意味などはこれから調べていくが、何をやっているかの想像は既にだいたいついている。静的にテンソルのサイズを型で表現して決めているな〜とかが見てとれるのが面白かった。実行速度もとても速かった。画像のロードが面倒だったが、たぶん便利なメソッドがどこかに生えているであろう。簡潔な記述で簡潔に動作する良いライブラリだった。Scala 3でも動作するのがお利口さんだ。
ちょっと迷ったのが、画像のピクセルの階調は[0, 1]
なのか[0, 255]
なのかである。実際は[0, 255]
のfloat
を渡してやればよいのだが、0〜1の範囲に正規化しなければならないと思い込んでしまった。(もちろん判定に失敗して、鎖帷子とか意味不明な判定が行なわれた)
今後はLLMみたいな言語モデルや、分類だけではない画像生成モデルでも動作するかを確かめていこうと思う。
*1:5月になって干支の話をするのは完全に間に合っていないのだが、本当は正月のネタに画像分類をやりたかったのだ