読者です 読者をやめる 読者になる 読者になる

Itsukaraの日記

最新IT技術を勉強・実践中。最近はDeep Learningに注力。

Deep Learning最新論文の再現コードを試行(DeepMindのA3C)

機械学習 ITのお勉強 Python

前回の記事で書きましたように、DeepMind社の最新論文Asynchronous Methods for Deep Reinforcement Learning、16 Jun 2016に書かれた手法A3C(Asynchronous Advantage Actor-critic)の再現コードGithubで見つけたので、実際に走らせて試行中。

Pongの学習結果

約27時間(36.5M steps)の学習を行った結果が下記です。横軸は学習量(steps、ゲーム画面のframe数に相当)で1M steps単位です。縦軸はゲームSOREです。
f:id:Itsukara:20160729115917p:plain

Github掲載グラフ(A3C-LSTM)では、下記のように、12h(18M steps)あたりで急激にゲームSOREが上がり、16h(24M steps)でほぼ最大になっています。
https://github.com/miyosuda/async_deep_reinforce/raw/master/docs/graph_24h_lstm.png

Github掲載グラフと当方グラフでスケールが異なり申し訳ないのですが、Github掲載グラフの16h(24M steps)近辺と、当方グラフの24M steps近辺を比較してみると、当方グラフでは、ゲームSCOREの急激な上昇がありません。当方のグラフの形は、Github記載の別のグラフ(A3C-FF、下記)と近いです。

https://github.com/miyosuda/async_deep_reinforce/raw/master/docs/graph_24h.png

残念ながら、なぜ、このような結果になったかは不明です。

上記学習後のプレイ動画は下記です。まだまだ学習が不十分で、あまり上手くありません。
youtu.be

Breakoutの学習結果

ほぼ1日掛けて学習させた結果が下記です(途中中断あり)。横軸は学習量(steps、ゲーム画面のframe数に相当)で1M steps単位です。縦軸はゲームSOREです。学習量は全体で52.5M stepsで、学習速度は約800 steps/秒だったので、時間としては約18.5時間です。
f:id:Itsukara:20160729132400p:plain

当方環境で52.5M stepsの学習を完了した後のゲームプレイ動画を下記に載せます。DQNで評価した時と較べて、格段に賢くなっています。なお、動画は、10回くらいプレイしたなかでSCOREが一番高いもので、833点取れてます。ちなみに、このゲームの最高点数は864点*1なので、ほぼ満点です。
youtu.be

論文との比較(学習性能)

論文によると、BreakoutはA3Cの適用効果が特に大きいゲームの1つになります。幾つかの手法の適用効果を比較した図を論文から抜粋しました。

1つ目のグラフはBreakoutのゲームSCOREの「5回の平均値」が書かれており、約6時間で400ぐらいまで上がっています。
f:id:Itsukara:20160729123352p:plain

2つ目のグラフは、横軸がTraining epochsで、縦軸はゲームSOREです。
f:id:Itsukara:20160729135958p:plain

「1 Training epochs」は4M stepsに対応するので、80M stepsでSCOREが400になっています。これらを元に計算すると、学習性能は3700 steps/秒です。当方のマシンは4コア(8仮想CPU)+GPUで800 steps/秒であり、論文のマシンは16コア(32仮想CPU、当方マシンの4倍)で、学習性能は4.6倍なので、当方マシンでの学習性能は、ほぼ妥当と思われます。

性能差の理由は、論文のマシンは最新CPUであることと、Github掲載のプログラムはCPU利用率が70%程度しか出ていない*2ことが原因と思われます。

なお、上記は、A3Cのスレッド効率を示す図です。Breakoutでは、スレッド数が変わっても「学習量(Training epochs)」あたりのSCOREのグラフは、ほぼ変わらないことを示してます。スレッド数に比例して学習量は増えるので、逆に言うと、スレッド数に比例した学習性能が出ているということです。

論文との比較(グラフの形)

当方の評価と論文でグラフの形を比較します。当方の環境では、ゲームSCOREの「最大値」が400になるまで30M steps掛かっています。論文では、20M stepsでゲームSCOREの「平均値」が400になるので、論文と比べると、あまり良い結果は出ていません。

