hyperasでkeras-rlのハイパーパラメータチューニングをやってみる
ハイパーパラメータを最適化するために探索を行う、
hyperoptというpythonのライブラリがありますが。
これをニューラルネットワークライブラリのkerasで利用するための、
hyperasというラッパーがあります。
hyperopt | github
https://github.com/hyperopt/hyperopt
hyperas | github
https://github.com/maxpumperla/hyperas
このエントリでは、このhyperasを、
kerasの強化学習ライブラリであるkeras-rlで動かす流れを説明します。
hyperasのサンプルの理解
hyperasのREADME.mdにexampleのコードが記載されているので、
まずはこのコードを見てhyperasの使い方を把握します。
ハイパーパラメータの探索を行っている処理は、以下の部分です。
best_run, best_model = optim.minimize(model=model,
data=data,
algo=tpe.suggest,
max_evals=5,
trials=Trials())
このoptim.minimizeのパラメータで、
algo, trialsはhyperoptで定義されているパラメータ、max_evalsは試行回数です。
残りのmodelとdataを自分で定義することになります。
exampleのコードを読むと、
dataに定義したdata関数では、
使用するデータセットを読み込み、戻り値に返却していることがわかります。
def data():
※途中省略※
(x_train, y_train), (x_test, y_test) = mnist.load_data()
※途中省略※
return x_train, y_train, x_test, y_test
modelに定義したmodel関数では、
data関数の戻り値を入力としたfittingを行い、
最小化したいパラメータを、戻り値として返却していることがわかります。
※exampleでは、精度を最大化したいので、accuracyに-1掛けた値を返却。
def model(x_train, y_train, x_test, y_test):
※途中省略※
model.fit(x_train, y_train,
batch_size=,
epochs=1,
verbose=2,
validation_data=(x_test, y_test))
score, acc = model.evaluate(x_test, y_test, verbose=0)
※途中省略※
return {'loss': -acc, 'status': STATUS_OK, 'model': model}
また、model関数では、
ハイパーパラメータを````というテンプレートで、
変更できるよう定義しています。
# hyperas内部ではmodel関数をテンプレートとして
# パラメータを流し込んでPythonのコードを生成しているようです。
※一部抜粋
model.add(Dropout())
※一部抜粋
model.compile(loss='categorical_crossentropy', metrics=['accuracy'],
optimizer=)
keras-rlへの適用
hyperasのサンプルの理解が理解出来たところで、
これをkeras-rlに適用できないかを考えてみます。
強化学習の場合は、training/validationデータを使うのでは無く、
environmentが返却するobservationとrewardに基づいて学習を行います。
そのため、hyperasに渡すdata関数が定義できません。
そこで考え方を変えて、
model関数側でenvironmentを定義し、model関数で学習の処理が完結するようにします。
optim.minimizeのdataパラメータにはdata関数を渡す必要があるので、
data関数は何も行わないダミー関数として定義しておきます。
data関数の戻り値がmodel関数の引数になるので、どちらも無しで揃えておきます。
このような考え方で、
hyperasとkeras-rlのサンプルを組み合わせて書いたサンプルを以下のgistに置いています。
https://gist.github.com/takemikami/70ec73a76ad8c8c6cb41b17ce6cd9a77
このサンプルでは、CartPoleのEnvironmentに対して、
DQNAgentのモデルのノード数、EpsGreedyQPolicyのepsの値を探索させるように記載しています。