Lambdaカクテル

京都在住Webエンジニアの日記です

Invite link for Scalaわいわいランド

Cyclic Barrierで安全なじゃんけんを実装する feat. Scala + Cats Effect

じゃんけんという遊びがある。

じゃんけんは、離散的に定義された三つの手(グー・チョキ・パー)の非推移的な優劣関係――グーはチョキに勝ち、チョキはパーに勝ち、パーはグーに勝つ――を用いて勝敗を決定する二人以上参加可能な競技的ゲームです。遊戯手順は、参加者が向き合って「じゃんけん」の掛け声とともに拳を振り下ろし、合図(「ぽん」「ほい」など)で選択した手を同時に提示し、その瞬間に優劣規則を適用して勝者・敗者・あいこ(同手による引き分け)を判定し、あいこの場合は同一手順を繰り返すだけという、実装・解析が容易で確率論やゲーム理論の導入例としても活用される簡潔なゲームです。 -- ChatGPT o3

この知的なゲームには問題がある。「じゃん」「けん」「ぽん」のタイミングで各プレイヤーが同期しなければならないのだ。同期せずに手を出すことは重大なルール違反だ。

Scalaの並行プログラミング用のライブラリであるCats Effectにも同期を制御するためのプリミティブが沢山用意されている。特にその中でもCyclicBarrierはこの目的にぴったりだ。 この記事ではCyclicBarrierを利用して複数プレイヤーが同期してじゃんけんを行えるようにする処理を実装しよう。

環境

この記事ではScala 3.7、Cats Effect 3.6.1を使っているものとする。ソースコードの冒頭には以下の記述があるものとする:

//> using scala 3.7
//> using deps "org.typelevel::cats-effect:3.6.1"

また、プログラムの実行にはScala CLIを利用する:

% scala-cli code.scala.sc

じゃんけんプレイヤーの挙動

じゃんけんは以下のようなタイミングで実行されるものとする:

  • (ランダムな時間待つ)
  • 「じゃん」
  • (全員が出すまで同期する)
  • (ランダムな時間待つ)
  • 「けん」
  • (全員が出すまで同期する)
  • (ランダムな時間待つ)
  • 「ぽん」
  • (全員が出すまで同期する)
  • (手を出す)

このとき、手を出したタイミングが50ms以上ずれた場合はじゃんけんが失敗するものとしよう。

じゃんけんプレイヤーの素朴な実装

まずは同期を何も行わないバージョンを用意しよう。

import cats.effect.std.Random
import cats.effect.*
import scala.concurrent.duration.*

def janken(): IO[Unit] = {
  for {
    n <- Random[IO].betweenInt(100, 1000)
    _ <- IO.sleep(n.milliseconds)
    _ <- IO.println(s"じゃん")
    m <- Random[IO].betweenInt(100, 1000)
    _ <- IO.sleep(m.milliseconds)
    _ <- IO.println(s"けん")
    l <- Random[IO].betweenInt(100, 1000)
    _ <- IO.sleep(l.milliseconds)
    _ <- IO.println(s"ぽん")
    hand <- Random[IO].oneOf('✊', '✌', '✋')
    _ <- IO.println(hand)
  } yield ()
}

Cats Effectでは、非同期タスクはIO[結果型]という型で扱う。Scalaの標準にあるFutureと違って、即時に実行されないという違いがある。

この実装は以下のようにして実行できる。

import cats.effect.unsafe.implicits.global // 非同期タスク実行用のランタイム

def run = janken() // 後から複数人に拡張する

run.unsafeRunSync() // 実行
% scala-cli code.scala.sc
じゃん
けん
ぽん
✋

パーですね。

複数人で動作させる

Cats Effectでは、IOのリストを同時並行的に実行するためのparSequenceというメソッドが利用できる。List[IO[A]]に対してparSequenceすると、IO[List[A]]になる。つまり、「非同期タスクのリスト」から「リストを計算する非同期タスク」にまとめてくれるのだ。JavaScriptで言うところのPromise.allみたいなやつだ。

ちなみにparSequenceを呼んでも実行はされない。あくまで非同期タスクをまとめて別の非同期タスクにしてくれる、というやつで、実行するかどうかはこっちに任せてくれる。

run を以下のように書き換えよう:

def run = List(janken(), janken(), janken()).parSequence
% scala-cli code.scala.sc
じゃん
じゃん
けん
じゃん
けん
けん
ぽん
✋
ぽん
✊
ぽん
✊