ちゃんと比較するには、論文と同じように、5回の平均SCOREを計算する必要がありそうです。Githubの他のコードでは、そのような処理も入っているのですが、何分、当方の環境では性能が出ないので....

A3Cの性能に関して

それにしても、前回のDQNと今回のA3Cの評価結果、及び、上記に引用した「論文の1つ目のグラフ」を見ると、DQNとA3Cの学習性能の差は桁違いですね。DQNは、DeepMind社の2013年の論文で取り上げられて話題になったわけですが、その後の3年間で、学習性能がずいぶんと高まったようです。Deep Learningは進化が激しいので、最新の成果を使わないとダメですね。

最新の論文について

今回の評価を通し、最新の論文を見る必要があると感じました。DeepMind社は、かなり多数の論文を公開しているので、これらの論文に目を通したほうが良さそうです。DeepMind社の論文公開サイトこの論文「Learning to learn by gradient descent by gradient descent」が気になったので、読んでいるところです。

Gradient Descentは、高速化の「ノウハウ」がいろいろあって、「ノウハウ」に従って手作業で最適化されているのが現状ですが、この論文では、「ノウハウ」や手作業の部分を、Deep Learningで自動的に最適化しようとしています。幾つかの実例では、手作業よりも良い結果が出ているようです。機械学習の勉強を始めた時、「ノウハウ」とか、泥臭いところが多いので、その部分も機械学習で対応できないのかと感じましたが、実際に行われて、論文も出ているのですね。

ちなみに、今欲しいと思っていることは、1つの手法で学習した結果を、他の手法で活用できないかということです。例えば、A3Cには、A3C-FFとA3C-LSTMがあり、後者のほうが最終的な結果が良くなることが多いのですが、学習の初期段階では、なかなかSCOREが上がりません。そこで、A3C-FFで学習させた結果を、A3C-LSTMで学習するときの教師として使えないかと考えています。

Breakout評価中のトラブル

実は、A3C再現コードをBreakoutで評価している際に、途中で計算結果がおかしくなるという問題が発生しました。具体的には、学習中の画面に表示される計算結果がNaN(非数)になっていました。

対応に困ったのですが、30分おきに途中結果をSAVEし(Control-Cで結果がSAVEされる)、SAVEした内容を元に再開するという形で凌ぎました(SAVE内容をLOADする機能が組み込まれていました)。

また、原因を調べるために、30分間の待ち時間に、計算状況やデバッグ用の情報を出力するコードを追加しました。その結果、「Pi」という変数の値が急激に小さくなるのが原因の模様(1回表示される毎に約1/10になり、1e-19ぐらいになる。その後、NaNになる)。

そこで、PiがNaNになったときはControl-Cの入力を促すメッセージを表示し、手動でControl-Cを打ち、途中結果をSAVEするなど、場当たり的な対応を行っていました。しかし、改めてWebで調べたところ、この記事に書かれていること(値が0.0になる変数のlogを計算している)が原因と思われました。そこで、この記事を参考に修正したところ、NaNは発生しなくなりました。

ソースコードの修正

上記のトラブル対応も含め、下記の修正を行っています。まだGithubには載せてませんが、近々、差分を載せる予定です。

  • 元のコードはPython2.7だったので、Python3.5ベースに修正。
  • 計算途中でPiの値がNaNになる部分があったので修正。
  • ゲームSCOREの状況表示を追加。
  • 処理速度(Steps/秒)表示を追加。
  • デモプレイの動画記録機能を追加。

感想

途中に書きましたが、ともかく、Deep Learningは進化が激しいので、最新の情報を得て、それを活用することが重要と感じました。

*1:https://atariage.com/manual_html_page.php?SoftwareID=889

*2:スレッド処理にPythonのthreadモジュールを使っており、GIL(Global Interpreter Lock)が発生するため。Pythonのmultiprocessingを使えば回避できると、GithubのIssues#1に書かれてますが、まだ未対応の模様