Lambdaカクテル

京都在住Webエンジニアの日記です

Invite link for Scalaわいわいランド

CatsのfoldM(foldLeftM)を勉強した

foldMについてちょっと勉強したのでメモ。

文脈

このへんの会話が流れてきたけど、そういえばfoldM使ったことなかったな、と思った。

とりいそぎ、foldMについて勉強してみるか。

おさらい: foldLeft

foldMは正確?にはfoldLeftMなのだけれど、いったんfoldLeftについても復習しておこう。

www.scala-lang.org

def foldLeft[B](z: B)(op: (B, A) ⇒ B): B

あるSeq[A]に対して、初期値B(B, A) => Bとなる関数を渡すと、Seq[A]は畳み込まれてBになる。

例えばsumを以下のようにして実装できる:

val xs: Seq[Int] = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9)

xs.foldLeft(0)(_ + _)
xs.reduce(_ + _) // same

しかしこれだと型はSeq[Int]からIntになるだけなのであまり変わり映えしない。同じことがreduceで可能だ。

より面白い例として、リストを逆転させることができる:

val xs: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9)

xs.foldLeft[List[Int]](Nil)((lis, x) => x * 2 :: lis) // => List(18, 16, 14, 12, 10, 8, 6, 4, 2)

畳み込みというと語弊があるが、左からリストを走査していくのでそれを元に値を組み立てているのだ。reduce(A, A) => Aの関数しか扱えないので、こういった処理はreduceではできない。

foldLeft自体の解説は色々あるので参考にしてほしい。

dev.classmethod.jp

foldM

さて、おなじみの関数型ライブラリCatsはコレクションにfoldMを導入する:

https://typelevel.org/cats/api/cats/Foldable.html

型シグネチャを見てみよう:

def foldM[G[_], A, B](fa: F[A], z: B)(f: (B, A) ⇒ G[B])(implicit G: Monad[G]): G[B]

実際にこのコードを使うときはコレクションの拡張メソッドとして呼び出すので、実際は以下のような形になる:

def foldM[G[_], B](z: B)(f: (B, A) ⇒ G[B])(implicit G: Monad[G]): G[B]
// 参考
def foldLeft[B](z: B)(op: (B, A) ⇒ B): B

すると、差分となるのは関数の部分がBの代わりにG[B]を返すこと、そしてG[_]はモナドでなければならないという2点のみだ。

ちなみにfoldMMはモナドのMだ。同じく、語尾にAがついているメソッドがあるとしたら、そいつはアプリカティブ版だ。

foldMの振る舞い

foldMは、foldと同様に畳み込みを行うが、その過程でモナディックな合成を行ってくれる。まずは何もしないモナドであるIdfoldMに渡してみよう:

import cats.implicits.{*, given}
import cats.Id

Seq(1, 2, 3).foldM[Id, Int](0)((a, x) => a + x) // => 6: Int

IdMonadのインスタンスだが、実際はId[A]は単なるAのラッパーとして振る舞い、何も特殊な振る舞いをしない。したがってこれは普通のfoldLeftを実行しているのと変わりない。実行結果は全ての値を足した6になった。

さて、今度は本物のモナドを使ってみよう。こういうときはOptionを使うと良さそうだ:

Seq(1, 2, 3).foldM[Option, Int](0)((a, x) => Some(a + x))
// => Some(6)

実行結果はSome(6)だ。foldMOptionを適用すると、Some(_)である間だけ畳み込みを続行する。常時Someを返すようにしているので、先程のバージョンとあまり変わりはない。

もし計算結果が10を超えたらNoneを返すようにするとどうだろう:

def mySum(xs: Seq[Int]) = xs.foldM[Option, Int](0)((a, x) =>
  (a + x) match {
    case n if n > 10 => None
    case n           => Some(n)
  }
)

mySum(Seq(1, 2, 3)) // => Some(6)
mySum(Seq(1, 2, 3, 4)) // => Some(10)
mySum(Seq(1, 2, 3, 4, 5)) // => None

Optionの特性が発揮され、10を超えたタイミングで計算が中断した。これは畳み込み処理を強制脱出させるのに便利だ。