当然のことながら、非同期じゃんけんが行われてしまう。

手を出すタイミングを監視する

手を出すタイミングをちゃんと記録し、50ms以上のずれがあった場合はこれを知ることができるようにしよう。

jankenを書き換えて、出した手とその時刻を記録するようにしよう:

def janken(): IO[(Char, FiniteDuration)] = {
  for {
    // ...
    hand <- Random[IO].oneOf('✊', '✌', '✋')
    now <- Clock[IO].realTime
    _ <- IO.println(hand)
  } yield (hand, now)
}

次に、それぞれの手が出たタイミングが一定以内に収まることを確認する関数を定義しよう:

def validate(hands: Seq[(Char, FiniteDuration)]): Either["bad", "ok"] = {
  import cats.syntax.apply.*

  val times = hands.map(_._2)
  def withinThreshold(t1: FiniteDuration)(t2: FiniteDuration): Boolean = {
    val diff = if t1 < t2 then t2 - t1 else t1 - t2
    diff < 50.milliseconds
  }

  val results = List(withinThreshold) <*> times <*> times

  results.reduce(_ && _) match
    case true  => Right("ok")
    case false => Left("bad")
}

runも以下のように結果を表示するようにしよう:

def run = for {
  list <- List(janken(), janken(), janken()).parSequence
  _ <- IO.println(list)
  _ <- IO.println(validate(list))
} yield ()

するとどうだろう:

% scala-cli code.scala.sc
じゃん
じゃん
けん
じゃん
けん
けん
ぽん
✊
ぽん
✋
ぽん
✊
List((✊,1750435469729381 microseconds), (✋,1750435469607853 microseconds), (✊,1750435468856257 microseconds))
Left(bad)

ダメですね。

CyclicBarrier

さて、ここでCyclicBarrierが登場する。サイクリックバリアとは、複数の並行実行している処理が一時的に同期するためのプリミティブで、利用者は「待機」のみの操作が可能だ。あらかじめ指定した数だけ待機が揃えば、全ての待機状態が解放されて処理が進む。いちど解放されたバリアはまた何度も待機に入ることができる。このためCyclicと呼ばれている。工場なんかで両手を使わなければ動かないプレス機の安全装置なんかは待機数2のサイクリックバリアだと言ってもよいかもしれない。

Cats Effectでサイクリックバリアを作るには、cats.effect.std.CyclicBarrierを利用する:

val cb: IO[CyclicBarrier[IO]] = CyclicBarrier[IO](3)

サイクリックバリア自体が状態を持つプリミティブなので、CyclicBarrierを作る操作自体も非同期タスクとしてIO扱いになる。

使うときは以下のようにfor中で使うことになるだろう:

for {
  cb <- CyclicBarrier[IO](3)
  // ...
  _ <- cb.await // 1しか埋まらないので永遠に止まり続ける
} yield ()

この例では特に並行処理せずにawaitを呼び出しているので、永遠に数が埋まらずに待ち続けることになる。awaitは「バリアカウントを1増やして待機」するための操作で、CyclicBarrierのほぼ唯一の操作だ。awaitのシグネチャはIO[Unit]だ。「待つ」ことも非同期タスクだからね。

CyclicBarrier 導入

さて、jankenCyclicBarrierを受け取って同期を取るようにしよう。「じゃん」「けん」「ぽん」のタイミングで同期を取る必要があるので、CyclicBarrierも3つ必要だ:

