pysparkでデータ加工のプロセスを書いていると。

加工処理①→加工処理②→加工処理③のような流れで、
加工処理ごとに中間ファイルを出力しておき、
途中から(例えば、加工処理①は行わずに、②③のみ)再実行できるようにしたい場合があります。
※ビッグデータを扱う場合は、処理時間がかかるため、こういった要件は多いと思います。

再実行を行う場合に、
スキップしたい箇所をコメントアウトしたりなどの対応を行うと、ミスの原因となるので、
中間ファイルが存在するかをチェックして、処理をスキップさせるようにしたいのですが、
都度if文を書くと、冗長な記載が多くなり、これもミスの原因となります。

運用中のバッチ処理であれば、Luigiのようなツールを使うと便利なのですが、
データ分析の試行錯誤をしている最中の場合は、
分析プロセス全体のコードの見通しが悪くなるので、扱いにくいと思います。

spotify/luigi | GitHub
https://github.com/spotify/luigi

以上を踏まえて、都合の良いやり方をいろいろと考えたのですが。
今のところ、中間ファイルの存在確認・加工処理実行を行う関数を定義して、
プロセスを書く方法に落ち着いています。

具体的な方法を、本エントリで紹介します。

想定シナリオ

まず、中間ファイルを出力させる想定シナリオの例を示します。

prod_id, shop_id, priceのカラムを持つDataFrameが与えられ、
商品(prod_id)毎の最安値(price)と、最安値での取扱店舗(shop_id)を求めたいとします。

以下のようなコードになります。

from pyspark.sql import functions as F
from pyspark.sql.window import Window

df = spark.createDataFrame([
  ["prod_a", "shop_1", 100],
  ["prod_a", "shop_2", 400],
  ["prod_a", "shop_3", 500],
  ["prod_b", "shop_1", 200],
  ["prod_b", "shop_2", 700],
  ["prod_b", "shop_4", 200],
  ["prod_b", "shop_5", 400],
  ["prod_b", "shop_6", 500],
], ["prod_id", "shop_id", "price"])

df.withColumn('min_price', F.min('price').over(Window.partitionBy('prod_id'))) \
 .where(F.col("price") == F.col("min_price")) \
 .groupBy("prod_id") \
 .agg(F.collect_set("shop_id").alias("shops"), F.first("price").alias("price")) \
 .show()

出力される結果は、以下の通りです。

+-------+----------------+-----+
|prod_id|           shops|price|
+-------+----------------+-----+
| prod_a|        [shop_1]|  100|
| prod_b|[shop_1, shop_4]|  200|
+-------+----------------+-----+

中間ファイルに出力する対応

前項のシナリオは、大きく以下の2処理に分かれているので、
前半処理の結果を中間ファイルに出力することを考えます。

  • 商品毎の最安値を求めて、該当価格のレコードのみに絞り込む
  • 商品毎の店舗リストを求める

# 対象データが大きい場合、このような処理では、
# 一度中間ファイルに出力しなければ、リソース不足に陥るケースもあると思います。

関数の定義

まず、以下のcheckpoint関数を定義します。
この関数では、出力ファイルの存在を確認し、
ファイルが存在しなければ、processに渡された処理を実行して、ファイルに出力します。
戻り値は、出力ファイル(parquet)のDataFrameです。

# プロセスの実行とDatasetのparquetファイル出力、既に存在する場合は処理せず読込のみ
import time, os
def checkpoint(process, file):
  start = time.time()
  if os.path.exists("{}/_SUCCESS".format(file)):
    print("==> {} : skip".format(file))
    return spark.read.parquet(file)
  dataset = process()
  dataset.write.mode("overwrite").parquet(file)
  print("==> {} : write / {} s".format(file, time.time() - start))
  return spark.read.parquet(file)

# ここでは、os.path.existsでチェックしてますが、
# 実際は、hdfsやS3などのファイルをチェックする処理になります。

中間ファイル出力の実装

定義したcheckpoint関数を使って、
中間ファイルを出力しながら処理を実行させると、以下のようなコードになります。

※DataFrameの定義までは同じなので省略※

df_filterd = checkpoint(
    process=lambda: df
      .withColumn('min_price', F.min('price').over(Window.partitionBy('prod_id')))
      .where(F.col("price") == F.col("min_price")),
    file="/tmp/df_filterd"
)

df_filterd.groupBy("prod_id") \
  .agg(F.collect_set("shop_id").alias("shops"), F.first("price").alias("price")) \
  .show()

# ここではプロセスをlambdaで書いていますが、
# ある程度ボリュームのある処理であれば関数にした方が良いと思います。

また、checkpoint関数で、spark.read.parquetが呼ばれますが、
遅延評価なので、
実際にファイルアクセスが行われるのは、show()が実行されるタイミングになります。
この性質のおかげで、次のようなケースで不要な読み込みをスキップする事が出来ます。

加工処理①→加工処理②→加工処理③の流れで、
加工処理②、③がそれぞれ先行処理の出力のみを参照する場合の処理は、以下の処理が実行されます。

  • 加工処理①
  • 加工処理①の出力
  • 加工処理①の読込
  • 加工処理②
  • 加工処理②の出力
  • 加工処理②の読込
  • 加工処理③
  • 加工処理③の出力

ここで、加工処理③のみを再実行したい(加工処理①②の出力が存在する)場合は、
以下の処理のみが実行され、「加工処理①の読込」は実行されません。

  • 加工処理②の読込
  • 加工処理③
  • 加工処理③の出力

このような書き方で、
間に処理を追加したり、ハイパーパラメータを変更したりなどの試行錯誤を行うケースに、
ある程度、見通しよくコードをメンテナンスできると思います。