ハイパーパラメータチューニングのための、pyspark用の自前関数
pysparkでハイパーパラメータのチューニングを行いたいとき、
MLflow、MLlibのModel Selectionなどの方法がありますが。
ツールをインストールしにくい環境だったり、
フレームワークに乗せられるように実装を手直しするのが手間な時のために、
評価結果の記録、パラメータグリッドを作る自前関数を作ってみました。
MLflow
https://mlflow.org/
ML Tuning: model selection and hyperparameter tuning | spark.apache.org
https://spark.apache.org/docs/latest/ml-tuning.html
自前関数の定義
実験結果の保存、読込用に2つの関数を定義します。
- save_experiment_results
- load_experiment_results
実験結果は、"/experiment_results/dt={}/experiment_id={}"という形で、
実験日・実験のIDの階層を持ったパスに保存します。
次に、パラメータグリッドを作るために、以下の関数を定義します。
- params_grid_searched
デフォルトのパラメータと、探索するパラメータのリストを指定して呼び出します。
各関数の利用例は、次の節で紹介します。
import hashlib, json, datetime, os
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.dataframe import DataFrame
from typing import Optional, Iterator
spark = SparkSession.builder.appName('local').getOrCreate()
### 実験結果を保存する
def save_experiment_results(
parent_path: str, # 実験結果保存先の親パス
params: dict, # ハイパーパラメータ
metrics: dict # 結果指標
) -> None:
dt = datetime.datetime.now().strftime("%Y%m%d")
experiment_id = hashlib.sha256(json.dumps(sorted(params.items(), key=lambda x: x[0])).encode()).hexdigest()
result = list()
for k in params.keys():
result.append(['parameter', k, str(params[k])])
for k in metrics.keys():
result.append(['metrics', k, str(metrics[k])])
df = spark.createDataFrame(result, ["type", "key", "value"])
df.write.mode("overwrite").parquet(
"{}/experiment_results/dt={}/experiment_id={}".format(parent_path, dt, experiment_id))
### 実験結果を読み込む
def load_experiment_results(
parent_path: str, # 実験結果保存先の親パス
dt_start: str = None, dt_end: str = None, # 対象実験結果期間
last_days: int = 3 # 直近n日間の結果を読み込む時に指定
) -> Optional[DataFrame]:
if dt_start is None or dt_end is None:
dt_end = datetime.datetime.now().strftime("%Y%m%d")
dt_start = (datetime.datetime.now() - datetime.timedelta(days=last_days)).strftime("%Y%m%d")
df = spark.read.parquet(
"{}/experiment_results".format(parent_path)).where(F.col('dt').between(dt_start, dt_end))
if df.count() == 0:
return None
pdf_keys = df.select('type', 'key').distinct().orderBy('type', 'key').toPandas()
selects = [F.col('dt'), F.col('experiment_id')]
aggs = list()
for t in ['parameter', 'metrics']:
pdf_key_type = pdf_keys[pdf_keys['type'] == t]['key'].values
for k in pdf_key_type:
col_k = k.replace('.', '_').replace(' ', '')
selects.append(F.when((F.col('type') == F.lit(t)) & (F.col('key') == k), F.col('value')).alias(col_k))
aggs.append(F.first(col_k, ignorenulls=True).alias(col_k))
return df.select(selects).groupBy('dt', 'experiment_id').agg(*aggs)
### grid search用のパラメータリストを作る
def params_grid_searched(
default_params: dict, # 既定のパラメータ辞書
grid: dict # 探索するパラメータの辞書
) -> Iterator[dict]:
def expand_grid(i: int, g: dict, l: list, param: dict):
key = list(g.keys())[i]
for v in g[key]:
param[key] = v
if i + 1 < len(g):
expand_grid(i + 1, g, l, param)
else:
l.append(param.copy())
grid_param_list = []
expand_grid(0, grid, grid_param_list, {})
for p in grid_param_list:
grid_params = default_params.copy()
for k in p:
grid_params[k] = p[k]
yield grid_params
自前関数の使用例
次に、定義した関数を利用して、ハイパーパラメータの探索を行います。
データを準備した後、
デフォルトパラメータ及び探索用パラメータグリッドを定義して、
params_grid_searched
の結果をfor文で繰り返し処理します。
各実験の最後に、save_experiment_results
でパラメータ・メトリクスを保存します。
全ての実験が終わったら、load_experiment_results
で結果を一覧します。
if __name__ == '__main__':
_temporary_path = "file:///tmp/pyspark_expriment_store"
# prepare test data
testdata = spark.read.format("libsvm") \
.load("file://{}/data/mllib/sample_libsvm_data.txt".format(os.environ['SPARK_HOME']))
df_train, df_valid = testdata.randomSplit([0.8, 0.2])
# prepare parameter grid
params_default = {
'maxIter': 10,
'reg': 1.0,
'elasticNet': 1.0
}
params_grid = {
'reg': [1.0, 0.5],
'elasticNet': [1.0, 0.5, 0.1]
}
for param in params_grid_searched(params_default, params_grid):
# fit & transform
from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression(maxIter=param['maxIter'], regParam=param['reg'], elasticNetParam=param['elasticNet'])
df_predict = lr.fit(df_train).transform(df_valid)
# validation
total = df_predict.count()
true_positive = df_predict.where((F.col('prediction') == 1.0) & (F.col('label') == 1.0)).count()
true_negative = df_predict.where((F.col('prediction') == 0.0) & (F.col('label') == 0.0)).count()
false_positive = df_predict.where((F.col('prediction') == 1.0) & (F.col('label') == 0.0)).count()
false_negative = df_predict.where((F.col('prediction') == 0.0) & (F.col('label') == 1.0)).count()
metrics = {
'total': total,
'true_positive': true_positive,
'true_negative': true_negative,
'false_positive': false_positive,
'false_negative': false_negative,
'precision': true_positive / (true_positive + false_positive),
'recall': true_positive / (true_positive + false_negative),
'accuracy': (true_positive + true_negative) / total
}
# save params & metrics
save_experiment_results(_temporary_path, param, metrics)
# show experiment results
load_experiment_results(_temporary_path).show()
上記の実行結果は、以下のようになります。
+--------+--------------------+----------+-------+---+------------------+--------------+--------------+------------------+------+-----+-------------+-------------+
| dt| experiment_id|elasticNet|maxIter|reg| accuracy|false_negative|false_positive| precision|recall|total|true_negative|true_positive|
+--------+--------------------+----------+-------+---+------------------+--------------+--------------+------------------+------+-----+-------------+-------------+
|20200422|f392485425ca64d73...| 0.1| 10|1.0| 1.0| 0| 0| 1.0| 1.0| 15| 7| 8|
|20200422|f55554b9263a3461a...| 1.0| 10|0.5|0.5333333333333333| 0| 7|0.5333333333333333| 1.0| 15| 0| 8|
|20200422|177eada5d905becf4...| 1.0| 10|1.0|0.5333333333333333| 0| 7|0.5333333333333333| 1.0| 15| 0| 8|
|20200422|14ebdaf8188fe938b...| 0.1| 10|0.5| 1.0| 0| 0| 1.0| 1.0| 15| 7| 8|
|20200422|e572a60f58da08090...| 0.5| 10|1.0|0.5333333333333333| 0| 7|0.5333333333333333| 1.0| 15| 0| 8|
|20200422|1befffa3ca8a4b8cc...| 0.5| 10|0.5|0.9333333333333333| 0| 1|0.8888888888888888| 1.0| 15| 6| 8|
+--------+--------------------+----------+-------+---+------------------+--------------+--------------+------------------+------+-----+-------------+-------------+
また、for文の繰り返しの中で中間データを作る場合は、
以下のエントリで考えた、処理済み中間ファイルの生成処理スキップと組み合わせてやると、
探索にかかる時間を短くする事が出来ると思います。
pysparkでデータ加工する時、処理済み中間ファイルの生成処理をスキップする書き方を考えた | takemikami.com
https://takemikami.com/2019/11/01/pyspark.html
基本的には、チューニング用のフレームワークに乗せた方が好ましいのだと思いますが、
あまり手を掛けずに、パラメータを変化させて試したい場合は、このようなやり方でも十分なのかなと思います。