仕事でコードを書いていて、タイトルのようなケースに遭遇した。
val a, b, c, ... = ??? // 定数 val x = f(a) // なんかを計算したり生成する val y = g(x, b) val z = h(y, c) ... val 最終的に欲しい値 = (x, y, z)
こういうの。
これがもし仮に以下のようなパターンだったら、あーReader Monadですね〜となって抽象化完了する(はず)。
が、今回は共通の値を渡したいわけではなく、直前の値を再利用しつつ、新たな値を計算し、その結果を集めたい。
素朴な解: scanLeft
する
Tuple
のことをいったん忘れ、結果がSeq
で得られても良いのであれば、Catsなどの高級なライブラリを使わずとも、標準でscanLeft
という良いメソッドが用意されている。
val f: Int => Int = _ + 1 val g: Int => Int = _ + 2 val h: Int => Int = _ + 3 Seq(f, g, h).scanLeft(0){ case x -> ff => ff(x) } // => Seq(0, 1, 3, 6)
最初に渡した0が残ってしまうが、まあこれは捨てればよい。
モナディックに
ちょっと発想を変えると、f
, g
, h
を変形して、以下のような形にすることもできる。
Int
を受け取り、(受け取ったIntと計算結果とを結合したもの, 計算結果)
を返す- つまり
Int => (Seq[Int], Int)
これWriterモナドですね。ぜんぜん気付かなかった。「直前の」みたいな用語が出てくる時点でmonadicであることが暗示されていたのにそれに気付けなかった。
Writerモナドは文字列の結合とかログみたいな文脈で登場しがちだけれど、結合していくものが文字列である必要はなくて、Semigroup
でさえあれば動く。Seq[Int]
はSemigroup
のインスタンスなので、Writer[Seq[Int], Int]
を構成できる:
val f: Int => Int = _ + 1 val g: Int => Int = _ + 2 val h: Int => Int = _ + 3 import cats.syntax._ import cats.data.Writer import cats.instances._ val lift: (Int => Int) => Int => Writer[Seq[Int], Int] = f => x => Writer(Seq(f(x)), f(x)) val (mf, mg, mh) = (lift(f), lift(g), lift(h)) mf(0).flatMap(mg).flatMap(mh) // => WriterT(List(1, 3, 6), 6)
がしかし、結構不恰好になってしまった。Writer自体が結果の値ではなく関数を保持するようにしても良いのだが、そうすると
val lift2: (Int => Int) => Writer[Seq[Int], Int => Int] =
f => Writer(Seq(???), f)
という形になってしまい、何も記録できない。Writerを使う場合のメリットは、一度に2つ以上の要素を記録させることができる点だ。
Writerモナドを使うのは初めてなので、もっと筋が良い書き方があるかもしれない。
追記
がくぞさんに、Stateモナドでもいいよと教えてもらった。
前の結果を次で使いたい場合 State でもいいかもです?https://t.co/AZUWsXHvDQ pic.twitter.com/xeXpSA4Gtt
— がくぞ (@gakuzzzz) 2023年3月6日
import cats._ import cats.data._ import cats.syntax.all._ val f: Int => Int = _ + 1 val g: Int => Int = _ + 2 val h: Int => Int = _ + 3 val lift: (Int => Int) => State[Int, Int] = f => State(x => (f(x), f(x))) Seq(f, g, h).traverse(lift).runA(0).value // => Vector(1, 3, 6)
f,g,hはやっていることは同じだが、lift
のシグネチャは(Int => Int) => State[Int, Int]
だ。そして、State
の中身にはx => (f(x), f(x))
が入っている。
次にやっていることがちょっと難しそうだ。
まずSeq(f, g, h)
で f, g, hをまとめている。
次に.traverse(lift)
。traverse
のシグネチャはdef traverse[G[_]: Applicative, A, B](fa: F[A])(f: A => G[B]): G[F[B]]
。メソッドで呼び出されているので実際はfa
にはSeq(f, g, h)
が、f
にはlift
が入っている。
つまり、Seq(f, g, h).traverse(lift)
の型はState[Int, Seq[Int]]
になっている。
そして、runA
だ。runA
は初期状態を与えてStateを動かすが、run
と違って最終状態を捨てて結果だけ返す。
ちなみにrun
するとこのような感じになる:
Seq(f, g, h).traverse(lift).run(0).value // => (6, Vector(1, 3, 6))
そもそもSeq
をState
にtraverse
できることが結構な驚きだった。Seq[State[ほげ, ふが]]
をtraverseすると、順に状態を更新しながら結果を集めるという振舞いになる。
Traverse
は高カインド型が2つも出てくるので、いまだに覚えられていない。