pysparkでデータ加工する時、処理済み中間ファイルの生成処理をスキップする書き方を考えた
pysparkでデータ加工のプロセスを書いていると。
加工処理①→加工処理②→加工処理③のような流れで、
加工処理ごとに中間ファイルを出力しておき、
途中から(例えば、加工処理①は行わずに、②③のみ)再実行できるようにしたい場合があります。
※ビッグデータを扱う場合は、処理時間がかかるため、こういった要件は多いと思います。
再実行を行う場合に、
スキップしたい箇所をコメントアウトしたりなどの対応を行うと、ミスの原因となるので、
中間ファイルが存在するかをチェックして、処理をスキップさせるようにしたいのですが、
都度if文を書くと、冗長な記載が多くなり、これもミスの原因となります。
運用中のバッチ処理であれば、Luigiのようなツールを使うと便利なのですが、
データ分析の試行錯誤をしている最中の場合は、
分析プロセス全体のコードの見通しが悪くなるので、扱いにくいと思います。
spotify/luigi | GitHub
https://github.com/spotify/luigi
以上を踏まえて、都合の良いやり方をいろいろと考えたのですが。
今のところ、中間ファイルの存在確認・加工処理実行を行う関数を定義して、
プロセスを書く方法に落ち着いています。
具体的な方法を、本エントリで紹介します。
想定シナリオ
まず、中間ファイルを出力させる想定シナリオの例を示します。
prod_id, shop_id, priceのカラムを持つDataFrameが与えられ、
商品(prod_id)毎の最安値(price)と、最安値での取扱店舗(shop_id)を求めたいとします。
以下のようなコードになります。
from pyspark.sql import functions as F
from pyspark.sql.window import Window
df = spark.createDataFrame([
["prod_a", "shop_1", 100],
["prod_a", "shop_2", 400],
["prod_a", "shop_3", 500],
["prod_b", "shop_1", 200],
["prod_b", "shop_2", 700],
["prod_b", "shop_4", 200],
["prod_b", "shop_5", 400],
["prod_b", "shop_6", 500],
], ["prod_id", "shop_id", "price"])
df.withColumn('min_price', F.min('price').over(Window.partitionBy('prod_id'))) \
.where(F.col("price") == F.col("min_price")) \
.groupBy("prod_id") \
.agg(F.collect_set("shop_id").alias("shops"), F.first("price").alias("price")) \
.show()
出力される結果は、以下の通りです。
+-------+----------------+-----+
|prod_id| shops|price|
+-------+----------------+-----+
| prod_a| [shop_1]| 100|
| prod_b|[shop_1, shop_4]| 200|
+-------+----------------+-----+
中間ファイルに出力する対応
前項のシナリオは、大きく以下の2処理に分かれているので、
前半処理の結果を中間ファイルに出力することを考えます。
- 商品毎の最安値を求めて、該当価格のレコードのみに絞り込む
- 商品毎の店舗リストを求める
# 対象データが大きい場合、このような処理では、
# 一度中間ファイルに出力しなければ、リソース不足に陥るケースもあると思います。
関数の定義
まず、以下のcheckpoint関数を定義します。
この関数では、出力ファイルの存在を確認し、
ファイルが存在しなければ、processに渡された処理を実行して、ファイルに出力します。
戻り値は、出力ファイル(parquet)のDataFrameです。
# プロセスの実行とDatasetのparquetファイル出力、既に存在する場合は処理せず読込のみ
import time, os
def checkpoint(process, file):
start = time.time()
if os.path.exists("{}/_SUCCESS".format(file)):
print("==> {} : skip".format(file))
return spark.read.parquet(file)
dataset = process()
dataset.write.mode("overwrite").parquet(file)
print("==> {} : write / {} s".format(file, time.time() - start))
return spark.read.parquet(file)
# ここでは、os.path.existsでチェックしてますが、
# 実際は、hdfsやS3などのファイルをチェックする処理になります。
中間ファイル出力の実装
定義したcheckpoint関数を使って、
中間ファイルを出力しながら処理を実行させると、以下のようなコードになります。
※DataFrameの定義までは同じなので省略※
df_filterd = checkpoint(
process=lambda: df
.withColumn('min_price', F.min('price').over(Window.partitionBy('prod_id')))
.where(F.col("price") == F.col("min_price")),
file="/tmp/df_filterd"
)
df_filterd.groupBy("prod_id") \
.agg(F.collect_set("shop_id").alias("shops"), F.first("price").alias("price")) \
.show()
# ここではプロセスをlambdaで書いていますが、
# ある程度ボリュームのある処理であれば関数にした方が良いと思います。
また、checkpoint関数で、spark.read.parquetが呼ばれますが、
遅延評価なので、
実際にファイルアクセスが行われるのは、show()が実行されるタイミングになります。
この性質のおかげで、次のようなケースで不要な読み込みをスキップする事が出来ます。
加工処理①→加工処理②→加工処理③の流れで、
加工処理②、③がそれぞれ先行処理の出力のみを参照する場合の処理は、以下の処理が実行されます。
- 加工処理①
- 加工処理①の出力
- 加工処理①の読込
- 加工処理②
- 加工処理②の出力
- 加工処理②の読込
- 加工処理③
- 加工処理③の出力
ここで、加工処理③のみを再実行したい(加工処理①②の出力が存在する)場合は、
以下の処理のみが実行され、「加工処理①の読込」は実行されません。
- 加工処理②の読込
- 加工処理③
- 加工処理③の出力
このような書き方で、
間に処理を追加したり、ハイパーパラメータを変更したりなどの試行錯誤を行うケースに、
ある程度、見通しよくコードをメンテナンスできると思います。