発展例: ルールエンジン

以下のようなPersonデータがあるとする。

import cats.implicits.{*, given}

case class Person(name: String, age: Short, employed: Boolean, living: String)

val x = Person("windymelt", 30, employed = true, "Kyoto")
val y = Person("foo", 24, employed = false, "Kyoto")
val z = Person("bar", 32, employed = true, "Hokkaido")
val a = Person("buzz", 15, employed = false, "Kyoto")

Personをあらかじめ定められたルールに適合するかどうか判定できるだろうか。ルールは以下のように定義する:

type Rule = Person => Either[String, Person]

val adult: Rule = p => Right(p).filterOrElse(_.age >= 20, "Should be an adult")
val free: Rule = p => Right(p).filterOrElse(!_.employed, "Should not be employed")
val kyoto: Rule = p => Right(p).filterOrElse(_.living == "Kyoto", "Should be living in Kyoto")

例えば、アルコールを購入するためには大人である必要がある:

// rules may be added
def validateAlcohol0(
    p: Person,
): Either[String, Person] = for {
  rp <- Right(p)
  ad <- adult(rp)
} yield ad

ルールが1つだけならばadultに渡すだけで良いが、今後ルールが追加されることを考えると、forを使っておいたほうが無難だ。

数年後、法律が厳しくなって身分証の提示が必須になったとする。これをコードに反映させよう:

def checkId(p: Person): Option[String] = Some("123456") // stub
val shouldHaveId: Rule = p => Right(p).filterOrElse(checkId(_).isDefined, "Should have ID card")
def validateAlcohol1(
    p: Person,
): Either[String, Person] = for {
  rp <- Right(p)
  ad <- adult(rp)
  id <- shouldHaveId(ad)
} yield id

validateAlcohol1(x) // => Right(x)

検証を実行すると2つのルールが検証されることがわかる。この調子で今度は雇用のためのルールを定義しよう。成人で現在就職しておらず、京都に住んでいる人が欲しいとする:

def validateHiring0(p: Person): Either[String, Person] = for {
  rp <- Right(p)
  ad <- adult(rp)
  fr <- free(ad)
  ky <- kyoto(fr)
} yield ky

forを使うとこのような実装になるはずだ。

ここで1つ問題がある。ルールを増減させた別のルールを定義するには、コードをほぼ複製してコンパイルし直さなければならないのだ。また、どう見てもボイラープレートにしかなっていない変数があり、冗長だ。コレクションにルールを集めて、うまく使えないだろうか?

foldMの出番

さて、ここからがfoldMの使い所になる。foldMfor文によるモナドの積み重ねをコレクションに押し込めると考えることができる。

// ルールのコレクションを定義する
val alcoholRules = Seq(adult)
// foldMでそれを全て適用する
def validateAlcohol(
    p: Person,
): Either[String, Person] = alcoholRules.foldLeftM(p)((p, f) => f(p))

val hiringRules = Seq(adult, free, kyoto)
def validateHiring(
    p: Person,
): Either[String, Person] = hiringRules.foldLeftM(p)((p, f) => f(p)) // for comprehensionをfoldの形にできる

これにより、動的にルールの数を増減できるようになった。例えば、夜間に限って身分証を出さないとアルコールを買うことができない、といった動的な定義も可能になる。

def isNight: Boolean = ???

val alcoholRules = Seq(adult)
val alcoholRulesNight = Seq(adult, shouldHaveId)

def validateAlcohol(
    p: Person,
): Either[String, Person] = {
  val rs = if (isNight) alcoholRulesNight else alcoholRules
  rs.foldLeftM(p)((p, f) => f(p))
}

まとめ

  • foldMfoldLeftのモナディック版だ。
  • モナディック版であるということは、畳み込む過程でモナドの合成が同時に行なわれるということを意味する。
  • foldMを使うことで、for式を使ったルールの積み重ねを動的に書けるようになる可能性がある。
★記事をRTしてもらえると喜びます
Webアプリケーション開発関連の記事を投稿しています.読者になってみませんか?