Lambdaカクテル

京都在住Webエンジニアの日記です

Invite link for Scalaわいわいランド

Seq[(K, V)]について、キーが同じである限り結合し続ける処理

こういうメソッドを定義したい。

どうしてこういうメソッドが欲しかったのかというと、自作の解説動画生成ツールでBGMをうまく割り当てる処理のために必要だったのだ。

BGMと尺

BGMと尺が以下の形式で与えられる。

val pairs: Seq[Pair]= Seq(
  "a.mp3" -> FiniteDuration(10, "seconds"),
  "a.mp3" -> FiniteDuration(15, "seconds"),
  "b.mp3" -> FiniteDuration(5, "seconds"),
  "c.mp3" -> FiniteDuration(10, "seconds"),
  "c.mp3" -> FiniteDuration(10, "seconds"),
)

どうしてこうなっているかというのは割愛するとして、これらの入力を以下のような形式に変換したい。

val pairs: Seq[Pair]= Seq(
  "a.mp3" -> FiniteDuration(25, "seconds"),
  "b.mp3" -> FiniteDuration(5, "seconds"),
  "c.mp3" -> FiniteDuration(20, "seconds"),
)

2-tupleの左側で表現されるキーが合致している限り、2-tupleの右側の値を結合し続けたい。これをうまくコードに落しこめないだろうか。

より抽象的には、def reduction[A : Eq, B : Monoid](xs: Seq[(A, B)]): Seq[(A, B)]を定義したい。

やった

まず前提となるコードがこれ。

import scala.concurrent.duration.FiniteDuration
import cats._
import cats.implicits._

// こういうペアがあるとする。
type Pair = (String, FiniteDuration)

// そのペアがつらなったリストがあるとする。
val pairs: Seq[Pair]= Seq(
  "a.mp3" -> FiniteDuration(10, "seconds"),
  "a.mp3" -> FiniteDuration(15, "seconds"),
  "b.mp3" -> FiniteDuration(5, "seconds"),
  "c.mp3" -> FiniteDuration(10, "seconds"),
  "c.mp3" -> FiniteDuration(10, "seconds"),
)
// pairs: Seq[Pair] = List(("a.mp3", 10 seconds), ("a.mp3", 15 seconds), ("b.mp3", 5 seconds), ("c.mp3", 10 seconds), ("c.mp3", 10 seconds))

よく見ると、これはSeqを畳みこむ作業だといえそうだ。

val f = (x: Pair) => (y: Pair) => x._1 -> y._1 match {
  case k -> l if k == l => Seq(k -> (x._2 |+| y._2))
  case _ => Seq(x, y)
}

pairs.foldLeft[Seq[Pair]](Seq("" -> FiniteDuration(0, "seconds"))){ case (x, y) =>
  val combined = f(x.last)(y)
  x.take(x.length - 1) ++ combined
}
// res1: Seq[Pair] = List(("", 0 seconds), ("a.mp3", 25 seconds), ("b.mp3", 5 seconds), ("c.mp3", 20 seconds))

できた。foldLeftの第一引数にSeq.emptyを使うとうまくいかなかった。なんでだろう?

もっとスマートな方法もありそうだけど、まあ今回は動いたのでよし。

追記: 別解たち (20221123T2046+0900)

ツイッターで苦しんでいると別解を教えてくれる人たちがいた。

headに積み重ねていくパターン

1つのfoldLeftにうまく押し込めたパターンだ。

pairs.foldLeft(Seq.empty[Pairs]) {
  case ((headKey, headValue) +: tail, (k, v)) if headKey == k =>
    (headKey -> (headValue |+| v)) +: tail
  case (ls, pair) =>
    pair +: ls
}.reverse

foldLeftなので、関数ブロックには(xs, x)の形で引数が渡ってくる:

Visual scala referenceより引用

superruzafa.github.io

headKeykとが等しい場合は、いい感じに先頭をcombineするという操作を繰り返す。この処理の面白い点は、元のpairsを頭から走査しつつ、先頭要素を積み重ねていくので、最終的な順番は反転してしまい、最後にreverseする必要があるということ。リスト処理だと定番の状況ですね。

ListMapを使うパターン