def janken(
    jan: CyclicBarrier[IO],
    ken: CyclicBarrier[IO],
    pon: CyclicBarrier[IO]
): IO[(Char, FiniteDuration)] = {
  for {
    n <- Random[IO].betweenInt(100, 1000)
    _ <- IO.sleep(n.milliseconds)
    _ <- IO.println(s"じゃん")
    _ <- jan.await // 追加
    m <- Random[IO].betweenInt(100, 1000)
    _ <- IO.sleep(m.milliseconds)
    _ <- IO.println(s"けん")
    _ <- ken.await // 追加
    l <- Random[IO].betweenInt(100, 1000)
    _ <- IO.sleep(l.milliseconds)
    _ <- IO.println(s"ぽん")
    _ <- pon.await // 追加
    hand <- Random[IO].oneOf('✊', '✌', '✋')
    now <- Clock[IO].realTime
    _ <- IO.println(hand)
  } yield (hand, now)

なにやら大仰になってきたが、やっていることは「じゃん」などと言った後のタイミングで同期させているだけだ。

CycilcBarrierを作って渡す

後は呼び出す側でCyclicBarrierを作って渡すだけだ:

def run = for {
  jan <- CyclicBarrier[IO](3)
  ken <- CyclicBarrier[IO](3)
  pon <- CyclicBarrier[IO](3)
  list <- List(
    janken(jan, ken, pon),
    janken(jan, ken, pon),
    janken(jan, ken, pon)
  ).parSequence
  _ <- IO.println(list)
  _ <- IO.println(validate(list))
} yield ()
% scala-cli code.scala.sc
じゃん
じゃん
じゃん
けん
けん
けん
ぽん
ぽん
ぽん
✊
✊
✌
List((✊,1750436411931683 microseconds), (✌,1750436411933206 microseconds), (✊,1750436411931549 microseconds))
Right(ok)

おお

ちょっとリファクタする

参加者数がハードコードされた感じになっているのでちょっとリファクタする:

val playerCount = 3

def run = for {
  jan <- CyclicBarrier[IO](playerCount)
  ken <- CyclicBarrier[IO](playerCount)
  pon <- CyclicBarrier[IO](playerCount)
  list <- List.fill(playerCount)(janken(jan, ken, pon)).parSequence
  _ <- IO.println(list)
  _ <- IO.println(validate(list))
} yield ()

勝敗も判定させる

せっかくなので勝敗も判定させたい。まずは勝敗判定用のロジックを用意しよう:

def winner(hands: Seq[Char]): Option[Char] = {
  val handKinds = hands.toSet
  if handKinds.size != 2 then return None // あいこ
  else
    handKinds.toSeq.sorted match
      case Seq('✊', '✌') => Some('✊')
      case Seq('✋', '✌') => Some('✌')
      case Seq('✊', '✋') => Some('✋')
}

あとは、勝敗が決定するまで繰り返すだけだ。

まずはrunが勝敗を返すようにしよう:

def run = for {
  jan <- CyclicBarrier[IO](playerCount)
  ken <- CyclicBarrier[IO](playerCount)
  pon <- CyclicBarrier[IO](playerCount)
  list <- List.fill(playerCount)(janken(jan, ken, pon)).parSequence
  _ <- IO.println(list)
  _ <- IO.println(validate(list))
  result = winner(list.map(_._1)) // 修正
  _ <- IO.println(result) // 修正
} yield result // 修正

Someになるまで実行し続けるのは、Cats Effectでは簡単:

def runUntilWin = run.iterateUntil(_.isDefined)

あとはrunのかわりにrunUntilWinを実行させるだけだ:

runUntilWin.unsafeRunSync()

すると、勝敗が決定するまで繰り返してくれる:

% scala-cli code.scala.sc
じゃん
じゃん
じゃん
けん
けん
けん
ぽん
ぽん
ぽん
✌
✌
✌
List((✌,1750437528797143 microseconds), (✌,1750437528798949 microseconds), (✌,1750437528797094 microseconds))
Right(ok)
None
じゃん
じゃん
じゃん
けん
けん
けん
ぽん
ぽん
ぽん
✋
✌
✌
List((✌,1750437530625564 microseconds), (✋,1750437530625527 microseconds), (✌,1750437530625836 microseconds))
Right(ok)
Some()

30人でバトルさせる

人数を切り出してあるので、大人数でバトルしてもらうこともできる。30人でバトルさせるとどうなるか見てみる。ついでに何回かかるかも見てみよう:

def runUntilWin = for {
  counter <- Ref.of[IO, Int](0)
  _ <- (counter.update(_ + 1) >> run).iterateUntil(_.isDefined)
  n <- counter.get
  _ <- IO.println(s"$n 回かかりました。おつかれさまでした")
} yield ()

あとは回数を増やして実行する(待ち時間もこっそり短くしよう):

val playerCount = 30

List((✋,1750443600817269 microseconds), (✊,1750443600817268 microseconds), (✊,1750443600817268 microseconds), (✊,1750443600817267 microseconds), (✋,1750443600817264 microseconds), (✋,1750443600817263 microseconds), (✋,1750443600817270 microseconds), (✊,1750443600817270 microseconds), (✊,1750443600817265 microseconds), (✋,1750443600817272 microseconds), (✊,1750443600817271 microseconds), (✊,1750443600817263 microseconds), (✊,1750443600817269 microseconds), (✊,1750443600817266 microseconds), (✋,1750443600817262 microseconds), (✊,1750443600817266 microseconds), (✋,1750443600817266 microseconds), (✊,1750443600817261 microseconds), (✋,1750443600817271 microseconds), (✋,1750443600817265 microseconds), (✋,1750443600817260 microseconds), (✊,1750443600817267 microseconds), (✊,1750443600817265 microseconds), (✊,1750443600817259 microseconds), (✋,1750443600817257 microseconds), (✋,1750443600817261 microseconds), (✊,1750443600817258 microseconds), (✋,1750443600817250 microseconds), (✊,1750443600817249 microseconds), (✋,1750443600817262 microseconds))
Right(ok)
Some(✋)
17010 回かかりました。おつかれさまでした

結論

30人でじゃんけんをしないほうが良い。

余談

本当は50人で実験したかったけれど全然終わらなかったので30人に変更したという経緯がある。

練習問題

一斉に30人でじゃんけんするのではなく、同時に2〜3人ずつじゃんけんして勝ち上がっていくようにしてみよう。最初は15組が対戦し、最終的に1組になる。

ソースコード

//> using scala 3.7
//> using deps "org.typelevel::cats-effect:3.6.1"

import cats.effect.std.Random
import cats.effect.*
import cats.effect.std.CyclicBarrier
import scala.concurrent.duration.*

def janken(
    jan: CyclicBarrier[IO],
    ken: CyclicBarrier[IO],
    pon: CyclicBarrier[IO]
): IO[(Char, FiniteDuration)] = {
  for {
    n <- Random[IO].betweenInt(1, 2)
    _ <- IO.sleep(n.milliseconds)
    _ <- IO.println(s"じゃん")
    _ <- jan.await
    m <- Random[IO].betweenInt(1, 2)
    _ <- IO.sleep(m.milliseconds)
    _ <- IO.println(s"けん")
    _ <- ken.await
    l <- Random[IO].betweenInt(1, 2)
    _ <- IO.sleep(l.milliseconds)
    _ <- IO.println(s"ぽん")
    _ <- pon.await
    hand <- Random[IO].oneOf('✊', '✌', '✋')
    now <- Clock[IO].realTime
    _ <- IO.println(hand)
  } yield (hand, now)
}

def validate(hands: Seq[(Char, FiniteDuration)]): Either["bad", "ok"] = {
  import cats.syntax.apply.*

  val times = hands.map(_._2)
  def withinThreshold(t1: FiniteDuration)(t2: FiniteDuration): Boolean = {
    val diff = if t1 < t2 then t2 - t1 else t1 - t2
    diff < 50.milliseconds
  }

  val results = List(withinThreshold) <*> times <*> times

  results.reduce(_ && _) match
    case true  => Right("ok")
    case false => Left("bad")
}

def winner(hands: Seq[Char]): Option[Char] = {
  val handKinds = hands.toSet
  if handKinds.size != 2 then return None // あいこ
  else
    handKinds.toSeq.sorted match
      case Seq('✊', '✌') => Some('✊')
      case Seq('✋', '✌') => Some('✌')
      case Seq('✊', '✋') => Some('✋')
}

val playerCount = 30

def run = for {
  jan <- CyclicBarrier[IO](playerCount)
  ken <- CyclicBarrier[IO](playerCount)
  pon <- CyclicBarrier[IO](playerCount)
  list <- List.fill(playerCount)(janken(jan, ken, pon)).parSequence
  _ <- IO.println(list)
  _ <- IO.println(validate(list))
  result = winner(list.map(_._1))
  _ <- IO {
    if result.isDefined then
      println(list)
      println(validate(list))
      println(result)
  }
  _ <- IO.println(result)
} yield result

def runUntilWin = for {
  counter <- Ref.of[IO, Int](0)
  _ <- (counter.update(_ + 1) >> run).iterateUntil(_.isDefined)
  n <- counter.get
  _ <- IO.println(s"$n 回かかりました。おつかれさまでした")
} yield ()

import cats.effect.unsafe.implicits.global
runUntilWin.unsafeRunSync()
★記事をRTしてもらえると喜びます
Webアプリケーション開発関連の記事を投稿しています.読者になってみませんか?