Lambdaカクテル

京都在住Webエンジニアの日記です

Invite link for Scalaわいわいランド

Scala 3でONNX Runtimeを走らせ、ウサギの画像分類タスクを実行させた

最近のAI技術の発展には目を見張るものがあるが、そんな中でも言語・ライブラリ・フレームワーク間の互換性は今ひとつ進んでいないのが現状で、TensorFlowとPyTorchとscikit-learnとでは別々の形式をモデルとして利用しているし、もちろんPythonで動かすしかないという現状がここ最近まではあった(最近MLをやるようになったので認識が雑だが、最近始めた人からはこう見えているという話)。

そんな中登場したONNX(Open Neural Network Exchange)は、群雄割拠しているDeep Learningモデルの共通規格として最近脚光を浴びつつある。

onnx.ai

ja.wikipedia.org

イケてる。カッコいいね。

また、ONNXは規格だけではなく軽量なランタイムも用意しており、ONNX Runtimeと呼ばれている(VOICEVOXもこれを呼び出している)。そしてこのONNX RuntimeはC APIを経由して提供されるので、FFIが利用可能な言語であれば既存のモデルを走らせて画像分類や音声合成などが行えるというモジュラーな仕組みになっている。

で、ありがたいことにScalaからもONNX Runtimeが呼び出せる仕組みが存在する。それがONNX-Scalaである。

github.com

このバインディングとONNX Runtime、そしてSqueezeNetという既存の画像分類モデルを使って、今年の干支であるウサギを正しく判定できるか試してみる*1

ONNX Runtime

まずは、公式ドキュメントの手順にそってONNX Runtimeをインストールしておく。

onnxruntime.ai

ちなみにインストールにはpipが必要である。それでいいのか。

ちなみに音声合成ソフトウェアであるVOICEVOXも音声合成モデルとしてONNXを使っている。

github.com

モデル

今回はSqueezeNetという画像分類モデルを使う。ONNX形式にしたときのサイズは5MB程度で、計算量が少なくて成績良好、というのがウリのモデルである。

en.wikipedia.org

ONNX版のモデルの詳細は以下の通り。

github.com

ONNXのサイトから.onnxファイルをダウンロードできるので、手元に落としておく。

curl -O "https://media.githubusercontent.com/media/onnx/models/main/vision/classification/squeezenet/model/squeezenet1.0-12.onnx"

また、分類結果はIMAGENET 1000という分類で出力される。

IMAGENETの何番がどの物体に相当するかは以下のページを見るとよい。

deeplearning.cms.waikato.ac.nz

画像

Wikimediaからかわいいウサちゃんの画像をいくつか用意する。

commons.wikimedia.org

commons.wikimedia.org

まず、入力画像のサイズは少なくとも224x224である必要があるため、GIMPで適宜切り出して縮小しておく。

ウサちゃん

ウサちゃん(2)

今回は特別なライブラリはあまり使わずにそのまま画像をバイト配列として読み込みたい。今回は画像がRGBの順で1バイトずつ左から右に、右端まで行ったら次の行の左端から・・・というフォーマットで格納するPPM形式で画像を保存する。

ja.wikipedia.org

NHWCとNCHW

突然だが、画像を読み込むにあたって、気にしなければならないことがある。ピクセルの格納順序である。PPM形式では、画像処理用語で言うところのNHWCという順序でピクセルがシリアライズされている。NHWCでは、左上のピクセルのR,G,B,その右のピクセルのR,G,B, ... 右端のR,G,B, 次の行の左端の...という順序でバイトが格納される。

一方で、Deep LearningではNCHWという順序でピクセルをメモリに格納するのが一般的なようだ。今回もこの順序での格納が必要だった。すなわち、左上のピクセルのR,その右のピクセルのR,...右端のR, ... 右下のピクセルのR, 左上のピクセルのG, ...という順序でバイトが格納される。

口で言ってもわかりにくいのでPyTorch Channels Last memory format perf optimization and oneDNN integration plan. · GitHubから画像を引用する。

このような順序になっている。

この順序入れ替えのためのコードがスクリプトの大勢を占めることになった。

Scala Scriptを書く

さていきなりだがScalaのコードを書いていく。ライブラリのREADMEを参考に、ちょっと色々工夫して画像を読み込んだ。

//> using scala "3.2"
//> using dep "org.emergent-order::onnx-scala-backends:0.17.0"

import java.nio.file.{Files, Paths}
import org.emergentorder.onnx.Tensors._
import org.emergentorder.onnx.Tensors.Tensor._
import org.emergentorder.onnx.backends._
import org.emergentorder.compiletime._
import org.emergentorder.io.kjaer.compiletime._

// 非同期実行ランタイムとして必要
import cats.effect.unsafe.implicits.global

// ONNXモデルを読み込む
val squeezenetBytes = Files.readAllBytes(Paths.get("squeezenet1.0-12.onnx"))
val squeezenet = new ORTModelBackend(squeezenetBytes)

