Itsukaraの日記

最新IT技術を勉強・実践中。最近はDeep Learningに注力。

【DL、RL】A3Cでのpseudo-countの実装について

DeepMind社の論文[「[1606.01868] Unifying Count-Based Exploration and Intrinsic Motivation」記載のpseudo-countのA3C on pythonでの実装について、防備録を残しておきます。なお、論文は抽象化して書かれていますが、当方の当面のターゲットはAtari2600のゲームなので、その文脈で記載します。

pseudo-countとは

pseudo-countは、ゲームを何回かプレーする中で、同じ画面が出てきた回数をカウントすることを目的としています。これにより、滅多に出現しない画面には大きな価値を与え、その画面に繋がる行動を強化します。

このようなカウント法としては、ゲーム画面を全て記録しておきマッチングするという方法もあります。バイナリーデータとして全く同じ画面を探すのならば、ハッシングなどの技法で直ぐに見つかるものの、探索する画面との差が少ないものを探すためには既存画面数に比例した時間が掛かってしまいます*1。また、どの程度の差まで許容するかなど、判断が難しい部分もあります。

pseudo-countでは、画面の画素毎に、過去の画面でその画素の値が出現した比率を計算し、この比率を全画素について掛け算することで、その画面の出現確率を近似します。この計算だけでは、計算した確率は非常に小さくなってしまうのですが、論文では、ひと工夫することで、現実的な値を求めます。

まず、新たな画面xが追加される前のxの出現比率\rhoと、追加された後の出現比率\rho'を考えます。既存画面数をn、既存画面内でのxの出現回数をNとすると、次のような関係になることが期待されます。

  • \rho=\frac{N}{n}
  • \rho'=\frac{N+1}{n+1}

なお、画素毎では、上記のように画面の個数や出現回数が求まりますが、画面に対しては、上述のように各画素の出現比率の積で画面の出現確率を計算するので、実際には画面の出現回数を数えるわけではありません。そこで、「期待されます」という表現にしてます。nNは、既存の画面の個数とxの出現回数を擬似的に表すもので、それぞれ、pseudo-count total、pseudo-count functionと呼んでいます。

これらの式から逆にNを求めると、次に様になります。

  • N=\frac{\rho(1-\rho')}{\rho'-\rho}≒\frac{\rho}{\rho'-\rho}

上記で求めたNをpseudo-count(擬似的な出現回数)と考えて、出現回数が少ない画面は、希少価値が高いと考えて、Rewardを与えるというものです。これによって、あまり出現したことが無い画面、例えば、Montezuma's Revengeで部屋の鍵を取る場所に到達した画面、別の部屋の画面などに到達する行動が強化されます。

具体的なReward(R)の値としては、下記になります*2

  • R=\beta(N+0.01)^{-1/2}

\betaは、Atari2600の代表的なゲームを評価して求め、A3Cでは0.01にしたとのこと。

pseudo-countの計算

当方プログラムでのpseudo-countの計算方法は下記としました。

新たな画面xのi番目の画素の値の、既存画面の同じ画素での全出現回数をx^iとすると、\rho\rho'は、次の式で求めることができます。

  • \rho=\frac{x^1}{n}\frac{x^2}{n}...\frac{x^k}{n}
  • \rho'=\frac{x^1+1}{n+1}\frac{x^2+1}{n+1}...\frac{x^k+1}{n+1}

これらを直接計算してからNを計算しても良いのですが、これらの値は非常に小さくなり、差を求めるのは不適切そうなので、Nの式を、次のように変形しました。
N≒\frac{\rho}{\rho'-\rho}=\frac{\frac{\rho}{\rho'}}{1-\frac{\rho}{\rho'}}

そのうえで、\frac{\rho}{\rho'}は、次のように計算しました。

  • \frac{\rho}{\rho'}=\frac{\frac{x^1}{n}\frac{x^2}{n}...\frac{x^k}{n}}{\frac{x^1+1}{n+1}\frac{x^2+1}{n+1}...\frac{x^k+1}{n+1}}=(\frac{x^1}{x^1+1}\frac{n+1}{n})(\frac{x^2}{x^2+1}\frac{n+1}{n})...(\frac{x^k}{x^k+1}\frac{n+1}{n})

pythonでは下記のコードにしました(下記は擬似コード)。なお、vは画素の最大値です。

#初期化
value_count = np.zeros((k, v + 1), dtype=np.float32)

#value_countへのimageの追加とreward(R)の計算
nr = (n + 1.0)/n
vcount = value_count[range(k), image]
value_count[range(k), image] += 1.0
r_over_rp = np.prod(nr * vcount / (1.0 + vcount))
N = r_over_rp / (1.0 - r_over_rp)
R = beta / sqrt(N + 0.01)
n += 1

上記のコードは、メモリ上の非連続領域にアクセするためか、計算に結構時間が掛かります。上記を入れる前は800STEPS/秒だった学習速度が、800STEPS/秒まで落ちてしまいました。何とかならないものでしょうか...

追記

pseudo-countを入れたコードを下記に公開しました。途中でControl-Cで止めた時に、countの値がファイルに保存できないという問題がありますが、とりあえず、公開させていただきました。
github.com

*1:画面の特徴量特徴量で分類することである程度高速化するなどの手法は考えられると思いますが...

*2:論文では他の式も評価してます