Lambdaカクテル

京都在住Webエンジニアの日記です

Invite link for Scalaわいわいランド

ScalaとApache Sparkで線形回帰学習をやってみる + 簡単なSpark使い方メモ

1年間病院にかからなかったということで褒美の図書カードを健康保険組合にもらったので、こういう本を購入した。

鈍器っぽい。この本ではまずはscikit-learnを用いて線形回帰をやってみるという内容になっている。具体的には、以下のことをやっている:

  • OECDのデータから人々の幸福度合いのデータを得る
  • IMFのデータから国のGDPデータを得る
  • データを整形して国ごとにJOINする
  • うまくデータにフィットしそうな線形モデルを学習させる

実際に線形モデルがフィットするのかはともかくとして、このような構成になっている。

この内容を、Apache SparkというScalaで動くデータ分析エンジン(要するに、ScalaのPandas)でなぞってみたというエントリ。

Apache Spark

自分は普段からScalaを使っているので、Scalaで動くpandasとかscikit-learnの代替ライブラリとしてApache Sparkを使う。

spark.apache.org

Sparkはデータ分析のための分散エンジンで、ScalaのAPIが提供されているのが特徴。ローカルなノードでも動かすことができる。要するに、型付きのPandas、Scikit-learnだ

SparkはSpark Shellというシェルからも使うことができるけど、今回はScalaから呼び出すことにする。

Scala-cli

Scala-cliはScalaの包括的な実行環境になることを目指して開発されているツールで、その機能のうちの一つにScala Script の実行がある。

Scala Scriptとは拡張子が.scまたは.scala.scになっているファイルで、Scalaのコードを書くと直接実行できる。

このへんは id:tanishiking24 の記事が詳しい。

tanishiking24.hatenablog.com

Scalaを普通にsbtでビルドするときはプロジェクトディレクトリを作成して依存関係の設定を別ファイルに置いたりする必要があった。Scala Scriptを使うことで、LL言語のようにその場でScalaコードを実行できる。

// script.sc

def greeting(): Unit = println("Hello, World!")

greeting()
$ scala-cli ./script.sc
Hello, World!

Scala ScriptでSparkを使う

Scala Scriptでは依存性をスクリプト内に直接記述できるので、SparkをScala Scriptで使うにはファイル冒頭に以下のように記載する。

//> using scala "2.13"
//> using lib "org.apache.spark::spark-core:3.3.1"
//> using lib "org.apache.spark::spark-sql:3.3.1"
//> using lib "org.apache.spark::spark-mllib:3.3.1"

Sparkは今のところScala 3に対応していないようなので、明示的にScala 2.13を指定している。

事前に用意するデータ

書籍の指示に従って、以下のデータをCSVでダウンロードしておく。

基本的にデータを操作するのはSparkでやるので、最新のデータさえあればよい。

コード

まず最初に全てのコードを示す。以下のコードで回帰分析を行うことができた。

//> using scala "2.13"
//> using lib "org.apache.spark::spark-core:3.3.1"
//> using lib "org.apache.spark::spark-sql:3.3.1"
//> using lib "org.apache.spark::spark-mllib:3.3.1"

// Java17だとうまく動かないことがあるので11で動かすこと
// https://stackoverflow.com/questions/73465937/apache-spark-3-3-0-breaks-on-java-17-with-cannot-access-class-sun-nio-ch-direct

import org.apache.spark.sql.SparkSession

println("Apache Spark Regression Example")

// Scikit-learn、Keras、TensorFlowによる実践機械学習第2版
// pp.23 の線形モデルの実装をApache Sparkで実装してみる

// SparkSessionをまず作成する必要がある。ドキュメントでspark.という表記が登場した場合はこのSparkSessionのことを指している。
val spark = SparkSession
  .builder()
  .appName("Spark-Exercise-Regression")
  .config("spark.master", "local") // 実行するマスターノードを指定するのが必須なのでlocalとする
  .getOrCreate()

// https://homl.info/4
val oecdBli = spark.read
  .format("csv")
  .option("header", "true")
  .load("BLI_03012023082923436.csv")
// https://homl.info/5
val gdpPerCapita =
  spark.read.format("csv").option("header", "true").load("WEOOct2022all.csv")

// oecdBliを'Life satisfaction' かつ 男女を絞り込まない総合点数で絞り込み、 "LOCATION" カラムと "Value" カラムを残す。
val oecdBliFiltered = oecdBli
  .filter("`INDICATOR2` == 'SW_LIFS' AND `INEQUALITY6` = 'TOT'")
  .select("LOCATION", "Value")

// そして WEOは "WEO Subject Code" == "NGDP" でフィルタし、"ISO" カラムと 年のカラムとを残す。
val year = "2015"
val gdp = "NGDPDPC" // GDP per capita, constant price, dollar
val gdpPerCapitaFiltered = gdpPerCapita
  .filter(s"`WEO Subject Code` == '${gdp}'")
  .select("ISO", year, "WEO Subject Code")

