Lambdaカクテル

京都在住Webエンジニアの日記です

Invite link for Scalaわいわいランド

Tensorflo兎 Scalaで遊んだ記録その4(XORを学習させる)

あけましておめでとうございます。2023年は兎年ですね。

そして、兎といえばTensorFlo兎(テンソルフロウ)。今回もTensorflow Scalaをやっていきます。

XORをニューラルネットワークで実装する

platanios.org

公式ページのサンプルを見ながら、

techblog.glpgs.com

この記事を見つつ、Tensorflow ScalaでXORを実装してみました。いきなりソースコードを示します。

import org.platanios.tensorflow
import java.nio.file.Paths

  def learn: Unit = {
    import tensorflow.api._
    import tensorflow.api.learn.Model
    import tensorflow.api.tensors.Tensor

    // 教師データをテンソルで直接定義する。
    val trainDS = Tensor[Float](Seq(0, 0), Seq(0, 1), Seq(1, 0), Seq(1, 1))
    println(trainDS.summarize())

    // 教師データのラベル(XORした結果)。ニューラルネットワークのShapeの都合でtransposeして形を合わせている
    val trainLabels = Tensor[Float](Seq(0, 1, 1, 0)).transpose()
    println(trainLabels.summarize())

    import tensorflow.api.learn.layers._

    // 入力。
    val input = Input(FLOAT32, Shape(4, 2))

    // 教師データの入力もInputとして扱う。
    val trainInput = Input(FLOAT32, Shape(4, 1))

    // 層を定義する。まず幅2の線形レイヤ、次にReLUレイヤ、次に1つの線形レイヤで出力する。
    val layer = Linear[Float]("inputLinear", 2, useBias = true) >> ReLU[Float]("hidden ReLU") >> Linear[Float]("outputLinear", 1)

    // 学習で誤差をフィードバックするための差分を定義する、損失関数を定義する。
    val loss = L2Loss[Float, Float]("l2loss") >> Mean("loss/mean") >> ScalarSummary(name = "Loss", tag = "Loss") // 本当はMSEが欲しいんだけど・・・

    // 損失をどう最適化するかを決定する最適化アルゴリズム。ここではAdaGrad法とした。
    val optimizer = tensorflow.api.ops.training.optimizers.AdaGrad(0.01f)

    // 入力、教師入力、層、損失関数、最適化関数の5つ組でシンプルな教師あり学習モデルを定義する。
    val model = Model.simpleSupervised(input, trainInput, layer, loss, optimizer)

    // さて、学習途上のデータをセーブしつつ学習するといろいろ都合が良いのでパスを定義する。
    val summariesDir = Paths.get("/tmp/summaries")
    // 最終的な学習を担うEstimatorを定義する。
    val estimator = tensorflow.api.learn.estimators.InMemoryEstimator(
      model,
      // checkpointをここに保存せよという指定。
      configurationBase = tensorflow.api.learn.Configuration(Some(summariesDir)),
      // 学習中に定期的に統計情報とチェックポイントを保存させるための設定。
      trainHooks = Set(
        tensorflow.api.learn.hooks.SummarySaver(summariesDir, tensorflow.api.learn.hooks.StepHookTrigger(100)),
        tensorflow.api.learn.hooks.CheckpointSaver(summariesDir, tensorflow.api.learn.hooks.StepHookTrigger(1000))
      ),
      // TensorBoardというダッシュボードをHTTP経由で表示させるための設定。
      tensorBoardConfig = tensorflow.api.config.TensorBoardConfig(summariesDir),
    )

    // 最初に定義したデータをTensorからDataSetに変換する。
    val trainDataSet = tensorflow.api.ops.data.Data.datasetFromTensors(trainDS)
    val trainLabelsDataSet = tensorflow.api.ops.data.Data.datasetFromTensors(trainLabels)
    val trainData = trainDataSet.zip(trainLabelsDataSet).repeat().shuffle(10000).prefetch(10)

    // 学習データを用いてEstimatorに学習させる。最大学習回数を50万回とする。
    estimator.train(() => trainData, tensorflow.api.learn.StopCriteria(maxSteps = Some(500000L)))
  }

するとこんな感じで学習状況が http://localhost:6006/ で確認できる。

どんどんエラーが小さくなっているのがわかる。

この調子で今年もモリモリ機械学習をやっていきたい。今回はここまで。

感想

型があるので入力補完などが強力に働いて作業しやすかった。30分くらいで簡単なXORが実装できた。

とはいえ、型があってもランタイムに発生するエラー(入力層と中間層のShapeが合ってないよ〜といったエラー)があり、そこはコンパイル時に分かってほしいな〜と思った。

基本的にC++のTensorFlowをJNIで呼び出すというライブラリなので、C++のTensorFlowのドキュメントなどを参考にすればよいことがわかった。

★記事をRTしてもらえると喜びます
Webアプリケーション開発関連の記事を投稿しています.読者になってみませんか?