// 画像ファイルを読み込む
val rabbitPpmBytes =
  Files.readAllBytes(Paths.get("Domestic-rabbit-Lilly-washing-0a-224.ppm"))

// PPM形式のヘッダは不要でデータだけ欲しいので落とす(バイナリエディタを見て位置を決定した)
val rabbitPpmBytesBody = rabbitPpmBytes.drop(16 * 3 + 13)

// PPMはByte配列であり、NHWC順になっているので、floatに変換してからNCHW順に並べ替える。

// Javaはsigned byteでデータを読み取るので0xffでandして符号を外させる(0〜255に変換する)
val data = rabbitPpmBytesBody.map(b =>
  (b & 0xff).toFloat
)

// 次元入れ替えを行うだけのボイラープレート
def transform(a: Array[Float]): Array[Float] = {
// サイズと次元の情報
  val height: Int = 224
  val width: Int = 224
  val channel: Int = 3

// NHWC形式からNCHW形式に変換する
  val nchwTensor: Array[Array[Array[Float]]] = {
    val c = channel
    val h = height
    val w = width

    val nchwArray: Array[Array[Array[Float]]] = Array.ofDim[Float](c, h, w)
    val nhwcArray: Array[Array[Array[Float]]] = Array.ofDim[Float](h, w, c)

    // バイト配列からNHWC形式のテンソルに復元
    var idx = 0
    for {
      hh <- 0 until h
      ww <- 0 until w
      cc <- 0 until c
    } {
      nhwcArray(hh)(ww)(cc) = a(idx)
      idx += 1
    }

    // NHWC形式からNCHW形式に変換
    for {
      hh <- 0 until h
      ww <- 0 until w
      cc <- 0 until c
    } {
      nchwArray(cc)(hh)(ww) = nhwcArray(hh)(ww)(cc)
    }

    nchwArray
  }
  nchwTensor.flatten.flatten
}

// NHWC形式をNCHW形式に変換する
val transformedData = transform(data)

// サイズが合致するか確認している
println(s"data size: ${transformedData.size}")
println(s"required size: ${3 * 224 * 224}")

// 入力テンソルのShapeを設定する
val shape = 1 #: 3 #: 224 #: 224 #: SNil
val tensorShapeDenotation =
  "Batch" ##: "Channel" ##: "Height" ##: "Width" ##: TSNil

val tensorDenotation: String & Singleton = "Image"

// float配列とメタデータをもとに入力テンソルを作る
val imageTens =
  Tensor(transformedData, tensorDenotation, tensorShapeDenotation, shape)

// モデルに画像を投入する
val out = squeezenet.fullModel[
  Float,
  "ImageNetClassification",
  "Batch" ##: "Class" ##: TSNil,
  1 #: 1000 #: 1 #: 1 #: SNil
](Tuple(imageTens))

// 実行結果のShapeを確認
out.shape.unsafeRunSync()
// val res0: Array[Int] = Array(1, 1000, 1, 1)

// モデルを実行する
val calcdata = out.data.unsafeRunSync()

// 最も出力が高いラベルを5つ取出す
val classified = calcdata.indices.sortBy(calcdata).reverse.take(5)
println(classified)
println(calcdata.max)

実行結果は以下の通りになった:

data size: 150528
required size: 150528
Vector(330, 331, 332, 434, 174)
0.9106789

最も可能性が高いのは330番、つまりwood rabbit, cottontail, cottontail rabbitであり、91%の確率でウサギであると無事判定している。

wood rabbit, cottontail, cottontail rabbit

もう一方のウサちゃんでも実行してみる:

data size: 150528
required size: 150528
Vector(331, 330, 24, 348, 349)
0.9466005

最も可能性が高いのは331番、つまりhareであり、94%の確率でノウサギであると無事判定している。

hare

感想

型の具体的な意味などはこれから調べていくが、何をやっているかの想像は既にだいたいついている。静的にテンソルのサイズを型で表現して決めているな〜とかが見てとれるのが面白かった。実行速度もとても速かった。画像のロードが面倒だったが、たぶん便利なメソッドがどこかに生えているであろう。簡潔な記述で簡潔に動作する良いライブラリだった。Scala 3でも動作するのがお利口さんだ。

ちょっと迷ったのが、画像のピクセルの階調は[0, 1]なのか[0, 255]なのかである。実際は[0, 255]floatを渡してやればよいのだが、0〜1の範囲に正規化しなければならないと思い込んでしまった。(もちろん判定に失敗して、鎖帷子とか意味不明な判定が行なわれた)

今後はLLMみたいな言語モデルや、分類だけではない画像生成モデルでも動作するかを確かめていこうと思う。

*1:5月になって干支の話をするのは完全に間に合っていないのだが、本当は正月のネタに画像分類をやりたかったのだ

★記事をRTしてもらえると喜びます
Webアプリケーション開発関連の記事を投稿しています.読者になってみませんか?