そう、今回のキモは元々の順序を維持しなければならないということである。a.mp3b.mp3の順序がどうでも良いということであれば、単にgroupMapすれば良い(そのかわり、a, b, aというパターンであっても強制的にa, bに還元されてしまうが):

// いちどMapを経由するため順序が失われる
pairs.groupMapReduce(_._1)(identity) { case (kvs1, kvs2) => (kvs1._1, kvs1._2 |+| kvs2._2) }.values.toSeq
// res1: Seq[Pair] = List(("a", 25 seconds), ("b", 5 seconds), ("c", 20 seconds))

がくぞさんはより簡潔にfoldMapしている:

pairs.foldMap(Map(_)).toSeq

groupMapReduceすると明示的にグルーピングしてMapの構造を作ることを意識するが、foldMapだとMap(_)とするだけでよい。Map[K, V : Monoid]のとき、自動的にMap[K, V]はモノイドになるから、そのままfoldできるのだ。

foldRightするパターン

ソート済みならもうちょっと平たく書けるらしい。

pairs.foldRight(Seq.empty[Pair]) {
  case ((k1, d1), (k2, d2) +: rest) if k1 == k2 => (k1, d1 + d2) +: rest
  case (pair, acc) => pair +: acc
}

この解は、前掲したKugiyaJさんの解とちょうど向きが逆----つまりfoldLeftのかわりにfoldRightを使っている----だ。

ソート済みでない場合はどうなるかというと:

def reducingRight(xs: Seq[Pair]): Seq[Pair] = 
  xs.foldRight(Seq.empty[Pair]) {
    case ((k1, d1), (k2, d2) +: rest) if k1 == k2 => (k1, d1 + d2) +: rest
    case (pair, acc) => pair +: acc
  }

val xs2 = Seq(
  "a" -> FiniteDuration(10, "seconds"),
  "a" -> FiniteDuration(15, "seconds"),
  "b" -> FiniteDuration(5, "seconds"),
  "a" -> FiniteDuration(15, "seconds"),
)
// xs2: Seq[(String, FiniteDuration)] = List(("a", 10 seconds), ("a", 15 seconds), ("b", 5 seconds), ("a", 15 seconds))

reducingRight(xs2)
// res1: Seq[Pair] = List(("a", 25 seconds), ("b", 5 seconds), ("a", 15 seconds))

値が分かれてしまった。しかしこの場合は本来の仕様に合致している(キーが同じである限り結合し続ける)のでこれでいい。元々は、その時流すBGMの尺を計算するためのアルゴリズムなので、結果が分かれるのは全然問題がないのだ。

リストを舐めつつ積み重ねていくような処理は、foldRightのほうが筋良く書けそうだ。

処理を分けるパターン

Catsを使わないパターンを教えていただいた。

import scala.annotation.tailrec
import scala.concurrent.duration.FiniteDuration

object Main {

    type Pair = (String, FiniteDuration)

  
  def main(args: Array[String]): Unit = {

    val pairs: Seq[Pair] = Seq(
      "a.mp3" -> FiniteDuration(10, "seconds"),
      "a.mp3" -> FiniteDuration(15, "seconds"),
      "b.mp3" -> FiniteDuration(5, "seconds"),
      "c.mp3" -> FiniteDuration(10, "seconds"),
      "c.mp3" -> FiniteDuration(10, "seconds")
    )

    val result = pairs
      .group { (a, b) => a._1 == b._1 }
      .map(_.reduceLeft { (a, b) => a._1 -> (a._2 + b._2) })

    result.foreach(println)
  }

  implicit class SeqOps[A](list: Seq[A]) {

    def group(isSameGroup: (A, A) => Boolean): List[List[A]] = {

      @tailrec
      def go(list: List[A], acc: List[List[A]]): List[List[A]] = list match {
        case Nil => acc.reverse
        case head :: _ =>
          val (sameGroup, rest) = list.span(isSameGroup(head, _))
          go(rest, sameGroup +: acc)
      }

      go(list.toList, Nil)
    }
  }
}

結語

foldRightでくっつけるパターンが最も簡潔で面白かった。今回は「全てのキーを集めてくる」必要はなくて、隣接してさえいればいいのでfoldMapする必要はなさそうだった。

みなさまありがとうございました〜

★記事をRTしてもらえると喜びます
Webアプリケーション開発関連の記事を投稿しています.読者になってみませんか?