iPX社員によるブログ

iPX社員が"社の動向"から"自身の知見や趣味"、"セミナーなどのおすすめ情報"に至るまで幅広い話題を投下していくブログ。社の雰囲気を感じ取っていただけたら幸いです。

Behavior CloningとGAILについて

久しぶりの投票になりますが、パルハットです。今回強化学習関連のテーマでの内容になります。最近複数のプロジェクトが強化学習関連の案件でした。主のstable baselinesという強化学習関連のフレームワークを使っていました。そこで色々DQN(Deep Q Network)とかPPO2(Proximal Policy Optimization)といったアルゴリズムを使ったていましたが、今回のメイン内容はBehavior CloningとGAILGenerative Adversarial Imitation Learningの紹介です。

Behavior CloningとGAIL

Behavior Cloningは行動を真似て学習する方法の一種類であります、摸倣学習とも言われています。摸倣学習といえば逆強化学習もありますが、それ以外にGAILも摸倣学習の一種類です。逆強化の場合エキスパートとの行動履歴から報酬を推定するように学習を行います。一方Behavior CloningとGAILは教師あり学習みたいにエキスパートの行動履歴を正解データとして使って損失関数を最小化しながら、上手に摸倣できたらいい報酬を与えるようになっています。Behavior CloningとGAILの場合エキスパートの行動履歴に主にPIDコントローラもしくは人間による手作業からえられてたデータを使うことになっています。 BCとGAILの実行例は概ね一緒です。

エキスパートのデータを作成

一番最初はエキスパートの履歴データを作成します。状況によってPIDによってデータを作成するか、人間の手作業によってデータを作成するかはわかりますが、私が関わった案件でPIDによってエキスパートのデータを履歴を作成したことがあります。その場合stable-baselinesが提供しているメソッドを使えば簡単にデータを作成できます。

env = gym.make('Pendulum-v0')
def dummy_expert(_obs):
    return env.action_space.sample()

generate_expert_traj(dumpy_expert, 'dummpy_expert_pendulum', env, n_episodes=10)

ここからdummy_expert_pendulum.npzというエキスパート履歴データが作成されます。次のステップでそれを用いて学習を行う。

Behavior Cloningを用いてPre-Train(転移学習)を行う

エキスパートの履歴をデータを読んで、データ・セットを作成する。

dataset = ExpertDataset(expert_path='dummpy_expert_pendulum.npz',
                        traj_limitation=1, batch_size=128)

次に使えたアルゴリズムを用いて学習を回しす。純粋のBehavior cloningの場合PPO2等使って転移学習をおこないます、GAILの場合そのまま使います。

model = PPO2('MlpPolicy', 'Pendulum-v0', verbose=1)
model.pretrain(dataset, n_epochs=1000)
#GAILの場合
model = GAIL('MlpPolicy', 'Pendulum-v0', dataset, verbose=1)
model.learn(total_timesteps=1000)
model.save("gail_pendulum")

実際の実行例

以下のコードをそのまま実行したら確認可能です。

import gym

from stable_baselines import GAIL, SAC
from stable_baselines.gail import ExpertDataset, generate_expert_traj

model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1)
generate_expert_traj(model, 'expert_pendulum', n_timesteps=100, n_episodes=10)

dataset = ExpertDataset(expert_path='expert_pendulum.npz', traj_limitation=10, verbose=1)

model = GAIL('MlpPolicy', 'Pendulum-v0', dataset, verbose=1)
model.learn(total_timesteps=1000)
model.save("gail_pendulum")

del model

model = GAIL.load("gail_pendulum")

env = gym.make('Pendulum-v0')
obs = env.reset()
while True:
  action, _states = model.predict(obs)
  obs, rewards, dones, info = env.step(action)
  env.render()

まとめ

今回強化学習関連内容でした、現在ウチの周り皆強化学習をやっているので、IT業界でも結構熱いテーマになっています、と言っても直接AIに絡んで来るので、これからも強化学習関連の新しい技術が関心を持つでしょう。ということで今回はここまでにさせていただきます。