// 最後に "Location" == "ISO" でInner JOINする。
import org.apache.spark.sql.Column
val joined = oecdBliFiltered
  .join(
    gdpPerCapitaFiltered,
    oecdBliFiltered.col("Location") === gdpPerCapitaFiltered.col("ISO"),
    "inner"
  )
  .select("ISO", "Value", year)
  .withColumnRenamed("Value", "Satisfaction")
  .withColumnRenamed(year, gdp)
  .withColumns(
    Map(
      "Satisfaction" -> new Column("Satisfaction").cast("double"),
      gdp -> new Column(gdp).cast("double")
    )
  )

joined.show()

joined.drop("ISO").select("Satisfaction", gdp).write.option("header", true).csv("./result.csv")

// Linear RegressionがFeaturesのためにVectorを要求するので、VectorAssemblerでVectorにカラムを変換する
import org.apache.spark.ml.feature.VectorAssembler
val va = new VectorAssembler().setInputCols(Array(gdp)).setOutputCol("NGDPVec")

import org.apache.spark.ml.regression.LinearRegression
val lr = new LinearRegression()
  .setMaxIter(50)
  .setRegParam(0.1)
  .setFeaturesCol("NGDPVec")
  .setLabelCol("Satisfaction")
val model = lr.fit(va.transform(joined))
val coeffs = model.coefficients.toArray.mkString("[", ", ", "]")
println(s"intercept: ${model.intercept}, coefficients: ${coeffs}")

spark.stop() // 必須

いくつか注意点があったのでメモしておく。

Java17だとうまく動かない

普通動くと思っていたが、Spark 3.3.0はJava 17だとエラーが出てちゃんと動作しない。

stackoverflow.com

Java 11で動作させることでこの問題は解消する。ASDFなどをインストールしておいて、JVMをぽんぽん切り替えられる環境を構築しておくとこういうときに怒らずに済むのでおすすめ。

SparkSessionを作成する必要がある

SparkのドキュメントではSpark Shellを使っている前提で話が進んでいくので、Scalaから使おうとするとなんか必要なオブジェクトが無い、ということがままある。その最たるものがspark変数で、これはSparkSession型である。

Spark ShellではなくScalaから呼び出す場合はSparkSessionというのを作成する必要がある。

SparkSessionというのはTensorFlowのsessionとかと同じで、計算エンジンに対するコネクションハンドラのようなもの。Sparkは分散型エンジンなので、計算ノードが他のマシンで動作しているような状況が当然にある。こうした状況を抽象化するために、なんらかの方法でSparkSessionを確保してこれに対して処理を依頼するという感じになっている。DBのコネクションハンドラとかと同じ感じ。

1台で計算するような通常の利用では計算ノードはlocalhostなので、以下のようにしてsessionを作成する。

val spark = SparkSession
  .builder()
  .appName("Spark-Exercise-Regression")
  .config("spark.master", "local") // 実行するマスターノードを指定するのが必須なのでlocalとする
  .getOrCreate()

上掲した通り、spark.masterという必須configがあるので、これにlocalを指定すればよい。

CSVのロード

CSVはspark.read.csv(...)とすれば読み込めるが、ヘッダを含むデータの場合はspark.read.format("csv").option("header", "true").load(...)と書く必要がある。Sparkは、こういうメソッドチェインで物事をなんとかする傾向がある。

これにより読み込まれたデータはDataFrameという形式になる。だいたいPandasのDataFrameと同じだと思ってよさそう。

カラムを絞り込む

SparkでDataFrameの特定のカラムだけ欲しい場合は、df.select("foo", "bar", ...)と書けばよい。自分はfilterとかで検索していたが、filterは行を絞り込むのでちょっと違う。

// oecdBliを'Life satisfaction' かつ 男女を絞り込まない総合点数で絞り込み、 "LOCATION" カラムと "Value" カラムを残す。
val oecdBliFiltered = oecdBli.filter("`INDICATOR2` == 'SW_LIFS' AND `INEQUALITY6` = 'TOT'").select("LOCATION", "Value")

上掲のコードのように、filterの引数はSQLを受け付けるので便利。

Join

DataFrame同士をjoinするにはdf.joinを使う。

// 最後に "Location" == "ISO" でInner JOINする。
val joined = oecdBliFiltered
  .join(gdpPerCapitaFiltered, oecdBliFiltered.col("Location") === gdpPerCapitaFiltered.col("ISO"), "inner")

いくつか呼び出し方にオーバーロードがあるが、一番使いやすくて応用が効くのは上掲したdf.join(df2, column, how)の形式だろう。innerのかわりにleftなどを指定すると他のjoinアルゴリズムを利用できる。

カラムのリネーム

.withColumnRenamed(from, to)という読んで字の如しなメソッドがあるので、これを呼ぶとカラムをリネームできる。

カラムのキャスト

線形回帰学習をさせるには、カラムがDoubleである必要がある。

.withColumns.castを組み合わせると複数のカラムをキャストできる。

  .withColumns(Map(
    "Satisfaction" -> new Column("Satisfaction").cast("double"),
    gdp -> new Column(gdp).cast("double"),
  ))

DataFrame をCSVに書き出す

