Itsukaraの日記

最新IT技術を勉強・実践中。最近はDeep Learningに注力。

Variational Autoencoderの20次元中間層内距離の確認結果

昨日(8/31)、全脳アーキテクチャ勉強会のオフ会に参加したところ、NTT研究所の山田さんがVariational Autoencoderの話をされ、MNISTのデータに対して20ノードの中間層まで圧縮できる話を伺いました。

これだけ圧縮すると、いわゆる「おばあちゃんニューロン」のように、20個のノードのそれぞれが0〜9の数字の1つを表すようになっているのではないかと思ったのですが、そうはならないだろうとのこと。

気になったので、一応、確認してみました。

結果、そのとおりでした(おばあちゃんニューロンにはなっていない)。

また、中間層が2個の時は下図のように各数字の図形データの中間層の値が、近くに固まる傾向がありましたが、20個(20次元)でも同様の傾向があることが「数値的」に確認できました(20次元は視覚化ができないので数値的に確認しました)。
f:id:Itsukara:20160901045859p:plain

各数字の図形データに対する中間層の値の分布

学習が終わった後で、各数字の図形データに対して中間層の値の分布を示したものが下記です。各1つの折れ線が1つの図形データに対する中間層の値の分布を表しており、数百個の図形データでの中間層の値の分布を示したものです。1つの数字に対して、中間層は様々な分布になることが分かります。

  • 数字0に対する中間層の値の分布

f:id:Itsukara:20160901040332p:plain

  • 数字1に対する中間層の値の分布

f:id:Itsukara:20160901040500p:plain

  • 数字2に対する中間層の値の分布

f:id:Itsukara:20160901050614p:plain

なお、全ての数字に対する分布図を見ると、中間層では値が殆ど0のノードが多く、値が0意外なのは10個程度でした。この10という数字が偶然なのか、データの種類が10個であることに関係しているのか、不明です。

各数字の図形データの中間層の重心どうしの距離

各数字の図形データの中間層データ(群)に対して、重心(計算的上は平均)を求め、重心間の距離を求めた結果が下記です。重心どうしの距離は、最大でも12であり、結構近いことが分かります。ちなみに、i行j列の成分は、数字iの中間層データの重心と、数字jの中間層データの重心の距離です。

[[ 0 11  7  6  7  6  4  7  6  4]
 [11  0  5  7  7  7  6  6  6  6]
 [ 7  5  0  5  7  8  6  6  6  8]
 [ 6  7  5  0 12  3  7  6  5  7]
 [ 7  7  7 12  0 12  5  7  6  2]
 [ 6  7  8  3 12  0  6  7  6  6]
 [ 4  6  6  7  5  6  0 10  7  5]
 [ 7  6  6  6  7  7 10  0  8  3]
 [ 6  6  6  5  6  6  7  8  0  4]
 [ 4  6  8  7  2  6  5  3  4  0]]

ちなみに、テンソル間のデータの距離は、要素ごとの距離の2乗和として計算しています。

各数字の図形データの重心からの距離

各数字の図形データの重心からの距離を求めたのが下記です。重心間の距離の最大値である12よりも遠くにあるデータが結構あることが分かりました。これでは、正しい重心の近くにない可能性があると思い、次項で確認しました。

  • 数字0に対する中間層データの重心からの距離

f:id:Itsukara:20160901041749p:plain

  • 数字1に対する中間層データの重心からの距離

f:id:Itsukara:20160901041756p:plain

  • 数字2に対する中間層データの重心からの距離

f:id:Itsukara:20160901050851p:plain

各数字の図形データの重心が一番近い重心である比率

各数字の図形データの各データに対して、そのデータが属する重心が他の重心よりも近い比率を求めました。その結果、約9割以上は、正しい重心の近くにあることが分かりました。つまり、各数字の中間層データは、それなりに近くに固まっている事が分かります。

Ratio of 'center of 0 is_nearest' = 99.0%
Ratio of 'center of 1 is_nearest' = 97.0%
Ratio of 'center of 2 is_nearest' = 91.3%
Ratio of 'center of 3 is_nearest' = 93.3%
Ratio of 'center of 4 is_nearest' = 90.9%
Ratio of 'center of 5 is_nearest' = 92.4%
Ratio of 'center of 6 is_nearest' = 91.9%
Ratio of 'center of 7 is_nearest' = 93.3%
Ratio of 'center of 8 is_nearest' = 89.2%
Ratio of 'center of 9 is_nearest' = 88.1%

別の記事で、MNISTデータのK平均法での分類を行っていますが、正しく分類できる比率は71.5%でした。これに対して、上記では約90%が正しく分類されています。つまり、VAEによって次元が784次元から20次元に圧縮されるだけでなく、距離による分類という観点で基のデータより上手く分類されていることが分かります。

ソースコード

ここで公開されていたコード(jupyter IPython)を基に、上記の確認用のコードを追加しました。追加後のコードは下記にあります。なお、html化したファイルも置きましたので、結果だけを見るならば、htmlファイルをご覧ください。ご参考まで。
github.com

感想

疑問点を確認することが出来ましたが、「中間層の活性ノード数が、入力したデータの種類数とほぼ等しい」という理由が、新たな疑問点となりました。