readの逆操作としてwriteが用意されている。これはSparkSessionではなくDataFrameに対して呼び出す。

df.write.option("header", true).csv("./result.csv")

Sparkが分散計算エンジンである都合上、単一のCSVファイルが得られるわけではなく、いくつかのファイルの集合が得られるという仕組みになっているようだ。今回は./result.csv以下にpart-***といった感じのファイルが1つだけ出力された。

カラムを1つのベクトルにまとめる

Sparkの回帰関連の機能を使うためには、FeaturesとLabelという2つのカラムが必要になる。

  • Features
    • 入力。訳すなら「特徴量」?
    • GDPから暮らしの満足度を予想するという今回のタスクでは、「GDP」がFeatures
    • 複数の特徴量を扱うために、ベクトルで入力する必要がある
  • Label
    • 出力。
    • GDPから暮らしの満足度を予想するという今回のタスクでは、「満足度」
    • こちらはスカラ値

このメンタルモデルは回帰に限らず分類問題でも登場する。

Featuresはベクトル型なので、いったんGDPが入っているカラムを要素数1のベクトルに変換する必要がある。そのためにはVectorAssemblerを使えばよい。

// Linear RegressionがFeaturesのためにVectorを要求するので、VectorAssemblerでVectorにカラムを変換する
import org.apache.spark.ml.feature.VectorAssembler
val va = new VectorAssembler().setInputCols(Array(gdp)).setOutputCol("NGDPVec")

この処理によって新たにNGDPVecカラムが生え、その型はvector<double>になる。

線形回帰モデルを訓練する

前述したFeaturesカラムとLabelカラムをもとに線形回帰モデルを訓練する。

線形回帰モデルを作るには、org.apache.spark.ml.regression.LinearRegressionのインスタンスを経由する。

一度インスタンスを作成した後は、setterを使って各種設定を行い、最終的にfitメソッドを呼び出すことで訓練が行われたモデルが作成される。

import org.apache.spark.ml.regression.LinearRegression
val lr = new LinearRegression()
  .setMaxIter(50)
  .setRegParam(0.1)
  .setFeaturesCol("NGDPVec")
  .setLabelCol("Satisfaction")
val model = lr.fit(va.transform(joined))
val coeffs = model.coefficients.toArray.mkString("[", ", ", "]")
println(s"intercept: ${model.intercept}, coefficients: ${coeffs}")

訓練結果の切片(intercept)と係数ベクトル(coefficients)は、モデルのinterceptメソッド・coefficientsメソッドを呼び出すことで得られる。

今回学習させた結果、以下のような出力が得られた。

intercept: 5.89829422298187, coefficients: [2.1020826210258197E-5]

本ではそれぞれ4.85, 4.91E-5となっていて、ちょっとデータが違うのかもしれない。

グラフ

Spark自体にはビジュアライズ機能が付いていない?っぽい。VegasというSparkに対応したビジュアライズライブラリがあるのだけれど、数年前に開発が放棄されてScala 2.12でも動かないという有様なので使えない。Plotly-scalaという、Plotly.jsに対応したJSONを吐くライブラリがあって、これは最近までメンテされている。

github.com

本格的に使うならこれで良いのだけれど、今回はプロットさえできればいいのでGnuplotで先程出力したCSVをプロットする。

$ gnuplot
set datafile separator ","
plot "result.csv/part-00000-5d53b5cc-20da-4e51-805c-2b2638184054-c000.csv" using "NGDPDPC":"Satisfaction", x*2.1020826210258197E-5+5.89829422298187

するとこのようなグラフが得られ、線形回帰がフィットしていることがわかる。

Gnuplotで線形回帰モデルがデータにフィットしている様子を図示

感想

Apache SparkとGnuplotとを使って簡単な線形回帰の問題を解くことができた。教科書と違う数字が出てきたのは、計算する年度やデータの取り方がなんか違うのだろうと思う(どのデータを使えばいいかは教科書には書かれていないので)。

自分はPandasを使ったことがあるけれど、Sparkはまたちょっと違う操作感で面白かった。具体的にはSQLを使えるあたりとか、強い型のおかげで使い方がすぐ分かり、実行時エラーを何度も繰り返すというPythonあるあるな辛さが無かったのが良かった。

SparkはPandasよりも守備範囲が広いライブラリ、というかエンジンなので、どうしても起動には少し時間がかかるなという印象を受けた。また分散計算エンジンである都合上、デバッグ出力がかなり出力される。ロガーの出力をいじったら消せそうだけど、UNIX的なパイプラインとして使う感じのプログラムではないことは確かだ。

Numpy/Pandasと比べて情報の絶対量が少ないことを除いては、Pandasよりも操作性は良いと感じる。あと依存性で悩まされることもまったく無かった(初回の実行で全て読み込まれて動いた)。

あとSparkのドキュメントの見通しが悪すぎる。どこ見たらいいんだ・・・と途方に暮れる。

参考文献

spark.apache.org

sparkbyexamples.com

★記事をRTしてもらえると喜びます
Webアプリケーション開発関連の記事を投稿しています.読者になってみませんか?